lixin4ever CircleRadon commited on
Commit
7dede7c
·
verified ·
1 Parent(s): 69a1bb9

Update app.py (#6)

Browse files

- Update app.py (60ccb1ef6748c9246dc32f76f0cd54040003ebde)


Co-authored-by: YuqianYuan <[email protected]>

Files changed (1) hide show
  1. app.py +36 -62
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import spaces
2
  import gradio as gr
3
  import numpy as np
4
- import os
5
  import torch
6
  from transformers import SamModel, SamProcessor
7
  from PIL import Image
 
8
  import cv2
9
  import argparse
10
  import sys
@@ -25,11 +25,6 @@ color_rgbs = [
25
  (1.0, 0.0, 1.0),
26
  ]
27
 
28
- mask_list = []
29
- mask_raw_list = []
30
- mask_list_video = []
31
- mask_raw_list_video = []
32
-
33
  def extract_first_frame_from_video(video):
34
  cap = cv2.VideoCapture(video)
35
  success, frame = cap.read()
@@ -55,9 +50,7 @@ def add_contour(img, mask, color=(1., 1., 1.)):
55
  return img
56
 
57
  @spaces.GPU(duration=120)
58
- def generate_masks(image):
59
- global mask_list
60
- global mask_raw_list
61
  image['image'] = image['background'].convert('RGB')
62
  # del image['background'], image['composite']
63
  assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
@@ -80,13 +73,10 @@ def generate_masks(image):
80
  # Return a list containing the mask image.
81
  image['layers'] = []
82
  image['composite'] = image['background']
83
- return mask_list, image
84
-
85
 
86
  @spaces.GPU(duration=120)
87
- def generate_masks_video(image):
88
- global mask_list_video
89
- global mask_raw_list_video
90
  image['image'] = image['background'].convert('RGB')
91
  # del image['background'], image['composite']
92
  assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
@@ -109,7 +99,7 @@ def generate_masks_video(image):
109
  # Return a list containing the mask image.
110
  image['layers'] = []
111
  image['composite'] = image['background']
112
- return mask_list_video, image
113
 
114
 
115
  @spaces.GPU(duration=120)
@@ -152,13 +142,13 @@ def describe(image, mode, query, masks):
152
  img_with_contour_np = add_contour(img_np, mask_np, color=color_rgb)
153
  img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
154
  else:
155
- masks = mask_raw_list
156
  img_with_contour_np = img_np.copy()
157
 
158
  mask_ids = []
159
  for i, mask_np in enumerate(masks):
160
- img_with_contour_np = add_contour(img_with_contour_np, mask_np, color=color_rgbs[i])
161
- img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
 
162
  mask_ids.append(0)
163
 
164
  masks = np.stack(masks, axis=0)
@@ -214,8 +204,7 @@ def load_first_frame(video_path):
214
  return image
215
 
216
  @spaces.GPU(duration=120)
217
- def describe_video(video_path, mode, query, annotated_frame, masks):
218
- global mask_list_video
219
  # Create a temporary directory to save extracted video frames
220
  cap = cv2.VideoCapture(video_path)
221
 
@@ -267,7 +256,6 @@ def describe_video(video_path, mode, query, annotated_frame, masks):
267
 
268
 
269
  else:
270
- masks = mask_raw_list_video
271
  img_with_contour_np = img_np.copy()
272
 
273
  mask_ids = []
@@ -306,7 +294,7 @@ def describe_video(video_path, mode, query, annotated_frame, masks):
306
  mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(annotated_frame['image'])).astype(np.uint8))
307
  mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>"))
308
  text = ""
309
- yield frame_img, text, mask_list_video
310
 
311
  for token in get_model_output(
312
  video_tensor,
@@ -319,7 +307,7 @@ def describe_video(video_path, mode, query, annotated_frame, masks):
319
  streaming=True,
320
  ):
321
  text += token
322
- yield gr.update(), text, gr.update()
323
 
324
 
325
  @spaces.GPU(duration=120)
@@ -338,20 +326,9 @@ def apply_sam(image, input_points):
338
 
339
  return mask_np
340
 
341
- def clear_masks():
342
- global mask_list
343
- global mask_raw_list
344
- mask_list = []
345
- mask_raw_list = []
346
- return []
347
-
348
 
349
- def clear_masks_video():
350
- global mask_list_video
351
- global mask_raw_list_video
352
- mask_list_video = []
353
- mask_raw_list_video = []
354
- return []
355
 
356
 
357
  if __name__ == "__main__":
@@ -363,10 +340,15 @@ if __name__ == "__main__":
363
  parser.add_argument("--top_p", type=float, default=0.5, help="Top-p for sampling")
364
 
365
  args_cli = parser.parse_args()
366
- print(args_cli.model_path)
367
 
368
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="amber")) as demo:
369
 
 
 
 
 
 
 
