import spaces import gradio as gr import numpy as np import torch from transformers import SamModel, SamProcessor from PIL import Image import os import cv2 import argparse import sys # This is for making model initialization faster and has no effect since we are loading the weights sys.path.append('./') from videollama3 import disable_torch_init, model_init, mm_infer, get_model_output from videollama3.mm_utils import load_images from videollama3.mm_utils import load_video color_rgb = (1.0, 1.0, 1.0) color_rgbs = [ (1.0, 1.0, 1.0), (1.0, 0.0, 0.0), (0.0, 1.0, 1.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0), (1.0, 0.0, 1.0), ] def extract_first_frame_from_video(video): cap = cv2.VideoCapture(video) success, frame = cap.read() cap.release() if success: return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) return None def extract_points_from_mask(mask_pil): mask = np.asarray(mask_pil)[..., 0] coords = np.nonzero(mask) coords = np.stack((coords[1], coords[0]), axis=1) return coords def add_contour(img, mask, color=(1., 1., 1.)): img = img.copy() mask = mask.astype(np.uint8) * 255 contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(img, contours, -1, color, thickness=8) return img @spaces.GPU(duration=120) def generate_masks(image, mask_list, mask_raw_list): """ Generates segmentation masks for selected regions in an image using SAM. Args: image (dict): A dictionary containing image data, typically from a Gradio ImageEditor, with 'background' (PIL Image) and 'layers' (list of PIL Image layers). mask_list (list): A list to accumulate (mask_image, label) tuples for display in a gallery. mask_raw_list (list): A list to accumulate raw NumPy mask arrays. Returns: tuple: A tuple containing: - mask_list (list): Updated list of mask images for display. - image (dict): Updated image dictionary with layers cleared. - mask_list (list): Redundant return of mask_list (for Gradio update). - mask_raw_list (list): Updated list of raw mask arrays. """ image['image'] = image['background'].convert('RGB') # del image['background'], image['composite'] assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}" mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB') points = extract_points_from_mask(mask) np.random.seed(0) if points.shape[0] == 0: raise gr.Error("No points selected") points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False) points = points[points_selected_indices] coords = [points.tolist()] mask_np = apply_sam(image['image'], coords) mask_raw_list.append(mask_np) mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(image['image'])).astype(np.uint8)) mask_list.append((mask_image, f"")) # Return a list containing the mask image. image['layers'] = [] image['composite'] = image['background'] return mask_list, image, mask_list, mask_raw_list @spaces.GPU(duration=120) def generate_masks_video(image, mask_list_video, mask_raw_list_video): """ Generates segmentation masks for selected regions in the first frame of a video using SAM. Args: image (dict): A dictionary containing image data (first frame of video), typically from a Gradio ImageEditor, with 'background' (PIL Image) and 'layers' (list of PIL Image layers). mask_list_video (list): A list to accumulate (mask_image, label) tuples for display. mask_raw_list_video (list): A list to accumulate raw NumPy mask arrays for video processing. Returns: tuple: A tuple containing: - mask_list_video (list): Updated list of mask images for display. - image (dict): Updated image dictionary with layers cleared. - mask_list_video (list): Redundant return of mask_list_video (for Gradio update). - mask_raw_list_video (list): Updated list of raw mask arrays. """ image['image'] = image['background'].convert('RGB') # del image['background'], image['composite'] assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}" mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB') points = extract_points_from_mask(mask) np.random.seed(0) if points.shape[0] == 0: raise gr.Error("No points selected") points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False) points = points[points_selected_indices] coords = [points.tolist()] mask_np = apply_sam(image['image'], coords) mask_raw_list_video.append(mask_np) mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(image['image'])).astype(np.uint8)) mask_list_video.append((mask_image, f"")) # Return a list containing the mask image. image['layers'] = [] image['composite'] = image['background'] return mask_list_video, image, mask_list_video, mask_raw_list_video @spaces.GPU(duration=120) def describe(image, mode, query, masks): """ Describes an image based on selected regions or answers a question about them. Args: image (dict): A dictionary containing image data, typically from a Gradio ImageEditor, with 'background' (PIL Image) and 'layers' (list of PIL Image layers). mode (str): The operational mode, either "Caption" (to describe a selected region) or "QA" (to answer a question about one or more regions). query (str): The question to ask in "QA" mode. Ignored in "Caption" mode. masks (list): A list of raw NumPy mask arrays representing previously generated masks. Yields: tuple: An image with contours and the generated text description/answer, or updates for Gradio components during streaming. """ # Create an image object from the uploaded image # print(image.keys()) image['image'] = image['background'].convert('RGB') # del image['background'], image['composite'] assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}" # Handle both hex and rgba color formats img_np = np.asarray(image['image']).astype(float) / 255. if mode=='Caption': mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB') points = extract_points_from_mask(mask) np.random.seed(0) if points.shape[0] == 0: if len(masks)>1: raise gr.Error("No points selected") else: # Randomly sample 8 points from the mask # Follow DAM https://github.com/NVlabs/describe-anything points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False) points = points[points_selected_indices] coords = [points.tolist()] mask_np = apply_sam(image['image'], coords) masks = [] masks.append(mask_np) mask_ids = [0] img_with_contour_np = add_contour(img_np, mask_np, color=color_rgb) img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8)) else: img_with_contour_np = img_np.copy() mask_ids = [] for i, mask_np in enumerate(masks): # img_with_contour_np = add_contour(img_with_contour_np, mask_np, color=color_rgbs[i]) # img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8)) img_with_contour_pil = Image.fromarray((img_with_contour_np* 255.).astype(np.uint8)) mask_ids.append(0) masks = np.stack(masks, axis=0) masks = torch.from_numpy(masks).to(torch.uint8) img = np.asarray(image['image']) if mode == "Caption": query = '\nPlease describe the in the image in detail.' else: if len(masks)==1: prefix = "\nThere is 1 region in the image: . " else: prefix = f"\nThere is {len(masks)} region in the image: " for i in range(len(masks)): prefix += f", " prefix = prefix[:-2]+'. ' query = prefix + query # print(query) image['layers'] = [] image['composite'] = image['background'] text = "" yield img_with_contour_pil, text, image for token in get_model_output( [img], query, model=model, tokenizer=tokenizer, masks=masks, mask_ids=mask_ids, modal='image', image_downsampling=1, streaming=True, ): text += token yield gr.update(), text, gr.update() def load_first_frame(video_path): """ Loads the first frame of a given video file. Args: video_path (str): The file path to the video. Returns: PIL.Image.Image: The first frame of the video as a PIL Image. Raises: gr.Error: If the video file cannot be read. """ cap = cv2.VideoCapture(video_path) ret, frame = cap.read() cap.release() if not ret: raise gr.Error("Could not read the video file.") frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) image = Image.fromarray(frame) return image @spaces.GPU(duration=120) def describe_video(video_path, mode, query, annotated_frame, masks, mask_list_video): """ Describes a video based on selected regions in its first frame or answers a question about them. Args: video_path (str): The file path to the video. mode (str): The operational mode, either "Caption" (to describe a selected region) or "QA" (to answer a question about one or more regions). query (str): The question to ask in "QA" mode. Ignored in "Caption" mode. annotated_frame (dict): A dictionary containing the first frame's image data from a Gradio ImageEditor, with 'background' (PIL Image) and 'layers' (list of PIL Image layers). masks (list): A list of raw NumPy mask arrays representing previously generated masks for objects in the video. mask_list_video (list): A list to accumulate (mask_image, label) tuples for display. Yields: tuple: The annotated first frame, the generated text description/answer, and updated mask lists for Gradio components during streaming. """ # Create a temporary directory to save extracted video frames cap = cv2.VideoCapture(video_path) video_tensor = load_video(video_path, fps=4, max_frames=768, frame_ids=[0]) annotated_frame['image'] = annotated_frame['background'].convert('RGB') # Process the annotated frame from the image editor if isinstance(annotated_frame, dict): # Get the composite image with annotations frame_img = annotated_frame.get("image", annotated_frame.get("background")) if frame_img is None: raise gr.Error("No valid annotation found in the image editor.") frame_img = frame_img.convert("RGB") # Get the annotation layer if "layers" in annotated_frame and len(annotated_frame["layers"]) > 0: mask = Image.fromarray((np.asarray(annotated_frame["layers"][0])[..., 3] > 0).astype(np.uint8) * 255).convert("RGB") else: mask = Image.new("RGB", frame_img.size, 0) else: frame_img = annotated_frame.convert("RGB") mask = Image.new("RGB", frame_img.size, 0) img_np = np.asarray(annotated_frame['image']).astype(float) / 255. # Extract points from the annotated mask (using the first channel) if mode == "Caption": points = extract_points_from_mask(mask) np.random.seed(0) if points.shape[0] == 0: raise gr.Error("No points were selected in the annotation.") # Randomly select up to 8 points # Follow DAM https://github.com/NVlabs/describe-anything points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False) points = points[points_selected_indices] # print(f"Selected points (to SAM): {points}") coords = [points.tolist()] mask_np = apply_sam(annotated_frame['image'], coords) masks = [] masks.append(mask_np) mask_ids = [0] # img_with_contour_np = add_contour(img_np, mask_np, color=color_rgb) # img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8)) else: img_with_contour_np = img_np.copy() mask_ids = [] for i, mask_np in enumerate(masks): # img_with_contour_np = add_contour(img_with_contour_np, mask_np, color=color_rgbs[i]) # img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8)) mask_ids.append(0) masks = np.stack(masks, axis=0) masks = torch.from_numpy(masks).to(torch.uint8) if mode == "Caption": query = '