Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -52,15 +52,20 @@ def add_contour(img, mask, color=(1., 1., 1.)):
|
|
52 |
@spaces.GPU(duration=120)
|
53 |
def generate_masks(image, mask_list, mask_raw_list):
|
54 |
"""
|
55 |
-
|
56 |
-
|
57 |
Args:
|
58 |
-
image:
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
62 |
Returns:
|
63 |
-
|
|
|
|
|
|
|
|
|
64 |
"""
|
65 |
image['image'] = image['background'].convert('RGB')
|
66 |
# del image['background'], image['composite']
|
@@ -89,15 +94,21 @@ def generate_masks(image, mask_list, mask_raw_list):
|
|
89 |
@spaces.GPU(duration=120)
|
90 |
def generate_masks_video(image, mask_list_video, mask_raw_list_video):
|
91 |
"""
|
92 |
-
|
93 |
-
|
94 |
Args:
|
95 |
-
image:
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
|
|
99 |
Returns:
|
100 |
-
|
|
|
|
|
|
|
|
|
101 |
"""
|
102 |
image['image'] = image['background'].convert('RGB')
|
103 |
# del image['background'], image['composite']
|
@@ -127,16 +138,19 @@ def generate_masks_video(image, mask_list_video, mask_raw_list_video):
|
|
127 |
@spaces.GPU(duration=120)
|
128 |
def describe(image, mode, query, masks):
|
129 |
"""
|
130 |
-
|
131 |
-
|
132 |
Args:
|
133 |
-
image:
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
|
|
|
|
|
|
140 |
"""
|
141 |
# Create an image object from the uploaded image
|
142 |
# print(image.keys())
|
@@ -229,13 +243,16 @@ def describe(image, mode, query, masks):
|
|
229 |
|
230 |
def load_first_frame(video_path):
|
231 |
"""
|
232 |
-
|
233 |
-
|
234 |
Args:
|
235 |
-
video_path:
|
236 |
-
|
237 |
Returns:
|
238 |
-
PIL
|
|
|
|
|
|
|
239 |
"""
|
240 |
cap = cv2.VideoCapture(video_path)
|
241 |
ret, frame = cap.read()
|
@@ -249,18 +266,23 @@ def load_first_frame(video_path):
|
|
249 |
@spaces.GPU(duration=120)
|
250 |
def describe_video(video_path, mode, query, annotated_frame, masks, mask_list_video):
|
251 |
"""
|
252 |
-
|
253 |
-
|
254 |
Args:
|
255 |
-
video_path:
|
256 |
-
mode:
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
|
|
|
|
|
|
|
|
|
|
264 |
"""
|
265 |
# Create a temporary directory to save extracted video frames
|
266 |
cap = cv2.VideoCapture(video_path)
|
@@ -351,4 +373,286 @@ def describe_video(video_path, mode, query, annotated_frame, masks, mask_list_vi
|
|
351 |
mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(annotated_frame['image'])).astype(np.uint8))
|
352 |
mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>"))
|
353 |
text = ""
|
354 |
-
yield frame_img, text, mask_list_video, mask_list_video
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
@spaces.GPU(duration=120)
|
53 |
def generate_masks(image, mask_list, mask_raw_list):
|
54 |
"""
|
55 |
+
Generates segmentation masks for selected regions in an image using SAM.
|
56 |
+
|
57 |
Args:
|
58 |
+
image (dict): A dictionary containing image data, typically from a Gradio ImageEditor,
|
59 |
+
with 'background' (PIL Image) and 'layers' (list of PIL Image layers).
|
60 |
+
mask_list (list): A list to accumulate (mask_image, label) tuples for display in a gallery.
|
61 |
+
mask_raw_list (list): A list to accumulate raw NumPy mask arrays.
|
62 |
+
|
63 |
Returns:
|
64 |
+
tuple: A tuple containing:
|
65 |
+
- mask_list (list): Updated list of mask images for display.
|
66 |
+
- image (dict): Updated image dictionary with layers cleared.
|
67 |
+
- mask_list (list): Redundant return of mask_list (for Gradio update).
|
68 |
+
- mask_raw_list (list): Updated list of raw mask arrays.
|
69 |
"""
|
70 |
image['image'] = image['background'].convert('RGB')
|
71 |
# del image['background'], image['composite']
|
|
|
94 |
@spaces.GPU(duration=120)
|
95 |
def generate_masks_video(image, mask_list_video, mask_raw_list_video):
|
96 |
"""
|
97 |
+
Generates segmentation masks for selected regions in the first frame of a video using SAM.
|
98 |
+
|
99 |
Args:
|
100 |
+
image (dict): A dictionary containing image data (first frame of video),
|
101 |
+
typically from a Gradio ImageEditor, with 'background' (PIL Image)
|
102 |
+
and 'layers' (list of PIL Image layers).
|
103 |
+
mask_list_video (list): A list to accumulate (mask_image, label) tuples for display.
|
104 |
+
mask_raw_list_video (list): A list to accumulate raw NumPy mask arrays for video processing.
|
105 |
+
|
106 |
Returns:
|
107 |
+
tuple: A tuple containing:
|
108 |
+
- mask_list_video (list): Updated list of mask images for display.
|
109 |
+
- image (dict): Updated image dictionary with layers cleared.
|
110 |
+
- mask_list_video (list): Redundant return of mask_list_video (for Gradio update).
|
111 |
+
- mask_raw_list_video (list): Updated list of raw mask arrays.
|
112 |
"""
|
113 |
image['image'] = image['background'].convert('RGB')
|
114 |
# del image['background'], image['composite']
|
|
|
138 |
@spaces.GPU(duration=120)
|
139 |
def describe(image, mode, query, masks):
|
140 |
"""
|
141 |
+
Describes an image based on selected regions or answers a question about them.
|
142 |
+
|
143 |
Args:
|
144 |
+
image (dict): A dictionary containing image data, typically from a Gradio ImageEditor,
|
145 |
+
with 'background' (PIL Image) and 'layers' (list of PIL Image layers).
|
146 |
+
mode (str): The operational mode, either "Caption" (to describe a selected region)
|
147 |
+
or "QA" (to answer a question about one or more regions).
|
148 |
+
query (str): The question to ask in "QA" mode. Ignored in "Caption" mode.
|
149 |
+
masks (list): A list of raw NumPy mask arrays representing previously generated masks.
|
150 |
+
|
151 |
+
Yields:
|
152 |
+
tuple: An image with contours and the generated text description/answer,
|
153 |
+
or updates for Gradio components during streaming.
|
154 |
"""
|
155 |
# Create an image object from the uploaded image
|
156 |
# print(image.keys())
|
|
|
243 |
|
244 |
def load_first_frame(video_path):
|
245 |
"""
|
246 |
+
Loads the first frame of a given video file.
|
247 |
+
|
248 |
Args:
|
249 |
+
video_path (str): The file path to the video.
|
250 |
+
|
251 |
Returns:
|
252 |
+
PIL.Image.Image: The first frame of the video as a PIL Image.
|
253 |
+
|
254 |
+
Raises:
|
255 |
+
gr.Error: If the video file cannot be read.
|
256 |
"""
|
257 |
cap = cv2.VideoCapture(video_path)
|
258 |
ret, frame = cap.read()
|
|
|
266 |
@spaces.GPU(duration=120)
|
267 |
def describe_video(video_path, mode, query, annotated_frame, masks, mask_list_video):
|
268 |
"""
|
269 |
+
Describes a video based on selected regions in its first frame or answers a question about them.
|
270 |
+
|
271 |
Args:
|
272 |
+
video_path (str): The file path to the video.
|
273 |
+
mode (str): The operational mode, either "Caption" (to describe a selected region)
|
274 |
+
or "QA" (to answer a question about one or more regions).
|
275 |
+
query (str): The question to ask in "QA" mode. Ignored in "Caption" mode.
|
276 |
+
annotated_frame (dict): A dictionary containing the first frame's image data
|
277 |
+
from a Gradio ImageEditor, with 'background' (PIL Image)
|
278 |
+
and 'layers' (list of PIL Image layers).
|
279 |
+
masks (list): A list of raw NumPy mask arrays representing previously generated masks
|
280 |
+
for objects in the video.
|
281 |
+
mask_list_video (list): A list to accumulate (mask_image, label) tuples for display.
|
282 |
+
|
283 |
+
Yields:
|
284 |
+
tuple: The annotated first frame, the generated text description/answer,
|
285 |
+
and updated mask lists for Gradio components during streaming.
|
286 |
"""
|
287 |
# Create a temporary directory to save extracted video frames
|
288 |
cap = cv2.VideoCapture(video_path)
|
|
|
373 |
mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(annotated_frame['image'])).astype(np.uint8))
|
374 |
mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>"))
|
375 |
text = ""
|
376 |
+
yield frame_img, text, mask_list_video, mask_list_video
|
377 |
+
|
378 |
+
for token in get_model_output(
|
379 |
+
video_tensor,
|
380 |
+
query,
|
381 |
+
model=model,
|
382 |
+
tokenizer=tokenizer,
|
383 |
+
masks=masks,
|
384 |
+
mask_ids=mask_ids,
|
385 |
+
modal='video',
|
386 |
+
streaming=True,
|
387 |
+
):
|
388 |
+
text += token
|
389 |
+
yield gr.update(), text, gr.update(), gr.update()
|
390 |
+
|
391 |
+
|
392 |
+
@spaces.GPU(duration=120)
|
393 |
+
def apply_sam(image, input_points):
|
394 |
+
"""
|
395 |
+
Applies the Segment Anything Model (SAM) to an image based on input points
|
396 |
+
to generate a segmentation mask.
|
397 |
+
|
398 |
+
Args:
|
399 |
+
image (PIL.Image.Image): The input image.
|
400 |
+
input_points (list): A list of lists, where each inner list contains
|
401 |
+
[x, y] coordinates representing points used for segmentation.
|
402 |
+
|
403 |
+
Returns:
|
404 |
+
numpy.ndarray: The selected binary segmentation mask as a NumPy array (H, W).
|
405 |
+
"""
|
406 |
+
inputs = sam_processor(image, input_points=input_points, return_tensors="pt").to(device)
|
407 |
+
|
408 |
+
with torch.no_grad():
|
409 |
+
outputs = sam_model(**inputs)
|
410 |
+
|
411 |
+
masks = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())[0][0]
|
412 |
+
scores = outputs.iou_scores[0, 0]
|
413 |
+
|
414 |
+
mask_selection_index = scores.argmax()
|
415 |
+
|
416 |
+
mask_np = masks[mask_selection_index].numpy()
|
417 |
+
|
418 |
+
return mask_np
|
419 |
+
|
420 |
+
|
421 |
+
def clear_masks():
|
422 |
+
"""
|
423 |
+
Clears the stored lists of masks and raw masks.
|
424 |
+
|
425 |
+
Returns:
|
426 |
+
tuple: Three empty lists, intended to reset Gradio components
|
427 |
+
displaying masks.
|
428 |
+
"""
|
429 |
+
return [], [], []
|
430 |
+
|
431 |
+
|
432 |
+
if __name__ == "__main__":
|
433 |
+
parser = argparse.ArgumentParser(description="VideoRefer gradio demo")
|
434 |
+
parser.add_argument("--model-path", type=str, default="DAMO-NLP-SG/VideoRefer-VideoLLaMA3-7B", help="Path to the model checkpoint")
|
435 |
+
parser.add_argument("--prompt-mode", type=str, default="focal_prompt", help="Prompt mode")
|
436 |
+
parser.add_argument("--conv-mode", type=str, default="v1", help="Conversation mode")
|
437 |
+
parser.add_argument("--temperature", type=float, default=0.2, help="Sampling temperature")
|
438 |
+
parser.add_argument("--top_p", type=float, default=0.5, help="Top-p for sampling")
|
439 |
+
|
440 |
+
args_cli = parser.parse_args()
|
441 |
+
|
442 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="amber")) as demo:
|
443 |
+
|
444 |
+
mask_list = gr.State([])
|
445 |
+
mask_raw_list = gr.State([])
|
446 |
+
mask_list_video = gr.State([])
|
447 |
+
mask_raw_list_video = gr.State([])
|
448 |
+
|
449 |
+
|
450 |
+
HEADER = ("""
|
451 |
+
<div>
|
452 |
+
<h1>VideoRefer X VideoLLaMA3 Demo</h1>
|
453 |
+
<h5 style="margin: 0;">Feel free to click on anything that grabs your interest!</h5>
|
454 |
+
<h5 style="margin: 0;">If this demo please you, please give us a star ⭐ on Github or 💖 on this space.</h5>
|
455 |
+
</div>
|
456 |
+
</div>
|
457 |
+
<div style="display: flex; justify-content: left; margin-top: 10px;">
|
458 |
+
<a href="https://arxiv.org/pdf/2501.00599"><img src="https://img.shields.io/badge/Arxiv-2501.00599-ECA8A7" style="margin-right: 5px;"></a>
|
459 |
+
<a href="https://github.com/DAMO-NLP-SG/VideoRefer"><img src='https://img.shields.io/badge/Github-VideoRefer-F7C97E' style="margin-right: 5px;"></a>
|
460 |
+
<a href="https://github.com/DAMO-NLP-SG/VideoLLaMA3"><img src='https://img.shields.io/badge/Github-VideoLLaMA3-9DC3E6' style="margin-right: 5px;"></a>
|
461 |
+
</div>
|
462 |
+
""")
|
463 |
+
|
464 |
+
with gr.Row():
|
465 |
+
with gr.Column():
|
466 |
+
gr.HTML(HEADER)
|
467 |
+
|
468 |
+
|
469 |
+
image_tips = """
|
470 |
+
### 💡 Tips:
|
471 |
+
|
472 |
+
🧸 Upload an image, and you can use the drawing tool✍️ to highlight the areas you're interested in.
|
473 |
+
|
474 |
+
🔖 For single-object caption mode, simply select the area and click the 'Generate Caption' button to receive a caption for the object.
|
475 |
+
|
476 |
+
🔔 In QA mode, you can generate multiple masks by clicking the 'Generate Mask' button multiple times. Afterward, use the corresponding object id to ask questions.
|
477 |
+
|
478 |
+
📌 Click the button 'Clear Masks' to clear the current generated masks.
|
479 |
+
|
480 |
+
"""
|
481 |
+
|
482 |
+
video_tips = """
|
483 |
+
### 💡 Tips:
|
484 |
+
⚠️ For video mode, we only support masking on the first frame in this demo.
|
485 |
+
|
486 |
+
🧸 Upload an video, and you can use the drawing tool✍️ to highlight the areas you're interested in the first frame.
|
487 |
+
|
488 |
+
🔖 For single-object caption mode, simply select the area and click the 'Generate Caption' button to receive a caption for the object.
|
489 |
+
|
490 |
+
🔔 In QA mode, you can generate multiple masks by clicking the 'Generate Mask' button multiple times. Afterward, use the corresponding object id to ask questions.
|
491 |
+
|
492 |
+
📌 Click the button 'Clear Masks' to clear the current generated masks.
|
493 |
+
|
494 |
+
"""
|
495 |
+
|
496 |
+
|
497 |
+
with gr.TabItem("Image"):
|
498 |
+
with gr.Row():
|
499 |
+
with gr.Column():
|
500 |
+
image_input = gr.ImageEditor(
|
501 |
+
label="Image",
|
502 |
+
type="pil",
|
503 |
+
sources=['upload'],
|
504 |
+
brush=gr.Brush(colors=["#ED7D31"], color_mode="fixed", default_size=10),
|
505 |
+
eraser=True,
|
506 |
+
layers=False,
|
507 |
+
transforms=[],
|
508 |
+
height=300,
|
509 |
+
)
|
510 |
+
generate_mask_btn = gr.Button("1️⃣ Generate Mask", visible=False, variant="primary")
|
511 |
+
mode = gr.Radio(label="Mode", choices=["Caption", "QA"], value="Caption")
|
512 |
+
query = gr.Textbox(label="Question", value="What is the relationship between <region0> and <region1>?", interactive=True, visible=False)
|
513 |
+
|
514 |
+
submit_btn = gr.Button("Generate Caption", variant="primary")
|
515 |
+
submit_btn1 = gr.Button("2️⃣ Generate Answer", variant="primary", visible=False)
|
516 |
+
gr.Examples([f"./demo/images/{i+1}.jpg" for i in range(8)], inputs=image_input, label="Examples")
|
517 |
+
|
518 |
+
with gr.Column():
|
519 |
+
mask_output = gr.Gallery(label="Referred Masks", object_fit='scale-down', visible=False)
|
520 |
+
output_image = gr.Image(label="Image with Mask", visible=True, height=400)
|
521 |
+
description = gr.Textbox(label="Output", visible=True)
|
522 |
+
|
523 |
+
clear_masks_btn = gr.Button("Clear Masks", variant="secondary", visible=False)
|
524 |
+
gr.Markdown(image_tips)
|
525 |
+
|
526 |
+
with gr.TabItem("Video"):
|
527 |
+
with gr.Row():
|
528 |
+
with gr.Column():
|
529 |
+
video_input = gr.Video(label="Video")
|
530 |
+
# load_btn = gr.Button("🖼️ Load First Frame", variant="secondary")
|
531 |
+
first_frame = gr.ImageEditor(
|
532 |
+
label="Annotate First Frame",
|
533 |
+
type="pil",
|
534 |
+
sources=['upload'],
|
535 |
+
brush=gr.Brush(colors=["#ED7D31"], color_mode="fixed", default_size=10),
|
536 |
+
eraser=True,
|
537 |
+
layers=False,
|
538 |
+
transforms=[],
|
539 |
+
height=300,
|
540 |
+
)
|
541 |
+
generate_mask_btn_video = gr.Button("1️⃣ Generate Mask", visible=False, variant="primary")
|
542 |
+
gr.Examples([f"./demo/videos/{i+1}.mp4" for i in range(4)], inputs=video_input, label="Examples")
|
543 |
+
|
544 |
+
with gr.Column():
|
545 |
+
mode_video = gr.Radio(label="Mode", choices=["Caption", "QA"], value="Caption")
|
546 |
+
mask_output_video = gr.Gallery(label="Referred Masks", object_fit='scale-down')
|
547 |
+
|
548 |
+
query_video = gr.Textbox(label="Question", value="What is the relationship between <object0> and <object1>?", interactive=True, visible=False)
|
549 |
+
|
550 |
+
submit_btn_video = gr.Button("Generate Caption", variant="primary")
|
551 |
+
submit_btn_video1 = gr.Button("2️⃣ Generate Answer", variant="primary", visible=False)
|
552 |
+
description_video = gr.Textbox(label="Output", visible=True)
|
553 |
+
|
554 |
+
clear_masks_btn_video = gr.Button("Clear Masks", variant="secondary")
|
555 |
+
|
556 |
+
gr.Markdown(video_tips)
|
557 |
+
|
558 |
+
|
559 |
+
def toggle_query_and_generate_button(mode):
|
560 |
+
"""
|
561 |
+
Toggles the visibility of query-related Gradio components based on the selected mode.
|
562 |
+
Also clears mask states.
|
563 |
+
|
564 |
+
Args:
|
565 |
+
mode (str): The selected mode ("Caption" or "QA").
|
566 |
+
|
567 |
+
Returns:
|
568 |
+
tuple: A series of gr.update() calls and empty lists to update Gradio components.
|
569 |
+
"""
|
570 |
+
query_visible = mode == "QA"
|
571 |
+
caption_visible = mode == "Caption"
|
572 |
+
return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), gr.update(visible=caption_visible), [], "", [], [],[],[]
|
573 |
+
|
574 |
+
video_input.change(load_first_frame, inputs=video_input, outputs=first_frame)
|
575 |
+
|
576 |
+
mode.change(toggle_query_and_generate_button, inputs=mode, outputs=[query, generate_mask_btn, clear_masks_btn, submit_btn1, mask_output, output_image, submit_btn, mask_output, description, mask_list, mask_raw_list, mask_list_video, mask_raw_list_video])
|
577 |
+
|
578 |
+
def toggle_query_and_generate_button_video(mode):
|
579 |
+
"""
|
580 |
+
Toggles the visibility of query-related Gradio components for video mode
|
581 |
+
based on the selected mode. Also clears mask states.
|
582 |
+
|
583 |
+
Args:
|
584 |
+
mode (str): The selected mode ("Caption" or "QA").
|
585 |
+
|
586 |
+
Returns:
|
587 |
+
tuple: A series of gr.update() calls and empty lists to update Gradio components.
|
588 |
+
"""
|
589 |
+
query_visible = mode == "QA"
|
590 |
+
caption_visible = mode == "Caption"
|
591 |
+
return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), [], [], [], [], []
|
592 |
+
|
593 |
+
|
594 |
+
mode_video.change(toggle_query_and_generate_button_video, inputs=mode_video, outputs=[query_video, generate_mask_btn_video, submit_btn_video1, submit_btn_video, mask_output_video, mask_list, mask_raw_list, mask_list_video, mask_raw_list_video])
|
595 |
+
|
596 |
+
submit_btn.click(
|
597 |
+
fn=describe,
|
598 |
+
inputs=[image_input, mode, query, mask_raw_list],
|
599 |
+
outputs=[output_image, description, image_input],
|
600 |
+
api_name="describe"
|
601 |
+
)
|
602 |
+
|
603 |
+
submit_btn1.click(
|
604 |
+
fn=describe,
|
605 |
+
inputs=[image_input, mode, query, mask_raw_list],
|
606 |
+
outputs=[output_image, description, image_input],
|
607 |
+
api_name="describe"
|
608 |
+
)
|
609 |
+
|
610 |
+
generate_mask_btn.click(
|
611 |
+
fn=generate_masks,
|
612 |
+
inputs=[image_input, mask_list, mask_raw_list],
|
613 |
+
outputs=[mask_output, image_input, mask_list, mask_raw_list]
|
614 |
+
)
|
615 |
+
|
616 |
+
generate_mask_btn_video.click(
|
617 |
+
fn=generate_masks_video,
|
618 |
+
inputs=[first_frame, mask_list_video, mask_raw_list_video],
|
619 |
+
outputs=[mask_output_video, first_frame, mask_list_video, mask_raw_list_video]
|
620 |
+
)
|
621 |
+
|
622 |
+
clear_masks_btn.click(
|
623 |
+
fn=clear_masks,
|
624 |
+
outputs=[mask_output, mask_list, mask_raw_list]
|
625 |
+
)
|
626 |
+
|
627 |
+
clear_masks_btn_video.click(
|
628 |
+
fn=clear_masks,
|
629 |
+
outputs=[mask_output_video, mask_list_video, mask_raw_list_video]
|
630 |
+
)
|
631 |
+
|
632 |
+
submit_btn_video.click(
|
633 |
+
fn=describe_video,
|
634 |
+
inputs=[video_input, mode_video, query_video, first_frame, mask_raw_list_video, mask_list_video],
|
635 |
+
outputs=[first_frame, description_video, mask_output_video, mask_list_video],
|
636 |
+
api_name="describe_video"
|
637 |
+
)
|
638 |
+
|
639 |
+
submit_btn_video1.click(
|
640 |
+
fn=describe_video,
|
641 |
+
inputs=[video_input, mode_video, query_video, first_frame, mask_raw_list_video, mask_list_video],
|
642 |
+
outputs=[first_frame, description_video, mask_output_video, mask_list_video],
|
643 |
+
api_name="describe_video"
|
644 |
+
)
|
645 |
+
|
646 |
+
|
647 |
+
|
648 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
649 |
+
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
|
650 |
+
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
651 |
+
|
652 |
+
disable_torch_init()
|
653 |
+
|
654 |
+
|
655 |
+
model, processor, tokenizer = model_init(args_cli.model_path)
|
656 |
+
|
657 |
+
|
658 |
+
demo.launch(mcp_server=True)
|