Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from torchvision.models import convnext_tiny | |
| from ultralytics import YOLO | |
| import numpy as np | |
| import cv2 | |
| import gradio as gr | |
| from PIL import Image, ImageDraw | |
| from fast_alpr import ALPR | |
| # ------------------ Constants and Models ------------------ | |
| class_names = [ | |
| 'beige', 'black', 'blue', 'brown', 'gold', | |
| 'green', 'grey', 'orange', 'pink', 'purple', | |
| 'red', 'silver', 'tan', 'white', 'yellow' | |
| ] | |
| DETECTOR_MODEL = "yolo-v9-s-608-license-plate-end2end" | |
| OCR_MODEL = "global-plates-mobile-vit-v2-model" | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = convnext_tiny(pretrained=False) | |
| model.classifier[2] = nn.Linear(768, len(class_names)) | |
| model.load_state_dict(torch.load("convnext_best_model.pth", map_location=device)) | |
| model = model.to(device) | |
| model.eval() | |
| transform = transforms.Compose([ | |
| transforms.Resize((512, 512)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]) | |
| ]) | |
| yolo_model = YOLO("yolo11x.pt") | |
| # ------------------ Unified Inference Function ------------------ | |
| def alpr_color_inference(image): | |
| if image is None: | |
| return None, None, None, "Please upload an image to continue." | |
| img = image.convert("RGB") | |
| img_array = np.array(img) | |
| alpr = ALPR(detector_model=DETECTOR_MODEL, ocr_model=OCR_MODEL) | |
| results = alpr.predict(img_array) | |
| annotated_img = Image.fromarray(img_array.copy()) | |
| draw = ImageDraw.Draw(annotated_img) | |
| plate_texts = [] | |
| for result in results: | |
| detection = getattr(result, 'detection', None) | |
| ocr = getattr(result, 'ocr', None) | |
| if detection is not None: | |
| bbox_obj = getattr(detection, 'bounding_box', None) | |
| if bbox_obj is not None: | |
| bbox = [int(bbox_obj.x1), int(bbox_obj.y1), int(bbox_obj.x2), int(bbox_obj.y2)] | |
| draw.rectangle(bbox, outline="red", width=3) | |
| if ocr is not None: | |
| text = getattr(ocr, 'text', '') | |
| plate_texts.append(text) | |
| draw.text((bbox[0], max(bbox[1] - 10, 0)), text, fill="red") | |
| # Color Detection | |
| img_cv2 = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
| yolo_results = yolo_model(img_cv2) | |
| boxes = yolo_results[0].boxes | |
| vehicle_class_ids = {2, 3, 5, 7} # car, motorcycle, bus, truck | |
| vehicle_boxes = [box for box in boxes if int(box.cls.item()) in vehicle_class_ids] | |
| if not vehicle_boxes: | |
| color_text = "No vehicle detected" | |
| cropped_img = img | |
| vehicle_img = img | |
| else: | |
| largest_vehicle = max(vehicle_boxes, key=lambda box: (box.xyxy[0][2] - box.xyxy[0][0]) * (box.xyxy[0][3] - box.xyxy[0][1])) | |
| x1, y1, x2, y2 = map(int, largest_vehicle.xyxy[0].tolist()) | |
| cropped_img = img.crop((x1, y1, x2, y2)) | |
| input_tensor = transform(cropped_img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| probs = torch.softmax(output, dim=1)[0] | |
| pred_idx = torch.argmax(probs).item() | |
| pred_class = class_names[pred_idx] | |
| confidence = probs[pred_idx].item() | |
| vehicle_img = Image.fromarray(cv2.rectangle(np.array(img), (x1, y1), (x2, y2), (255, 0, 0), 3)) | |
| color_text = f"{pred_class} ({confidence*100:.1f}%)" | |
| detection_results = (f"Detected {len(results)} license plate(s): {', '.join(plate_texts)}" | |
| if results else "No license plate detected π.") | |
| return annotated_img, vehicle_img, cropped_img, f"{detection_results}\nVehicle Color: {color_text}" | |
| # ------------------ Gradio UI ------------------ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# License Plate + Vehicle Color Detection") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload an image") | |
| submit_btn = gr.Button("Run Detection") | |
| with gr.Column(): | |
| plate_output = gr.Image(label="License Plate Detection") | |
| vehicle_output = gr.Image(label="Detected Vehicle in Original") | |
| cropped_output = gr.Image(label="Cropped Vehicle Region") | |
| result_text = gr.Markdown(label="Results") | |
| submit_btn.click( | |
| alpr_color_inference, | |
| inputs=[image_input], | |
| outputs=[plate_output, vehicle_output, cropped_output, result_text] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| "examples/car1.jpg", | |
| "examples/car2.jpg", | |
| "examples/car3.jpg", | |
| "examples/car4.jpg", | |
| ], | |
| inputs=[image_input], | |
| label="Example Images" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |