6554898F / app.py
Mitthoon97's picture
Update app.py
8108c95 verified
import gradio as gr
from ultralytics import YOLO
import cv2
import numpy as np
import tempfile
import os
# Load model
print("Loading model...")
model = YOLO('best.pt')
print("Model loaded successfully")
def process_image(image):
try:
print("\nProcessing image...")
# Resize to match training size
processed_image = cv2.resize(image, (1024, 1024))
# Run detection
results = model(processed_image, conf=0.25)
if results[0].obb is not None and len(results[0].obb) > 0:
obb_results = results[0].obb
# Count detections by class
caries_count = 0
non_caries_count = 0
# Process detections
for i in range(len(obb_results)):
cls = int(obb_results.cls[i])
if cls == 0: # assuming 0 is caries
caries_count += 1
else:
non_caries_count += 1
# Get annotated image
annotated_image = results[0].plot(
conf=True,
line_width=2,
font_size=15,
labels=True
)
# Resize back to original size
annotated_image = cv2.resize(annotated_image, (image.shape[1], image.shape[0]))
summary = f"Found {caries_count} cavities and {non_caries_count} normal teeth"
else:
annotated_image = image
summary = "No detections found"
return annotated_image, summary
except Exception as e:
print(f"Error in image processing: {str(e)}")
return image, f"Error: {str(e)}"
def process_video(video_path):
try:
print("Processing video...")
cap = cv2.VideoCapture(video_path)
# Get video properties
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Calculate frame interval (7 seconds * fps)
interval = 7 * fps
# Create temporary output file
temp_dir = tempfile.gettempdir()
output_path = os.path.join(temp_dir, 'output_video.mp4')
# Initialize video writer
output = cv2.VideoWriter(
output_path,
cv2.VideoWriter_fourcc(*'mp4v'),
fps,
(width, height)
)
frame_count = 0
last_processed_frame = None
frames_since_last_process = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Process frame if it's at the interval or first frame
if frame_count == 0 or frames_since_last_process >= interval:
print(f"Processing frame at {frame_count / fps:.1f} seconds")
# Process frame
processed_frame = cv2.resize(frame, (1024, 1024))
results = model(processed_frame, conf=0.25)
if results[0].obb is not None:
# Draw detections
annotated_frame = results[0].plot(
conf=True,
line_width=2,
font_size=15,
labels=True
)
annotated_frame = cv2.resize(annotated_frame, (width, height))
last_processed_frame = annotated_frame
else:
last_processed_frame = frame
frames_since_last_process = 0
# Write the last processed frame
if last_processed_frame is not None:
output.write(last_processed_frame)
else:
output.write(frame)
frame_count += 1
frames_since_last_process += 1
print(f"Processed {frame_count}/{total_frames} frames", end='\r')
cap.release()
output.release()
summary = f"Processed video with {total_frames} frames\nDetected frames at 7-second intervals"
return output_path, summary
except Exception as e:
print(f"Error in video processing: {str(e)}")
return None, f"Error: {str(e)}"
# Create Gradio interface with examples
with gr.Blocks(title="Dental Cavity Detection") as demo:
gr.Markdown("""
# Dental Cavity Detection System
Upload or select an example to detect dental cavities
""")
with gr.Tabs():
with gr.Tab("Image Detection"):
with gr.Row():
# Input Column
with gr.Column():
image_input = gr.Image(label="Upload Image")
image_button = gr.Button("Detect Cavities", variant="primary")
# Output Column
with gr.Column():
image_output = gr.Image(label="Detection Result")
image_summary = gr.Textbox(label="Detection Summary")
# Add examples using Gradio's examples feature
gr.Examples(
examples=[
["image1.jpg"],
["image2.jpg"]
],
inputs=image_input,
label="Example Images - Click to use"
)
# Process button click
image_button.click(
fn=process_image,
inputs=image_input,
outputs=[image_output, image_summary]
)
with gr.Tab("Video Detection"):
with gr.Row():
# Input Column
with gr.Column():
video_input = gr.Video(label="Upload Video")
video_button = gr.Button("Process Video", variant="primary")
# Output Column
with gr.Column():
video_output = gr.Video(label="Processed Video")
gr.Markdown("""⚠️ Note: Video preview may not work in browser.
Please use the download button below the video to view results.""")
video_summary = gr.Textbox(label="Processing Summary")
# Add video example
gr.Examples(
examples=[
["video1.mp4"]
],
inputs=video_input,
label="Example Video - Click to use"
)
# Process button click
video_button.click(
fn=process_video,
inputs=video_input,
outputs=[video_output, video_summary]
)
# Add footer with instructions
gr.Markdown("""
### Instructions:
1. Choose Image or Video tab
2. Upload your own file or click an example below
3. Click the detection button to process
""")
if __name__ == "__main__":
demo.launch(debug=True)