Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import numpy as np | |
import torch | |
from transformers import SamModel, SamProcessor | |
from PIL import Image | |
import os | |
import cv2 | |
import argparse | |
import sys | |
# This is for making model initialization faster and has no effect since we are loading the weights | |
sys.path.append('./') | |
from videollama3 import disable_torch_init, model_init, mm_infer, get_model_output | |
from videollama3.mm_utils import load_images | |
from videollama3.mm_utils import load_video | |
color_rgb = (1.0, 1.0, 1.0) | |
color_rgbs = [ | |
(1.0, 1.0, 1.0), | |
(1.0, 0.0, 0.0), | |
(0.0, 1.0, 1.0), | |
(0.0, 1.0, 0.0), | |
(0.0, 0.0, 1.0), | |
(1.0, 0.0, 1.0), | |
] | |
def extract_first_frame_from_video(video): | |
cap = cv2.VideoCapture(video) | |
success, frame = cap.read() | |
cap.release() | |
if success: | |
return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
return None | |
def extract_points_from_mask(mask_pil): | |
mask = np.asarray(mask_pil)[..., 0] | |
coords = np.nonzero(mask) | |
coords = np.stack((coords[1], coords[0]), axis=1) | |
return coords | |
def add_contour(img, mask, color=(1., 1., 1.)): | |
img = img.copy() | |
mask = mask.astype(np.uint8) * 255 | |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
cv2.drawContours(img, contours, -1, color, thickness=8) | |
return img | |
def generate_masks(image, mask_list, mask_raw_list): | |
""" | |
Generate masks from user-drawn annotations on an image. | |
Args: | |
image: Dictionary containing the image editor state with background and layers | |
mask_list: List of generated mask images with labels | |
mask_raw_list: List of raw numpy arrays of masks | |
Returns: | |
Tuple containing updated mask_list, image editor state, mask_list, and mask_raw_list | |
""" | |
image['image'] = image['background'].convert('RGB') | |
# del image['background'], image['composite'] | |
assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}" | |
mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB') | |
points = extract_points_from_mask(mask) | |
np.random.seed(0) | |
if points.shape[0] == 0: | |
raise gr.Error("No points selected") | |
points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False) | |
points = points[points_selected_indices] | |
coords = [points.tolist()] | |
mask_np = apply_sam(image['image'], coords) | |
mask_raw_list.append(mask_np) | |
mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(image['image'])).astype(np.uint8)) | |
mask_list.append((mask_image, f"<region{len(mask_list)}>")) | |
# Return a list containing the mask image. | |
image['layers'] = [] | |
image['composite'] = image['background'] | |
return mask_list, image, mask_list, mask_raw_list | |
def generate_masks_video(image, mask_list_video, mask_raw_list_video): | |
""" | |
Generate masks from user-drawn annotations on a video frame. | |
Args: | |
image: Dictionary containing the image editor state with background and layers | |
mask_list_video: List of generated mask images with labels for video | |
mask_raw_list_video: List of raw numpy arrays of masks for video | |
Returns: | |
Tuple containing updated mask_list_video, image editor state, mask_list_video, and mask_raw_list_video | |
""" | |
image['image'] = image['background'].convert('RGB') | |
# del image['background'], image['composite'] | |
assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}" | |
mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB') | |
points = extract_points_from_mask(mask) | |
np.random.seed(0) | |
if points.shape[0] == 0: | |
raise gr.Error("No points selected") | |
points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False) | |
points = points[points_selected_indices] | |
coords = [points.tolist()] | |
mask_np = apply_sam(image['image'], coords) | |
mask_raw_list_video.append(mask_np) | |
mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(image['image'])).astype(np.uint8)) | |
mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>")) | |
# Return a list containing the mask image. | |
image['layers'] = [] | |
image['composite'] = image['background'] | |
return mask_list_video, image, mask_list_video, mask_raw_list_video | |
def describe(image, mode, query, masks): | |
""" | |
Generate descriptions or answer questions about regions in an image. | |
Args: | |
image: Dictionary containing the image editor state | |
mode: Either "Caption" or "QA" mode | |
query: Question to ask about the image (used in QA mode) | |
masks: List of mask arrays for the regions | |
Returns: | |
Generator yielding image with contours, generated text, and updated image state | |
""" | |
# Create an image object from the uploaded image | |
# print(image.keys()) | |
image['image'] = image['background'].convert('RGB') | |
# del image['background'], image['composite'] | |
assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}" | |
# Handle both hex and rgba color formats | |
img_np = np.asarray(image['image']).astype(float) / 255. | |
if mode=='Caption': | |
mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB') | |
points = extract_points_from_mask(mask) | |
np.random.seed(0) | |
if points.shape[0] == 0: | |
if len(masks)>1: | |
raise gr.Error("No points selected") | |
else: | |
# Randomly sample 8 points from the mask | |
# Follow DAM https://github.com/NVlabs/describe-anything | |
points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False) | |
points = points[points_selected_indices] | |
coords = [points.tolist()] | |
mask_np = apply_sam(image['image'], coords) | |
masks = [] | |
masks.append(mask_np) | |
mask_ids = [0] | |
img_with_contour_np = add_contour(img_np, mask_np, color=color_rgb) | |
img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8)) | |
else: | |
img_with_contour_np = img_np.copy() | |
mask_ids = [] | |
for i, mask_np in enumerate(masks): | |
# img_with_contour_np = add_contour(img_with_contour_np, mask_np, color=color_rgbs[i]) | |
# img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8)) | |
img_with_contour_pil = Image.fromarray((img_with_contour_np* 255.).astype(np.uint8)) | |
mask_ids.append(0) | |
masks = np.stack(masks, axis=0) | |
masks = torch.from_numpy(masks).to(torch.uint8) | |
img = np.asarray(image['image']) | |
if mode == "Caption": | |
query = '<image>\nPlease describe the <region> in the image in detail.' | |
else: | |
if len(masks)==1: | |
prefix = "<image>\nThere is 1 region in the image: <region0> <region>. " | |
else: | |
prefix = f"<image>\nThere is {len(masks)} region in the image: " | |
for i in range(len(masks)): | |
prefix += f"<region{i}><region>, " | |
prefix = prefix[:-2]+'. ' | |
query = prefix + query | |
# print(query) | |
image['layers'] = [] | |
image['composite'] = image['background'] | |
text = "" | |
yield img_with_contour_pil, text, image | |
for token in get_model_output( | |
[img], | |
query, | |
model=model, | |
tokenizer=tokenizer, | |
masks=masks, | |
mask_ids=mask_ids, | |
modal='image', | |
image_downsampling=1, | |
streaming=True, | |
): | |
text += token | |
yield gr.update(), text, gr.update() | |
def load_first_frame(video_path): | |
""" | |
Load and return the first frame of a video. | |
Args: | |
video_path: Path to the video file | |
Returns: | |
PIL Image of the first frame | |
""" | |
cap = cv2.VideoCapture(video_path) | |
ret, frame = cap.read() | |
cap.release() | |
if not ret: | |
raise gr.Error("Could not read the video file.") | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
image = Image.fromarray(frame) | |
return image | |
def describe_video(video_path, mode, query, annotated_frame, masks, mask_list_video): | |
""" | |
Generate descriptions or answer questions about regions in a video. | |
Args: | |
video_path: Path to the video file | |
mode: Either "Caption" or "QA" mode | |
query: Question to ask about the video (used in QA mode) | |
annotated_frame: Dictionary containing the annotated first frame | |
masks: List of mask arrays for the regions | |
mask_list_video: List of mask images with labels | |
Returns: | |
Generator yielding frame image, generated text, and updated mask lists | |
""" | |
# Create a temporary directory to save extracted video frames | |
cap = cv2.VideoCapture(video_path) | |
video_tensor = load_video(video_path, fps=4, max_frames=768, frame_ids=[0]) | |
annotated_frame['image'] = annotated_frame['background'].convert('RGB') | |
# Process the annotated frame from the image editor | |
if isinstance(annotated_frame, dict): | |
# Get the composite image with annotations | |
frame_img = annotated_frame.get("image", annotated_frame.get("background")) | |
if frame_img is None: | |
raise gr.Error("No valid annotation found in the image editor.") | |
frame_img = frame_img.convert("RGB") | |
# Get the annotation layer | |
if "layers" in annotated_frame and len(annotated_frame["layers"]) > 0: | |
mask = Image.fromarray((np.asarray(annotated_frame["layers"][0])[..., 3] > 0).astype(np.uint8) * 255).convert("RGB") | |
else: | |
mask = Image.new("RGB", frame_img.size, 0) | |
else: | |
frame_img = annotated_frame.convert("RGB") | |
mask = Image.new("RGB", frame_img.size, 0) | |
img_np = np.asarray(annotated_frame['image']).astype(float) / 255. | |
# Extract points from the annotated mask (using the first channel) | |
if mode == "Caption": | |
points = extract_points_from_mask(mask) | |
np.random.seed(0) | |
if points.shape[0] == 0: | |
raise gr.Error("No points were selected in the annotation.") | |
# Randomly select up to 8 points | |
# Follow DAM https://github.com/NVlabs/describe-anything | |
points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False) | |
points = points[points_selected_indices] | |
# print(f"Selected points (to SAM): {points}") | |
coords = [points.tolist()] | |
mask_np = apply_sam(annotated_frame['image'], coords) | |
masks = [] | |
masks.append(mask_np) | |
mask_ids = [0] | |
# img_with_contour_np = add_contour(img_np, mask_np, color=color_rgb) | |
# img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8)) | |
else: | |
img_with_contour_np = img_np.copy() | |
mask_ids = [] | |
for i, mask_np in enumerate(masks): | |
# img_with_contour_np = add_contour(img_with_contour_np, mask_np, color=color_rgbs[i]) | |
# img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8)) | |
mask_ids.append(0) | |
masks = np.stack(masks, axis=0) | |
masks = torch.from_numpy(masks).to(torch.uint8) | |
if mode == "Caption": | |
query = '<video>\nPlease describe the <region> in the video in detail.' | |
else: | |
if len(masks)==1: | |
prefix = "<video>\nThere is 1 object in the video: <object0> <region>. " | |
else: | |
prefix = f"<video>\nThere is {len(masks)} objects in the video: " | |
for i in range(len(masks)): | |
prefix += f"<object{i}><region>, " | |
prefix = prefix[:-2]+'. ' | |
query = prefix + query | |
# Initialize empty text | |
# text = description_generator | |
annotated_frame['layers'] = [] | |
annotated_frame['composite'] = annotated_frame['background'] | |
if mode=="Caption": | |
mask_list_video = [] | |
mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(annotated_frame['image'])).astype(np.uint8)) | |
mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>")) | |
text = "" | |
yield frame_img, text, mask_list_video, mask_list_video |