shilinxu commited on
Commit
b5f73a9
·
verified ·
1 Parent(s): d990e8e

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. config.json +1 -1
  2. configuration_qwen2_vl.py +7 -2
  3. modeling_qwen2_vl.py +237 -73
config.json CHANGED
@@ -16,5 +16,5 @@
16
  "spatial_patch_size": 14,
17
  "temporal_patch_size": 2,
18
  "torch_dtype": "bfloat16",
19
- "transformers_version": "4.48.3"
20
  }
 
16
  "spatial_patch_size": 14,
17
  "temporal_patch_size": 2,
18
  "torch_dtype": "bfloat16",
19
+ "transformers_version": "4.49.0"
20
  }
configuration_qwen2_vl.py CHANGED
@@ -173,6 +173,11 @@ class Qwen2VLConfig(PretrainedConfig):
173
  "layers.*.mlp.up_proj": "colwise",
174
  "layers.*.mlp.down_proj": "rowwise",
175
  }
 
 
 
 
 
176
 
177
  def __init__(
178
  self,
@@ -198,9 +203,9 @@ class Qwen2VLConfig(PretrainedConfig):
198
  **kwargs,
199
  ):
200
  if isinstance(vision_config, dict):
201
- self.vision_config = Qwen2VLVisionConfig(**vision_config)
202
  elif vision_config is None:
203
- self.vision_config = Qwen2VLVisionConfig()
204
 
205
  self.vocab_size = vocab_size
206
  self.max_position_embeddings = max_position_embeddings
 
173
  "layers.*.mlp.up_proj": "colwise",
174
  "layers.*.mlp.down_proj": "rowwise",
175
  }
176
+ base_model_pp_plan = {
177
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
178
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
179
+ "norm": (["hidden_states"], ["hidden_states"]),
180
+ }
181
 
182
  def __init__(
183
  self,
 
203
  **kwargs,
204
  ):
205
  if isinstance(vision_config, dict):
206
+ self.vision_config = self.sub_configs["vision_config"](**vision_config)
207
  elif vision_config is None:
208
+ self.vision_config = self.sub_configs["vision_config"]()
209
 
210
  self.vocab_size = vocab_size
211
  self.max_position_embeddings = max_position_embeddings
modeling_qwen2_vl.py CHANGED
@@ -21,7 +21,7 @@
21
 
22
  import math
23
  from dataclasses import dataclass
24
- from typing import List, Optional, Tuple, Union
25
 
26
  import torch
27
  import torch.nn as nn
@@ -41,6 +41,7 @@ from transformers.utils import (
41
  add_start_docstrings_to_model_forward,
42
  is_flash_attn_2_available,
43
  is_flash_attn_greater_or_equal_2_10,
 
44
  logging,
45
  replace_return_docstrings,
46
  )
@@ -100,47 +101,20 @@ class Qwen2VLCausalLMOutputWithPast(ModelOutput):
100
 
101
 
102
  class Qwen2VLRotaryEmbedding(nn.Module):
103
- def __init__(
104
- self,
105
- dim=None,
106
- max_position_embeddings=2048,
107
- base=10000,
108
- device=None,
109
- scaling_factor=1.0,
110
- rope_type="default",
111
- config: Optional[Qwen2VLConfig] = None,
112
- ):
113
  super().__init__()
114
- # TODO (joao): remove the `if` below, only used for BC
115
- self.rope_kwargs = {}
116
- if config is None:
117
- logger.warning_once(
118
- "`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the "
119
- "`config` argument. All other arguments will be removed in v4.46"
120
- )
121
- self.rope_kwargs = {
122
- "rope_type": rope_type,
123
- "factor": scaling_factor,
124
- "dim": dim,
125
- "base": base,
126
- "max_position_embeddings": max_position_embeddings,
127
- }
128
- self.rope_type = rope_type
129
- self.max_seq_len_cached = max_position_embeddings
130
- self.original_max_seq_len = max_position_embeddings
131
  else:
