File size: 7,033 Bytes
d25d814
 
 
2c4e7b0
d25d814
 
 
 
 
 
 
2c4e7b0
 
d25d814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c4e7b0
d25d814
 
 
 
 
2c4e7b0
d25d814
 
2c4e7b0
d25d814
 
 
 
 
 
 
2c4e7b0
d25d814
 
 
 
 
 
2c4e7b0
d25d814
 
 
 
 
 
2c4e7b0
d25d814
 
 
 
 
 
 
 
2c4e7b0
d25d814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c4e7b0
d25d814
 
 
 
 
2c4e7b0
d25d814
 
 
 
 
 
 
 
2c4e7b0
d25d814
 
 
 
2c4e7b0
d25d814
 
 
 
2c4e7b0
d25d814
 
2c4e7b0
d25d814
 
 
 
 
 
 
 
 
 
 
2c4e7b0
d25d814
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# app.py
import io
import uvicorn
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from typing import List
from fastapi import FastAPI, UploadFile, File, Request, Form
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import onnxruntime as ort
import cv2
from huggingface_hub import hf_hub_download
import os
import uuid

# --- FastAPI and Template Setup ---
app = FastAPI(title="YOLOv8 ONNX Object Detection Demo")

# Mount a static directory to serve saved images
app.mount("/static", StaticFiles(directory="static"), name="static")

templates = Jinja2Templates(directory="templates")

# --- Model Loading and Configuration ---
# Download the ONNX model file and get its path
try:
    onnx_model_path = hf_hub_download(repo_id="tententgc/Iskyn", filename="best.onnx")
    session = ort.InferenceSession(onnx_model_path)
    print("ONNX model loaded successfully.")
except Exception as e:
    print(f"Failed to load ONNX model: {e}")
    session = None

if session:
    input_name = session.get_inputs()[0].name
    output_names = [output.name for output in session.get_outputs()]
    input_shape = session.get_inputs()[0].shape[2:]  # Get the expected image size
else:
    input_name = None
    output_names = []
    input_shape = (640, 640)  # Default size if model fails to load

# Define the class names for your model
# IMPORTANT: Update this with the actual class names your model was trained on
CLASSES = [
    "melasma", "acne", "wrinkle" 
]

# A dictionary to map class names to colors for plotting
COLORS = {
    "melasma": "red",
    "acne": "green",
    "wrinkle": "blue",
    # Add more classes and colors as needed
}

# --- Helper Functions ---
def preprocess_image(image: Image.Image, size: tuple) -> np.ndarray:
    """Preprocesses an image for model inference."""
    image = image.resize(size)
    image = np.array(image)
    image = image.transpose(2, 0, 1)  # HWC to CHW
    image = np.expand_dims(image, axis=0) # Add batch dimension
    image = image.astype(np.float32) / 255.0  # Normalize
    return image

def postprocess_output(output, original_size, input_shape, conf_threshold=0.25, iou_threshold=0.45):
    """Post-processes the model output to get bounding boxes, scores, and class IDs."""
    output = np.squeeze(output).T
    scores = np.max(output[:, 4:], axis=1)
    filtered_indices = scores > conf_threshold
    output = output[filtered_indices]
    scores = scores[filtered_indices]

    if not len(output):
        return []

    boxes = output[:, :4]
    boxes[:, 0] -= boxes[:, 2] / 2
    boxes[:, 1] -= boxes[:, 3] / 2
    boxes[:, 2] += boxes[:, 0]
    boxes[:, 3] += boxes[:, 1]

    class_ids = np.argmax(output[:, 4:], axis=1)
    indices = cv2.dnn.NMSBoxes(boxes.astype(np.int32), scores.astype(np.float32), conf_threshold, iou_threshold)

    detections = []
    if len(indices) > 0:
        for i in indices.flatten():
            box = boxes[i]
            x1, y1, x2, y2 = box.astype(int)
            class_id = class_ids[i]
            score = scores[i]

            original_width, original_height = original_size
            resized_width, resized_height = input_shape
            x1 = int(x1 * original_width / resized_width)
            y1 = int(y1 * original_height / resized_height)
            x2 = int(x2 * original_width / resized_width)
            y2 = int(y2 * original_height / resized_height)

            detections.append({
                "class_name": CLASSES[class_id],
                "confidence": float(score),
                "box": [x1, y1, x2, y2]
            })
    return detections

def draw_boxes_on_image(image, detections):
    """Draws bounding boxes, class names, and confidence scores on an image."""
    draw = ImageDraw.Draw(image)
    try:
        font = ImageFont.truetype("arial.ttf", 30)
    except IOError:
        font = ImageFont.load_default()
        print("Arial font not found, using default font.")

    for detection in detections:
        box = detection['box']
        class_name = detection['class_name']
        confidence = detection['confidence']
        
        color = COLORS.get(class_name, "white")
        draw.rectangle(box, outline=color, width=3)
        
        label = f"{class_name}: {confidence:.2f}"
        
        # Use textbbox() to get text dimensions
        text_x, text_y, text_width, text_height = draw.textbbox((0, 0), label, font=font)
        
        # Position text slightly above the top-left corner
        text_position_y = box[1] - text_height - 5
        if text_position_y < 0:
            text_position_y = box[1] + 5 # Draw below if not enough space above
            
        draw.rectangle([box[0], text_position_y, box[0] + text_width, text_position_y + text_height], fill=color)
        draw.text((box[0], text_position_y), label, fill="black", font=font)
    return image

# --- FastAPI Endpoints ---
@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
    """Serve the HTML interface."""
    return templates.TemplateResponse("index.html", {"request": request, "image_url": None, "error_message": None})

@app.post("/predict_web", response_class=HTMLResponse)
async def predict_web(request: Request, file: UploadFile = File(...)):
    """Handle image upload, run detection, and return plotted image."""
    if not session:
        return templates.TemplateResponse("index.html", {"request": request, "error_message": "ONNX model not loaded."})
    
    if not file.content_type.startswith("image/"):
        return templates.TemplateResponse("index.html", {"request": request, "error_message": "Invalid file type. Please upload an image."})

    try:
        image_data = await file.read()
        image = Image.open(io.BytesIO(image_data)).convert("RGB")
        original_size = image.size

        # Preprocess, run inference, and post-process
        preprocessed_image = preprocess_image(image, size=input_shape)
        outputs = session.run(output_names, {input_name: preprocessed_image})
        detections = postprocess_output(outputs, original_size, input_shape)

        # Draw boxes on the original image
        plotted_image = draw_boxes_on_image(image.copy(), detections)

        # Create a unique filename and save the plotted image
        unique_filename = f"{uuid.uuid4()}.jpg"
        output_image_path = os.path.join("static", "output", unique_filename)
        plotted_image.save(output_image_path)
        
        image_url = f"/static/output/{unique_filename}"
        
        return templates.TemplateResponse("index.html", {"request": request, "image_url": image_url})
    
    except Exception as e:
        return templates.TemplateResponse("index.html", {"request": request, "error_message": f"An error occurred: {e}"})

if __name__ == "__main__":
    # Create the static/output directory if it doesn't exist
    os.makedirs(os.path.join("static", "output"), exist_ok=True)
    uvicorn.run(app, host="127.0.0.1", port=8000)