multimodalart HF Staff commited on
Commit
f3050ba
·
verified ·
1 Parent(s): e2b49c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +344 -40
app.py CHANGED
@@ -52,15 +52,20 @@ def add_contour(img, mask, color=(1., 1., 1.)):
52
  @spaces.GPU(duration=120)
53
  def generate_masks(image, mask_list, mask_raw_list):
54
  """
55
- Generate masks from user-drawn annotations on an image.
56
-
57
  Args:
58
- image: Dictionary containing the image editor state with background and layers
59
- mask_list: List of generated mask images with labels
60
- mask_raw_list: List of raw numpy arrays of masks
61
-
 
62
  Returns:
63
- Tuple containing updated mask_list, image editor state, mask_list, and mask_raw_list
 
 
 
 
64
  """
65
  image['image'] = image['background'].convert('RGB')
66
  # del image['background'], image['composite']
@@ -89,15 +94,21 @@ def generate_masks(image, mask_list, mask_raw_list):
89
  @spaces.GPU(duration=120)
90
  def generate_masks_video(image, mask_list_video, mask_raw_list_video):
91
  """
92
- Generate masks from user-drawn annotations on a video frame.
93
-
94
  Args:
95
- image: Dictionary containing the image editor state with background and layers
96
- mask_list_video: List of generated mask images with labels for video
97
- mask_raw_list_video: List of raw numpy arrays of masks for video
98
-
 
 
99
  Returns:
100
- Tuple containing updated mask_list_video, image editor state, mask_list_video, and mask_raw_list_video
 
 
 
 
101
  """
102
  image['image'] = image['background'].convert('RGB')
103
  # del image['background'], image['composite']
@@ -127,16 +138,19 @@ def generate_masks_video(image, mask_list_video, mask_raw_list_video):
127
  @spaces.GPU(duration=120)
128
  def describe(image, mode, query, masks):
129
  """
130
- Generate descriptions or answer questions about regions in an image.
131
-
132
  Args:
133
- image: Dictionary containing the image editor state
134
- mode: Either "Caption" or "QA" mode
135
- query: Question to ask about the image (used in QA mode)
136
- masks: List of mask arrays for the regions
137
-
138
- Returns:
139
- Generator yielding image with contours, generated text, and updated image state
 
 
 
140
  """
141
  # Create an image object from the uploaded image
142
  # print(image.keys())
@@ -229,13 +243,16 @@ def describe(image, mode, query, masks):
229
 
230
  def load_first_frame(video_path):
231
  """
232
- Load and return the first frame of a video.
233
-
234
  Args:
235
- video_path: Path to the video file
236
-
237
  Returns:
238
- PIL Image of the first frame
 
 
 
239
  """
240
  cap = cv2.VideoCapture(video_path)
241
  ret, frame = cap.read()
@@ -249,18 +266,23 @@ def load_first_frame(video_path):
249
  @spaces.GPU(duration=120)
250
  def describe_video(video_path, mode, query, annotated_frame, masks, mask_list_video):
251
  """
252
- Generate descriptions or answer questions about regions in a video.
253
-
254
  Args:
255
- video_path: Path to the video file
256
- mode: Either "Caption" or "QA" mode
257
- query: Question to ask about the video (used in QA mode)
258
- annotated_frame: Dictionary containing the annotated first frame
259
- masks: List of mask arrays for the regions
260
- mask_list_video: List of mask images with labels
261
-
262
- Returns:
263
- Generator yielding frame image, generated text, and updated mask lists
 
 
 
 
 
264
  """
265
  # Create a temporary directory to save extracted video frames
266
  cap = cv2.VideoCapture(video_path)
@@ -351,4 +373,286 @@ def describe_video(video_path, mode, query, annotated_frame, masks, mask_list_vi
351
  mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(annotated_frame['image'])).astype(np.uint8))
352
  mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>"))
353
  text = ""
354
- yield frame_img, text, mask_list_video, mask_list_video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  @spaces.GPU(duration=120)
53
  def generate_masks(image, mask_list, mask_raw_list):
