Spaces:
Runtime error
Runtime error
File size: 8,158 Bytes
bbb5d57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
import gradio as gr
import cv2
import torch
import torchvision.transforms as transforms
from torchvision import models
from ultralytics import YOLO
import numpy as np
from collections import OrderedDict, Counter
import time
import tempfile
import os
# Your existing classes (simplified for web deployment)
class SceneClassifier:
def __init__(self, model_path='scene.pth.tar', categories_file='categories_places365.txt'):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load classes
self.classes = []
try:
with open(categories_file) as class_file:
for line in class_file:
self.classes.append(line.strip().split(' ')[0][3:])
except FileNotFoundError:
self.classes = ['indoor', 'outdoor', 'street', 'kitchen', 'bedroom']
self.classes = tuple(self.classes)
# Load model
self.model = models.resnet50(num_classes=len(self.classes))
try:
checkpoint = torch.load(model_path, map_location=self.device)
state_dict = checkpoint['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_key = k.replace('module.', '')
new_state_dict[new_key] = v
self.model.load_state_dict(new_state_dict)
except Exception as e:
print(f"Warning: Could not load scene model: {e}")
self.model.to(self.device)
self.model.eval()
self.transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def predict(self, frame):
try:
if len(frame.shape) == 3:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
else:
frame_rgb = frame
input_tensor = self.transform(frame_rgb).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(input_tensor)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
top_prob, top_class = torch.topk(probabilities, 1)
scene_class = self.classes[top_class.item()]
confidence = top_prob.item()
return scene_class, confidence
except Exception as e:
return "unknown", 0.0
class NaturalLanguageGenerator:
def __init__(self, confidence_threshold=0.3):
self.confidence_threshold = confidence_threshold
self.templates = {
'basic': [
"I can see {objects} in this {scene}.",
"This {scene} contains {objects}.",
"In this {scene} setting, there are {objects}.",
],
'no_objects': [
"This is a {scene} scene with no prominent objects.",
"The {scene} appears empty.",
]
}
def format_objects(self, objects):
if not objects:
return ""
object_counts = Counter(objects)
formatted = []
for obj, count in object_counts.items():
if count == 1:
formatted.append(f"a {obj}")
else:
formatted.append(f"{count} {obj}s")
if len(formatted) == 1:
return formatted[0]
elif len(formatted) == 2:
return f"{formatted[0]} and {formatted[1]}"
else:
return ", ".join(formatted[:-1]) + f", and {formatted[-1]}"
def generate_description(self, detected_objects, scene_label):
filtered_objects = [
obj_name for obj_name, confidence in detected_objects
if confidence >= self.confidence_threshold
]
scene = scene_label.lower().replace('_', ' ')
if filtered_objects:
object_string = self.format_objects(filtered_objects[:5])
description = f"I can see {object_string} in this {scene}."
else:
description = f"This is a {scene} scene with no prominent objects."
return description.capitalize()
# Initialize models
print("Loading models...")
yolo_model = YOLO('best.pt')
scene_classifier = SceneClassifier('scene.pth.tar')
nlg = NaturalLanguageGenerator()
print("Models loaded!")
def process_video_frame(frame):
"""Process a single frame and return description"""
try:
# Object detection
results = yolo_model(frame, verbose=False)
detected_objects = []
if results and len(results) > 0:
result = results[0]
if hasattr(result, 'boxes') and result.boxes is not None:
for box in result.boxes:
if hasattr(box, 'cls') and hasattr(box, 'conf'):
class_id = int(box.cls.cpu().numpy())
confidence = float(box.conf.cpu().numpy())
if hasattr(result, 'names'):
class_name = result.names[class_id]
detected_objects.append((class_name, confidence))
# Scene classification
scene, scene_confidence = scene_classifier.predict(frame)
# Generate description
description = nlg.generate_description(detected_objects, scene)
# Draw annotations
annotated_frame = results[0].plot() if results else frame
cv2.putText(annotated_frame, f"Scene: {scene}",
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(annotated_frame, f"Objects: {len(detected_objects)}",
(10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
return annotated_frame, description
except Exception as e:
return frame, f"Error: {str(e)}"
def analyze_video_live(video):
"""Process live video stream"""
if video is None:
return None, "No video input"
# Convert to OpenCV format
frame = cv2.cvtColor(video, cv2.COLOR_RGB2BGR)
# Process frame
annotated_frame, description = process_video_frame(frame)
# Convert back to RGB for Gradio
annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
return annotated_frame, description
# Create Gradio interface
with gr.Blocks(title="Live Video Narrator for the Blind") as demo:
gr.Markdown("# 🎥 Live Video Narrator for the Blind")
gr.Markdown("This AI system provides real-time narration of live video feeds to assist visually impaired users.")
with gr.Row():
with gr.Column():
video_input = gr.Image(source="webcam", streaming=True, type="numpy")
with gr.Column():
video_output = gr.Image(label="Processed Video")
description_output = gr.Textbox(label="Live Description", lines=3)
# Set up real-time processing
video_input.stream(
fn=analyze_video_live,
inputs=[video_input],
outputs=[video_output, description_output],
stream_every=0.5, # Process every 0.5 seconds
show_progress=False
)
gr.Markdown("""
## How to Use:
1. Allow camera access when prompted
2. Point your camera at objects/scenes
3. Get real-time audio descriptions
## For Mobile App Integration:
- API endpoint: `/api/predict`
- Send POST request with base64 encoded frame
- Receive JSON response with description and annotations
""")
# Launch the app
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
) |