370
  HEADER = ("""
371
  <div>
372
  <h1>VideoRefer X VideoLLaMA3 Demo</h1>
@@ -479,75 +461,67 @@ if __name__ == "__main__":
479
  def toggle_query_and_generate_button(mode):
480
  query_visible = mode == "QA"
481
  caption_visible = mode == "Caption"
482
- global mask_list
483
- global mask_raw_list
484
- mask_list = []
485
- mask_raw_list = []
486
- return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), gr.update(visible=caption_visible), [], ""
487
 
488
  video_input.change(load_first_frame, inputs=video_input, outputs=first_frame)
489
 
490
- mode.change(toggle_query_and_generate_button, inputs=mode, outputs=[query, generate_mask_btn, clear_masks_btn, submit_btn1, mask_output, output_image, submit_btn, mask_output, description])
491
 
492
  def toggle_query_and_generate_button_video(mode):
493
  query_visible = mode == "QA"
494
  caption_visible = mode == "Caption"
495
- global mask_list_video
496
- global mask_raw_list_video
497
- mask_list_video = []
498
- mask_raw_list_video = []
499
- return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), []
500
 
501
 
502
- mode_video.change(toggle_query_and_generate_button_video, inputs=mode_video, outputs=[query_video, generate_mask_btn_video, submit_btn_video1, submit_btn_video, mask_output_video])
503
 
504
  submit_btn.click(
505
  fn=describe,
506
- inputs=[image_input, mode, query],
507
  outputs=[output_image, description, image_input],
508
  api_name="describe"
509
  )
510
 
511
  submit_btn1.click(
512
  fn=describe,
513
- inputs=[image_input, mode, query],
514
  outputs=[output_image, description, image_input],
515
  api_name="describe"
516
  )
517
 
518
  generate_mask_btn.click(
519
  fn=generate_masks,
520
- inputs=[image_input],
521
- outputs=[mask_output, image_input]
522
  )
523
 
524
  generate_mask_btn_video.click(
525
  fn=generate_masks_video,
526
- inputs=[first_frame],
527
- outputs=[mask_output_video, first_frame]
528
  )
529
 
530
  clear_masks_btn.click(
531
  fn=clear_masks,
532
- outputs=[mask_output]
533
  )
534
 
535
  clear_masks_btn_video.click(
536
- fn=clear_masks_video,
537
- outputs=[mask_output_video]
538
  )
539
 
540
  submit_btn_video.click(
541
  fn=describe_video,
542
- inputs=[video_input, mode_video, query_video, first_frame],
543
- outputs=[first_frame, description_video, mask_output_video],
544
  api_name="describe_video"
545
  )
546
 
547
  submit_btn_video1.click(
548
  fn=describe_video,
549
- inputs=[video_input, mode_video, query_video, first_frame],
550
- outputs=[first_frame, description_video, mask_output_video],
551
  api_name="describe_video"
552
  )
553
 
 
1
  import spaces
2
  import gradio as gr
3
  import numpy as np
 
4
  import torch
5
  from transformers import SamModel, SamProcessor
6
  from PIL import Image
7
+ import os
8
  import cv2
9
  import argparse
10
  import sys
 
25
  (1.0, 0.0, 1.0),
26
  ]
27
 
 
 
 
 
 
28
  def extract_first_frame_from_video(video):
29
  cap = cv2.VideoCapture(video)
30
  success, frame = cap.read()
 
50
  return img
51
 
52
  @spaces.GPU(duration=120)
53
+ def generate_masks(image, mask_list, mask_raw_list):
 
 
54
  image['image'] = image['background'].convert('RGB')
55
  # del image['background'], image['composite']
56
  assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
 
73
  # Return a list containing the mask image.
74
  image['layers'] = []
75
  image['composite'] = image['background']
76
+ return mask_list, image, mask_list, mask_raw_list
 
77
 
78
  @spaces.GPU(duration=120)
79
+ def generate_masks_video(image, mask_list_video, mask_raw_list_video):
 
 
80
  image['image'] = image['background'].convert('RGB')
81
  # del image['background'], image['composite']
82
  assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
 
99
  # Return a list containing the mask image.
100
  image['layers'] = []
101
  image['composite'] = image['background']
102
+ return mask_list_video, image, mask_list_video, mask_raw_list_video
103
 
104
 
105
  @spaces.GPU(duration=120)
 
142
  img_with_contour_np = add_contour(img_np, mask_np, color=color_rgb)
143
  img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
144
  else:
 
145
  img_with_contour_np = img_np.copy()
146
 
147
  mask_ids = []
148
  for i, mask_np in enumerate(masks):
149
+ # img_with_contour_np = add_contour(img_with_contour_np, mask_np, color=color_rgbs[i])
150
+ # img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
151
+ img_with_contour_pil = Image.fromarray((img_with_contour_np* 255.).astype(np.uint8))
152
  mask_ids.append(0)
153
 
154
  masks = np.stack(masks, axis=0)
 
204
  return image
205
 
206
  @spaces.GPU(duration=120)
207
+ def describe_video(video_path, mode, query, annotated_frame, masks, mask_list_video):
 
208
  # Create a temporary directory to save extracted video frames
209
  cap = cv2.VideoCapture(video_path)
210
 
 
256
 
257
 
258
  else:
 
259
  img_with_contour_np = img_np.copy()
260
 
261
  mask_ids = []
 
294
  mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(annotated_frame['image'])).astype(np.uint8))
295
  mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>"))
296
  text = ""
297
+ yield frame_img, text, mask_list_video, mask_list_video
298
 
299
  for token in get_model_output(
300
  video_tensor,
 
307
  streaming=True,
308
  ):
309
  text += token
310
+ yield gr.update(), text, gr.update(), gr.update()
311
 
312
 
313
  @spaces.GPU(duration=120)
 
326
 
327
  return mask_np
328
 
 
 
 
 
 
 
 
329
 
330
+ def clear_masks():
331
+ return [], [], []
 
 
 
 
332
 
333
 
334
  if __name__ == "__main__":
 
340
  parser.add_argument("--top_p", type=float, default=0.5, help="Top-p for sampling")
341
 
342
  args_cli = parser.parse_args()
 
343
 
344
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="amber")) as demo:
345
 
346
+ mask_list = gr.State([])
347
+ mask_raw_list = gr.State([])
348
+ mask_list_video = gr.State([])
349
+ mask_raw_list_video = gr.State([])
350
+
351
+
352
  HEADER = ("""
353
  <div>
354
  <h1>VideoRefer X VideoLLaMA3 Demo</h1>
 
461
  def toggle_query_and_generate_button(mode):
462
  query_visible = mode == "QA"
463
  caption_visible = mode == "Caption"
464
+ return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), gr.update(visible=caption_visible), [], "", [], [],[],[]
 
 
 
 
465
 
