# coding=utf-8 # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Qwen2-VL model.""" import math from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint from torch.nn import LayerNorm from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from transformers.generation import GenerationMixin from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import PreTrainedModel from transformers.utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, is_torchdynamo_compiling, logging from .configuration_qwen2_vl import Qwen2VLVisionConfig if is_flash_attn_available(): from transformers.modeling_flash_attention_utils import _flash_attention_forward, flash_attn_varlen_func if is_torch_flex_attn_available(): from torch.nn.attention.flex_attention import BlockMask from transformers.integrations.flex_attention import make_flex_block_causal_mask logger = logging.get_logger(__name__) # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb_vision( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: orig_q_dtype = q.dtype orig_k_dtype = k.dtype q, k = q.float(), k.float() cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) q_embed = q_embed.to(orig_q_dtype) k_embed = k_embed.to(orig_k_dtype) return q_embed, k_embed class VisionRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, seqlen: int) -> torch.Tensor: seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) freqs = torch.outer(seq, self.inv_freq) return freqs class PatchEmbed(nn.Module): def __init__( self, patch_size: int = 14, temporal_patch_size: int = 2, in_channels: int = 3, embed_dim: int = 1152, ) -> None: super().__init__() self.patch_size = patch_size self.temporal_patch_size = temporal_patch_size self.in_channels = in_channels self.embed_dim = embed_dim kernel_size = [temporal_patch_size, patch_size, patch_size] self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: target_dtype = self.proj.weight.dtype hidden_states = hidden_states.view( -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size ) hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) return hidden_states class PatchMerger(nn.Module): def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) self.ln_q = LayerNorm(context_dim, eps=1e-6) self.mlp = nn.Sequential( nn.Linear(self.hidden_size, self.hidden_size), nn.GELU(), nn.Linear(self.hidden_size, dim), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) return x class VisionMlp(nn.Module): def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None: super().__init__() self.fc1 = nn.Linear(dim, hidden_dim) self.act = ACT2FN[hidden_act] self.fc2 = nn.Linear(hidden_dim, dim) def forward(self, x) -> torch.Tensor: return self.fc2(self.act(self.fc1(x))) class VisionAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() self.num_heads = num_heads self.head_dim = dim // num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " "removed and `position_embeddings` will be mandatory." ) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) cos = emb.cos() sin = emb.sin() else: cos, sin = position_embeddings q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) attention_mask = torch.full( [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype ) for i in range(1, len(cu_seqlens)): attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) attn_output = torch.matmul(attn_weights, v) attn_output = attn_output.transpose(0, 1) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output class VisionFlashAttention2(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " "removed and `position_embeddings` will be mandatory." ) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) cos = emb.cos() sin = emb.sin() else: cos, sin = position_embeddings q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( seq_length, -1 ) attn_output = self.proj(attn_output) return attn_output class VisionSdpaAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) self.proj = nn.Linear(dim, dim) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " "removed and `position_embeddings` will be mandatory." ) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) cos = emb.cos() sin = emb.sin() else: cos, sin = position_embeddings q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) for i in range(1, len(cu_seqlens)): attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) attn_output = F.scaled_dot_product_attention( q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attention_mask, dropout_p=0.0 ) attn_output = attn_output.squeeze(0).transpose(0, 1) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output QWEN2_VL_VISION_ATTENTION_CLASSES = { "eager": VisionAttention, "flash_attention_2": VisionFlashAttention2, "sdpa": VisionSdpaAttention, } class Qwen2VLVisionBlock(nn.Module): def __init__(self, config, attn_implementation: str = "sdpa") -> None: super().__init__() self.norm1 = LayerNorm(config.embed_dim, eps=1e-6) self.norm2 = LayerNorm(config.embed_dim, eps=1e-6) mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio) self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation]( config.embed_dim, num_heads=config.num_heads ) self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states @auto_docstring class Qwen2VLPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` def _init_weights(self, module): std = self.config.get_text_config().initializer_range if isinstance(module, (nn.Linear, nn.Conv3d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.weight.data.fill_(1.0) module.bias.data.zero_() @auto_docstring class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): config_class = Qwen2VLVisionConfig _no_split_modules = ["Qwen2VLVisionBlock"] def __init__(self, config) -> None: super().__init__(config) self.spatial_merge_size = config.spatial_merge_size self.patch_embed = PatchEmbed( patch_size=config.patch_size, temporal_patch_size=config.temporal_patch_size, in_channels=config.in_channels, embed_dim=config.embed_dim, ) head_dim = config.embed_dim // config.num_heads self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList( [Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] ) self.merger = PatchMerger( dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size ) self.gradient_checkpointing = False def get_dtype(self) -> torch.dtype: return self.blocks[0].mlp.fc2.weight.dtype def get_device(self) -> torch.device: return self.blocks[0].mlp.fc2.weight.device def rot_pos_emb(self, grid_thw): pos_ids = [] for t, h, w in grid_thw: hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) hpos_ids = hpos_ids.reshape( h // self.spatial_merge_size, self.spatial_merge_size, w // self.spatial_merge_size, self.spatial_merge_size, ) hpos_ids = hpos_ids.permute(0, 2, 1, 3) hpos_ids = hpos_ids.flatten() wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) wpos_ids = wpos_ids.reshape( h // self.spatial_merge_size, self.spatial_merge_size, w // self.spatial_merge_size, self.spatial_merge_size, ) wpos_ids = wpos_ids.permute(0, 2, 1, 3) wpos_ids = wpos_ids.flatten() pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb @auto_docstring def forward(self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor) -> torch.Tensor: r""" grid_thw (`torch.LongTensor` of shape `(num_images, 3)`): The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values. """ hidden_states = self.patch_embed(pixel_values) rotary_pos_emb = self.rot_pos_emb(image_grid_thw) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum( dim=0, # Select dtype based on the following factors: # - FA2 requires that cu_seqlens_q must have dtype int32 # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw # See https://github.com/huggingface/transformers/pull/34852 for more information dtype=image_grid_thw.dtype if torch.jit.is_tracing() else torch.int32, ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) for blk in self.blocks: if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func( blk.__call__, hidden_states, cu_seqlens, None, position_embeddings ) else: hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings) return self.merger(hidden_states)