File size: 3,386 Bytes
1527d63
cf074ed
 
9e119d3
1527d63
 
 
37f8717
cf074ed
1527d63
 
cf074ed
37f8717
6e65bcd
baa12d9
cf074ed
37f8717
6e65bcd
 
 
cf074ed
3ba8ee5
cf074ed
37f8717
cf074ed
2054f57
d167772
114187c
d167772
cf074ed
 
 
 
6e65bcd
0120905
 
 
 
4b46a93
 
d167772
 
 
 
baa12d9
 
 
 
4b46a93
cf074ed
 
d167772
0120905
 
baa12d9
cf074ed
 
 
37f8717
 
cf074ed
37f8717
cf074ed
 
d92f8b6
37f8717
cf074ed
 
 
 
02d27db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import gradio as gr
import numpy as np
from torchvision.ops import nms
from PIL import Image
import cv2

# Load the model
model = torch.jit.load("best.torchscript")
model.eval()

# Define the detection function
def detect_salmon(image):
    try:
        # Preprocess the image
        image_resized = Image.fromarray(image).resize((640, 640))
        input_tensor = torch.from_numpy(np.array(image_resized).transpose(2, 0, 1) / 255.0).unsqueeze(0).float()

        # Run inference
        output = model(input_tensor)
        detection_data = output[0][0].detach().numpy()  # Remove batch dimension

        # Filter detections by confidence threshold
        conf_threshold = 0.5
        filtered_detections = detection_data[detection_data[:, 4] >= conf_threshold]

        # Define class names (update based on your classes)
        class_names = ["background", "farmed", "wild"]

        # Prepare boxes for NMS
        boxes = []
        confidences = []
        labels = []
        for detection in filtered_detections:
            if len(detection) < 7:  # Ensure detection has enough elements
                continue
            x_center, y_center, width, height = detection[:4]
            confidence = detection[4]
            class_probs = detection[5:]  # Probabilities for all classes

            # Get the predicted class by finding the max probability index
            class_index = np.argmax(class_probs)
            class_label = class_names[class_index]

            x_min = int(x_center - width / 2.2)
            y_min = int(y_center - height / 2.2)
            x_max = int(x_center + width / 2.2)
            y_max = int(y_center + height / 2.2)

            boxes.append([x_min, y_min, x_max, y_max])
            confidences.append(confidence)
            labels.append(class_label)

        if not boxes:  # No valid boxes
            raise ValueError("No detections with sufficient confidence.")

        boxes_tensor = torch.tensor(boxes, dtype=torch.float32)
        scores_tensor = torch.tensor(confidences, dtype=torch.float32)

        # Apply NMS
        iou_threshold = 0.5
        nms_indices = nms(boxes_tensor, scores_tensor, iou_threshold)
        nms_boxes = boxes_tensor[nms_indices].tolist()
        nms_labels = [labels[i] for i in nms_indices]

        # Draw bounding boxes
        image_with_boxes = image.copy()
        for i, box in enumerate(nms_boxes):
            x_min, y_min, x_max, y_max = map(int, box)
            label = nms_labels[i]
            cv2.rectangle(image_with_boxes, (x_min, y_min), (x_max, y_max), (255, 0, 0), 2)
            cv2.putText(image_with_boxes, label, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)

        return image_with_boxes

    except Exception as e:
        # Return error as text overlay on the image
        image_with_error = image.copy()
        cv2.putText(image_with_error, f"Error: {str(e)}", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
        return image_with_error

# Define the Gradio interface
interface = gr.Interface(
    fn=detect_salmon,
    inputs=gr.Image(type="numpy", label="Upload Image"),
    outputs=gr.Image(type="numpy", label="Output Image"),
    title="Salmon Detection",
    description="Upload an image to detect whether the salmon is farmed or wild."
)

# Launch the app
if __name__ == "__main__":
    interface.launch()