Spaces:
Runtime error
Runtime error
import gradio as gr | |
import cv2 | |
import torch | |
import torchvision.transforms as transforms | |
from torchvision import models | |
from ultralytics import YOLO | |
import numpy as np | |
from collections import OrderedDict, Counter | |
import time | |
import tempfile | |
import os | |
# Your existing classes (simplified for web deployment) | |
class SceneClassifier: | |
def __init__(self, model_path='scene.pth.tar', categories_file='categories_places365.txt'): | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Load classes | |
self.classes = [] | |
try: | |
with open(categories_file) as class_file: | |
for line in class_file: | |
self.classes.append(line.strip().split(' ')[0][3:]) | |
except FileNotFoundError: | |
self.classes = ['indoor', 'outdoor', 'street', 'kitchen', 'bedroom'] | |
self.classes = tuple(self.classes) | |
# Load model | |
self.model = models.resnet50(num_classes=len(self.classes)) | |
try: | |
checkpoint = torch.load(model_path, map_location=self.device) | |
state_dict = checkpoint['state_dict'] | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
new_key = k.replace('module.', '') | |
new_state_dict[new_key] = v | |
self.model.load_state_dict(new_state_dict) | |
except Exception as e: | |
print(f"Warning: Could not load scene model: {e}") | |
self.model.to(self.device) | |
self.model.eval() | |
self.transform = transforms.Compose([ | |
transforms.ToPILImage(), | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
def predict(self, frame): | |
try: | |
if len(frame.shape) == 3: | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
else: | |
frame_rgb = frame | |
input_tensor = self.transform(frame_rgb).unsqueeze(0).to(self.device) | |
with torch.no_grad(): | |
outputs = self.model(input_tensor) | |
probabilities = torch.nn.functional.softmax(outputs[0], dim=0) | |
top_prob, top_class = torch.topk(probabilities, 1) | |
scene_class = self.classes[top_class.item()] | |
confidence = top_prob.item() | |
return scene_class, confidence | |
except Exception as e: | |
return "unknown", 0.0 | |
class NaturalLanguageGenerator: | |
def __init__(self, confidence_threshold=0.3): | |
self.confidence_threshold = confidence_threshold | |
self.templates = { | |
'basic': [ | |
"I can see {objects} in this {scene}.", | |
"This {scene} contains {objects}.", | |
"In this {scene} setting, there are {objects}.", | |
], | |
'no_objects': [ | |
"This is a {scene} scene with no prominent objects.", | |
"The {scene} appears empty.", | |
] | |
} | |
def format_objects(self, objects): | |
if not objects: | |
return "" | |
object_counts = Counter(objects) | |
formatted = [] | |
for obj, count in object_counts.items(): | |
if count == 1: | |
formatted.append(f"a {obj}") | |
else: | |
formatted.append(f"{count} {obj}s") | |
if len(formatted) == 1: | |
return formatted[0] | |
elif len(formatted) == 2: | |
return f"{formatted[0]} and {formatted[1]}" | |
else: | |
return ", ".join(formatted[:-1]) + f", and {formatted[-1]}" | |
def generate_description(self, detected_objects, scene_label): | |
filtered_objects = [ | |
obj_name for obj_name, confidence in detected_objects | |
if confidence >= self.confidence_threshold | |
] | |
scene = scene_label.lower().replace('_', ' ') | |
if filtered_objects: | |
object_string = self.format_objects(filtered_objects[:5]) | |
description = f"I can see {object_string} in this {scene}." | |
else: | |
description = f"This is a {scene} scene with no prominent objects." | |
return description.capitalize() | |
# Initialize models | |
print("Loading models...") | |
yolo_model = YOLO('best.pt') | |
scene_classifier = SceneClassifier('scene.pth.tar') | |
nlg = NaturalLanguageGenerator() | |
print("Models loaded!") | |
def process_video_frame(frame): | |
"""Process a single frame and return description""" | |
try: | |
# Object detection | |
results = yolo_model(frame, verbose=False) | |
detected_objects = [] | |
if results and len(results) > 0: | |
result = results[0] | |
if hasattr(result, 'boxes') and result.boxes is not None: | |
for box in result.boxes: | |
if hasattr(box, 'cls') and hasattr(box, 'conf'): | |
class_id = int(box.cls.cpu().numpy()) | |
confidence = float(box.conf.cpu().numpy()) | |
if hasattr(result, 'names'): | |
class_name = result.names[class_id] | |
detected_objects.append((class_name, confidence)) | |
# Scene classification | |
scene, scene_confidence = scene_classifier.predict(frame) | |
# Generate description | |
description = nlg.generate_description(detected_objects, scene) | |
# Draw annotations | |
annotated_frame = results[0].plot() if results else frame | |
cv2.putText(annotated_frame, f"Scene: {scene}", | |
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) | |
cv2.putText(annotated_frame, f"Objects: {len(detected_objects)}", | |
(10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) | |
return annotated_frame, description | |
except Exception as e: | |
return frame, f"Error: {str(e)}" | |
def analyze_video_live(video): | |
"""Process live video stream""" | |
if video is None: | |
return None, "No video input" | |
# Convert to OpenCV format | |
frame = cv2.cvtColor(video, cv2.COLOR_RGB2BGR) | |
# Process frame | |
annotated_frame, description = process_video_frame(frame) | |
# Convert back to RGB for Gradio | |
annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB) | |
return annotated_frame, description | |
# Create Gradio interface | |
with gr.Blocks(title="Live Video Narrator for the Blind") as demo: | |
gr.Markdown("# 🎥 Live Video Narrator for the Blind") | |
gr.Markdown("This AI system provides real-time narration of live video feeds to assist visually impaired users.") | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Image(source="webcam", streaming=True, type="numpy") | |
with gr.Column(): | |
video_output = gr.Image(label="Processed Video") | |
description_output = gr.Textbox(label="Live Description", lines=3) | |
# Set up real-time processing | |
video_input.stream( | |
fn=analyze_video_live, | |
inputs=[video_input], | |
outputs=[video_output, description_output], | |
stream_every=0.5, # Process every 0.5 seconds | |
show_progress=False | |
) | |
gr.Markdown(""" | |
## How to Use: | |
1. Allow camera access when prompted | |
2. Point your camera at objects/scenes | |
3. Get real-time audio descriptions | |
## For Mobile App Integration: | |
- API endpoint: `/api/predict` | |
- Send POST request with base64 encoded frame | |
- Receive JSON response with description and annotations | |
""") | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True | |
) |