466
  video_input.change(load_first_frame, inputs=video_input, outputs=first_frame)
467
 
468
+ mode.change(toggle_query_and_generate_button, inputs=mode, outputs=[query, generate_mask_btn, clear_masks_btn, submit_btn1, mask_output, output_image, submit_btn, mask_output, description, mask_list, mask_raw_list, mask_list_video, mask_raw_list_video])
469
 
470
  def toggle_query_and_generate_button_video(mode):
471
  query_visible = mode == "QA"
472
  caption_visible = mode == "Caption"
473
+ return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), [], [], [], [], []
 
 
 
 
474
 
475
 
476
+ mode_video.change(toggle_query_and_generate_button_video, inputs=mode_video, outputs=[query_video, generate_mask_btn_video, submit_btn_video1, submit_btn_video, mask_output_video, mask_list, mask_raw_list, mask_list_video, mask_raw_list_video])
477
 
478
  submit_btn.click(
479
  fn=describe,
480
+ inputs=[image_input, mode, query, mask_raw_list],
481
  outputs=[output_image, description, image_input],
482
  api_name="describe"
483
  )
484
 
485
  submit_btn1.click(
486
  fn=describe,
487
+ inputs=[image_input, mode, query, mask_raw_list],
488
  outputs=[output_image, description, image_input],
489
  api_name="describe"
490
  )
491
 
492
  generate_mask_btn.click(
493
  fn=generate_masks,
494
+ inputs=[image_input, mask_list, mask_raw_list],
495
+ outputs=[mask_output, image_input, mask_list, mask_raw_list]
496
  )
497
 
498
  generate_mask_btn_video.click(
499
  fn=generate_masks_video,
500
+ inputs=[first_frame, mask_list_video, mask_raw_list_video],
501
+ outputs=[mask_output_video, first_frame, mask_list_video, mask_raw_list_video]
502
  )
503
 
504
  clear_masks_btn.click(
505
  fn=clear_masks,
506
+ outputs=[mask_output, mask_list, mask_raw_list]
507
  )
508
 
509
  clear_masks_btn_video.click(
510
+ fn=clear_masks,
511
+ outputs=[mask_output_video, mask_list_video, mask_raw_list_video]
512
  )
513
 
514
  submit_btn_video.click(
515
  fn=describe_video,
516
+ inputs=[video_input, mode_video, query_video, first_frame, mask_raw_list_video, mask_list_video],
517
+ outputs=[first_frame, description_video, mask_output_video, mask_list_video],
518
  api_name="describe_video"
519
  )
520
 
521
  submit_btn_video1.click(
522
  fn=describe_video,
523
+ inputs=[video_input, mode_video, query_video, first_frame, mask_raw_list_video, mask_list_video],
524
+ outputs=[first_frame, description_video, mask_output_video, mask_list_video],
525
  api_name="describe_video"
526
  )
527