Spaces:
Running
on
Zero
Running
on
Zero
Update app.py (#6)
Browse files- Update app.py (60ccb1ef6748c9246dc32f76f0cd54040003ebde)
Co-authored-by: YuqianYuan <[email protected]>
app.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
import spaces
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
4 |
-
import os
|
5 |
import torch
|
6 |
from transformers import SamModel, SamProcessor
|
7 |
from PIL import Image
|
|
|
8 |
import cv2
|
9 |
import argparse
|
10 |
import sys
|
@@ -25,11 +25,6 @@ color_rgbs = [
|
|
25 |
(1.0, 0.0, 1.0),
|
26 |
]
|
27 |
|
28 |
-
mask_list = []
|
29 |
-
mask_raw_list = []
|
30 |
-
mask_list_video = []
|
31 |
-
mask_raw_list_video = []
|
32 |
-
|
33 |
def extract_first_frame_from_video(video):
|
34 |
cap = cv2.VideoCapture(video)
|
35 |
success, frame = cap.read()
|
@@ -55,9 +50,7 @@ def add_contour(img, mask, color=(1., 1., 1.)):
|
|
55 |
return img
|
56 |
|
57 |
@spaces.GPU(duration=120)
|
58 |
-
def generate_masks(image):
|
59 |
-
global mask_list
|
60 |
-
global mask_raw_list
|
61 |
image['image'] = image['background'].convert('RGB')
|
62 |
# del image['background'], image['composite']
|
63 |
assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
|
@@ -80,13 +73,10 @@ def generate_masks(image):
|
|
80 |
# Return a list containing the mask image.
|
81 |
image['layers'] = []
|
82 |
image['composite'] = image['background']
|
83 |
-
return mask_list, image
|
84 |
-
|
85 |
|
86 |
@spaces.GPU(duration=120)
|
87 |
-
def generate_masks_video(image):
|
88 |
-
global mask_list_video
|
89 |
-
global mask_raw_list_video
|
90 |
image['image'] = image['background'].convert('RGB')
|
91 |
# del image['background'], image['composite']
|
92 |
assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
|
@@ -109,7 +99,7 @@ def generate_masks_video(image):
|
|
109 |
# Return a list containing the mask image.
|
110 |
image['layers'] = []
|
111 |
image['composite'] = image['background']
|
112 |
-
return mask_list_video, image
|
113 |
|
114 |
|
115 |
@spaces.GPU(duration=120)
|
@@ -152,13 +142,13 @@ def describe(image, mode, query, masks):
|
|
152 |
img_with_contour_np = add_contour(img_np, mask_np, color=color_rgb)
|
153 |
img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
|
154 |
else:
|
155 |
-
masks = mask_raw_list
|
156 |
img_with_contour_np = img_np.copy()
|
157 |
|
158 |
mask_ids = []
|
159 |
for i, mask_np in enumerate(masks):
|
160 |
-
img_with_contour_np = add_contour(img_with_contour_np, mask_np, color=color_rgbs[i])
|
161 |
-
img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
|
|
|
162 |
mask_ids.append(0)
|
163 |
|
164 |
masks = np.stack(masks, axis=0)
|
@@ -214,8 +204,7 @@ def load_first_frame(video_path):
|
|
214 |
return image
|
215 |
|
216 |
@spaces.GPU(duration=120)
|
217 |
-
def describe_video(video_path, mode, query, annotated_frame, masks):
|
218 |
-
global mask_list_video
|
219 |
# Create a temporary directory to save extracted video frames
|
220 |
cap = cv2.VideoCapture(video_path)
|
221 |
|
@@ -267,7 +256,6 @@ def describe_video(video_path, mode, query, annotated_frame, masks):
|
|
267 |
|
268 |
|
269 |
else:
|
270 |
-
masks = mask_raw_list_video
|
271 |
img_with_contour_np = img_np.copy()
|
272 |
|
273 |
mask_ids = []
|
@@ -306,7 +294,7 @@ def describe_video(video_path, mode, query, annotated_frame, masks):
|
|
306 |
mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(annotated_frame['image'])).astype(np.uint8))
|
307 |
mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>"))
|
308 |
text = ""
|
309 |
-
yield frame_img, text, mask_list_video
|
310 |
|
311 |
for token in get_model_output(
|
312 |
video_tensor,
|
@@ -319,7 +307,7 @@ def describe_video(video_path, mode, query, annotated_frame, masks):
|
|
319 |
streaming=True,
|
320 |
):
|
321 |
text += token
|
322 |
-
yield gr.update(), text, gr.update()
|
323 |
|
324 |
|
325 |
@spaces.GPU(duration=120)
|
@@ -338,20 +326,9 @@ def apply_sam(image, input_points):
|
|
338 |
|
339 |
return mask_np
|
340 |
|
341 |
-
def clear_masks():
|
342 |
-
global mask_list
|
343 |
-
global mask_raw_list
|
344 |
-
mask_list = []
|
345 |
-
mask_raw_list = []
|
346 |
-
return []
|
347 |
-
|
348 |
|
349 |
-
def
|
350 |
-
|
351 |
-
global mask_raw_list_video
|
352 |
-
mask_list_video = []
|
353 |
-
mask_raw_list_video = []
|
354 |
-
return []
|
355 |
|
356 |
|
357 |
if __name__ == "__main__":
|
@@ -363,10 +340,15 @@ if __name__ == "__main__":
|
|
363 |
parser.add_argument("--top_p", type=float, default=0.5, help="Top-p for sampling")
|
364 |
|
365 |
args_cli = parser.parse_args()
|
366 |
-
print(args_cli.model_path)
|
367 |
|
368 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="amber")) as demo:
|
369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
HEADER = ("""
|
371 |
<div>
|
372 |
<h1>VideoRefer X VideoLLaMA3 Demo</h1>
|
@@ -479,75 +461,67 @@ if __name__ == "__main__":
|
|
479 |
def toggle_query_and_generate_button(mode):
|
480 |
query_visible = mode == "QA"
|
481 |
caption_visible = mode == "Caption"
|
482 |
-
|
483 |
-
global mask_raw_list
|
484 |
-
mask_list = []
|
485 |
-
mask_raw_list = []
|
486 |
-
return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), gr.update(visible=caption_visible), [], ""
|
487 |
|
488 |
video_input.change(load_first_frame, inputs=video_input, outputs=first_frame)
|
489 |
|
490 |
-
mode.change(toggle_query_and_generate_button, inputs=mode, outputs=[query, generate_mask_btn, clear_masks_btn, submit_btn1, mask_output, output_image, submit_btn, mask_output, description])
|
491 |
|
492 |
def toggle_query_and_generate_button_video(mode):
|
493 |
query_visible = mode == "QA"
|
494 |
caption_visible = mode == "Caption"
|
495 |
-
|
496 |
-
global mask_raw_list_video
|
497 |
-
mask_list_video = []
|
498 |
-
mask_raw_list_video = []
|
499 |
-
return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), []
|
500 |
|
501 |
|
502 |
-
mode_video.change(toggle_query_and_generate_button_video, inputs=mode_video, outputs=[query_video, generate_mask_btn_video, submit_btn_video1, submit_btn_video, mask_output_video])
|
503 |
|
504 |
submit_btn.click(
|
505 |
fn=describe,
|
506 |
-
inputs=[image_input, mode, query],
|
507 |
outputs=[output_image, description, image_input],
|
508 |
api_name="describe"
|
509 |
)
|
510 |
|
511 |
submit_btn1.click(
|
512 |
fn=describe,
|
513 |
-
inputs=[image_input, mode, query],
|
514 |
outputs=[output_image, description, image_input],
|
515 |
api_name="describe"
|
516 |
)
|
517 |
|
518 |
generate_mask_btn.click(
|
519 |
fn=generate_masks,
|
520 |
-
inputs=[image_input],
|
521 |
-
outputs=[mask_output, image_input]
|
522 |
)
|
523 |
|
524 |
generate_mask_btn_video.click(
|
525 |
fn=generate_masks_video,
|
526 |
-
inputs=[first_frame],
|
527 |
-
outputs=[mask_output_video, first_frame]
|
528 |
)
|
529 |
|
530 |
clear_masks_btn.click(
|
531 |
fn=clear_masks,
|
532 |
-
outputs=[mask_output]
|
533 |
)
|
534 |
|
535 |
clear_masks_btn_video.click(
|
536 |
-
fn=
|
537 |
-
outputs=[mask_output_video]
|
538 |
)
|
539 |
|
540 |
submit_btn_video.click(
|
541 |
fn=describe_video,
|
542 |
-
inputs=[video_input, mode_video, query_video, first_frame],
|
543 |
-
outputs=[first_frame, description_video, mask_output_video],
|
544 |
api_name="describe_video"
|
545 |
)
|
546 |
|
547 |
submit_btn_video1.click(
|
548 |
fn=describe_video,
|
549 |
-
inputs=[video_input, mode_video, query_video, first_frame],
|
550 |
-
outputs=[first_frame, description_video, mask_output_video],
|
551 |
api_name="describe_video"
|
552 |
)
|
553 |
|
|
|
1 |
import spaces
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
|
|
4 |
import torch
|
5 |
from transformers import SamModel, SamProcessor
|
6 |
from PIL import Image
|
7 |
+
import os
|
8 |
import cv2
|
9 |
import argparse
|
10 |
import sys
|
|
|
25 |
(1.0, 0.0, 1.0),
|
26 |
]
|
27 |
|
|
|
|
|
|
|
|
|
|
|
28 |
def extract_first_frame_from_video(video):
|
29 |
cap = cv2.VideoCapture(video)
|
30 |
success, frame = cap.read()
|
|
|
50 |
return img
|
51 |
|
52 |
@spaces.GPU(duration=120)
|
53 |
+
def generate_masks(image, mask_list, mask_raw_list):
|
|
|
|
|
54 |
image['image'] = image['background'].convert('RGB')
|
55 |
# del image['background'], image['composite']
|
56 |
assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
|
|
|
73 |
# Return a list containing the mask image.
|
74 |
image['layers'] = []
|
75 |
image['composite'] = image['background']
|
76 |
+
return mask_list, image, mask_list, mask_raw_list
|
|
|
77 |
|
78 |
@spaces.GPU(duration=120)
|
79 |
+
def generate_masks_video(image, mask_list_video, mask_raw_list_video):
|
|
|
|
|
80 |
image['image'] = image['background'].convert('RGB')
|
81 |
# del image['background'], image['composite']
|
82 |
assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
|
|
|
99 |
# Return a list containing the mask image.
|
100 |
image['layers'] = []
|
101 |
image['composite'] = image['background']
|
102 |
+
return mask_list_video, image, mask_list_video, mask_raw_list_video
|
103 |
|
104 |
|
105 |
@spaces.GPU(duration=120)
|
|
|
142 |
img_with_contour_np = add_contour(img_np, mask_np, color=color_rgb)
|
143 |
img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
|
144 |
else:
|
|
|
145 |
img_with_contour_np = img_np.copy()
|
146 |
|
147 |
mask_ids = []
|
148 |
for i, mask_np in enumerate(masks):
|
149 |
+
# img_with_contour_np = add_contour(img_with_contour_np, mask_np, color=color_rgbs[i])
|
150 |
+
# img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
|
151 |
+
img_with_contour_pil = Image.fromarray((img_with_contour_np* 255.).astype(np.uint8))
|
152 |
mask_ids.append(0)
|
153 |
|
154 |
masks = np.stack(masks, axis=0)
|
|
|
204 |
return image
|
205 |
|
206 |
@spaces.GPU(duration=120)
|
207 |
+
def describe_video(video_path, mode, query, annotated_frame, masks, mask_list_video):
|
|
|
208 |
# Create a temporary directory to save extracted video frames
|
209 |
cap = cv2.VideoCapture(video_path)
|
210 |
|
|
|
256 |
|
257 |
|
258 |
else:
|
|
|
259 |
img_with_contour_np = img_np.copy()
|
260 |
|
261 |
mask_ids = []
|
|
|
294 |
mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(annotated_frame['image'])).astype(np.uint8))
|
295 |
mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>"))
|
296 |
text = ""
|
297 |
+
yield frame_img, text, mask_list_video, mask_list_video
|
298 |
|
299 |
for token in get_model_output(
|
300 |
video_tensor,
|
|
|
307 |
streaming=True,
|
308 |
):
|
309 |
text += token
|
310 |
+
yield gr.update(), text, gr.update(), gr.update()
|
311 |
|
312 |
|
313 |
@spaces.GPU(duration=120)
|
|
|
326 |
|
327 |
return mask_np
|
328 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
|
330 |
+
def clear_masks():
|
331 |
+
return [], [], []
|
|
|
|
|
|
|
|
|
332 |
|
333 |
|
334 |
if __name__ == "__main__":
|
|
|
340 |
parser.add_argument("--top_p", type=float, default=0.5, help="Top-p for sampling")
|
341 |
|
342 |
args_cli = parser.parse_args()
|
|
|
343 |
|
344 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="amber")) as demo:
|
345 |
|
346 |
+
mask_list = gr.State([])
|
347 |
+
mask_raw_list = gr.State([])
|
348 |
+
mask_list_video = gr.State([])
|
349 |
+
mask_raw_list_video = gr.State([])
|
350 |
+
|
351 |
+
|
352 |
HEADER = ("""
|
353 |
<div>
|
354 |
<h1>VideoRefer X VideoLLaMA3 Demo</h1>
|
|
|
461 |
def toggle_query_and_generate_button(mode):
|
462 |
query_visible = mode == "QA"
|
463 |
caption_visible = mode == "Caption"
|
464 |
+
return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), gr.update(visible=caption_visible), [], "", [], [],[],[]
|
|
|
|
|
|
|
|
|
465 |
|
466 |
video_input.change(load_first_frame, inputs=video_input, outputs=first_frame)
|
467 |
|
468 |
+
mode.change(toggle_query_and_generate_button, inputs=mode, outputs=[query, generate_mask_btn, clear_masks_btn, submit_btn1, mask_output, output_image, submit_btn, mask_output, description, mask_list, mask_raw_list, mask_list_video, mask_raw_list_video])
|
469 |
|
470 |
def toggle_query_and_generate_button_video(mode):
|
471 |
query_visible = mode == "QA"
|
472 |
caption_visible = mode == "Caption"
|
473 |
+
return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), [], [], [], [], []
|
|
|
|
|
|
|
|
|
474 |
|
475 |
|
476 |
+
mode_video.change(toggle_query_and_generate_button_video, inputs=mode_video, outputs=[query_video, generate_mask_btn_video, submit_btn_video1, submit_btn_video, mask_output_video, mask_list, mask_raw_list, mask_list_video, mask_raw_list_video])
|
477 |
|
478 |
submit_btn.click(
|
479 |
fn=describe,
|
480 |
+
inputs=[image_input, mode, query, mask_raw_list],
|
481 |
outputs=[output_image, description, image_input],
|
482 |
api_name="describe"
|
483 |
)
|
484 |
|
485 |
submit_btn1.click(
|
486 |
fn=describe,
|
487 |
+
inputs=[image_input, mode, query, mask_raw_list],
|
488 |
outputs=[output_image, description, image_input],
|
489 |
api_name="describe"
|
490 |
)
|
491 |
|
492 |
generate_mask_btn.click(
|
493 |
fn=generate_masks,
|
494 |
+
inputs=[image_input, mask_list, mask_raw_list],
|
495 |
+
outputs=[mask_output, image_input, mask_list, mask_raw_list]
|
496 |
)
|
497 |
|
498 |
generate_mask_btn_video.click(
|
499 |
fn=generate_masks_video,
|
500 |
+
inputs=[first_frame, mask_list_video, mask_raw_list_video],
|
501 |
+
outputs=[mask_output_video, first_frame, mask_list_video, mask_raw_list_video]
|
502 |
)
|
503 |
|
504 |
clear_masks_btn.click(
|
505 |
fn=clear_masks,
|
506 |
+
outputs=[mask_output, mask_list, mask_raw_list]
|
507 |
)
|
508 |
|
509 |
clear_masks_btn_video.click(
|
510 |
+
fn=clear_masks,
|
511 |
+
outputs=[mask_output_video, mask_list_video, mask_raw_list_video]
|
512 |
)
|
513 |
|
514 |
submit_btn_video.click(
|
515 |
fn=describe_video,
|
516 |
+
inputs=[video_input, mode_video, query_video, first_frame, mask_raw_list_video, mask_list_video],
|
517 |
+
outputs=[first_frame, description_video, mask_output_video, mask_list_video],
|
518 |
api_name="describe_video"
|
519 |
)
|
520 |
|
521 |
submit_btn_video1.click(
|
522 |
fn=describe_video,
|
523 |
+
inputs=[video_input, mode_video, query_video, first_frame, mask_raw_list_video, mask_list_video],
|
524 |
+
outputs=[first_frame, description_video, mask_output_video, mask_list_video],
|
525 |
api_name="describe_video"
|
526 |
)
|
527 |
|