132
- # BC: "rope_type" was originally "type"
133
- if config.rope_scaling is not None:
134
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
135
- else:
136
- self.rope_type = "default"
137
- self.max_seq_len_cached = config.max_position_embeddings
138
- self.original_max_seq_len = config.max_position_embeddings
139
 
140
  self.config = config
141
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
142
 
143
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
144
  self.register_buffer("inv_freq", inv_freq, persistent=False)
145
  self.original_inv_freq = self.inv_freq
146
 
@@ -240,16 +214,18 @@ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
240
  return q_embed, k_embed
241
 
242
 
243
- def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
244
- orig_dtype = tensor.dtype
245
- tensor = tensor.float()
246
- cos = freqs.cos()
247
- sin = freqs.sin()
248
- cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
249
- sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
250
- output = (tensor * cos) + (rotate_half(tensor) * sin)
251
- output = output.to(orig_dtype)
252
- return output
 
 
253
 
254
 
255
  class VisionRotaryEmbedding(nn.Module):
@@ -326,12 +302,27 @@ class VisionAttention(nn.Module):
326
  self.proj = nn.Linear(dim, dim)
327
 
328
  def forward(
329
- self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
 
 
 
 
330
  ) -> torch.Tensor:
331
  seq_length = hidden_states.shape[0]
332
  q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
333
- q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
334
- k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
 
 
 
 
 
 
 
 
 
 
 
335
 
336
  attention_mask = torch.full(
337
  [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
@@ -360,12 +351,27 @@ class VisionFlashAttention2(nn.Module):
360
  self.proj = nn.Linear(dim, dim)
361
 
362
  def forward(
363
- self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
 
 
 
 
364
  ) -> torch.Tensor:
365
  seq_length = hidden_states.shape[0]
366
  q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
367
- q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
368
- k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
 
 
 
 
 
 
 
 
 
 
 
369
 
370
  max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
371
  attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
@@ -383,12 +389,27 @@ class VisionSdpaAttention(nn.Module):
383
  self.proj = nn.Linear(dim, dim)
384
 
385
  def forward(
386
- self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
 
 
 
 
387
  ) -> torch.Tensor:
388
  seq_length = hidden_states.shape[0]
389
  q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
390
- q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
391
- k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
 
 
 
 
 
 
 
 
 
 
 
392
 
393
  attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
394
  for i in range(1, len(cu_seqlens)):
@@ -422,9 +443,18 @@ class Qwen2VLVisionBlock(nn.Module):
422
  )
423
  self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
424
 
425
- def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
 
 
 
 
 
 
426
  hidden_states = hidden_states + self.attn(
427
- self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
 
 
 
428
  )
429
  hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
430
  return hidden_states
@@ -503,8 +533,6 @@ class Qwen2VLAttention(nn.Module):
503
  self.head_dim = self.hidden_size // self.num_heads
504
  self.num_key_value_heads = config.num_key_value_heads
505
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
506
- self.max_position_embeddings = config.max_position_embeddings
507
- self.rope_theta = config.rope_theta
508
  self.is_causal = True
509
  self.attention_dropout = config.attention_dropout
510
  self.rope_scaling = config.rope_scaling
@@ -519,11 +547,7 @@ class Qwen2VLAttention(nn.Module):
519
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
520
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
521
 
522
- self.rotary_emb = Qwen2VLRotaryEmbedding(
523
- self.head_dim,
524
- max_position_embeddings=self.max_position_embeddings,
525
- base=self.rope_theta,
526
- )
527
 
528
  def forward(
529
  self,
@@ -915,7 +939,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
915
  _supports_flash_attn_2 = True
916
  _supports_sdpa = True
917
  _supports_cache_class = True
918
- _supports_static_cache = True
919
 
920
  def _init_weights(self, module):
921
  std = self.config.initializer_range
@@ -993,6 +1017,8 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
993
  def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
994
  hidden_states = self.patch_embed(hidden_states)
995
  rotary_pos_emb = self.rot_pos_emb(grid_thw)
 
 
996
 
997
  cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
998
  dim=0,
@@ -1007,10 +1033,10 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
1007
  for blk in self.blocks:
1008
  if self.gradient_checkpointing and self.training:
1009
  hidden_states = self._gradient_checkpointing_func(
1010
- blk.__call__, hidden_states, cu_seqlens, rotary_pos_emb
1011
  )
1012
  else:
1013
- hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
1014
 
1015
  return self.merger(hidden_states)
1016
 
@@ -1234,7 +1260,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
1234
  if (
1235
  self.config._attn_implementation == "sdpa"
1236
  and attention_mask is not None
1237
- and attention_mask.device.type == "cuda"
1238
  and not output_attentions
1239
  ):
1240
  # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
@@ -1305,7 +1331,9 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
1305
  if attention_mask.shape[-1] > target_length:
1306
  attention_mask = attention_mask[:, :target_length]
1307
  mask_length = attention_mask.shape[-1]
1308
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
 
 
1309
  padding_mask = padding_mask == 0
1310
  causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1311
  padding_mask, min_dtype
@@ -1486,7 +1514,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
1486
  )
1487
  image_index, video_index = 0, 0
1488
  for i, input_ids in enumerate(total_input_ids):
1489
- input_ids = input_ids[attention_mask[i] == 1]
1490
  image_nums, video_nums = 0, 0
1491
  vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
1492
  vision_tokens = input_ids[vision_start_indices + 1]
@@ -1681,7 +1709,11 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
1681
  # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
1682
  if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
1683
  # calculate RoPE index once per generation in the pre-fill stage only
1684
- if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
 
 
 
 
1685
  position_ids, rope_deltas = self.get_rope_index(
1686
  input_ids, image_grid_thw, video_grid_thw, attention_mask
1687
  )
@@ -1694,6 +1726,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
1694
  position_ids = position_ids.view(1, -1).expand(batch_size, -1)
1695
  if cache_position is not None: # otherwise `deltas` is an int `0`
1696
  delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
 
1697
  position_ids = position_ids.add(delta)
1698
  position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
1699
 
@@ -1761,8 +1794,17 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
1761
  # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1762
  # Exception 1: when passing input_embeds, input_ids may be missing entries
1763
  # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
 
 
 
 
1764
  if past_key_values is not None:
1765
- if inputs_embeds is not None: # Exception 1
 
 
 
 
 
1766
  input_ids = input_ids[:, -cache_position.shape[0] :]
1767
  elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
1768
  input_ids = input_ids[:, cache_position]
@@ -1772,7 +1814,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
1772
  pixel_values_videos = None
1773
 
1774
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1775
- if inputs_embeds is not None and cache_position[0] == 0:
1776
  model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
1777
  else:
1778
  model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
@@ -1812,5 +1854,127 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
1812
  )
