fastapi-demo / app.py
{{Your Name}}
add demo
d25d814
# app.py
import io
import uvicorn
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from typing import List
from fastapi import FastAPI, UploadFile, File, Request, Form
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import onnxruntime as ort
import cv2
from huggingface_hub import hf_hub_download
import os
import uuid
# --- FastAPI and Template Setup ---
app = FastAPI(title="YOLOv8 ONNX Object Detection Demo")
# Mount a static directory to serve saved images
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
# --- Model Loading and Configuration ---
# Download the ONNX model file and get its path
try:
onnx_model_path = hf_hub_download(repo_id="tententgc/Iskyn", filename="best.onnx")
session = ort.InferenceSession(onnx_model_path)
print("ONNX model loaded successfully.")
except Exception as e:
print(f"Failed to load ONNX model: {e}")
session = None
if session:
input_name = session.get_inputs()[0].name
output_names = [output.name for output in session.get_outputs()]
input_shape = session.get_inputs()[0].shape[2:] # Get the expected image size
else:
input_name = None
output_names = []
input_shape = (640, 640) # Default size if model fails to load
# Define the class names for your model
# IMPORTANT: Update this with the actual class names your model was trained on
CLASSES = [
"melasma", "acne", "wrinkle"
]
# A dictionary to map class names to colors for plotting
COLORS = {
"melasma": "red",
"acne": "green",
"wrinkle": "blue",
# Add more classes and colors as needed
}
# --- Helper Functions ---
def preprocess_image(image: Image.Image, size: tuple) -> np.ndarray:
"""Preprocesses an image for model inference."""
image = image.resize(size)
image = np.array(image)
image = image.transpose(2, 0, 1) # HWC to CHW
image = np.expand_dims(image, axis=0) # Add batch dimension
image = image.astype(np.float32) / 255.0 # Normalize
return image
def postprocess_output(output, original_size, input_shape, conf_threshold=0.25, iou_threshold=0.45):
"""Post-processes the model output to get bounding boxes, scores, and class IDs."""
output = np.squeeze(output).T
scores = np.max(output[:, 4:], axis=1)
filtered_indices = scores > conf_threshold
output = output[filtered_indices]
scores = scores[filtered_indices]
if not len(output):
return []
boxes = output[:, :4]
boxes[:, 0] -= boxes[:, 2] / 2
boxes[:, 1] -= boxes[:, 3] / 2
boxes[:, 2] += boxes[:, 0]
boxes[:, 3] += boxes[:, 1]
class_ids = np.argmax(output[:, 4:], axis=1)
indices = cv2.dnn.NMSBoxes(boxes.astype(np.int32), scores.astype(np.float32), conf_threshold, iou_threshold)
detections = []
if len(indices) > 0:
for i in indices.flatten():
box = boxes[i]
x1, y1, x2, y2 = box.astype(int)
class_id = class_ids[i]
score = scores[i]
original_width, original_height = original_size
resized_width, resized_height = input_shape
x1 = int(x1 * original_width / resized_width)
y1 = int(y1 * original_height / resized_height)
x2 = int(x2 * original_width / resized_width)
y2 = int(y2 * original_height / resized_height)
detections.append({
"class_name": CLASSES[class_id],
"confidence": float(score),
"box": [x1, y1, x2, y2]
})
return detections
def draw_boxes_on_image(image, detections):
"""Draws bounding boxes, class names, and confidence scores on an image."""
draw = ImageDraw.Draw(image)
try:
font = ImageFont.truetype("arial.ttf", 30)
except IOError:
font = ImageFont.load_default()
print("Arial font not found, using default font.")
for detection in detections:
box = detection['box']
class_name = detection['class_name']
confidence = detection['confidence']
color = COLORS.get(class_name, "white")
draw.rectangle(box, outline=color, width=3)
label = f"{class_name}: {confidence:.2f}"
# Use textbbox() to get text dimensions
text_x, text_y, text_width, text_height = draw.textbbox((0, 0), label, font=font)
# Position text slightly above the top-left corner
text_position_y = box[1] - text_height - 5
if text_position_y < 0:
text_position_y = box[1] + 5 # Draw below if not enough space above
draw.rectangle([box[0], text_position_y, box[0] + text_width, text_position_y + text_height], fill=color)
draw.text((box[0], text_position_y), label, fill="black", font=font)
return image
# --- FastAPI Endpoints ---
@app.get("/", response_class=HTMLResponse)
async def read_root(request: Request):
"""Serve the HTML interface."""
return templates.TemplateResponse("index.html", {"request": request, "image_url": None, "error_message": None})
@app.post("/predict_web", response_class=HTMLResponse)
async def predict_web(request: Request, file: UploadFile = File(...)):
"""Handle image upload, run detection, and return plotted image."""
if not session:
return templates.TemplateResponse("index.html", {"request": request, "error_message": "ONNX model not loaded."})
if not file.content_type.startswith("image/"):
return templates.TemplateResponse("index.html", {"request": request, "error_message": "Invalid file type. Please upload an image."})
try:
image_data = await file.read()
image = Image.open(io.BytesIO(image_data)).convert("RGB")
original_size = image.size
# Preprocess, run inference, and post-process
preprocessed_image = preprocess_image(image, size=input_shape)
outputs = session.run(output_names, {input_name: preprocessed_image})
detections = postprocess_output(outputs, original_size, input_shape)
# Draw boxes on the original image
plotted_image = draw_boxes_on_image(image.copy(), detections)
# Create a unique filename and save the plotted image
unique_filename = f"{uuid.uuid4()}.jpg"
output_image_path = os.path.join("static", "output", unique_filename)
plotted_image.save(output_image_path)
image_url = f"/static/output/{unique_filename}"
return templates.TemplateResponse("index.html", {"request": request, "image_url": image_url})
except Exception as e:
return templates.TemplateResponse("index.html", {"request": request, "error_message": f"An error occurred: {e}"})
if __name__ == "__main__":
# Create the static/output directory if it doesn't exist
os.makedirs(os.path.join("static", "output"), exist_ok=True)
uvicorn.run(app, host="127.0.0.1", port=8000)