Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -52,15 +52,20 @@ def add_contour(img, mask, color=(1., 1., 1.)):
|
|
| 52 |
@spaces.GPU(duration=120)
|
| 53 |
def generate_masks(image, mask_list, mask_raw_list):
|
| 54 |
"""
|
| 55 |
-
|
| 56 |
-
|
| 57 |
Args:
|
| 58 |
-
image:
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
| 62 |
Returns:
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
"""
|
| 65 |
image['image'] = image['background'].convert('RGB')
|
| 66 |
# del image['background'], image['composite']
|
|
@@ -89,15 +94,21 @@ def generate_masks(image, mask_list, mask_raw_list):
|
|
| 89 |
@spaces.GPU(duration=120)
|
| 90 |
def generate_masks_video(image, mask_list_video, mask_raw_list_video):
|
| 91 |
"""
|
| 92 |
-
|
| 93 |
-
|
| 94 |
Args:
|
| 95 |
-
image:
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
| 99 |
Returns:
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
"""
|
| 102 |
image['image'] = image['background'].convert('RGB')
|
| 103 |
# del image['background'], image['composite']
|
|
@@ -127,16 +138,19 @@ def generate_masks_video(image, mask_list_video, mask_raw_list_video):
|
|
| 127 |
@spaces.GPU(duration=120)
|
| 128 |
def describe(image, mode, query, masks):
|
| 129 |
"""
|
| 130 |
-
|
| 131 |
-
|
| 132 |
Args:
|
| 133 |
-
image:
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
| 140 |
"""
|
| 141 |
# Create an image object from the uploaded image
|
| 142 |
# print(image.keys())
|
|
@@ -229,13 +243,16 @@ def describe(image, mode, query, masks):
|
|
| 229 |
|
| 230 |
def load_first_frame(video_path):
|
| 231 |
"""
|
| 232 |
-
|
| 233 |
-
|
| 234 |
Args:
|
| 235 |
-
video_path:
|
| 236 |
-
|
| 237 |
Returns:
|
| 238 |
-
PIL
|
|
|
|
|
|
|
|
|
|
| 239 |
"""
|
| 240 |
cap = cv2.VideoCapture(video_path)
|
| 241 |
ret, frame = cap.read()
|
|
@@ -249,18 +266,23 @@ def load_first_frame(video_path):
|
|
| 249 |
@spaces.GPU(duration=120)
|
| 250 |
def describe_video(video_path, mode, query, annotated_frame, masks, mask_list_video):
|
| 251 |
"""
|
| 252 |
-
|
| 253 |
-
|
| 254 |
Args:
|
| 255 |
-
video_path:
|
| 256 |
-
mode:
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
"""
|
| 265 |
# Create a temporary directory to save extracted video frames
|
| 266 |
cap = cv2.VideoCapture(video_path)
|
|
@@ -351,4 +373,286 @@ def describe_video(video_path, mode, query, annotated_frame, masks, mask_list_vi
|
|
| 351 |
mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(annotated_frame['image'])).astype(np.uint8))
|
| 352 |
mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>"))
|
| 353 |
text = ""
|
| 354 |
-
yield frame_img, text, mask_list_video, mask_list_video
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
@spaces.GPU(duration=120)
|
| 53 |
def generate_masks(image, mask_list, mask_raw_list):
|
| 54 |
"""
|
| 55 |
+
Generates segmentation masks for selected regions in an image using SAM.
|
| 56 |
+
|
| 57 |
Args:
|
| 58 |
+
image (dict): A dictionary containing image data, typically from a Gradio ImageEditor,
|
| 59 |
+
with 'background' (PIL Image) and 'layers' (list of PIL Image layers).
|
| 60 |
+
mask_list (list): A list to accumulate (mask_image, label) tuples for display in a gallery.
|
| 61 |
+
mask_raw_list (list): A list to accumulate raw NumPy mask arrays.
|
| 62 |
+
|
| 63 |
Returns:
|
| 64 |
+
tuple: A tuple containing:
|
| 65 |
+
- mask_list (list): Updated list of mask images for display.
|
| 66 |
+
- image (dict): Updated image dictionary with layers cleared.
|
| 67 |
+
- mask_list (list): Redundant return of mask_list (for Gradio update).
|
| 68 |
+
- mask_raw_list (list): Updated list of raw mask arrays.
|
| 69 |
"""
|
| 70 |
image['image'] = image['background'].convert('RGB')
|
| 71 |
# del image['background'], image['composite']
|
|
|
|
| 94 |
@spaces.GPU(duration=120)
|
| 95 |
def generate_masks_video(image, mask_list_video, mask_raw_list_video):
|
| 96 |
"""
|
| 97 |
+
Generates segmentation masks for selected regions in the first frame of a video using SAM.
|
| 98 |
+
|
| 99 |
Args:
|
| 100 |
+
image (dict): A dictionary containing image data (first frame of video),
|
| 101 |
+
typically from a Gradio ImageEditor, with 'background' (PIL Image)
|
| 102 |
+
and 'layers' (list of PIL Image layers).
|
| 103 |
+
mask_list_video (list): A list to accumulate (mask_image, label) tuples for display.
|
| 104 |
+
mask_raw_list_video (list): A list to accumulate raw NumPy mask arrays for video processing.
|
| 105 |
+
|
| 106 |
Returns:
|
| 107 |
+
tuple: A tuple containing:
|
| 108 |
+
- mask_list_video (list): Updated list of mask images for display.
|
| 109 |
+
- image (dict): Updated image dictionary with layers cleared.
|
| 110 |
+
- mask_list_video (list): Redundant return of mask_list_video (for Gradio update).
|
| 111 |
+
- mask_raw_list_video (list): Updated list of raw mask arrays.
|
| 112 |
"""
|
| 113 |
image['image'] = image['background'].convert('RGB')
|
| 114 |
# del image['background'], image['composite']
|
|
|
|
| 138 |
@spaces.GPU(duration=120)
|
| 139 |
def describe(image, mode, query, masks):
|
| 140 |
"""
|
| 141 |
+
Describes an image based on selected regions or answers a question about them.
|
| 142 |
+
|
| 143 |
Args:
|
| 144 |
+
image (dict): A dictionary containing image data, typically from a Gradio ImageEditor,
|
| 145 |
+
with 'background' (PIL Image) and 'layers' (list of PIL Image layers).
|
| 146 |
+
mode (str): The operational mode, either "Caption" (to describe a selected region)
|
| 147 |
+
or "QA" (to answer a question about one or more regions).
|
| 148 |
+
query (str): The question to ask in "QA" mode. Ignored in "Caption" mode.
|
| 149 |
+
masks (list): A list of raw NumPy mask arrays representing previously generated masks.
|
| 150 |
+
|
| 151 |
+
Yields:
|
| 152 |
+
tuple: An image with contours and the generated text description/answer,
|
| 153 |
+
or updates for Gradio components during streaming.
|
| 154 |
"""
|
| 155 |
# Create an image object from the uploaded image
|
| 156 |
# print(image.keys())
|
|
|
|
| 243 |
|
| 244 |
def load_first_frame(video_path):
|
| 245 |
"""
|
| 246 |
+
Loads the first frame of a given video file.
|
| 247 |
+
|
| 248 |
Args:
|
| 249 |
+
video_path (str): The file path to the video.
|
| 250 |
+
|
| 251 |
Returns:
|
| 252 |
+
PIL.Image.Image: The first frame of the video as a PIL Image.
|
| 253 |
+
|
| 254 |
+
Raises:
|
| 255 |
+
gr.Error: If the video file cannot be read.
|
| 256 |
"""
|
| 257 |
cap = cv2.VideoCapture(video_path)
|
| 258 |
ret, frame = cap.read()
|
|
|
|
| 266 |
@spaces.GPU(duration=120)
|
| 267 |
def describe_video(video_path, mode, query, annotated_frame, masks, mask_list_video):
|
| 268 |
"""
|
| 269 |
+
Describes a video based on selected regions in its first frame or answers a question about them.
|
| 270 |
+
|
| 271 |
Args:
|
| 272 |
+
video_path (str): The file path to the video.
|
| 273 |
+
mode (str): The operational mode, either "Caption" (to describe a selected region)
|
| 274 |
+
or "QA" (to answer a question about one or more regions).
|
| 275 |
+
query (str): The question to ask in "QA" mode. Ignored in "Caption" mode.
|
| 276 |
+
annotated_frame (dict): A dictionary containing the first frame's image data
|
| 277 |
+
from a Gradio ImageEditor, with 'background' (PIL Image)
|
| 278 |
+
and 'layers' (list of PIL Image layers).
|
| 279 |
+
masks (list): A list of raw NumPy mask arrays representing previously generated masks
|
| 280 |
+
for objects in the video.
|
| 281 |
+
mask_list_video (list): A list to accumulate (mask_image, label) tuples for display.
|
| 282 |
+
|
| 283 |
+
Yields:
|
| 284 |
+
tuple: The annotated first frame, the generated text description/answer,
|
| 285 |
+
and updated mask lists for Gradio components during streaming.
|
| 286 |
"""
|
| 287 |
# Create a temporary directory to save extracted video frames
|
| 288 |
cap = cv2.VideoCapture(video_path)
|
|
|
|
| 373 |
mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(annotated_frame['image'])).astype(np.uint8))
|
| 374 |
mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>"))
|
| 375 |
text = ""
|
| 376 |
+
yield frame_img, text, mask_list_video, mask_list_video
|
| 377 |
+
|
| 378 |
+
for token in get_model_output(
|
| 379 |
+
video_tensor,
|
| 380 |
+
query,
|
| 381 |
+
model=model,
|
| 382 |
+
tokenizer=tokenizer,
|
| 383 |
+
masks=masks,
|
| 384 |
+
mask_ids=mask_ids,
|
| 385 |
+
modal='video',
|
| 386 |
+
streaming=True,
|
| 387 |
+
):
|
| 388 |
+
text += token
|
| 389 |
+
yield gr.update(), text, gr.update(), gr.update()
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
@spaces.GPU(duration=120)
|
| 393 |
+
def apply_sam(image, input_points):
|
| 394 |
+
"""
|
| 395 |
+
Applies the Segment Anything Model (SAM) to an image based on input points
|
| 396 |
+
to generate a segmentation mask.
|
| 397 |
+
|
| 398 |
+
Args:
|
| 399 |
+
image (PIL.Image.Image): The input image.
|
| 400 |
+
input_points (list): A list of lists, where each inner list contains
|
| 401 |
+
[x, y] coordinates representing points used for segmentation.
|
| 402 |
+
|
| 403 |
+
Returns:
|
| 404 |
+
numpy.ndarray: The selected binary segmentation mask as a NumPy array (H, W).
|
| 405 |
+
"""
|
| 406 |
+
inputs = sam_processor(image, input_points=input_points, return_tensors="pt").to(device)
|
| 407 |
+
|
| 408 |
+
with torch.no_grad():
|
| 409 |
+
outputs = sam_model(**inputs)
|
| 410 |
+
|
| 411 |
+
masks = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())[0][0]
|
| 412 |
+
scores = outputs.iou_scores[0, 0]
|
| 413 |
+
|
| 414 |
+
mask_selection_index = scores.argmax()
|
| 415 |
+
|
| 416 |
+
mask_np = masks[mask_selection_index].numpy()
|
| 417 |
+
|
| 418 |
+
return mask_np
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def clear_masks():
|
| 422 |
+
"""
|
| 423 |
+
Clears the stored lists of masks and raw masks.
|
| 424 |
+
|
| 425 |
+
Returns:
|
| 426 |
+
tuple: Three empty lists, intended to reset Gradio components
|
| 427 |
+
displaying masks.
|
| 428 |
+
"""
|
| 429 |
+
return [], [], []
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
if __name__ == "__main__":
|
| 433 |
+
parser = argparse.ArgumentParser(description="VideoRefer gradio demo")
|
| 434 |
+
parser.add_argument("--model-path", type=str, default="DAMO-NLP-SG/VideoRefer-VideoLLaMA3-7B", help="Path to the model checkpoint")
|
| 435 |
+
parser.add_argument("--prompt-mode", type=str, default="focal_prompt", help="Prompt mode")
|
| 436 |
+
parser.add_argument("--conv-mode", type=str, default="v1", help="Conversation mode")
|
| 437 |
+
parser.add_argument("--temperature", type=float, default=0.2, help="Sampling temperature")
|
| 438 |
+
parser.add_argument("--top_p", type=float, default=0.5, help="Top-p for sampling")
|
| 439 |
+
|
| 440 |
+
args_cli = parser.parse_args()
|
| 441 |
+
|
| 442 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="amber")) as demo:
|
| 443 |
+
|
| 444 |
+
mask_list = gr.State([])
|
| 445 |
+
mask_raw_list = gr.State([])
|
| 446 |
+
mask_list_video = gr.State([])
|
| 447 |
+
mask_raw_list_video = gr.State([])
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
HEADER = ("""
|
| 451 |
+
<div>
|
| 452 |
+
<h1>VideoRefer X VideoLLaMA3 Demo</h1>
|
| 453 |
+
<h5 style="margin: 0;">Feel free to click on anything that grabs your interest!</h5>
|
| 454 |
+
<h5 style="margin: 0;">If this demo please you, please give us a star ⭐ on Github or 💖 on this space.</h5>
|
| 455 |
+
</div>
|
| 456 |
+
</div>
|
| 457 |
+
<div style="display: flex; justify-content: left; margin-top: 10px;">
|
| 458 |
+
<a href="https://arxiv.org/pdf/2501.00599"><img src="https://img.shields.io/badge/Arxiv-2501.00599-ECA8A7" style="margin-right: 5px;"></a>
|
| 459 |
+
<a href="https://github.com/DAMO-NLP-SG/VideoRefer"><img src='https://img.shields.io/badge/Github-VideoRefer-F7C97E' style="margin-right: 5px;"></a>
|
| 460 |
+
<a href="https://github.com/DAMO-NLP-SG/VideoLLaMA3"><img src='https://img.shields.io/badge/Github-VideoLLaMA3-9DC3E6' style="margin-right: 5px;"></a>
|
| 461 |
+
</div>
|
| 462 |
+
""")
|
| 463 |
+
|
| 464 |
+
with gr.Row():
|
| 465 |
+
with gr.Column():
|
| 466 |
+
gr.HTML(HEADER)
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
image_tips = """
|
| 470 |
+
### 💡 Tips:
|
| 471 |
+
|
| 472 |
+
🧸 Upload an image, and you can use the drawing tool✍️ to highlight the areas you're interested in.
|
| 473 |
+
|
| 474 |
+
🔖 For single-object caption mode, simply select the area and click the 'Generate Caption' button to receive a caption for the object.
|
| 475 |
+
|
| 476 |
+
🔔 In QA mode, you can generate multiple masks by clicking the 'Generate Mask' button multiple times. Afterward, use the corresponding object id to ask questions.
|
| 477 |
+
|
| 478 |
+
📌 Click the button 'Clear Masks' to clear the current generated masks.
|
| 479 |
+
|
| 480 |
+
"""
|
| 481 |
+
|
| 482 |
+
video_tips = """
|
| 483 |
+
### 💡 Tips:
|
| 484 |
+
⚠️ For video mode, we only support masking on the first frame in this demo.
|
| 485 |
+
|
| 486 |
+
🧸 Upload an video, and you can use the drawing tool✍️ to highlight the areas you're interested in the first frame.
|
| 487 |
+
|
| 488 |
+
🔖 For single-object caption mode, simply select the area and click the 'Generate Caption' button to receive a caption for the object.
|
| 489 |
+
|
| 490 |
+
🔔 In QA mode, you can generate multiple masks by clicking the 'Generate Mask' button multiple times. Afterward, use the corresponding object id to ask questions.
|
| 491 |
+
|
| 492 |
+
📌 Click the button 'Clear Masks' to clear the current generated masks.
|
| 493 |
+
|
| 494 |
+
"""
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
with gr.TabItem("Image"):
|
| 498 |
+
with gr.Row():
|
| 499 |
+
with gr.Column():
|
| 500 |
+
image_input = gr.ImageEditor(
|
| 501 |
+
label="Image",
|
| 502 |
+
type="pil",
|
| 503 |
+
sources=['upload'],
|
| 504 |
+
brush=gr.Brush(colors=["#ED7D31"], color_mode="fixed", default_size=10),
|
| 505 |
+
eraser=True,
|
| 506 |
+
layers=False,
|
| 507 |
+
transforms=[],
|
| 508 |
+
height=300,
|
| 509 |
+
)
|
| 510 |
+
generate_mask_btn = gr.Button("1️⃣ Generate Mask", visible=False, variant="primary")
|
| 511 |
+
mode = gr.Radio(label="Mode", choices=["Caption", "QA"], value="Caption")
|
| 512 |
+
query = gr.Textbox(label="Question", value="What is the relationship between <region0> and <region1>?", interactive=True, visible=False)
|
| 513 |
+
|
| 514 |
+
submit_btn = gr.Button("Generate Caption", variant="primary")
|
| 515 |
+
submit_btn1 = gr.Button("2️⃣ Generate Answer", variant="primary", visible=False)
|
| 516 |
+
gr.Examples([f"./demo/images/{i+1}.jpg" for i in range(8)], inputs=image_input, label="Examples")
|
| 517 |
+
|
| 518 |
+
with gr.Column():
|
| 519 |
+
mask_output = gr.Gallery(label="Referred Masks", object_fit='scale-down', visible=False)
|
| 520 |
+
output_image = gr.Image(label="Image with Mask", visible=True, height=400)
|
| 521 |
+
description = gr.Textbox(label="Output", visible=True)
|
| 522 |
+
|
| 523 |
+
clear_masks_btn = gr.Button("Clear Masks", variant="secondary", visible=False)
|
| 524 |
+
gr.Markdown(image_tips)
|
| 525 |
+
|
| 526 |
+
with gr.TabItem("Video"):
|
| 527 |
+
with gr.Row():
|
| 528 |
+
with gr.Column():
|
| 529 |
+
video_input = gr.Video(label="Video")
|
| 530 |
+
# load_btn = gr.Button("🖼️ Load First Frame", variant="secondary")
|
| 531 |
+
first_frame = gr.ImageEditor(
|
| 532 |
+
label="Annotate First Frame",
|
| 533 |
+
type="pil",
|
| 534 |
+
sources=['upload'],
|
| 535 |
+
brush=gr.Brush(colors=["#ED7D31"], color_mode="fixed", default_size=10),
|
| 536 |
+
eraser=True,
|
| 537 |
+
layers=False,
|
| 538 |
+
transforms=[],
|
| 539 |
+
height=300,
|
| 540 |
+
)
|
| 541 |
+
generate_mask_btn_video = gr.Button("1️⃣ Generate Mask", visible=False, variant="primary")
|
| 542 |
+
gr.Examples([f"./demo/videos/{i+1}.mp4" for i in range(4)], inputs=video_input, label="Examples")
|
| 543 |
+
|
| 544 |
+
with gr.Column():
|
| 545 |
+
mode_video = gr.Radio(label="Mode", choices=["Caption", "QA"], value="Caption")
|
| 546 |
+
mask_output_video = gr.Gallery(label="Referred Masks", object_fit='scale-down')
|
| 547 |
+
|
| 548 |
+
query_video = gr.Textbox(label="Question", value="What is the relationship between <object0> and <object1>?", interactive=True, visible=False)
|
| 549 |
+
|
| 550 |
+
submit_btn_video = gr.Button("Generate Caption", variant="primary")
|
| 551 |
+
submit_btn_video1 = gr.Button("2️⃣ Generate Answer", variant="primary", visible=False)
|
| 552 |
+
description_video = gr.Textbox(label="Output", visible=True)
|
| 553 |
+
|
| 554 |
+
clear_masks_btn_video = gr.Button("Clear Masks", variant="secondary")
|
| 555 |
+
|
| 556 |
+
gr.Markdown(video_tips)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def toggle_query_and_generate_button(mode):
|
| 560 |
+
"""
|
| 561 |
+
Toggles the visibility of query-related Gradio components based on the selected mode.
|
| 562 |
+
Also clears mask states.
|
| 563 |
+
|
| 564 |
+
Args:
|
| 565 |
+
mode (str): The selected mode ("Caption" or "QA").
|
| 566 |
+
|
| 567 |
+
Returns:
|
| 568 |
+
tuple: A series of gr.update() calls and empty lists to update Gradio components.
|
| 569 |
+
"""
|
| 570 |
+
query_visible = mode == "QA"
|
| 571 |
+
caption_visible = mode == "Caption"
|
| 572 |
+
return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), gr.update(visible=caption_visible), [], "", [], [],[],[]
|
| 573 |
+
|
| 574 |
+
video_input.change(load_first_frame, inputs=video_input, outputs=first_frame)
|
| 575 |
+
|
| 576 |
+
mode.change(toggle_query_and_generate_button, inputs=mode, outputs=[query, generate_mask_btn, clear_masks_btn, submit_btn1, mask_output, output_image, submit_btn, mask_output, description, mask_list, mask_raw_list, mask_list_video, mask_raw_list_video])
|
| 577 |
+
|
| 578 |
+
def toggle_query_and_generate_button_video(mode):
|
| 579 |
+
"""
|
| 580 |
+
Toggles the visibility of query-related Gradio components for video mode
|
| 581 |
+
based on the selected mode. Also clears mask states.
|
| 582 |
+
|
| 583 |
+
Args:
|
| 584 |
+
mode (str): The selected mode ("Caption" or "QA").
|
| 585 |
+
|
| 586 |
+
Returns:
|
| 587 |
+
tuple: A series of gr.update() calls and empty lists to update Gradio components.
|
| 588 |
+
"""
|
| 589 |
+
query_visible = mode == "QA"
|
| 590 |
+
caption_visible = mode == "Caption"
|
| 591 |
+
return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), [], [], [], [], []
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
mode_video.change(toggle_query_and_generate_button_video, inputs=mode_video, outputs=[query_video, generate_mask_btn_video, submit_btn_video1, submit_btn_video, mask_output_video, mask_list, mask_raw_list, mask_list_video, mask_raw_list_video])
|
| 595 |
+
|
| 596 |
+
submit_btn.click(
|
| 597 |
+
fn=describe,
|
| 598 |
+
inputs=[image_input, mode, query, mask_raw_list],
|
| 599 |
+
outputs=[output_image, description, image_input],
|
| 600 |
+
api_name="describe"
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
submit_btn1.click(
|
| 604 |
+
fn=describe,
|
| 605 |
+
inputs=[image_input, mode, query, mask_raw_list],
|
| 606 |
+
outputs=[output_image, description, image_input],
|
| 607 |
+
api_name="describe"
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
generate_mask_btn.click(
|
| 611 |
+
fn=generate_masks,
|
| 612 |
+
inputs=[image_input, mask_list, mask_raw_list],
|
| 613 |
+
outputs=[mask_output, image_input, mask_list, mask_raw_list]
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
generate_mask_btn_video.click(
|
| 617 |
+
fn=generate_masks_video,
|
| 618 |
+
inputs=[first_frame, mask_list_video, mask_raw_list_video],
|
| 619 |
+
outputs=[mask_output_video, first_frame, mask_list_video, mask_raw_list_video]
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
clear_masks_btn.click(
|
| 623 |
+
fn=clear_masks,
|
| 624 |
+
outputs=[mask_output, mask_list, mask_raw_list]
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
clear_masks_btn_video.click(
|
| 628 |
+
fn=clear_masks,
|
| 629 |
+
outputs=[mask_output_video, mask_list_video, mask_raw_list_video]
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
submit_btn_video.click(
|
| 633 |
+
fn=describe_video,
|
| 634 |
+
inputs=[video_input, mode_video, query_video, first_frame, mask_raw_list_video, mask_list_video],
|
| 635 |
+
outputs=[first_frame, description_video, mask_output_video, mask_list_video],
|
| 636 |
+
api_name="describe_video"
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
submit_btn_video1.click(
|
| 640 |
+
fn=describe_video,
|
| 641 |
+
inputs=[video_input, mode_video, query_video, first_frame, mask_raw_list_video, mask_list_video],
|
| 642 |
+
outputs=[first_frame, description_video, mask_output_video, mask_list_video],
|
| 643 |
+
api_name="describe_video"
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 649 |
+
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
|
| 650 |
+
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
| 651 |
+
|
| 652 |
+
disable_torch_init()
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
model, processor, tokenizer = model_init(args_cli.model_path)
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
demo.launch(mcp_server=True)
|