Ayesha352 commited on
Commit
37c726a
·
verified ·
1 Parent(s): b1fdbd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -75
app.py CHANGED
@@ -3,20 +3,22 @@ import torch.nn as nn
3
  from torchvision import transforms
4
  from torchvision.models import convnext_tiny
5
  from ultralytics import YOLO
6
- from PIL import Image, ImageDraw
7
  import numpy as np
8
  import cv2
9
  import gradio as gr
 
10
  from fast_alpr import ALPR
11
 
12
- # ---------- 1. Class labels ----------
13
  class_names = [
14
  'beige', 'black', 'blue', 'brown', 'gold',
15
  'green', 'grey', 'orange', 'pink', 'purple',
16
  'red', 'silver', 'tan', 'white', 'yellow'
17
  ]
18
 
19
- # ---------- 2. Load ConvNeXt-Tiny Model ----------
 
 
20
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
  model = convnext_tiny(pretrained=False)
22
  model.classifier[2] = nn.Linear(768, len(class_names))
@@ -24,7 +26,6 @@ model.load_state_dict(torch.load("convnext_best_model.pth", map_location=device)
24
  model = model.to(device)
25
  model.eval()
26
 
27
- # ---------- 3. Image Transform ----------
28
  transform = transforms.Compose([
29
  transforms.Resize((512, 512)),
30
  transforms.ToTensor(),
@@ -32,91 +33,93 @@ transform = transforms.Compose([
32
  [0.229, 0.224, 0.225])
33
  ])
34
 
35
- # ---------- 4. Load YOLOv8 Model ----------
36
  yolo_model = YOLO("yolo11x.pt")
37
 
38
- # ---------- 5. ALPR Setup ----------
39
- DETECTOR_MODEL = "yolo-v9-s-608-license-plate-end2end"
40
- OCR_MODEL = "global-plates-mobile-vit-v2-model"
41
-
42
- def detect_vehicle_and_plate(input_img):
43
- if input_img is None:
44
- return None, None, None, None, "Please upload an image."
45
-
46
- # Convert to RGB
47
- img_original = input_img.convert("RGB")
48
- img_cv2 = cv2.cvtColor(np.array(img_original), cv2.COLOR_RGB2BGR)
49
-
50
- # ---------- Vehicle Detection ----------
51
- results = yolo_model(img_cv2)
52
- boxes = results[0].boxes
53
- vehicle_class_ids = {2, 3, 5, 7}
54
- vehicle_boxes = [box for box in boxes if int(box.cls.item()) in vehicle_class_ids]
55
-
56
- if not vehicle_boxes:
57
- return "No vehicle detected", img_original, img_original, img_original, "No plate detected."
58
-
59
- def box_area(box):
60
- x1, y1, x2, y2 = box.xyxy[0].tolist()
61
- return (x2 - x1) * (y2 - y1)
62
-
63
- largest_vehicle = max(vehicle_boxes, key=box_area)
64
- x1, y1, x2, y2 = map(int, largest_vehicle.xyxy[0].tolist())
65
- cropped = img_original.crop((x1, y1, x2, y2))
66
 
67
- input_tensor = transform(cropped).unsqueeze(0).to(device)
68
- with torch.no_grad():
69
- output = model(input_tensor)
70
- probs = torch.softmax(output, dim=1)[0]
71
- pred_idx = torch.argmax(probs).item()
72
- pred_class = class_names[pred_idx]
73
- confidence = probs[pred_idx].item()
74
-
75
- img_with_box = np.array(img_original).copy()
76
- cv2.rectangle(img_with_box, (x1, y1), (x2, y2), (255, 0, 0), 3)
77
- img_with_box_pil = Image.fromarray(img_with_box)
78
-
79
- # ---------- License Plate Detection ----------
80
  alpr = ALPR(detector_model=DETECTOR_MODEL, ocr_model=OCR_MODEL)
81
- results = alpr.predict(np.array(img_original))
82
- annotated_img = img_original.copy()
 
83
  draw = ImageDraw.Draw(annotated_img)
84
- final_text = ""
85
 
 
86
  for result in results:
87
  detection = getattr(result, 'detection', None)
88
  ocr = getattr(result, 'ocr', None)
89
- if detection and getattr(detection, 'bounding_box', None):
90
- bbox = detection.bounding_box
91
- box_coords = [int(bbox.x1), int(bbox.y1), int(bbox.x2), int(bbox.y2)]
92
- draw.rectangle(box_coords, outline="red", width=3)
93
- if ocr:
94
- text = ocr.text
95
- final_text += text + " "
96
- draw.text((box_coords[0], max(box_coords[1] - 10, 0)), text, fill="red")
97
-
98
- plate_result = f"Detected plate(s): {final_text.strip()}" if final_text else "No license plate detected."
99
-
100
- return f"{pred_class} ({confidence*100:.1f}%)", img_with_box_pil, cropped, annotated_img, plate_result
 
 
 
 
101
 
102
- # ---------- Gradio UI ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  with gr.Blocks() as demo:
104
- gr.Markdown("# 🚗 Vehicle Color + License Plate Detection")
105
  with gr.Row():
106
  with gr.Column():
107
- img_input = gr.Image(type="pil", label="Upload Vehicle Image")
108
- btn = gr.Button("Run Detection")
109
  with gr.Column():
110
- color_result = gr.Text(label="Predicted Vehicle Color")
111
- image_with_box = gr.Image(label="Detected Vehicle in Original")
112
- cropped_img = gr.Image(label="Cropped Vehicle Region")
113
- plate_annotated = gr.Image(label="License Plate Detection")
114
- plate_result = gr.Text(label="License Plate OCR")
115
-
116
- btn.click(
117
- detect_vehicle_and_plate,
118
- inputs=[img_input],
119
- outputs=[color_result, image_with_box, cropped_img, plate_annotated, plate_result]
 
 
 
 
 
 
 
 
 
 
120
  )
121
 
122
  if __name__ == "__main__":
 
3
  from torchvision import transforms
4
  from torchvision.models import convnext_tiny
5
  from ultralytics import YOLO
 
6
  import numpy as np
7
  import cv2
8
  import gradio as gr
9
+ from PIL import Image, ImageDraw
10
  from fast_alpr import ALPR
11
 
12
+ # ------------------ Constants and Models ------------------
13
  class_names = [
14
  'beige', 'black', 'blue', 'brown', 'gold',
15
  'green', 'grey', 'orange', 'pink', 'purple',
16
  'red', 'silver', 'tan', 'white', 'yellow'
17
  ]
18
 
19
+ DETECTOR_MODEL = "yolo-v9-s-608-license-plate-end2end"
20
+ OCR_MODEL = "global-plates-mobile-vit-v2-model"
21
+
22
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
  model = convnext_tiny(pretrained=False)
24
  model.classifier[2] = nn.Linear(768, len(class_names))
 
26
  model = model.to(device)
27
  model.eval()
28
 
 
29
  transform = transforms.Compose([
30
  transforms.Resize((512, 512)),
31
  transforms.ToTensor(),
 
33
  [0.229, 0.224, 0.225])
34
  ])
35
 
 
36
  yolo_model = YOLO("yolo11x.pt")
37
 
38
+ # ------------------ Unified Inference Function ------------------
39
+ def alpr_color_inference(image):
40
+ if image is None:
41
+ return None, None, None, "Please upload an image to continue."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ img = image.convert("RGB")
44
+ img_array = np.array(img)
 
 
 
 
 
 
 
 
 
 
 
45
  alpr = ALPR(detector_model=DETECTOR_MODEL, ocr_model=OCR_MODEL)
46
+ results = alpr.predict(img_array)
47
+
48
+ annotated_img = Image.fromarray(img_array.copy())
49
  draw = ImageDraw.Draw(annotated_img)
 
50
 
51
+ plate_texts = []
52
  for result in results:
53
  detection = getattr(result, 'detection', None)
54
  ocr = getattr(result, 'ocr', None)
55
+ if detection is not None:
56
+ bbox_obj = getattr(detection, 'bounding_box', None)
57
+ if bbox_obj is not None:
58
+ bbox = [int(bbox_obj.x1), int(bbox_obj.y1), int(bbox_obj.x2), int(bbox_obj.y2)]
59
+ draw.rectangle(bbox, outline="red", width=3)
60
+ if ocr is not None:
61
+ text = getattr(ocr, 'text', '')
62
+ plate_texts.append(text)
63
+ draw.text((bbox[0], max(bbox[1] - 10, 0)), text, fill="red")
64
+
65
+ # Color Detection
66
+ img_cv2 = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
67
+ yolo_results = yolo_model(img_cv2)
68
+ boxes = yolo_results[0].boxes
69
+ vehicle_class_ids = {2, 3, 5, 7} # car, motorcycle, bus, truck
70
+ vehicle_boxes = [box for box in boxes if int(box.cls.item()) in vehicle_class_ids]
71
 
72
+ if not vehicle_boxes:
73
+ color_text = "No vehicle detected"
74
+ cropped_img = img
75
+ vehicle_img = img
76
+ else:
77
+ largest_vehicle = max(vehicle_boxes, key=lambda box: (box.xyxy[0][2] - box.xyxy[0][0]) * (box.xyxy[0][3] - box.xyxy[0][1]))
78
+ x1, y1, x2, y2 = map(int, largest_vehicle.xyxy[0].tolist())
79
+ cropped_img = img.crop((x1, y1, x2, y2))
80
+ input_tensor = transform(cropped_img).unsqueeze(0).to(device)
81
+ with torch.no_grad():
82
+ output = model(input_tensor)
83
+ probs = torch.softmax(output, dim=1)[0]
84
+ pred_idx = torch.argmax(probs).item()
85
+ pred_class = class_names[pred_idx]
86
+ confidence = probs[pred_idx].item()
87
+ vehicle_img = Image.fromarray(cv2.rectangle(np.array(img), (x1, y1), (x2, y2), (255, 0, 0), 3))
88
+ color_text = f"{pred_class} ({confidence*100:.1f}%)"
89
+
90
+ detection_results = (f"Detected {len(results)} license plate(s): {', '.join(plate_texts)}"
91
+ if results else "No license plate detected 😔.")
92
+
93
+ return annotated_img, vehicle_img, cropped_img, f"{detection_results}\nVehicle Color: {color_text}"
94
+
95
+ # ------------------ Gradio UI ------------------
96
  with gr.Blocks() as demo:
97
+ gr.Markdown("# License Plate + Vehicle Color Detection")
98
  with gr.Row():
99
  with gr.Column():
100
+ image_input = gr.Image(type="pil", label="Upload an image")
101
+ submit_btn = gr.Button("Run Detection")
102
  with gr.Column():
103
+ plate_output = gr.Image(label="License Plate Detection")
104
+ vehicle_output = gr.Image(label="Detected Vehicle in Original")
105
+ cropped_output = gr.Image(label="Cropped Vehicle Region")
106
+ result_text = gr.Markdown(label="Results")
107
+
108
+ submit_btn.click(
109
+ alpr_color_inference,
110
+ inputs=[image_input],
111
+ outputs=[plate_output, vehicle_output, cropped_output, result_text]
112
+ )
113
+
114
+ gr.Examples(
115
+ examples=[
116
+ "examples/car1.jpg",
117
+ "examples/car2.jpg",
118
+ "examples/car3.jpg",
119
+ "examples/car4.jpg",
120
+ ],
121
+ inputs=[image_input],
122
+ label="Example Images"
123
  )
124
 
125
  if __name__ == "__main__":