vedioNarrator / app.py
chirag484587's picture
Upload 5 files
bbb5d57 verified
import gradio as gr
import cv2
import torch
import torchvision.transforms as transforms
from torchvision import models
from ultralytics import YOLO
import numpy as np
from collections import OrderedDict, Counter
import time
import tempfile
import os
# Your existing classes (simplified for web deployment)
class SceneClassifier:
def __init__(self, model_path='scene.pth.tar', categories_file='categories_places365.txt'):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load classes
self.classes = []
try:
with open(categories_file) as class_file:
for line in class_file:
self.classes.append(line.strip().split(' ')[0][3:])
except FileNotFoundError:
self.classes = ['indoor', 'outdoor', 'street', 'kitchen', 'bedroom']
self.classes = tuple(self.classes)
# Load model
self.model = models.resnet50(num_classes=len(self.classes))
try:
checkpoint = torch.load(model_path, map_location=self.device)
state_dict = checkpoint['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_key = k.replace('module.', '')
new_state_dict[new_key] = v
self.model.load_state_dict(new_state_dict)
except Exception as e:
print(f"Warning: Could not load scene model: {e}")
self.model.to(self.device)
self.model.eval()
self.transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def predict(self, frame):
try:
if len(frame.shape) == 3:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
else:
frame_rgb = frame
input_tensor = self.transform(frame_rgb).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(input_tensor)
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
top_prob, top_class = torch.topk(probabilities, 1)
scene_class = self.classes[top_class.item()]
confidence = top_prob.item()
return scene_class, confidence
except Exception as e:
return "unknown", 0.0
class NaturalLanguageGenerator:
def __init__(self, confidence_threshold=0.3):
self.confidence_threshold = confidence_threshold
self.templates = {
'basic': [
"I can see {objects} in this {scene}.",
"This {scene} contains {objects}.",
"In this {scene} setting, there are {objects}.",
],
'no_objects': [
"This is a {scene} scene with no prominent objects.",
"The {scene} appears empty.",
]
}
def format_objects(self, objects):
if not objects:
return ""
object_counts = Counter(objects)
formatted = []
for obj, count in object_counts.items():
if count == 1:
formatted.append(f"a {obj}")
else:
formatted.append(f"{count} {obj}s")
if len(formatted) == 1:
return formatted[0]
elif len(formatted) == 2:
return f"{formatted[0]} and {formatted[1]}"
else:
return ", ".join(formatted[:-1]) + f", and {formatted[-1]}"
def generate_description(self, detected_objects, scene_label):
filtered_objects = [
obj_name for obj_name, confidence in detected_objects
if confidence >= self.confidence_threshold
]
scene = scene_label.lower().replace('_', ' ')
if filtered_objects:
object_string = self.format_objects(filtered_objects[:5])
description = f"I can see {object_string} in this {scene}."
else:
description = f"This is a {scene} scene with no prominent objects."
return description.capitalize()
# Initialize models
print("Loading models...")
yolo_model = YOLO('best.pt')
scene_classifier = SceneClassifier('scene.pth.tar')
nlg = NaturalLanguageGenerator()
print("Models loaded!")
def process_video_frame(frame):
"""Process a single frame and return description"""
try:
# Object detection
results = yolo_model(frame, verbose=False)
detected_objects = []
if results and len(results) > 0:
result = results[0]
if hasattr(result, 'boxes') and result.boxes is not None:
for box in result.boxes:
if hasattr(box, 'cls') and hasattr(box, 'conf'):
class_id = int(box.cls.cpu().numpy())
confidence = float(box.conf.cpu().numpy())
if hasattr(result, 'names'):
class_name = result.names[class_id]
detected_objects.append((class_name, confidence))
# Scene classification
scene, scene_confidence = scene_classifier.predict(frame)
# Generate description
description = nlg.generate_description(detected_objects, scene)
# Draw annotations
annotated_frame = results[0].plot() if results else frame
cv2.putText(annotated_frame, f"Scene: {scene}",
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(annotated_frame, f"Objects: {len(detected_objects)}",
(10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
return annotated_frame, description
except Exception as e:
return frame, f"Error: {str(e)}"
def analyze_video_live(video):
"""Process live video stream"""
if video is None:
return None, "No video input"
# Convert to OpenCV format
frame = cv2.cvtColor(video, cv2.COLOR_RGB2BGR)
# Process frame
annotated_frame, description = process_video_frame(frame)
# Convert back to RGB for Gradio
annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
return annotated_frame, description
# Create Gradio interface
with gr.Blocks(title="Live Video Narrator for the Blind") as demo:
gr.Markdown("# 🎥 Live Video Narrator for the Blind")
gr.Markdown("This AI system provides real-time narration of live video feeds to assist visually impaired users.")
with gr.Row():
with gr.Column():
video_input = gr.Image(source="webcam", streaming=True, type="numpy")
with gr.Column():
video_output = gr.Image(label="Processed Video")
description_output = gr.Textbox(label="Live Description", lines=3)
# Set up real-time processing
video_input.stream(
fn=analyze_video_live,
inputs=[video_input],
outputs=[video_output, description_output],
stream_every=0.5, # Process every 0.5 seconds
show_progress=False
)
gr.Markdown("""
## How to Use:
1. Allow camera access when prompted
2. Point your camera at objects/scenes
3. Get real-time audio descriptions
## For Mobile App Integration:
- API endpoint: `/api/predict`
- Send POST request with base64 encoded frame
- Receive JSON response with description and annotations
""")
# Launch the app
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)