Upload folder using huggingface_hub
Browse files- config.json +1 -1
- configuration_qwen2_vl.py +7 -2
- modeling_qwen2_vl.py +237 -73
config.json
CHANGED
@@ -16,5 +16,5 @@
|
|
16 |
"spatial_patch_size": 14,
|
17 |
"temporal_patch_size": 2,
|
18 |
"torch_dtype": "bfloat16",
|
19 |
-
"transformers_version": "4.
|
20 |
}
|
|
|
16 |
"spatial_patch_size": 14,
|
17 |
"temporal_patch_size": 2,
|
18 |
"torch_dtype": "bfloat16",
|
19 |
+
"transformers_version": "4.49.0"
|
20 |
}
|
configuration_qwen2_vl.py
CHANGED
@@ -173,6 +173,11 @@ class Qwen2VLConfig(PretrainedConfig):
|
|
173 |
"layers.*.mlp.up_proj": "colwise",
|
174 |
"layers.*.mlp.down_proj": "rowwise",
|
175 |
}
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
def __init__(
|
178 |
self,
|
@@ -198,9 +203,9 @@ class Qwen2VLConfig(PretrainedConfig):
|
|
198 |
**kwargs,
|
199 |
):
|
200 |
if isinstance(vision_config, dict):
|
201 |
-
self.vision_config =
|
202 |
elif vision_config is None:
|
203 |
-
self.vision_config =
|
204 |
|
205 |
self.vocab_size = vocab_size
|
206 |
self.max_position_embeddings = max_position_embeddings
|
|
|
173 |
"layers.*.mlp.up_proj": "colwise",
|
174 |
"layers.*.mlp.down_proj": "rowwise",
|
175 |
}
|
176 |
+
base_model_pp_plan = {
|
177 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
178 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
179 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
180 |
+
}
|
181 |
|
182 |
def __init__(
|
183 |
self,
|
|
|
203 |
**kwargs,
|
204 |
):
|
205 |
if isinstance(vision_config, dict):
|
206 |
+
self.vision_config = self.sub_configs["vision_config"](**vision_config)
|
207 |
elif vision_config is None:
|
208 |
+
self.vision_config = self.sub_configs["vision_config"]()
|
209 |
|
210 |
self.vocab_size = vocab_size
|
211 |
self.max_position_embeddings = max_position_embeddings
|
modeling_qwen2_vl.py
CHANGED
@@ -21,7 +21,7 @@
|
|
21 |
|
22 |
import math
|
23 |
from dataclasses import dataclass
|
24 |
-
from typing import List, Optional, Tuple, Union
|
25 |
|
26 |
import torch
|
27 |
import torch.nn as nn
|
@@ -41,6 +41,7 @@ from transformers.utils import (
|
|
41 |
add_start_docstrings_to_model_forward,
|
42 |
is_flash_attn_2_available,
|
43 |
is_flash_attn_greater_or_equal_2_10,
|
|
|
44 |
logging,
|
45 |
replace_return_docstrings,
|
46 |
)
|
@@ -100,47 +101,20 @@ class Qwen2VLCausalLMOutputWithPast(ModelOutput):
|
|
100 |
|
101 |
|
102 |
class Qwen2VLRotaryEmbedding(nn.Module):
|
103 |
-
def __init__(
|
104 |
-
self,
|
105 |
-
dim=None,
|
106 |
-
max_position_embeddings=2048,
|
107 |
-
base=10000,
|
108 |
-
device=None,
|
109 |
-
scaling_factor=1.0,
|
110 |
-
rope_type="default",
|
111 |
-
config: Optional[Qwen2VLConfig] = None,
|
112 |
-
):
|
113 |
super().__init__()
|
114 |
-
#
|
115 |
-
|
116 |
-
|
117 |
-
logger.warning_once(
|
118 |
-
"`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
119 |
-
"`config` argument. All other arguments will be removed in v4.46"
|
120 |
-
)
|
121 |
-
self.rope_kwargs = {
|
122 |
-
"rope_type": rope_type,
|
123 |
-
"factor": scaling_factor,
|
124 |
-
"dim": dim,
|
125 |
-
"base": base,
|
126 |
-
"max_position_embeddings": max_position_embeddings,
|
127 |
-
}
|
128 |
-
self.rope_type = rope_type
|
129 |
-
self.max_seq_len_cached = max_position_embeddings
|
130 |
-
self.original_max_seq_len = max_position_embeddings
|
131 |
else:
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
else:
|
136 |
-
self.rope_type = "default"
|
137 |
-
self.max_seq_len_cached = config.max_position_embeddings
|
138 |
-
self.original_max_seq_len = config.max_position_embeddings
|
139 |
|
140 |
self.config = config
|
141 |
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
142 |
|
143 |
-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device
|
144 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
145 |
self.original_inv_freq = self.inv_freq
|
146 |
|
@@ -240,16 +214,18 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
|
|
240 |
return q_embed, k_embed
|
241 |
|
242 |
|
243 |
-
def apply_rotary_pos_emb_vision(
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
sin =
|
250 |
-
|
251 |
-
|
252 |
-
|
|
|
|
|
253 |
|
254 |
|
255 |
class VisionRotaryEmbedding(nn.Module):
|
@@ -326,12 +302,27 @@ class VisionAttention(nn.Module):
|
|
326 |
self.proj = nn.Linear(dim, dim)
|
327 |
|
328 |
def forward(
|
329 |
-
self,
|
|
|
|
|
|
|
|
|
330 |
) -> torch.Tensor:
|
331 |
seq_length = hidden_states.shape[0]
|
332 |
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
333 |
-
|
334 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
|
336 |
attention_mask = torch.full(
|
337 |
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
@@ -360,12 +351,27 @@ class VisionFlashAttention2(nn.Module):
|
|
360 |
self.proj = nn.Linear(dim, dim)
|
361 |
|
362 |
def forward(
|
363 |
-
self,
|
|
|
|
|
|
|
|
|
364 |
) -> torch.Tensor:
|
365 |
seq_length = hidden_states.shape[0]
|
366 |
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
367 |
-
|
368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
|
370 |
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
371 |
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
@@ -383,12 +389,27 @@ class VisionSdpaAttention(nn.Module):
|
|
383 |
self.proj = nn.Linear(dim, dim)
|
384 |
|
385 |
def forward(
|
386 |
-
self,
|
|
|
|
|
|
|
|
|
387 |
) -> torch.Tensor:
|
388 |
seq_length = hidden_states.shape[0]
|
389 |
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
390 |
-
|
391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
|
393 |
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
394 |
for i in range(1, len(cu_seqlens)):
|
@@ -422,9 +443,18 @@ class Qwen2VLVisionBlock(nn.Module):
|
|
422 |
)
|
423 |
self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
|
424 |
|
425 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
hidden_states = hidden_states + self.attn(
|
427 |
-
self.norm1(hidden_states),
|
|
|
|
|
|
|
428 |
)
|
429 |
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
430 |
return hidden_states
|
@@ -503,8 +533,6 @@ class Qwen2VLAttention(nn.Module):
|
|
503 |
self.head_dim = self.hidden_size // self.num_heads
|
504 |
self.num_key_value_heads = config.num_key_value_heads
|
505 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
506 |
-
self.max_position_embeddings = config.max_position_embeddings
|
507 |
-
self.rope_theta = config.rope_theta
|
508 |
self.is_causal = True
|
509 |
self.attention_dropout = config.attention_dropout
|
510 |
self.rope_scaling = config.rope_scaling
|
@@ -519,11 +547,7 @@ class Qwen2VLAttention(nn.Module):
|
|
519 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
520 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
521 |
|
522 |
-
self.rotary_emb = Qwen2VLRotaryEmbedding(
|
523 |
-
self.head_dim,
|
524 |
-
max_position_embeddings=self.max_position_embeddings,
|
525 |
-
base=self.rope_theta,
|
526 |
-
)
|
527 |
|
528 |
def forward(
|
529 |
self,
|
@@ -915,7 +939,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
|
|
915 |
_supports_flash_attn_2 = True
|
916 |
_supports_sdpa = True
|
917 |
_supports_cache_class = True
|
918 |
-
_supports_static_cache =
|
919 |
|
920 |
def _init_weights(self, module):
|
921 |
std = self.config.initializer_range
|
@@ -993,6 +1017,8 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
|
993 |
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
994 |
hidden_states = self.patch_embed(hidden_states)
|
995 |
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
|
|
|
|
996 |
|
997 |
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
998 |
dim=0,
|
@@ -1007,10 +1033,10 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
|
1007 |
for blk in self.blocks:
|
1008 |
if self.gradient_checkpointing and self.training:
|
1009 |
hidden_states = self._gradient_checkpointing_func(
|
1010 |
-
blk.__call__, hidden_states, cu_seqlens,
|
1011 |
)
|
1012 |
else:
|
1013 |
-
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens,
|
1014 |
|
1015 |
return self.merger(hidden_states)
|
1016 |
|
@@ -1234,7 +1260,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|
1234 |
if (
|
1235 |
self.config._attn_implementation == "sdpa"
|
1236 |
and attention_mask is not None
|
1237 |
-
and attention_mask.device.type
|
1238 |
and not output_attentions
|
1239 |
):
|
1240 |
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
@@ -1305,7 +1331,9 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|
1305 |
if attention_mask.shape[-1] > target_length:
|
1306 |
attention_mask = attention_mask[:, :target_length]
|
1307 |
mask_length = attention_mask.shape[-1]
|
1308 |
-
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
|
|
|
|
1309 |
padding_mask = padding_mask == 0
|
1310 |
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
1311 |
padding_mask, min_dtype
|
@@ -1486,7 +1514,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
|
1486 |
)
|
1487 |
image_index, video_index = 0, 0
|
1488 |
for i, input_ids in enumerate(total_input_ids):
|
1489 |
-
input_ids = input_ids[attention_mask[i] == 1]
|
1490 |
image_nums, video_nums = 0, 0
|
1491 |
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
|
1492 |
vision_tokens = input_ids[vision_start_indices + 1]
|
@@ -1681,7 +1709,11 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
|
1681 |
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
1682 |
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
1683 |
# calculate RoPE index once per generation in the pre-fill stage only
|
1684 |
-
if (
|
|
|
|
|
|
|
|
|
1685 |
position_ids, rope_deltas = self.get_rope_index(
|
1686 |
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
1687 |
)
|
@@ -1694,6 +1726,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
|
1694 |
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
1695 |
if cache_position is not None: # otherwise `deltas` is an int `0`
|
1696 |
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
|
|
1697 |
position_ids = position_ids.add(delta)
|
1698 |
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
1699 |
|
@@ -1761,8 +1794,17 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
|
1761 |
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
1762 |
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
1763 |
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
|
|
|
|
|
|
|
|
1764 |
if past_key_values is not None:
|
1765 |
-
if inputs_embeds is not None: # Exception
|
|
|
|
|
|
|
|
|
|
|
1766 |
input_ids = input_ids[:, -cache_position.shape[0] :]
|
1767 |
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
1768 |
input_ids = input_ids[:, cache_position]
|
@@ -1772,7 +1814,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
|
1772 |
pixel_values_videos = None
|
1773 |
|
1774 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1775 |
-
if inputs_embeds is not None and cache_position
|
1776 |
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
1777 |
else:
|
1778 |
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
|
@@ -1812,5 +1854,127 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
|
1812 |
)
|
1813 |
return model_inputs
|
1814 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1815 |
|
1816 |
__all__ = ["Qwen2VLForConditionalGeneration", "Qwen2VLModel", "Qwen2VLPreTrainedModel"]
|
|
|
21 |
|
22 |
import math
|
23 |
from dataclasses import dataclass
|
24 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
25 |
|
26 |
import torch
|
27 |
import torch.nn as nn
|
|
|
41 |
add_start_docstrings_to_model_forward,
|
42 |
is_flash_attn_2_available,
|
43 |
is_flash_attn_greater_or_equal_2_10,
|
44 |
+
is_torchdynamo_compiling,
|
45 |
logging,
|
46 |
replace_return_docstrings,
|
47 |
)
|
|
|
101 |
|
102 |
|
103 |
class Qwen2VLRotaryEmbedding(nn.Module):
|
104 |
+
def __init__(self, config: Qwen2VLConfig, device=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
super().__init__()
|
106 |
+
# BC: "rope_type" was originally "type"
|
107 |
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
108 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
else:
|
110 |
+
self.rope_type = "default"
|
111 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
112 |
+
self.original_max_seq_len = config.max_position_embeddings
|
|
|
|
|
|
|
|
|
113 |
|
114 |
self.config = config
|
115 |
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
116 |
|
117 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
118 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
119 |
self.original_inv_freq = self.inv_freq
|
120 |
|
|
|
214 |
return q_embed, k_embed
|
215 |
|
216 |
|
217 |
+
def apply_rotary_pos_emb_vision(
|
218 |
+
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
219 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
220 |
+
orig_q_dtype = q.dtype
|
221 |
+
orig_k_dtype = k.dtype
|
222 |
+
q, k = q.float(), k.float()
|
223 |
+
cos, sin = cos.unsqueeze(-2), sin.unsqueeze(-2)
|
224 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
225 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
226 |
+
q_embed = q_embed.to(orig_q_dtype)
|
227 |
+
k_embed = k_embed.to(orig_k_dtype)
|
228 |
+
return q_embed, k_embed
|
229 |
|
230 |
|
231 |
class VisionRotaryEmbedding(nn.Module):
|
|
|
302 |
self.proj = nn.Linear(dim, dim)
|
303 |
|
304 |
def forward(
|
305 |
+
self,
|
306 |
+
hidden_states: torch.Tensor,
|
307 |
+
cu_seqlens: torch.Tensor,
|
308 |
+
rotary_pos_emb: Optional[torch.Tensor] = None,
|
309 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
310 |
) -> torch.Tensor:
|
311 |
seq_length = hidden_states.shape[0]
|
312 |
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
313 |
+
if position_embeddings is None:
|
314 |
+
logger.warning_once(
|
315 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
316 |
+
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
317 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
318 |
+
"removed and `position_embeddings` will be mandatory."
|
319 |
+
)
|
320 |
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
321 |
+
cos = emb.cos().float()
|
322 |
+
sin = emb.sin().float()
|
323 |
+
else:
|
324 |
+
cos, sin = position_embeddings
|
325 |
+
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
|
326 |
|
327 |
attention_mask = torch.full(
|
328 |
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
|
|
351 |
self.proj = nn.Linear(dim, dim)
|
352 |
|
353 |
def forward(
|
354 |
+
self,
|
355 |
+
hidden_states: torch.Tensor,
|
356 |
+
cu_seqlens: torch.Tensor,
|
357 |
+
rotary_pos_emb: Optional[torch.Tensor] = None,
|
358 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
359 |
) -> torch.Tensor:
|
360 |
seq_length = hidden_states.shape[0]
|
361 |
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
362 |
+
if position_embeddings is None:
|
363 |
+
logger.warning_once(
|
364 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
365 |
+
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
366 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
367 |
+
"removed and `position_embeddings` will be mandatory."
|
368 |
+
)
|
369 |
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
370 |
+
cos = emb.cos().float()
|
371 |
+
sin = emb.sin().float()
|
372 |
+
else:
|
373 |
+
cos, sin = position_embeddings
|
374 |
+
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
|
375 |
|
376 |
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
377 |
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
|
|
389 |
self.proj = nn.Linear(dim, dim)
|
390 |
|
391 |
def forward(
|
392 |
+
self,
|
393 |
+
hidden_states: torch.Tensor,
|
394 |
+
cu_seqlens: torch.Tensor,
|
395 |
+
rotary_pos_emb: Optional[torch.Tensor] = None,
|
396 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
397 |
) -> torch.Tensor:
|
398 |
seq_length = hidden_states.shape[0]
|
399 |
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
400 |
+
if position_embeddings is None:
|
401 |
+
logger.warning_once(
|
402 |
+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
403 |
+
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
|
404 |
+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
|
405 |
+
"removed and `position_embeddings` will be mandatory."
|
406 |
+
)
|
407 |
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
408 |
+
cos = emb.cos().float()
|
409 |
+
sin = emb.sin().float()
|
410 |
+
else:
|
411 |
+
cos, sin = position_embeddings
|
412 |
+
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
|
413 |
|
414 |
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
415 |
for i in range(1, len(cu_seqlens)):
|
|
|
443 |
)
|
444 |
self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
|
445 |
|
446 |
+
def forward(
|
447 |
+
self,
|
448 |
+
hidden_states: torch.Tensor,
|
449 |
+
cu_seqlens: torch.Tensor,
|
450 |
+
rotary_pos_emb: Optional[torch.Tensor] = None,
|
451 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
452 |
+
) -> torch.Tensor:
|
453 |
hidden_states = hidden_states + self.attn(
|
454 |
+
self.norm1(hidden_states),
|
455 |
+
cu_seqlens=cu_seqlens,
|
456 |
+
rotary_pos_emb=rotary_pos_emb,
|
457 |
+
position_embeddings=position_embeddings,
|
458 |
)
|
459 |
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
460 |
return hidden_states
|
|
|
533 |
self.head_dim = self.hidden_size // self.num_heads
|
534 |
self.num_key_value_heads = config.num_key_value_heads
|
535 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
|
|
|
536 |
self.is_causal = True
|
537 |
self.attention_dropout = config.attention_dropout
|
538 |
self.rope_scaling = config.rope_scaling
|
|
|
547 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
548 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
549 |
|
550 |
+
self.rotary_emb = Qwen2VLRotaryEmbedding(config=config)
|
|
|
|
|
|
|
|
|
551 |
|
552 |
def forward(
|
553 |
self,
|
|
|
939 |
_supports_flash_attn_2 = True
|
940 |
_supports_sdpa = True
|
941 |
_supports_cache_class = True
|
942 |
+
_supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
|
943 |
|
944 |
def _init_weights(self, module):
|
945 |
std = self.config.initializer_range
|
|
|
1017 |
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
1018 |
hidden_states = self.patch_embed(hidden_states)
|
1019 |
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
1020 |
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
1021 |
+
position_embeddings = (emb.cos(), emb.sin())
|
1022 |
|
1023 |
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
1024 |
dim=0,
|
|
|
1033 |
for blk in self.blocks:
|
1034 |
if self.gradient_checkpointing and self.training:
|
1035 |
hidden_states = self._gradient_checkpointing_func(
|
1036 |
+
blk.__call__, hidden_states, cu_seqlens, None, position_embeddings
|
1037 |
)
|
1038 |
else:
|
1039 |
+
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
|
1040 |
|
1041 |
return self.merger(hidden_states)
|
1042 |
|
|
|
1260 |
if (
|
1261 |
self.config._attn_implementation == "sdpa"
|
1262 |
and attention_mask is not None
|
1263 |
+
and attention_mask.device.type in ["cuda", "xpu"]
|
1264 |
and not output_attentions
|
1265 |
):
|
1266 |
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
|
|
1331 |
if attention_mask.shape[-1] > target_length:
|
1332 |
attention_mask = attention_mask[:, :target_length]
|
1333 |
mask_length = attention_mask.shape[-1]
|
1334 |
+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
|
1335 |
+
causal_mask.device
|
1336 |
+
)
|
1337 |
padding_mask = padding_mask == 0
|
1338 |
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
1339 |
padding_mask, min_dtype
|
|
|
1514 |
)
|
1515 |
image_index, video_index = 0, 0
|
1516 |
for i, input_ids in enumerate(total_input_ids):
|
1517 |
+
input_ids = input_ids[attention_mask[i].to(input_ids.device) == 1]
|
1518 |
image_nums, video_nums = 0, 0
|
1519 |
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
|
1520 |
vision_tokens = input_ids[vision_start_indices + 1]
|
|
|
1709 |
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
1710 |
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
1711 |
# calculate RoPE index once per generation in the pre-fill stage only
|
1712 |
+
if (
|
1713 |
+
(cache_position is not None and cache_position[0] == 0)
|
1714 |
+
or self.rope_deltas is None
|
1715 |
+
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
1716 |
+
):
|
1717 |
position_ids, rope_deltas = self.get_rope_index(
|
1718 |
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
1719 |
)
|
|
|
1726 |
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
1727 |
if cache_position is not None: # otherwise `deltas` is an int `0`
|
1728 |
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
1729 |
+
delta = delta.to(position_ids.device)
|
1730 |
position_ids = position_ids.add(delta)
|
1731 |
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
1732 |
|
|
|
1794 |
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
|
1795 |
# Exception 1: when passing input_embeds, input_ids may be missing entries
|
1796 |
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
|
1797 |
+
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
|
1798 |
+
# (we can't check exception 3 while compiling)
|
1799 |
+
# Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
|
1800 |
+
# generate the first token for each sequence. Later use the generated Input ids for continuation.
|
1801 |
if past_key_values is not None:
|
1802 |
+
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
|
1803 |
+
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
|
1804 |
+
elif (
|
1805 |
+
inputs_embeds is not None # Exception 1
|
1806 |
+
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
|
1807 |
+
):
|
1808 |
input_ids = input_ids[:, -cache_position.shape[0] :]
|
1809 |
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
|
1810 |
input_ids = input_ids[:, cache_position]
|
|
|
1814 |
pixel_values_videos = None
|
1815 |
|
1816 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1817 |
+
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
|
1818 |
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
|
1819 |
else:
|
1820 |
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
|
|
|
1854 |
)
|
1855 |
return model_inputs
|
1856 |
|
1857 |
+
def _get_image_nums_and_video_nums(
|
1858 |
+
self,
|
1859 |
+
input_ids: Optional[torch.LongTensor],
|
1860 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1861 |
+
"""
|
1862 |
+
Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
|
1863 |
+
These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
|
1864 |
+
|
1865 |
+
Args:
|
1866 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
1867 |
+
Indices of input sequence tokens in the vocabulary.
|
1868 |
+
|
1869 |
+
Returns:
|
1870 |
+
image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
|
1871 |
+
video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
|
1872 |
+
"""
|
1873 |
+
image_token_id = self.config.image_token_id
|
1874 |
+
video_token_id = self.config.video_token_id
|
1875 |
+
vision_start_token_id = self.config.vision_start_token_id
|
1876 |
+
|
1877 |
+
vision_start_mask = input_ids == vision_start_token_id
|
1878 |
+
vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
|
1879 |
+
image_mask = input_ids == image_token_id
|
1880 |
+
video_mask = input_ids == video_token_id
|
1881 |
+
image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
|
1882 |
+
video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
|
1883 |
+
|
1884 |
+
return image_nums, video_nums
|
1885 |
+
|
1886 |
+
def _expand_inputs_for_generation(
|
1887 |
+
self,
|
1888 |
+
expand_size: int = 1,
|
1889 |
+
is_encoder_decoder: bool = False,
|
1890 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1891 |
+
**model_kwargs,
|
1892 |
+
) -> Tuple[torch.LongTensor, Dict[str, Any]]:
|
1893 |
+
# Overwritten -- Support for expanding tensors without a batch size dimension
|
1894 |
+
# e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
|
1895 |
+
# pixel_values.shape[0] is sum(seqlen_images for samples)
|
1896 |
+
# image_grid_thw.shape[0] is sum(num_images for samples)
|
1897 |
+
|
1898 |
+
if expand_size == 1:
|
1899 |
+
return input_ids, model_kwargs
|
1900 |
+
|
1901 |
+
visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
|
1902 |
+
|
1903 |
+
def _expand_dict_for_generation_visual(dict_to_expand):
|
1904 |
+
image_grid_thw = model_kwargs.get("image_grid_thw", None)
|
1905 |
+
video_grid_thw = model_kwargs.get("video_grid_thw", None)
|
1906 |
+
image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids)
|
1907 |
+
|
1908 |
+
def _repeat_interleave_samples(x, lengths, repeat_times):
|
1909 |
+
samples = torch.split(x, lengths)
|
1910 |
+
repeat_args = [repeat_times] + [1] * (x.dim() - 1)
|
1911 |
+
result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
|
1912 |
+
return result
|
1913 |
+
|
1914 |
+
for key in dict_to_expand:
|
1915 |
+
if key == "pixel_values":
|
1916 |
+
# split images into samples
|
1917 |
+
samples = torch.split(image_grid_thw, list(image_nums))
|
1918 |
+
# compute the sequence length of images for each sample
|
1919 |
+
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
|
1920 |
+
dict_to_expand[key] = _repeat_interleave_samples(
|
1921 |
+
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
1922 |
+
)
|
1923 |
+
elif key == "image_grid_thw":
|
1924 |
+
# get the num of images for each sample
|
1925 |
+
lengths = list(image_nums)
|
1926 |
+
dict_to_expand[key] = _repeat_interleave_samples(
|
1927 |
+
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
1928 |
+
)
|
1929 |
+
elif key == "pixel_values_videos":
|
1930 |
+
samples = torch.split(video_grid_thw, list(video_nums))
|
1931 |
+
lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
|
1932 |
+
dict_to_expand[key] = _repeat_interleave_samples(
|
1933 |
+
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
1934 |
+
)
|
1935 |
+
elif key == "video_grid_thw":
|
1936 |
+
lengths = list(video_nums)
|
1937 |
+
dict_to_expand[key] = _repeat_interleave_samples(
|
1938 |
+
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
1939 |
+
)
|
1940 |
+
elif key == "second_per_grid_ts":
|
1941 |
+
if not isinstance(dict_to_expand[key], list):
|
1942 |
+
raise TypeError(
|
1943 |
+
f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead."
|
1944 |
+
)
|
1945 |
+
tensor = torch.tensor(dict_to_expand[key])
|
1946 |
+
lengths = list(video_nums)
|
1947 |
+
tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size)
|
1948 |
+
dict_to_expand[key] = tensor.tolist()
|
1949 |
+
return dict_to_expand
|
1950 |
+
|
1951 |
+
def _expand_dict_for_generation(dict_to_expand):
|
1952 |
+
for key in dict_to_expand:
|
1953 |
+
if (
|
1954 |
+
key != "cache_position"
|
1955 |
+
and dict_to_expand[key] is not None
|
1956 |
+
and isinstance(dict_to_expand[key], torch.Tensor)
|
1957 |
+
and key not in visual_keys
|
1958 |
+
):
|
1959 |
+
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
|
1960 |
+
return dict_to_expand
|
1961 |
+
|
1962 |
+
# input_ids is required for expanding visual inputs
|
1963 |
+
# If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs.
|
1964 |
+
if input_ids is not None and input_ids.numel() != 0:
|
1965 |
+
model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
|
1966 |
+
|
1967 |
+
if input_ids is not None:
|
1968 |
+
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
|
1969 |
+
|
1970 |
+
model_kwargs = _expand_dict_for_generation(model_kwargs)
|
1971 |
+
|
1972 |
+
if is_encoder_decoder:
|
1973 |
+
if model_kwargs.get("encoder_outputs") is None:
|
1974 |
+
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
|
1975 |
+
model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
|
1976 |
+
|
1977 |
+
return input_ids, model_kwargs
|
1978 |
+
|
1979 |
|
1980 |
__all__ = ["Qwen2VLForConditionalGeneration", "Qwen2VLModel", "Qwen2VLPreTrainedModel"]
|