|
|
import contextlib |
|
|
import math |
|
|
|
|
|
import einops |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch import Tensor |
|
|
from transformers import Qwen2ForCausalLM, SiglipVisionModel |
|
|
from transformers.cache_utils import Cache |
|
|
from transformers.generation.utils import GenerationMixin |
|
|
from transformers.modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
|
|
from .configuration_nvila import NVILAConfig |
|
|
|
|
|
MM_HIDDEN_SIZE = 3456 |
|
|
|
|
|
|
|
|
class NVILAMultiModalProjectorDownsampleBlock(nn.Module): |
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
batch_size, sequence_length, hidden_size = x.shape |
|
|
|
|
|
feat_size = math.isqrt(sequence_length) |
|
|
|
|
|
features = x.reshape(batch_size, feat_size, feat_size, hidden_size) |
|
|
|
|
|
pad_after = feat_size % 2 |
|
|
if pad_after > 0: |
|
|
features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after)) |
|
|
feat_size = feat_size + pad_after |
|
|
|
|
|
features = features.reshape(batch_size, feat_size // 2, 2, feat_size // 2, 2, hidden_size) |
|
|
features = features.permute(0, 1, 3, 2, 4, 5).contiguous() |
|
|
features = features.reshape(batch_size, -1, 4 * hidden_size) |
|
|
|
|
|
return features |
|
|
|
|
|
|
|
|
class NVILAMultiModalProjector(nn.Module): |
|
|
def __init__(self, config: NVILAConfig): |
|
|
super().__init__() |
|
|
|
|
|
self.layers = nn.Sequential( |
|
|
NVILAMultiModalProjectorDownsampleBlock(), |
|
|
nn.LayerNorm(MM_HIDDEN_SIZE * 4), |
|
|
nn.Linear(MM_HIDDEN_SIZE * 4, config.text_config.hidden_size), |
|
|
nn.GELU(), |
|
|
nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size), |
|
|
) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
return self.layers(x) |
|
|
|
|
|
|
|
|
class NVILAForConditionalGeneration(PreTrainedModel, GenerationMixin): |
|
|
config_class = NVILAConfig |
|
|
base_model_prefix: str = "llm" |
|
|
_auto_class = "AutoModel" |
|
|
_supports_flash_attn_2 = True |
|
|
_supports_sdpa = True |
|
|
|
|
|
def __init__(self, config: NVILAConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
self.config: NVILAConfig |
|
|
|
|
|
@contextlib.contextmanager |
|
|
def default_torch_dtype(dtype): |
|
|
original_dtype = torch.get_default_dtype() |
|
|
torch.set_default_dtype(dtype) |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
torch.set_default_dtype(original_dtype) |
|
|
|
|
|
with default_torch_dtype(config.torch_dtype): |
|
|
self.vision_tower = SiglipVisionModel(config.vision_config) |
|
|
self.mm_projector = NVILAMultiModalProjector(config) |
|
|
self.llm = Qwen2ForCausalLM(config.text_config) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
*, |
|
|
block_sizes: list[tuple[int, int]] | None = None, |
|
|
input_ids: Tensor | None = None, |
|
|
inputs_embeds: Tensor | None = None, |
|
|
pixel_values: Tensor | None = None, |
|
|
pixel_values_videos: Tensor | None = None, |
|
|
**kwargs, |
|
|
) -> CausalLMOutputWithPast: |
|
|
assert (input_ids is None) != ( |
|
|
inputs_embeds is None |
|
|
), "Exactly one of `input_ids` or `inputs_embeds` must be specified." |
|
|
|
|
|
if input_ids is not None and torch.any( |
|
|
torch.isin( |
|
|
input_ids, |
|
|
torch.tensor( |
|
|
[self.config.image_token_id, self.config.video_token_id], |
|
|
device=input_ids.device, |
|
|
), |
|
|
).any() |
|
|
): |
|
|
inputs_embeds = self._embed( |
|
|
block_sizes=block_sizes, |
|
|
input_ids=input_ids, |
|
|
pixel_values=pixel_values, |
|
|
pixel_values_videos=pixel_values_videos, |
|
|
) |
|
|
input_ids = None |
|
|
|
|
|
outputs = self.llm( |
|
|
input_ids=input_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
return outputs |
|
|
|
|
|
def _embed( |
|
|
self, |
|
|
*, |
|
|
block_sizes: list[tuple[int, int]] | None, |
|
|
input_ids: Tensor, |
|
|
pixel_values: Tensor | None, |
|
|
pixel_values_videos: Tensor | None, |
|
|
) -> Tensor: |
|
|
inputs_embeds: Tensor = self.llm.model.embed_tokens(input_ids) |
|
|
|
|
|
for pixel_values, media_token_id in [ |
|
|
(pixel_values, self.config.image_token_id), |
|
|
(pixel_values_videos, self.config.video_token_id), |
|
|
]: |
|
|
if pixel_values is None: |
|
|
continue |
|
|
|
|
|
vision_features = self._encode_vision( |
|
|
pixel_values, |
|
|
block_sizes=block_sizes, |
|
|
) |
|
|
vision_features = einops.rearrange(vision_features, "n p d -> (n p) d") |
|
|
|
|
|
inputs_embeds[input_ids == media_token_id] = vision_features |
|
|
|
|
|
return inputs_embeds |
|
|
|
|
|
def _encode_vision( |
|
|
self, |
|
|
pixel_values: Tensor, |
|
|
*, |
|
|
block_sizes: list[tuple[int, int]] | None = None, |
|
|
) -> Tensor: |
|
|
vision_tower_output: BaseModelOutputWithPooling = self.vision_tower( |
|
|
pixel_values.to(device=self.vision_tower.device, dtype=self.vision_tower.dtype), |
|
|
output_hidden_states=True, |
|
|
) |
|
|
assert vision_tower_output.hidden_states is not None |
|
|
|
|
|
vision_features: Tensor = vision_tower_output.hidden_states[-2] |
|
|
|
|
|
vision_features_list, block_sizes = merge_features_for_dynamic_s2( |
|
|
vision_features, |
|
|
block_sizes=block_sizes if block_sizes is not None else [None] * vision_features.shape[0], |
|
|
resize_output_to_scale_idx=-1, |
|
|
scales=[448, 896, 1344], |
|
|
) |
|
|
|
|
|
vision_features_list = [ |
|
|
split_chessboard(x, block_size[0], block_size[1]) |
|
|
for x, block_size in zip(vision_features_list, block_sizes) |
|
|
] |
|
|
|
|
|
vision_features = torch.cat([einops.rearrange(x, "b c h w -> b (h w) c") for x in vision_features_list]) |
|
|
|
|
|
vision_features = self.mm_projector(vision_features.to(self.device, self.dtype)) |
|
|
|
|
|
vision_features_list = list( |
|
|
vision_features.split([block_size[0] * block_size[1] for block_size in block_sizes], dim=0) |
|
|
) |
|
|
vision_features_list = [ |
|
|
merge_chessboard(x, block_size[0], block_size[1]) |
|
|
for x, block_size in zip(vision_features_list, block_sizes) |
|
|
] |
|
|
|
|
|
vision_features = torch.stack([einops.rearrange(x, "1 c h w -> (h w) c") for x in vision_features_list]) |
|
|
|
|
|
return vision_features |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def merge_chessboard(x, num_split_h, num_split_w): |
|
|
""" |
|
|
x: b * n * c or b * h * w * c |
|
|
out: b * c * h * w |
|
|
Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square. |
|
|
""" |
|
|
B = x.shape[0] |
|
|
if x.dim() == 3: |
|
|
N = x.shape[1] |
|
|
x = einops.rearrange(x, "b (h w) c -> b c h w", h=math.isqrt(N), w=math.isqrt(N)) |
|
|
|
|
|
assert B % (num_split_h * num_split_w) == 0 |
|
|
b = B // (num_split_h * num_split_w) |
|
|
|
|
|
x_merge = torch.cat( |
|
|
[ |
|
|
torch.cat( |
|
|
[x[(i * num_split_w + j) * b : (i * num_split_w + j + 1) * b] for j in range(num_split_w)], dim=-1 |
|
|
) |
|
|
for i in range(num_split_h) |
|
|
], |
|
|
dim=-2, |
|
|
) |
|
|
|
|
|
return x_merge |
|
|
|
|
|
|
|
|
def merge_features_for_dynamic_s2(image_features, block_sizes, *, scales, resize_output_to_scale_idx): |
|
|
image_features_each_image = [] |
|
|
new_block_sizes = [] |
|
|
block_cnt = 0 |
|
|
for block_size_each_image in block_sizes: |
|
|
if block_size_each_image is None: |
|
|
cur_features = image_features[block_cnt : block_cnt + 1] |
|
|
cur_features = einops.rearrange(cur_features, "1 (h w) c -> 1 c h w", h=math.isqrt(cur_features.shape[1])) |
|
|
cur_features = cur_features.repeat(1, len(scales), 1, 1) |
|
|
image_features_each_image.append(cur_features) |
|
|
new_block_sizes.append((1, 1)) |
|
|
block_cnt += 1 |
|
|
else: |
|
|
cur_features_each_scale = [] |
|
|
for scale in scales[:-1]: |
|
|
num_blocks_this_scale = (scale // scales[0]) ** 2 |
|
|
cur_features_each_scale.append( |
|
|
merge_chessboard( |
|
|
image_features[block_cnt : block_cnt + num_blocks_this_scale], |
|
|
num_split_h=scale // scales[0], |
|
|
num_split_w=scale // scales[0], |
|
|
) |
|
|
) |
|
|
block_cnt += num_blocks_this_scale |
|
|
num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1] |
|
|
cur_features_each_scale.append( |
|
|
merge_chessboard( |
|
|
image_features[block_cnt : block_cnt + num_blocks_last_scale], |
|
|
num_split_h=block_size_each_image[0], |
|
|
num_split_w=block_size_each_image[1], |
|
|
) |
|
|
) |
|
|
block_cnt += num_blocks_last_scale |
|
|
|
|
|
|
|
|
output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:] |
|
|
cur_features = torch.cat( |
|
|
[ |
|
|
F.interpolate(cur_features_each_scale[i].to(torch.float32), size=output_size, mode="area").to( |
|
|
cur_features_each_scale[i].dtype |
|
|
) |
|
|
for i in range(len(cur_features_each_scale)) |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
|
|
|
|
|
|
image_features_each_image.append(cur_features) |
|
|
|
|
|
if resize_output_to_scale_idx == len(scales) - 1 or resize_output_to_scale_idx == -1: |
|
|
new_block_sizes.append(block_size_each_image) |
|
|
else: |
|
|
new_block_sizes.append( |
|
|
( |
|
|
scales[resize_output_to_scale_idx] // scales[0], |
|
|
scales[resize_output_to_scale_idx] // scales[0], |
|
|
) |
|
|
) |
|
|
|
|
|
assert block_cnt == len( |
|
|
image_features |
|
|
), f"The number of blocks ({block_cnt}) does not match length of image_features ({len(image_features)})!" |
|
|
|
|
|
return image_features_each_image, new_block_sizes |
|
|
|
|
|
|
|
|
def split_chessboard(x, num_split_h, num_split_w): |
|
|
""" |
|
|
x: b * c * h * w |
|
|
out: b * c * h * w |
|
|
Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension |
|
|
""" |
|
|
B, C, H, W = x.shape |
|
|
assert H % num_split_h == 0 and W % num_split_w == 0 |
|
|
h, w = H // num_split_h, W // num_split_w |
|
|
x_split = torch.cat( |
|
|
[x[:, :, i * h : (i + 1) * h, j * w : (j + 1) * w] for i in range(num_split_h) for j in range(num_split_w)], |
|
|
dim=0, |
|
|
) |
|
|
return x_split |
|
|
|