import spaces import gradio as gr import numpy as np import torch from transformers import SamModel, SamProcessor from PIL import Image import os import cv2 import argparse import sys # This is for making model initialization faster and has no effect since we are loading the weights sys.path.append('./') from videollama3 import disable_torch_init, model_init, mm_infer, get_model_output from videollama3.mm_utils import load_images from videollama3.mm_utils import load_video color_rgb = (1.0, 1.0, 1.0) color_rgbs = [ (1.0, 1.0, 1.0), (1.0, 0.0, 0.0), (0.0, 1.0, 1.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0), (1.0, 0.0, 1.0), ] def extract_first_frame_from_video(video): cap = cv2.VideoCapture(video) success, frame = cap.read() cap.release() if success: return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) return None def extract_points_from_mask(mask_pil): mask = np.asarray(mask_pil)[..., 0] coords = np.nonzero(mask) coords = np.stack((coords[1], coords[0]), axis=1) return coords def add_contour(img, mask, color=(1., 1., 1.)): img = img.copy() mask = mask.astype(np.uint8) * 255 contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(img, contours, -1, color, thickness=8) return img @spaces.GPU(duration=120) def generate_masks(image, mask_list, mask_raw_list): image['image'] = image['background'].convert('RGB') # del image['background'], image['composite'] assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}" mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB') points = extract_points_from_mask(mask) np.random.seed(0) if points.shape[0] == 0: raise gr.Error("No points selected") points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False) points = points[points_selected_indices] coords = [points.tolist()] mask_np = apply_sam(image['image'], coords) mask_raw_list.append(mask_np) mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(image['image'])).astype(np.uint8)) mask_list.append((mask_image, f"")) # Return a list containing the mask image. image['layers'] = [] image['composite'] = image['background'] return mask_list, image, mask_list, mask_raw_list @spaces.GPU(duration=120) def generate_masks_video(image, mask_list_video, mask_raw_list_video): image['image'] = image['background'].convert('RGB') # del image['background'], image['composite'] assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}" mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB') points = extract_points_from_mask(mask) np.random.seed(0) if points.shape[0] == 0: raise gr.Error("No points selected") points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False) points = points[points_selected_indices] coords = [points.tolist()] mask_np = apply_sam(image['image'], coords) mask_raw_list_video.append(mask_np) mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(image['image'])).astype(np.uint8)) mask_list_video.append((mask_image, f"")) # Return a list containing the mask image. image['layers'] = [] image['composite'] = image['background'] return mask_list_video, image, mask_list_video, mask_raw_list_video @spaces.GPU(duration=120) def describe(image, mode, query, masks): # Create an image object from the uploaded image # print(image.keys()) image['image'] = image['background'].convert('RGB') # del image['background'], image['composite'] assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}" # Handle both hex and rgba color formats img_np = np.asarray(image['image']).astype(float) / 255. if mode=='Caption': mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB') points = extract_points_from_mask(mask) np.random.seed(0) if points.shape[0] == 0: if len(masks)>1: raise gr.Error("No points selected") else: # Randomly sample 8 points from the mask # Follow DAM https://github.com/NVlabs/describe-anything points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False) points = points[points_selected_indices] coords = [points.tolist()] mask_np = apply_sam(image['image'], coords) masks = [] masks.append(mask_np) mask_ids = [0] img_with_contour_np = add_contour(img_np, mask_np, color=color_rgb) img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8)) else: img_with_contour_np = img_np.copy() mask_ids = [] for i, mask_np in enumerate(masks): # img_with_contour_np = add_contour(img_with_contour_np, mask_np, color=color_rgbs[i]) # img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8)) img_with_contour_pil = Image.fromarray((img_with_contour_np* 255.).astype(np.uint8)) mask_ids.append(0) masks = np.stack(masks, axis=0) masks = torch.from_numpy(masks).to(torch.uint8) img = np.asarray(image['image']) if mode == "Caption": query = '\nPlease describe the in the image in detail.' else: if len(masks)==1: prefix = "\nThere is 1 region in the image: . " else: prefix = f"\nThere is {len(masks)} region in the image: " for i in range(len(masks)): prefix += f", " prefix = prefix[:-2]+'. ' query = prefix + query # print(query) image['layers'] = [] image['composite'] = image['background'] text = "" yield img_with_contour_pil, text, image for token in get_model_output( [img], query, model=model, tokenizer=tokenizer, masks=masks, mask_ids=mask_ids, modal='image', image_downsampling=1, streaming=True, ): text += token yield gr.update(), text, gr.update() def load_first_frame(video_path): cap = cv2.VideoCapture(video_path) ret, frame = cap.read() cap.release() if not ret: raise gr.Error("Could not read the video file.") frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) image = Image.fromarray(frame) return image @spaces.GPU(duration=120) def describe_video(video_path, mode, query, annotated_frame, masks, mask_list_video): # Create a temporary directory to save extracted video frames cap = cv2.VideoCapture(video_path) video_tensor = load_video(video_path, fps=4, max_frames=768, frame_ids=[0]) annotated_frame['image'] = annotated_frame['background'].convert('RGB') # Process the annotated frame from the image editor if isinstance(annotated_frame, dict): # Get the composite image with annotations frame_img = annotated_frame.get("image", annotated_frame.get("background")) if frame_img is None: raise gr.Error("No valid annotation found in the image editor.") frame_img = frame_img.convert("RGB") # Get the annotation layer if "layers" in annotated_frame and len(annotated_frame["layers"]) > 0: mask = Image.fromarray((np.asarray(annotated_frame["layers"][0])[..., 3] > 0).astype(np.uint8) * 255).convert("RGB") else: mask = Image.new("RGB", frame_img.size, 0) else: frame_img = annotated_frame.convert("RGB") mask = Image.new("RGB", frame_img.size, 0) img_np = np.asarray(annotated_frame['image']).astype(float) / 255. # Extract points from the annotated mask (using the first channel) if mode == "Caption": points = extract_points_from_mask(mask) np.random.seed(0) if points.shape[0] == 0: raise gr.Error("No points were selected in the annotation.") # Randomly select up to 8 points # Follow DAM https://github.com/NVlabs/describe-anything points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False) points = points[points_selected_indices] # print(f"Selected points (to SAM): {points}") coords = [points.tolist()] mask_np = apply_sam(annotated_frame['image'], coords) masks = [] masks.append(mask_np) mask_ids = [0] # img_with_contour_np = add_contour(img_np, mask_np, color=color_rgb) # img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8)) else: img_with_contour_np = img_np.copy() mask_ids = [] for i, mask_np in enumerate(masks): # img_with_contour_np = add_contour(img_with_contour_np, mask_np, color=color_rgbs[i]) # img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8)) mask_ids.append(0) masks = np.stack(masks, axis=0) masks = torch.from_numpy(masks).to(torch.uint8) if mode == "Caption": query = '