feat: Enable MCP

#9
by multimodalart HF Staff - opened
Files changed (1) hide show
  1. app.py +119 -1
app.py CHANGED
@@ -51,6 +51,22 @@ def add_contour(img, mask, color=(1., 1., 1.)):
51
 
52
  @spaces.GPU(duration=120)
53
  def generate_masks(image, mask_list, mask_raw_list):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  image['image'] = image['background'].convert('RGB')
55
  # del image['background'], image['composite']
56
  assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
@@ -77,6 +93,23 @@ def generate_masks(image, mask_list, mask_raw_list):
77
 
78
  @spaces.GPU(duration=120)
79
  def generate_masks_video(image, mask_list_video, mask_raw_list_video):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  image['image'] = image['background'].convert('RGB')
81
  # del image['background'], image['composite']
82
  assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
@@ -104,6 +137,21 @@ def generate_masks_video(image, mask_list_video, mask_raw_list_video):
104
 
105
  @spaces.GPU(duration=120)
106
  def describe(image, mode, query, masks):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  # Create an image object from the uploaded image
108
  # print(image.keys())
109
 
@@ -194,6 +242,18 @@ def describe(image, mode, query, masks):
194
 
195
 
196
  def load_first_frame(video_path):
 
 
 
 
 
 
 
 
 
 
 
 
197
  cap = cv2.VideoCapture(video_path)
198
  ret, frame = cap.read()
199
  cap.release()
@@ -205,6 +265,25 @@ def load_first_frame(video_path):
205
 
206
  @spaces.GPU(duration=120)
207
  def describe_video(video_path, mode, query, annotated_frame, masks, mask_list_video):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  # Create a temporary directory to save extracted video frames
209
  cap = cv2.VideoCapture(video_path)
210
 
@@ -312,6 +391,18 @@ def describe_video(video_path, mode, query, annotated_frame, masks, mask_list_vi
312
 
313
  @spaces.GPU(duration=120)
314
  def apply_sam(image, input_points):
 
 
 
 
 
 
 
 
 
 
 
 
315
  inputs = sam_processor(image, input_points=input_points, return_tensors="pt").to(device)
316
 
317
  with torch.no_grad():
@@ -328,6 +419,13 @@ def apply_sam(image, input_points):
328
 
329
 
330
  def clear_masks():
 
 
 
 
 
 
 
331
  return [], [], []
332
 
333
 
@@ -459,6 +557,16 @@ if __name__ == "__main__":
459
 
460
 
461
  def toggle_query_and_generate_button(mode):
 
 
 
 
 
 
 
 
 
 
462
  query_visible = mode == "QA"
463
  caption_visible = mode == "Caption"
464
  return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), gr.update(visible=caption_visible), [], "", [], [],[],[]
@@ -468,6 +576,16 @@ if __name__ == "__main__":
468
  mode.change(toggle_query_and_generate_button, inputs=mode, outputs=[query, generate_mask_btn, clear_masks_btn, submit_btn1, mask_output, output_image, submit_btn, mask_output, description, mask_list, mask_raw_list, mask_list_video, mask_raw_list_video])
469
 
470
  def toggle_query_and_generate_button_video(mode):
 
 
 
 
 
 
 
 
 
 
471
  query_visible = mode == "QA"
472
  caption_visible = mode == "Caption"
473
  return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), [], [], [], [], []
@@ -537,4 +655,4 @@ if __name__ == "__main__":
537
  model, processor, tokenizer = model_init(args_cli.model_path)
538
 
539
 
540
- demo.launch()
 
51
 
52
  @spaces.GPU(duration=120)
53
  def generate_masks(image, mask_list, mask_raw_list):
54
+ """
55
+ Generates segmentation masks for selected regions in an image using SAM.
56
+
57
+ Args:
58
+ image (dict): A dictionary containing image data, typically from a Gradio ImageEditor,
59
+ with 'background' (PIL Image) and 'layers' (list of PIL Image layers).
60
+ mask_list (list): A list to accumulate (mask_image, label) tuples for display in a gallery.
61
+ mask_raw_list (list): A list to accumulate raw NumPy mask arrays.
62
+
63
+ Returns:
64
+ tuple: A tuple containing:
65
+ - mask_list (list): Updated list of mask images for display.
66
+ - image (dict): Updated image dictionary with layers cleared.
67
+ - mask_list (list): Redundant return of mask_list (for Gradio update).
68
+ - mask_raw_list (list): Updated list of raw mask arrays.
69
+ """
70
  image['image'] = image['background'].convert('RGB')
71
  # del image['background'], image['composite']
72
  assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
 
93
 
94
  @spaces.GPU(duration=120)
95
  def generate_masks_video(image, mask_list_video, mask_raw_list_video):
96
+ """
97
+ Generates segmentation masks for selected regions in the first frame of a video using SAM.
98
+
99
+ Args:
100
+ image (dict): A dictionary containing image data (first frame of video),
101
+ typically from a Gradio ImageEditor, with 'background' (PIL Image)
102
+ and 'layers' (list of PIL Image layers).
103
+ mask_list_video (list): A list to accumulate (mask_image, label) tuples for display.
104
+ mask_raw_list_video (list): A list to accumulate raw NumPy mask arrays for video processing.
105
+
106
+ Returns:
107
+ tuple: A tuple containing:
108
+ - mask_list_video (list): Updated list of mask images for display.
109
+ - image (dict): Updated image dictionary with layers cleared.
110
+ - mask_list_video (list): Redundant return of mask_list_video (for Gradio update).
111
+ - mask_raw_list_video (list): Updated list of raw mask arrays.
112
+ """
113
  image['image'] = image['background'].convert('RGB')
