Spaces:
Running
on
Zero
Running
on
Zero
# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright: | |
# Copyright 2023 Haotian Liu | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import math | |
from abc import ABC, abstractmethod | |
import einops | |
import torch | |
import torch.distributed as dist | |
import torch.nn as nn | |
import numpy as np | |
from ..constants import IGNORE_INDEX, MODAL_INDEX_MAP, NUM_FRAMES | |
from .encoder import build_vision_encoder | |
from .projector import build_vision_projector, load_mm_projector | |
from .region_encoder import build_region_encoder | |
from ..mm_utils import reshape_images_to_raw_grid | |
def spatial_downsampling(features, grid_thws, strides): | |
n, c = features.shape | |
flatten_grid_thws = torch.cat([grid_thw for batch_grid_thws in grid_thws for grid_thw in batch_grid_thws]) | |
split_sizes = [grid_thw.prod() for grid_thw in flatten_grid_thws] | |
features = torch.split(features, split_sizes) | |
flatten_strides = [stride for batch_strides in strides for stride in batch_strides] | |
new_features = [] | |
for feature, grid_thw, stride in zip(features, flatten_grid_thws, flatten_strides): | |
# NOTE: adapted for reshape in image processor | |
feature = feature.view(grid_thw[0], grid_thw[1] // stride, grid_thw[2] // stride, stride, stride, c).permute(0, 1, 3, 2, 4, 5) | |
feature = feature.reshape(grid_thw[0], grid_thw[1], grid_thw[2], c).permute(0, 3, 1, 2) | |
# NOTE: previous version model is align_corners=True | |
new_feature = torch.nn.functional.interpolate(feature, (math.ceil(grid_thw[1] / stride), math.ceil(grid_thw[2] / stride)), mode='bilinear') | |
# new_feature = nn.functional.avg_pool2d(feature, stride) | |
# new_feature = nn.functional.max_pool2d(feature, stride) | |
new_features.append(new_feature.permute(0, 2, 3, 1).view(-1, c)) | |
new_features = torch.cat(new_features) | |
return new_features | |
class Videollama3MetaModel: | |
def __init__(self, config): | |
super(Videollama3MetaModel, self).__init__(config) | |
if hasattr(config, "vision_encoder") or hasattr(config, "mm_vision_encoder"): | |
self.vision_encoder = build_vision_encoder(config, delay_load=False) | |
self.mm_projector = build_vision_projector(config) | |
self.region_encoder = build_region_encoder(config, self.vision_encoder.hidden_size) | |
def get_vision_encoder(self): | |
vision_encoder = getattr(self, 'vision_encoder', None) | |
if type(vision_encoder) is list: | |
vision_encoder = vision_encoder[0] | |
return vision_encoder | |
def get_mm_projector(self): | |
return self.mm_projector | |
def initialize_vision_modules(self, model_args, fsdp=None): | |
vision_encoder = model_args.vision_encoder | |
mm_vision_select_layer = model_args.mm_vision_select_layer | |
mm_vision_select_feature = model_args.mm_vision_select_feature | |
pretrain_mm_projector = model_args.pretrain_mm_projector | |
self.config.mm_vision_encoder = vision_encoder | |
if self.get_vision_encoder() is None: | |
vision_encoder = build_vision_encoder(model_args) | |
if fsdp is not None and len(fsdp) > 0: | |
self.vision_encoder = [vision_encoder] | |
else: | |
self.vision_encoder = vision_encoder | |
else: | |
if fsdp is not None and len(fsdp) > 0: | |
vision_encoder = self.vision_encoder[0] | |
else: | |
vision_encoder = self.vision_encoder | |
# NOTE: only compatible with delay_load encoder | |
# vision_encoder.load_model(vision_encoder.cfg_only) | |
self.config.use_mm_proj = True | |
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') | |
self.config.mm_hidden_size = vision_encoder.hidden_size | |
self.config.mm_vision_select_layer = mm_vision_select_layer | |
self.config.mm_vision_select_feature = mm_vision_select_feature | |
if getattr(self, 'mm_projector', None) is None: | |
self.mm_projector = build_vision_projector(self.config) | |
else: | |
# In case it is frozen by LoRA | |
for p in self.mm_projector.parameters(): | |
p.requires_grad = True | |
if pretrain_mm_projector is not None: | |
if os.path.exists(pretrain_mm_projector): | |
is_local = True | |
if os.path.isdir(pretrain_mm_projector): | |
mm_projector_weights = load_mm_projector(pretrain_mm_projector) | |
else: | |
mm_projector_weights = torch.load(pretrain_mm_projector, map_location='cpu') | |
else: | |
# Support loading projector weights from remote HuggingFace model hub | |
is_local = False | |
pretrain_mm_projector = pretrain_mm_projector.replace('mm_projector.bin', '') | |
pretrain_mm_projector = pretrain_mm_projector.strip('/').strip('\\').strip() | |
mm_projector_weights = load_mm_projector(pretrain_mm_projector) | |
def get_w(weights, keyword): | |
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} | |
# self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) | |
# set strict=False to avoid missing key error regarding bert.embeddings.position_ids | |
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'), strict=False) | |
class Videollama3MetaForCausalLM(ABC): | |
def get_model(self): | |
pass | |
def num_frames(self): | |
if hasattr(self.config, 'num_frames'): | |
return self.config.num_frames | |
else: | |
return NUM_FRAMES | |
def spatial_merge_size(self): | |
if hasattr(self.config, 'spatial_merge_size'): | |
return self.config.spatial_merge_size | |
else: | |
return 1 | |
def get_vision_encoder(self): | |
return self.get_model().get_vision_encoder() | |
def get_mm_projector(self): | |
return self.get_model().get_mm_projector() | |
def encode_images(self,images, grid_thws, strides): | |
""" | |
images shape [b c h w] | |
""" | |
images_features = self.get_model().get_vision_encoder()(images, grid_thws=grid_thws, strides=strides) | |
# images_features = spatial_downsampling(images_features, grid_thws, stride=self.config.spatial_merge_size) | |
mm_features = spatial_downsampling(images_features, grid_thws, strides=strides) | |
images_features = self.get_model().mm_projector(mm_features) | |
return images_features | |
def prepare_inputs_labels_for_multimodal( | |
self, input_ids, attention_mask, past_key_values, labels, images, position_ids=None, masks=None, additional_images = None, | |
): | |
if self.config.use_token_compression: | |
return self.prepare_inputs_labels_for_multimodal_with_compression(input_ids, attention_mask, past_key_values, labels, images, position_ids, masks, additional_images) | |
# # images shape (modal, tensor, flag) | |
# vision_encoder = self.get_vision_encoder() | |
# # NOTE: text-only situation | |
# if vision_encoder is None or images is None or input_ids.shape[1] == 1: | |
# return input_ids, attention_mask, past_key_values, None, labels, position_ids | |
# # NOTE: Equvialent to the following code: | |
# # images_tensor = [image for modal, image, image_flag, grid_thw in images] | |
# # images_flag = [image_flag for modal, image, image_flag, grid_thw in images] | |
# # grid_thws = [grid_thw for modal, image, image_flag, grid_thw in images] | |
# modals, images, grid_thws = zip(*images) | |
# images_flag = [] | |
# strides = [] | |
# for modal, grid_thw in zip(modals, grid_thws): | |
# grid_thw = torch.cat(grid_thw) | |
# stride = self.config.spatial_merge_size if modal == "video" else 1 | |
# num_patches = grid_thw.prod(dim=-1).sum().div(stride**2).long() | |
# image_flag = torch.full((num_patches, ), 0 if modal == 'text' else 1) | |
# images_flag.append(image_flag) | |
# strides.append([stride] * grid_thw.size(0)) | |
# images_flag_tensor = torch.cat(images_flag) | |
# mm_features = self.encode_images(images, grid_thws, strides) | |
# mm_features = mm_features[images_flag_tensor.to(mm_features.device) == 1].to(input_ids.device) | |
# additional_images_list = [] | |
# additional_images_thw = [] | |
# additional_images_strides = [] | |
# for i in range(len(additional_images)): | |
# additional_images_list.append(torch.from_numpy(np.array(additional_images[0][0])).to(mm_features.dtype).to(mm_features.device)) | |
# additional_images_thw.append(torch.tensor(additional_images[0][1][0]).to(mm_features.device)) | |
# additional_images_strides.append([1]*len(additional_images[0][1][0])) | |
# image_selected = (input_ids == self.config.image_token_index) | |
# audio_selected = (input_ids == MODAL_INDEX_MAP['<audio>']) | |
# input_ids[image_selected] = 0 | |
# input_ids[audio_selected] = 0 | |
# input_embeds = self.get_model().embed_tokens(input_ids).clone() | |
# B, N, C = input_embeds.shape | |
# input_embeds = input_embeds.reshape(B * N, C).to(input_ids.device) | |
# image_selected = image_selected.reshape(B * N) | |
# audio_selected = audio_selected.reshape(B * N) | |
# input_embeds[image_selected] = input_embeds[image_selected] * 0.0 + mm_features.reshape(-1, C) | |
# # replace region token | |
# mask_selected = (input_ids == self.config.region_token_index) | |
# if mask_selected.sum()>0: | |
# additional_images_features = self.get_model().get_vision_encoder()(additional_images_list, grid_thws=[additional_images_thw], strides=additional_images_strides) | |
# reshaped_features = reshape_images_to_raw_grid(additional_images_features, additional_images_thw) | |
# mask_additional_image_features = [] | |
# for idx in mask_ids: | |
# mask_additional_image_features.append(reshaped_features[idx]) | |
# mask_feats = self.model.region_encoder(mask_additional_image_features, masks) | |
# input_embeds[mask_selected] = input_embeds[mask_selected]*0.0 + mask_feats | |
# input_embeds = input_embeds.reshape(B, N, C) | |
# return None, attention_mask, past_key_values, input_embeds, labels, position_ids | |
def prepare_inputs_labels_for_multimodal_with_compression( | |
self, input_ids, attention_mask, past_key_values, labels, images, position_ids=None, masks=None, additional_images = None, | |
): | |
# images shape (modal, tensor, flag) | |
vision_encoder = self.get_vision_encoder() | |
# NOTE: text-only situation | |
if vision_encoder is None or images is None or input_ids.shape[1] == 1: | |
return input_ids, attention_mask, past_key_values, None, labels, position_ids | |
# NOTE: Equvialent to the following code: | |
# images_tensor = [image for modal, image, image_flag, grid_thw in images] | |
# images_flag = [image_flag for modal, image, image_flag, grid_thw in images] | |
# grid_thws = [grid_thw for modal, image, image_flag, grid_thw in images] | |
modals, images, grid_thws = zip(*images) | |
images_flag = [] | |
visual_masks = [] | |
strides = [] | |
visual_trunc_masks = [] | |
for modal, image, grid_thw in zip(modals, images, grid_thws): | |
grid_thw = torch.cat(grid_thw) | |
stride = self.config.spatial_merge_size if modal == "video" else 1 | |
num_patches = grid_thw.prod(dim=-1).sum().div(stride**2).long() | |
image_flag = torch.full((num_patches, ), 0 if modal == 'text' else 1) | |
images_flag.append(image_flag) | |
strides.append([stride] * grid_thw.size(0)) | |
if modal == "image" or (modal == "video" and len(image) == 1): | |
visual_masks.append(torch.ones((num_patches,), dtype=torch.bool, device=input_ids.device)) | |
visual_trunc_masks.append(torch.ones((num_patches,), dtype=torch.bool, device=input_ids.device)) | |
elif modal == "video": | |
# NOTE: video frame compressor | |
n, h, w = len(image), grid_thw[0][1], grid_thw[0][2] | |
image = torch.stack(image, dim=0).view(n, (h // stride) * (w // stride), -1) | |
threshold = 0.1 | |
min_tokens = 1 | |
pixel_diff = image[1:] - image[:-1] | |
pixel_diff = torch.abs(pixel_diff).mean(dim=-1) * 255 | |
pixel_diff = torch.cat([torch.full_like(pixel_diff[0:1], threshold + 1), pixel_diff], dim=0) | |
# if dist.get_rank() == 0: | |
# print(pixel_diff.shape, image.shape) | |
mask = pixel_diff > threshold | |
padding_ids = torch.nonzero(mask.sum(dim=1) < min_tokens)[:, 0] | |
# mask[padding_ids, torch.randperm(min_tokens)] = 1 | |
mask[padding_ids, :min_tokens] = 1 | |
visual_masks.append(mask.flatten()) | |
visual_trunc_masks.append(torch.ones((num_patches,), dtype=torch.bool, device=input_ids.device)) | |
elif modal == "text": | |
visual_trunc_masks.append(torch.ones((0,), dtype=torch.bool, device=input_ids.device)) | |
images_flag_tensor = torch.cat(images_flag) | |
mm_features = self.encode_images(images, grid_thws, strides) | |
mm_features = mm_features[images_flag_tensor.to(mm_features.device) == 1] | |
additional_images_list = [] | |
additional_images_thw = [] | |
additional_images_strides = [] | |
if additional_images is not None: #and additional_images[0] is not None | |
for i in range(len(additional_images)): | |
for img_idx in range(len(additional_images[i][0])): | |
additional_images_list.append([torch.from_numpy(np.array(additional_images[i][0][img_idx])).to(mm_features.dtype).to(mm_features.device)]) | |
additional_images_thw.append([torch.tensor(np.array(additional_images[i][1][img_idx])).to(mm_features.device)]) | |
additional_images_strides.append([1]*len(additional_images[i][1][img_idx])) | |
# additional_images_list.append(additional_images[i][0]) | |
# additional_images_thw.append(additional_images[i][1]) | |
# additional_images_strides.append([1]*len(additional_images[i][1])) | |
# import pdb | |
# pdb.set_trace() | |
B, N = input_ids.shape | |
C = mm_features.shape[-1] | |
assert B == 1, "Only support batch flattening for now" | |
input_ids = input_ids.view(B * N) | |
image_selected = (input_ids == self.config.image_token_index) | |
audio_selected = (input_ids == MODAL_INDEX_MAP['<audio>']) | |
if len(visual_masks) > 0: | |
# if dist.get_rank() == 0: | |
# print(grid_thws, [x.shape for x in visual_masks]) | |
visual_masks = torch.cat(visual_masks) | |
# print((visual_masks == 1).sum(), (visual_masks == 0).sum()) | |
mm_features = mm_features[visual_masks] | |
# text_masks = torch.zeros_like(input_ids, dtype=torch.bool) | |
# text_masks[~image_selected] = True | |
text_masks = torch.logical_not(image_selected) | |
try: | |
text_masks[image_selected] = visual_masks | |
except Exception as e: | |
assert position_ids is not None, "Position ids must be provided when shapes mismatch" | |
print( | |
f'warning: {e}, text_masks[image_selected].shape={text_masks[image_selected].shape},', | |
f'visual_masks.shape={visual_masks.shape}' | |
) | |
seq_end_indices = torch.nonzero(position_ids.view(B * N) == 0)[:, 0] | |
seq_end_indices = seq_end_indices[seq_end_indices > 0] | |
seq_end_indices = seq_end_indices.tolist()+ [len(input_ids)] | |
seq_start_indices = [0] + seq_end_indices[:-1] | |
num_visual_tokens = [ | |
input_ids[start:end].eq(self.config.image_token_index).sum() | |
for start, end in zip(seq_start_indices, seq_end_indices) | |
] | |
for n, mask in zip(num_visual_tokens, visual_trunc_masks): | |
if len(mask) > 0: | |
mask[n:] = False | |
visual_trunc_masks = torch.cat(visual_trunc_masks) | |
text_masks[image_selected] = visual_masks[visual_trunc_masks] | |
mm_features = mm_features[visual_trunc_masks[visual_masks]] | |
else: | |
text_masks = torch.ones_like(input_ids, dtype=torch.bool) | |
input_ids = input_ids[text_masks] | |
if attention_mask is not None: | |
attention_mask = attention_mask.view(B * N)[text_masks].reshape(1, -1) | |
if labels is not None: | |
labels = labels.view(B * N)[text_masks].reshape(1, -1) | |
if position_ids is not None: | |
position_ids = position_ids.view(B * N)[text_masks] | |
pos_start = [0] + torch.nonzero(position_ids == 0)[:, 0].tolist() | |
pos_end = pos_start[1:] + [len(input_ids)] | |
position_ids = torch.cat([torch.arange(end - start, device=input_ids.device) for start, end in zip(pos_start, pos_end)]) | |
position_ids = position_ids.reshape(1, -1) | |
image_selected = (input_ids == self.config.image_token_index) | |
audio_selected = (input_ids == MODAL_INDEX_MAP['<audio>']) | |
input_ids[image_selected] = 0 | |
input_ids[audio_selected] = 0 | |
input_embeds = self.get_model().embed_tokens(input_ids).clone() | |
input_embeds[image_selected] = input_embeds[image_selected] * 0.0 + mm_features.reshape(-1, C) | |
# replace region token | |
mask_selected = (input_ids == self.config.region_token_index) | |
try: | |
if mask_selected.sum()>0: | |
# try: | |
# patches = np.ascontiguousarray(additional_images_list[0][0]) | |
# grid_h = additional_images_thw[0][0][0][1] | |
# grid_w = additional_images_thw[0][0][0][2] | |
# patches = patches.reshape(grid_h ,grid_w, 3, 14, 14) | |
# from matplotlib import pyplot as plt | |
# plt.imshow(patches[:,:,:,0,0]) | |
# plt.savefig('7.png') | |
# import pdb | |
# pdb.set_trace() | |
# patches = patches.transpose(2, 0, 3, 1, 4) | |
# reconstructed_image = patches.reshape(3, grid_h*14, grid_w*14).transpose(1, 2, 0) | |
# from matplotlib import pyplot as plt | |
# plt.imshow(reconstructed_image) | |
# plt.savefig('7.png') | |
# import pdb | |
# pdb.set_trace() | |
additional_images_features = self.get_model().get_vision_encoder()(additional_images_list, grid_thws=additional_images_thw, strides=additional_images_strides) | |
reshaped_features = reshape_images_to_raw_grid(additional_images_features, additional_images_thw) | |
# mask_additional_image_features = [] | |
# for idx in mask_ids: | |
# mask_additional_image_features.append(reshaped_features[idx]) | |
mask_feats = self.model.region_encoder(reshaped_features, masks) | |
input_embeds[mask_selected] = input_embeds[mask_selected]*0.0 + mask_feats | |
# except: #FIXME | |
# print('additional_images_list is empty...') | |
except Exception as exp: | |
print('error: ', exp) | |
new_input_embeds = input_embeds.reshape(1, -1, C) | |
return None, attention_mask, past_key_values, new_input_embeds, labels, position_ids | |