NyxKrage commited on
Commit
a35d75f
·
verified ·
1 Parent(s): eb7b99d

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling.py +100 -5
modeling.py CHANGED
@@ -17,6 +17,7 @@
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
 
20
  from typing import Optional, Tuple, Union
21
 
22
  import torch
@@ -24,6 +25,7 @@ from torch import nn
24
 
25
  from transformers.cache_utils import Cache, DynamicCache
26
  from transformers.generation import GenerationMixin
 
27
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
28
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
29
  from transformers.modeling_layers import GradientCheckpointingLayer # type: ignore for some reason transformers doesn't have an __ALL__ in the modeling_layers.py file
@@ -31,11 +33,12 @@ from transformers.modeling_outputs import (
31
  BaseModelOutputWithPast,
32
  CausalLMOutputWithPast,
33
  )
 
34
  from transformers.modeling_utils import PreTrainedModel
35
  from transformers.processing_utils import Unpack
36
  from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
37
 
38
- from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaRotaryEmbedding, apply_rotary_pos_emb, LlamaMLP
39
 
40
  if is_torch_flex_attn_available():
41
  from torch.nn.attention.flex_attention import BlockMask
@@ -46,6 +49,98 @@ from .config import LlamaMlaConfig
46
 
47
  logger = logging.get_logger(__name__)
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  class LlamaMlaAttention(nn.Module):
50
  """Multi-headed Latent attention from 'DeepSeek-V2'"""
51
 
@@ -107,7 +202,7 @@ class LlamaMlaAttention(nn.Module):
107
  bias=config.attention_bias,
108
  )
109
 
110
- self.rotary_emb = LlamaRotaryEmbedding(config=config)
111
 
112
  self.softmax_scale = self.q_head_dim ** (-0.5)
113
 
@@ -166,7 +261,7 @@ class LlamaMlaAttention(nn.Module):
166
  "with a layer index."
167
  )
168
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
169
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
170
 
171
  q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
172
 
@@ -224,7 +319,7 @@ class LlamaMlaAttention(nn.Module):
224
  if not output_attentions:
225
  attn_weights = None
226
 
227
- return attn_output, attn_weights, past_key_value
228
 
229
 
230
  class LlamaMlaDecoderLayer(GradientCheckpointingLayer):
@@ -321,7 +416,7 @@ class LlamaMlaModel(LlamaMlaPreTrainedModel):
321
  [LlamaMlaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
322
  )
323
  self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
324
- self.rotary_emb = LlamaRotaryEmbedding(config=config)
325
  self.gradient_checkpointing = False
326
 
327
  # Initialize weights and apply final processing
 
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
20
+ import math
21
  from typing import Optional, Tuple, Union
22
 
23
  import torch
 
25
 
26
  from transformers.cache_utils import Cache, DynamicCache
27
  from transformers.generation import GenerationMixin
28
+ from transformers.configuration_utils import PretrainedConfig
29
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
30
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
31
  from transformers.modeling_layers import GradientCheckpointingLayer # type: ignore for some reason transformers doesn't have an __ALL__ in the modeling_layers.py file
 
33
  BaseModelOutputWithPast,
34
  CausalLMOutputWithPast,
35
  )
36
+ from transformers.modeling_rope_utils import dynamic_rope_update
37
  from transformers.modeling_utils import PreTrainedModel
38
  from transformers.processing_utils import Unpack
39
  from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
40
 
41
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm, apply_rotary_pos_emb, LlamaMLP
42
 
43
  if is_torch_flex_attn_available():
44
  from torch.nn.attention.flex_attention import BlockMask
 
49
 
50
  logger = logging.get_logger(__name__)
51
 