54
  """
55
+ Generates segmentation masks for selected regions in an image using SAM.
56
+
57
  Args:
58
+ image (dict): A dictionary containing image data, typically from a Gradio ImageEditor,
59
+ with 'background' (PIL Image) and 'layers' (list of PIL Image layers).
60
+ mask_list (list): A list to accumulate (mask_image, label) tuples for display in a gallery.
61
+ mask_raw_list (list): A list to accumulate raw NumPy mask arrays.
62
+
63
  Returns:
64
+ tuple: A tuple containing:
65
+ - mask_list (list): Updated list of mask images for display.
66
+ - image (dict): Updated image dictionary with layers cleared.
67
+ - mask_list (list): Redundant return of mask_list (for Gradio update).
68
+ - mask_raw_list (list): Updated list of raw mask arrays.
69
  """
70
  image['image'] = image['background'].convert('RGB')
71
  # del image['background'], image['composite']
 
94
  @spaces.GPU(duration=120)
95
  def generate_masks_video(image, mask_list_video, mask_raw_list_video):
96
  """
97
+ Generates segmentation masks for selected regions in the first frame of a video using SAM.
98
+
99
  Args:
100
+ image (dict): A dictionary containing image data (first frame of video),
101
+ typically from a Gradio ImageEditor, with 'background' (PIL Image)
102
+ and 'layers' (list of PIL Image layers).
103
+ mask_list_video (list): A list to accumulate (mask_image, label) tuples for display.
104
+ mask_raw_list_video (list): A list to accumulate raw NumPy mask arrays for video processing.
105
+
106
  Returns:
107
+ tuple: A tuple containing:
108
+ - mask_list_video (list): Updated list of mask images for display.
109
+ - image (dict): Updated image dictionary with layers cleared.
110
+ - mask_list_video (list): Redundant return of mask_list_video (for Gradio update).
111
+ - mask_raw_list_video (list): Updated list of raw mask arrays.
112
  """
113
  image['image'] = image['background'].convert('RGB')
114
  # del image['background'], image['composite']
 
138
  @spaces.GPU(duration=120)
139
  def describe(image, mode, query, masks):
140
  """
141
+ Describes an image based on selected regions or answers a question about them.
142
+
143
  Args:
144
+ image (dict): A dictionary containing image data, typically from a Gradio ImageEditor,
145
+ with 'background' (PIL Image) and 'layers' (list of PIL Image layers).
146
+ mode (str): The operational mode, either "Caption" (to describe a selected region)
147
+ or "QA" (to answer a question about one or more regions).
148
+ query (str): The question to ask in "QA" mode. Ignored in "Caption" mode.
149
+ masks (list): A list of raw NumPy mask arrays representing previously generated masks.
150
+
151
+ Yields:
152
+ tuple: An image with contours and the generated text description/answer,
153
+ or updates for Gradio components during streaming.
154
  """
155
  # Create an image object from the uploaded image
156
  # print(image.keys())
 
243
 
244
  def load_first_frame(video_path):
245
  """
246
+ Loads the first frame of a given video file.
247
+
248
  Args:
249
+ video_path (str): The file path to the video.
250
+
251
  Returns:
252
+ PIL.Image.Image: The first frame of the video as a PIL Image.
253
+
254
+ Raises:
255
+ gr.Error: If the video file cannot be read.
256
  """
257
  cap = cv2.VideoCapture(video_path)
258
  ret, frame = cap.read()
 
266
  @spaces.GPU(duration=120)
267
  def describe_video(video_path, mode, query, annotated_frame, masks, mask_list_video):
268
  """
269
+ Describes a video based on selected regions in its first frame or answers a question about them.
270
+
271
  Args:
272
+ video_path (str): The file path to the video.
273
+ mode (str): The operational mode, either "Caption" (to describe a selected region)
274
+ or "QA" (to answer a question about one or more regions).
275
+ query (str): The question to ask in "QA" mode. Ignored in "Caption" mode.
276
+ annotated_frame (dict): A dictionary containing the first frame's image data
277
+ from a Gradio ImageEditor, with 'background' (PIL Image)
278
+ and 'layers' (list of PIL Image layers).
279
+ masks (list): A list of raw NumPy mask arrays representing previously generated masks
280
+ for objects in the video.
281
+ mask_list_video (list): A list to accumulate (mask_image, label) tuples for display.
282
+
283
+ Yields:
284
+ tuple: The annotated first frame, the generated text description/answer,
285
+ and updated mask lists for Gradio components during streaming.
286
  """
287
  # Create a temporary directory to save extracted video frames
288
  cap = cv2.VideoCapture(video_path)
 
