import math from abc import ABC, abstractmethod from typing import Any, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import pack, rearrange, repeat, unpack from einops.layers.torch import Rearrange from timm.layers import DropPath, to_3tuple, trunc_normal_ from transformers import ( AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel, Qwen2Config, Qwen2ForCausalLM, Qwen2Model, ) from transformers.generation.utils import GenerateOutput from transformers.modeling_outputs import CausalLMOutputWithPast try: import torch.distributed.nn from torch import distributed as dist has_distributed = True except ImportError: has_distributed = False class DEC_CLIPConfig(PretrainedConfig): model_type = "dec_clip" def __init__( self, language_model_name_or_path: str = "", local_loss: bool = False, gather_loss: bool = True, input_size: tuple = (256, 256, 128), dim: int = 768, depth: int = 12, hidden_size: int = 512, mlp_depth: int = 2, loss_type: str = "nce", t_prime: float = np.log(1 / 0.07), bias: float = 0.0, efficient_loss: bool = False, **kwargs, ): self.language_model_name_or_path = language_model_name_or_path self.input_size = input_size self.dim = dim self.depth = depth self.hidden_size = hidden_size self.mlp_depth = mlp_depth self.local_loss = local_loss self.gather_loss = gather_loss self.loss_type = loss_type self.t_prime = t_prime self.bias = bias self.efficient_loss = efficient_loss super().__init__(**kwargs) class DEC_CLIP(PreTrainedModel): config_class = DEC_CLIPConfig def __init__(self, config): super().__init__(config) self.config = config if config.vision_encoder == "vit3d": self.vision_encoder = Vit3D( input_size=config.input_size, dim=config.dim, depth=config.depth, ) elif config.vision_encoder == "dcformer": self.vision_encoder = decomp_small(input_size=config.input_size) else: raise ValueError(f"Unexpected vision encoder: {config.vision_encoder}") self.language_encoder = AutoModel.from_pretrained( config.language_model_name_or_path ) self.mm_vision_proj = nn.Linear( self.vision_encoder.channels[-1], config.hidden_size ) self.mm_language_proj = nn.Linear( self.language_encoder.config.dim, config.hidden_size ) self.efficient_loss = config.efficient_loss self.local_loss = config.local_loss self.gather_loss = config.gather_loss self.loss_type = config.loss_type if self.loss_type == "sigmoid": self.t_prime = nn.Parameter(torch.tensor(config.t_prime)) self.bias = nn.Parameter(torch.tensor(config.bias)) else: self.logit_scale = nn.Parameter(torch.ones([]) * config.t_prime) def encode_image(self, image): image_feats = self.vision_encoder(image) if isinstance(image_feats, list): image_feats = image_feats[-1] image_feats = image_feats.mean(dim=1) image_feats = self.mm_vision_proj(image_feats) image_feats = F.normalize(image_feats, dim=-1) return image_feats def encode_text(self, input_id, attention_mask): text_feats = self.language_encoder(input_id, attention_mask=attention_mask)[ "last_hidden_state" ] text_feats = text_feats[:, 0] text_feats = self.mm_language_proj(text_feats) text_feats = F.normalize(text_feats, dim=-1) return text_feats def forward(self, images, input_ids, attention_mask, labels, **kwargs): image_features = self.encode_image(images) text_features = self.encode_text(input_ids, attention_mask) rank = 0 world_size = 1 if has_distributed and dist.is_initialized(): rank = dist.get_rank() world_size = dist.get_world_size() batch_size = image_features.size(0) device = image_features.device if self.loss_type == "sigmoid": if has_distributed and dist.is_initialized(): if self.efficient_loss: t = torch.exp(self.t_prime) loss = 0.0 for target_rank in range(world_size): if rank == target_rank: target_text_features = text_features else: target_text_features = torch.distributed.nn.broadcast( text_features.requires_grad_(), target_rank ) local_logits_per_image = ( image_features @ target_text_features.T ) * t + self.bias local_logits_per_text = local_logits_per_image.T if rank == target_rank: local_labels = 2 * torch.eye( batch_size, device=device ) - torch.ones(batch_size, batch_size, device=device) else: local_labels = -torch.ones( batch_size, batch_size, device=device ) local_logits = ( local_logits_per_image + local_logits_per_text ) / 2.0 local_loss = -torch.sum( F.logsigmoid(local_labels * local_logits) ) / (batch_size * world_size) loss += local_loss torch.distributed.nn.all_reduce(loss) torch.cuda.synchronize() if self.training: logits = 0 else: t = torch.exp(self.t_prime) all_image_features, all_text_features = gather_features( image_features, text_features, gather_with_grad=True, rank=rank, world_size=world_size, ) logits_per_image = ( all_image_features @ all_text_features.T ) * t + self.bias logits_per_text = logits_per_image.T batch_size = all_image_features.size(0) labels = 2 * torch.eye( batch_size, device=image_features.device ) - torch.ones(batch_size, device=image_features.device) logits = (logits_per_image + logits_per_text) / 2.0 loss = -torch.sum(F.logsigmoid(labels * logits)) / batch_size else: logits_per_image = ( image_features @ text_features.T ) * self.t_prime + self.bias logits_per_text = logits_per_image.T labels = 2 * torch.eye(batch_size, device=device) - torch.ones( batch_size, batch_size, device=device ) logits = (logits_per_image + logits_per_text) / 2.0 loss = -torch.sum(F.logsigmoid(logits * labels)) else: all_image_features, all_text_features = gather_features( image_features, text_features, local_loss=self.local_loss, gather_with_grad=True, rank=rank, world_size=world_size, ) if self.gather_loss: if self.local_loss: logits_per_image = ( self.logit_scale * image_features @ all_text_features.T ) logits_per_text = ( self.logit_scale * text_features @ all_image_features.T ) else: logits_per_image = ( self.logit_scale * all_image_features @ all_text_features.T ) logits_per_text = logits_per_image.T else: logits_per_image = self.logit_scale * image_features @ text_features.T logits_per_text = self.logit_scale * text_features @ image_features.T image_loss = F.cross_entropy(logits_per_image, labels) text_loss = F.cross_entropy(logits_per_text, labels) loss = (image_loss + text_loss) / 2.0 logits = ((logits_per_image + logits_per_text) / 2.0,) ret = { "loss": loss, "logits": logits, } return ret def gather_features( image_features, text_features, local_loss=False, gather_with_grad=True, rank=0, world_size=1, ): assert ( has_distributed ), "torch.distributed did not import correctly, please use a PyTorch version with support." if not (has_distributed and dist.is_initialized()): return image_features, text_features if gather_with_grad: all_image_features = torch.cat( torch.distributed.nn.all_gather(image_features), dim=0 ) all_text_features = torch.cat( torch.distributed.nn.all_gather(text_features), dim=0 ) else: gathered_image_features = [ torch.zeros_like(image_features) for _ in range(world_size) ] gathered_text_features = [ torch.zeros_like(text_features) for _ in range(world_size) ] dist.all_gather(gathered_image_features, image_features) dist.all_gather(gathered_text_features, text_features) if not local_loss: gathered_image_features[rank] = image_features gathered_text_features[rank] = text_features all_image_features = torch.cat(gathered_image_features, dim=0) all_text_features = torch.cat(gathered_text_features, dim=0) return all_image_features, all_text_features AutoConfig.register("dec_clip", DEC_CLIPConfig) AutoModel.register(DEC_CLIPConfig, DEC_CLIP) def stem(inp, oup, image_size, downsample=False): stride = 1 if downsample == False else 2 return nn.Sequential( nn.Conv3d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm3d(oup), nn.GELU(), nn.Conv3d(oup, oup, 3, 1, 1, bias=False), nn.BatchNorm3d(oup), nn.GELU(), ) def DecomposedStem(inp, oup, image_size, kernel_size, downsample=False): return nn.Sequential( DecompConv3D(inp, oup, 7, 4, 1, nn.GELU()), DecompConv3D(oup, oup, 3, 1, 1, nn.GELU()), DecompConv3D(oup, oup, 3, 1, 1, nn.GELU()), DecompConv3D(oup, oup, 3, 1, 1, nn.GELU()), ) class DecompConv3D(nn.Module): def __init__( self, in_dim, out_dim, kernel_size, stride=1, groups=1, norm=True, act=None ) -> None: super().__init__() self.act = act self.c1 = nn.Sequential( nn.Conv3d( in_dim, out_dim, kernel_size=(kernel_size, 1, 1), padding=(kernel_size // 2, 0, 0), stride=stride, groups=groups, ), nn.BatchNorm3d(out_dim) if norm else nn.Identity(), ) self.c2 = nn.Sequential( nn.Conv3d( in_dim, out_dim, kernel_size=(1, kernel_size, 1), padding=(0, kernel_size // 2, 0), stride=stride, groups=groups, ), nn.BatchNorm3d(out_dim) if norm else nn.Identity(), ) self.c3 = nn.Sequential( nn.Conv3d( in_dim, out_dim, kernel_size=(1, 1, kernel_size), padding=(0, 0, kernel_size // 2), stride=stride, groups=groups, ), nn.BatchNorm3d(out_dim) if norm else nn.Identity(), ) def forward(self, x): x = self.c1(x) + self.c2(x) + self.c3(x) if self.act is not None: x = self.act(x) return x class ConvPosEnc(nn.Module): def __init__(self, dim, k=3, decompose=False): super().__init__() if decompose: self.proj = DecompConv3D(dim, dim, k, groups=dim, norm=None) else: self.proj = nn.Conv3d( dim, dim, to_3tuple(k), to_3tuple(1), to_3tuple(k // 2), groups=dim ) def forward(self, x, size): B, N, C = x.shape H, W, T = size assert N == H * W * T feat = rearrange(x, "b (h w t) c -> b c h w t", h=H, w=W, t=T) feat = self.proj(feat) feat = rearrange(feat, "b c h w t -> b (h w t) c ") x = x + feat return x class MLP(nn.Module): def __init__(self, oup, mlp_dim, dp=0.0): super().__init__() self.mlp = nn.Sequential( nn.Linear(oup, mlp_dim), nn.GELU(), nn.Dropout(dp), nn.Linear(mlp_dim, oup), nn.Dropout(dp), ) def forward(self, x): x = self.mlp(x) return x class ScaleDotProduct(nn.Module): def __init__(self, scale): super().__init__() self.scale = scale self.softmax = nn.Softmax(dim=-1) def forward(self, qkv): q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.scale attn = q @ k.transpose(-2, -1) attn = self.softmax(attn) x = (attn @ v).transpose(1, 2) return x class DecomposedAttention(nn.Module): def __init__(self, oup, head_num): super().__init__() self.head_num = head_num scale = (oup // head_num) ** (1 / 2) self.sdp = ScaleDotProduct(scale) self.qkv = nn.Linear(oup, oup * 3, bias=False) self.proj = nn.Linear(oup, oup, bias=False) def forward(self, x, size): b, n, c = x.shape h, w, t = size assert n == h * w * t B_, N, C = x.shape qkv = ( self.qkv(x) .reshape(B_, N, 3, self.head_num, C // self.head_num) .permute(2, 0, 3, 1, 4) ) x = rearrange(qkv, "k b nh (h w t) c -> k b c nh h w t", h=h, w=w, t=t) x1 = rearrange(x, "k b c nh h w t -> k (b t) nh (h w) c") x2 = rearrange(x, "k b c nh h w t -> k (b w) nh (h t) c") x3 = rearrange(x, "k b c nh h w t -> k (b h) nh (w t) c") x1 = self.sdp(x1) x2 = self.sdp(x2) x3 = self.sdp(x3) x1 = rearrange(x1, "(b t) (h w) nh c -> b (h w t) (nh c)", h=h, w=w, t=t) x2 = rearrange(x2, "(b w) (h t) nh c -> b (h w t) (nh c)", h=h, w=w, t=t) x3 = rearrange(x3, "(b h) (w t) nh c -> b (h w t) (nh c)", h=h, w=w, t=t) x = self.proj(x1 + x2 + x3) return x class SelfAttention(nn.Module): def __init__(self, oup, head_num): super().__init__() self.head_num = head_num scale = (oup // head_num) ** (1 / 2) self.sdp = ScaleDotProduct(scale) self.qkv = nn.Linear(oup, oup * 3, bias=False) self.proj = nn.Linear(oup, oup, bias=False) def forward(self, x, size=None): B_, N, C = x.shape qkv = ( self.qkv(x) .reshape(B_, N, 3, self.head_num, C // self.head_num) .permute(2, 0, 3, 1, 4) ) x = self.sdp(qkv).reshape(B_, N, C) x = self.proj(x) return x class ChannelAttention(nn.Module): def __init__(self, oup, head_num): super().__init__() self.head_num = head_num self.scale = (oup // head_num) ** (1 / 2) self.softmax = nn.Softmax(dim=-1) self.qkv = nn.Linear(oup, oup * 3, bias=False) self.proj = nn.Linear(oup, oup, bias=False) def forward(self, x): B, N, C = x.shape qkv = ( self.qkv(x) .reshape(B, N, 3, self.head_num, C // self.head_num) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv[0], qkv[1], qkv[2] k = k * self.scale attention = k.transpose(-1, -2) @ v attention = self.softmax(attention) x = (attention @ q.transpose(-1, -2)).transpose(-1, -2) x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) return x class ChannelBlock(nn.Module): def __init__(self, dim, heads=8): super().__init__() hidden_dim = int(dim * 4) self.cpe = nn.ModuleList( [ ConvPosEnc(dim=dim, k=3, decompose=True), ConvPosEnc(dim=dim, k=3, decompose=True), ] ) self.attn = ChannelAttention(dim, heads) self.layer_norm1 = nn.LayerNorm(dim) self.mlp1 = MLP(dim, hidden_dim) self.layer_norm2 = nn.LayerNorm(dim) def forward(self, x, size): x = self.cpe[0](x, size) _x = self.layer_norm1(x) _x = self.attn(_x) x = x + _x x = self.cpe[1](x, size) _x = self.layer_norm2(x) _x = self.mlp1(_x) x = x + _x return x class SpatialBlock(nn.Module): def __init__(self, dim, heads=8): super().__init__() hidden_dim = int(dim * 4) self.cpe = nn.ModuleList( [ ConvPosEnc(dim=dim, k=3, decompose=True), ConvPosEnc(dim=dim, k=3, decompose=True), ] ) # self.attn = DecomposedAttention(dim, heads) self.attn = SelfAttention(dim, heads) self.layer_norm1 = nn.LayerNorm(dim) self.mlp1 = MLP(dim, hidden_dim) self.layer_norm2 = nn.LayerNorm(dim) def forward(self, x, size): x = self.cpe[0](x, size) _x = self.layer_norm1(x) _x = self.attn(_x, size) x = x + _x x = self.cpe[1](x, size) _x = self.layer_norm2(x) _x = self.mlp1(_x) x = x + _x return x class TransformerBlock(nn.Module): def __init__( self, inp, oup, image_size, kernel_size, heads=8, dim_head=32, downsample=False, dropout=0.0, ): super().__init__() hidden_dim = int(inp * 4) self.ih, self.iw, self.it = image_size self.downsample = downsample if self.downsample: self.pool1 = nn.MaxPool3d(3, 2, 1) # self.pool2 = nn.MaxPool3d(3, 2, 1) self.proj = nn.Conv3d(inp, oup, 1, 1, 0, bias=False) self.spatial_attention = SpatialBlock(oup, heads) # self.channel_attention = ChannelBlock(oup, heads) def forward(self, x): if self.downsample: x = self.proj(self.pool1(x)) # if self.downsample: # x = self.proj(self.pool1(x)) + self.attn(self.pool2(x)) # x = x.permute(0, 2, 3, 4, 1) h, w, t = x.shape[2], x.shape[3], x.shape[4] size = (h, w, t) x = rearrange(x, "b c h w t -> b (h w t) c ") x = self.spatial_attention(x, size) # x = self.channel_attention(x, size) x = rearrange(x, "b (h w t) c -> b c h w t", h=h, w=w, t=t) return x class ConvBlock(nn.Module): def __init__( self, inp, oup, image_size, kernel_size=7, downsample=False, expansion=4 ): super().__init__() self.downsample = downsample stride = 1 if self.downsample == False else 2 hidden_dim = int(oup * expansion) drop_path = 0.0 layer_scale_init_value = 1e-6 if self.downsample: self.pool = nn.MaxPool3d(3, 2, 1) self.proj = nn.Conv3d(inp, oup, 1, 1, 0, bias=False) # self.dwconv = nn.Sequential(nn.Conv3d(oup, oup, kernel_size=7, padding=3, groups=oup), nn.BatchNorm3d(oup)) # depthwise conv self.dwconv = DecompConv3D(oup, oup, kernel_size, groups=oup) self.mlp = MLP(oup, hidden_dim) self.scale = ( nn.Parameter(layer_scale_init_value * torch.ones((oup)), requires_grad=True) if layer_scale_init_value > 0 else None ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, x): if self.downsample: x = self.proj(self.pool(x)) input = x x = self.dwconv(x) x = x.permute(0, 2, 3, 4, 1) # (N, C, H, W, T) -> (N, H, W, T, C) x = self.mlp(x) if self.scale is not None: x = self.scale * x x = x.permute(0, 4, 1, 2, 3) # (N, H, W, C) -> (N, C, H, W) x = input + self.drop_path(x) return x class Encoder(nn.Module): def __init__( self, input_size, in_channels, num_blocks, channels, kernel_sizes=[7, 7, 7, 7], block_types=["C", "C", "C", "C"], ): super().__init__() self.dims = channels ih, iw, it = input_size block = {"C": ConvBlock, "T": TransformerBlock} i = 4 self.s0 = self._make_layer( DecomposedStem, in_channels, channels[0], num_blocks[0], kernel_sizes[0], (ih // i, iw // i, it // i), ) self.s1 = self._make_layer( block[block_types[0]], channels[0], channels[1], num_blocks[1], kernel_sizes[0], (ih // (i * 2**1), iw // (i * 2**1), it // (i * 2**1)), ) self.s2 = self._make_layer( block[block_types[1]], channels[1], channels[2], num_blocks[2], kernel_sizes[1], (ih // (i * 2**2), iw // (i * 2**2), it // (i * 2**2)), ) self.s3 = self._make_layer( block[block_types[2]], channels[2], channels[3], num_blocks[3], kernel_sizes[2], (ih // (i * 2**3), iw // (i * 2**3), it // (i * 2**3)), ) self.s4 = self._make_layer( block[block_types[3]], channels[3], channels[4], num_blocks[4], kernel_sizes[3], (ih // (i * 2**4), iw // (i * 2**4), it // (i * 2**4)), ) def forward(self, x): hidden_states = [] x = x.permute(0, 1, 3, 4, 2) for i in range(5): if hasattr(self, "s" + str(i)): x = getattr(self, "s" + str(i))(x) hidden_states.append(x) return hidden_states def _make_layer(self, block, inp, oup, depth, kernel_size, image_size): layers = nn.ModuleList([]) for i in range(depth): if i == 0: layers.append(block(inp, oup, image_size, kernel_size, downsample=True)) else: layers.append(block(oup, oup, image_size, kernel_size)) return nn.Sequential(*layers) class DecompModel(nn.Module): def __init__( self, input_size=(512, 512, 256), in_channels=1, num_blocks=[2, 2, 3, 5, 2], channels=[64, 96, 192, 384, 768], # kernel_sizes=[7, 7, 7, 7], kernel_sizes=[13, 11, 9, 7], block_types=["C", "C", "C", "C"], codebook_size=8192, ): super().__init__() self.channels = channels self.encoder = Encoder( input_size, in_channels, num_blocks, channels, kernel_sizes, block_types ) # self.vq = VectorQuantize(dim = channels[-1], codebook_size = codebook_size, use_cosine_sim = True) def forward(self, video, mask=None, device="cuda"): hidden_states = self.encoder(video) # tokens = rearrange(tokens, "b d h w t -> b t h w d") # shape = tokens.shape # *_, h, w, _ = shape # quantize # tokens, _ = pack([tokens], "b * d") for i in range(len(hidden_states)): hidden_states[i] = rearrange(hidden_states[i], "b d h w t -> b t h w d") hidden_states[i], _ = pack([hidden_states[i]], "b * d") # vq_mask = None # tokens, _, _ = self.vq(tokens, mask = vq_mask) # tokens = rearrange(tokens, 'b (t h w) d -> b t h w d', h = h, w = w) return hidden_states def decomp_nano( input_size=(512, 512, 256), # input_size=(256, 256, 128), ): model = DecompModel( input_size=input_size, num_blocks=[1, 1, 1, 1, 1], channels=[32, 32, 64, 128, 256], ) return model def decomp_naive( input_size=(512, 512, 256), # input_size=(256, 256, 128), ): model = DecompModel( input_size=input_size, num_blocks=[1, 2, 2, 2, 2], # channels = [64, 64, 128, 256, 512] channels=[32, 64, 128, 256, 512], ) return model def decomp_tiny( input_size=(512, 512, 256), ): model = DecompModel( input_size=input_size, num_blocks=[1, 2, 3, 3, 2], # channels = [64, 64, 128, 256, 512] channels=[64, 96, 192, 384, 768], ) return model def decomp_small( input_size=(512, 512, 256), ): model = DecompModel( input_size=input_size, num_blocks=[1, 2, 3, 6, 2], channels=[64, 96, 192, 384, 768], ) return model def decomp_base( input_size=(512, 512, 256), ): model = DecompModel( input_size=input_size, num_blocks=[1, 2, 6, 6, 2], channels=[64, 128, 256, 512, 1024], ) return model def decomp_large( input_size=(512, 512, 256), ): model = DecompModel( input_size=input_size, num_blocks=[1, 2, 6, 12, 2], # channels=[64, 192, 384, 768, 1536], channels=[64, 256, 512, 1024, 2048], ) return model class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout=0.0): super().__init__() self.net = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout), ) def forward(self, x): return self.net(x) class Attention(nn.Module): def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads self.scale = dim_head**-0.5 self.norm = nn.LayerNorm(dim) self.attend = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = ( nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) if project_out else nn.Identity() ) def forward(self, x): x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale attn = self.attend(dots) attn = self.dropout(attn) out = torch.matmul(attn, v) out = rearrange(out, "b h n d -> b n (h d)") return self.to_out(out) class Transformer(nn.Module): def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( nn.ModuleList( [ Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout), FeedForward(dim, mlp_dim, dropout=dropout), ] ) ) def forward(self, x): for attn, ff in self.layers: x = attn(x) + x x = ff(x) + x return x class ViTEncoder(nn.Module): def __init__( self, image_size=[512, 512, 256], patch_size=16, dim=512, depth=8, heads=8, mlp_dim=4, channels=1, dim_head=64, dropout=0.0, emb_dropout=0.0, ): super().__init__() h, w, t = image_size[0], image_size[1], image_size[2] self.vit_img_dim = [i // patch_size for i in image_size] num_patches = (h // patch_size) * (w // patch_size) * (t // patch_size) patch_dim = channels * patch_size * patch_size * patch_size self.to_patch_embedding = nn.Sequential( Rearrange( "b c (h p1) (w p2) (t p3) -> b (h w t) (p1 p2 p3 c)", p1=patch_size, p2=patch_size, p3=patch_size, ), nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), nn.LayerNorm(dim), ) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.dropout = nn.Dropout(emb_dropout) self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) def forward(self, x): x = x.permute(0, 1, 3, 4, 2) x = self.to_patch_embedding(x) b, n, _ = x.shape cls_tokens = repeat(self.cls_token, "1 1 d -> b 1 d", b=b) x = torch.cat((cls_tokens, x), dim=1) x += self.pos_embedding[:, : (n + 1)] x = self.dropout(x) x = self.transformer(x) x = x[:, 1:, :] x = rearrange( x, "b (x y z) c -> b c x y z", x=self.vit_img_dim[0], y=self.vit_img_dim[1], z=self.vit_img_dim[2], ) return x class Vit3D(nn.Module): def __init__(self, input_size=[512, 512, 256], patch_size=32, dim=512, depth=8): super().__init__() self.encoder = ViTEncoder(input_size, patch_size, dim, depth) # self.vq = VectorQuantize(dim = dim, codebook_size = 8192, use_cosine_sim = True) def forward(self, video, mask=None, device="cuda"): tokens = self.encoder(video) tokens = rearrange(tokens, "b d h w t -> b t h w d") shape = tokens.shape *_, h, w, _ = shape # quantize tokens, _ = pack([tokens], "b * d") # vq_mask = None # tokens, _, _ = self.vq(tokens, mask = vq_mask) # tokens = rearrange(tokens, 'b (t h w) d -> b t h w d', h = h, w = w) return tokens def build_vision_tower(config, **kwargs): return VisionTower(config) class VisionTower(nn.Module): def __init__(self, config): super().__init__() self.config = config self.select_layer = config.vision_select_layer self.select_feature = config.vision_select_feature self.hidden_size = config.dim if config.vision_tower == "vit3d": self.vision_tower = Vit3D( input_size=config.input_size, dim=config.dim, depth=config.depth, ) elif config.vision_tower == "dcformer": self.vision_tower = decomp_small( input_size=config.input_size, ) self.low_input_size = self.vision_tower.channels[-2] self.high_input_size = self.vision_tower.channels[-1] else: raise ValueError(f"Unexpected vision tower: {config.vision_tower}") def forward(self, images): hidden_states = self.vision_tower(images) if self.select_layer == 0: image_features = hidden_states elif self.select_layer < 0: image_features = hidden_states[self.select_layer :] else: raise ValueError(f"Unexpected select layer: {self.select_layer}") if self.select_feature == "patch": image_features = image_features[:, 1:] elif self.select_feature == "cls_patch": image_features = image_features else: raise ValueError(f"Unexpected select feature: {self.select_feature}") return image_features @property def dtype(self): return self.vision_tower.dtype @property def device(self): return self.vision_tower.device def readable_params(num): magnitude = 0 while abs(num) >= 1000: magnitude += 1 num /= 1000.0 return "%.2f%s" % (num, ["", "K", "M", "G", "T", "P"][magnitude]) class MLPLayer(nn.Module): def __init__(self, embed_dim, scale=4, *args, **kwargs): super().__init__(*args, **kwargs) self.linear1 = nn.Linear(embed_dim, embed_dim * scale) self.linear2 = nn.Linear(embed_dim * scale, embed_dim) self.act = nn.GELU() def forward(self, x): x = self.linear1(x) x = self.act(x) x = self.linear2(x) return x class MultiHeadSelfAttention(nn.Module): def __init__( self, embed_dim, output_dim, num_heads=8, proj_out_num=32, ): super(MultiHeadSelfAttention, self).__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.proj_out_num = proj_out_num self.mlp = MLPLayer(embed_dim) assert ( self.head_dim * num_heads == embed_dim ), "embed_dim must be divisible by num_heads" self.norm1 = nn.LayerNorm(embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) self.norm2 = nn.LayerNorm(embed_dim) self.act = nn.GELU() self.out_layer = nn.Linear(embed_dim, output_dim) self.scale = math.sqrt(self.head_dim) def forward(self, x): batch_size, seq_len, _ = x.size() x = self.norm1(x) q = ( self.q_proj(x) .reshape(batch_size, seq_len, self.num_heads, self.head_dim) .transpose(1, 2) ) k = ( self.k_proj(x) .reshape(batch_size, seq_len, self.num_heads, self.head_dim) .transpose(1, 2) ) v = ( self.v_proj(x) .reshape(batch_size, seq_len, self.num_heads, self.head_dim) .transpose(1, 2) ) attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale attn_weights = F.softmax(attn_weights, dim=-1) attn_output = torch.matmul(attn_weights, v) attn_output = attn_output.transpose(1, 2).reshape( batch_size, seq_len, self.embed_dim ) output = self.out_proj(attn_output) output = self.norm2(output) output = self.mlp(output) output = self.act(output) output = self.out_layer(output) return output class MultiLayerPerceptron(nn.Module): def __init__(self, hidden_size, depth): super().__init__() self.mlp = nn.Sequential( nn.Linear(hidden_size, hidden_size), *[ nn.Sequential(nn.GELU(), nn.Linear(hidden_size, hidden_size)) for _ in range(depth - 1) ], ) def forward(self, x): return self.mlp(x) class MultiModalProjector(nn.Module): def __init__(self, input_size, output_size, mlp_depth, proj_out_num=256): super().__init__() self.proj_out_num = proj_out_num self.mm_projector = nn.Sequential( nn.Linear(input_size, output_size), *[ nn.Sequential( nn.GELU(), nn.Linear(output_size, output_size), ) for _ in range(mlp_depth - 1) ], ) def forward(self, x): return self.mm_projector(x) class LowHighHybridMLP(nn.Module): def __init__( self, low_input_size, high_input_size, output_size, mlp_depth, proj_out_num=288 ): super().__init__() self.proj_out_num = proj_out_num self.low_up_mlp = nn.Linear(low_input_size, output_size) self.high_up_mlp = nn.Linear(high_input_size, output_size) modules = [] for _ in range(1, mlp_depth): modules.append(nn.GELU()) modules.append(nn.Linear(output_size, output_size)) self.mm_projector = nn.Sequential(*modules) def forward(self, x): low_x, high_x = x low_x = self.low_up_mlp(low_x) high_x = self.high_up_mlp(high_x) x = torch.cat([low_x, high_x], dim=1) x = self.mm_projector(x) return x class MixerLayer(nn.Module): def __init__(self, input_size, output_size, mlp_depth=2): super().__init__() self.ln1 = nn.LayerNorm(input_size[1]) self.ln2 = nn.LayerNorm(input_size[1]) self.mlp1 = MultiModalProjector( input_size=input_size[0], output_size=output_size[0], mlp_depth=mlp_depth ) self.mlp2 = MultiModalProjector( input_size=input_size[1], output_size=output_size[1], mlp_depth=mlp_depth ) def forward(self, x): x = self.ln1(x) x = rearrange(x, "b n d -> b d n") x = self.mlp1(x) x = rearrange(x, "b d n -> b n d") x = self.ln2(x) x = self.mlp2(x) return x class MixerLowHighHybridMLP(nn.Module): def __init__( self, low_input_size: tuple = (256, 384), low_output_size: list = [192, 128], high_input_size: tuple = (32, 768), high_output_size: list = [64, 128], output_dim=3584, depth=2, mlp_depth=2, proj_out_num=256, ): assert ( len(low_output_size) == len(high_output_size) == depth ), "Output size must be same for both low and high input" assert output_dim % (2**depth) == 0, "Output dim must be divisible by 2**depth" super().__init__() self.proj_out_num = proj_out_num self.low_mixer = nn.ModuleList( [ MixerLayer( input_size=( (low_output_size[i - 1], output_dim // (2 ** (depth - i))) if i > 0 else low_input_size ), output_size=( low_output_size[i], output_dim // (2 ** (depth - i - 1)), ), mlp_depth=mlp_depth, ) for i in range(depth) ] ) self.high_mixer = nn.ModuleList( [ MixerLayer( input_size=( (high_output_size[i - 1], output_dim // (2 ** (depth - i))) if i > 0 else high_input_size ), output_size=( high_output_size[i], output_dim // (2 ** (depth - i - 1)), ), mlp_depth=mlp_depth, ) for i in range(depth) ] ) def forward(self, x): low_x, high_x = x for low_layer, high_layer in zip(self.low_mixer, self.high_mixer): low_x = low_layer(low_x) high_x = high_layer(high_x) x = torch.cat([low_x, high_x], dim=1) return x class IdentityMap(nn.Module): def __init__(self): super().__init__() def forward(self, x, *args, **kwargs): return x @property def config(self): return {"mm_projector_type": "identity"} def build_mm_projector(config, delay_load=False, **kwargs): projector_type = getattr(config, "mm_projector_type", "linear") if projector_type == "linear": return nn.Linear(config.mm_hidden_size, config.hidden_size) elif projector_type == "mlp": return MultiModalProjector( input_size=config.mm_hidden_size, output_size=config.hidden_size, mlp_depth=config.mm_mlp_depth, proj_out_num=config.proj_out_num, ) elif projector_type == "low_high_mlp": return LowHighHybridMLP( low_input_size=config.low_input_size, high_input_size=config.high_input_size, output_size=config.hidden_size, mlp_depth=config.mm_mlp_depth, proj_out_num=config.proj_out_num, ) elif projector_type == "mixer": return MixerLowHighHybridMLP( low_input_size=config.low_input_size, low_output_size=config.low_output_size, high_input_size=config.high_input_size, high_output_size=config.high_output_size, output_dim=config.hidden_size, depth=len(config.low_output_size), mlp_depth=config.mm_mlp_depth, proj_out_num=config.proj_out_num, ) elif projector_type == "mhsa": return MultiHeadSelfAttention( embed_dim=config.mm_hidden_size, output_dim=config.hidden_size, num_heads=hasattr(config, "num_heads") and config.num_heads or 8, proj_out_num=config.proj_out_num, ) elif projector_type == "identity": return IdentityMap() else: raise ValueError(f"Unknown projector type: {projector_type}") class VLMMetaModel: def __init__(self, config): super(VLMMetaModel, self).__init__(config) if hasattr(config, "vision_tower"): self.vision_tower = build_vision_tower(config, delay_load=True) self.mm_projector = build_mm_projector(config) def get_vision_tower(self): vision_tower = getattr(self, "vision_tower", None) if type(vision_tower) is list: vision_tower = vision_tower[0] return vision_tower def initialize_vision_modules(self, model_args): self.config.input_size = model_args.input_size self.config.patch_size = model_args.patch_size self.config.dim = model_args.dim self.config.depth = model_args.depth self.config.vision_tower = model_args.vision_tower self.config.vision_select_layer = model_args.vision_select_layer self.config.vision_select_feature = model_args.vision_select_feature self.config.mm_projector_type = model_args.mm_projector_type self.config.mm_mlp_depth = model_args.mm_mlp_depth self.config.proj_out_num = model_args.proj_out_num # vision tower if self.get_vision_tower() is None: self.vision_tower = build_vision_tower(self.config) self.vision_tower.requires_grad_(not model_args.freeze_vision_tower) if self.config.vision_tower == "hybrid": self.config.low_input_size = self.vision_tower.low_input_size self.config.high_input_size = self.vision_tower.high_input_size elif self.config.mm_projector_type == "mixer": self.config.low_output_size = model_args.low_output_size self.config.high_output_size = model_args.high_output_size self.config.low_input_size = (256, 384) self.config.high_input_size = (32, 768) if model_args.pretrain_vision_model is not None: vision_model_weights = torch.load( model_args.pretrain_vision_model, map_location="cpu" ) self.vision_tower.vision_tower.load_state_dict( vision_model_weights, strict=True ) if model_args.pretrain_clip_model is not None: clip_model = AutoModel.from_pretrained(model_args.pretrain_clip_model) self.vision_tower.vision_tower = clip_model.vision_encoder self.config.mm_hidden_size = self.vision_tower.hidden_size # mm_projector if getattr(self, "mm_projector", None) is None: self.mm_projector = build_mm_projector(self.config) if model_args.pretrain_mm_mlp_adapter is not None: mm_projector_weights = torch.load( model_args.pretrain_mm_mlp_adapter, map_location="cpu" ) if self.config.mm_projector_type == "mlp": def get_w(weights, keyword): return { f"{keyword}.{k.split(keyword + ".")[2]}": v for k, v in weights.items() if keyword in k } elif self.config.mm_projector_type == "low_high_mlp": def get_w(weights, keyword): result = {} for k, v in weights.items(): if keyword in k: if f"{keyword}.{keyword}" in k: part = k.split(f"{keyword}.{keyword}.")[1] result[f"mm_projector.{part}"] = v elif f"{keyword}." in k: part = k.split(f"{keyword}.")[1] result[part] = v return result elif self.config.mm_projector_type == "mixer": def get_w(weights, keyword): result = {} for k, v in weights.items(): if keyword in k: new_key = k.split(".") if len(new_key) > 2: new_key = ".".join(new_key[2:]) result[new_key] = v return result else: def get_w(weights, keyword): result = {} for k, v in weights.items(): if keyword in k: new_key = k.split(".") if len(new_key) > 2: new_key = ".".join(new_key[2:]) result[new_key] = v return result self.mm_projector.load_state_dict( get_w(mm_projector_weights, "mm_projector"), strict=True ) class VLMMetaForCausalLM(ABC): @abstractmethod def get_model(self): pass def get_vision_tower(self): return self.get_model().get_vision_tower() def encode_images(self, images): image_features = self.get_model().get_vision_tower()(images) image_features = self.get_model().mm_projector(image_features) return image_features def prepare_inputs_for_multimodal( self, input_ids, position_ids, attention_mask, past_key_values, labels, images, ): vision_tower = self.get_vision_tower() if vision_tower is None or images is None or input_ids.shape[1] == 1: return ( input_ids, position_ids, attention_mask, past_key_values, None, labels, ) else: image_features = self.encode_images(images) inputs_embeds = self.get_model().embed_tokens(input_ids) inputs_embeds = torch.cat( ( inputs_embeds[:, :1, :], image_features, inputs_embeds[:, (image_features.shape[1] + 1) :, :], ), dim=1, ) return ( None, position_ids, attention_mask, past_key_values, inputs_embeds, labels, ) def initialize_vision_tokenizer(self, model_args, tokenizer): num_new_tokens = model_args.num_new_tokens self.resize_token_embeddings(len(tokenizer)) if num_new_tokens > 0: input_embeddings = self.get_input_embeddings().weight.data output_embeddings = self.get_output_embeddings().weight.data input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True ) output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( dim=0, keepdim=True ) input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg if model_args.tune_mm_mlp_adapter: for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = False else: for p in self.get_input_embeddings().parameters(): p.requires_grad = True for p in self.get_output_embeddings().parameters(): p.requires_grad = True if model_args.pretrain_mm_mlp_adapter: mm_projector_weights = torch.load( model_args.pretrain_mm_mlp_adapter, map_location="cpu" ) embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] if input_embeddings.shape == embed_tokens_weight.shape: input_embeddings = embed_tokens_weight elif embed_tokens_weight.shape[0] == num_new_tokens: input_embeddings[-num_new_tokens:] = embed_tokens_weight else: raise ValueError( f"Unexpected embed_tokens_weight shape. " f"Pretrained: {embed_tokens_weight.shape}. " f"Current: {input_embeddings.shape}. " f"Number of new tokens: {num_new_tokens}." ) class VLMQwenConfig(Qwen2Config): model_type = "vlm_qwen" class VLMQwenModel(VLMMetaModel, Qwen2Model): config_class = VLMQwenConfig def __init__(self, config: Qwen2Config): super(VLMQwenModel, self).__init__(config) class VLMQwenForCausalLM(Qwen2ForCausalLM, VLMMetaForCausalLM): config_class = VLMQwenConfig def __init__(self, config): super(Qwen2ForCausalLM, self).__init__(config) self.model = VLMQwenModel(config) self.pretraining_tp = getattr(config, "pretraining_tp", None) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def get_model(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, image_sizes: Optional[List[List[int]]] = None, return_dict: Optional[bool] = None, num_logits_to_keep: Optional[int] = None, **kwargs: Any, ) -> Union[Tuple, CausalLMOutputWithPast]: if inputs_embeds is None: ( input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels, ) = self.prepare_inputs_for_multimodal( input_ids, position_ids, attention_mask, past_key_values, labels, images ) return super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) @torch.no_grad() def generate( self, images: Optional[torch.Tensor] = None, inputs: Optional[torch.Tensor] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor, Any]: position_ids = kwargs.pop("position_ids", None) attention_mask = kwargs.pop("attention_mask", None) if "inputs_embeds" in kwargs: raise NotImplementedError("`inputs_embeds` is not supported") if images is not None: (inputs, position_ids, attention_mask, _, inputs_embeds, _) = ( self.prepare_inputs_for_multimodal( inputs, position_ids, attention_mask, None, None, images, ) ) else: inputs_embeds = self.get_model().embed_tokens(inputs) output_ids = super().generate(inputs_embeds=inputs_embeds, **kwargs) return output_ids def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs ): images = kwargs.pop("images", None) inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs, ) if images is not None: inputs["images"] = images return inputs AutoConfig.register("vlm_qwen", VLMQwenConfig) AutoModelForCausalLM.register(VLMQwenConfig, VLMQwenForCausalLM)