52
+ def _compute_llama_mla_parameters(
53
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, head_dim: int = None, **rope_kwargs
54
+ ) -> tuple["torch.Tensor", float]:
55
+ """
56
+ Computes the inverse frequencies for llama 3.1.
57
+
58
+ Args:
59
+ config ([`~transformers.PretrainedConfig`]):
60
+ The model configuration.
61
+ device (`torch.device`):
62
+ The device to use for initialization of the inverse frequencies.
63
+ seq_len (`int`, *optional*):
64
+ The current sequence length. Unused for this type of RoPE.
65
+ rope_kwargs (`Dict`, *optional*):
66
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
67
+ Returns:
68
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
69
+ post-processing scaling factor applied to the computed cos/sin.
70
+ """
71
+ # Gets the default RoPE parameters
72
+ if config is not None and len(rope_kwargs) > 0:
73
+ raise ValueError(
74
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
75
+ f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
76
+ )
77
+ if len(rope_kwargs) > 0:
78
+ base = rope_kwargs["base"]
79
+ dim = rope_kwargs["dim"]
80
+ elif config is not None:
81
+ base = config.rope_theta
82
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
83
+ head_dim = head_dim or getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
84
+ dim = int(head_dim * partial_rotary_factor)
85
+
86
+ attention_factor = 1.0 # Unused in this type of RoPE
87
+
88
+ # Compute the inverse frequencies
89
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
90
+
91
+ factor = config.rope_scaling["factor"] # `8` in the original implementation
92
+ low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
93
+ high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
94
+ old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
95
+
96
+ low_freq_wavelen = old_context_len / low_freq_factor
97
+ high_freq_wavelen = old_context_len / high_freq_factor
98
+
99
+ wavelen = 2 * math.pi / inv_freq
100
+ # wavelen < high_freq_wavelen: do nothing
101
+ # wavelen > low_freq_wavelen: divide by factor
102
+ inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
103
+ # otherwise: interpolate between the two, using a smooth factor
104
+ smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
105
+ smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
106
+ is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
107
+ inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
108
+
109
+ return inv_freq_llama, attention_factor
110
+
111
+ class LlamaMlaRotaryEmbedding(nn.Module):
112
+ def __init__(self, config: LlamaMlaConfig, device=None, head_dim: int = None):
113
+ super().__init__()
114
+ # BC: "rope_type" was originally "type"
115
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
116
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
117
+ else:
118
+ self.rope_type = "default"
119
+ self.max_seq_len_cached = config.max_position_embeddings
120
+ self.original_max_seq_len = config.max_position_embeddings
121
+
122
+ self.config = config
123
+ self.rope_init_fn = _compute_llama_mla_parameters
124
+
125
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, head_dim=head_dim)
126
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
127
+ self.original_inv_freq = self.inv_freq
128
+
129
+ @torch.no_grad()
130
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
131
+ def forward(self, x, position_ids):
132
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
133
+ position_ids_expanded = position_ids[:, None, :].float()
134
+
135
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
136
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
137
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
138
+ emb = torch.cat((freqs, freqs), dim=-1)
139
+ cos = emb.cos() * self.attention_scaling
140
+ sin = emb.sin() * self.attention_scaling
141
+
142
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
143
+
144
  class LlamaMlaAttention(nn.Module):
145
  """Multi-headed Latent attention from 'DeepSeek-V2'"""
146
 
 
202
  bias=config.attention_bias,
203
  )
204
 
205
+ self.rotary_emb = LlamaMlaRotaryEmbedding(config=config, head_dim=self.qk_rope_head_dim)
206
 
207
  self.softmax_scale = self.q_head_dim ** (-0.5)
208
 
 
261
  "with a layer index."
262
  )
263
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
264
+ cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
265
 
266
  q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
267
 
 
319
  if not output_attentions:
320
  attn_weights = None
321
 
322
+ return attn_output, attn_weights
323
 
324
 
325
  class LlamaMlaDecoderLayer(GradientCheckpointingLayer):
 
416
  [LlamaMlaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
417
  )
418
  self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
419
+ self.rotary_emb = LlamaMlaRotaryEmbedding(config=config)
420
  self.gradient_checkpointing = False
421
 
422
  # Initialize weights and apply final processing