import torch import torch.nn.functional as F from typing import List, Optional, Tuple, Type, Union from functools import partial import torch.nn as nn from torch.nn import CrossEntropyLoss from typing import Type from torchvision import transforms from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from torchvision.transforms.functional import InterpolationMode from transformers import ( Qwen2Config, Qwen2Model, Qwen2ForCausalLM, ) from .configuration_gex import GexConfig LayerNorm = partial(nn.LayerNorm, eps=1e-6) class GexImageEvalProcessor: def __init__(self, image_size=1024, mean=None, std=None): if mean is None: mean = (0.48145466, 0.4578275, 0.40821073) if std is None: std = (0.26862954, 0.26130258, 0.27577711) self.normalize = transforms.Normalize(mean, std) self.transform = transforms.Compose( [ transforms.Resize( (image_size, image_size), interpolation=InterpolationMode.BICUBIC ), transforms.ToTensor(), self.normalize, ] ) def __call__(self, item): return self.transform(item) class LayerNorm2d(nn.Module): def __init__(self, num_channels: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels)) self.num_channels = num_channels self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.permute(0, 2, 3, 1) return torch.nn.functional.layer_norm( x, normalized_shape=(self.num_channels,), weight=self.weight, bias=self.bias, eps=self.eps, ).permute(0, 3, 1, 2) class PatchEmbed(nn.Module): """ Image to Patch Embedding. """ def __init__( self, kernel_size: Tuple[int, int] = (16, 16), stride: Tuple[int, int] = (16, 16), in_chans: int = 3, embed_dim: int = 768, ) -> None: """ Args: kernel_size (Tuple): kernel size of the projection layer. stride (Tuple): stride of the projection layer. padding (Tuple): padding size of the projection layer. in_chans (int): Number of input image channels. embed_dim (int): Patch embedding dimension. """ super().__init__() self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=kernel_size, stride=stride ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) # B C H W -> B H W C x = x.permute(0, 2, 3, 1) return x class Attention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, input_size: Optional[Tuple[int, int]] = None, ) -> None: super().__init__() self.num_heads = num_heads self.head_dim = 64 self.scale = 64**-0.5 self.seq_len = input_size[0] * input_size[1] self.input_size = input_size self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) # self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim)) # self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim)) self.rel_pos_h = nn.Parameter(torch.zeros(input_size[0],input_size[0], self.head_dim)) self.rel_pos_w = nn.Parameter(torch.zeros(input_size[1],input_size[1], self.head_dim)) def init_rel_pos(self): q_size, k_size = self.input_size q_coords = torch.arange(q_size)[:, None] k_coords = torch.arange(k_size)[None, :] relative_coords = (q_coords - k_coords) + (k_size - 1) self.rel_pos_h = nn.Parameter(self.rel_pos_h.data[relative_coords.long()]) self.rel_pos_w = nn.Parameter(self.rel_pos_w.data[relative_coords.long()]) def get_attn_bias(self, q: torch.Tensor): q = q.view(-1, *self.input_size, 64) rel_h = torch.einsum("bhwc,hkc->bhwk", q, self.rel_pos_h) rel_w = torch.einsum("bhwc,wkc->bhwk", q, self.rel_pos_w) return (rel_h.unsqueeze(-1) + rel_w.unsqueeze(-2)).reshape( -1, self.num_heads, self.seq_len, self.seq_len ) def forward(self, x: torch.Tensor) -> torch.Tensor: qkv = torch.split( self.qkv(x).view(-1, self.seq_len, 3 * 768), 768, dim=2, ) q, k, v = ( i.unflatten(-1, (self.num_heads, -1)).transpose(1, 2).contiguous() for i in qkv ) attn_bias = self.get_attn_bias(q) attn_output = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attn_bias, is_causal=False ) attn_output = attn_output.transpose(1, 2).flatten(-2) x = self.proj(attn_output) return x.view(-1, *self.input_size, 768) class MLP(nn.Module): def __init__( self, ): super().__init__() self.lin1 = nn.Linear(768, 4 * 768) self.lin2 = nn.Linear(4 * 768, 768) self.act = nn.GELU() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.lin2(self.act(self.lin1(x))) class Block(nn.Module): def __init__(self, idx: int, window_size: int = 14): super().__init__() self.idx = idx self.window_size = window_size self.norm1 = LayerNorm(768) self.attn = Attention( dim=768, num_heads=12, input_size=(64, 64) if window_size == 0 else (14, 14), ) self.norm2 = LayerNorm(768) self.mlp = MLP() @staticmethod def window_partition(x: torch.Tensor) -> torch.Tensor: x = F.pad(x, (0, 0, 0, 6, 0, 6)) x = ( x.view(-1, 5, 14, 5, 14, 768) .permute(0, 1, 3, 2, 4, 5) .contiguous() .view(-1, 14, 14, 768) ) return x @staticmethod def window_unpartition(x: torch.Tensor) -> torch.Tensor: x = ( x.view(-1, 5, 5, 14, 14, 768) .permute(0, 1, 3, 2, 4, 5) .contiguous() .view(-1, 70, 70, 768) ) return x[:, :64, :64, :].contiguous() def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x x = self.norm1(x) if self.window_size > 0: x = self.window_partition(x) x = self.attn(x) if self.window_size > 0: x = self.window_unpartition(x) x = shortcut + x x = x + self.mlp(self.norm2(x)) return x class GexVit(nn.Module): def __init__(self, global_attn_indexes=[2, 5, 8, 11], **kwargs): super().__init__() self.global_attn_indexes = global_attn_indexes self.patch_embed = PatchEmbed() self.pos_embed = nn.Parameter(torch.zeros(1, 64, 64, 768)) self.blocks = nn.ModuleList( [ Block(idx=i, window_size=14 if i not in global_attn_indexes else 0) for i in range(12) ] ) self.neck = nn.ModuleList( [ nn.Conv2d( 768, 256, kernel_size=1, bias=False, ), LayerNorm2d(256), nn.Conv2d( 256, 256, kernel_size=3, padding=1, bias=False, ), LayerNorm2d(256), ] ) self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) self.net_3 = nn.Conv2d( 512, 1024, kernel_size=3, stride=2, padding=1, bias=False ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) x = x + self.pos_embed for blk in self.blocks: x = blk(x) x = x.permute(0, 3, 1, 2) for m in self.neck: x = m(x) x = self.net_2(x) x = self.net_3(x) return x class GexQwenModel(Qwen2Model): config_class = GexConfig def __init__(self, config: Qwen2Config): super().__init__(config) self.vit = GexVit() self.vit.eval() self.vit_proj = nn.Linear(1024, 1024) self.vit_proj.eval() for param in self.vit.parameters(): param.requires_grad = False for param in self.vit_proj.parameters(): param.requires_grad = False def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: if images is not None: assert input_ids is None, input_ids input_ids = None attention_mask = None kwargs["is_causal"] = True with torch.no_grad(): vit_feature = self.vit_proj( self.vit(images).flatten(2).permute(0, 2, 1) ) inputs_embeds = vit_feature # print(input_ids, images) if inputs_embeds is None and input_ids is not None: inputs_embeds = self.embed_tokens(input_ids) return super().forward( input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) class GexQwenForCausalLM(Qwen2ForCausalLM): config_class = GexConfig # supports_gradient_checkpointing = True def __init__(self, config): super().__init__(config) self.model = GexQwenModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() self.has_image = False self.image_preprocess = GexImageEvalProcessor() def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, images: Optional[torch.FloatTensor] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.has_image: input_ids = None self.has_image = False else: images = None outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, images=images, **kwargs, ) hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @torch.no_grad def generate(self,*args,**kwargs): self.has_image = True res = super().generate(*args, **kwargs) self.has_image = False return res