Spaces:
Sleeping
Sleeping
| # app.py | |
| import io | |
| import uvicorn | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFont | |
| from typing import List | |
| from fastapi import FastAPI, UploadFile, File, Request, Form | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| import onnxruntime as ort | |
| import cv2 | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| import uuid | |
| # --- FastAPI and Template Setup --- | |
| app = FastAPI(title="YOLOv8 ONNX Object Detection Demo") | |
| # Mount a static directory to serve saved images | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| # --- Model Loading and Configuration --- | |
| # Download the ONNX model file and get its path | |
| try: | |
| onnx_model_path = hf_hub_download(repo_id="tententgc/Iskyn", filename="best.onnx") | |
| session = ort.InferenceSession(onnx_model_path) | |
| print("ONNX model loaded successfully.") | |
| except Exception as e: | |
| print(f"Failed to load ONNX model: {e}") | |
| session = None | |
| if session: | |
| input_name = session.get_inputs()[0].name | |
| output_names = [output.name for output in session.get_outputs()] | |
| input_shape = session.get_inputs()[0].shape[2:] # Get the expected image size | |
| else: | |
| input_name = None | |
| output_names = [] | |
| input_shape = (640, 640) # Default size if model fails to load | |
| # Define the class names for your model | |
| # IMPORTANT: Update this with the actual class names your model was trained on | |
| CLASSES = [ | |
| "melasma", "acne", "wrinkle" | |
| ] | |
| # A dictionary to map class names to colors for plotting | |
| COLORS = { | |
| "melasma": "red", | |
| "acne": "green", | |
| "wrinkle": "blue", | |
| # Add more classes and colors as needed | |
| } | |
| # --- Helper Functions --- | |
| def preprocess_image(image: Image.Image, size: tuple) -> np.ndarray: | |
| """Preprocesses an image for model inference.""" | |
| image = image.resize(size) | |
| image = np.array(image) | |
| image = image.transpose(2, 0, 1) # HWC to CHW | |
| image = np.expand_dims(image, axis=0) # Add batch dimension | |
| image = image.astype(np.float32) / 255.0 # Normalize | |
| return image | |
| def postprocess_output(output, original_size, input_shape, conf_threshold=0.25, iou_threshold=0.45): | |
| """Post-processes the model output to get bounding boxes, scores, and class IDs.""" | |
| output = np.squeeze(output).T | |
| scores = np.max(output[:, 4:], axis=1) | |
| filtered_indices = scores > conf_threshold | |
| output = output[filtered_indices] | |
| scores = scores[filtered_indices] | |
| if not len(output): | |
| return [] | |
| boxes = output[:, :4] | |
| boxes[:, 0] -= boxes[:, 2] / 2 | |
| boxes[:, 1] -= boxes[:, 3] / 2 | |
| boxes[:, 2] += boxes[:, 0] | |
| boxes[:, 3] += boxes[:, 1] | |
| class_ids = np.argmax(output[:, 4:], axis=1) | |
| indices = cv2.dnn.NMSBoxes(boxes.astype(np.int32), scores.astype(np.float32), conf_threshold, iou_threshold) | |
| detections = [] | |
| if len(indices) > 0: | |
| for i in indices.flatten(): | |
| box = boxes[i] | |
| x1, y1, x2, y2 = box.astype(int) | |
| class_id = class_ids[i] | |
| score = scores[i] | |
| original_width, original_height = original_size | |
| resized_width, resized_height = input_shape | |
| x1 = int(x1 * original_width / resized_width) | |
| y1 = int(y1 * original_height / resized_height) | |
| x2 = int(x2 * original_width / resized_width) | |
| y2 = int(y2 * original_height / resized_height) | |
| detections.append({ | |
| "class_name": CLASSES[class_id], | |
| "confidence": float(score), | |
| "box": [x1, y1, x2, y2] | |
| }) | |
| return detections | |
| def draw_boxes_on_image(image, detections): | |
| """Draws bounding boxes, class names, and confidence scores on an image.""" | |
| draw = ImageDraw.Draw(image) | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 30) | |
| except IOError: | |
| font = ImageFont.load_default() | |
| print("Arial font not found, using default font.") | |
| for detection in detections: | |
| box = detection['box'] | |
| class_name = detection['class_name'] | |
| confidence = detection['confidence'] | |
| color = COLORS.get(class_name, "white") | |
| draw.rectangle(box, outline=color, width=3) | |
| label = f"{class_name}: {confidence:.2f}" | |
| # Use textbbox() to get text dimensions | |
| text_x, text_y, text_width, text_height = draw.textbbox((0, 0), label, font=font) | |
| # Position text slightly above the top-left corner | |
| text_position_y = box[1] - text_height - 5 | |
| if text_position_y < 0: | |
| text_position_y = box[1] + 5 # Draw below if not enough space above | |
| draw.rectangle([box[0], text_position_y, box[0] + text_width, text_position_y + text_height], fill=color) | |
| draw.text((box[0], text_position_y), label, fill="black", font=font) | |
| return image | |
| # --- FastAPI Endpoints --- | |
| async def read_root(request: Request): | |
| """Serve the HTML interface.""" | |
| return templates.TemplateResponse("index.html", {"request": request, "image_url": None, "error_message": None}) | |
| async def predict_web(request: Request, file: UploadFile = File(...)): | |
| """Handle image upload, run detection, and return plotted image.""" | |
| if not session: | |
| return templates.TemplateResponse("index.html", {"request": request, "error_message": "ONNX model not loaded."}) | |
| if not file.content_type.startswith("image/"): | |
| return templates.TemplateResponse("index.html", {"request": request, "error_message": "Invalid file type. Please upload an image."}) | |
| try: | |
| image_data = await file.read() | |
| image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
| original_size = image.size | |
| # Preprocess, run inference, and post-process | |
| preprocessed_image = preprocess_image(image, size=input_shape) | |
| outputs = session.run(output_names, {input_name: preprocessed_image}) | |
| detections = postprocess_output(outputs, original_size, input_shape) | |
| # Draw boxes on the original image | |
| plotted_image = draw_boxes_on_image(image.copy(), detections) | |
| # Create a unique filename and save the plotted image | |
| unique_filename = f"{uuid.uuid4()}.jpg" | |
| output_image_path = os.path.join("static", "output", unique_filename) | |
| plotted_image.save(output_image_path) | |
| image_url = f"/static/output/{unique_filename}" | |
| return templates.TemplateResponse("index.html", {"request": request, "image_url": image_url}) | |
| except Exception as e: | |
| return templates.TemplateResponse("index.html", {"request": request, "error_message": f"An error occurred: {e}"}) | |
| if __name__ == "__main__": | |
| # Create the static/output directory if it doesn't exist | |
| os.makedirs(os.path.join("static", "output"), exist_ok=True) | |
| uvicorn.run(app, host="127.0.0.1", port=8000) | |