|
|
import math |
|
|
from abc import ABC, abstractmethod |
|
|
from typing import Any, List, Optional, Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from einops import pack, rearrange, repeat, unpack |
|
|
from einops.layers.torch import Rearrange |
|
|
from timm.layers import DropPath, to_3tuple, trunc_normal_ |
|
|
from transformers import ( |
|
|
AutoConfig, |
|
|
AutoModel, |
|
|
AutoModelForCausalLM, |
|
|
PretrainedConfig, |
|
|
PreTrainedModel, |
|
|
Qwen2Config, |
|
|
Qwen2ForCausalLM, |
|
|
Qwen2Model, |
|
|
) |
|
|
from transformers.generation.utils import GenerateOutput |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
|
|
try: |
|
|
import torch.distributed.nn |
|
|
from torch import distributed as dist |
|
|
|
|
|
has_distributed = True |
|
|
except ImportError: |
|
|
has_distributed = False |
|
|
|
|
|
|
|
|
class DEC_CLIPConfig(PretrainedConfig): |
|
|
model_type = "dec_clip" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
language_model_name_or_path: str = "", |
|
|
local_loss: bool = False, |
|
|
gather_loss: bool = True, |
|
|
input_size: tuple = (256, 256, 128), |
|
|
dim: int = 768, |
|
|
depth: int = 12, |
|
|
hidden_size: int = 512, |
|
|
mlp_depth: int = 2, |
|
|
loss_type: str = "nce", |
|
|
t_prime: float = np.log(1 / 0.07), |
|
|
bias: float = 0.0, |
|
|
efficient_loss: bool = False, |
|
|
**kwargs, |
|
|
): |
|
|
self.language_model_name_or_path = language_model_name_or_path |
|
|
self.input_size = input_size |
|
|
self.dim = dim |
|
|
self.depth = depth |
|
|
self.hidden_size = hidden_size |
|
|
self.mlp_depth = mlp_depth |
|
|
self.local_loss = local_loss |
|
|
self.gather_loss = gather_loss |
|
|
self.loss_type = loss_type |
|
|
self.t_prime = t_prime |
|
|
self.bias = bias |
|
|
self.efficient_loss = efficient_loss |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
class DEC_CLIP(PreTrainedModel): |
|
|
config_class = DEC_CLIPConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
self.config = config |
|
|
|
|
|
if config.vision_encoder == "vit3d": |
|
|
self.vision_encoder = Vit3D( |
|
|
input_size=config.input_size, |
|
|
dim=config.dim, |
|
|
depth=config.depth, |
|
|
) |
|
|
elif config.vision_encoder == "dcformer": |
|
|
self.vision_encoder = decomp_small(input_size=config.input_size) |
|
|
else: |
|
|
raise ValueError(f"Unexpected vision encoder: {config.vision_encoder}") |
|
|
|
|
|
self.language_encoder = AutoModel.from_pretrained( |
|
|
config.language_model_name_or_path |
|
|
) |
|
|
|
|
|
self.mm_vision_proj = nn.Linear( |
|
|
self.vision_encoder.channels[-1], config.hidden_size |
|
|
) |
|
|
self.mm_language_proj = nn.Linear( |
|
|
self.language_encoder.config.dim, config.hidden_size |
|
|
) |
|
|
|
|
|
self.efficient_loss = config.efficient_loss |
|
|
self.local_loss = config.local_loss |
|
|
self.gather_loss = config.gather_loss |
|
|
self.loss_type = config.loss_type |
|
|
|
|
|
if self.loss_type == "sigmoid": |
|
|
self.t_prime = nn.Parameter(torch.tensor(config.t_prime)) |
|
|
self.bias = nn.Parameter(torch.tensor(config.bias)) |
|
|
else: |
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * config.t_prime) |
|
|
|
|
|
def encode_image(self, image): |
|
|
image_feats = self.vision_encoder(image) |
|
|
if isinstance(image_feats, list): |
|
|
image_feats = image_feats[-1] |
|
|
image_feats = image_feats.mean(dim=1) |
|
|
image_feats = self.mm_vision_proj(image_feats) |
|
|
image_feats = F.normalize(image_feats, dim=-1) |
|
|
|
|
|
return image_feats |
|
|
|
|
|
def encode_text(self, input_id, attention_mask): |
|
|
text_feats = self.language_encoder(input_id, attention_mask=attention_mask)[ |
|
|
"last_hidden_state" |
|
|
] |
|
|
text_feats = text_feats[:, 0] |
|
|
text_feats = self.mm_language_proj(text_feats) |
|
|
text_feats = F.normalize(text_feats, dim=-1) |
|
|
|
|
|
return text_feats |
|
|
|
|
|
def forward(self, images, input_ids, attention_mask, labels, **kwargs): |
|
|
image_features = self.encode_image(images) |
|
|
text_features = self.encode_text(input_ids, attention_mask) |
|
|
|
|
|
rank = 0 |
|
|
world_size = 1 |
|
|
if has_distributed and dist.is_initialized(): |
|
|
rank = dist.get_rank() |
|
|
world_size = dist.get_world_size() |
|
|
|
|
|
batch_size = image_features.size(0) |
|
|
device = image_features.device |
|
|
if self.loss_type == "sigmoid": |
|
|
if has_distributed and dist.is_initialized(): |
|
|
if self.efficient_loss: |
|
|
t = torch.exp(self.t_prime) |
|
|
loss = 0.0 |
|
|
|
|
|
for target_rank in range(world_size): |
|
|
if rank == target_rank: |
|
|
target_text_features = text_features |
|
|
else: |
|
|
target_text_features = torch.distributed.nn.broadcast( |
|
|
text_features.requires_grad_(), target_rank |
|
|
) |
|
|
|
|
|
local_logits_per_image = ( |
|
|
image_features @ target_text_features.T |
|
|
) * t + self.bias |
|
|
local_logits_per_text = local_logits_per_image.T |
|
|
|
|
|
if rank == target_rank: |
|
|
local_labels = 2 * torch.eye( |
|
|
batch_size, device=device |
|
|
) - torch.ones(batch_size, batch_size, device=device) |
|
|
else: |
|
|
local_labels = -torch.ones( |
|
|
batch_size, batch_size, device=device |
|
|
) |
|
|
|
|
|
local_logits = ( |
|
|
local_logits_per_image + local_logits_per_text |
|
|
) / 2.0 |
|
|
local_loss = -torch.sum( |
|
|
F.logsigmoid(local_labels * local_logits) |
|
|
) / (batch_size * world_size) |
|
|
|
|
|
loss += local_loss |
|
|
|
|
|
torch.distributed.nn.all_reduce(loss) |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
if self.training: |
|
|
logits = 0 |
|
|
else: |
|
|
t = torch.exp(self.t_prime) |
|
|
|
|
|
all_image_features, all_text_features = gather_features( |
|
|
image_features, |
|
|
text_features, |
|
|
gather_with_grad=True, |
|
|
rank=rank, |
|
|
world_size=world_size, |
|
|
) |
|
|
|
|
|
logits_per_image = ( |
|
|
all_image_features @ all_text_features.T |
|
|
) * t + self.bias |
|
|
logits_per_text = logits_per_image.T |
|
|
batch_size = all_image_features.size(0) |
|
|
|
|
|
labels = 2 * torch.eye( |
|
|
batch_size, device=image_features.device |
|
|
) - torch.ones(batch_size, device=image_features.device) |
|
|
|
|
|
logits = (logits_per_image + logits_per_text) / 2.0 |
|
|
loss = -torch.sum(F.logsigmoid(labels * logits)) / batch_size |
|
|
|
|
|
else: |
|
|
logits_per_image = ( |
|
|
image_features @ text_features.T |
|
|
) * self.t_prime + self.bias |
|
|
logits_per_text = logits_per_image.T |
|
|
|
|
|
labels = 2 * torch.eye(batch_size, device=device) - torch.ones( |
|
|
batch_size, batch_size, device=device |
|
|
) |
|
|
|
|
|
logits = (logits_per_image + logits_per_text) / 2.0 |
|
|
loss = -torch.sum(F.logsigmoid(logits * labels)) |
|
|
else: |
|
|
all_image_features, all_text_features = gather_features( |
|
|
image_features, |
|
|
text_features, |
|
|
local_loss=self.local_loss, |
|
|
gather_with_grad=True, |
|
|
rank=rank, |
|
|
world_size=world_size, |
|
|
) |
|
|
|
|
|
if self.gather_loss: |
|
|
if self.local_loss: |
|
|
logits_per_image = ( |
|
|
self.logit_scale * image_features @ all_text_features.T |
|
|
) |
|
|
logits_per_text = ( |
|
|
self.logit_scale * text_features @ all_image_features.T |
|
|
) |
|
|
else: |
|
|
logits_per_image = ( |
|
|
self.logit_scale * all_image_features @ all_text_features.T |
|
|
) |
|
|
logits_per_text = logits_per_image.T |
|
|
else: |
|
|
logits_per_image = self.logit_scale * image_features @ text_features.T |
|
|
logits_per_text = self.logit_scale * text_features @ image_features.T |
|
|
|
|
|
image_loss = F.cross_entropy(logits_per_image, labels) |
|
|
text_loss = F.cross_entropy(logits_per_text, labels) |
|
|
|
|
|
loss = (image_loss + text_loss) / 2.0 |
|
|
logits = ((logits_per_image + logits_per_text) / 2.0,) |
|
|
|
|
|
ret = { |
|
|
"loss": loss, |
|
|
"logits": logits, |
|
|
} |
|
|
|
|
|
return ret |
|
|
|
|
|
|
|
|
def gather_features( |
|
|
image_features, |
|
|
text_features, |
|
|
local_loss=False, |
|
|
gather_with_grad=True, |
|
|
rank=0, |
|
|
world_size=1, |
|
|
): |
|
|
assert ( |
|
|
has_distributed |
|
|
), "torch.distributed did not import correctly, please use a PyTorch version with support." |
|
|
|
|
|
if not (has_distributed and dist.is_initialized()): |
|
|
return image_features, text_features |
|
|
|
|
|
if gather_with_grad: |
|
|
all_image_features = torch.cat( |
|
|
torch.distributed.nn.all_gather(image_features), dim=0 |
|
|
) |
|
|
all_text_features = torch.cat( |
|
|
torch.distributed.nn.all_gather(text_features), dim=0 |
|
|
) |
|
|
else: |
|
|
gathered_image_features = [ |
|
|
torch.zeros_like(image_features) for _ in range(world_size) |
|
|
] |
|
|
gathered_text_features = [ |
|
|
torch.zeros_like(text_features) for _ in range(world_size) |
|
|
] |
|
|
dist.all_gather(gathered_image_features, image_features) |
|
|
dist.all_gather(gathered_text_features, text_features) |
|
|
if not local_loss: |
|
|
gathered_image_features[rank] = image_features |
|
|
gathered_text_features[rank] = text_features |
|
|
all_image_features = torch.cat(gathered_image_features, dim=0) |
|
|
all_text_features = torch.cat(gathered_text_features, dim=0) |
|
|
|
|
|
return all_image_features, all_text_features |
|
|
|
|
|
|
|
|
AutoConfig.register("dec_clip", DEC_CLIPConfig) |
|
|
AutoModel.register(DEC_CLIPConfig, DEC_CLIP) |
|
|
|
|
|
|
|
|
def stem(inp, oup, image_size, downsample=False): |
|
|
stride = 1 if downsample == False else 2 |
|
|
return nn.Sequential( |
|
|
nn.Conv3d(inp, oup, 3, stride, 1, bias=False), |
|
|
nn.BatchNorm3d(oup), |
|
|
nn.GELU(), |
|
|
nn.Conv3d(oup, oup, 3, 1, 1, bias=False), |
|
|
nn.BatchNorm3d(oup), |
|
|
nn.GELU(), |
|
|
) |
|
|
|
|
|
|
|
|
def DecomposedStem(inp, oup, image_size, kernel_size, downsample=False): |
|
|
return nn.Sequential( |
|
|
DecompConv3D(inp, oup, 7, 4, 1, nn.GELU()), |
|
|
DecompConv3D(oup, oup, 3, 1, 1, nn.GELU()), |
|
|
DecompConv3D(oup, oup, 3, 1, 1, nn.GELU()), |
|
|
DecompConv3D(oup, oup, 3, 1, 1, nn.GELU()), |
|
|
) |
|
|
|
|
|
|
|
|
class DecompConv3D(nn.Module): |
|
|
def __init__( |
|
|
self, in_dim, out_dim, kernel_size, stride=1, groups=1, norm=True, act=None |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.act = act |
|
|
|
|
|
self.c1 = nn.Sequential( |
|
|
nn.Conv3d( |
|
|
in_dim, |
|
|
out_dim, |
|
|
kernel_size=(kernel_size, 1, 1), |
|
|
padding=(kernel_size // 2, 0, 0), |
|
|
stride=stride, |
|
|
groups=groups, |
|
|
), |
|
|
nn.BatchNorm3d(out_dim) if norm else nn.Identity(), |
|
|
) |
|
|
self.c2 = nn.Sequential( |
|
|
nn.Conv3d( |
|
|
in_dim, |
|
|
out_dim, |
|
|
kernel_size=(1, kernel_size, 1), |
|
|
padding=(0, kernel_size // 2, 0), |
|
|
stride=stride, |
|
|
groups=groups, |
|
|
), |
|
|
nn.BatchNorm3d(out_dim) if norm else nn.Identity(), |
|
|
) |
|
|
self.c3 = nn.Sequential( |
|
|
nn.Conv3d( |
|
|
in_dim, |
|
|
out_dim, |
|
|
kernel_size=(1, 1, kernel_size), |
|
|
padding=(0, 0, kernel_size // 2), |
|
|
stride=stride, |
|
|
groups=groups, |
|
|
), |
|
|
nn.BatchNorm3d(out_dim) if norm else nn.Identity(), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.c1(x) + self.c2(x) + self.c3(x) |
|
|
if self.act is not None: |
|
|
x = self.act(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class ConvPosEnc(nn.Module): |
|
|
|
|
|
def __init__(self, dim, k=3, decompose=False): |
|
|
super().__init__() |
|
|
if decompose: |
|
|
self.proj = DecompConv3D(dim, dim, k, groups=dim, norm=None) |
|
|
else: |
|
|
self.proj = nn.Conv3d( |
|
|
dim, dim, to_3tuple(k), to_3tuple(1), to_3tuple(k // 2), groups=dim |
|
|
) |
|
|
|
|
|
def forward(self, x, size): |
|
|
B, N, C = x.shape |
|
|
H, W, T = size |
|
|
assert N == H * W * T |
|
|
feat = rearrange(x, "b (h w t) c -> b c h w t", h=H, w=W, t=T) |
|
|
feat = self.proj(feat) |
|
|
feat = rearrange(feat, "b c h w t -> b (h w t) c ") |
|
|
x = x + feat |
|
|
return x |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, oup, mlp_dim, dp=0.0): |
|
|
super().__init__() |
|
|
|
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(oup, mlp_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dp), |
|
|
nn.Linear(mlp_dim, oup), |
|
|
nn.Dropout(dp), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.mlp(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class ScaleDotProduct(nn.Module): |
|
|
def __init__(self, scale): |
|
|
super().__init__() |
|
|
self.scale = scale |
|
|
self.softmax = nn.Softmax(dim=-1) |
|
|
|
|
|
def forward(self, qkv): |
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
|
|
q = q * self.scale |
|
|
attn = q @ k.transpose(-2, -1) |
|
|
attn = self.softmax(attn) |
|
|
|
|
|
x = (attn @ v).transpose(1, 2) |
|
|
return x |
|
|
|
|
|
|
|
|
class DecomposedAttention(nn.Module): |
|
|
def __init__(self, oup, head_num): |
|
|
super().__init__() |
|
|
|
|
|
self.head_num = head_num |
|
|
scale = (oup // head_num) ** (1 / 2) |
|
|
self.sdp = ScaleDotProduct(scale) |
|
|
self.qkv = nn.Linear(oup, oup * 3, bias=False) |
|
|
self.proj = nn.Linear(oup, oup, bias=False) |
|
|
|
|
|
def forward(self, x, size): |
|
|
b, n, c = x.shape |
|
|
h, w, t = size |
|
|
assert n == h * w * t |
|
|
B_, N, C = x.shape |
|
|
qkv = ( |
|
|
self.qkv(x) |
|
|
.reshape(B_, N, 3, self.head_num, C // self.head_num) |
|
|
.permute(2, 0, 3, 1, 4) |
|
|
) |
|
|
|
|
|
x = rearrange(qkv, "k b nh (h w t) c -> k b c nh h w t", h=h, w=w, t=t) |
|
|
|
|
|
x1 = rearrange(x, "k b c nh h w t -> k (b t) nh (h w) c") |
|
|
x2 = rearrange(x, "k b c nh h w t -> k (b w) nh (h t) c") |
|
|
x3 = rearrange(x, "k b c nh h w t -> k (b h) nh (w t) c") |
|
|
|
|
|
x1 = self.sdp(x1) |
|
|
x2 = self.sdp(x2) |
|
|
x3 = self.sdp(x3) |
|
|
|
|
|
x1 = rearrange(x1, "(b t) (h w) nh c -> b (h w t) (nh c)", h=h, w=w, t=t) |
|
|
x2 = rearrange(x2, "(b w) (h t) nh c -> b (h w t) (nh c)", h=h, w=w, t=t) |
|
|
x3 = rearrange(x3, "(b h) (w t) nh c -> b (h w t) (nh c)", h=h, w=w, t=t) |
|
|
x = self.proj(x1 + x2 + x3) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class SelfAttention(nn.Module): |
|
|
def __init__(self, oup, head_num): |
|
|
super().__init__() |
|
|
|
|
|
self.head_num = head_num |
|
|
scale = (oup // head_num) ** (1 / 2) |
|
|
self.sdp = ScaleDotProduct(scale) |
|
|
self.qkv = nn.Linear(oup, oup * 3, bias=False) |
|
|
self.proj = nn.Linear(oup, oup, bias=False) |
|
|
|
|
|
def forward(self, x, size=None): |
|
|
B_, N, C = x.shape |
|
|
qkv = ( |
|
|
self.qkv(x) |
|
|
.reshape(B_, N, 3, self.head_num, C // self.head_num) |
|
|
.permute(2, 0, 3, 1, 4) |
|
|
) |
|
|
|
|
|
x = self.sdp(qkv).reshape(B_, N, C) |
|
|
x = self.proj(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class ChannelAttention(nn.Module): |
|
|
def __init__(self, oup, head_num): |
|
|
super().__init__() |
|
|
|
|
|
self.head_num = head_num |
|
|
self.scale = (oup // head_num) ** (1 / 2) |
|
|
self.softmax = nn.Softmax(dim=-1) |
|
|
|
|
|
self.qkv = nn.Linear(oup, oup * 3, bias=False) |
|
|
self.proj = nn.Linear(oup, oup, bias=False) |
|
|
|
|
|
def forward(self, x): |
|
|
B, N, C = x.shape |
|
|
|
|
|
qkv = ( |
|
|
self.qkv(x) |
|
|
.reshape(B, N, 3, self.head_num, C // self.head_num) |
|
|
.permute(2, 0, 3, 1, 4) |
|
|
) |
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
|
|
k = k * self.scale |
|
|
attention = k.transpose(-1, -2) @ v |
|
|
attention = self.softmax(attention) |
|
|
x = (attention @ q.transpose(-1, -2)).transpose(-1, -2) |
|
|
x = x.transpose(1, 2).reshape(B, N, C) |
|
|
x = self.proj(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class ChannelBlock(nn.Module): |
|
|
def __init__(self, dim, heads=8): |
|
|
super().__init__() |
|
|
hidden_dim = int(dim * 4) |
|
|
|
|
|
self.cpe = nn.ModuleList( |
|
|
[ |
|
|
ConvPosEnc(dim=dim, k=3, decompose=True), |
|
|
ConvPosEnc(dim=dim, k=3, decompose=True), |
|
|
] |
|
|
) |
|
|
|
|
|
self.attn = ChannelAttention(dim, heads) |
|
|
|
|
|
self.layer_norm1 = nn.LayerNorm(dim) |
|
|
|
|
|
self.mlp1 = MLP(dim, hidden_dim) |
|
|
self.layer_norm2 = nn.LayerNorm(dim) |
|
|
|
|
|
def forward(self, x, size): |
|
|
x = self.cpe[0](x, size) |
|
|
_x = self.layer_norm1(x) |
|
|
|
|
|
_x = self.attn(_x) |
|
|
x = x + _x |
|
|
|
|
|
x = self.cpe[1](x, size) |
|
|
_x = self.layer_norm2(x) |
|
|
_x = self.mlp1(_x) |
|
|
x = x + _x |
|
|
return x |
|
|
|
|
|
|
|
|
class SpatialBlock(nn.Module): |
|
|
def __init__(self, dim, heads=8): |
|
|
super().__init__() |
|
|
hidden_dim = int(dim * 4) |
|
|
|
|
|
self.cpe = nn.ModuleList( |
|
|
[ |
|
|
ConvPosEnc(dim=dim, k=3, decompose=True), |
|
|
ConvPosEnc(dim=dim, k=3, decompose=True), |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
self.attn = SelfAttention(dim, heads) |
|
|
|
|
|
self.layer_norm1 = nn.LayerNorm(dim) |
|
|
|
|
|
self.mlp1 = MLP(dim, hidden_dim) |
|
|
self.layer_norm2 = nn.LayerNorm(dim) |
|
|
|
|
|
def forward(self, x, size): |
|
|
x = self.cpe[0](x, size) |
|
|
_x = self.layer_norm1(x) |
|
|
|
|
|
_x = self.attn(_x, size) |
|
|
x = x + _x |
|
|
|
|
|
x = self.cpe[1](x, size) |
|
|
_x = self.layer_norm2(x) |
|
|
_x = self.mlp1(_x) |
|
|
x = x + _x |
|
|
return x |
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
inp, |
|
|
oup, |
|
|
image_size, |
|
|
kernel_size, |
|
|
heads=8, |
|
|
dim_head=32, |
|
|
downsample=False, |
|
|
dropout=0.0, |
|
|
): |
|
|
super().__init__() |
|
|
hidden_dim = int(inp * 4) |
|
|
|
|
|
self.ih, self.iw, self.it = image_size |
|
|
self.downsample = downsample |
|
|
|
|
|
if self.downsample: |
|
|
self.pool1 = nn.MaxPool3d(3, 2, 1) |
|
|
|
|
|
self.proj = nn.Conv3d(inp, oup, 1, 1, 0, bias=False) |
|
|
|
|
|
self.spatial_attention = SpatialBlock(oup, heads) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
if self.downsample: |
|
|
x = self.proj(self.pool1(x)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
h, w, t = x.shape[2], x.shape[3], x.shape[4] |
|
|
size = (h, w, t) |
|
|
x = rearrange(x, "b c h w t -> b (h w t) c ") |
|
|
|
|
|
x = self.spatial_attention(x, size) |
|
|
|
|
|
|
|
|
x = rearrange(x, "b (h w t) c -> b c h w t", h=h, w=w, t=t) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class ConvBlock(nn.Module): |
|
|
def __init__( |
|
|
self, inp, oup, image_size, kernel_size=7, downsample=False, expansion=4 |
|
|
): |
|
|
super().__init__() |
|
|
self.downsample = downsample |
|
|
stride = 1 if self.downsample == False else 2 |
|
|
hidden_dim = int(oup * expansion) |
|
|
drop_path = 0.0 |
|
|
layer_scale_init_value = 1e-6 |
|
|
if self.downsample: |
|
|
self.pool = nn.MaxPool3d(3, 2, 1) |
|
|
self.proj = nn.Conv3d(inp, oup, 1, 1, 0, bias=False) |
|
|
|
|
|
|
|
|
self.dwconv = DecompConv3D(oup, oup, kernel_size, groups=oup) |
|
|
self.mlp = MLP(oup, hidden_dim) |
|
|
|
|
|
self.scale = ( |
|
|
nn.Parameter(layer_scale_init_value * torch.ones((oup)), requires_grad=True) |
|
|
if layer_scale_init_value > 0 |
|
|
else None |
|
|
) |
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
|
|
|
def forward(self, x): |
|
|
if self.downsample: |
|
|
x = self.proj(self.pool(x)) |
|
|
input = x |
|
|
x = self.dwconv(x) |
|
|
x = x.permute(0, 2, 3, 4, 1) |
|
|
|
|
|
x = self.mlp(x) |
|
|
|
|
|
if self.scale is not None: |
|
|
x = self.scale * x |
|
|
x = x.permute(0, 4, 1, 2, 3) |
|
|
|
|
|
x = input + self.drop_path(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class Encoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
input_size, |
|
|
in_channels, |
|
|
num_blocks, |
|
|
channels, |
|
|
kernel_sizes=[7, 7, 7, 7], |
|
|
block_types=["C", "C", "C", "C"], |
|
|
): |
|
|
super().__init__() |
|
|
self.dims = channels |
|
|
ih, iw, it = input_size |
|
|
block = {"C": ConvBlock, "T": TransformerBlock} |
|
|
i = 4 |
|
|
self.s0 = self._make_layer( |
|
|
DecomposedStem, |
|
|
in_channels, |
|
|
channels[0], |
|
|
num_blocks[0], |
|
|
kernel_sizes[0], |
|
|
(ih // i, iw // i, it // i), |
|
|
) |
|
|
self.s1 = self._make_layer( |
|
|
block[block_types[0]], |
|
|
channels[0], |
|
|
channels[1], |
|
|
num_blocks[1], |
|
|
kernel_sizes[0], |
|
|
(ih // (i * 2**1), iw // (i * 2**1), it // (i * 2**1)), |
|
|
) |
|
|
self.s2 = self._make_layer( |
|
|
block[block_types[1]], |
|
|
channels[1], |
|
|
channels[2], |
|
|
num_blocks[2], |
|
|
kernel_sizes[1], |
|
|
(ih // (i * 2**2), iw // (i * 2**2), it // (i * 2**2)), |
|
|
) |
|
|
self.s3 = self._make_layer( |
|
|
block[block_types[2]], |
|
|
channels[2], |
|
|
channels[3], |
|
|
num_blocks[3], |
|
|
kernel_sizes[2], |
|
|
(ih // (i * 2**3), iw // (i * 2**3), it // (i * 2**3)), |
|
|
) |
|
|
self.s4 = self._make_layer( |
|
|
block[block_types[3]], |
|
|
channels[3], |
|
|
channels[4], |
|
|
num_blocks[4], |
|
|
kernel_sizes[3], |
|
|
(ih // (i * 2**4), iw // (i * 2**4), it // (i * 2**4)), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
hidden_states = [] |
|
|
|
|
|
x = x.permute(0, 1, 3, 4, 2) |
|
|
|
|
|
for i in range(5): |
|
|
if hasattr(self, "s" + str(i)): |
|
|
x = getattr(self, "s" + str(i))(x) |
|
|
hidden_states.append(x) |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
def _make_layer(self, block, inp, oup, depth, kernel_size, image_size): |
|
|
layers = nn.ModuleList([]) |
|
|
for i in range(depth): |
|
|
if i == 0: |
|
|
layers.append(block(inp, oup, image_size, kernel_size, downsample=True)) |
|
|
else: |
|
|
layers.append(block(oup, oup, image_size, kernel_size)) |
|
|
return nn.Sequential(*layers) |
|
|
|
|
|
|
|
|
class DecompModel(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
input_size=(512, 512, 256), |
|
|
in_channels=1, |
|
|
num_blocks=[2, 2, 3, 5, 2], |
|
|
channels=[64, 96, 192, 384, 768], |
|
|
|
|
|
kernel_sizes=[13, 11, 9, 7], |
|
|
block_types=["C", "C", "C", "C"], |
|
|
codebook_size=8192, |
|
|
): |
|
|
super().__init__() |
|
|
self.channels = channels |
|
|
self.encoder = Encoder( |
|
|
input_size, in_channels, num_blocks, channels, kernel_sizes, block_types |
|
|
) |
|
|
|
|
|
|
|
|
def forward(self, video, mask=None, device="cuda"): |
|
|
hidden_states = self.encoder(video) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(hidden_states)): |
|
|
hidden_states[i] = rearrange(hidden_states[i], "b d h w t -> b t h w d") |
|
|
hidden_states[i], _ = pack([hidden_states[i]], "b * d") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
def decomp_nano( |
|
|
input_size=(512, 512, 256), |
|
|
|
|
|
): |
|
|
|
|
|
model = DecompModel( |
|
|
input_size=input_size, |
|
|
num_blocks=[1, 1, 1, 1, 1], |
|
|
channels=[32, 32, 64, 128, 256], |
|
|
) |
|
|
return model |
|
|
|
|
|
|
|
|
def decomp_naive( |
|
|
input_size=(512, 512, 256), |
|
|
|
|
|
): |
|
|
|
|
|
model = DecompModel( |
|
|
input_size=input_size, |
|
|
num_blocks=[1, 2, 2, 2, 2], |
|
|
|
|
|
channels=[32, 64, 128, 256, 512], |
|
|
) |
|
|
return model |
|
|
|
|
|
|
|
|
def decomp_tiny( |
|
|
input_size=(512, 512, 256), |
|
|
): |
|
|
|
|
|
model = DecompModel( |
|
|
input_size=input_size, |
|
|
num_blocks=[1, 2, 3, 3, 2], |
|
|
|
|
|
channels=[64, 96, 192, 384, 768], |
|
|
) |
|
|
return model |
|
|
|
|
|
|
|
|
def decomp_small( |
|
|
input_size=(512, 512, 256), |
|
|
): |
|
|
|
|
|
model = DecompModel( |
|
|
input_size=input_size, |
|
|
num_blocks=[1, 2, 3, 6, 2], |
|
|
channels=[64, 96, 192, 384, 768], |
|
|
) |
|
|
return model |
|
|
|
|
|
|
|
|
def decomp_base( |
|
|
input_size=(512, 512, 256), |
|
|
): |
|
|
|
|
|
model = DecompModel( |
|
|
input_size=input_size, |
|
|
num_blocks=[1, 2, 6, 6, 2], |
|
|
channels=[64, 128, 256, 512, 1024], |
|
|
) |
|
|
return model |
|
|
|
|
|
|
|
|
def decomp_large( |
|
|
input_size=(512, 512, 256), |
|
|
): |
|
|
|
|
|
model = DecompModel( |
|
|
input_size=input_size, |
|
|
num_blocks=[1, 2, 6, 12, 2], |
|
|
|
|
|
channels=[64, 256, 512, 1024, 2048], |
|
|
) |
|
|
return model |
|
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
def __init__(self, dim, hidden_dim, dropout=0.0): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.LayerNorm(dim), |
|
|
nn.Linear(dim, hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_dim, dim), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.net(x) |
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
|
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): |
|
|
super().__init__() |
|
|
inner_dim = dim_head * heads |
|
|
project_out = not (heads == 1 and dim_head == dim) |
|
|
|
|
|
self.heads = heads |
|
|
self.scale = dim_head**-0.5 |
|
|
|
|
|
self.norm = nn.LayerNorm(dim) |
|
|
self.attend = nn.Softmax(dim=-1) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) |
|
|
|
|
|
self.to_out = ( |
|
|
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) |
|
|
if project_out |
|
|
else nn.Identity() |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.norm(x) |
|
|
qkv = self.to_qkv(x).chunk(3, dim=-1) |
|
|
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) |
|
|
|
|
|
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale |
|
|
|
|
|
attn = self.attend(dots) |
|
|
attn = self.dropout(attn) |
|
|
|
|
|
out = torch.matmul(attn, v) |
|
|
out = rearrange(out, "b h n d -> b n (h d)") |
|
|
return self.to_out(out) |
|
|
|
|
|
|
|
|
class Transformer(nn.Module): |
|
|
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0): |
|
|
super().__init__() |
|
|
self.layers = nn.ModuleList([]) |
|
|
for _ in range(depth): |
|
|
self.layers.append( |
|
|
nn.ModuleList( |
|
|
[ |
|
|
Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout), |
|
|
FeedForward(dim, mlp_dim, dropout=dropout), |
|
|
] |
|
|
) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
for attn, ff in self.layers: |
|
|
x = attn(x) + x |
|
|
x = ff(x) + x |
|
|
return x |
|
|
|
|
|
|
|
|
class ViTEncoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
image_size=[512, 512, 256], |
|
|
patch_size=16, |
|
|
dim=512, |
|
|
depth=8, |
|
|
heads=8, |
|
|
mlp_dim=4, |
|
|
channels=1, |
|
|
dim_head=64, |
|
|
dropout=0.0, |
|
|
emb_dropout=0.0, |
|
|
): |
|
|
super().__init__() |
|
|
h, w, t = image_size[0], image_size[1], image_size[2] |
|
|
|
|
|
self.vit_img_dim = [i // patch_size for i in image_size] |
|
|
num_patches = (h // patch_size) * (w // patch_size) * (t // patch_size) |
|
|
|
|
|
patch_dim = channels * patch_size * patch_size * patch_size |
|
|
|
|
|
self.to_patch_embedding = nn.Sequential( |
|
|
Rearrange( |
|
|
"b c (h p1) (w p2) (t p3) -> b (h w t) (p1 p2 p3 c)", |
|
|
p1=patch_size, |
|
|
p2=patch_size, |
|
|
p3=patch_size, |
|
|
), |
|
|
nn.LayerNorm(patch_dim), |
|
|
nn.Linear(patch_dim, dim), |
|
|
nn.LayerNorm(dim), |
|
|
) |
|
|
|
|
|
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) |
|
|
self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) |
|
|
self.dropout = nn.Dropout(emb_dropout) |
|
|
|
|
|
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x.permute(0, 1, 3, 4, 2) |
|
|
|
|
|
x = self.to_patch_embedding(x) |
|
|
b, n, _ = x.shape |
|
|
|
|
|
cls_tokens = repeat(self.cls_token, "1 1 d -> b 1 d", b=b) |
|
|
x = torch.cat((cls_tokens, x), dim=1) |
|
|
x += self.pos_embedding[:, : (n + 1)] |
|
|
x = self.dropout(x) |
|
|
|
|
|
x = self.transformer(x) |
|
|
x = x[:, 1:, :] |
|
|
x = rearrange( |
|
|
x, |
|
|
"b (x y z) c -> b c x y z", |
|
|
x=self.vit_img_dim[0], |
|
|
y=self.vit_img_dim[1], |
|
|
z=self.vit_img_dim[2], |
|
|
) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class Vit3D(nn.Module): |
|
|
def __init__(self, input_size=[512, 512, 256], patch_size=32, dim=512, depth=8): |
|
|
super().__init__() |
|
|
|
|
|
self.encoder = ViTEncoder(input_size, patch_size, dim, depth) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, video, mask=None, device="cuda"): |
|
|
tokens = self.encoder(video) |
|
|
tokens = rearrange(tokens, "b d h w t -> b t h w d") |
|
|
shape = tokens.shape |
|
|
*_, h, w, _ = shape |
|
|
|
|
|
tokens, _ = pack([tokens], "b * d") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return tokens |
|
|
|
|
|
|
|
|
def build_vision_tower(config, **kwargs): |
|
|
return VisionTower(config) |
|
|
|
|
|
|
|
|
class VisionTower(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
|
|
|
self.config = config |
|
|
self.select_layer = config.vision_select_layer |
|
|
self.select_feature = config.vision_select_feature |
|
|
self.hidden_size = config.dim |
|
|
|
|
|
if config.vision_tower == "vit3d": |
|
|
self.vision_tower = Vit3D( |
|
|
input_size=config.input_size, |
|
|
dim=config.dim, |
|
|
depth=config.depth, |
|
|
) |
|
|
elif config.vision_tower == "dcformer": |
|
|
self.vision_tower = decomp_small( |
|
|
input_size=config.input_size, |
|
|
) |
|
|
self.low_input_size = self.vision_tower.channels[-2] |
|
|
self.high_input_size = self.vision_tower.channels[-1] |
|
|
else: |
|
|
raise ValueError(f"Unexpected vision tower: {config.vision_tower}") |
|
|
|
|
|
def forward(self, images): |
|
|
hidden_states = self.vision_tower(images) |
|
|
if self.select_layer == 0: |
|
|
image_features = hidden_states |
|
|
elif self.select_layer < 0: |
|
|
image_features = hidden_states[self.select_layer :] |
|
|
else: |
|
|
raise ValueError(f"Unexpected select layer: {self.select_layer}") |
|
|
|
|
|
if self.select_feature == "patch": |
|
|
image_features = image_features[:, 1:] |
|
|
elif self.select_feature == "cls_patch": |
|
|
image_features = image_features |
|
|
else: |
|
|
raise ValueError(f"Unexpected select feature: {self.select_feature}") |
|
|
|
|
|
return image_features |
|
|
|
|
|
@property |
|
|
def dtype(self): |
|
|
return self.vision_tower.dtype |
|
|
|
|
|
@property |
|
|
def device(self): |
|
|
return self.vision_tower.device |
|
|
|
|
|
|
|
|
def readable_params(num): |
|
|
magnitude = 0 |
|
|
while abs(num) >= 1000: |
|
|
magnitude += 1 |
|
|
num /= 1000.0 |
|
|
return "%.2f%s" % (num, ["", "K", "M", "G", "T", "P"][magnitude]) |
|
|
|
|
|
|
|
|
class MLPLayer(nn.Module): |
|
|
def __init__(self, embed_dim, scale=4, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.linear1 = nn.Linear(embed_dim, embed_dim * scale) |
|
|
self.linear2 = nn.Linear(embed_dim * scale, embed_dim) |
|
|
self.act = nn.GELU() |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.linear1(x) |
|
|
x = self.act(x) |
|
|
x = self.linear2(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class MultiHeadSelfAttention(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
embed_dim, |
|
|
output_dim, |
|
|
num_heads=8, |
|
|
proj_out_num=32, |
|
|
): |
|
|
super(MultiHeadSelfAttention, self).__init__() |
|
|
self.embed_dim = embed_dim |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = embed_dim // num_heads |
|
|
self.proj_out_num = proj_out_num |
|
|
self.mlp = MLPLayer(embed_dim) |
|
|
|
|
|
assert ( |
|
|
self.head_dim * num_heads == embed_dim |
|
|
), "embed_dim must be divisible by num_heads" |
|
|
|
|
|
self.norm1 = nn.LayerNorm(embed_dim) |
|
|
|
|
|
self.q_proj = nn.Linear(embed_dim, embed_dim) |
|
|
self.k_proj = nn.Linear(embed_dim, embed_dim) |
|
|
self.v_proj = nn.Linear(embed_dim, embed_dim) |
|
|
|
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim) |
|
|
|
|
|
self.norm2 = nn.LayerNorm(embed_dim) |
|
|
self.act = nn.GELU() |
|
|
self.out_layer = nn.Linear(embed_dim, output_dim) |
|
|
|
|
|
self.scale = math.sqrt(self.head_dim) |
|
|
|
|
|
def forward(self, x): |
|
|
batch_size, seq_len, _ = x.size() |
|
|
|
|
|
x = self.norm1(x) |
|
|
|
|
|
q = ( |
|
|
self.q_proj(x) |
|
|
.reshape(batch_size, seq_len, self.num_heads, self.head_dim) |
|
|
.transpose(1, 2) |
|
|
) |
|
|
k = ( |
|
|
self.k_proj(x) |
|
|
.reshape(batch_size, seq_len, self.num_heads, self.head_dim) |
|
|
.transpose(1, 2) |
|
|
) |
|
|
v = ( |
|
|
self.v_proj(x) |
|
|
.reshape(batch_size, seq_len, self.num_heads, self.head_dim) |
|
|
.transpose(1, 2) |
|
|
) |
|
|
|
|
|
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale |
|
|
|
|
|
attn_weights = F.softmax(attn_weights, dim=-1) |
|
|
|
|
|
attn_output = torch.matmul(attn_weights, v) |
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).reshape( |
|
|
batch_size, seq_len, self.embed_dim |
|
|
) |
|
|
output = self.out_proj(attn_output) |
|
|
|
|
|
output = self.norm2(output) |
|
|
|
|
|
output = self.mlp(output) |
|
|
|
|
|
output = self.act(output) |
|
|
output = self.out_layer(output) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class MultiLayerPerceptron(nn.Module): |
|
|
def __init__(self, hidden_size, depth): |
|
|
super().__init__() |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(hidden_size, hidden_size), |
|
|
*[ |
|
|
nn.Sequential(nn.GELU(), nn.Linear(hidden_size, hidden_size)) |
|
|
for _ in range(depth - 1) |
|
|
], |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.mlp(x) |
|
|
|
|
|
|
|
|
class MultiModalProjector(nn.Module): |
|
|
def __init__(self, input_size, output_size, mlp_depth, proj_out_num=256): |
|
|
super().__init__() |
|
|
self.proj_out_num = proj_out_num |
|
|
self.mm_projector = nn.Sequential( |
|
|
nn.Linear(input_size, output_size), |
|
|
*[ |
|
|
nn.Sequential( |
|
|
nn.GELU(), |
|
|
nn.Linear(output_size, output_size), |
|
|
) |
|
|
for _ in range(mlp_depth - 1) |
|
|
], |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.mm_projector(x) |
|
|
|
|
|
|
|
|
class LowHighHybridMLP(nn.Module): |
|
|
def __init__( |
|
|
self, low_input_size, high_input_size, output_size, mlp_depth, proj_out_num=288 |
|
|
): |
|
|
super().__init__() |
|
|
self.proj_out_num = proj_out_num |
|
|
self.low_up_mlp = nn.Linear(low_input_size, output_size) |
|
|
self.high_up_mlp = nn.Linear(high_input_size, output_size) |
|
|
modules = [] |
|
|
for _ in range(1, mlp_depth): |
|
|
modules.append(nn.GELU()) |
|
|
modules.append(nn.Linear(output_size, output_size)) |
|
|
self.mm_projector = nn.Sequential(*modules) |
|
|
|
|
|
def forward(self, x): |
|
|
low_x, high_x = x |
|
|
|
|
|
low_x = self.low_up_mlp(low_x) |
|
|
high_x = self.high_up_mlp(high_x) |
|
|
x = torch.cat([low_x, high_x], dim=1) |
|
|
|
|
|
x = self.mm_projector(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class MixerLayer(nn.Module): |
|
|
def __init__(self, input_size, output_size, mlp_depth=2): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(input_size[1]) |
|
|
self.ln2 = nn.LayerNorm(input_size[1]) |
|
|
|
|
|
self.mlp1 = MultiModalProjector( |
|
|
input_size=input_size[0], output_size=output_size[0], mlp_depth=mlp_depth |
|
|
) |
|
|
self.mlp2 = MultiModalProjector( |
|
|
input_size=input_size[1], output_size=output_size[1], mlp_depth=mlp_depth |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.ln1(x) |
|
|
x = rearrange(x, "b n d -> b d n") |
|
|
x = self.mlp1(x) |
|
|
x = rearrange(x, "b d n -> b n d") |
|
|
x = self.ln2(x) |
|
|
x = self.mlp2(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class MixerLowHighHybridMLP(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
low_input_size: tuple = (256, 384), |
|
|
low_output_size: list = [192, 128], |
|
|
high_input_size: tuple = (32, 768), |
|
|
high_output_size: list = [64, 128], |
|
|
output_dim=3584, |
|
|
depth=2, |
|
|
mlp_depth=2, |
|
|
proj_out_num=256, |
|
|
): |
|
|
assert ( |
|
|
len(low_output_size) == len(high_output_size) == depth |
|
|
), "Output size must be same for both low and high input" |
|
|
assert output_dim % (2**depth) == 0, "Output dim must be divisible by 2**depth" |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
self.proj_out_num = proj_out_num |
|
|
|
|
|
self.low_mixer = nn.ModuleList( |
|
|
[ |
|
|
MixerLayer( |
|
|
input_size=( |
|
|
(low_output_size[i - 1], output_dim // (2 ** (depth - i))) |
|
|
if i > 0 |
|
|
else low_input_size |
|
|
), |
|
|
output_size=( |
|
|
low_output_size[i], |
|
|
output_dim // (2 ** (depth - i - 1)), |
|
|
), |
|
|
mlp_depth=mlp_depth, |
|
|
) |
|
|
for i in range(depth) |
|
|
] |
|
|
) |
|
|
self.high_mixer = nn.ModuleList( |
|
|
[ |
|
|
MixerLayer( |
|
|
input_size=( |
|
|
(high_output_size[i - 1], output_dim // (2 ** (depth - i))) |
|
|
if i > 0 |
|
|
else high_input_size |
|
|
), |
|
|
output_size=( |
|
|
high_output_size[i], |
|
|
output_dim // (2 ** (depth - i - 1)), |
|
|
), |
|
|
mlp_depth=mlp_depth, |
|
|
) |
|
|
for i in range(depth) |
|
|
] |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
low_x, high_x = x |
|
|
for low_layer, high_layer in zip(self.low_mixer, self.high_mixer): |
|
|
low_x = low_layer(low_x) |
|
|
high_x = high_layer(high_x) |
|
|
x = torch.cat([low_x, high_x], dim=1) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class IdentityMap(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
def forward(self, x, *args, **kwargs): |
|
|
return x |
|
|
|
|
|
@property |
|
|
def config(self): |
|
|
return {"mm_projector_type": "identity"} |
|
|
|
|
|
|
|
|
def build_mm_projector(config, delay_load=False, **kwargs): |
|
|
projector_type = getattr(config, "mm_projector_type", "linear") |
|
|
|
|
|
if projector_type == "linear": |
|
|
return nn.Linear(config.mm_hidden_size, config.hidden_size) |
|
|
elif projector_type == "mlp": |
|
|
return MultiModalProjector( |
|
|
input_size=config.mm_hidden_size, |
|
|
output_size=config.hidden_size, |
|
|
mlp_depth=config.mm_mlp_depth, |
|
|
proj_out_num=config.proj_out_num, |
|
|
) |
|
|
elif projector_type == "low_high_mlp": |
|
|
return LowHighHybridMLP( |
|
|
low_input_size=config.low_input_size, |
|
|
high_input_size=config.high_input_size, |
|
|
output_size=config.hidden_size, |
|
|
mlp_depth=config.mm_mlp_depth, |
|
|
proj_out_num=config.proj_out_num, |
|
|
) |
|
|
elif projector_type == "mixer": |
|
|
return MixerLowHighHybridMLP( |
|
|
low_input_size=config.low_input_size, |
|
|
low_output_size=config.low_output_size, |
|
|
high_input_size=config.high_input_size, |
|
|
high_output_size=config.high_output_size, |
|
|
output_dim=config.hidden_size, |
|
|
depth=len(config.low_output_size), |
|
|
mlp_depth=config.mm_mlp_depth, |
|
|
proj_out_num=config.proj_out_num, |
|
|
) |
|
|
elif projector_type == "mhsa": |
|
|
return MultiHeadSelfAttention( |
|
|
embed_dim=config.mm_hidden_size, |
|
|
output_dim=config.hidden_size, |
|
|
num_heads=hasattr(config, "num_heads") and config.num_heads or 8, |
|
|
proj_out_num=config.proj_out_num, |
|
|
) |
|
|
elif projector_type == "identity": |
|
|
return IdentityMap() |
|
|
else: |
|
|
raise ValueError(f"Unknown projector type: {projector_type}") |
|
|
|
|
|
|
|
|
class VLMMetaModel: |
|
|
|
|
|
def __init__(self, config): |
|
|
super(VLMMetaModel, self).__init__(config) |
|
|
|
|
|
if hasattr(config, "vision_tower"): |
|
|
self.vision_tower = build_vision_tower(config, delay_load=True) |
|
|
self.mm_projector = build_mm_projector(config) |
|
|
|
|
|
def get_vision_tower(self): |
|
|
vision_tower = getattr(self, "vision_tower", None) |
|
|
if type(vision_tower) is list: |
|
|
vision_tower = vision_tower[0] |
|
|
return vision_tower |
|
|
|
|
|
def initialize_vision_modules(self, model_args): |
|
|
self.config.input_size = model_args.input_size |
|
|
self.config.patch_size = model_args.patch_size |
|
|
self.config.dim = model_args.dim |
|
|
self.config.depth = model_args.depth |
|
|
|
|
|
self.config.vision_tower = model_args.vision_tower |
|
|
self.config.vision_select_layer = model_args.vision_select_layer |
|
|
self.config.vision_select_feature = model_args.vision_select_feature |
|
|
|
|
|
self.config.mm_projector_type = model_args.mm_projector_type |
|
|
self.config.mm_mlp_depth = model_args.mm_mlp_depth |
|
|
self.config.proj_out_num = model_args.proj_out_num |
|
|
|
|
|
|
|
|
if self.get_vision_tower() is None: |
|
|
self.vision_tower = build_vision_tower(self.config) |
|
|
self.vision_tower.requires_grad_(not model_args.freeze_vision_tower) |
|
|
|
|
|
if self.config.vision_tower == "hybrid": |
|
|
self.config.low_input_size = self.vision_tower.low_input_size |
|
|
self.config.high_input_size = self.vision_tower.high_input_size |
|
|
elif self.config.mm_projector_type == "mixer": |
|
|
self.config.low_output_size = model_args.low_output_size |
|
|
self.config.high_output_size = model_args.high_output_size |
|
|
self.config.low_input_size = (256, 384) |
|
|
self.config.high_input_size = (32, 768) |
|
|
|
|
|
if model_args.pretrain_vision_model is not None: |
|
|
vision_model_weights = torch.load( |
|
|
model_args.pretrain_vision_model, map_location="cpu" |
|
|
) |
|
|
self.vision_tower.vision_tower.load_state_dict( |
|
|
vision_model_weights, strict=True |
|
|
) |
|
|
|
|
|
if model_args.pretrain_clip_model is not None: |
|
|
clip_model = AutoModel.from_pretrained(model_args.pretrain_clip_model) |
|
|
self.vision_tower.vision_tower = clip_model.vision_encoder |
|
|
|
|
|
self.config.mm_hidden_size = self.vision_tower.hidden_size |
|
|
|
|
|
|
|
|
if getattr(self, "mm_projector", None) is None: |
|
|
self.mm_projector = build_mm_projector(self.config) |
|
|
|
|
|
if model_args.pretrain_mm_mlp_adapter is not None: |
|
|
mm_projector_weights = torch.load( |
|
|
model_args.pretrain_mm_mlp_adapter, map_location="cpu" |
|
|
) |
|
|
|
|
|
if self.config.mm_projector_type == "mlp": |
|
|
|
|
|
def get_w(weights, keyword): |
|
|
return { |
|
|
f"{keyword}.{k.split(keyword + ".")[2]}": v |
|
|
for k, v in weights.items() |
|
|
if keyword in k |
|
|
} |
|
|
|
|
|
elif self.config.mm_projector_type == "low_high_mlp": |
|
|
|
|
|
def get_w(weights, keyword): |
|
|
result = {} |
|
|
for k, v in weights.items(): |
|
|
if keyword in k: |
|
|
if f"{keyword}.{keyword}" in k: |
|
|
part = k.split(f"{keyword}.{keyword}.")[1] |
|
|
result[f"mm_projector.{part}"] = v |
|
|
elif f"{keyword}." in k: |
|
|
part = k.split(f"{keyword}.")[1] |
|
|
result[part] = v |
|
|
return result |
|
|
|
|
|
elif self.config.mm_projector_type == "mixer": |
|
|
|
|
|
def get_w(weights, keyword): |
|
|
result = {} |
|
|
for k, v in weights.items(): |
|
|
if keyword in k: |
|
|
new_key = k.split(".") |
|
|
if len(new_key) > 2: |
|
|
new_key = ".".join(new_key[2:]) |
|
|
result[new_key] = v |
|
|
return result |
|
|
|
|
|
else: |
|
|
|
|
|
def get_w(weights, keyword): |
|
|
result = {} |
|
|
for k, v in weights.items(): |
|
|
if keyword in k: |
|
|
new_key = k.split(".") |
|
|
if len(new_key) > 2: |
|
|
new_key = ".".join(new_key[2:]) |
|
|
result[new_key] = v |
|
|
return result |
|
|
|
|
|
self.mm_projector.load_state_dict( |
|
|
get_w(mm_projector_weights, "mm_projector"), strict=True |
|
|
) |
|
|
|
|
|
|
|
|
class VLMMetaForCausalLM(ABC): |
|
|
@abstractmethod |
|
|
def get_model(self): |
|
|
pass |
|
|
|
|
|
def get_vision_tower(self): |
|
|
return self.get_model().get_vision_tower() |
|
|
|
|
|
def encode_images(self, images): |
|
|
image_features = self.get_model().get_vision_tower()(images) |
|
|
image_features = self.get_model().mm_projector(image_features) |
|
|
return image_features |
|
|
|
|
|
def prepare_inputs_for_multimodal( |
|
|
self, |
|
|
input_ids, |
|
|
position_ids, |
|
|
attention_mask, |
|
|
past_key_values, |
|
|
labels, |
|
|
images, |
|
|
): |
|
|
vision_tower = self.get_vision_tower() |
|
|
if vision_tower is None or images is None or input_ids.shape[1] == 1: |
|
|
return ( |
|
|
input_ids, |
|
|
position_ids, |
|
|
attention_mask, |
|
|
past_key_values, |
|
|
None, |
|
|
labels, |
|
|
) |
|
|
else: |
|
|
image_features = self.encode_images(images) |
|
|
inputs_embeds = self.get_model().embed_tokens(input_ids) |
|
|
inputs_embeds = torch.cat( |
|
|
( |
|
|
inputs_embeds[:, :1, :], |
|
|
image_features, |
|
|
inputs_embeds[:, (image_features.shape[1] + 1) :, :], |
|
|
), |
|
|
dim=1, |
|
|
) |
|
|
return ( |
|
|
None, |
|
|
position_ids, |
|
|
attention_mask, |
|
|
past_key_values, |
|
|
inputs_embeds, |
|
|
labels, |
|
|
) |
|
|
|
|
|
def initialize_vision_tokenizer(self, model_args, tokenizer): |
|
|
num_new_tokens = model_args.num_new_tokens |
|
|
|
|
|
self.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
if num_new_tokens > 0: |
|
|
input_embeddings = self.get_input_embeddings().weight.data |
|
|
output_embeddings = self.get_output_embeddings().weight.data |
|
|
|
|
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( |
|
|
dim=0, keepdim=True |
|
|
) |
|
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( |
|
|
dim=0, keepdim=True |
|
|
) |
|
|
|
|
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg |
|
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg |
|
|
|
|
|
if model_args.tune_mm_mlp_adapter: |
|
|
for p in self.get_input_embeddings().parameters(): |
|
|
p.requires_grad = True |
|
|
for p in self.get_output_embeddings().parameters(): |
|
|
p.requires_grad = False |
|
|
else: |
|
|
for p in self.get_input_embeddings().parameters(): |
|
|
p.requires_grad = True |
|
|
for p in self.get_output_embeddings().parameters(): |
|
|
p.requires_grad = True |
|
|
|
|
|
if model_args.pretrain_mm_mlp_adapter: |
|
|
mm_projector_weights = torch.load( |
|
|
model_args.pretrain_mm_mlp_adapter, map_location="cpu" |
|
|
) |
|
|
|
|
|
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"] |
|
|
|
|
|
if input_embeddings.shape == embed_tokens_weight.shape: |
|
|
input_embeddings = embed_tokens_weight |
|
|
elif embed_tokens_weight.shape[0] == num_new_tokens: |
|
|
input_embeddings[-num_new_tokens:] = embed_tokens_weight |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Unexpected embed_tokens_weight shape. " |
|
|
f"Pretrained: {embed_tokens_weight.shape}. " |
|
|
f"Current: {input_embeddings.shape}. " |
|
|
f"Number of new tokens: {num_new_tokens}." |
|
|
) |
|
|
|
|
|
|
|
|
class VLMQwenConfig(Qwen2Config): |
|
|
model_type = "vlm_qwen" |
|
|
|
|
|
|
|
|
class VLMQwenModel(VLMMetaModel, Qwen2Model): |
|
|
config_class = VLMQwenConfig |
|
|
|
|
|
def __init__(self, config: Qwen2Config): |
|
|
super(VLMQwenModel, self).__init__(config) |
|
|
|
|
|
|
|
|
class VLMQwenForCausalLM(Qwen2ForCausalLM, VLMMetaForCausalLM): |
|
|
config_class = VLMQwenConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super(Qwen2ForCausalLM, self).__init__(config) |
|
|
self.model = VLMQwenModel(config) |
|
|
self.pretraining_tp = getattr(config, "pretraining_tp", None) |
|
|
self.vocab_size = config.vocab_size |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
self.post_init() |
|
|
|
|
|
def get_model(self): |
|
|
return self.model |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
images: Optional[torch.FloatTensor] = None, |
|
|
image_sizes: Optional[List[List[int]]] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
num_logits_to_keep: Optional[int] = None, |
|
|
**kwargs: Any, |
|
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
|
|
|
if inputs_embeds is None: |
|
|
( |
|
|
input_ids, |
|
|
position_ids, |
|
|
attention_mask, |
|
|
past_key_values, |
|
|
inputs_embeds, |
|
|
labels, |
|
|
) = self.prepare_inputs_for_multimodal( |
|
|
input_ids, position_ids, attention_mask, past_key_values, labels, images |
|
|
) |
|
|
|
|
|
return super().forward( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
labels=labels, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
images: Optional[torch.Tensor] = None, |
|
|
inputs: Optional[torch.Tensor] = None, |
|
|
**kwargs, |
|
|
) -> Union[GenerateOutput, torch.LongTensor, Any]: |
|
|
position_ids = kwargs.pop("position_ids", None) |
|
|
attention_mask = kwargs.pop("attention_mask", None) |
|
|
if "inputs_embeds" in kwargs: |
|
|
raise NotImplementedError("`inputs_embeds` is not supported") |
|
|
|
|
|
if images is not None: |
|
|
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = ( |
|
|
self.prepare_inputs_for_multimodal( |
|
|
inputs, |
|
|
position_ids, |
|
|
attention_mask, |
|
|
None, |
|
|
None, |
|
|
images, |
|
|
) |
|
|
) |
|
|
else: |
|
|
inputs_embeds = self.get_model().embed_tokens(inputs) |
|
|
|
|
|
output_ids = super().generate(inputs_embeds=inputs_embeds, **kwargs) |
|
|
return output_ids |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs |
|
|
): |
|
|
images = kwargs.pop("images", None) |
|
|
inputs = super().prepare_inputs_for_generation( |
|
|
input_ids, |
|
|
past_key_values=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
**kwargs, |
|
|
) |
|
|
if images is not None: |
|
|
inputs["images"] = images |
|
|
return inputs |
|
|
|
|
|
|
|
|
AutoConfig.register("vlm_qwen", VLMQwenConfig) |
|
|
AutoModelForCausalLM.register(VLMQwenConfig, VLMQwenForCausalLM) |
|
|
|