1813
  return model_inputs
1814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1815
 
1816
  __all__ = ["Qwen2VLForConditionalGeneration", "Qwen2VLModel", "Qwen2VLPreTrainedModel"]
 
21
 
22
  import math
23
  from dataclasses import dataclass
24
+ from typing import Any, Dict, List, Optional, Tuple, Union
25
 
26
  import torch
27
  import torch.nn as nn
 
41
  add_start_docstrings_to_model_forward,
42
  is_flash_attn_2_available,
43
  is_flash_attn_greater_or_equal_2_10,
44
+ is_torchdynamo_compiling,
45
  logging,
46
  replace_return_docstrings,
47
  )
 
101
 
102
 
103
  class Qwen2VLRotaryEmbedding(nn.Module):
104
+ def __init__(self, config: Qwen2VLConfig, device=None):
 
 
 
 
 
 
 
 
 
105
  super().__init__()
106
+ # BC: "rope_type" was originally "type"
107
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
108
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  else:
110
+ self.rope_type = "default"
111
+ self.max_seq_len_cached = config.max_position_embeddings
112
+ self.original_max_seq_len = config.max_position_embeddings
 
 
 
 
113
 
114
  self.config = config
115
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
116
 
117
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
118
  self.register_buffer("inv_freq", inv_freq, persistent=False)
119
  self.original_inv_freq = self.inv_freq
120
 
 
214
  return q_embed, k_embed
215
 
216
 
