Spaces:
Sleeping
Sleeping
File size: 3,386 Bytes
1527d63 cf074ed 9e119d3 1527d63 37f8717 cf074ed 1527d63 cf074ed 37f8717 6e65bcd baa12d9 cf074ed 37f8717 6e65bcd cf074ed 3ba8ee5 cf074ed 37f8717 cf074ed 2054f57 d167772 114187c d167772 cf074ed 6e65bcd 0120905 4b46a93 d167772 baa12d9 4b46a93 cf074ed d167772 0120905 baa12d9 cf074ed 37f8717 cf074ed 37f8717 cf074ed d92f8b6 37f8717 cf074ed 02d27db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import torch
import gradio as gr
import numpy as np
from torchvision.ops import nms
from PIL import Image
import cv2
# Load the model
model = torch.jit.load("best.torchscript")
model.eval()
# Define the detection function
def detect_salmon(image):
try:
# Preprocess the image
image_resized = Image.fromarray(image).resize((640, 640))
input_tensor = torch.from_numpy(np.array(image_resized).transpose(2, 0, 1) / 255.0).unsqueeze(0).float()
# Run inference
output = model(input_tensor)
detection_data = output[0][0].detach().numpy() # Remove batch dimension
# Filter detections by confidence threshold
conf_threshold = 0.5
filtered_detections = detection_data[detection_data[:, 4] >= conf_threshold]
# Define class names (update based on your classes)
class_names = ["background", "farmed", "wild"]
# Prepare boxes for NMS
boxes = []
confidences = []
labels = []
for detection in filtered_detections:
if len(detection) < 7: # Ensure detection has enough elements
continue
x_center, y_center, width, height = detection[:4]
confidence = detection[4]
class_probs = detection[5:] # Probabilities for all classes
# Get the predicted class by finding the max probability index
class_index = np.argmax(class_probs)
class_label = class_names[class_index]
x_min = int(x_center - width / 2.2)
y_min = int(y_center - height / 2.2)
x_max = int(x_center + width / 2.2)
y_max = int(y_center + height / 2.2)
boxes.append([x_min, y_min, x_max, y_max])
confidences.append(confidence)
labels.append(class_label)
if not boxes: # No valid boxes
raise ValueError("No detections with sufficient confidence.")
boxes_tensor = torch.tensor(boxes, dtype=torch.float32)
scores_tensor = torch.tensor(confidences, dtype=torch.float32)
# Apply NMS
iou_threshold = 0.5
nms_indices = nms(boxes_tensor, scores_tensor, iou_threshold)
nms_boxes = boxes_tensor[nms_indices].tolist()
nms_labels = [labels[i] for i in nms_indices]
# Draw bounding boxes
image_with_boxes = image.copy()
for i, box in enumerate(nms_boxes):
x_min, y_min, x_max, y_max = map(int, box)
label = nms_labels[i]
cv2.rectangle(image_with_boxes, (x_min, y_min), (x_max, y_max), (255, 0, 0), 2)
cv2.putText(image_with_boxes, label, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
return image_with_boxes
except Exception as e:
# Return error as text overlay on the image
image_with_error = image.copy()
cv2.putText(image_with_error, f"Error: {str(e)}", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
return image_with_error
# Define the Gradio interface
interface = gr.Interface(
fn=detect_salmon,
inputs=gr.Image(type="numpy", label="Upload Image"),
outputs=gr.Image(type="numpy", label="Output Image"),
title="Salmon Detection",
description="Upload an image to detect whether the salmon is farmed or wild."
)
# Launch the app
if __name__ == "__main__":
interface.launch()
|