Spaces:
Running
on
Zero
Running
on
Zero
# coding=utf-8 | |
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. | |
# | |
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX | |
# and OPT implementations in this library. It has been modified from its | |
# original forms to accommodate minor architectural differences compared | |
# to GPT-NeoX and OPT used by the Meta AI team that trained the model. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""PyTorch Qwen2-VL model.""" | |
import math | |
from dataclasses import dataclass | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils.checkpoint | |
from torch.nn import CrossEntropyLoss, LayerNorm | |
from transformers.activations import ACT2FN | |
from transformers.cache_utils import Cache, StaticCache | |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter | |
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput | |
from transformers.modeling_utils import PreTrainedModel | |
from transformers.utils import (add_start_docstrings, | |
add_start_docstrings_to_model_forward, | |
is_flash_attn_2_available, | |
is_flash_attn_greater_or_equal_2_10, logging, | |
replace_return_docstrings) | |
from .configuration_qwen2vl_encoder import Qwen2VLVisionConfig | |
if is_flash_attn_2_available(): | |
from flash_attn import flash_attn_varlen_func | |
from transformers.modeling_flash_attention_utils import \ | |
_flash_attention_forward | |
else: | |
flash_attn_varlen_func = None | |
logger = logging.get_logger(__name__) | |
# Copied from transformers.models.llama.modeling_llama.rotate_half | |
def rotate_half(x): | |
"""Rotates half the hidden dims of the input.""" | |
x1 = x[..., : x.shape[-1] // 2] | |
x2 = x[..., x.shape[-1] // 2 :] | |
return torch.cat((-x2, x1), dim=-1) | |
def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: | |
orig_dtype = tensor.dtype | |
tensor = tensor.float() | |
cos = freqs.cos() | |
sin = freqs.sin() | |
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() | |
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() | |
output = (tensor * cos) + (rotate_half(tensor) * sin) | |
output = output.to(orig_dtype) | |
return output | |
class VisionRotaryEmbedding(nn.Module): | |
def __init__(self, dim: int, theta: float = 10000.0) -> None: | |
super().__init__() | |
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) | |
self.register_buffer("inv_freq", inv_freq, persistent=False) | |
def forward(self, seqlen: int) -> torch.Tensor: | |
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) | |
freqs = torch.outer(seq, self.inv_freq) | |
return freqs | |
class PatchEmbed(nn.Module): | |
def __init__( | |
self, | |
patch_size: int = 14, | |
temporal_patch_size: int = 2, | |
in_channels: int = 3, | |
embed_dim: int = 1152, | |
) -> None: | |
super().__init__() | |
self.patch_size = patch_size | |
self.temporal_patch_size = temporal_patch_size | |
self.in_channels = in_channels | |
self.embed_dim = embed_dim | |
kernel_size = [temporal_patch_size, patch_size, patch_size] | |
self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) | |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
target_dtype = self.proj.weight.dtype | |
hidden_states = hidden_states.view( | |
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size | |
) | |
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) | |
return hidden_states | |
class PatchMerger(nn.Module): | |
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: | |
super().__init__() | |
self.hidden_size = context_dim * (spatial_merge_size**2) | |
self.ln_q = LayerNorm(context_dim, eps=1e-6) | |
self.mlp = nn.Sequential( | |
nn.Linear(self.hidden_size, self.hidden_size), | |
nn.GELU(), | |
nn.Linear(self.hidden_size, dim), | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) | |
return x | |
class VisionMlp(nn.Module): | |
def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None: | |
super().__init__() | |
self.fc1 = nn.Linear(dim, hidden_dim) | |
self.act = ACT2FN[hidden_act] | |
self.fc2 = nn.Linear(hidden_dim, dim) | |
def forward(self, x) -> torch.Tensor: | |
return self.fc2(self.act(self.fc1(x))) | |
class VisionAttention(nn.Module): | |
def __init__(self, dim: int, num_heads: int = 16) -> None: | |
super().__init__() | |
self.num_heads = num_heads | |
self.head_dim = dim // num_heads | |
self.qkv = nn.Linear(dim, dim * 3, bias=True) | |
self.proj = nn.Linear(dim, dim) | |
def forward( | |
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None | |
) -> torch.Tensor: | |
seq_length = hidden_states.shape[0] | |
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) | |
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) | |
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) | |
attention_mask = torch.full( | |
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype | |
) | |
for i in range(1, len(cu_seqlens)): | |
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 | |
q = q.transpose(0, 1) | |
k = k.transpose(0, 1) | |
v = v.transpose(0, 1) | |
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) | |
attn_weights = attn_weights + attention_mask | |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) | |
attn_output = torch.matmul(attn_weights, v) | |
attn_output = attn_output.transpose(0, 1) | |
attn_output = attn_output.reshape(seq_length, -1) | |
attn_output = self.proj(attn_output) | |
return attn_output | |
class VisionFlashAttention2(nn.Module): | |
def __init__(self, dim: int, num_heads: int = 16) -> None: | |
super().__init__() | |
self.num_heads = num_heads | |
self.qkv = nn.Linear(dim, dim * 3, bias=True) | |
self.proj = nn.Linear(dim, dim) | |
def forward( | |
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None | |
) -> torch.Tensor: | |
seq_length = hidden_states.shape[0] | |
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) | |
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) | |
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) | |
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() | |
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( | |
seq_length, -1 | |
) | |
attn_output = self.proj(attn_output) | |
return attn_output | |
class VisionSdpaAttention(nn.Module): | |
def __init__(self, dim: int, num_heads: int = 16) -> None: | |
super().__init__() | |
self.num_heads = num_heads | |
self.qkv = nn.Linear(dim, dim * 3, bias=True) | |
self.proj = nn.Linear(dim, dim) | |
def forward( | |
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None | |
) -> torch.Tensor: | |
seq_length = hidden_states.shape[0] | |
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) | |
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) | |
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) | |
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) | |
for i in range(1, len(cu_seqlens)): | |
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True | |
q = q.transpose(0, 1) | |
k = k.transpose(0, 1) | |
v = v.transpose(0, 1) | |
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) | |
attn_output = attn_output.transpose(0, 1) | |
attn_output = attn_output.reshape(seq_length, -1) | |
attn_output = self.proj(attn_output) | |
return attn_output | |
QWEN2_VL_VISION_ATTENTION_CLASSES = { | |
"eager": VisionAttention, | |
"flash_attention_2": VisionFlashAttention2, | |
"sdpa": VisionSdpaAttention, | |
} | |
class Qwen2VLVisionBlock(nn.Module): | |
def __init__(self, config, attn_implementation: str = "sdpa") -> None: | |
super().__init__() | |
self.norm1 = LayerNorm(config.embed_dim, eps=1e-6) | |
self.norm2 = LayerNorm(config.embed_dim, eps=1e-6) | |
mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio) | |
self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation]( | |
config.embed_dim, num_heads=config.num_heads | |
) | |
self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act) | |
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: | |
hidden_states = hidden_states + self.attn( | |
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb | |
) | |
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) | |
return hidden_states | |
class Qwen2VLPreTrainedModel(PreTrainedModel): | |
config_class = Qwen2VLVisionConfig | |
base_model_prefix = "model" | |
supports_gradient_checkpointing = True | |
_no_split_modules = ["Qwen2VLVisionBlock"] | |
_skip_keys_device_placement = "past_key_values" | |
_supports_flash_attn_2 = True | |
_supports_sdpa = True | |
_supports_cache_class = True | |
_supports_static_cache = True | |
def _init_weights(self, module): | |
std = self.config.initializer_range | |
if isinstance(module, (nn.Linear, nn.Conv3d)): | |
module.weight.data.normal_(mean=0.0, std=std) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.Embedding): | |
module.weight.data.normal_(mean=0.0, std=std) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |
class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): | |
config_class = Qwen2VLVisionConfig | |
_no_split_modules = ["Qwen2VLVisionBlock"] | |
def __init__(self, config) -> None: | |
super().__init__(config) | |
self.spatial_merge_size = config.spatial_merge_size | |
self.gradient_checkpointing = False | |
self.patch_embed = PatchEmbed( | |
patch_size=config.patch_size, | |
temporal_patch_size=config.temporal_patch_size, | |
in_channels=config.in_channels, | |
embed_dim=config.embed_dim, | |
) | |
head_dim = config.embed_dim // config.num_heads | |
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) | |
self.blocks = nn.ModuleList( | |
[Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] | |
) | |
# | |
# if self.spatial_merge_size > 1: | |
# self.merger = PatchMerger(dim=config.hidden_size, context_dim=config.embed_dim) | |
def get_dtype(self) -> torch.dtype: | |
return self.blocks[0].mlp.fc2.weight.dtype | |
def get_device(self) -> torch.device: | |
return self.blocks[0].mlp.fc2.weight.device | |
def rot_pos_emb(self, grid_thw, strides): | |
pos_ids = [] | |
for (t, h, w), stride in zip(grid_thw, strides): | |
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) | |
hpos_ids = hpos_ids.reshape( | |
h // stride, | |
stride, | |
w // stride, | |
stride, | |
) | |
hpos_ids = hpos_ids.permute(0, 2, 1, 3) | |
hpos_ids = hpos_ids.flatten() | |
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) | |
wpos_ids = wpos_ids.reshape( | |
h // stride, | |
stride, | |
w // stride, | |
stride, | |
) | |
wpos_ids = wpos_ids.permute(0, 2, 1, 3) | |
wpos_ids = wpos_ids.flatten() | |
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) | |
pos_ids = torch.cat(pos_ids, dim=0) | |
max_grid_size = grid_thw[:, 1:].max() | |
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) | |
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) | |
return rotary_pos_emb | |
def forward(self, hidden_states, grid_thws, strides) -> torch.Tensor: | |
hidden_states = self.patch_embed(hidden_states) | |
# BUG: These codes will cause deepspeed issue: `RuntimeError: disagreement between rank0 and rankx` | |
# rotary_pos_emb = [] | |
# for thw in grid_thws: | |
# rotary_pos_emb.append(self.rot_pos_emb(thw).unsqueeze(0)) | |
# rotary_pos_emb1 = torch.cat(rotary_pos_emb, dim=1).squeeze(0) | |
# grid_thws = torch.cat(grid_thws, dim = 0) | |
# new version of creating rotary position embedding | |
# grid_thws shapes like [batch_flatten_image_num, 3] | |
# grid_thws = torch.cat(grid_thws, dim = 0) # is conducted in the `encoder.py` | |
rotary_pos_emb = self.rot_pos_emb(grid_thws, strides) | |
cu_seqlens = torch.repeat_interleave(grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]).cumsum(dim=0, dtype=torch.int32) | |
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) | |
for blk in self.blocks: | |
if self.gradient_checkpointing and self.training: | |
hidden_states = self._gradient_checkpointing_func( | |
blk.__call__, | |
hidden_states, | |
cu_seqlens, | |
rotary_pos_emb | |
) | |
else: | |
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) | |
# if self.spatial_merge_size > 1: | |
# hidden_states = self.merger(hidden_states) | |
return hidden_states | |