114
  # del image['background'], image['composite']
115
  assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
 
137
 
138
  @spaces.GPU(duration=120)
139
  def describe(image, mode, query, masks):
140
+ """
141
+ Describes an image based on selected regions or answers a question about them.
142
+
143
+ Args:
144
+ image (dict): A dictionary containing image data, typically from a Gradio ImageEditor,
145
+ with 'background' (PIL Image) and 'layers' (list of PIL Image layers).
146
+ mode (str): The operational mode, either "Caption" (to describe a selected region)
147
+ or "QA" (to answer a question about one or more regions).
148
+ query (str): The question to ask in "QA" mode. Ignored in "Caption" mode.
149
+ masks (list): A list of raw NumPy mask arrays representing previously generated masks.
150
+
151
+ Yields:
152
+ tuple: An image with contours and the generated text description/answer,
153
+ or updates for Gradio components during streaming.
154
+ """
155
  # Create an image object from the uploaded image
156
  # print(image.keys())
157
 
 
242
 
243
 
244
  def load_first_frame(video_path):
245
+ """
246
+ Loads the first frame of a given video file.
247
+
248
+ Args:
249
+ video_path (str): The file path to the video.
250
+
251
+ Returns:
252
+ PIL.Image.Image: The first frame of the video as a PIL Image.
253
+
254
+ Raises:
255
+ gr.Error: If the video file cannot be read.
256
+ """
257
  cap = cv2.VideoCapture(video_path)
258
  ret, frame = cap.read()
259
  cap.release()
 
265
 
266
  @spaces.GPU(duration=120)
267
  def describe_video(video_path, mode, query, annotated_frame, masks, mask_list_video):
268
+ """
269
+ Describes a video based on selected regions in its first frame or answers a question about them.
270
+
271
+ Args:
272
+ video_path (str): The file path to the video.
273
+ mode (str): The operational mode, either "Caption" (to describe a selected region)
274
+ or "QA" (to answer a question about one or more regions).
275
+ query (str): The question to ask in "QA" mode. Ignored in "Caption" mode.
276
+ annotated_frame (dict): A dictionary containing the first frame's image data
277
+ from a Gradio ImageEditor, with 'background' (PIL Image)
278
+ and 'layers' (list of PIL Image layers).
279
+ masks (list): A list of raw NumPy mask arrays representing previously generated masks
280
+ for objects in the video.
281
+ mask_list_video (list): A list to accumulate (mask_image, label) tuples for display.
282
+
283
+ Yields:
284
+ tuple: The annotated first frame, the generated text description/answer,
285
+ and updated mask lists for Gradio components during streaming.
286
+ """
287
  # Create a temporary directory to save extracted video frames
288
  cap = cv2.VideoCapture(video_path)
289
 
 
391
 
392
  @spaces.GPU(duration=120)
393
  def apply_sam(image, input_points):
394
+ """
395
+ Applies the Segment Anything Model (SAM) to an image based on input points
396
+ to generate a segmentation mask.
397
+
398
+ Args:
399
+ image (PIL.Image.Image): The input image.
400
+ input_points (list): A list of lists, where each inner list contains
401
+ [x, y] coordinates representing points used for segmentation.
402
+
403
+ Returns:
404
+ numpy.ndarray: The selected binary segmentation mask as a NumPy array (H, W).
405
+ """
406
  inputs = sam_processor(image, input_points=input_points, return_tensors="pt").to(device)
407
 
408
  with torch.no_grad():
 
419
 
420
 
421
  def clear_masks():
422
+ """
423
+ Clears the stored lists of masks and raw masks.
424
+
425
+ Returns:
426
+ tuple: Three empty lists, intended to reset Gradio components
427
+ displaying masks.
428
+ """
429
  return [], [], []
430
 
431
 
 
557
 
558
 
559
  def toggle_query_and_generate_button(mode):
560
+ """
561
+ Toggles the visibility of query-related Gradio components based on the selected mode.
562
+ Also clears mask states.
563
+
564
+ Args:
565
+ mode (str): The selected mode ("Caption" or "QA").
566
+
567
+ Returns:
568
+ tuple: A series of gr.update() calls and empty lists to update Gradio components.
569
+ """
570
  query_visible = mode == "QA"
571
  caption_visible = mode == "Caption"
572
  return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), gr.update(visible=caption_visible), [], "", [], [],[],[]
 
576
  mode.change(toggle_query_and_generate_button, inputs=mode, outputs=[query, generate_mask_btn, clear_masks_btn, submit_btn1, mask_output, output_image, submit_btn, mask_output, description, mask_list, mask_raw_list, mask_list_video, mask_raw_list_video])
577
 
578
  def toggle_query_and_generate_button_video(mode):
579
+ """
580
+ Toggles the visibility of query-related Gradio components for video mode
581
+ based on the selected mode. Also clears mask states.
582
+
583
+ Args:
584
+ mode (str): The selected mode ("Caption" or "QA").
585
+
586
+ Returns:
587
+ tuple: A series of gr.update() calls and empty lists to update Gradio components.
588
+ """
589
  query_visible = mode == "QA"
590
  caption_visible = mode == "Caption"
591
  return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), [], [], [], [], []
 
655
  model, processor, tokenizer = model_init(args_cli.model_path)
656
 
657
 
658
+ demo.launch(mcp_server=True)