217
+ def apply_rotary_pos_emb_vision(
218
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
219
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
220
+ orig_q_dtype = q.dtype
221
+ orig_k_dtype = k.dtype
222
+ q, k = q.float(), k.float()
223
+ cos, sin = cos.unsqueeze(-2), sin.unsqueeze(-2)
224
+ q_embed = (q * cos) + (rotate_half(q) * sin)
225
+ k_embed = (k * cos) + (rotate_half(k) * sin)
226
+ q_embed = q_embed.to(orig_q_dtype)
227
+ k_embed = k_embed.to(orig_k_dtype)
228
+ return q_embed, k_embed
229
 
230
 
231
  class VisionRotaryEmbedding(nn.Module):
 
302
  self.proj = nn.Linear(dim, dim)
303
 
304
  def forward(
305
+ self,
306
+ hidden_states: torch.Tensor,
307
+ cu_seqlens: torch.Tensor,
308
+ rotary_pos_emb: Optional[torch.Tensor] = None,
309
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
310
  ) -> torch.Tensor:
311
  seq_length = hidden_states.shape[0]
312
  q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
313
+ if position_embeddings is None:
314
+ logger.warning_once(
315
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
316
+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
317
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
318
+ "removed and `position_embeddings` will be mandatory."
319
+ )
320
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
321
+ cos = emb.cos().float()
322
+ sin = emb.sin().float()
323
+ else:
324
+ cos, sin = position_embeddings
325
+ q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
326
 
327
  attention_mask = torch.full(
328
  [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
 
351
  self.proj = nn.Linear(dim, dim)
352
 
353
  def forward(
354
+ self,
355
+ hidden_states: torch.Tensor,
356
+ cu_seqlens: torch.Tensor,
357
+ rotary_pos_emb: Optional[torch.Tensor] = None,
358
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
359
  ) -> torch.Tensor:
360
  seq_length = hidden_states.shape[0]
361
  q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
362
+ if position_embeddings is None:
363
+ logger.warning_once(
364
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
365
+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
366
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
367
+ "removed and `position_embeddings` will be mandatory."
368
+ )
369
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
370
+ cos = emb.cos().float()
371
+ sin = emb.sin().float()
372
+ else:
373
+ cos, sin = position_embeddings
374
+ q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
375
 
376
  max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
377
  attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
 
389
  self.proj = nn.Linear(dim, dim)
390
 
391
  def forward(
392
+ self,
393
+ hidden_states: torch.Tensor,
394
+ cu_seqlens: torch.Tensor,
395
+ rotary_pos_emb: Optional[torch.Tensor] = None,
396
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
397
  ) -> torch.Tensor:
398
  seq_length = hidden_states.shape[0]
399
  q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
400
+ if position_embeddings is None:
401
+ logger.warning_once(
402
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
403
+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
404
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
405
+ "removed and `position_embeddings` will be mandatory."
406
+ )
407
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
408
+ cos = emb.cos().float()
409
+ sin = emb.sin().float()
410
+ else:
411
+ cos, sin = position_embeddings
412
+ q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
413
 
414
  attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
415
  for i in range(1, len(cu_seqlens)):
 
443
  )
444
  self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
445
 
446
+ def forward(
447
+ self,
448
+ hidden_states: torch.Tensor,
449
+ cu_seqlens: torch.Tensor,
450
+ rotary_pos_emb: Optional[torch.Tensor] = None,
451
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
452
+ ) -> torch.Tensor:
453
  hidden_states = hidden_states + self.attn(
454
+ self.norm1(hidden_states),
455
+ cu_seqlens=cu_seqlens,
456
+ rotary_pos_emb=rotary_pos_emb,
457
+ position_embeddings=position_embeddings,
458
  )
459
  hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
460
  return hidden_states
 
533
  self.head_dim = self.hidden_size // self.num_heads
534
  self.num_key_value_heads = config.num_key_value_heads
535
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
 
 
536
  self.is_causal = True
537
  self.attention_dropout = config.attention_dropout
538
  self.rope_scaling = config.rope_scaling
 
547
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
548
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
549
 
550
+ self.rotary_emb = Qwen2VLRotaryEmbedding(config=config)
 
 
 
 
551
 
552
  def forward(
553
  self,
 
939
  _supports_flash_attn_2 = True
940
  _supports_sdpa = True
941
  _supports_cache_class = True
942
+ _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
943
 
944
  def _init_weights(self, module):
945
  std = self.config.initializer_range
 
1017
  def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
1018
  hidden_states = self.patch_embed(hidden_states)
1019
  rotary_pos_emb = self.rot_pos_emb(grid_thw)
1020
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
1021
+ position_embeddings = (emb.cos(), emb.sin())
1022
 
1023
  cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
1024
  dim=0,
 
1033
  for blk in self.blocks:
1034
  if self.gradient_checkpointing and self.training:
1035
  hidden_states = self._gradient_checkpointing_func(
1036
+ blk.__call__, hidden_states, cu_seqlens, None, position_embeddings
1037
  )
1038
  else:
1039
+ hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
1040
 
1041
  return self.merger(hidden_states)
1042
 
 
1260
  if (
1261
  self.config._attn_implementation == "sdpa"
1262
  and attention_mask is not None
1263
+ and attention_mask.device.type in ["cuda", "xpu"]
1264
  and not output_attentions
1265
  ):
1266
  # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
 
1331
  if attention_mask.shape[-1] > target_length:
1332
  attention_mask = attention_mask[:, :target_length]
1333
  mask_length = attention_mask.shape[-1]
1334
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
1335
+ causal_mask.device
1336
+ )
1337
  padding_mask = padding_mask == 0
1338
  causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1339
  padding_mask, min_dtype
 
1514
  )
1515
  image_index, video_index = 0, 0
1516
  for i, input_ids in enumerate(total_input_ids):
1517
+ input_ids = input_ids[attention_mask[i].to(input_ids.device) == 1]
1518
  image_nums, video_nums = 0, 0
1519
  vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
1520
  vision_tokens = input_ids[vision_start_indices + 1]
 
1709
  # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
1710
  if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
1711
  # calculate RoPE index once per generation in the pre-fill stage only
1712
+ if (
1713
+ (cache_position is not None and cache_position[0] == 0)
1714
+ or self.rope_deltas is None
1715
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
1716
+ ):
1717
  position_ids, rope_deltas = self.get_rope_index(
1718
  input_ids, image_grid_thw, video_grid_thw, attention_mask
1719
  )
 
1726
  position_ids = position_ids.view(1, -1).expand(batch_size, -1)
1727
  if cache_position is not None: # otherwise `deltas` is an int `0`
1728
  delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
1729
+ delta = delta.to(position_ids.device)
1730
  position_ids = position_ids.add(delta)
1731
  position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
1732
 
 
1794
  # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1795
  # Exception 1: when passing input_embeds, input_ids may be missing entries
1796
  # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
1797
+ # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
1798
+ # (we can't check exception 3 while compiling)
1799
+ # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
1800
+ # generate the first token for each sequence. Later use the generated Input ids for continuation.
1801
  if past_key_values is not None:
1802
+ if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
1803
+ inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
1804
+ elif (
1805
+ inputs_embeds is not None # Exception 1
1806
+ or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
1807
+ ):
1808
  input_ids = input_ids[:, -cache_position.shape[0] :]
1809
  elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
1810
  input_ids = input_ids[:, cache_position]
 
1814
  pixel_values_videos = None
1815
 
1816
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1817
+ if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
1818
  model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
1819
  else:
1820
  model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
 
1854
  )
