Upload folder using huggingface_hub
Browse files- modeling.py +100 -5
modeling.py
CHANGED
@@ -17,6 +17,7 @@
|
|
17 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
# See the License for the specific language governing permissions and
|
19 |
# limitations under the License.
|
|
|
20 |
from typing import Optional, Tuple, Union
|
21 |
|
22 |
import torch
|
@@ -24,6 +25,7 @@ from torch import nn
|
|
24 |
|
25 |
from transformers.cache_utils import Cache, DynamicCache
|
26 |
from transformers.generation import GenerationMixin
|
|
|
27 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
28 |
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
29 |
from transformers.modeling_layers import GradientCheckpointingLayer # type: ignore for some reason transformers doesn't have an __ALL__ in the modeling_layers.py file
|
@@ -31,11 +33,12 @@ from transformers.modeling_outputs import (
|
|
31 |
BaseModelOutputWithPast,
|
32 |
CausalLMOutputWithPast,
|
33 |
)
|
|
|
34 |
from transformers.modeling_utils import PreTrainedModel
|
35 |
from transformers.processing_utils import Unpack
|
36 |
from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
37 |
|
38 |
-
from transformers.models.llama.modeling_llama import LlamaRMSNorm,
|
39 |
|
40 |
if is_torch_flex_attn_available():
|
41 |
from torch.nn.attention.flex_attention import BlockMask
|
@@ -46,6 +49,98 @@ from .config import LlamaMlaConfig
|
|
46 |
|
47 |
logger = logging.get_logger(__name__)
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
class LlamaMlaAttention(nn.Module):
|
50 |
"""Multi-headed Latent attention from 'DeepSeek-V2'"""
|
51 |
|
@@ -107,7 +202,7 @@ class LlamaMlaAttention(nn.Module):
|
|
107 |
bias=config.attention_bias,
|
108 |
)
|
109 |
|
110 |
-
self.rotary_emb =
|
111 |
|
112 |
self.softmax_scale = self.q_head_dim ** (-0.5)
|
113 |
|
@@ -166,7 +261,7 @@ class LlamaMlaAttention(nn.Module):
|
|
166 |
"with a layer index."
|
167 |
)
|
168 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
169 |
-
cos, sin = self.rotary_emb(value_states,
|
170 |
|
171 |
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
172 |
|
@@ -224,7 +319,7 @@ class LlamaMlaAttention(nn.Module):
|
|
224 |
if not output_attentions:
|
225 |
attn_weights = None
|
226 |
|
227 |
-
return attn_output, attn_weights
|
228 |
|
229 |
|
230 |
class LlamaMlaDecoderLayer(GradientCheckpointingLayer):
|
@@ -321,7 +416,7 @@ class LlamaMlaModel(LlamaMlaPreTrainedModel):
|
|
321 |
[LlamaMlaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
322 |
)
|
323 |
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
324 |
-
self.rotary_emb =
|
325 |
self.gradient_checkpointing = False
|
326 |
|
327 |
# Initialize weights and apply final processing
|
|
|
17 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
# See the License for the specific language governing permissions and
|
19 |
# limitations under the License.
|
20 |
+
import math
|
21 |
from typing import Optional, Tuple, Union
|
22 |
|
23 |
import torch
|
|
|
25 |
|
26 |
from transformers.cache_utils import Cache, DynamicCache
|
27 |
from transformers.generation import GenerationMixin
|
28 |
+
from transformers.configuration_utils import PretrainedConfig
|
29 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
30 |
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
31 |
from transformers.modeling_layers import GradientCheckpointingLayer # type: ignore for some reason transformers doesn't have an __ALL__ in the modeling_layers.py file
|
|
|
33 |
BaseModelOutputWithPast,
|
34 |
CausalLMOutputWithPast,
|
35 |
)
|
36 |
+
from transformers.modeling_rope_utils import dynamic_rope_update
|
37 |
from transformers.modeling_utils import PreTrainedModel
|
38 |
from transformers.processing_utils import Unpack
|
39 |
from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
40 |
|
41 |
+
from transformers.models.llama.modeling_llama import LlamaRMSNorm, apply_rotary_pos_emb, LlamaMLP
|
42 |
|
43 |
if is_torch_flex_attn_available():
|
44 |
from torch.nn.attention.flex_attention import BlockMask
|
|
|
49 |
|
50 |
logger = logging.get_logger(__name__)
|
51 |
|
52 |
+
def _compute_llama_mla_parameters(
|
53 |
+
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, head_dim: int = None, **rope_kwargs
|
54 |
+
) -> tuple["torch.Tensor", float]:
|
55 |
+
"""
|
56 |
+
Computes the inverse frequencies for llama 3.1.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
config ([`~transformers.PretrainedConfig`]):
|
60 |
+
The model configuration.
|
61 |
+
device (`torch.device`):
|
62 |
+
The device to use for initialization of the inverse frequencies.
|
63 |
+
seq_len (`int`, *optional*):
|
64 |
+
The current sequence length. Unused for this type of RoPE.
|
65 |
+
rope_kwargs (`Dict`, *optional*):
|
66 |
+
BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
|
67 |
+
Returns:
|
68 |
+
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
69 |
+
post-processing scaling factor applied to the computed cos/sin.
|
70 |
+
"""
|
71 |
+
# Gets the default RoPE parameters
|
72 |
+
if config is not None and len(rope_kwargs) > 0:
|
73 |
+
raise ValueError(
|
74 |
+
"Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
|
75 |
+
f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
|
76 |
+
)
|
77 |
+
if len(rope_kwargs) > 0:
|
78 |
+
base = rope_kwargs["base"]
|
79 |
+
dim = rope_kwargs["dim"]
|
80 |
+
elif config is not None:
|
81 |
+
base = config.rope_theta
|
82 |
+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
|
83 |
+
head_dim = head_dim or getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
84 |
+
dim = int(head_dim * partial_rotary_factor)
|
85 |
+
|
86 |
+
attention_factor = 1.0 # Unused in this type of RoPE
|
87 |
+
|
88 |
+
# Compute the inverse frequencies
|
89 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
|
90 |
+
|
91 |
+
factor = config.rope_scaling["factor"] # `8` in the original implementation
|
92 |
+
low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
|
93 |
+
high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
|
94 |
+
old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
|
95 |
+
|
96 |
+
low_freq_wavelen = old_context_len / low_freq_factor
|
97 |
+
high_freq_wavelen = old_context_len / high_freq_factor
|
98 |
+
|
99 |
+
wavelen = 2 * math.pi / inv_freq
|
100 |
+
# wavelen < high_freq_wavelen: do nothing
|
101 |
+
# wavelen > low_freq_wavelen: divide by factor
|
102 |
+
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
|
103 |
+
# otherwise: interpolate between the two, using a smooth factor
|
104 |
+
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
105 |
+
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
|
106 |
+
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
107 |
+
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
|
108 |
+
|
109 |
+
return inv_freq_llama, attention_factor
|
110 |
+
|
111 |
+
class LlamaMlaRotaryEmbedding(nn.Module):
|
112 |
+
def __init__(self, config: LlamaMlaConfig, device=None, head_dim: int = None):
|
113 |
+
super().__init__()
|
114 |
+
# BC: "rope_type" was originally "type"
|
115 |
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
116 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
117 |
+
else:
|
118 |
+
self.rope_type = "default"
|
119 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
120 |
+
self.original_max_seq_len = config.max_position_embeddings
|
121 |
+
|
122 |
+
self.config = config
|
123 |
+
self.rope_init_fn = _compute_llama_mla_parameters
|
124 |
+
|
125 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, head_dim=head_dim)
|
126 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
127 |
+
self.original_inv_freq = self.inv_freq
|
128 |
+
|
129 |
+
@torch.no_grad()
|
130 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
131 |
+
def forward(self, x, position_ids):
|
132 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
133 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
134 |
+
|
135 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
136 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
137 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
138 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
139 |
+
cos = emb.cos() * self.attention_scaling
|
140 |
+
sin = emb.sin() * self.attention_scaling
|
141 |
+
|
142 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
143 |
+
|
144 |
class LlamaMlaAttention(nn.Module):
|
145 |
"""Multi-headed Latent attention from 'DeepSeek-V2'"""
|
146 |
|
|
|
202 |
bias=config.attention_bias,
|
203 |
)
|
204 |
|
205 |
+
self.rotary_emb = LlamaMlaRotaryEmbedding(config=config, head_dim=self.qk_rope_head_dim)
|
206 |
|
207 |
self.softmax_scale = self.q_head_dim ** (-0.5)
|
208 |
|
|
|
261 |
"with a layer index."
|
262 |
)
|
263 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
264 |
+
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
|
265 |
|
266 |
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
267 |
|
|
|
319 |
if not output_attentions:
|
320 |
attn_weights = None
|
321 |
|
322 |
+
return attn_output, attn_weights
|
323 |
|
324 |
|
325 |
class LlamaMlaDecoderLayer(GradientCheckpointingLayer):
|
|
|
416 |
[LlamaMlaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
417 |
)
|
418 |
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
419 |
+
self.rotary_emb = LlamaMlaRotaryEmbedding(config=config)
|
420 |
self.gradient_checkpointing = False
|
421 |
|
422 |
# Initialize weights and apply final processing
|