Spaces:
Sleeping
Sleeping
import gradio as gr | |
from ultralytics import YOLO | |
import cv2 | |
import numpy as np | |
import tempfile | |
import os | |
# Load model | |
print("Loading model...") | |
model = YOLO('best.pt') | |
print("Model loaded successfully") | |
def process_image(image): | |
try: | |
print("\nProcessing image...") | |
# Resize to match training size | |
processed_image = cv2.resize(image, (1024, 1024)) | |
# Run detection | |
results = model(processed_image, conf=0.25) | |
if results[0].obb is not None and len(results[0].obb) > 0: | |
obb_results = results[0].obb | |
# Count detections by class | |
caries_count = 0 | |
non_caries_count = 0 | |
# Process detections | |
for i in range(len(obb_results)): | |
cls = int(obb_results.cls[i]) | |
if cls == 0: # assuming 0 is caries | |
caries_count += 1 | |
else: | |
non_caries_count += 1 | |
# Get annotated image | |
annotated_image = results[0].plot( | |
conf=True, | |
line_width=2, | |
font_size=15, | |
labels=True | |
) | |
# Resize back to original size | |
annotated_image = cv2.resize(annotated_image, (image.shape[1], image.shape[0])) | |
summary = f"Found {caries_count} cavities and {non_caries_count} normal teeth" | |
else: | |
annotated_image = image | |
summary = "No detections found" | |
return annotated_image, summary | |
except Exception as e: | |
print(f"Error in image processing: {str(e)}") | |
return image, f"Error: {str(e)}" | |
def process_video(video_path): | |
try: | |
print("Processing video...") | |
cap = cv2.VideoCapture(video_path) | |
# Get video properties | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
# Calculate frame interval (7 seconds * fps) | |
interval = 7 * fps | |
# Create temporary output file | |
temp_dir = tempfile.gettempdir() | |
output_path = os.path.join(temp_dir, 'output_video.mp4') | |
# Initialize video writer | |
output = cv2.VideoWriter( | |
output_path, | |
cv2.VideoWriter_fourcc(*'mp4v'), | |
fps, | |
(width, height) | |
) | |
frame_count = 0 | |
last_processed_frame = None | |
frames_since_last_process = 0 | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Process frame if it's at the interval or first frame | |
if frame_count == 0 or frames_since_last_process >= interval: | |
print(f"Processing frame at {frame_count / fps:.1f} seconds") | |
# Process frame | |
processed_frame = cv2.resize(frame, (1024, 1024)) | |
results = model(processed_frame, conf=0.25) | |
if results[0].obb is not None: | |
# Draw detections | |
annotated_frame = results[0].plot( | |
conf=True, | |
line_width=2, | |
font_size=15, | |
labels=True | |
) | |
annotated_frame = cv2.resize(annotated_frame, (width, height)) | |
last_processed_frame = annotated_frame | |
else: | |
last_processed_frame = frame | |
frames_since_last_process = 0 | |
# Write the last processed frame | |
if last_processed_frame is not None: | |
output.write(last_processed_frame) | |
else: | |
output.write(frame) | |
frame_count += 1 | |
frames_since_last_process += 1 | |
print(f"Processed {frame_count}/{total_frames} frames", end='\r') | |
cap.release() | |
output.release() | |
summary = f"Processed video with {total_frames} frames\nDetected frames at 7-second intervals" | |
return output_path, summary | |
except Exception as e: | |
print(f"Error in video processing: {str(e)}") | |
return None, f"Error: {str(e)}" | |
# Create Gradio interface with examples | |
with gr.Blocks(title="Dental Cavity Detection") as demo: | |
gr.Markdown(""" | |
# Dental Cavity Detection System | |
Upload or select an example to detect dental cavities | |
""") | |
with gr.Tabs(): | |
with gr.Tab("Image Detection"): | |
with gr.Row(): | |
# Input Column | |
with gr.Column(): | |
image_input = gr.Image(label="Upload Image") | |
image_button = gr.Button("Detect Cavities", variant="primary") | |
# Output Column | |
with gr.Column(): | |
image_output = gr.Image(label="Detection Result") | |
image_summary = gr.Textbox(label="Detection Summary") | |
# Add examples using Gradio's examples feature | |
gr.Examples( | |
examples=[ | |
["image1.jpg"], | |
["image2.jpg"] | |
], | |
inputs=image_input, | |
label="Example Images - Click to use" | |
) | |
# Process button click | |
image_button.click( | |
fn=process_image, | |
inputs=image_input, | |
outputs=[image_output, image_summary] | |
) | |
with gr.Tab("Video Detection"): | |
with gr.Row(): | |
# Input Column | |
with gr.Column(): | |
video_input = gr.Video(label="Upload Video") | |
video_button = gr.Button("Process Video", variant="primary") | |
# Output Column | |
with gr.Column(): | |
video_output = gr.Video(label="Processed Video") | |
gr.Markdown("""⚠️ Note: Video preview may not work in browser. | |
Please use the download button below the video to view results.""") | |
video_summary = gr.Textbox(label="Processing Summary") | |
# Add video example | |
gr.Examples( | |
examples=[ | |
["video1.mp4"] | |
], | |
inputs=video_input, | |
label="Example Video - Click to use" | |
) | |
# Process button click | |
video_button.click( | |
fn=process_video, | |
inputs=video_input, | |
outputs=[video_output, video_summary] | |
) | |
# Add footer with instructions | |
gr.Markdown(""" | |
### Instructions: | |
1. Choose Image or Video tab | |
2. Upload your own file or click an example below | |
3. Click the detection button to process | |
""") | |
if __name__ == "__main__": | |
demo.launch(debug=True) |