############################################################## # copy from cognitron_vl/constants.py ############################################################## import logging logger = logging.getLogger(__name__) if True: IMG_TAG_TOKEN = "<image>" VID_TAG_TOKEN = "<video>" AUD_TAG_TOKEN = "<audio>" IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>' IMG_START_TOKEN = '<img>' IMG_END_TOKEN = '</img>' VID_CONTEXT_TOKEN = '<VID_CONTEXT>' VID_START_TOKEN = '<vid>' VID_END_TOKEN = '</vid>' PATCH_CONTEXT_TOKEN = '<PATCH_CONTEXT>' PATCH_START_TOKEN = '<patch>' PATCH_END_TOKEN = '</patch>' AUD_START_TOKEN = '<|begin_of_audio|>' AUD_END_TOKEN = '<|end_of_audio|>' QUAD_START_TOKEN = '<quad>' QUAD_END_TOKEN = '</quad>' REF_START_TOKEN = '<ref>' REF_END_TOKEN = '</ref>' BOX_START_TOKEN = '<box>' BOX_END_TOKEN = '</box>' if False: IMG_TAG_TOKEN = "<|image|>" VID_TAG_TOKEN = "<|video|>" AUD_TAG_TOKEN = "<|audio|>" IMG_CONTEXT_TOKEN = '<|context_of_image|>' IMG_START_TOKEN = '<|begin_of_image|>' IMG_END_TOKEN = '<|end_of_image|>' VID_CONTEXT_TOKEN = '<|context_of_video|>' VID_START_TOKEN = '<|begin_of_video|>' VID_END_TOKEN = '<|end_of_video|>' PATCH_CONTEXT_TOKEN = '<|context_of_patch|>' PATCH_START_TOKEN = '<|begin_of_patch|>' PATCH_END_TOKEN = '<|end_of_patch|>' AUD_START_TOKEN = '<|begin_of_audio|>' AUD_END_TOKEN = '<|end_of_audio|>' QUAD_START_TOKEN = '<|begin_of_quad|>' QUAD_END_TOKEN = '<|end_of_quad|>' REF_START_TOKEN = '<|begin_of_ref|>' REF_END_TOKEN = '<|end_of_ref|>' BOX_START_TOKEN = '<|begin_of_box|>' BOX_END_TOKEN = '<|end_of_box|>' logger.info(f"IMG_TAG_TOKEN {IMG_TAG_TOKEN}") logger.info(f"VID_TAG_TOKEN {VID_TAG_TOKEN}") logger.info(f"AUD_TAG_TOKEN {AUD_TAG_TOKEN}") logger.info(f"IMG_CONTEXT_TOKEN {IMG_CONTEXT_TOKEN}") logger.info(f"IMG_START_TOKEN {IMG_START_TOKEN}") logger.info(f"IMG_END_TOKEN {IMG_END_TOKEN}") logger.info(f"VID_CONTEXT_TOKEN {VID_CONTEXT_TOKEN}") logger.info(f"VID_START_TOKEN {VID_START_TOKEN}") logger.info(f"VID_END_TOKEN {VID_END_TOKEN}") logger.info(f"PATCH_CONTEXT_TOKEN {PATCH_CONTEXT_TOKEN}") logger.info(f"PATCH_START_TOKEN {PATCH_START_TOKEN}") logger.info(f"PATCH_END_TOKEN {PATCH_END_TOKEN}") logger.info(f"AUD_START_TOKEN {AUD_START_TOKEN}") logger.info(f"AUD_END_TOKEN {AUD_END_TOKEN}") # IMAGENET_MEAN = (0.485, 0.456, 0.406) # IMAGENET_STD = (0.229, 0.224, 0.225) # CLIP_MEAN = (0.4814546, 0.4578275, 0.40821073) # CLIP_STD = (0.2686295, 0.2613025, 0.2757711) # SIGLIP_MEAN = (0.5, 0.5, 0.5) # SIGLIP_STD = (0.5, 0.5, 0.5) IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406] IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225] IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5] IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5] OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073] OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711] # Model Constants IGNORE_INDEX = -100 IMAGE_TOKEN_INDEX = -200 DEFAULT_IMAGE_TOKEN = IMG_CONTEXT_TOKEN DEFAULT_IMAGE_PATCH_TOKEN = PATCH_CONTEXT_TOKEN DEFAULT_IM_START_TOKEN = IMG_START_TOKEN DEFAULT_IM_END_TOKEN = IMG_END_TOKEN ############################################################## ############################################################## # copy from cognitron_vl/data/processor/image_processor.py ############################################################## import math import os import cv2 import natsort import numpy as np import torch from PIL import Image import decord # from cognitron_vl.constants import ( # IMAGENET_DEFAULT_MEAN, # IMAGENET_DEFAULT_STD, # IMAGENET_STANDARD_MEAN, # IMAGENET_STANDARD_STD, # OPENAI_CLIP_MEAN, # OPENAI_CLIP_STD, # ) class ImageProcessor: def __init__( self, process_type, image_size=448, normalize_type="imagenet", min_patch_grid=1, max_patch_grid=6, ): self.process_type = process_type self.image_size = image_size if normalize_type == "imagenet": MEAN, STD = IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD elif normalize_type == "clip": MEAN, STD = OPENAI_CLIP_MEAN, OPENAI_CLIP_STD elif normalize_type == "siglip": MEAN, STD = IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD else: raise NotImplementedError self.mean = MEAN self.std = STD self.patch_size = image_size self.min_patch_grid = min_patch_grid self.max_patch_grid = max_patch_grid if self.process_type == "anyres": self.grid_pinpoints = [ (i, j) for i in range(min_patch_grid, max_patch_grid + 1) for j in range(min_patch_grid, max_patch_grid + 1) ] self.possible_resolutions = [ [dim * self.patch_size for dim in pair] for pair in self.grid_pinpoints ] print(f"grid_pinpoints {self.grid_pinpoints}") print(f"possible_resolutions {self.possible_resolutions}") if self.process_type == "dynamic": max_num = self.max_patch_grid min_num = self.min_patch_grid # calculate the existing image aspect ratio target_ratios = set( (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num ) self.target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) self.possible_resolutions = [ [dim * self.patch_size for dim in pair] for pair in self.target_ratios ] print(f"target_ratios {self.target_ratios}") print(f"possible_resolutions {self.possible_resolutions}") def get_frame_paths(self, frame_root, num_frames=8): os.makedirs(frame_root, exist_ok=True) self.frame_tmpl = "frame-{}-of-{}.jpg" return [ os.path.join(frame_root, self.frame_tmpl.format(i, num_frames)) for i in range(1, num_frames + 1) ] def save_video_frames(self, vid_path, max_fps=1, num_frames=8): vid = decord.VideoReader(vid_path, num_threads=1) step_size = len(vid) / (num_frames + 1) # step_size = max(1, step_size) fps = vid.get_avg_fps() step_size = max(fps / max_fps, step_size) # indices = [int(i * step_size) for i in range(1, num_frames + 1)] indices = [int(i * step_size) for i in range(0, num_frames)] indices = [i for i in indices if i < len(vid)] num_frames = len(indices) frame_paths = self.get_frame_paths(vid_path + ".saved_frames", num_frames) flag = np.all([os.path.exists(p) for p in frame_paths]) if flag: return frame_paths images = [vid[i].asnumpy() for i in indices] images = [Image.fromarray(arr) for arr in images] for im, pth in zip(images, frame_paths): # if not os.path.exists(pth): # im.save(pth) im.save(pth) # print(f"save_video_frames vid_path {vid_path} fps {fps} len(vid) {len(vid)} frame_paths {frame_paths}") return frame_paths def get_video_frames(self, vid_path, max_fps=1, num_frames=8): vid = decord.VideoReader(vid_path, num_threads=1) step_size = len(vid) / (num_frames + 1) # step_size = max(1, step_size) fps = vid.get_avg_fps() step_size = max(fps / max_fps, step_size) # indices = [int(i * step_size) for i in range(1, num_frames + 1)] indices = [int(i * step_size) for i in range(0, num_frames)] indices = [i for i in indices if i < len(vid)] images = [vid[i].asnumpy() for i in indices] images = [Image.fromarray(arr) for arr in images] # print(f"save_video_frames vid_path {vid_path} fps {fps} len(vid) {len(vid)} frame_paths {frame_paths}") return images def process_video(self, video_file_or_dir, max_num_frame=8, max_fps=1): if os.path.isdir(video_file_or_dir): all_filepath = [] for root, dirs, files in os.walk(video_file_or_dir): for filename in files: if ( filename.endswith("png") or filename.endswith("jpeg") or filename.endswith("jpg") ): filepath = os.path.join(root, filename) all_filepath.append(filepath) if len(all_filepath) == 0: return None # all_filepath.sort() all_filepath = natsort.natsorted(all_filepath) total_frame = len(all_filepath) if "ShareGPTVideo" in video_file_or_dir: fps = 2 else: fps = 1 target_frame = int(min(total_frame / fps * max_fps, max_num_frame)) index = [int(1.0 * total_frame / target_frame) * x for x in range(target_frame)] selected_filepath = [all_filepath[x] for x in index] img_or_path_list = selected_filepath # print(f"process_video {img_or_path_list}") elif os.path.isfile(video_file_or_dir): # frame_paths = self.save_video_frames( # video_file_or_dir, num_frames=max_num_frame, max_fps=max_fps # ) # img_or_path_list = frame_paths img_or_path_list = self.get_video_frames( video_file_or_dir, num_frames=max_num_frame, max_fps=max_fps ) else: # print(f"FileNotFoundError {video_file_or_dir}") raise NotImplementedError return self.process_images(img_or_path_list), img_or_path_list def process_images(self, img_or_path_list): if isinstance(img_or_path_list[0], str): images = [Image.open(x).convert("RGB") for x in img_or_path_list] elif isinstance(img_or_path_list[0], Image.Image): images = [x.convert("RGB") for x in img_or_path_list] else: images = img_or_path_list def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image_tensor = torch.ones([len(images), 3, self.image_size, self.image_size]) for i, image in enumerate(images): image = expand2square(image, tuple(int(x * 255) for x in self.mean)) image = image.resize( (self.image_size, self.image_size), resample=Image.Resampling.BICUBIC ) image = np.array(image, dtype=np.float32) image = image * 1.0 / 255.0 mean = np.array(self.mean, dtype=image.dtype) std = np.array(self.std, dtype=image.dtype) image = (image - mean) / std image = torch.tensor(image, dtype=torch.float32) image = image.permute(2, 0, 1) image_tensor[i] = image return image_tensor def process_images_with_subpatch(self, img_or_path): if self.process_type == "anyres": return self.process_anyres(img_or_path) if self.process_type == "dynamic": return self.process_dynamic(img_or_path) if isinstance(img_or_path, str): image = Image.open(img_or_path).convert("RGB") elif isinstance(img_or_path, Image.Image): image = img_or_path.convert("RGB") else: image = img_or_path return self.process_images([images]) def process_anyres(self, img_or_path): if isinstance(img_or_path, str): image = Image.open(img_or_path).convert("RGB") elif isinstance(img_or_path, Image.Image): image = img_or_path.convert("RGB") else: image = img_or_path best_resolution = select_best_resolution(image.size, self.possible_resolutions) image_padded = resize_and_pad_image(image, best_resolution) patches = divide_to_patches(image_padded, self.patch_size) if best_resolution == (self.patch_size, self.patch_size): image_patches = [image] else: image_patches = [image] + patches image_patches = self.process_images(image_patches) # print(f"image {image.size} best_resolution {best_resolution} image_padded {image_padded.size} patches {len(patches)} image_patches {image_patches.size()}") return image_patches, best_resolution def process_dynamic(self, img_or_path): if isinstance(img_or_path, str): image = Image.open(img_or_path).convert("RGB") elif isinstance(img_or_path, Image.Image): image = img_or_path.convert("RGB") else: image = img_or_path image_patches, best_resolution = dynamic_preprocess( image, min_num=self.min_patch_grid, max_num=self.max_patch_grid, image_size=self.patch_size, use_thumbnail=True, ) image_patches = self.process_images(image_patches) # print(f"image {image.size} best_resolution {best_resolution} image_padded {image_padded.size} patches {len(patches)} image_patches {image_patches.size()}") return image_patches, best_resolution def select_best_resolution(original_size, possible_resolutions): """ Selects the best resolution from a list of possible resolutions based on the original size. Args: original_size (tuple): The original size of the image in the format (width, height). possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. Returns: tuple: The best fit resolution in the format (width, height). """ original_width, original_height = original_size best_fit = None max_effective_resolution = 0 min_wasted_resolution = float("inf") for width, height in possible_resolutions: # Calculate the downscaled size to keep the aspect ratio scale = min(width / original_width, height / original_height) downscaled_width, downscaled_height = int(original_width * scale), int( original_height * scale ) # Calculate effective and wasted resolutions effective_resolution = min( downscaled_width * downscaled_height, original_width * original_height ) wasted_resolution = (width * height) - effective_resolution if effective_resolution > max_effective_resolution or ( effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution ): max_effective_resolution = effective_resolution min_wasted_resolution = wasted_resolution best_fit = (width, height) return best_fit def resize_and_pad_image(image, target_resolution): """ Resize and pad an image to a target resolution while maintaining aspect ratio. Args: image (PIL.Image.Image): The input image. target_resolution (tuple): The target resolution (width, height) of the image. Returns: PIL.Image.Image: The resized and padded image. """ original_width, original_height = image.size target_width, target_height = target_resolution # Determine which dimension (width or height) to fill scale_w = target_width / original_width scale_h = target_height / original_height if scale_w < scale_h: # Width will be filled completely new_width = target_width new_height = min(math.ceil(original_height * scale_w), target_height) else: # Height will be filled completely new_height = target_height new_width = min(math.ceil(original_width * scale_h), target_width) # Resize the image resized_image = image.resize((new_width, new_height)) # Create a new image with the target size and paste the resized image onto it new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0)) paste_x = (target_width - new_width) // 2 paste_y = (target_height - new_height) // 2 new_image.paste(resized_image, (paste_x, paste_y)) return new_image def divide_to_patches(image, patch_size): """ Divides an image into patches of a specified size. Args: image (PIL.Image.Image): The input image. patch_size (int): The size of each patch. Returns: list: A list of PIL.Image.Image objects representing the patches. """ patches = [] width, height = image.size for i in range(0, height, patch_size): for j in range(0, width, patch_size): box = (j, i, j + patch_size, i + patch_size) patch = image.crop(box) patches.append(patch) return patches def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): best_ratio_diff = float("inf") best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') return best_ratio def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio target_ratios = set( (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num ) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size ) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ( (i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size, ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) # processed_images.append(thumbnail_img) processed_images = [ thumbnail_img, ] + processed_images return processed_images, (target_width, target_height) ############################################################## ############################################################## # modify from long_vita_megatron/tasks/inference/module.py ############################################################## def get_external_inputs(tokens, image_list=None, image_path_list=None, video_path_list=None): print(f"get_external_inputs tokens {tokens.size()}") tokens = tokens.tolist() image_token_length = 256 max_num_frame = 4096 max_fps = 1 # from cognitron_vl.constants import IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN, VID_START_TOKEN, VID_END_TOKEN, VID_CONTEXT_TOKEN, PATCH_START_TOKEN, PATCH_END_TOKEN, PATCH_CONTEXT_TOKEN, IMG_TAG_TOKEN, VID_TAG_TOKEN image_tag = "<image>" video_tag = "<video>" IMG_CONTEXT_ID = tokenizer(IMG_CONTEXT_TOKEN, add_special_tokens=False).input_ids IMG_START_ID = tokenizer(IMG_START_TOKEN, add_special_tokens=False).input_ids IMG_END_ID = tokenizer(IMG_END_TOKEN, add_special_tokens=False).input_ids VID_CONTEXT_ID = tokenizer(VID_CONTEXT_TOKEN, add_special_tokens=False).input_ids VID_START_ID = tokenizer(VID_START_TOKEN, add_special_tokens=False).input_ids VID_END_ID = tokenizer(VID_END_TOKEN, add_special_tokens=False).input_ids PATCH_CONTEXT_ID = tokenizer(PATCH_CONTEXT_TOKEN, add_special_tokens=False).input_ids PATCH_START_ID = tokenizer(PATCH_START_TOKEN, add_special_tokens=False).input_ids PATCH_END_ID = tokenizer(PATCH_END_TOKEN, add_special_tokens=False).input_ids IMG_TAG_ID = tokenizer(IMG_TAG_TOKEN, add_special_tokens=False).input_ids VID_TAG_ID = tokenizer(VID_TAG_TOKEN, add_special_tokens=False).input_ids assert len(IMG_CONTEXT_ID) == 1 assert len(IMG_START_ID) == 1 assert len(IMG_END_ID) == 1 assert len(VID_CONTEXT_ID) == 1 assert len(VID_START_ID) == 1 assert len(VID_END_ID) == 1 assert len(PATCH_CONTEXT_ID) == 1 assert len(PATCH_START_ID) == 1 assert len(PATCH_END_ID) == 1 IMG_CONTEXT_ID = IMG_CONTEXT_ID[0] IMG_START_ID = IMG_START_ID[0] IMG_END_ID = IMG_END_ID[0] VID_CONTEXT_ID = VID_CONTEXT_ID[0] VID_START_ID = VID_START_ID[0] VID_END_ID = VID_END_ID[0] PATCH_CONTEXT_ID = PATCH_CONTEXT_ID[0] PATCH_START_ID = PATCH_START_ID[0] PATCH_END_ID = PATCH_END_ID[0] IMG_TAG_ID = IMG_TAG_ID[0] VID_TAG_ID = VID_TAG_ID[0] nl_tokens = tokenizer("\n", add_special_tokens=False).input_ids image_indices = [] images = [] # ---------------------------------------------------------------- # image for batch_idx, input_ids in enumerate(tokens): # img_positions = [i for i, x in enumerate(input_ids) if x == IMG_CONTEXT_ID] img_positions = [i for i, x in enumerate(input_ids) if x == IMG_TAG_ID] if len(img_positions) == 0: continue if image_path_list is not None: assert len(img_positions) == len(image_path_list), f"{img_positions} {image_path_list} {IMG_CONTEXT_TOKEN} {IMG_CONTEXT_ID} {tokens}" if image_list is not None: assert len(img_positions) == len(image_list), f"{img_positions} {image_list} {IMG_CONTEXT_TOKEN} {IMG_CONTEXT_ID} {tokens}" new_input_ids = [] st = 0 for img_idx, img_pos in enumerate(img_positions): if image_path_list is not None: image_patches, (best_width, best_height) = image_processor.process_images_with_subpatch(image_path_list[img_idx]) if image_list is not None: image_patches, (best_width, best_height) = image_processor.process_images_with_subpatch(image_list[img_idx]) images.append(image_patches) print(f"get_external_inputs best_width {best_width} best_height {best_height}") new_input_ids += input_ids[st:img_pos] new_input_ids += [IMG_START_ID] image_indice_b = torch.zeros( 1, image_token_length, dtype=torch.int64 ) # This will change in collate_fn image_indice_s = ( torch.arange(len(new_input_ids), len(new_input_ids) + image_token_length) .unsqueeze(0) .repeat(1, 1) ) image_indice_b_s = torch.stack( [image_indice_b, image_indice_s], dim=0 ) # 2, num_image, image_length image_indices.append(image_indice_b_s) new_input_ids += [IMG_CONTEXT_ID] * image_token_length new_input_ids += [IMG_END_ID] if len(image_patches) > 1: for i in range(0, best_height, image_processor.patch_size): new_input_ids += nl_tokens for j in range(0, best_width, image_processor.patch_size): new_input_ids += [PATCH_START_ID] image_indice_b = torch.zeros( 1, image_token_length, dtype=torch.int64 ) # This will change in collate_fn image_indice_s = ( torch.arange(len(new_input_ids), len(new_input_ids) + image_token_length) .unsqueeze(0) .repeat(1, 1) ) image_indice_b_s = torch.stack( [image_indice_b, image_indice_s], dim=0 ) # 2, num_image, image_length image_indices.append(image_indice_b_s) new_input_ids += [PATCH_CONTEXT_ID] * image_token_length new_input_ids += [PATCH_END_ID] # print(f"get_external_dict i {i} j {j} new_input_ids {len(new_input_ids)}") st = img_pos + 1 new_input_ids += input_ids[st:] input_ids = new_input_ids tokens[batch_idx] = input_ids # ---------------------------------------------------------------- # video for batch_idx, input_ids in enumerate(tokens): # vid_positions = [i for i, x in enumerate(input_ids) if x == VID_CONTEXT_ID] vid_positions = [i for i, x in enumerate(input_ids) if x == VID_TAG_ID] if len(vid_positions) == 0: continue if video_path_list is not None: assert len(vid_positions) == len(video_path_list), f"{vid_positions} {video_path_list} {VID_CONTEXT_TOKEN} {VID_CONTEXT_ID} {tokens}" if image_path_list is not None: assert len(vid_positions) == len(image_path_list), f"{vid_positions} {image_path_list} {VID_CONTEXT_TOKEN} {VID_CONTEXT_ID} {tokens}" if image_list is not None: assert len(vid_positions) == len(image_list), f"{vid_positions} {image_list} {VID_CONTEXT_TOKEN} {VID_CONTEXT_ID} {tokens}" new_input_ids = [] st = 0 for vid_idx, vid_pos in enumerate(vid_positions): if video_path_list is not None: video_frames, _ = image_processor.process_video(video_path_list[vid_idx], max_num_frame, max_fps) if image_path_list is not None: video_frames = image_processor.process_images([image_path_list[vid_idx]]) if image_list is not None: video_frames = image_processor.process_images([image_list[vid_idx]]) images.append(video_frames) new_input_ids += input_ids[st:vid_pos] for _ in video_frames: new_input_ids += [VID_START_ID] image_indice_b = torch.zeros( 1, image_token_length, dtype=torch.int64 ) # This will change in collate_fn image_indice_s = ( torch.arange(len(new_input_ids), len(new_input_ids) + image_token_length) .unsqueeze(0) .repeat(1, 1) ) image_indice_b_s = torch.stack( [image_indice_b, image_indice_s], dim=0 ) # 2, num_image, image_length image_indices.append(image_indice_b_s) new_input_ids += [VID_CONTEXT_ID] * image_token_length new_input_ids += [VID_END_ID] st = vid_pos + 1 new_input_ids += input_ids[st:] input_ids = new_input_ids tokens[batch_idx] = input_ids if len(images) > 0: images = torch.cat(images, dim=0) image_indices = torch.cat(image_indices, dim=1) image_indices = image_indices.contiguous().to(torch.cuda.current_device()) if True: images = torch.tensor(images, dtype=torch.bfloat16).contiguous().to(torch.cuda.current_device()) else: images = torch.tensor(images, dtype=torch.float16).contiguous().to(torch.cuda.current_device()) print(f"get_external_inputs images {images.size()}") print(f"get_external_inputs image_indices {image_indices.size()}") else: images = None image_indices = None print(f"get_external_inputs images {images}") print(f"get_external_inputs image_indices {image_indices}") tokens = torch.tensor(tokens, dtype=torch.long, device='cuda') print(f"get_external_inputs tokens {tokens.size()}") return tokens, images, image_indices ############################################################## from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig import torch import importlib if importlib.util.find_spec("torch_npu") is not None: print("Loading torch_npu") import torch_npu from torch_npu.contrib import transfer_to_npu # torch.npu.set_compile_mode(jit_compile=True) import sys import os import natsort import gradio as gr import spaces torch.manual_seed(1234) model_name_or_path = "VITA-MLLM/Long-VITA-128K_HF" device_map = "auto" # device_map = "npu:0" # torch_dtype=torch.float16 torch_dtype=torch.bfloat16 # torch_dtype=torch.float32 tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, trust_remote_code=True ) print("tokenizer", tokenizer) model = AutoModelForCausalLM.from_pretrained( model_name_or_path, trust_remote_code=True, device_map=device_map, torch_dtype=torch_dtype, attn_implementation="flash_attention_2", ).eval() # print("model", model) model.generation_config = GenerationConfig.from_pretrained(model_name_or_path, trust_remote_code=True) model.generation_config.max_new_tokens = 1024 model.generation_config.chat_format = "chatml" model.generation_config.max_window_size = 1310720 model.generation_config.do_sample = False model.generation_config.use_cache = True model.generation_config.pad_token_id = tokenizer.pad_token_id # from cognitron_vl.data.processor.image_processor import ImageProcessor image_processor = ImageProcessor( process_type="dynamic", image_size=448, normalize_type="imagenet", min_patch_grid=1, max_patch_grid=12, ) @spaces.GPU(duration=120) def inference_model(messages, image_path_list, video_path_list): default_system_message = [ { "role": "system", "content": "You are a helpful AI assistant.", } ] messages = default_system_message + messages inputs = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", ) # .to("cuda") print("input", tokenizer.decode(inputs[0], skip_special_tokens=False), flush=True) inputs, images, image_indices = get_external_inputs(inputs, image_path_list=image_path_list, video_path_list=video_path_list) # inputs = inputs.to("cuda") # images = images.to("cuda") # image_indices = image_indices.to("cuda") outputs = model.generate(inputs=inputs, images=images, image_indices=image_indices) # output = tokenizer.decode(outputs[0], skip_special_tokens=False) output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True) print(f"output {output}", flush=True) return output import time import filetype font_size = "2.5em" html = f""" <p align="center" style="font-size: {font_size}; line-height: 1;"> <span style="display: inline-block; vertical-align: middle;">{model_name_or_path.split('/')[-1]}</span> </p> <center> <font size=3> <b>Long-VITA</b> has been fully open-sourced on <a href='https://huggingface.co/VITA-MLLM'>π Huggingface</a> and <a href='https://github.com/VITA-MLLM/Long-VITA'>π GitHub</a>. If you find Long-VITA useful, a likeβ€οΈ or a starπ would be appreciated. </font> </center> """ def add_message(history, message): for x in message["files"]: history.append({"role": "user", "content": {"path": x}}) if message["text"] is not None: history.append({"role": "user", "content": message["text"]}) return history, gr.MultimodalTextbox(value=None, interactive=False) def bot(history: list): print("#" * 100) messages = [] image_path_list = [] video_path_list = [] for message in history: # print(f"message {message}") role = message["role"] content = message["content"] if isinstance(content, str): if len(messages) == 0 or messages[-1]["role"] != role: messages.append( { "role": role, "content": "", } ) messages[-1]["content"] = messages[-1]["content"] + content else: for filepath in content: if filetype.is_image(filepath): # print(f"{filepath} is a valid image...") if len(messages) == 0 or messages[-1]["role"] != role: messages.append( { "role": role, "content": "", } ) messages[-1]["content"] = "<image>" + messages[-1]["content"] image_path_list.append(filepath) elif filetype.is_video(filepath): # print(f"{filepath} is a valid video...") if len(messages) == 0 or messages[-1]["role"] != role: messages.append( { "role": role, "content": "", } ) messages[-1]["content"] = "<video>" + messages[-1]["content"] video_path_list.append(filepath) print(f"messages {messages}") print(f"image_path_list {image_path_list}") print(f"video_path_list {video_path_list}") if len(image_path_list) == 0: image_path_list = None if len(video_path_list) == 0: video_path_list = None output = inference_model(messages, image_path_list, video_path_list) history.append({"role": "assistant", "content": output}) return history with gr.Blocks(title=model_name_or_path.split('/')[-1] + "π₯ππ₯", theme=gr.themes.Ocean()) as demo: gr.HTML(html) with gr.Row(): chatbot = gr.Chatbot(type="messages", elem_id="chatbot", bubble_full_width=False, height=600) with gr.Row(): chat_input = gr.MultimodalTextbox( interactive=True, file_count="multiple", file_types=['image', 'video'], placeholder="Enter message or upload file...", show_label=False, # sources=["microphone", "upload"], sources=["upload"], ) chat_msg = chat_input.submit( add_message, [chatbot, chat_input], [chatbot, chat_input] ) bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name="bot_response") bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input]) demo.launch()