NVILA-15B-hf / modeling_nvila.py
Ligeng-Zhu's picture
Upload files with `vila-upload`.
bf6def8 verified
import contextlib
import math
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from transformers import Qwen2ForCausalLM, SiglipVisionModel
from transformers.cache_utils import Cache
from transformers.generation.utils import GenerationMixin
from transformers.modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from .configuration_nvila import NVILAConfig
MM_HIDDEN_SIZE = 3456
class NVILAMultiModalProjectorDownsampleBlock(nn.Module):
def forward(self, x: Tensor) -> Tensor:
batch_size, sequence_length, hidden_size = x.shape
feat_size = math.isqrt(sequence_length)
features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
pad_after = feat_size % 2
if pad_after > 0:
features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after))
feat_size = feat_size + pad_after
features = features.reshape(batch_size, feat_size // 2, 2, feat_size // 2, 2, hidden_size)
features = features.permute(0, 1, 3, 2, 4, 5).contiguous()
features = features.reshape(batch_size, -1, 4 * hidden_size)
return features
class NVILAMultiModalProjector(nn.Module):
def __init__(self, config: NVILAConfig):
super().__init__()
self.layers = nn.Sequential(
NVILAMultiModalProjectorDownsampleBlock(),
nn.LayerNorm(MM_HIDDEN_SIZE * 4),
nn.Linear(MM_HIDDEN_SIZE * 4, config.text_config.hidden_size),
nn.GELU(),
nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size),
)
def forward(self, x: Tensor) -> Tensor:
return self.layers(x)
class NVILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
config_class = NVILAConfig
base_model_prefix: str = "llm"
_auto_class = "AutoModel"
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, config: NVILAConfig):
super().__init__(config)
self.config: NVILAConfig
@contextlib.contextmanager
def default_torch_dtype(dtype):
original_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
try:
yield
finally:
torch.set_default_dtype(original_dtype)
with default_torch_dtype(config.torch_dtype):
self.vision_tower = SiglipVisionModel(config.vision_config)
self.mm_projector = NVILAMultiModalProjector(config)
self.llm = Qwen2ForCausalLM(config.text_config)
self.post_init()
def forward(
self,
*,
block_sizes: list[tuple[int, int]] | None = None,
input_ids: Tensor | None = None,
inputs_embeds: Tensor | None = None,
pixel_values: Tensor | None = None,
pixel_values_videos: Tensor | None = None,
**kwargs,
) -> CausalLMOutputWithPast:
assert (input_ids is None) != (
inputs_embeds is None
), "Exactly one of `input_ids` or `inputs_embeds` must be specified."
if input_ids is not None and torch.any(
torch.isin(
input_ids,
torch.tensor(
[self.config.image_token_id, self.config.video_token_id],
device=input_ids.device,
),
).any()
): # Prefill
inputs_embeds = self._embed(
block_sizes=block_sizes,
input_ids=input_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
)
input_ids = None
outputs = self.llm(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
**kwargs,
)
return outputs
def _embed(
self,
*,
block_sizes: list[tuple[int, int]] | None,
input_ids: Tensor,
pixel_values: Tensor | None,
pixel_values_videos: Tensor | None,
) -> Tensor:
inputs_embeds: Tensor = self.llm.model.embed_tokens(input_ids)
for pixel_values, media_token_id in [
(pixel_values, self.config.image_token_id),
(pixel_values_videos, self.config.video_token_id),
]:
if pixel_values is None:
continue
vision_features = self._encode_vision(
pixel_values,
block_sizes=block_sizes,
)
vision_features = einops.rearrange(vision_features, "n p d -> (n p) d")
inputs_embeds[input_ids == media_token_id] = vision_features
return inputs_embeds
def _encode_vision(
self,
pixel_values: Tensor,
*,
block_sizes: list[tuple[int, int]] | None = None,
) -> Tensor:
vision_tower_output: BaseModelOutputWithPooling = self.vision_tower(
pixel_values.to(device=self.vision_tower.device, dtype=self.vision_tower.dtype),
output_hidden_states=True,
)
assert vision_tower_output.hidden_states is not None
vision_features: Tensor = vision_tower_output.hidden_states[-2]
vision_features_list, block_sizes = merge_features_for_dynamic_s2(
vision_features,
block_sizes=block_sizes if block_sizes is not None else [None] * vision_features.shape[0],
resize_output_to_scale_idx=-1,
scales=[448, 896, 1344],
)
vision_features_list = [
split_chessboard(x, block_size[0], block_size[1])
for x, block_size in zip(vision_features_list, block_sizes)
]
vision_features = torch.cat([einops.rearrange(x, "b c h w -> b (h w) c") for x in vision_features_list])
vision_features = self.mm_projector(vision_features.to(self.device, self.dtype))
vision_features_list = list(
vision_features.split([block_size[0] * block_size[1] for block_size in block_sizes], dim=0)
)
vision_features_list = [
merge_chessboard(x, block_size[0], block_size[1])
for x, block_size in zip(vision_features_list, block_sizes)
]
vision_features = torch.stack([einops.rearrange(x, "1 c h w -> (h w) c") for x in vision_features_list])
return vision_features
# NOTE: The following functions are directly copied from VILA codebase.
def merge_chessboard(x, num_split_h, num_split_w):
"""
x: b * n * c or b * h * w * c
out: b * c * h * w
Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square.
"""
B = x.shape[0]
if x.dim() == 3:
N = x.shape[1]
x = einops.rearrange(x, "b (h w) c -> b c h w", h=math.isqrt(N), w=math.isqrt(N))
assert B % (num_split_h * num_split_w) == 0
b = B // (num_split_h * num_split_w)
x_merge = torch.cat(
[
torch.cat(
[x[(i * num_split_w + j) * b : (i * num_split_w + j + 1) * b] for j in range(num_split_w)], dim=-1
)
for i in range(num_split_h)
],
dim=-2,
)
return x_merge
def merge_features_for_dynamic_s2(image_features, block_sizes, *, scales, resize_output_to_scale_idx):
image_features_each_image = []
new_block_sizes = []
block_cnt = 0
for block_size_each_image in block_sizes:
if block_size_each_image is None:
cur_features = image_features[block_cnt : block_cnt + 1]
cur_features = einops.rearrange(cur_features, "1 (h w) c -> 1 c h w", h=math.isqrt(cur_features.shape[1]))
cur_features = cur_features.repeat(1, len(scales), 1, 1)
image_features_each_image.append(cur_features)
new_block_sizes.append((1, 1))
block_cnt += 1
else:
cur_features_each_scale = []
for scale in scales[:-1]:
num_blocks_this_scale = (scale // scales[0]) ** 2
cur_features_each_scale.append(
merge_chessboard(
image_features[block_cnt : block_cnt + num_blocks_this_scale],
num_split_h=scale // scales[0],
num_split_w=scale // scales[0],
)
) # 1 * C * H * W
block_cnt += num_blocks_this_scale
num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1]
cur_features_each_scale.append(
merge_chessboard(
image_features[block_cnt : block_cnt + num_blocks_last_scale],
num_split_h=block_size_each_image[0],
num_split_w=block_size_each_image[1],
)
) # 1 * C * H * W
block_cnt += num_blocks_last_scale
# resize and concat features from different scales
output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:]
cur_features = torch.cat(
[
F.interpolate(cur_features_each_scale[i].to(torch.float32), size=output_size, mode="area").to(
cur_features_each_scale[i].dtype
)
for i in range(len(cur_features_each_scale))
],
dim=1,
)
# cur_features = rearrange(cur_features, "1 c h w -> (h w) c")
image_features_each_image.append(cur_features)
if resize_output_to_scale_idx == len(scales) - 1 or resize_output_to_scale_idx == -1:
new_block_sizes.append(block_size_each_image)
else:
new_block_sizes.append(
(
scales[resize_output_to_scale_idx] // scales[0],
scales[resize_output_to_scale_idx] // scales[0],
)
)
assert block_cnt == len(
image_features
), f"The number of blocks ({block_cnt}) does not match length of image_features ({len(image_features)})!"
return image_features_each_image, new_block_sizes
def split_chessboard(x, num_split_h, num_split_w):
"""
x: b * c * h * w
out: b * c * h * w
Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension
"""
B, C, H, W = x.shape
assert H % num_split_h == 0 and W % num_split_w == 0
h, w = H // num_split_h, W // num_split_w
x_split = torch.cat(
[x[:, :, i * h : (i + 1) * h, j * w : (j + 1) * w] for i in range(num_split_h) for j in range(num_split_w)],
dim=0,
)
return x_split