Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
import numpy as np | |
from torchvision.ops import nms | |
from PIL import Image | |
import cv2 | |
# Load the model | |
model = torch.jit.load("best.torchscript") | |
model.eval() | |
# Define the detection function | |
def detect_salmon(image): | |
try: | |
# Preprocess the image | |
image_resized = Image.fromarray(image).resize((640, 640)) | |
input_tensor = torch.from_numpy(np.array(image_resized).transpose(2, 0, 1) / 255.0).unsqueeze(0).float() | |
# Run inference | |
output = model(input_tensor) | |
detection_data = output[0][0].detach().numpy() # Remove batch dimension | |
# Filter detections by confidence threshold | |
conf_threshold = 0.5 | |
filtered_detections = detection_data[detection_data[:, 4] >= conf_threshold] | |
# Define class names (update based on your classes) | |
class_names = ["background", "farmed", "wild"] | |
# Prepare boxes for NMS | |
boxes = [] | |
confidences = [] | |
labels = [] | |
for detection in filtered_detections: | |
if len(detection) < 7: # Ensure detection has enough elements | |
continue | |
x_center, y_center, width, height = detection[:4] | |
confidence = detection[4] | |
class_probs = detection[5:] # Probabilities for all classes | |
# Get the predicted class by finding the max probability index | |
class_index = np.argmax(class_probs) | |
class_label = class_names[class_index] | |
x_min = int(x_center - width / 2.2) | |
y_min = int(y_center - height / 2.2) | |
x_max = int(x_center + width / 2.2) | |
y_max = int(y_center + height / 2.2) | |
boxes.append([x_min, y_min, x_max, y_max]) | |
confidences.append(confidence) | |
labels.append(class_label) | |
if not boxes: # No valid boxes | |
raise ValueError("No detections with sufficient confidence.") | |
boxes_tensor = torch.tensor(boxes, dtype=torch.float32) | |
scores_tensor = torch.tensor(confidences, dtype=torch.float32) | |
# Apply NMS | |
iou_threshold = 0.5 | |
nms_indices = nms(boxes_tensor, scores_tensor, iou_threshold) | |
nms_boxes = boxes_tensor[nms_indices].tolist() | |
nms_labels = [labels[i] for i in nms_indices] | |
# Draw bounding boxes | |
image_with_boxes = image.copy() | |
for i, box in enumerate(nms_boxes): | |
x_min, y_min, x_max, y_max = map(int, box) | |
label = nms_labels[i] | |
cv2.rectangle(image_with_boxes, (x_min, y_min), (x_max, y_max), (255, 0, 0), 2) | |
cv2.putText(image_with_boxes, label, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2) | |
return image_with_boxes | |
except Exception as e: | |
# Return error as text overlay on the image | |
image_with_error = image.copy() | |
cv2.putText(image_with_error, f"Error: {str(e)}", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2) | |
return image_with_error | |
# Define the Gradio interface | |
interface = gr.Interface( | |
fn=detect_salmon, | |
inputs=gr.Image(type="numpy", label="Upload Image"), | |
outputs=gr.Image(type="numpy", label="Output Image"), | |
title="Salmon Detection", | |
description="Upload an image to detect whether the salmon is farmed or wild." | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
interface.launch() | |