multimodalart's picture
feat: Enable MCP
e2b49c1 verified
raw
history blame
12.7 kB
import spaces
import gradio as gr
import numpy as np
import torch
from transformers import SamModel, SamProcessor
from PIL import Image
import os
import cv2
import argparse
import sys
# This is for making model initialization faster and has no effect since we are loading the weights
sys.path.append('./')
from videollama3 import disable_torch_init, model_init, mm_infer, get_model_output
from videollama3.mm_utils import load_images
from videollama3.mm_utils import load_video
color_rgb = (1.0, 1.0, 1.0)
color_rgbs = [
(1.0, 1.0, 1.0),
(1.0, 0.0, 0.0),
(0.0, 1.0, 1.0),
(0.0, 1.0, 0.0),
(0.0, 0.0, 1.0),
(1.0, 0.0, 1.0),
]
def extract_first_frame_from_video(video):
cap = cv2.VideoCapture(video)
success, frame = cap.read()
cap.release()
if success:
return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
return None
def extract_points_from_mask(mask_pil):
mask = np.asarray(mask_pil)[..., 0]
coords = np.nonzero(mask)
coords = np.stack((coords[1], coords[0]), axis=1)
return coords
def add_contour(img, mask, color=(1., 1., 1.)):
img = img.copy()
mask = mask.astype(np.uint8) * 255
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(img, contours, -1, color, thickness=8)
return img
@spaces.GPU(duration=120)
def generate_masks(image, mask_list, mask_raw_list):
"""
Generate masks from user-drawn annotations on an image.
Args:
image: Dictionary containing the image editor state with background and layers
mask_list: List of generated mask images with labels
mask_raw_list: List of raw numpy arrays of masks
Returns:
Tuple containing updated mask_list, image editor state, mask_list, and mask_raw_list
"""
image['image'] = image['background'].convert('RGB')
# del image['background'], image['composite']
assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB')
points = extract_points_from_mask(mask)
np.random.seed(0)
if points.shape[0] == 0:
raise gr.Error("No points selected")
points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False)
points = points[points_selected_indices]
coords = [points.tolist()]
mask_np = apply_sam(image['image'], coords)
mask_raw_list.append(mask_np)
mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(image['image'])).astype(np.uint8))
mask_list.append((mask_image, f"<region{len(mask_list)}>"))
# Return a list containing the mask image.
image['layers'] = []
image['composite'] = image['background']
return mask_list, image, mask_list, mask_raw_list
@spaces.GPU(duration=120)
def generate_masks_video(image, mask_list_video, mask_raw_list_video):
"""
Generate masks from user-drawn annotations on a video frame.
Args:
image: Dictionary containing the image editor state with background and layers
mask_list_video: List of generated mask images with labels for video
mask_raw_list_video: List of raw numpy arrays of masks for video
Returns:
Tuple containing updated mask_list_video, image editor state, mask_list_video, and mask_raw_list_video
"""
image['image'] = image['background'].convert('RGB')
# del image['background'], image['composite']
assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB')
points = extract_points_from_mask(mask)
np.random.seed(0)
if points.shape[0] == 0:
raise gr.Error("No points selected")
points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False)
points = points[points_selected_indices]
coords = [points.tolist()]
mask_np = apply_sam(image['image'], coords)
mask_raw_list_video.append(mask_np)
mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(image['image'])).astype(np.uint8))
mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>"))
# Return a list containing the mask image.
image['layers'] = []
image['composite'] = image['background']
return mask_list_video, image, mask_list_video, mask_raw_list_video
@spaces.GPU(duration=120)
def describe(image, mode, query, masks):
"""
Generate descriptions or answer questions about regions in an image.
Args:
image: Dictionary containing the image editor state
mode: Either "Caption" or "QA" mode
query: Question to ask about the image (used in QA mode)
masks: List of mask arrays for the regions
Returns:
Generator yielding image with contours, generated text, and updated image state
"""
# Create an image object from the uploaded image
# print(image.keys())
image['image'] = image['background'].convert('RGB')
# del image['background'], image['composite']
assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
# Handle both hex and rgba color formats
img_np = np.asarray(image['image']).astype(float) / 255.
if mode=='Caption':
mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB')
points = extract_points_from_mask(mask)
np.random.seed(0)
if points.shape[0] == 0:
if len(masks)>1:
raise gr.Error("No points selected")
else:
# Randomly sample 8 points from the mask
# Follow DAM https://github.com/NVlabs/describe-anything
points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False)
points = points[points_selected_indices]
coords = [points.tolist()]
mask_np = apply_sam(image['image'], coords)
masks = []
masks.append(mask_np)
mask_ids = [0]
img_with_contour_np = add_contour(img_np, mask_np, color=color_rgb)
img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
else:
img_with_contour_np = img_np.copy()
mask_ids = []
for i, mask_np in enumerate(masks):
# img_with_contour_np = add_contour(img_with_contour_np, mask_np, color=color_rgbs[i])
# img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
img_with_contour_pil = Image.fromarray((img_with_contour_np* 255.).astype(np.uint8))
mask_ids.append(0)
masks = np.stack(masks, axis=0)
masks = torch.from_numpy(masks).to(torch.uint8)
img = np.asarray(image['image'])
if mode == "Caption":
query = '<image>\nPlease describe the <region> in the image in detail.'
else:
if len(masks)==1:
prefix = "<image>\nThere is 1 region in the image: <region0> <region>. "
else:
prefix = f"<image>\nThere is {len(masks)} region in the image: "
for i in range(len(masks)):
prefix += f"<region{i}><region>, "
prefix = prefix[:-2]+'. '
query = prefix + query
# print(query)
image['layers'] = []
image['composite'] = image['background']
text = ""
yield img_with_contour_pil, text, image
for token in get_model_output(
[img],
query,
model=model,
tokenizer=tokenizer,
masks=masks,
mask_ids=mask_ids,
modal='image',
image_downsampling=1,
streaming=True,
):
text += token
yield gr.update(), text, gr.update()
def load_first_frame(video_path):
"""
Load and return the first frame of a video.
Args:
video_path: Path to the video file
Returns:
PIL Image of the first frame
"""
cap = cv2.VideoCapture(video_path)
ret, frame = cap.read()
cap.release()
if not ret:
raise gr.Error("Could not read the video file.")
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
image = Image.fromarray(frame)
return image
@spaces.GPU(duration=120)
def describe_video(video_path, mode, query, annotated_frame, masks, mask_list_video):
"""
Generate descriptions or answer questions about regions in a video.
Args:
video_path: Path to the video file
mode: Either "Caption" or "QA" mode
query: Question to ask about the video (used in QA mode)
annotated_frame: Dictionary containing the annotated first frame
masks: List of mask arrays for the regions
mask_list_video: List of mask images with labels
Returns:
Generator yielding frame image, generated text, and updated mask lists
"""
# Create a temporary directory to save extracted video frames
cap = cv2.VideoCapture(video_path)
video_tensor = load_video(video_path, fps=4, max_frames=768, frame_ids=[0])
annotated_frame['image'] = annotated_frame['background'].convert('RGB')
# Process the annotated frame from the image editor
if isinstance(annotated_frame, dict):
# Get the composite image with annotations
frame_img = annotated_frame.get("image", annotated_frame.get("background"))
if frame_img is None:
raise gr.Error("No valid annotation found in the image editor.")
frame_img = frame_img.convert("RGB")
# Get the annotation layer
if "layers" in annotated_frame and len(annotated_frame["layers"]) > 0:
mask = Image.fromarray((np.asarray(annotated_frame["layers"][0])[..., 3] > 0).astype(np.uint8) * 255).convert("RGB")
else:
mask = Image.new("RGB", frame_img.size, 0)
else:
frame_img = annotated_frame.convert("RGB")
mask = Image.new("RGB", frame_img.size, 0)
img_np = np.asarray(annotated_frame['image']).astype(float) / 255.
# Extract points from the annotated mask (using the first channel)
if mode == "Caption":
points = extract_points_from_mask(mask)
np.random.seed(0)
if points.shape[0] == 0:
raise gr.Error("No points were selected in the annotation.")
# Randomly select up to 8 points
# Follow DAM https://github.com/NVlabs/describe-anything
points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False)
points = points[points_selected_indices]
# print(f"Selected points (to SAM): {points}")
coords = [points.tolist()]
mask_np = apply_sam(annotated_frame['image'], coords)
masks = []
masks.append(mask_np)
mask_ids = [0]
# img_with_contour_np = add_contour(img_np, mask_np, color=color_rgb)
# img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
else:
img_with_contour_np = img_np.copy()
mask_ids = []
for i, mask_np in enumerate(masks):
# img_with_contour_np = add_contour(img_with_contour_np, mask_np, color=color_rgbs[i])
# img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
mask_ids.append(0)
masks = np.stack(masks, axis=0)
masks = torch.from_numpy(masks).to(torch.uint8)
if mode == "Caption":
query = '<video>\nPlease describe the <region> in the video in detail.'
else:
if len(masks)==1:
prefix = "<video>\nThere is 1 object in the video: <object0> <region>. "
else:
prefix = f"<video>\nThere is {len(masks)} objects in the video: "
for i in range(len(masks)):
prefix += f"<object{i}><region>, "
prefix = prefix[:-2]+'. '
query = prefix + query
# Initialize empty text
# text = description_generator
annotated_frame['layers'] = []
annotated_frame['composite'] = annotated_frame['background']
if mode=="Caption":
mask_list_video = []
mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(annotated_frame['image'])).astype(np.uint8))
mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>"))
text = ""
yield frame_img, text, mask_list_video, mask_list_video