|
import torch |
|
import torch.nn.functional as F |
|
from typing import List, Optional, Tuple, Type, Union |
|
from functools import partial |
|
import torch.nn as nn |
|
from torch.nn import CrossEntropyLoss |
|
from typing import Type |
|
from torchvision import transforms |
|
from transformers.cache_utils import Cache, DynamicCache |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutputWithPast, |
|
CausalLMOutputWithPast, |
|
) |
|
|
|
from torchvision.transforms.functional import InterpolationMode |
|
from transformers import ( |
|
Qwen2Config, |
|
Qwen2Model, |
|
Qwen2ForCausalLM, |
|
) |
|
|
|
from .configuration_gex import GexConfig |
|
|
|
|
|
LayerNorm = partial(nn.LayerNorm, eps=1e-6) |
|
|
|
|
|
class GexImageEvalProcessor: |
|
def __init__(self, image_size=1024, mean=None, std=None): |
|
if mean is None: |
|
mean = (0.48145466, 0.4578275, 0.40821073) |
|
if std is None: |
|
std = (0.26862954, 0.26130258, 0.27577711) |
|
|
|
self.normalize = transforms.Normalize(mean, std) |
|
|
|
self.transform = transforms.Compose( |
|
[ |
|
transforms.Resize( |
|
(image_size, image_size), interpolation=InterpolationMode.BICUBIC |
|
), |
|
transforms.ToTensor(), |
|
self.normalize, |
|
] |
|
) |
|
|
|
def __call__(self, item): |
|
return self.transform(item) |
|
|
|
|
|
class LayerNorm2d(nn.Module): |
|
def __init__(self, num_channels: int, eps: float = 1e-6) -> None: |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(num_channels)) |
|
self.bias = nn.Parameter(torch.zeros(num_channels)) |
|
self.num_channels = num_channels |
|
self.eps = eps |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = x.permute(0, 2, 3, 1) |
|
return torch.nn.functional.layer_norm( |
|
x, |
|
normalized_shape=(self.num_channels,), |
|
weight=self.weight, |
|
bias=self.bias, |
|
eps=self.eps, |
|
).permute(0, 3, 1, 2) |
|
|
|
|
|
class PatchEmbed(nn.Module): |
|
""" |
|
Image to Patch Embedding. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
kernel_size: Tuple[int, int] = (16, 16), |
|
stride: Tuple[int, int] = (16, 16), |
|
in_chans: int = 3, |
|
embed_dim: int = 768, |
|
) -> None: |
|
""" |
|
Args: |
|
kernel_size (Tuple): kernel size of the projection layer. |
|
stride (Tuple): stride of the projection layer. |
|
padding (Tuple): padding size of the projection layer. |
|
in_chans (int): Number of input image channels. |
|
embed_dim (int): Patch embedding dimension. |
|
""" |
|
super().__init__() |
|
|
|
self.proj = nn.Conv2d( |
|
in_chans, embed_dim, kernel_size=kernel_size, stride=stride |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.proj(x) |
|
|
|
x = x.permute(0, 2, 3, 1) |
|
return x |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
num_heads: int = 8, |
|
input_size: Optional[Tuple[int, int]] = None, |
|
) -> None: |
|
super().__init__() |
|
self.num_heads = num_heads |
|
self.head_dim = 64 |
|
self.scale = 64**-0.5 |
|
self.seq_len = input_size[0] * input_size[1] |
|
self.input_size = input_size |
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=True) |
|
self.proj = nn.Linear(dim, dim) |
|
|
|
|
|
|
|
self.rel_pos_h = nn.Parameter(torch.zeros(input_size[0],input_size[0], self.head_dim)) |
|
self.rel_pos_w = nn.Parameter(torch.zeros(input_size[1],input_size[1], self.head_dim)) |
|
|
|
def init_rel_pos(self): |
|
q_size, k_size = self.input_size |
|
q_coords = torch.arange(q_size)[:, None] |
|
|
|
k_coords = torch.arange(k_size)[None, :] |
|
relative_coords = (q_coords - k_coords) + (k_size - 1) |
|
|
|
self.rel_pos_h = nn.Parameter(self.rel_pos_h.data[relative_coords.long()]) |
|
self.rel_pos_w = nn.Parameter(self.rel_pos_w.data[relative_coords.long()]) |
|
|
|
def get_attn_bias(self, q: torch.Tensor): |
|
q = q.view(-1, *self.input_size, 64) |
|
|
|
rel_h = torch.einsum("bhwc,hkc->bhwk", q, self.rel_pos_h) |
|
rel_w = torch.einsum("bhwc,wkc->bhwk", q, self.rel_pos_w) |
|
|
|
return (rel_h.unsqueeze(-1) + rel_w.unsqueeze(-2)).reshape( |
|
-1, self.num_heads, self.seq_len, self.seq_len |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
qkv = torch.split( |
|
self.qkv(x).view(-1, self.seq_len, 3 * 768), |
|
768, |
|
dim=2, |
|
) |
|
|
|
q, k, v = ( |
|
i.unflatten(-1, (self.num_heads, -1)).transpose(1, 2).contiguous() |
|
for i in qkv |
|
) |
|
|
|
attn_bias = self.get_attn_bias(q) |
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention( |
|
q, k, v, attn_mask=attn_bias, is_causal=False |
|
) |
|
attn_output = attn_output.transpose(1, 2).flatten(-2) |
|
|
|
x = self.proj(attn_output) |
|
|
|
return x.view(-1, *self.input_size, 768) |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__( |
|
self, |
|
): |
|
super().__init__() |
|
self.lin1 = nn.Linear(768, 4 * 768) |
|
self.lin2 = nn.Linear(4 * 768, 768) |
|
self.act = nn.GELU() |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.lin2(self.act(self.lin1(x))) |
|
|
|
|
|
class Block(nn.Module): |
|
def __init__(self, idx: int, window_size: int = 14): |
|
super().__init__() |
|
|
|
self.idx = idx |
|
self.window_size = window_size |
|
|
|
self.norm1 = LayerNorm(768) |
|
|
|
self.attn = Attention( |
|
dim=768, |
|
num_heads=12, |
|
input_size=(64, 64) if window_size == 0 else (14, 14), |
|
) |
|
|
|
self.norm2 = LayerNorm(768) |
|
self.mlp = MLP() |
|
|
|
@staticmethod |
|
def window_partition(x: torch.Tensor) -> torch.Tensor: |
|
x = F.pad(x, (0, 0, 0, 6, 0, 6)) |
|
x = ( |
|
x.view(-1, 5, 14, 5, 14, 768) |
|
.permute(0, 1, 3, 2, 4, 5) |
|
.contiguous() |
|
.view(-1, 14, 14, 768) |
|
) |
|
return x |
|
|
|
@staticmethod |
|
def window_unpartition(x: torch.Tensor) -> torch.Tensor: |
|
x = ( |
|
x.view(-1, 5, 5, 14, 14, 768) |
|
.permute(0, 1, 3, 2, 4, 5) |
|
.contiguous() |
|
.view(-1, 70, 70, 768) |
|
) |
|
return x[:, :64, :64, :].contiguous() |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
shortcut = x |
|
x = self.norm1(x) |
|
if self.window_size > 0: |
|
x = self.window_partition(x) |
|
|
|
x = self.attn(x) |
|
|
|
if self.window_size > 0: |
|
x = self.window_unpartition(x) |
|
|
|
x = shortcut + x |
|
x = x + self.mlp(self.norm2(x)) |
|
|
|
return x |
|
|
|
|
|
class GexVit(nn.Module): |
|
def __init__(self, global_attn_indexes=[2, 5, 8, 11], **kwargs): |
|
super().__init__() |
|
self.global_attn_indexes = global_attn_indexes |
|
self.patch_embed = PatchEmbed() |
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, 64, 64, 768)) |
|
|
|
self.blocks = nn.ModuleList( |
|
[ |
|
Block(idx=i, window_size=14 if i not in global_attn_indexes else 0) |
|
for i in range(12) |
|
] |
|
) |
|
|
|
self.neck = nn.ModuleList( |
|
[ |
|
nn.Conv2d( |
|
768, |
|
256, |
|
kernel_size=1, |
|
bias=False, |
|
), |
|
LayerNorm2d(256), |
|
nn.Conv2d( |
|
256, |
|
256, |
|
kernel_size=3, |
|
padding=1, |
|
bias=False, |
|
), |
|
LayerNorm2d(256), |
|
] |
|
) |
|
|
|
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) |
|
self.net_3 = nn.Conv2d( |
|
512, 1024, kernel_size=3, stride=2, padding=1, bias=False |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.patch_embed(x) |
|
x = x + self.pos_embed |
|
|
|
for blk in self.blocks: |
|
x = blk(x) |
|
|
|
x = x.permute(0, 3, 1, 2) |
|
|
|
for m in self.neck: |
|
x = m(x) |
|
|
|
x = self.net_2(x) |
|
x = self.net_3(x) |
|
|
|
return x |
|
|
|
|
|
class GexQwenModel(Qwen2Model): |
|
config_class = GexConfig |
|
|
|
def __init__(self, config: Qwen2Config): |
|
super().__init__(config) |
|
self.vit = GexVit() |
|
self.vit.eval() |
|
self.vit_proj = nn.Linear(1024, 1024) |
|
self.vit_proj.eval() |
|
|
|
for param in self.vit.parameters(): |
|
param.requires_grad = False |
|
for param in self.vit_proj.parameters(): |
|
param.requires_grad = False |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
images: Optional[torch.FloatTensor] = None, |
|
return_dict: Optional[bool] = None, |
|
**kwargs, |
|
) -> Union[Tuple, BaseModelOutputWithPast]: |
|
if images is not None: |
|
assert input_ids is None, input_ids |
|
input_ids = None |
|
attention_mask = None |
|
kwargs["is_causal"] = True |
|
with torch.no_grad(): |
|
vit_feature = self.vit_proj( |
|
self.vit(images).flatten(2).permute(0, 2, 1) |
|
) |
|
inputs_embeds = vit_feature |
|
|
|
|
|
if inputs_embeds is None and input_ids is not None: |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
return super().forward( |
|
input_ids=None, |
|
attention_mask=attention_mask, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
position_ids=position_ids, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
**kwargs, |
|
) |
|
|
|
|
|
class GexQwenForCausalLM(Qwen2ForCausalLM): |
|
config_class = GexConfig |
|
|
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = GexQwenModel(config) |
|
|
|
self.vocab_size = config.vocab_size |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
self.has_image = False |
|
self.image_preprocess = GexImageEvalProcessor() |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
images: Optional[torch.FloatTensor] = None, |
|
**kwargs, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if self.has_image: |
|
input_ids = None |
|
self.has_image = False |
|
else: |
|
images = None |
|
|
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
cache_position=cache_position, |
|
images=images, |
|
**kwargs, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
|
logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
@torch.no_grad |
|
def generate(self,*args,**kwargs): |
|
self.has_image = True |
|
res = super().generate(*args, **kwargs) |
|
self.has_image = False |
|
return res |
|
|