merve HF Staff commited on
Commit
234a585
·
verified ·
1 Parent(s): ec8a99d

add app.py

Browse files
Files changed (1) hide show
  1. app.py +439 -0
app.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ import cv2
4
+ import tempfile
5
+ import os
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ from gradio.themes.ocean import Ocean
10
+ from PIL import Image
11
+ import torch
12
+ from transformers import AutoModelForCausalLM
13
+ import supervision as sv
14
+
15
+ model_id = "moondream/moondream3-preview"
16
+
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ model_id,
19
+ trust_remote_code=True,
20
+ torch_dtype=torch.bfloat16,
21
+ device_map={"": "cuda"},
22
+ )
23
+ model.compile()
24
+
25
+ def create_annotated_image(image, detection_result, object_name="Object"):
26
+ if not isinstance(detection_result, dict) or "objects" not in detection_result:
27
+ return image
28
+
29
+ original_width, original_height = image.size
30
+ annotated_image = np.array(image.convert("RGB"))
31
+
32
+ bboxes = []
33
+ labels = []
34
+
35
+ for i, obj in enumerate(detection_result["objects"]):
36
+ x_min = int(obj["x_min"] * original_width)
37
+ y_min = int(obj["y_min"] * original_height)
38
+ x_max = int(obj["x_max"] * original_width)
39
+ y_max = int(obj["y_max"] * original_height)
40
+
41
+ x_min = max(0, min(x_min, original_width))
42
+ y_min = max(0, min(y_min, original_height))
43
+ x_max = max(0, min(x_max, original_width))
44
+ y_max = max(0, min(y_max, original_height))
45
+
46
+ if x_max > x_min and y_max > y_min:
47
+ bboxes.append([x_min, y_min, x_max, y_max])
48
+ labels.append(f"{object_name} {i+1}")
49
+ print(f"Box {i+1}: ({x_min}, {y_min}, {x_max}, {y_max})")
50
+
51
+
52
+ detections = sv.Detections(
53
+ xyxy=np.array(bboxes, dtype=np.float32),
54
+ class_id=np.arange(len(bboxes))
55
+ )
56
+
57
+ bounding_box_annotator = sv.BoxAnnotator(
58
+ thickness=3,
59
+ color_lookup=sv.ColorLookup.INDEX
60
+ )
61
+ label_annotator = sv.LabelAnnotator(
62
+ text_thickness=2,
63
+ text_scale=0.6,
64
+ color_lookup=sv.ColorLookup.INDEX
65
+ )
66
+
67
+ annotated_image = bounding_box_annotator.annotate(
68
+ scene=annotated_image, detections=detections
69
+ )
70
+ annotated_image = label_annotator.annotate(
71
+ scene=annotated_image, detections=detections, labels=labels
72
+ )
73
+
74
+
75
+
76
+ return Image.fromarray(annotated_image)
77
+
78
+
79
+
80
+ def process_video_with_tracking(video_path, prompt, detection_interval=3):
81
+
82
+ cap = cv2.VideoCapture(video_path)
83
+
84
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
85
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
86
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
87
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
88
+
89
+ byte_tracker = sv.ByteTrack()
90
+
91
+ temp_dir = tempfile.mkdtemp()
92
+ output_path = os.path.join(temp_dir, "tracked_video.mp4")
93
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
94
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
95
+
96
+ frame_count = 0
97
+ detection_count = 0
98
+ last_detections = None
99
+
100
+ try:
101
+ while True:
102
+ ret, frame = cap.read()
103
+ if not ret:
104
+ break
105
+
106
+ run_detection = (frame_count % detection_interval == 0)
107
+
108
+ if run_detection:
109
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
110
+ pil_image = Image.fromarray(frame_rgb)
111
+
112
+ result = model.detect(pil_image, prompt)
113
+ detection_count += 1
114
+
115
+ if "objects" in result and result["objects"]:
116
+ bboxes = []
117
+ confidences = []
118
+
119
+ for obj in result["objects"]:
120
+ x_min = max(0.0, min(1.0, obj["x_min"])) * width
121
+ y_min = max(0.0, min(1.0, obj["y_min"])) * height
122
+ x_max = max(0.0, min(1.0, obj["x_max"])) * width
123
+ y_max = max(0.0, min(1.0, obj["y_max"])) * height
124
+
125
+ if x_max > x_min and y_max > y_min:
126
+ bboxes.append([x_min, y_min, x_max, y_max])
127
+ confidences.append(0.8)
128
+
129
+ if bboxes:
130
+ detections = sv.Detections(
131
+ xyxy=np.array(bboxes, dtype=np.float32),
132
+ confidence=np.array(confidences, dtype=np.float32),
133
+ class_id=np.zeros(len(bboxes), dtype=int)
134
+ )
135
+
136
+ detections = byte_tracker.update_with_detections(detections)
137
+ last_detections = detections
138
+ else:
139
+ empty_detections = sv.Detections.empty()
140
+ detections = byte_tracker.update_with_detections(empty_detections)
141
+ last_detections = detections
142
+ else:
143
+ empty_detections = sv.Detections.empty()
144
+ detections = byte_tracker.update_with_detections(empty_detections)
145
+ last_detections = detections
146
+
147
+
148
+ else:
149
+ empty_detections = sv.Detections.empty()
150
+ detections = byte_tracker.update_with_detections(empty_detections)
151
+ if detections is not None and len(detections) > 0:
152
+ box_annotator = sv.BoxAnnotator(
153
+ thickness=3,
154
+ color_lookup=sv.ColorLookup.TRACK
155
+ )
156
+ label_annotator = sv.LabelAnnotator(
157
+ text_scale=0.6,
158
+ text_thickness=2,
159
+ color_lookup=sv.ColorLookup.TRACK
160
+ )
161
+
162
+ labels = []
163
+ for tracker_id in detections.tracker_id:
164
+ if tracker_id is not None:
165
+ labels.append(f"{prompt} ID: {tracker_id}")
166
+ else:
167
+ labels.append(f"{prompt} Unknown")
168
+
169
+ frame = box_annotator.annotate(scene=frame, detections=detections)
170
+ frame = label_annotator.annotate(scene=frame, detections=detections, labels=labels)
171
+
172
+ out.write(frame)
173
+ frame_count += 1
174
+
175
+ if frame_count % 30 == 0:
176
+ progress = (frame_count / total_frames) * 100
177
+ print(f"Processing: {progress:.1f}% ({frame_count}/{total_frames}) - Detections: {detection_count}")
178
+
179
+ finally:
180
+ cap.release()
181
+ out.release()
182
+
183
+ summary = f"""Video processing complete:
184
+ - Total frames processed: {frame_count}
185
+ - Detection runs: {detection_count} (every {detection_interval} frames)
186
+ - Objects tracked: {prompt}
187
+ - Processing speed: ~{detection_count/frame_count*100:.1f}% detection rate for optimization"""
188
+
189
+ return output_path, summary
190
+
191
+ def create_point_annotated_image(image, point_result):
192
+ """Create annotated image with points for detected objects."""
193
+ if not isinstance(point_result, dict) or "points" not in point_result:
194
+ return image
195
+
196
+ original_width, original_height = image.size
197
+ annotated_image = np.array(image.convert("RGB"))
198
+
199
+ points = []
200
+ for point in point_result["points"]:
201
+ x = int(point["x"] * original_width)
202
+ y = int(point["y"] * original_height)
203
+ points.append([x, y])
204
+
205
+ if points:
206
+ points_array = np.array(points).reshape(1, -1, 2)
207
+ key_points = sv.KeyPoints(xy=points_array)
208
+ vertex_annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED)
209
+ annotated_image = vertex_annotator.annotate(
210
+ scene=annotated_image, key_points=key_points
211
+ )
212
+
213
+ return Image.fromarray(annotated_image)
214
+
215
+ def detect_objects(image, prompt, task_type, max_objects):
216
+ STANDARD_SIZE = (1024, 1024)
217
+ image.thumbnail(STANDARD_SIZE)
218
+
219
+ t0 = time.perf_counter()
220
+
221
+ if task_type == "Object Detection":
222
+ settings = {"max_objects": max_objects} if max_objects > 0 else {}
223
+ result = model.detect(image, prompt, settings=settings)
224
+ annotated_image = create_annotated_image(image, result, prompt)
225
+
226
+ elif task_type == "Point Detection":
227
+ result = model.point(image, prompt)
228
+ annotated_image = create_point_annotated_image(image, result)
229
+
230
+ elif task_type == "Caption":
231
+ result = model.caption(image, length="normal")
232
+ annotated_image = image
233
+
234
+ else:
235
+ result = model.query(image=image, question=prompt, reasoning=True)
236
+ annotated_image = image
237
+
238
+
239
+ elapsed_ms = (time.perf_counter() - t0) * 1_000
240
+
241
+ if isinstance(result, dict):
242
+ if "objects" in result:
243
+ output_text = f"Found {len(result['objects'])} objects:\n"
244
+ for i, obj in enumerate(result['objects'], 1):
245
+ output_text += f"\n{i}. Bounding box: "
246
+ output_text += f"({obj['x_min']:.3f}, {obj['y_min']:.3f}, {obj['x_max']:.3f}, {obj['y_max']:.3f})"
247
+ elif "points" in result:
248
+ output_text = f"Found {len(result['points'])} points:\n"
249
+ for i, point in enumerate(result['points'], 1):
250
+ output_text += f"\n{i}. Point: ({point['x']:.3f}, {point['y']:.3f})"
251
+ elif "caption" in result:
252
+ output_text = result['caption']
253
+ elif "answer" in result:
254
+ if "reasoning" in result:
255
+ output_text = f"Reasoning: {result['reasoning']}\n\nAnswer: {result['answer']}"
256
+ else:
257
+ output_text = result['answer']
258
+ else:
259
+ output_text = json.dumps(result, indent=2)
260
+ else:
261
+ output_text = str(result)
262
+
263
+ timing_text = f"Inference time: {elapsed_ms:.0f} ms"
264
+
265
+ return annotated_image, output_text, timing_text
266
+
267
+ def process_video(video_file, prompt, detection_interval):
268
+ if video_file is None:
269
+ return None, "Please upload a video file"
270
+
271
+ output_path, summary = process_video_with_tracking(
272
+ video_file, prompt, detection_interval
273
+ )
274
+ return output_path, summary
275
+
276
+
277
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
278
+ gr.Markdown("# Moondream3 🌝")
279
+ gr.Markdown("""
280
+ *Try [Moondream3 Preview](https://huggingface.co/moondream/moondream3-preview) for following tasks:*
281
+
282
+ - **Object Detection**
283
+ - **Point Detection**
284
+ - **Captioning**
285
+ - **Visual Question Answering**
286
+ - **Video Object Tracking**
287
+ """)
288
+
289
+ with gr.Tabs() as tabs:
290
+ with gr.Tab("Image Processing"):
291
+ with gr.Row():
292
+ with gr.Column(scale=2):
293
+ image_input = gr.Image(label="Upload an image", type="pil", height=400)
294
+
295
+ task_type = gr.Radio(
296
+ choices=["Object Detection", "Point Detection", "Caption", "Visual Question Answering"],
297
+ label="Task Type",
298
+ value="Object Detection"
299
+ )
300
+
301
+ prompt_input = gr.Textbox(
302
+ label="Prompt (object to detect/question to ask)",
303
+ placeholder="e.g., 'car', 'person', 'What's in this image?'",
304
+ value="objects"
305
+ )
306
+
307
+ max_objects = gr.Number(
308
+ label="Max Objects (for Object Detection only)",
309
+ value=10,
310
+ minimum=1,
311
+ maximum=50,
312
+ step=1,
313
+ visible=True
314
+ )
315
+
316
+ generate_btn = gr.Button(value="Generate", variant="primary")
317
+
318
+ with gr.Column(scale=2):
319
+ output_image = gr.Image(
320
+ type="pil",
321
+ label="Result",
322
+ height=400
323
+ )
324
+ output_textbox = gr.Textbox(
325
+ label="Model Response",
326
+ lines=10,
327
+ show_copy_button=True
328
+ )
329
+ output_time = gr.Markdown()
330
+
331
+ gr.Markdown("### Examples")
332
+
333
+ example_prompts = [
334
+ [
335
+ "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/candy.JPG",
336
+ "Object Detection",
337
+ "candy",
338
+ 5
339
+ ],
340
+ [
341
+ "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/candy.JPG",
342
+ "Point Detection",
343
+ "candy",
344
+ 5
345
+ ],
346
+ [
347
+ "https://moondream.ai/images/blog/moondream-3-preview/benchmarks.jpg",
348
+ "Caption",
349
+ "",
350
+ 5
351
+ ],
352
+ [
353
+ "https://moondream.ai/images/blog/moondream-3-preview/benchmarks.jpg",
354
+ "Visual Question Answering",
355
+ "how well does moondream 3 perform in chartvqa?",
356
+ 5
357
+ ],
358
+ ]
359
+
360
+ gr.Examples(
361
+ examples=example_prompts,
362
+ inputs=[image_input, task_type, prompt_input, max_objects],
363
+ label="Click an example to populate inputs"
364
+ )
365
+
366
+ with gr.Tab("Video Object Tracking"):
367
+ with gr.Row():
368
+ with gr.Column(scale=2):
369
+ video_input = gr.Video(
370
+ label="Upload a video file",
371
+ height=400
372
+ )
373
+
374
+ video_prompt = gr.Textbox(
375
+ label="Object to track",
376
+ placeholder="e.g., 'person', 'car', 'ball'",
377
+ value="person"
378
+ )
379
+
380
+ detection_interval = gr.Slider(
381
+ minimum=1,
382
+ maximum=30,
383
+ value=5,
384
+ step=5,
385
+ label="Detection Interval (frames)",
386
+ info="Run detection every N frames (1 = every frame, slower but more accurate)"
387
+ )
388
+
389
+ process_video_btn = gr.Button(value="Process Video", variant="primary")
390
+
391
+ with gr.Column(scale=2):
392
+ output_video = gr.Video(
393
+ label="Tracked Video Result",
394
+ height=400
395
+ )
396
+ video_summary = gr.Textbox(
397
+ label="Processing Summary",
398
+ lines=8,
399
+ show_copy_button=True
400
+ )
401
+ gr.Markdown("### Examples")
402
+
403
+ example_prompts = [
404
+ [
405
+ "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_8137.mp4",
406
+ "snowboarder",
407
+ 5
408
+ ],
409
+ ]
410
+
411
+ gr.Examples(
412
+ examples=example_prompts,
413
+ inputs=[video_input, video_prompt, detection_interval],
414
+ label="Click an example to populate inputs"
415
+ )
416
+ def update_max_objects_visibility(task):
417
+ return gr.Number(visible=(task == "Object Detection"))
418
+
419
+ task_type.change(
420
+ fn=update_max_objects_visibility,
421
+ inputs=[task_type],
422
+ outputs=[max_objects]
423
+ )
424
+
425
+
426
+ generate_btn.click(
427
+ fn=detect_objects,
428
+ inputs=[image_input, prompt_input, task_type, max_objects],
429
+ outputs=[output_image, output_textbox, output_time]
430
+ )
431
+
432
+ process_video_btn.click(
433
+ fn=process_video,
434
+ inputs=[video_input, video_prompt, detection_interval],
435
+ outputs=[output_video, video_summary]
436
+ )
437
+
438
+ if __name__ == "__main__":
439
+ demo.launch(share=True, debug=True)