BrunoMelicio commited on
Commit
8d420f3
·
1 Parent(s): 048ce99

Added visualization for bbox predictions

Browse files
Files changed (2) hide show
  1. app.py +35 -10
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,19 +1,44 @@
1
  import gradio as gr
2
  import mediapipe as mp
3
- #from mediapipe.tasks import python
4
- #from mediapipe.tasks.python import vision
5
 
6
  BaseOptions = mp.tasks.BaseOptions
7
  ObjectDetector = mp.tasks.vision.ObjectDetector
8
  ObjectDetectorOptions = mp.tasks.vision.ObjectDetectorOptions
9
  VisionRunningMode = mp.tasks.vision.RunningMode
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def analyze_image(image):
12
  model_path = "efficientdet_lite0.tflite"
13
 
14
  options = ObjectDetectorOptions(
15
  base_options=BaseOptions(model_asset_path=model_path),
16
- max_results=1,
17
  running_mode=VisionRunningMode.IMAGE)
18
 
19
  mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image)
@@ -21,14 +46,14 @@ def analyze_image(image):
21
  with ObjectDetector.create_from_options(options) as detector:
22
  detection_result = detector.detect(mp_image)
23
 
24
- results = ""
25
- for i,detection in enumerate(detection_result.detections):
26
- results += f'Detection {i}: {detection.categories[0].category_name}\n'
27
 
28
- return results
29
 
30
- img = gr.Image()
31
- txt = gr.Text()
32
 
33
- iface = gr.Interface(fn=analyze_image, inputs=img, outputs=txt)
34
  iface.launch()
 
1
  import gradio as gr
2
  import mediapipe as mp
3
+ import cv2
4
+ import numpy as np
5
 
6
  BaseOptions = mp.tasks.BaseOptions
7
  ObjectDetector = mp.tasks.vision.ObjectDetector
8
  ObjectDetectorOptions = mp.tasks.vision.ObjectDetectorOptions
9
  VisionRunningMode = mp.tasks.vision.RunningMode
10
 
11
+ MARGIN = 10 # pixels
12
+ ROW_SIZE = 10 # pixels
13
+ FONT_SIZE = 1
14
+ FONT_THICKNESS = 1
15
+ TEXT_COLOR = (255, 0, 0) # red
16
+
17
+ def visualize(image, detection_result) -> np.ndarray:
18
+ for detection in detection_result.detections:
19
+ # Draw bounding_box
20
+ bbox = detection.bounding_box
21
+ start_point = bbox.origin_x, bbox.origin_y
22
+ end_point = bbox.origin_x + bbox.width, bbox.origin_y + bbox.height
23
+ cv2.rectangle(image, start_point, end_point, TEXT_COLOR, 3)
24
+
25
+ # Draw label and score
26
+ category = detection.categories[0]
27
+ category_name = category.category_name
28
+ probability = round(category.score, 2)
29
+ result_text = category_name + ' (' + str(probability) + ')'
30
+ text_location = (MARGIN + bbox.origin_x,
31
+ MARGIN + ROW_SIZE + bbox.origin_y)
32
+ cv2.putText(image, result_text, text_location, cv2.FONT_HERSHEY_PLAIN,
33
+ FONT_SIZE, TEXT_COLOR, FONT_THICKNESS)
34
+ return image
35
+
36
  def analyze_image(image):
37
  model_path = "efficientdet_lite0.tflite"
38
 
39
  options = ObjectDetectorOptions(
40
  base_options=BaseOptions(model_asset_path=model_path),
41
+ max_results=5,
42
  running_mode=VisionRunningMode.IMAGE)
43
 
44
  mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image)
 
46
  with ObjectDetector.create_from_options(options) as detector:
47
  detection_result = detector.detect(mp_image)
48
 
49
+ image_copy = np.copy(image.numpy_view())
50
+ annotated_image = visualize(image_copy, detection_result)
51
+ rgb_annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
52
 
53
+ return rgb_annotated_image
54
 
55
+ img_in = gr.Image()
56
+ #img_out = gr.Image()
57
 
58
+ iface = gr.Interface(fn=analyze_image, inputs=img_in, outputs="image")
59
  iface.launch()
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  numpy
2
- mediapipe
 
 
1
  numpy
2
+ mediapipe
3
+ opencv-python