diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..f2c74859d7e120622a212127ea3cbcb5dfb22caf Binary files /dev/null and b/.DS_Store differ diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..a26ede40ac4da2c448114b873bf8524cf9b5d21d 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +demo/videos/ filter=lfs diff=lfs merge=lfs -text +demo/videos/3.mp4 filter=lfs diff=lfs merge=lfs -text +demo/videos/4.mp4 filter=lfs diff=lfs merge=lfs -text +demo/videos/1.mp4 filter=lfs diff=lfs merge=lfs -text +demo/videos/2.mp4 filter=lfs diff=lfs merge=lfs -text +demo/images/4.jpg filter=lfs diff=lfs merge=lfs -text +demo/images/5.jpg filter=lfs diff=lfs merge=lfs -text +demo/images/6.jpg filter=lfs diff=lfs merge=lfs -text +demo/images/8.jpg filter=lfs diff=lfs merge=lfs -text +demo/images/1.jpg filter=lfs diff=lfs merge=lfs -text +demo/images/2.jpg filter=lfs diff=lfs merge=lfs -text +demo/images/3.jpg filter=lfs diff=lfs merge=lfs -text +demo/images/7.jpg filter=lfs diff=lfs merge=lfs -text +demo/images/LICENSE filter=lfs diff=lfs merge=lfs -text diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..4ceffc0625f1ea00cb339e68cae55cd95a2e4dac --- /dev/null +++ b/app.py @@ -0,0 +1,562 @@ +import gradio as gr +import numpy as np +import torch +from transformers import SamModel, SamProcessor +from PIL import Image +import os +import cv2 +import argparse +import sys +# This is for making model initialization faster and has no effect since we are loading the weights +sys.path.append('./') +from videollama3 import disable_torch_init, model_init, mm_infer, get_model_output +from videollama3.mm_utils import load_images +from videollama3.mm_utils import load_video + + +color_rgb = (1.0, 1.0, 1.0) +color_rgbs = [ + (1.0, 1.0, 1.0), + (1.0, 0.0, 0.0), + (0.0, 1.0, 1.0), + (0.0, 1.0, 0.0), + (0.0, 0.0, 1.0), + (1.0, 0.0, 1.0), + ] + +mask_list = [] +mask_raw_list = [] +mask_list_video = [] +mask_raw_list_video = [] + +def extract_first_frame_from_video(video): + cap = cv2.VideoCapture(video) + success, frame = cap.read() + cap.release() + if success: + return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + return None + +def extract_points_from_mask(mask_pil): + mask = np.asarray(mask_pil)[..., 0] + coords = np.nonzero(mask) + coords = np.stack((coords[1], coords[0]), axis=1) + + return coords + +def add_contour(img, mask, color=(1., 1., 1.)): + img = img.copy() + + mask = mask.astype(np.uint8) * 255 + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + cv2.drawContours(img, contours, -1, color, thickness=8) + + return img + +def generate_masks(image): + global mask_list + global mask_raw_list + image['image'] = image['background'].convert('RGB') + # del image['background'], image['composite'] + assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}" + + mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB') + points = extract_points_from_mask(mask) + np.random.seed(0) + if points.shape[0] == 0: + raise gr.Error("No points selected") + + points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False) + points = points[points_selected_indices] + coords = [points.tolist()] + mask_np = apply_sam(image['image'], coords) + + mask_raw_list.append(mask_np) + mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(image['image'])).astype(np.uint8)) + + mask_list.append((mask_image, f"")) + # Return a list containing the mask image. + image['layers'] = [] + image['composite'] = image['background'] + return mask_list, image + + +def generate_masks_video(image): + global mask_list_video + global mask_raw_list_video + image['image'] = image['background'].convert('RGB') + # del image['background'], image['composite'] + assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}" + + mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB') + points = extract_points_from_mask(mask) + np.random.seed(0) + if points.shape[0] == 0: + raise gr.Error("No points selected") + + points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False) + points = points[points_selected_indices] + coords = [points.tolist()] + mask_np = apply_sam(image['image'], coords) + + mask_raw_list_video.append(mask_np) + mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(image['image'])).astype(np.uint8)) + + mask_list_video.append((mask_image, f"")) + # Return a list containing the mask image. + image['layers'] = [] + image['composite'] = image['background'] + return mask_list_video, image + + + +def describe(image, mode, query, masks): + # Create an image object from the uploaded image + # print(image.keys()) + + image['image'] = image['background'].convert('RGB') + # del image['background'], image['composite'] + assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}" + + # Handle both hex and rgba color formats + + img_np = np.asarray(image['image']).astype(float) / 255. + if mode=='Caption': + mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB') + + points = extract_points_from_mask(mask) + + np.random.seed(0) + + if points.shape[0] == 0: + if len(masks)>1: + raise gr.Error("No points selected") + + else: + # Randomly sample 8 points from the mask + # Follow DAM https://github.com/NVlabs/describe-anything + points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False) + points = points[points_selected_indices] + + coords = [points.tolist()] + + mask_np = apply_sam(image['image'], coords) + + masks = [] + masks.append(mask_np) + mask_ids = [0] + + img_with_contour_np = add_contour(img_np, mask_np, color=color_rgb) + img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8)) + else: + masks = mask_raw_list + img_with_contour_np = img_np.copy() + + mask_ids = [] + for i, mask_np in enumerate(masks): + img_with_contour_np = add_contour(img_with_contour_np, mask_np, color=color_rgbs[i]) + img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8)) + mask_ids.append(0) + + masks = np.stack(masks, axis=0) + masks = torch.from_numpy(masks).to(torch.uint8) + + + + img = np.asarray(image['image']) + + + if mode == "Caption": + query = '\nPlease describe the in the image in detail.' + else: + if len(masks)==1: + prefix = "\nThere is 1 region in the image: . " + else: + prefix = f"\nThere is {len(masks)} region in the image: " + for i in range(len(masks)): + prefix += f", " + prefix = prefix[:-2]+'. ' + query = prefix + query + # print(query) + + image['layers'] = [] + image['composite'] = image['background'] + + text = "" + yield img_with_contour_pil, text, image + + for token in get_model_output( + [img], + query, + model=model, + tokenizer=tokenizer, + masks=masks, + mask_ids=mask_ids, + modal='image', + image_downsampling=1, + streaming=True, + ): + text += token + yield gr.update(), text, gr.update() + + +def load_first_frame(video_path): + cap = cv2.VideoCapture(video_path) + ret, frame = cap.read() + cap.release() + if not ret: + raise gr.Error("Could not read the video file.") + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + image = Image.fromarray(frame) + return image + +def describe_video(video_path, mode, query, annotated_frame, masks): + global mask_list_video + # Create a temporary directory to save extracted video frames + cap = cv2.VideoCapture(video_path) + + video_tensor = load_video(video_path, fps=4, max_frames=768, frame_ids=[0]) + + annotated_frame['image'] = annotated_frame['background'].convert('RGB') + + # Process the annotated frame from the image editor + if isinstance(annotated_frame, dict): + # Get the composite image with annotations + frame_img = annotated_frame.get("image", annotated_frame.get("background")) + if frame_img is None: + raise gr.Error("No valid annotation found in the image editor.") + frame_img = frame_img.convert("RGB") + + # Get the annotation layer + if "layers" in annotated_frame and len(annotated_frame["layers"]) > 0: + mask = Image.fromarray((np.asarray(annotated_frame["layers"][0])[..., 3] > 0).astype(np.uint8) * 255).convert("RGB") + else: + mask = Image.new("RGB", frame_img.size, 0) + else: + frame_img = annotated_frame.convert("RGB") + mask = Image.new("RGB", frame_img.size, 0) + + img_np = np.asarray(annotated_frame['image']).astype(float) / 255. + # Extract points from the annotated mask (using the first channel) + if mode == "Caption": + points = extract_points_from_mask(mask) + np.random.seed(0) + if points.shape[0] == 0: + raise gr.Error("No points were selected in the annotation.") + # Randomly select up to 8 points + # Follow DAM https://github.com/NVlabs/describe-anything + points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False) + points = points[points_selected_indices] + + # print(f"Selected points (to SAM): {points}") + + coords = [points.tolist()] + + mask_np = apply_sam(annotated_frame['image'], coords) + + masks = [] + masks.append(mask_np) + mask_ids = [0] + + # img_with_contour_np = add_contour(img_np, mask_np, color=color_rgb) + # img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8)) + + + else: + masks = mask_raw_list_video + img_with_contour_np = img_np.copy() + + mask_ids = [] + for i, mask_np in enumerate(masks): + # img_with_contour_np = add_contour(img_with_contour_np, mask_np, color=color_rgbs[i]) + # img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8)) + mask_ids.append(0) + + + + masks = np.stack(masks, axis=0) + masks = torch.from_numpy(masks).to(torch.uint8) + + + + + if mode == "Caption": + query = '