Gex_V1 / modeling_gex.py
MosRat's picture
Upload folder using huggingface_hub
56b1f4f verified
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple, Type, Union
from functools import partial
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from typing import Type
from torchvision import transforms
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from torchvision.transforms.functional import InterpolationMode
from transformers import (
Qwen2Config,
Qwen2Model,
Qwen2ForCausalLM,
)
from .configuration_gex import GexConfig
LayerNorm = partial(nn.LayerNorm, eps=1e-6)
class GexImageEvalProcessor:
def __init__(self, image_size=1024, mean=None, std=None):
if mean is None:
mean = (0.48145466, 0.4578275, 0.40821073)
if std is None:
std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = transforms.Normalize(mean, std)
self.transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
self.normalize,
]
)
def __call__(self, item):
return self.transform(item)
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.num_channels = num_channels
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 3, 1)
return torch.nn.functional.layer_norm(
x,
normalized_shape=(self.num_channels,),
weight=self.weight,
bias=self.bias,
eps=self.eps,
).permute(0, 3, 1, 2)
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = 64
self.scale = 64**-0.5
self.seq_len = input_size[0] * input_size[1]
self.input_size = input_size
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)
# self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim))
# self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim))
self.rel_pos_h = nn.Parameter(torch.zeros(input_size[0],input_size[0], self.head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(input_size[1],input_size[1], self.head_dim))
def init_rel_pos(self):
q_size, k_size = self.input_size
q_coords = torch.arange(q_size)[:, None]
k_coords = torch.arange(k_size)[None, :]
relative_coords = (q_coords - k_coords) + (k_size - 1)
self.rel_pos_h = nn.Parameter(self.rel_pos_h.data[relative_coords.long()])
self.rel_pos_w = nn.Parameter(self.rel_pos_w.data[relative_coords.long()])
def get_attn_bias(self, q: torch.Tensor):
q = q.view(-1, *self.input_size, 64)
rel_h = torch.einsum("bhwc,hkc->bhwk", q, self.rel_pos_h)
rel_w = torch.einsum("bhwc,wkc->bhwk", q, self.rel_pos_w)
return (rel_h.unsqueeze(-1) + rel_w.unsqueeze(-2)).reshape(
-1, self.num_heads, self.seq_len, self.seq_len
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
qkv = torch.split(
self.qkv(x).view(-1, self.seq_len, 3 * 768),
768,
dim=2,
)
q, k, v = (
i.unflatten(-1, (self.num_heads, -1)).transpose(1, 2).contiguous()
for i in qkv
)
attn_bias = self.get_attn_bias(q)
attn_output = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_bias, is_causal=False
)
attn_output = attn_output.transpose(1, 2).flatten(-2)
x = self.proj(attn_output)
return x.view(-1, *self.input_size, 768)
class MLP(nn.Module):
def __init__(
self,
):
super().__init__()
self.lin1 = nn.Linear(768, 4 * 768)
self.lin2 = nn.Linear(4 * 768, 768)
self.act = nn.GELU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
class Block(nn.Module):
def __init__(self, idx: int, window_size: int = 14):
super().__init__()
self.idx = idx
self.window_size = window_size
self.norm1 = LayerNorm(768)
self.attn = Attention(
dim=768,
num_heads=12,
input_size=(64, 64) if window_size == 0 else (14, 14),
)
self.norm2 = LayerNorm(768)
self.mlp = MLP()
@staticmethod
def window_partition(x: torch.Tensor) -> torch.Tensor:
x = F.pad(x, (0, 0, 0, 6, 0, 6))
x = (
x.view(-1, 5, 14, 5, 14, 768)
.permute(0, 1, 3, 2, 4, 5)
.contiguous()
.view(-1, 14, 14, 768)
)
return x
@staticmethod
def window_unpartition(x: torch.Tensor) -> torch.Tensor:
x = (
x.view(-1, 5, 5, 14, 14, 768)
.permute(0, 1, 3, 2, 4, 5)
.contiguous()
.view(-1, 70, 70, 768)
)
return x[:, :64, :64, :].contiguous()
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
if self.window_size > 0:
x = self.window_partition(x)
x = self.attn(x)
if self.window_size > 0:
x = self.window_unpartition(x)
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class GexVit(nn.Module):
def __init__(self, global_attn_indexes=[2, 5, 8, 11], **kwargs):
super().__init__()
self.global_attn_indexes = global_attn_indexes
self.patch_embed = PatchEmbed()
self.pos_embed = nn.Parameter(torch.zeros(1, 64, 64, 768))
self.blocks = nn.ModuleList(
[
Block(idx=i, window_size=14 if i not in global_attn_indexes else 0)
for i in range(12)
]
)
self.neck = nn.ModuleList(
[
nn.Conv2d(
768,
256,
kernel_size=1,
bias=False,
),
LayerNorm2d(256),
nn.Conv2d(
256,
256,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(256),
]
)
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
self.net_3 = nn.Conv2d(
512, 1024, kernel_size=3, stride=2, padding=1, bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
x = x.permute(0, 3, 1, 2)
for m in self.neck:
x = m(x)
x = self.net_2(x)
x = self.net_3(x)
return x
class GexQwenModel(Qwen2Model):
config_class = GexConfig
def __init__(self, config: Qwen2Config):
super().__init__(config)
self.vit = GexVit()
self.vit.eval()
self.vit_proj = nn.Linear(1024, 1024)
self.vit_proj.eval()
for param in self.vit.parameters():
param.requires_grad = False
for param in self.vit_proj.parameters():
param.requires_grad = False
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
if images is not None:
assert input_ids is None, input_ids
input_ids = None
attention_mask = None
kwargs["is_causal"] = True
with torch.no_grad():
vit_feature = self.vit_proj(
self.vit(images).flatten(2).permute(0, 2, 1)
)
inputs_embeds = vit_feature
# print(input_ids, images)
if inputs_embeds is None and input_ids is not None:
inputs_embeds = self.embed_tokens(input_ids)
return super().forward(
input_ids=None,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
class GexQwenForCausalLM(Qwen2ForCausalLM):
config_class = GexConfig
# supports_gradient_checkpointing = True
def __init__(self, config):
super().__init__(config)
self.model = GexQwenModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
self.has_image = False
self.image_preprocess = GexImageEvalProcessor()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
images: Optional[torch.FloatTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.has_image:
input_ids = None
self.has_image = False
else:
images = None
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
images=images,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@torch.no_grad
def generate(self,*args,**kwargs):
self.has_image = True
res = super().generate(*args, **kwargs)
self.has_image = False
return res