1855
  return model_inputs
1856
 
1857
+ def _get_image_nums_and_video_nums(
1858
+ self,
1859
+ input_ids: Optional[torch.LongTensor],
1860
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1861
+ """
1862
+ Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
1863
+ These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
1864
+
1865
+ Args:
1866
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1867
+ Indices of input sequence tokens in the vocabulary.
1868
+
1869
+ Returns:
1870
+ image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
1871
+ video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
1872
+ """
1873
+ image_token_id = self.config.image_token_id
1874
+ video_token_id = self.config.video_token_id
1875
+ vision_start_token_id = self.config.vision_start_token_id
1876
+
1877
+ vision_start_mask = input_ids == vision_start_token_id
1878
+ vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1)
1879
+ image_mask = input_ids == image_token_id
1880
+ video_mask = input_ids == video_token_id
1881
+ image_nums = torch.sum(vision_first_mask & image_mask, dim=1)
1882
+ video_nums = torch.sum(vision_first_mask & video_mask, dim=1)
1883
+
1884
+ return image_nums, video_nums
1885
+
1886
+ def _expand_inputs_for_generation(
1887
+ self,
1888
+ expand_size: int = 1,
1889
+ is_encoder_decoder: bool = False,
1890
+ input_ids: Optional[torch.LongTensor] = None,
1891
+ **model_kwargs,
1892
+ ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
1893
+ # Overwritten -- Support for expanding tensors without a batch size dimension
1894
+ # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
1895
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
1896
+ # image_grid_thw.shape[0] is sum(num_images for samples)
1897
+
1898
+ if expand_size == 1:
1899
+ return input_ids, model_kwargs
1900
+
1901
+ visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
1902
+
1903
+ def _expand_dict_for_generation_visual(dict_to_expand):
1904
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
1905
+ video_grid_thw = model_kwargs.get("video_grid_thw", None)
1906
+ image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids)
1907
+
1908
+ def _repeat_interleave_samples(x, lengths, repeat_times):
1909
+ samples = torch.split(x, lengths)
1910
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
1911
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
1912
+ return result
1913
+
1914
+ for key in dict_to_expand:
1915
+ if key == "pixel_values":
1916
+ # split images into samples
1917
+ samples = torch.split(image_grid_thw, list(image_nums))
1918
+ # compute the sequence length of images for each sample
1919
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1920
+ dict_to_expand[key] = _repeat_interleave_samples(
1921
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1922
+ )
1923
+ elif key == "image_grid_thw":
1924
+ # get the num of images for each sample
1925
+ lengths = list(image_nums)
1926
+ dict_to_expand[key] = _repeat_interleave_samples(
1927
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1928
+ )
1929
+ elif key == "pixel_values_videos":
1930
+ samples = torch.split(video_grid_thw, list(video_nums))
1931
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1932
+ dict_to_expand[key] = _repeat_interleave_samples(
1933
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1934
+ )
1935
+ elif key == "video_grid_thw":
1936
+ lengths = list(video_nums)
1937
+ dict_to_expand[key] = _repeat_interleave_samples(
1938
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1939
+ )
1940
+ elif key == "second_per_grid_ts":
1941
+ if not isinstance(dict_to_expand[key], list):
1942
+ raise TypeError(
1943
+ f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead."
1944
+ )
1945
+ tensor = torch.tensor(dict_to_expand[key])
1946
+ lengths = list(video_nums)
1947
+ tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size)
1948
+ dict_to_expand[key] = tensor.tolist()
1949
+ return dict_to_expand
1950
+
1951
+ def _expand_dict_for_generation(dict_to_expand):
1952
+ for key in dict_to_expand:
1953
+ if (
1954
+ key != "cache_position"
1955
+ and dict_to_expand[key] is not None
1956
+ and isinstance(dict_to_expand[key], torch.Tensor)
1957
+ and key not in visual_keys
1958
+ ):
1959
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
1960
+ return dict_to_expand
1961
+
1962
+ # input_ids is required for expanding visual inputs
1963
+ # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs.
1964
+ if input_ids is not None and input_ids.numel() != 0:
1965
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
1966
+
1967
+ if input_ids is not None:
1968
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
1969
+
1970
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
1971
+
1972
+ if is_encoder_decoder:
1973
+ if model_kwargs.get("encoder_outputs") is None:
1974
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
1975
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
1976
+
1977
+ return input_ids, model_kwargs
1978
+
1979
 
1980
  __all__ = ["Qwen2VLForConditionalGeneration", "Qwen2VLModel", "Qwen2VLPreTrainedModel"]