373
  mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(annotated_frame['image'])).astype(np.uint8))
374
  mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>"))
375
  text = ""
376
+ yield frame_img, text, mask_list_video, mask_list_video
377
+
378
+ for token in get_model_output(
379
+ video_tensor,
380
+ query,
381
+ model=model,
382
+ tokenizer=tokenizer,
383
+ masks=masks,
384
+ mask_ids=mask_ids,
385
+ modal='video',
386
+ streaming=True,
387
+ ):
388
+ text += token
389
+ yield gr.update(), text, gr.update(), gr.update()
390
+
391
+
392
+ @spaces.GPU(duration=120)
393
+ def apply_sam(image, input_points):
394
+ """
395
+ Applies the Segment Anything Model (SAM) to an image based on input points
396
+ to generate a segmentation mask.
397
+
398
+ Args:
399
+ image (PIL.Image.Image): The input image.
400
+ input_points (list): A list of lists, where each inner list contains
401
+ [x, y] coordinates representing points used for segmentation.
402
+
403
+ Returns:
404
+ numpy.ndarray: The selected binary segmentation mask as a NumPy array (H, W).
405
+ """
406
+ inputs = sam_processor(image, input_points=input_points, return_tensors="pt").to(device)
407
+
408
+ with torch.no_grad():
409
+ outputs = sam_model(**inputs)
410
+
411
+ masks = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())[0][0]
412
+ scores = outputs.iou_scores[0, 0]
413
+
414
+ mask_selection_index = scores.argmax()
415
+
416
+ mask_np = masks[mask_selection_index].numpy()
417
+
418
+ return mask_np
419
+
420
+
421
+ def clear_masks():
422
+ """
423
+ Clears the stored lists of masks and raw masks.
424
+
425
+ Returns:
426
+ tuple: Three empty lists, intended to reset Gradio components
427
+ displaying masks.
428
+ """
429
+ return [], [], []
430
+
431
+
432
+ if __name__ == "__main__":
433
+ parser = argparse.ArgumentParser(description="VideoRefer gradio demo")
434
+ parser.add_argument("--model-path", type=str, default="DAMO-NLP-SG/VideoRefer-VideoLLaMA3-7B", help="Path to the model checkpoint")
435
+ parser.add_argument("--prompt-mode", type=str, default="focal_prompt", help="Prompt mode")
436
+ parser.add_argument("--conv-mode", type=str, default="v1", help="Conversation mode")
437
+ parser.add_argument("--temperature", type=float, default=0.2, help="Sampling temperature")
438
+ parser.add_argument("--top_p", type=float, default=0.5, help="Top-p for sampling")
439
+
440
+ args_cli = parser.parse_args()
441
+
442
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="amber")) as demo:
443
+
444
+ mask_list = gr.State([])
445
+ mask_raw_list = gr.State([])
446
+ mask_list_video = gr.State([])
447
+ mask_raw_list_video = gr.State([])
448
+
449
+
450
+ HEADER = ("""
451
+ <div>
452
+ <h1>VideoRefer X VideoLLaMA3 Demo</h1>
453
+ <h5 style="margin: 0;">Feel free to click on anything that grabs your interest!</h5>
454
+ <h5 style="margin: 0;">If this demo please you, please give us a star ⭐ on Github or 💖 on this space.</h5>
455
+ </div>
456
+ </div>
457
+ <div style="display: flex; justify-content: left; margin-top: 10px;">
458
+ <a href="https://arxiv.org/pdf/2501.00599"><img src="https://img.shields.io/badge/Arxiv-2501.00599-ECA8A7" style="margin-right: 5px;"></a>
459
+ <a href="https://github.com/DAMO-NLP-SG/VideoRefer"><img src='https://img.shields.io/badge/Github-VideoRefer-F7C97E' style="margin-right: 5px;"></a>
460
+ <a href="https://github.com/DAMO-NLP-SG/VideoLLaMA3"><img src='https://img.shields.io/badge/Github-VideoLLaMA3-9DC3E6' style="margin-right: 5px;"></a>
461
+ </div>
462
+ """)
463
+
464
+ with gr.Row():
465
+ with gr.Column():
466
+ gr.HTML(HEADER)
467
+
468
+
469
+ image_tips = """
470
+ ### 💡 Tips:
471
+
472
+ 🧸 Upload an image, and you can use the drawing tool✍️ to highlight the areas you're interested in.
473
+
474
+ 🔖 For single-object caption mode, simply select the area and click the 'Generate Caption' button to receive a caption for the object.
475
+
476
+ 🔔 In QA mode, you can generate multiple masks by clicking the 'Generate Mask' button multiple times. Afterward, use the corresponding object id to ask questions.
477
+
478
+ 📌 Click the button 'Clear Masks' to clear the current generated masks.
479
+
480
+ """
481
+
482
+ video_tips = """
483
+ ### 💡 Tips:
484
+ ⚠️ For video mode, we only support masking on the first frame in this demo.
485
+
486
+ 🧸 Upload an video, and you can use the drawing tool✍️ to highlight the areas you're interested in the first frame.
487
+
488
+ 🔖 For single-object caption mode, simply select the area and click the 'Generate Caption' button to receive a caption for the object.
489
+
490
+ 🔔 In QA mode, you can generate multiple masks by clicking the 'Generate Mask' button multiple times. Afterward, use the corresponding object id to ask questions.
491
+
492
+ 📌 Click the button 'Clear Masks' to clear the current generated masks.
493
+
494
+ """
495
+
496
+
497
+ with gr.TabItem("Image"):
498
+ with gr.Row():
499
+ with gr.Column():
500
+ image_input = gr.ImageEditor(
501
+ label="Image",
502
+ type="pil",
503
+ sources=['upload'],
504
+ brush=gr.Brush(colors=["#ED7D31"], color_mode="fixed", default_size=10),
505
+ eraser=True,
506
+ layers=False,
507
+ transforms=[],
508
+ height=300,
509
+ )
510
+ generate_mask_btn = gr.Button("1️⃣ Generate Mask", visible=False, variant="primary")
511
+ mode = gr.Radio(label="Mode", choices=["Caption", "QA"], value="Caption")
512
+ query = gr.Textbox(label="Question", value="What is the relationship between <region0> and <region1>?", interactive=True, visible=False)
513
+
514
+ submit_btn = gr.Button("Generate Caption", variant="primary")
515
+ submit_btn1 = gr.Button("2️⃣ Generate Answer", variant="primary", visible=False)
516
+ gr.Examples([f"./demo/images/{i+1}.jpg" for i in range(8)], inputs=image_input, label="Examples")
517
+
518
+ with gr.Column():
519
+ mask_output = gr.Gallery(label="Referred Masks", object_fit='scale-down', visible=False)
520
+ output_image = gr.Image(label="Image with Mask", visible=True, height=400)
521
+ description = gr.Textbox(label="Output", visible=True)
522
+
523
+ clear_masks_btn = gr.Button("Clear Masks", variant="secondary", visible=False)
524
+ gr.Markdown(image_tips)
525
+
526
+ with gr.TabItem("Video"):
527
+ with gr.Row():
528
+ with gr.Column():
529
+ video_input = gr.Video(label="Video")
530
+ # load_btn = gr.Button("🖼️ Load First Frame", variant="secondary")
531
+ first_frame = gr.ImageEditor(
532
+ label="Annotate First Frame",
533
+ type="pil",
534
+ sources=['upload'],
535
+ brush=gr.Brush(colors=["#ED7D31"], color_mode="fixed", default_size=10),
536
+ eraser=True,
537
+ layers=False,
538
+ transforms=[],
539
+ height=300,
540
+ )
541
+ generate_mask_btn_video = gr.Button("1️⃣ Generate Mask", visible=False, variant="primary")
542
+ gr.Examples([f"./demo/videos/{i+1}.mp4" for i in range(4)], inputs=video_input, label="Examples")
543
+
544
+ with gr.Column():
545
+ mode_video = gr.Radio(label="Mode", choices=["Caption", "QA"], value="Caption")
546
+ mask_output_video = gr.Gallery(label="Referred Masks", object_fit='scale-down')
547
+
548
+ query_video = gr.Textbox(label="Question", value="What is the relationship between <object0> and <object1>?", interactive=True, visible=False)
549
+
550
+ submit_btn_video = gr.Button("Generate Caption", variant="primary")
551
+ submit_btn_video1 = gr.Button("2️⃣ Generate Answer", variant="primary", visible=False)
552
+ description_video = gr.Textbox(label="Output", visible=True)
553
+
554
+ clear_masks_btn_video = gr.Button("Clear Masks", variant="secondary")
555
+
556
+ gr.Markdown(video_tips)
557
+
558
+
559
+ def toggle_query_and_generate_button(mode):
560
+ """
561
+ Toggles the visibility of query-related Gradio components based on the selected mode.
562
+ Also clears mask states.
563
+
564
+ Args:
565
+ mode (str): The selected mode ("Caption" or "QA").
566
+
567
+ Returns:
568
+ tuple: A series of gr.update() calls and empty lists to update Gradio components.
569
+ """
570
+ query_visible = mode == "QA"
571
+ caption_visible = mode == "Caption"
572
+ return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), gr.update(visible=caption_visible), [], "", [], [],[],[]
573
+
574
+ video_input.change(load_first_frame, inputs=video_input, outputs=first_frame)
575
+
576
+ mode.change(toggle_query_and_generate_button, inputs=mode, outputs=[query, generate_mask_btn, clear_masks_btn, submit_btn1, mask_output, output_image, submit_btn, mask_output, description, mask_list, mask_raw_list, mask_list_video, mask_raw_list_video])
577
+
578
+ def toggle_query_and_generate_button_video(mode):
579
+ """
580
+ Toggles the visibility of query-related Gradio components for video mode
581
+ based on the selected mode. Also clears mask states.
582
+
583
+ Args:
584
+ mode (str): The selected mode ("Caption" or "QA").
585
+
586
+ Returns:
587
+ tuple: A series of gr.update() calls and empty lists to update Gradio components.
588
+ """
589
+ query_visible = mode == "QA"
590
+ caption_visible = mode == "Caption"
591
+ return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), [], [], [], [], []
592
+
593
+
594
+ mode_video.change(toggle_query_and_generate_button_video, inputs=mode_video, outputs=[query_video, generate_mask_btn_video, submit_btn_video1, submit_btn_video, mask_output_video, mask_list, mask_raw_list, mask_list_video, mask_raw_list_video])
595
+
596
+ submit_btn.click(
597
+ fn=describe,
598
+ inputs=[image_input, mode, query, mask_raw_list],
599
+ outputs=[output_image, description, image_input],
600
+ api_name="describe"
601
+ )
602
+
603
+ submit_btn1.click(
604
+ fn=describe,
605
+ inputs=[image_input, mode, query, mask_raw_list],
606
+ outputs=[output_image, description, image_input],
607
+ api_name="describe"
608
+ )
609
+
610
+ generate_mask_btn.click(
611
+ fn=generate_masks,
612
+ inputs=[image_input, mask_list, mask_raw_list],
613
+ outputs=[mask_output, image_input, mask_list, mask_raw_list]
614
+ )
615
+
616
+ generate_mask_btn_video.click(
617
+ fn=generate_masks_video,
618
+ inputs=[first_frame, mask_list_video, mask_raw_list_video],
619
+ outputs=[mask_output_video, first_frame, mask_list_video, mask_raw_list_video]
620
+ )
621
+
622
+ clear_masks_btn.click(
623
+ fn=clear_masks,
624
+ outputs=[mask_output, mask_list, mask_raw_list]
625
+ )
626
+
627
+ clear_masks_btn_video.click(
628
+ fn=clear_masks,
629
+ outputs=[mask_output_video, mask_list_video, mask_raw_list_video]
630
+ )
631
+
632
+ submit_btn_video.click(
633
+ fn=describe_video,
634
+ inputs=[video_input, mode_video, query_video, first_frame, mask_raw_list_video, mask_list_video],
635
+ outputs=[first_frame, description_video, mask_output_video, mask_list_video],
636
+ api_name="describe_video"
637
+ )
638
+
639
+ submit_btn_video1.click(
640
+ fn=describe_video,
641
+ inputs=[video_input, mode_video, query_video, first_frame, mask_raw_list_video, mask_list_video],
642
+ outputs=[first_frame, description_video, mask_output_video, mask_list_video],
643
+ api_name="describe_video"
644
+ )
645
+
646
+
647
+
648
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
649
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
650
+ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
651
+
652
+ disable_torch_init()
653
+
654
+
655
+ model, processor, tokenizer = model_init(args_cli.model_path)
656
+
657
+
658
+ demo.launch(mcp_server=True)