8186276W / app.py
WgJL's picture
Update app.py
85c8f23 verified
from ultralytics import YOLO
from PIL import Image
import gradio as gr
from huggingface_hub import snapshot_download
import os
import cv2
import numpy as np
def load_model(repo_id):
"""Download and load the YOLO model."""
download_dir = snapshot_download(repo_id)
path = os.path.join(download_dir, "model_- 5 january 2025 0_48_openvino_model")
detection_model = YOLO(path, task='detect')
return detection_model
def predict_image(pilimg, conf_threshold, iou_threshold):
"""Process an image with user-defined thresholds."""
result = detection_model.predict(pilimg, conf=conf_threshold, iou=iou_threshold)
img_bgr = result[0].plot()
out_pilimg = Image.fromarray(img_bgr[..., ::-1]) # Convert to RGB PIL image
return out_pilimg
def predict_video(video_file, conf_threshold, iou_threshold, start_time, end_time):
"""Process a video with user-defined thresholds and time range."""
cap = cv2.VideoCapture(video_file)
# Get video properties
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
output_path = "output_video.mp4"
# Calculate frame range based on start and end times
start_frame = int(start_time * fps) if start_time else 0
end_frame = int(end_time * fps) if end_time else total_frames
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) # Start processing at the desired frame
# Initialize VideoWriter
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
while cap.isOpened():
current_frame = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
if current_frame >= end_frame:
break
ret, frame = cap.read()
if not ret:
break
# Resize frame to 640x640 for faster inference
resized_frame = cv2.resize(frame, (640, 640))
# Perform detection on resized frame
result = detection_model.predict(resized_frame, conf=conf_threshold, iou=iou_threshold)
# Get processed frame and resize back to original dimensions
output_frame = result[0].plot()
output_frame = cv2.resize(output_frame, (frame_width, frame_height))
# Write the processed frame to output video
out.write(output_frame)
cap.release()
return output_path
# Load the model
REPO_ID = "WgJL/107_Assignment.v5i.yolov11"
detection_model = load_model(REPO_ID)
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Lemons and Crabs Detection")
with gr.Tab("Image Input"):
img_input = gr.Image(type="pil", label="Upload an Image")
conf_slider_img = gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold")
iou_slider_img = gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold")
img_output = gr.Image(type="pil", label="Processed Image")
img_submit = gr.Button("Process Image")
img_submit.click(
predict_image,
inputs=[img_input, conf_slider_img, iou_slider_img],
outputs=img_output
)
with gr.Tab("Video Input"):
video_input = gr.Video(label="Upload a Video")
conf_slider_video = gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold")
iou_slider_video = gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold")
start_time = gr.Number(value=0, label="Start Time (seconds)")
end_time = gr.Number(value=0, label="End Time (seconds, 0 for full video)")
video_output = gr.Video(label="Processed Video")
video_submit = gr.Button("Process Video")
video_submit.click(
predict_video,
inputs=[video_input, conf_slider_video, iou_slider_video, start_time, end_time],
outputs=video_output
)
demo.launch(share=True)