Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # import spaces first | |
| import spaces | |
| import gradio as gr | |
| import os | |
| from main import load_moondream, process_video, load_sam_model | |
| import shutil | |
| import torch | |
| from visualization import visualize_detections | |
| from persistence import load_detection_data | |
| import matplotlib.pyplot as plt | |
| import io | |
| from PIL import Image | |
| import pandas as pd | |
| from video_visualization import create_video_visualization | |
| # Get absolute path to workspace root | |
| WORKSPACE_ROOT = os.path.dirname(os.path.abspath(__file__)) | |
| # Check CUDA availability | |
| print(f"Is CUDA available: {torch.cuda.is_available()}") | |
| # We want to get True | |
| print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
| # GPU Name | |
| # Initialize Moondream model globally for reuse (will be loaded on first use) | |
| model, tokenizer = None, None | |
| # Uncomment for Hugging Face Spaces | |
| def process_video_file( | |
| video_file, target_object, box_style, ffmpeg_preset, grid_rows, grid_cols, test_mode, test_duration | |
| ): | |
| """Process a video file through the Gradio interface.""" | |
| try: | |
| if not video_file: | |
| raise gr.Error("Please upload a video file") | |
| # Load models if not already loaded | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| model, tokenizer = load_moondream() | |
| # Ensure input/output directories exist using absolute paths | |
| inputs_dir = os.path.join(WORKSPACE_ROOT, "inputs") | |
| outputs_dir = os.path.join(WORKSPACE_ROOT, "outputs") | |
| os.makedirs(inputs_dir, exist_ok=True) | |
| os.makedirs(outputs_dir, exist_ok=True) | |
| # Copy uploaded video to inputs directory | |
| video_filename = f"input_{os.path.basename(video_file)}" | |
| input_video_path = os.path.join(inputs_dir, video_filename) | |
| shutil.copy2(video_file, input_video_path) | |
| try: | |
| # Process the video | |
| output_path = process_video( | |
| input_video_path, | |
| target_object, | |
| test_mode=test_mode, | |
| test_duration=test_duration, | |
| ffmpeg_preset=ffmpeg_preset, | |
| grid_rows=grid_rows, | |
| grid_cols=grid_cols, | |
| box_style=box_style, | |
| ) | |
| # Get the corresponding JSON path | |
| base_name = os.path.splitext(os.path.basename(video_filename))[0] | |
| json_path = os.path.join(outputs_dir, f"{box_style}_{target_object}_{base_name}_detections.json") | |
| # Verify output exists and is readable | |
| if not output_path or not os.path.exists(output_path): | |
| print(f"Warning: Output path {output_path} does not exist") | |
| # Try to find the output based on expected naming convention | |
| expected_output = os.path.join( | |
| outputs_dir, f"{box_style}_{target_object}_{video_filename}" | |
| ) | |
| if os.path.exists(expected_output): | |
| output_path = expected_output | |
| else: | |
| # Try searching in outputs directory for any matching file | |
| matching_files = [ | |
| f | |
| for f in os.listdir(outputs_dir) | |
| if f.startswith(f"{box_style}_{target_object}_") | |
| ] | |
| if matching_files: | |
| output_path = os.path.join(outputs_dir, matching_files[0]) | |
| else: | |
| raise gr.Error("Failed to locate output video") | |
| # Convert output path to absolute path if it isn't already | |
| if not os.path.isabs(output_path): | |
| output_path = os.path.join(WORKSPACE_ROOT, output_path) | |
| print(f"Returning output path: {output_path}") | |
| return output_path, json_path | |
| finally: | |
| # Clean up input file | |
| try: | |
| if os.path.exists(input_video_path): | |
| os.remove(input_video_path) | |
| except: | |
| pass | |
| except Exception as e: | |
| print(f"Error in process_video_file: {str(e)}") | |
| raise gr.Error(f"Error processing video: {str(e)}") | |
| def create_visualization_plots(json_path): | |
| """Create visualization plots and return them as images.""" | |
| try: | |
| # Load the data | |
| data = load_detection_data(json_path) | |
| if not data: | |
| return None, None, None, None, None, None, None, None, "No data found" | |
| # Convert to DataFrame | |
| rows = [] | |
| for frame_data in data["frame_detections"]: | |
| frame = frame_data["frame"] | |
| timestamp = frame_data["timestamp"] | |
| for obj in frame_data["objects"]: | |
| rows.append({ | |
| "frame": frame, | |
| "timestamp": timestamp, | |
| "keyword": obj["keyword"], | |
| "x1": obj["bbox"][0], | |
| "y1": obj["bbox"][1], | |
| "x2": obj["bbox"][2], | |
| "y2": obj["bbox"][3], | |
| "area": (obj["bbox"][2] - obj["bbox"][0]) * (obj["bbox"][3] - obj["bbox"][1]), | |
| "center_x": (obj["bbox"][0] + obj["bbox"][2]) / 2, | |
| "center_y": (obj["bbox"][1] + obj["bbox"][3]) / 2 | |
| }) | |
| if not rows: | |
| return None, None, None, None, None, None, None, None, "No detections found in the data" | |
| df = pd.DataFrame(rows) | |
| plots = [] | |
| # Create each plot and convert to image | |
| for plot_num in range(8): # Increased to 8 plots | |
| plt.figure(figsize=(8, 6)) | |
| if plot_num == 0: | |
| # Plot 1: Number of detections per frame (Original) | |
| detections_per_frame = df.groupby("frame").size() | |
| plt.plot(detections_per_frame.index, detections_per_frame.values) | |
| plt.xlabel("Frame") | |
| plt.ylabel("Number of Detections") | |
| plt.title("Detections Per Frame") | |
| elif plot_num == 1: | |
| # Plot 2: Distribution of detection areas (Original) | |
| df["area"].hist(bins=30) | |
| plt.xlabel("Detection Area (normalized)") | |
| plt.ylabel("Count") | |
| plt.title("Distribution of Detection Areas") | |
| elif plot_num == 2: | |
| # Plot 3: Average detection area over time (Original) | |
| avg_area = df.groupby("frame")["area"].mean() | |
| plt.plot(avg_area.index, avg_area.values) | |
| plt.xlabel("Frame") | |
| plt.ylabel("Average Detection Area") | |
| plt.title("Average Detection Area Over Time") | |
| elif plot_num == 3: | |
| # Plot 4: Heatmap of detection centers (Original) | |
| plt.hist2d(df["center_x"], df["center_y"], bins=30) | |
| plt.colorbar() | |
| plt.xlabel("X Position") | |
| plt.ylabel("Y Position") | |
| plt.title("Detection Center Heatmap") | |
| elif plot_num == 4: | |
| # Plot 5: Time-based Detection Density | |
| # Shows when in the video most detections occur | |
| df["time_bucket"] = pd.qcut(df["timestamp"], q=20, labels=False) | |
| time_density = df.groupby("time_bucket").size() | |
| plt.bar(time_density.index, time_density.values) | |
| plt.xlabel("Video Timeline (20 segments)") | |
| plt.ylabel("Number of Detections") | |
| plt.title("Detection Density Over Video Duration") | |
| elif plot_num == 5: | |
| # Plot 6: Screen Region Analysis | |
| # Divide screen into 3x3 grid and show detection counts | |
| try: | |
| df["grid_x"] = pd.qcut(df["center_x"], q=3, labels=["Left", "Center", "Right"], duplicates='drop') | |
| df["grid_y"] = pd.qcut(df["center_y"], q=3, labels=["Top", "Middle", "Bottom"], duplicates='drop') | |
| region_counts = df.groupby(["grid_y", "grid_x"]).size().unstack(fill_value=0) | |
| plt.imshow(region_counts, cmap="YlOrRd") | |
| plt.colorbar(label="Detection Count") | |
| for i in range(3): | |
| for j in range(3): | |
| plt.text(j, i, region_counts.iloc[i, j], ha="center", va="center") | |
| plt.xticks(range(3), ["Left", "Center", "Right"]) | |
| plt.yticks(range(3), ["Top", "Middle", "Bottom"]) | |
| plt.title("Screen Region Analysis") | |
| except Exception as e: | |
| plt.text(0.5, 0.5, "Insufficient variation in detection positions", | |
| ha='center', va='center') | |
| plt.title("Screen Region Analysis (Not Available)") | |
| elif plot_num == 6: | |
| # Plot 7: Detection Size Categories | |
| # Categorize detections by size for content moderation | |
| try: | |
| size_labels = [ | |
| "Small (likely far/background)", | |
| "Medium-small", | |
| "Medium-large", | |
| "Large (likely foreground/close)" | |
| ] | |
| # Handle cases with limited unique values | |
| unique_areas = df["area"].nunique() | |
| if unique_areas >= 4: | |
| df["size_category"] = pd.qcut(df["area"], q=4, labels=size_labels, duplicates='drop') | |
| else: | |
| # Alternative binning for limited unique values | |
| df["size_category"] = pd.cut(df["area"], | |
| bins=unique_areas, | |
| labels=size_labels[:unique_areas]) | |
| size_dist = df["size_category"].value_counts() | |
| plt.pie(size_dist.values, labels=size_dist.index, autopct="%1.1f%%") | |
| plt.title("Detection Size Distribution") | |
| except Exception as e: | |
| plt.text(0.5, 0.5, "Insufficient variation in detection sizes", | |
| ha='center', va='center') | |
| plt.title("Detection Size Distribution (Not Available)") | |
| elif plot_num == 7: | |
| # Plot 8: Temporal Pattern Analysis | |
| # Show patterns of when detections occur in sequence | |
| try: | |
| detection_gaps = df.sort_values("frame")["frame"].diff() | |
| if len(detection_gaps.dropna().unique()) > 1: | |
| plt.hist(detection_gaps.dropna(), bins=min(30, len(detection_gaps.dropna().unique())), | |
| edgecolor="black") | |
| plt.xlabel("Frames Between Detections") | |
| plt.ylabel("Frequency") | |
| plt.title("Detection Temporal Pattern Analysis") | |
| else: | |
| plt.text(0.5, 0.5, "Uniform detection intervals", ha='center', va='center') | |
| plt.title("Temporal Pattern Analysis (Uniform)") | |
| except Exception as e: | |
| plt.text(0.5, 0.5, "Insufficient temporal data", ha='center', va='center') | |
| plt.title("Temporal Pattern Analysis (Not Available)") | |
| # Save plot to bytes | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight') | |
| buf.seek(0) | |
| plots.append(Image.open(buf)) | |
| plt.close() | |
| # Enhanced summary text | |
| summary = f"""Summary Statistics: | |
| Total frames analyzed: {len(data['frame_detections'])} | |
| Total detections: {len(df)} | |
| Average detections per frame: {len(df) / len(data['frame_detections']):.2f} | |
| Detection Patterns: | |
| - Peak detection count: {df.groupby('frame').size().max()} (in a single frame) | |
| - Most common screen region: {df.groupby(['grid_y', 'grid_x']).size().idxmax()} | |
| - Average detection size: {df['area'].mean():.3f} | |
| - Median frames between detections: {detection_gaps.median():.1f} | |
| Video metadata: | |
| """ | |
| for key, value in data["video_metadata"].items(): | |
| summary += f"{key}: {value}\n" | |
| return plots[0], plots[1], plots[2], plots[3], plots[4], plots[5], plots[6], plots[7], summary | |
| except Exception as e: | |
| print(f"Error creating visualization: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, None, None, None, None, None, None, None, f"Error creating visualization: {str(e)}" | |
| # Create the Gradio interface | |
| with gr.Blocks(title="Promptable Content Moderation") as app: | |
| with gr.Tabs(): | |
| with gr.Tab("Process Video"): | |
| gr.Markdown("# Promptable Content Moderation with Moondream") | |
| gr.Markdown( | |
| """ | |
| Powered by [Moondream 2B](https://github.com/vikhyat/moondream). | |
| Upload a video and specify what to moderate. The app will process each frame and moderate any visual content that matches the prompt. For help, join the [Moondream Discord](https://discord.com/invite/tRUdpjDQfH). | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Input components | |
| video_input = gr.Video(label="Upload Video") | |
| detect_input = gr.Textbox( | |
| label="What to Moderate", | |
| placeholder="e.g. face, cigarette, gun, etc.", | |
| value="face", | |
| info="Moondream can moderate anything that you can describe in natural language", | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/cig.mp4", "cigarette"], | |
| ["examples/gun.mp4", "gun"], | |
| ["examples/homealone.mp4", "face"], | |
| ["examples/conf.mp4", "confederate flag"], | |
| ], | |
| inputs=[video_input, detect_input], | |
| label="Try these examples", | |
| ) | |
| process_btn = gr.Button("Process Video", variant="primary") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| box_style_input = gr.Radio( | |
| choices=["censor", "bounding-box", "hitmarker", "sam", "sam-fast", "fuzzy-blur", "pixelated-blur", "intense-pixelated-blur", "obfuscated-pixel"], | |
| value="obfuscated-pixel", | |
| label="Visualization Style", | |
| info="Choose how to display moderations: censor (black boxes), bounding-box (red boxes with labels), hitmarker (COD-style markers), sam (precise segmentation), sam-fast (faster but less precise segmentation), fuzzy-blur (Gaussian blur), pixelated-blur (pixelated with blur), obfuscated-pixel (advanced pixelation with neighborhood averaging)", | |
| ) | |
| preset_input = gr.Dropdown( | |
| choices=[ | |
| "ultrafast", | |
| "superfast", | |
| "veryfast", | |
| "faster", | |
| "fast", | |
| "medium", | |
| "slow", | |
| "slower", | |
| "veryslow", | |
| ], | |
| value="medium", | |
| label="Processing Speed (faster = lower quality)", | |
| ) | |
| with gr.Row(): | |
| rows_input = gr.Slider( | |
| minimum=1, maximum=4, value=1, step=1, label="Grid Rows" | |
| ) | |
| cols_input = gr.Slider( | |
| minimum=1, maximum=4, value=1, step=1, label="Grid Columns" | |
| ) | |
| test_mode_input = gr.Checkbox( | |
| label="Test Mode (Process first 3 seconds only)", | |
| value=True, | |
| info="Enable to quickly test settings on a short clip before processing the full video (recommended). If using the data visualizations, disable.", | |
| ) | |
| test_duration_input = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=3, | |
| step=1, | |
| label="Test Mode Duration (seconds)", | |
| info="Number of seconds to process in test mode" | |
| ) | |
| gr.Markdown( | |
| """ | |
| Note: Processing in test mode will only process the first 3 seconds of the video and is recommended for testing settings. | |
| """ | |
| ) | |
| gr.Markdown( | |
| """ | |
| We can get a rough estimate of how long the video will take to process by multiplying the videos framerate * seconds * the number of rows and columns and assuming 0.12 seconds processing time per detection. | |
| For example, a 3 second video at 30fps with 2x2 grid, the estimated time is 3 * 30 * 2 * 2 * 0.12 = 43.2 seconds (tested on a 4090 GPU). | |
| Note: Using the SAM visualization style will increase processing time significantly as it performs additional segmentation for each detection. The sam-fast option uses a smaller model for faster processing at the cost of some accuracy. | |
| """ | |
| ) | |
| with gr.Column(): | |
| # Output components | |
| video_output = gr.Video(label="Processed Video") | |
| json_output = gr.Text(label="Detection Data Path", visible=False) | |
| # About section under the video output | |
| gr.Markdown( | |
| """ | |
| ### Links: | |
| - [GitHub Repository](https://github.com/vikhyat/moondream) | |
| - [Hugging Face](https://huggingface.co/vikhyatk/moondream2) | |
| - [Quick Start](https://docs.moondream.ai/quick-start) | |
| - [Moondream Recipes](https://docs.moondream.ai/recipes) | |
| """ | |
| ) | |
| with gr.Tab("Analyze Results"): | |
| gr.Markdown("# Detection Analysis") | |
| gr.Markdown( | |
| """ | |
| Analyze the detection results from processed videos. The analysis includes: | |
| - Basic detection statistics and patterns | |
| - Temporal and spatial distribution analysis | |
| - Size-based categorization | |
| - Screen region analysis | |
| - Detection density patterns | |
| """ | |
| ) | |
| with gr.Row(): | |
| json_input = gr.File( | |
| label="Upload Detection Data (JSON)", | |
| file_types=[".json"], | |
| ) | |
| analyze_btn = gr.Button("Analyze", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| plot1 = gr.Image( | |
| label="Detections Per Frame", | |
| ) | |
| plot2 = gr.Image( | |
| label="Detection Areas Distribution", | |
| ) | |
| plot5 = gr.Image( | |
| label="Detection Density Timeline", | |
| ) | |
| plot6 = gr.Image( | |
| label="Screen Region Analysis", | |
| ) | |
| with gr.Column(): | |
| plot3 = gr.Image( | |
| label="Average Detection Area Over Time", | |
| ) | |
| plot4 = gr.Image( | |
| label="Detection Center Heatmap", | |
| ) | |
| plot7 = gr.Image( | |
| label="Detection Size Categories", | |
| ) | |
| plot8 = gr.Image( | |
| label="Temporal Pattern Analysis", | |
| ) | |
| stats_output = gr.Textbox( | |
| label="Statistics", | |
| info="Summary of key metrics and patterns found in the detection data.", | |
| lines=12, | |
| max_lines=15, | |
| interactive=False | |
| ) | |
| # with gr.Tab("Video Visualizations"): | |
| # gr.Markdown("# Real-time Detection Visualization") | |
| # gr.Markdown( | |
| # """ | |
| # Watch the detection patterns unfold in real-time. Choose from: | |
| # - Timeline: Shows number of detections over time | |
| # - Gauge: Simple yes/no indicator for current frame detections | |
| # """ | |
| # ) | |
| # with gr.Row(): | |
| # json_input_realtime = gr.File( | |
| # label="Upload Detection Data (JSON)", | |
| # file_types=[".json"], | |
| # ) | |
| # viz_style = gr.Radio( | |
| # choices=["timeline", "gauge"], | |
| # value="timeline", | |
| # label="Visualization Style", | |
| # info="Choose between timeline view or simple gauge indicator" | |
| # ) | |
| # visualize_btn = gr.Button("Visualize", variant="primary") | |
| # with gr.Row(): | |
| # video_visualization = gr.Video( | |
| # label="Detection Visualization", | |
| # interactive=False | |
| # ) | |
| # stats_realtime = gr.Textbox( | |
| # label="Video Statistics", | |
| # lines=6, | |
| # max_lines=8, | |
| # interactive=False | |
| # ) | |
| # Event handlers | |
| process_outputs = process_btn.click( | |
| fn=process_video_file, | |
| inputs=[ | |
| video_input, | |
| detect_input, | |
| box_style_input, | |
| preset_input, | |
| rows_input, | |
| cols_input, | |
| test_mode_input, | |
| test_duration_input, | |
| ], | |
| outputs=[video_output, json_output], | |
| ) | |
| # Auto-analyze after processing | |
| process_outputs.then( | |
| fn=create_visualization_plots, | |
| inputs=[json_output], | |
| outputs=[plot1, plot2, plot3, plot4, plot5, plot6, plot7, plot8, stats_output], | |
| ) | |
| # Manual analysis button | |
| analyze_btn.click( | |
| fn=create_visualization_plots, | |
| inputs=[json_input], | |
| outputs=[plot1, plot2, plot3, plot4, plot5, plot6, plot7, plot8, stats_output], | |
| ) | |
| # Video visualization button | |
| # visualize_btn.click( | |
| # fn=lambda json_file, style: create_video_visualization(json_file.name if json_file else None, style), | |
| # inputs=[json_input_realtime, viz_style], | |
| # outputs=[video_visualization, stats_realtime], | |
| # ) | |
| if __name__ == "__main__": | |
| app.launch(share=True) | |