Med3DVLM-Qwen-2.5-7B / modeling.py
MagicXin's picture
Upload folder using huggingface_hub
2620266 verified
raw
history blame
54.2 kB
import math
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import pack, rearrange, repeat, unpack
from einops.layers.torch import Rearrange
from timm.layers import DropPath, to_3tuple, trunc_normal_
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
PretrainedConfig,
PreTrainedModel,
Qwen2Config,
Qwen2ForCausalLM,
Qwen2Model,
)
from transformers.generation.utils import GenerateOutput
from transformers.modeling_outputs import CausalLMOutputWithPast
try:
import torch.distributed.nn
from torch import distributed as dist
has_distributed = True
except ImportError:
has_distributed = False
class DEC_CLIPConfig(PretrainedConfig):
model_type = "dec_clip"
def __init__(
self,
language_model_name_or_path: str = "",
local_loss: bool = False,
gather_loss: bool = True,
input_size: tuple = (256, 256, 128),
dim: int = 768,
depth: int = 12,
hidden_size: int = 512,
mlp_depth: int = 2,
loss_type: str = "nce",
t_prime: float = np.log(1 / 0.07),
bias: float = 0.0,
efficient_loss: bool = False,
**kwargs,
):
self.language_model_name_or_path = language_model_name_or_path
self.input_size = input_size
self.dim = dim
self.depth = depth
self.hidden_size = hidden_size
self.mlp_depth = mlp_depth
self.local_loss = local_loss
self.gather_loss = gather_loss
self.loss_type = loss_type
self.t_prime = t_prime
self.bias = bias
self.efficient_loss = efficient_loss
super().__init__(**kwargs)
class DEC_CLIP(PreTrainedModel):
config_class = DEC_CLIPConfig
def __init__(self, config):
super().__init__(config)
self.config = config
if config.vision_encoder == "vit3d":
self.vision_encoder = Vit3D(
input_size=config.input_size,
dim=config.dim,
depth=config.depth,
)
elif config.vision_encoder == "dcformer":
self.vision_encoder = decomp_small(input_size=config.input_size)
else:
raise ValueError(f"Unexpected vision encoder: {config.vision_encoder}")
self.language_encoder = AutoModel.from_pretrained(
config.language_model_name_or_path
)
self.mm_vision_proj = nn.Linear(
self.vision_encoder.channels[-1], config.hidden_size
)
self.mm_language_proj = nn.Linear(
self.language_encoder.config.dim, config.hidden_size
)
self.efficient_loss = config.efficient_loss
self.local_loss = config.local_loss
self.gather_loss = config.gather_loss
self.loss_type = config.loss_type
if self.loss_type == "sigmoid":
self.t_prime = nn.Parameter(torch.tensor(config.t_prime))
self.bias = nn.Parameter(torch.tensor(config.bias))
else:
self.logit_scale = nn.Parameter(torch.ones([]) * config.t_prime)
def encode_image(self, image):
image_feats = self.vision_encoder(image)
if isinstance(image_feats, list):
image_feats = image_feats[-1]
image_feats = image_feats.mean(dim=1)
image_feats = self.mm_vision_proj(image_feats)
image_feats = F.normalize(image_feats, dim=-1)
return image_feats
def encode_text(self, input_id, attention_mask):
text_feats = self.language_encoder(input_id, attention_mask=attention_mask)[
"last_hidden_state"
]
text_feats = text_feats[:, 0]
text_feats = self.mm_language_proj(text_feats)
text_feats = F.normalize(text_feats, dim=-1)
return text_feats
def forward(self, images, input_ids, attention_mask, labels, **kwargs):
image_features = self.encode_image(images)
text_features = self.encode_text(input_ids, attention_mask)
rank = 0
world_size = 1
if has_distributed and dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
batch_size = image_features.size(0)
device = image_features.device
if self.loss_type == "sigmoid":
if has_distributed and dist.is_initialized():
if self.efficient_loss:
t = torch.exp(self.t_prime)
loss = 0.0
for target_rank in range(world_size):
if rank == target_rank:
target_text_features = text_features
else:
target_text_features = torch.distributed.nn.broadcast(
text_features.requires_grad_(), target_rank
)
local_logits_per_image = (
image_features @ target_text_features.T
) * t + self.bias
local_logits_per_text = local_logits_per_image.T
if rank == target_rank:
local_labels = 2 * torch.eye(
batch_size, device=device
) - torch.ones(batch_size, batch_size, device=device)
else:
local_labels = -torch.ones(
batch_size, batch_size, device=device
)
local_logits = (
local_logits_per_image + local_logits_per_text
) / 2.0
local_loss = -torch.sum(
F.logsigmoid(local_labels * local_logits)
) / (batch_size * world_size)
loss += local_loss
torch.distributed.nn.all_reduce(loss)
torch.cuda.synchronize()
if self.training:
logits = 0
else:
t = torch.exp(self.t_prime)
all_image_features, all_text_features = gather_features(
image_features,
text_features,
gather_with_grad=True,
rank=rank,
world_size=world_size,
)
logits_per_image = (
all_image_features @ all_text_features.T
) * t + self.bias
logits_per_text = logits_per_image.T
batch_size = all_image_features.size(0)
labels = 2 * torch.eye(
batch_size, device=image_features.device
) - torch.ones(batch_size, device=image_features.device)
logits = (logits_per_image + logits_per_text) / 2.0
loss = -torch.sum(F.logsigmoid(labels * logits)) / batch_size
else:
logits_per_image = (
image_features @ text_features.T
) * self.t_prime + self.bias
logits_per_text = logits_per_image.T
labels = 2 * torch.eye(batch_size, device=device) - torch.ones(
batch_size, batch_size, device=device
)
logits = (logits_per_image + logits_per_text) / 2.0
loss = -torch.sum(F.logsigmoid(logits * labels))
else:
all_image_features, all_text_features = gather_features(
image_features,
text_features,
local_loss=self.local_loss,
gather_with_grad=True,
rank=rank,
world_size=world_size,
)
if self.gather_loss:
if self.local_loss:
logits_per_image = (
self.logit_scale * image_features @ all_text_features.T
)
logits_per_text = (
self.logit_scale * text_features @ all_image_features.T
)
else:
logits_per_image = (
self.logit_scale * all_image_features @ all_text_features.T
)
logits_per_text = logits_per_image.T
else:
logits_per_image = self.logit_scale * image_features @ text_features.T
logits_per_text = self.logit_scale * text_features @ image_features.T
image_loss = F.cross_entropy(logits_per_image, labels)
text_loss = F.cross_entropy(logits_per_text, labels)
loss = (image_loss + text_loss) / 2.0
logits = ((logits_per_image + logits_per_text) / 2.0,)
ret = {
"loss": loss,
"logits": logits,
}
return ret
def gather_features(
image_features,
text_features,
local_loss=False,
gather_with_grad=True,
rank=0,
world_size=1,
):
assert (
has_distributed
), "torch.distributed did not import correctly, please use a PyTorch version with support."
if not (has_distributed and dist.is_initialized()):
return image_features, text_features
if gather_with_grad:
all_image_features = torch.cat(
torch.distributed.nn.all_gather(image_features), dim=0
)
all_text_features = torch.cat(
torch.distributed.nn.all_gather(text_features), dim=0
)
else:
gathered_image_features = [
torch.zeros_like(image_features) for _ in range(world_size)
]
gathered_text_features = [
torch.zeros_like(text_features) for _ in range(world_size)
]
dist.all_gather(gathered_image_features, image_features)
dist.all_gather(gathered_text_features, text_features)
if not local_loss:
gathered_image_features[rank] = image_features
gathered_text_features[rank] = text_features
all_image_features = torch.cat(gathered_image_features, dim=0)
all_text_features = torch.cat(gathered_text_features, dim=0)
return all_image_features, all_text_features
AutoConfig.register("dec_clip", DEC_CLIPConfig)
AutoModel.register(DEC_CLIPConfig, DEC_CLIP)
def stem(inp, oup, image_size, downsample=False):
stride = 1 if downsample == False else 2
return nn.Sequential(
nn.Conv3d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm3d(oup),
nn.GELU(),
nn.Conv3d(oup, oup, 3, 1, 1, bias=False),
nn.BatchNorm3d(oup),
nn.GELU(),
)
def DecomposedStem(inp, oup, image_size, kernel_size, downsample=False):
return nn.Sequential(
DecompConv3D(inp, oup, 7, 4, 1, nn.GELU()),
DecompConv3D(oup, oup, 3, 1, 1, nn.GELU()),
DecompConv3D(oup, oup, 3, 1, 1, nn.GELU()),
DecompConv3D(oup, oup, 3, 1, 1, nn.GELU()),
)
class DecompConv3D(nn.Module):
def __init__(
self, in_dim, out_dim, kernel_size, stride=1, groups=1, norm=True, act=None
) -> None:
super().__init__()
self.act = act
self.c1 = nn.Sequential(
nn.Conv3d(
in_dim,
out_dim,
kernel_size=(kernel_size, 1, 1),
padding=(kernel_size // 2, 0, 0),
stride=stride,
groups=groups,
),
nn.BatchNorm3d(out_dim) if norm else nn.Identity(),
)
self.c2 = nn.Sequential(
nn.Conv3d(
in_dim,
out_dim,
kernel_size=(1, kernel_size, 1),
padding=(0, kernel_size // 2, 0),
stride=stride,
groups=groups,
),
nn.BatchNorm3d(out_dim) if norm else nn.Identity(),
)
self.c3 = nn.Sequential(
nn.Conv3d(
in_dim,
out_dim,
kernel_size=(1, 1, kernel_size),
padding=(0, 0, kernel_size // 2),
stride=stride,
groups=groups,
),
nn.BatchNorm3d(out_dim) if norm else nn.Identity(),
)
def forward(self, x):
x = self.c1(x) + self.c2(x) + self.c3(x)
if self.act is not None:
x = self.act(x)
return x
class ConvPosEnc(nn.Module):
def __init__(self, dim, k=3, decompose=False):
super().__init__()
if decompose:
self.proj = DecompConv3D(dim, dim, k, groups=dim, norm=None)
else:
self.proj = nn.Conv3d(
dim, dim, to_3tuple(k), to_3tuple(1), to_3tuple(k // 2), groups=dim
)
def forward(self, x, size):
B, N, C = x.shape
H, W, T = size
assert N == H * W * T
feat = rearrange(x, "b (h w t) c -> b c h w t", h=H, w=W, t=T)
feat = self.proj(feat)
feat = rearrange(feat, "b c h w t -> b (h w t) c ")
x = x + feat
return x
class MLP(nn.Module):
def __init__(self, oup, mlp_dim, dp=0.0):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(oup, mlp_dim),
nn.GELU(),
nn.Dropout(dp),
nn.Linear(mlp_dim, oup),
nn.Dropout(dp),
)
def forward(self, x):
x = self.mlp(x)
return x
class ScaleDotProduct(nn.Module):
def __init__(self, scale):
super().__init__()
self.scale = scale
self.softmax = nn.Softmax(dim=-1)
def forward(self, qkv):
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = self.softmax(attn)
x = (attn @ v).transpose(1, 2)
return x
class DecomposedAttention(nn.Module):
def __init__(self, oup, head_num):
super().__init__()
self.head_num = head_num
scale = (oup // head_num) ** (1 / 2)
self.sdp = ScaleDotProduct(scale)
self.qkv = nn.Linear(oup, oup * 3, bias=False)
self.proj = nn.Linear(oup, oup, bias=False)
def forward(self, x, size):
b, n, c = x.shape
h, w, t = size
assert n == h * w * t
B_, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B_, N, 3, self.head_num, C // self.head_num)
.permute(2, 0, 3, 1, 4)
)
x = rearrange(qkv, "k b nh (h w t) c -> k b c nh h w t", h=h, w=w, t=t)
x1 = rearrange(x, "k b c nh h w t -> k (b t) nh (h w) c")
x2 = rearrange(x, "k b c nh h w t -> k (b w) nh (h t) c")
x3 = rearrange(x, "k b c nh h w t -> k (b h) nh (w t) c")
x1 = self.sdp(x1)
x2 = self.sdp(x2)
x3 = self.sdp(x3)
x1 = rearrange(x1, "(b t) (h w) nh c -> b (h w t) (nh c)", h=h, w=w, t=t)
x2 = rearrange(x2, "(b w) (h t) nh c -> b (h w t) (nh c)", h=h, w=w, t=t)
x3 = rearrange(x3, "(b h) (w t) nh c -> b (h w t) (nh c)", h=h, w=w, t=t)
x = self.proj(x1 + x2 + x3)
return x
class SelfAttention(nn.Module):
def __init__(self, oup, head_num):
super().__init__()
self.head_num = head_num
scale = (oup // head_num) ** (1 / 2)
self.sdp = ScaleDotProduct(scale)
self.qkv = nn.Linear(oup, oup * 3, bias=False)
self.proj = nn.Linear(oup, oup, bias=False)
def forward(self, x, size=None):
B_, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B_, N, 3, self.head_num, C // self.head_num)
.permute(2, 0, 3, 1, 4)
)
x = self.sdp(qkv).reshape(B_, N, C)
x = self.proj(x)
return x
class ChannelAttention(nn.Module):
def __init__(self, oup, head_num):
super().__init__()
self.head_num = head_num
self.scale = (oup // head_num) ** (1 / 2)
self.softmax = nn.Softmax(dim=-1)
self.qkv = nn.Linear(oup, oup * 3, bias=False)
self.proj = nn.Linear(oup, oup, bias=False)
def forward(self, x):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.head_num, C // self.head_num)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv[0], qkv[1], qkv[2]
k = k * self.scale
attention = k.transpose(-1, -2) @ v
attention = self.softmax(attention)
x = (attention @ q.transpose(-1, -2)).transpose(-1, -2)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class ChannelBlock(nn.Module):
def __init__(self, dim, heads=8):
super().__init__()
hidden_dim = int(dim * 4)
self.cpe = nn.ModuleList(
[
ConvPosEnc(dim=dim, k=3, decompose=True),
ConvPosEnc(dim=dim, k=3, decompose=True),
]
)
self.attn = ChannelAttention(dim, heads)
self.layer_norm1 = nn.LayerNorm(dim)
self.mlp1 = MLP(dim, hidden_dim)
self.layer_norm2 = nn.LayerNorm(dim)
def forward(self, x, size):
x = self.cpe[0](x, size)
_x = self.layer_norm1(x)
_x = self.attn(_x)
x = x + _x
x = self.cpe[1](x, size)
_x = self.layer_norm2(x)
_x = self.mlp1(_x)
x = x + _x
return x
class SpatialBlock(nn.Module):
def __init__(self, dim, heads=8):
super().__init__()
hidden_dim = int(dim * 4)
self.cpe = nn.ModuleList(
[
ConvPosEnc(dim=dim, k=3, decompose=True),
ConvPosEnc(dim=dim, k=3, decompose=True),
]
)
# self.attn = DecomposedAttention(dim, heads)
self.attn = SelfAttention(dim, heads)
self.layer_norm1 = nn.LayerNorm(dim)
self.mlp1 = MLP(dim, hidden_dim)
self.layer_norm2 = nn.LayerNorm(dim)
def forward(self, x, size):
x = self.cpe[0](x, size)
_x = self.layer_norm1(x)
_x = self.attn(_x, size)
x = x + _x
x = self.cpe[1](x, size)
_x = self.layer_norm2(x)
_x = self.mlp1(_x)
x = x + _x
return x
class TransformerBlock(nn.Module):
def __init__(
self,
inp,
oup,
image_size,
kernel_size,
heads=8,
dim_head=32,
downsample=False,
dropout=0.0,
):
super().__init__()
hidden_dim = int(inp * 4)
self.ih, self.iw, self.it = image_size
self.downsample = downsample
if self.downsample:
self.pool1 = nn.MaxPool3d(3, 2, 1)
# self.pool2 = nn.MaxPool3d(3, 2, 1)
self.proj = nn.Conv3d(inp, oup, 1, 1, 0, bias=False)
self.spatial_attention = SpatialBlock(oup, heads)
# self.channel_attention = ChannelBlock(oup, heads)
def forward(self, x):
if self.downsample:
x = self.proj(self.pool1(x))
# if self.downsample:
# x = self.proj(self.pool1(x)) + self.attn(self.pool2(x))
# x = x.permute(0, 2, 3, 4, 1)
h, w, t = x.shape[2], x.shape[3], x.shape[4]
size = (h, w, t)
x = rearrange(x, "b c h w t -> b (h w t) c ")
x = self.spatial_attention(x, size)
# x = self.channel_attention(x, size)
x = rearrange(x, "b (h w t) c -> b c h w t", h=h, w=w, t=t)
return x
class ConvBlock(nn.Module):
def __init__(
self, inp, oup, image_size, kernel_size=7, downsample=False, expansion=4
):
super().__init__()
self.downsample = downsample
stride = 1 if self.downsample == False else 2
hidden_dim = int(oup * expansion)
drop_path = 0.0
layer_scale_init_value = 1e-6
if self.downsample:
self.pool = nn.MaxPool3d(3, 2, 1)
self.proj = nn.Conv3d(inp, oup, 1, 1, 0, bias=False)
# self.dwconv = nn.Sequential(nn.Conv3d(oup, oup, kernel_size=7, padding=3, groups=oup), nn.BatchNorm3d(oup)) # depthwise conv
self.dwconv = DecompConv3D(oup, oup, kernel_size, groups=oup)
self.mlp = MLP(oup, hidden_dim)
self.scale = (
nn.Parameter(layer_scale_init_value * torch.ones((oup)), requires_grad=True)
if layer_scale_init_value > 0
else None
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x):
if self.downsample:
x = self.proj(self.pool(x))
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 4, 1) # (N, C, H, W, T) -> (N, H, W, T, C)
x = self.mlp(x)
if self.scale is not None:
x = self.scale * x
x = x.permute(0, 4, 1, 2, 3) # (N, H, W, C) -> (N, C, H, W)
x = input + self.drop_path(x)
return x
class Encoder(nn.Module):
def __init__(
self,
input_size,
in_channels,
num_blocks,
channels,
kernel_sizes=[7, 7, 7, 7],
block_types=["C", "C", "C", "C"],
):
super().__init__()
self.dims = channels
ih, iw, it = input_size
block = {"C": ConvBlock, "T": TransformerBlock}
i = 4
self.s0 = self._make_layer(
DecomposedStem,
in_channels,
channels[0],
num_blocks[0],
kernel_sizes[0],
(ih // i, iw // i, it // i),
)
self.s1 = self._make_layer(
block[block_types[0]],
channels[0],
channels[1],
num_blocks[1],
kernel_sizes[0],
(ih // (i * 2**1), iw // (i * 2**1), it // (i * 2**1)),
)
self.s2 = self._make_layer(
block[block_types[1]],
channels[1],
channels[2],
num_blocks[2],
kernel_sizes[1],
(ih // (i * 2**2), iw // (i * 2**2), it // (i * 2**2)),
)
self.s3 = self._make_layer(
block[block_types[2]],
channels[2],
channels[3],
num_blocks[3],
kernel_sizes[2],
(ih // (i * 2**3), iw // (i * 2**3), it // (i * 2**3)),
)
self.s4 = self._make_layer(
block[block_types[3]],
channels[3],
channels[4],
num_blocks[4],
kernel_sizes[3],
(ih // (i * 2**4), iw // (i * 2**4), it // (i * 2**4)),
)
def forward(self, x):
hidden_states = []
x = x.permute(0, 1, 3, 4, 2)
for i in range(5):
if hasattr(self, "s" + str(i)):
x = getattr(self, "s" + str(i))(x)
hidden_states.append(x)
return hidden_states
def _make_layer(self, block, inp, oup, depth, kernel_size, image_size):
layers = nn.ModuleList([])
for i in range(depth):
if i == 0:
layers.append(block(inp, oup, image_size, kernel_size, downsample=True))
else:
layers.append(block(oup, oup, image_size, kernel_size))
return nn.Sequential(*layers)
class DecompModel(nn.Module):
def __init__(
self,
input_size=(512, 512, 256),
in_channels=1,
num_blocks=[2, 2, 3, 5, 2],
channels=[64, 96, 192, 384, 768],
# kernel_sizes=[7, 7, 7, 7],
kernel_sizes=[13, 11, 9, 7],
block_types=["C", "C", "C", "C"],
codebook_size=8192,
):
super().__init__()
self.channels = channels
self.encoder = Encoder(
input_size, in_channels, num_blocks, channels, kernel_sizes, block_types
)
# self.vq = VectorQuantize(dim = channels[-1], codebook_size = codebook_size, use_cosine_sim = True)
def forward(self, video, mask=None, device="cuda"):
hidden_states = self.encoder(video)
# tokens = rearrange(tokens, "b d h w t -> b t h w d")
# shape = tokens.shape
# *_, h, w, _ = shape
# quantize
# tokens, _ = pack([tokens], "b * d")
for i in range(len(hidden_states)):
hidden_states[i] = rearrange(hidden_states[i], "b d h w t -> b t h w d")
hidden_states[i], _ = pack([hidden_states[i]], "b * d")
# vq_mask = None
# tokens, _, _ = self.vq(tokens, mask = vq_mask)
# tokens = rearrange(tokens, 'b (t h w) d -> b t h w d', h = h, w = w)
return hidden_states
def decomp_nano(
input_size=(512, 512, 256),
# input_size=(256, 256, 128),
):
model = DecompModel(
input_size=input_size,
num_blocks=[1, 1, 1, 1, 1],
channels=[32, 32, 64, 128, 256],
)
return model
def decomp_naive(
input_size=(512, 512, 256),
# input_size=(256, 256, 128),
):
model = DecompModel(
input_size=input_size,
num_blocks=[1, 2, 2, 2, 2],
# channels = [64, 64, 128, 256, 512]
channels=[32, 64, 128, 256, 512],
)
return model
def decomp_tiny(
input_size=(512, 512, 256),
):
model = DecompModel(
input_size=input_size,
num_blocks=[1, 2, 3, 3, 2],
# channels = [64, 64, 128, 256, 512]
channels=[64, 96, 192, 384, 768],
)
return model
def decomp_small(
input_size=(512, 512, 256),
):
model = DecompModel(
input_size=input_size,
num_blocks=[1, 2, 3, 6, 2],
channels=[64, 96, 192, 384, 768],
)
return model
def decomp_base(
input_size=(512, 512, 256),
):
model = DecompModel(
input_size=input_size,
num_blocks=[1, 2, 6, 6, 2],
channels=[64, 128, 256, 512, 1024],
)
return model
def decomp_large(
input_size=(512, 512, 256),
):
model = DecompModel(
input_size=input_size,
num_blocks=[1, 2, 6, 12, 2],
# channels=[64, 192, 384, 768, 1536],
channels=[64, 256, 512, 1024, 2048],
)
return model
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head**-0.5
self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
if project_out
else nn.Identity()
)
def forward(self, x):
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout),
FeedForward(dim, mlp_dim, dropout=dropout),
]
)
)
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class ViTEncoder(nn.Module):
def __init__(
self,
image_size=[512, 512, 256],
patch_size=16,
dim=512,
depth=8,
heads=8,
mlp_dim=4,
channels=1,
dim_head=64,
dropout=0.0,
emb_dropout=0.0,
):
super().__init__()
h, w, t = image_size[0], image_size[1], image_size[2]
self.vit_img_dim = [i // patch_size for i in image_size]
num_patches = (h // patch_size) * (w // patch_size) * (t // patch_size)
patch_dim = channels * patch_size * patch_size * patch_size
self.to_patch_embedding = nn.Sequential(
Rearrange(
"b c (h p1) (w p2) (t p3) -> b (h w t) (p1 p2 p3 c)",
p1=patch_size,
p2=patch_size,
p3=patch_size,
),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
def forward(self, x):
x = x.permute(0, 1, 3, 4, 2)
x = self.to_patch_embedding(x)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, "1 1 d -> b 1 d", b=b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, : (n + 1)]
x = self.dropout(x)
x = self.transformer(x)
x = x[:, 1:, :]
x = rearrange(
x,
"b (x y z) c -> b c x y z",
x=self.vit_img_dim[0],
y=self.vit_img_dim[1],
z=self.vit_img_dim[2],
)
return x
class Vit3D(nn.Module):
def __init__(self, input_size=[512, 512, 256], patch_size=32, dim=512, depth=8):
super().__init__()
self.encoder = ViTEncoder(input_size, patch_size, dim, depth)
# self.vq = VectorQuantize(dim = dim, codebook_size = 8192, use_cosine_sim = True)
def forward(self, video, mask=None, device="cuda"):
tokens = self.encoder(video)
tokens = rearrange(tokens, "b d h w t -> b t h w d")
shape = tokens.shape
*_, h, w, _ = shape
# quantize
tokens, _ = pack([tokens], "b * d")
# vq_mask = None
# tokens, _, _ = self.vq(tokens, mask = vq_mask)
# tokens = rearrange(tokens, 'b (t h w) d -> b t h w d', h = h, w = w)
return tokens
def build_vision_tower(config, **kwargs):
return VisionTower(config)
class VisionTower(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.select_layer = config.vision_select_layer
self.select_feature = config.vision_select_feature
self.hidden_size = config.dim
if config.vision_tower == "vit3d":
self.vision_tower = Vit3D(
input_size=config.input_size,
dim=config.dim,
depth=config.depth,
)
elif config.vision_tower == "dcformer":
self.vision_tower = decomp_small(
input_size=config.input_size,
)
self.low_input_size = self.vision_tower.channels[-2]
self.high_input_size = self.vision_tower.channels[-1]
else:
raise ValueError(f"Unexpected vision tower: {config.vision_tower}")
def forward(self, images):
hidden_states = self.vision_tower(images)
if self.select_layer == 0:
image_features = hidden_states
elif self.select_layer < 0:
image_features = hidden_states[self.select_layer :]
else:
raise ValueError(f"Unexpected select layer: {self.select_layer}")
if self.select_feature == "patch":
image_features = image_features[:, 1:]
elif self.select_feature == "cls_patch":
image_features = image_features
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
return image_features
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
def readable_params(num):
magnitude = 0
while abs(num) >= 1000:
magnitude += 1
num /= 1000.0
return "%.2f%s" % (num, ["", "K", "M", "G", "T", "P"][magnitude])
class MLPLayer(nn.Module):
def __init__(self, embed_dim, scale=4, *args, **kwargs):
super().__init__(*args, **kwargs)
self.linear1 = nn.Linear(embed_dim, embed_dim * scale)
self.linear2 = nn.Linear(embed_dim * scale, embed_dim)
self.act = nn.GELU()
def forward(self, x):
x = self.linear1(x)
x = self.act(x)
x = self.linear2(x)
return x
class MultiHeadSelfAttention(nn.Module):
def __init__(
self,
embed_dim,
output_dim,
num_heads=8,
proj_out_num=32,
):
super(MultiHeadSelfAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.proj_out_num = proj_out_num
self.mlp = MLPLayer(embed_dim)
assert (
self.head_dim * num_heads == embed_dim
), "embed_dim must be divisible by num_heads"
self.norm1 = nn.LayerNorm(embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.act = nn.GELU()
self.out_layer = nn.Linear(embed_dim, output_dim)
self.scale = math.sqrt(self.head_dim)
def forward(self, x):
batch_size, seq_len, _ = x.size()
x = self.norm1(x)
q = (
self.q_proj(x)
.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
k = (
self.k_proj(x)
.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
v = (
self.v_proj(x)
.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale
attn_weights = F.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(1, 2).reshape(
batch_size, seq_len, self.embed_dim
)
output = self.out_proj(attn_output)
output = self.norm2(output)
output = self.mlp(output)
output = self.act(output)
output = self.out_layer(output)
return output
class MultiLayerPerceptron(nn.Module):
def __init__(self, hidden_size, depth):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
*[
nn.Sequential(nn.GELU(), nn.Linear(hidden_size, hidden_size))
for _ in range(depth - 1)
],
)
def forward(self, x):
return self.mlp(x)
class MultiModalProjector(nn.Module):
def __init__(self, input_size, output_size, mlp_depth, proj_out_num=256):
super().__init__()
self.proj_out_num = proj_out_num
self.mm_projector = nn.Sequential(
nn.Linear(input_size, output_size),
*[
nn.Sequential(
nn.GELU(),
nn.Linear(output_size, output_size),
)
for _ in range(mlp_depth - 1)
],
)
def forward(self, x):
return self.mm_projector(x)
class LowHighHybridMLP(nn.Module):
def __init__(
self, low_input_size, high_input_size, output_size, mlp_depth, proj_out_num=288
):
super().__init__()
self.proj_out_num = proj_out_num
self.low_up_mlp = nn.Linear(low_input_size, output_size)
self.high_up_mlp = nn.Linear(high_input_size, output_size)
modules = []
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(output_size, output_size))
self.mm_projector = nn.Sequential(*modules)
def forward(self, x):
low_x, high_x = x
low_x = self.low_up_mlp(low_x)
high_x = self.high_up_mlp(high_x)
x = torch.cat([low_x, high_x], dim=1)
x = self.mm_projector(x)
return x
class MixerLayer(nn.Module):
def __init__(self, input_size, output_size, mlp_depth=2):
super().__init__()
self.ln1 = nn.LayerNorm(input_size[1])
self.ln2 = nn.LayerNorm(input_size[1])
self.mlp1 = MultiModalProjector(
input_size=input_size[0], output_size=output_size[0], mlp_depth=mlp_depth
)
self.mlp2 = MultiModalProjector(
input_size=input_size[1], output_size=output_size[1], mlp_depth=mlp_depth
)
def forward(self, x):
x = self.ln1(x)
x = rearrange(x, "b n d -> b d n")
x = self.mlp1(x)
x = rearrange(x, "b d n -> b n d")
x = self.ln2(x)
x = self.mlp2(x)
return x
class MixerLowHighHybridMLP(nn.Module):
def __init__(
self,
low_input_size: tuple = (256, 384),
low_output_size: list = [192, 128],
high_input_size: tuple = (32, 768),
high_output_size: list = [64, 128],
output_dim=3584,
depth=2,
mlp_depth=2,
proj_out_num=256,
):
assert (
len(low_output_size) == len(high_output_size) == depth
), "Output size must be same for both low and high input"
assert output_dim % (2**depth) == 0, "Output dim must be divisible by 2**depth"
super().__init__()
self.proj_out_num = proj_out_num
self.low_mixer = nn.ModuleList(
[
MixerLayer(
input_size=(
(low_output_size[i - 1], output_dim // (2 ** (depth - i)))
if i > 0
else low_input_size
),
output_size=(
low_output_size[i],
output_dim // (2 ** (depth - i - 1)),
),
mlp_depth=mlp_depth,
)
for i in range(depth)
]
)
self.high_mixer = nn.ModuleList(
[
MixerLayer(
input_size=(
(high_output_size[i - 1], output_dim // (2 ** (depth - i)))
if i > 0
else high_input_size
),
output_size=(
high_output_size[i],
output_dim // (2 ** (depth - i - 1)),
),
mlp_depth=mlp_depth,
)
for i in range(depth)
]
)
def forward(self, x):
low_x, high_x = x
for low_layer, high_layer in zip(self.low_mixer, self.high_mixer):
low_x = low_layer(low_x)
high_x = high_layer(high_x)
x = torch.cat([low_x, high_x], dim=1)
return x
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": "identity"}
def build_mm_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, "mm_projector_type", "linear")
if projector_type == "linear":
return nn.Linear(config.mm_hidden_size, config.hidden_size)
elif projector_type == "mlp":
return MultiModalProjector(
input_size=config.mm_hidden_size,
output_size=config.hidden_size,
mlp_depth=config.mm_mlp_depth,
proj_out_num=config.proj_out_num,
)
elif projector_type == "low_high_mlp":
return LowHighHybridMLP(
low_input_size=config.low_input_size,
high_input_size=config.high_input_size,
output_size=config.hidden_size,
mlp_depth=config.mm_mlp_depth,
proj_out_num=config.proj_out_num,
)
elif projector_type == "mixer":
return MixerLowHighHybridMLP(
low_input_size=config.low_input_size,
low_output_size=config.low_output_size,
high_input_size=config.high_input_size,
high_output_size=config.high_output_size,
output_dim=config.hidden_size,
depth=len(config.low_output_size),
mlp_depth=config.mm_mlp_depth,
proj_out_num=config.proj_out_num,
)
elif projector_type == "mhsa":
return MultiHeadSelfAttention(
embed_dim=config.mm_hidden_size,
output_dim=config.hidden_size,
num_heads=hasattr(config, "num_heads") and config.num_heads or 8,
proj_out_num=config.proj_out_num,
)
elif projector_type == "identity":
return IdentityMap()
else:
raise ValueError(f"Unknown projector type: {projector_type}")
class VLMMetaModel:
def __init__(self, config):
super(VLMMetaModel, self).__init__(config)
if hasattr(config, "vision_tower"):
self.vision_tower = build_vision_tower(config, delay_load=True)
self.mm_projector = build_mm_projector(config)
def get_vision_tower(self):
vision_tower = getattr(self, "vision_tower", None)
if type(vision_tower) is list:
vision_tower = vision_tower[0]
return vision_tower
def initialize_vision_modules(self, model_args):
self.config.input_size = model_args.input_size
self.config.patch_size = model_args.patch_size
self.config.dim = model_args.dim
self.config.depth = model_args.depth
self.config.vision_tower = model_args.vision_tower
self.config.vision_select_layer = model_args.vision_select_layer
self.config.vision_select_feature = model_args.vision_select_feature
self.config.mm_projector_type = model_args.mm_projector_type
self.config.mm_mlp_depth = model_args.mm_mlp_depth
self.config.proj_out_num = model_args.proj_out_num
# vision tower
if self.get_vision_tower() is None:
self.vision_tower = build_vision_tower(self.config)
self.vision_tower.requires_grad_(not model_args.freeze_vision_tower)
if self.config.vision_tower == "hybrid":
self.config.low_input_size = self.vision_tower.low_input_size
self.config.high_input_size = self.vision_tower.high_input_size
elif self.config.mm_projector_type == "mixer":
self.config.low_output_size = model_args.low_output_size
self.config.high_output_size = model_args.high_output_size
self.config.low_input_size = (256, 384)
self.config.high_input_size = (32, 768)
if model_args.pretrain_vision_model is not None:
vision_model_weights = torch.load(
model_args.pretrain_vision_model, map_location="cpu"
)
self.vision_tower.vision_tower.load_state_dict(
vision_model_weights, strict=True
)
if model_args.pretrain_clip_model is not None:
clip_model = AutoModel.from_pretrained(model_args.pretrain_clip_model)
self.vision_tower.vision_tower = clip_model.vision_encoder
self.config.mm_hidden_size = self.vision_tower.hidden_size
# mm_projector
if getattr(self, "mm_projector", None) is None:
self.mm_projector = build_mm_projector(self.config)
if model_args.pretrain_mm_mlp_adapter is not None:
mm_projector_weights = torch.load(
model_args.pretrain_mm_mlp_adapter, map_location="cpu"
)
if self.config.mm_projector_type == "mlp":
def get_w(weights, keyword):
return {
f"{keyword}.{k.split(keyword + ".")[2]}": v
for k, v in weights.items()
if keyword in k
}
elif self.config.mm_projector_type == "low_high_mlp":
def get_w(weights, keyword):
result = {}
for k, v in weights.items():
if keyword in k:
if f"{keyword}.{keyword}" in k:
part = k.split(f"{keyword}.{keyword}.")[1]
result[f"mm_projector.{part}"] = v
elif f"{keyword}." in k:
part = k.split(f"{keyword}.")[1]
result[part] = v
return result
elif self.config.mm_projector_type == "mixer":
def get_w(weights, keyword):
result = {}
for k, v in weights.items():
if keyword in k:
new_key = k.split(".")
if len(new_key) > 2:
new_key = ".".join(new_key[2:])
result[new_key] = v
return result
else:
def get_w(weights, keyword):
result = {}
for k, v in weights.items():
if keyword in k:
new_key = k.split(".")
if len(new_key) > 2:
new_key = ".".join(new_key[2:])
result[new_key] = v
return result
self.mm_projector.load_state_dict(
get_w(mm_projector_weights, "mm_projector"), strict=True
)
class VLMMetaForCausalLM(ABC):
@abstractmethod
def get_model(self):
pass
def get_vision_tower(self):
return self.get_model().get_vision_tower()
def encode_images(self, images):
image_features = self.get_model().get_vision_tower()(images)
image_features = self.get_model().mm_projector(image_features)
return image_features
def prepare_inputs_for_multimodal(
self,
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images,
):
vision_tower = self.get_vision_tower()
if vision_tower is None or images is None or input_ids.shape[1] == 1:
return (
input_ids,
position_ids,
attention_mask,
past_key_values,
None,
labels,
)
else:
image_features = self.encode_images(images)
inputs_embeds = self.get_model().embed_tokens(input_ids)
inputs_embeds = torch.cat(
(
inputs_embeds[:, :1, :],
image_features,
inputs_embeds[:, (image_features.shape[1] + 1) :, :],
),
dim=1,
)
return (
None,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
)
def initialize_vision_tokenizer(self, model_args, tokenizer):
num_new_tokens = model_args.num_new_tokens
self.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
if model_args.tune_mm_mlp_adapter:
for p in self.get_input_embeddings().parameters():
p.requires_grad = True
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
else:
for p in self.get_input_embeddings().parameters():
p.requires_grad = True
for p in self.get_output_embeddings().parameters():
p.requires_grad = True
if model_args.pretrain_mm_mlp_adapter:
mm_projector_weights = torch.load(
model_args.pretrain_mm_mlp_adapter, map_location="cpu"
)
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
if input_embeddings.shape == embed_tokens_weight.shape:
input_embeddings = embed_tokens_weight
elif embed_tokens_weight.shape[0] == num_new_tokens:
input_embeddings[-num_new_tokens:] = embed_tokens_weight
else:
raise ValueError(
f"Unexpected embed_tokens_weight shape. "
f"Pretrained: {embed_tokens_weight.shape}. "
f"Current: {input_embeddings.shape}. "
f"Number of new tokens: {num_new_tokens}."
)
class VLMQwenConfig(Qwen2Config):
model_type = "vlm_qwen"
class VLMQwenModel(VLMMetaModel, Qwen2Model):
config_class = VLMQwenConfig
def __init__(self, config: Qwen2Config):
super(VLMQwenModel, self).__init__(config)
class VLMQwenForCausalLM(Qwen2ForCausalLM, VLMMetaForCausalLM):
config_class = VLMQwenConfig
def __init__(self, config):
super(Qwen2ForCausalLM, self).__init__(config)
self.model = VLMQwenModel(config)
self.pretraining_tp = getattr(config, "pretraining_tp", None)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
num_logits_to_keep: Optional[int] = None,
**kwargs: Any,
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
) = self.prepare_inputs_for_multimodal(
input_ids, position_ids, attention_mask, past_key_values, labels, images
)
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
@torch.no_grad()
def generate(
self,
images: Optional[torch.Tensor] = None,
inputs: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor, Any]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if images is not None:
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = (
self.prepare_inputs_for_multimodal(
inputs,
position_ids,
attention_mask,
None,
None,
images,
)
)
else:
inputs_embeds = self.get_model().embed_tokens(inputs)
output_ids = super().generate(inputs_embeds=inputs_embeds, **kwargs)
return output_ids
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
):
images = kwargs.pop("images", None)
inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
**kwargs,
)
if images is not None:
inputs["images"] = images
return inputs
AutoConfig.register("vlm_qwen", VLMQwenConfig)
AutoModelForCausalLM.register(VLMQwenConfig, VLMQwenForCausalLM)