NyxKrage commited on
Commit
eb7b99d
·
verified ·
1 Parent(s): 2084f91

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. config.json +45 -0
  2. config.py +26 -0
  3. generation_config.json +6 -0
  4. model.safetensors +3 -0
  5. modeling.py +671 -0
config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaMlaForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoModelForCausalLM": "modeling.LlamaMlaForCausalLM",
7
+ "AutoModel": "modeling.LlamaMlaModel",
8
+ "AutoConfig": "config.LlamaMlaConfig"
9
+ },
10
+ "attention_bias": false,
11
+ "attention_dropout": 0.0,
12
+ "bos_token_id": 128000,
13
+ "eos_token_id": 128001,
14
+ "head_dim": 64,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 2048,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 8192,
19
+ "kv_lora_rank": 512,
20
+ "max_position_embeddings": 131072,
21
+ "mlp_bias": false,
22
+ "model_type": "llama_mla",
23
+ "num_attention_heads": 32,
24
+ "num_hidden_layers": 16,
25
+ "num_key_value_heads": 8,
26
+ "pretraining_tp": 1,
27
+ "q_lora_rank": 1536,
28
+ "qk_nope_head_dim": 64,
29
+ "qk_rope_head_dim": 32,
30
+ "rms_norm_eps": 1e-05,
31
+ "rope_scaling": {
32
+ "factor": 32.0,
33
+ "high_freq_factor": 4.0,
34
+ "low_freq_factor": 1.0,
35
+ "original_max_position_embeddings": 8192,
36
+ "rope_type": "llama3"
37
+ },
38
+ "rope_theta": 500000.0,
39
+ "tie_word_embeddings": true,
40
+ "torch_dtype": "bfloat16",
41
+ "transformers_version": "4.52.0.dev0",
42
+ "use_cache": true,
43
+ "v_head_dim": 64,
44
+ "vocab_size": 128256
45
+ }
config.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.llama.configuration_llama import LlamaConfig
2
+
3
+ class LlamaMlaConfig(LlamaConfig):
4
+ model_type = "llama_mla"
5
+ base_model_pp_plan = None
6
+ base_model_tp_plan = None
7
+
8
+ def __init__(
9
+ self,
10
+ kv_lora_rank = 512,
11
+ q_lora_rank = 1536,
12
+ qk_rope_head_dim = 64,
13
+ v_head_dim = 128,
14
+ qk_nope_head_dim = 128,
15
+ **kwargs,
16
+ ):
17
+ super().__init__(**kwargs)
18
+ self.kv_lora_rank = kv_lora_rank
19
+ self.q_lora_rank = q_lora_rank
20
+ self.qk_rope_head_dim = qk_rope_head_dim
21
+ self.v_head_dim = v_head_dim
22
+ self.qk_nope_head_dim = qk_nope_head_dim
23
+
24
+ __ALL__ = [
25
+ "LlamaMlaConfig",
26
+ ]
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 128000,
4
+ "eos_token_id": 128001,
5
+ "transformers_version": "4.52.0.dev0"
6
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d5f05b34b859a0faf9c69dace2faabde01155c68e5713ba3119b9e46709feb7
3
+ size 2624809232
modeling.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ from typing import Optional, Tuple, Union
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+ from transformers.cache_utils import Cache, DynamicCache
26
+ from transformers.generation import GenerationMixin
27
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
28
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
29
+ from transformers.modeling_layers import GradientCheckpointingLayer # type: ignore for some reason transformers doesn't have an __ALL__ in the modeling_layers.py file
30
+ from transformers.modeling_outputs import (
31
+ BaseModelOutputWithPast,
32
+ CausalLMOutputWithPast,
33
+ )
34
+ from transformers.modeling_utils import PreTrainedModel
35
+ from transformers.processing_utils import Unpack
36
+ from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
37
+
38
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaRotaryEmbedding, apply_rotary_pos_emb, LlamaMLP
39
+
40
+ if is_torch_flex_attn_available():
41
+ from torch.nn.attention.flex_attention import BlockMask
42
+
43
+ from transformers.integrations.flex_attention import make_flex_block_causal_mask
44
+
45
+ from .config import LlamaMlaConfig
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+ class LlamaMlaAttention(nn.Module):
50
+ """Multi-headed Latent attention from 'DeepSeek-V2'"""
51
+
52
+ def __init__(self, config: LlamaMlaConfig, layer_idx: Optional[int] = None):
53
+ super().__init__()
54
+ self.config = config
55
+ self.layer_idx = layer_idx
56
+ if layer_idx is None:
57
+ logger.warning_once(
58
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
59
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
60
+ "when creating this class."
61
+ )
62
+
63
+ self.attention_dropout = config.attention_dropout
64
+ self.hidden_size = config.hidden_size
65
+ self.num_heads = config.num_attention_heads
66
+
67
+ self.max_position_embeddings = config.max_position_embeddings
68
+ self.rope_theta = config.rope_theta
69
+ self.q_lora_rank = config.q_lora_rank
70
+ self.qk_rope_head_dim = config.qk_rope_head_dim
71
+ self.kv_lora_rank = config.kv_lora_rank
72
+ self.v_head_dim = config.v_head_dim
73
+ self.qk_nope_head_dim = config.qk_nope_head_dim
74
+ self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
75
+
76
+ self.is_causal = True
77
+
78
+ if self.q_lora_rank is None:
79
+ self.q_proj = nn.Linear(
80
+ self.hidden_size, self.num_heads * self.q_head_dim, bias=False
81
+ )
82
+ else:
83
+ self.q_a_proj = nn.Linear(
84
+ self.hidden_size, config.q_lora_rank, bias=config.attention_bias
85
+ )
86
+ self.q_a_layernorm = LlamaRMSNorm(config.q_lora_rank)
87
+ self.q_b_proj = nn.Linear(
88
+ config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
89
+ )
90
+
91
+ self.kv_a_proj_with_mqa = nn.Linear(
92
+ self.hidden_size,
93
+ config.kv_lora_rank + config.qk_rope_head_dim,
94
+ bias=config.attention_bias,
95
+ )
96
+ self.kv_a_layernorm = LlamaRMSNorm(config.kv_lora_rank)
97
+ self.kv_b_proj = nn.Linear(
98
+ config.kv_lora_rank,
99
+ self.num_heads
100
+ * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
101
+ bias=False,
102
+ )
103
+
104
+ self.o_proj = nn.Linear(
105
+ self.num_heads * self.v_head_dim,
106
+ self.hidden_size,
107
+ bias=config.attention_bias,
108
+ )
109
+
110
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
111
+
112
+ self.softmax_scale = self.q_head_dim ** (-0.5)
113
+
114
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
115
+ return (
116
+ tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim)
117
+ .transpose(1, 2)
118
+ .contiguous()
119
+ )
120
+
121
+ def forward(
122
+ self,
123
+ hidden_states: torch.Tensor,
124
+ attention_mask: Optional[torch.Tensor] = None,
125
+ position_ids: Optional[torch.LongTensor] = None,
126
+ past_key_value: Optional[Cache] = None,
127
+ output_attentions: bool = False,
128
+ use_cache: bool = False,
129
+ **kwargs,
130
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
131
+ if "padding_mask" in kwargs:
132
+ logger.warning_once(
133
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
134
+ )
135
+ bsz, q_len, _ = hidden_states.size()
136
+
137
+ if self.q_lora_rank is None:
138
+ q = self.q_proj(hidden_states)
139
+ else:
140
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
141
+ q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
142
+ q_nope, q_pe = torch.split(
143
+ q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
144
+ )
145
+
146
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
147
+ compressed_kv, k_pe = torch.split(
148
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
149
+ )
150
+ k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
151
+ kv = (
152
+ self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
153
+ .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
154
+ .transpose(1, 2)
155
+ )
156
+
157
+ k_nope, value_states = torch.split(
158
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
159
+ )
160
+ kv_seq_len = value_states.shape[-2]
161
+ if past_key_value is not None:
162
+ if self.layer_idx is None:
163
+ raise ValueError(
164
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
165
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
166
+ "with a layer index."
167
+ )
168
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
169
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
170
+
171
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
172
+
173
+ query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
174
+ query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
175
+ query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
176
+
177
+ key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
178
+ key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
179
+ key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
180
+ if past_key_value is not None:
181
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
182
+ key_states, value_states = past_key_value.update(
183
+ key_states, value_states, self.layer_idx, cache_kwargs
184
+ )
185
+
186
+ attn_weights = (
187
+ torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
188
+ )
189
+
190
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
191
+ raise ValueError(
192
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
193
+ f" {attn_weights.size()}"
194
+ )
195
+ assert attention_mask is not None
196
+ if attention_mask is not None:
197
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
198
+ raise ValueError(
199
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
200
+ )
201
+ attn_weights = attn_weights + attention_mask
202
+
203
+ # upcast attention to fp32
204
+ attn_weights = nn.functional.softmax(
205
+ attn_weights, dim=-1, dtype=torch.float32
206
+ ).to(query_states.dtype)
207
+ attn_weights = nn.functional.dropout(
208
+ attn_weights, p=self.attention_dropout, training=self.training
209
+ )
210
+ attn_output = torch.matmul(attn_weights, value_states)
211
+
212
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
213
+ raise ValueError(
214
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"
215
+ f" {attn_output.size()}"
216
+ )
217
+
218
+ attn_output = attn_output.transpose(1, 2).contiguous()
219
+
220
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
221
+
222
+ attn_output = self.o_proj(attn_output)
223
+
224
+ if not output_attentions:
225
+ attn_weights = None
226
+
227
+ return attn_output, attn_weights, past_key_value
228
+
229
+
230
+ class LlamaMlaDecoderLayer(GradientCheckpointingLayer):
231
+ def __init__(self, config: LlamaMlaConfig, layer_idx: int):
232
+ super().__init__()
233
+ self.hidden_size = config.hidden_size
234
+
235
+ self.self_attn = LlamaMlaAttention(config=config, layer_idx=layer_idx)
236
+
237
+ self.mlp = LlamaMLP(config)
238
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
239
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
240
+
241
+ def forward(
242
+ self,
243
+ hidden_states: torch.Tensor,
244
+ attention_mask: Optional[torch.Tensor] = None,
245
+ position_ids: Optional[torch.LongTensor] = None,
246
+ past_key_value: Optional[Cache] = None,
247
+ output_attentions: Optional[bool] = False,
248
+ use_cache: Optional[bool] = False,
249
+ cache_position: Optional[torch.LongTensor] = None,
250
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
251
+ **kwargs: Unpack[FlashAttentionKwargs],
252
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
253
+ residual = hidden_states
254
+ hidden_states = self.input_layernorm(hidden_states)
255
+
256
+ # Self Attention
257
+ hidden_states, self_attn_weights = self.self_attn(
258
+ hidden_states=hidden_states,
259
+ attention_mask=attention_mask,
260
+ position_ids=position_ids,
261
+ past_key_value=past_key_value,
262
+ output_attentions=output_attentions,
263
+ use_cache=use_cache,
264
+ cache_position=cache_position,
265
+ position_embeddings=position_embeddings,
266
+ **kwargs,
267
+ )
268
+ hidden_states = residual + hidden_states
269
+
270
+ # Fully Connected
271
+ residual = hidden_states
272
+ hidden_states = self.post_attention_layernorm(hidden_states)
273
+ hidden_states = self.mlp(hidden_states)
274
+ hidden_states = residual + hidden_states
275
+
276
+ outputs = (hidden_states,)
277
+ if output_attentions:
278
+ outputs += (self_attn_weights,)
279
+
280
+ return outputs
281
+
282
+
283
+ @auto_docstring
284
+ class LlamaMlaPreTrainedModel(PreTrainedModel):
285
+ config_class = LlamaMlaConfig
286
+ base_model_prefix = "model"
287
+ supports_gradient_checkpointing = True
288
+ _no_split_modules = ["LlamaMlaDecoderLayer"]
289
+ _skip_keys_device_placement = ["past_key_values"]
290
+ _supports_flash_attn_2 = True
291
+ _supports_sdpa = True
292
+ _supports_flex_attn = True
293
+ _supports_cache_class = True
294
+ _supports_quantized_cache = True
295
+ _supports_static_cache = True
296
+ _supports_attention_backend = True
297
+
298
+ def _init_weights(self, module):
299
+ std = self.config.initializer_range
300
+ if isinstance(module, nn.Linear):
301
+ module.weight.data.normal_(mean=0.0, std=std)
302
+ if module.bias is not None:
303
+ module.bias.data.zero_()
304
+ elif isinstance(module, nn.Embedding):
305
+ module.weight.data.normal_(mean=0.0, std=std)
306
+ if module.padding_idx is not None:
307
+ module.weight.data[module.padding_idx].zero_()
308
+ elif isinstance(module, LlamaRMSNorm):
309
+ module.weight.data.fill_(1.0)
310
+
311
+
312
+ @auto_docstring
313
+ class LlamaMlaModel(LlamaMlaPreTrainedModel):
314
+ def __init__(self, config: LlamaMlaConfig):
315
+ super().__init__(config)
316
+ self.padding_idx = config.pad_token_id
317
+ self.vocab_size = config.vocab_size
318
+
319
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
320
+ self.layers = nn.ModuleList(
321
+ [LlamaMlaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
322
+ )
323
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
324
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
325
+ self.gradient_checkpointing = False
326
+
327
+ # Initialize weights and apply final processing
328
+ self.post_init()
329
+
330
+ def get_input_embeddings(self):
331
+ return self.embed_tokens
332
+
333
+ def set_input_embeddings(self, value):
334
+ self.embed_tokens = value
335
+
336
+ @can_return_tuple
337
+ @auto_docstring
338
+ def forward(
339
+ self,
340
+ input_ids: Optional[torch.LongTensor] = None,
341
+ attention_mask: Optional[torch.Tensor] = None,
342
+ position_ids: Optional[torch.LongTensor] = None,
343
+ past_key_values: Optional[Cache] = None,
344
+ inputs_embeds: Optional[torch.FloatTensor] = None,
345
+ use_cache: Optional[bool] = None,
346
+ output_attentions: Optional[bool] = None,
347
+ output_hidden_states: Optional[bool] = None,
348
+ cache_position: Optional[torch.LongTensor] = None,
349
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
350
+ ) -> BaseModelOutputWithPast:
351
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
352
+ output_hidden_states = (
353
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
354
+ )
355
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
356
+
357
+ if (input_ids is None) ^ (inputs_embeds is not None):
358
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
359
+
360
+ if self.gradient_checkpointing and self.training and use_cache:
361
+ logger.warning_once(
362
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
363
+ )
364
+ use_cache = False
365
+
366
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
367
+ if not isinstance(past_key_values, (type(None), Cache)):
368
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
369
+
370
+ if inputs_embeds is None:
371
+ inputs_embeds = self.embed_tokens(input_ids)
372
+
373
+ if use_cache and past_key_values is None:
374
+ past_key_values = DynamicCache()
375
+
376
+ if cache_position is None:
377
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
378
+ cache_position = torch.arange(
379
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
380
+ )
381
+
382
+ if position_ids is None:
383
+ position_ids = cache_position.unsqueeze(0)
384
+
385
+ causal_mask = self._update_causal_mask(
386
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
387
+ )
388
+
389
+ hidden_states = inputs_embeds
390
+
391
+ # create position embeddings to be shared across the decoder layers
392
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
393
+
394
+ # decoder layers
395
+ all_hidden_states = () if output_hidden_states else None
396
+ all_self_attns = () if output_attentions else None
397
+
398
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
399
+ if output_hidden_states:
400
+ all_hidden_states += (hidden_states,)
401
+
402
+ layer_outputs = decoder_layer(
403
+ hidden_states,
404
+ attention_mask=causal_mask,
405
+ position_ids=position_ids,
406
+ past_key_value=past_key_values,
407
+ output_attentions=output_attentions,
408
+ use_cache=use_cache,
409
+ cache_position=cache_position,
410
+ position_embeddings=position_embeddings,
411
+ **flash_attn_kwargs,
412
+ )
413
+
414
+ hidden_states = layer_outputs[0]
415
+
416
+ if output_attentions:
417
+ all_self_attns += (layer_outputs[1],)
418
+
419
+ hidden_states = self.norm(hidden_states)
420
+
421
+ # add hidden states from the last decoder layer
422
+ if output_hidden_states:
423
+ all_hidden_states += (hidden_states,)
424
+
425
+ return BaseModelOutputWithPast(
426
+ last_hidden_state=hidden_states,
427
+ past_key_values=past_key_values if use_cache else None,
428
+ hidden_states=all_hidden_states,
429
+ attentions=all_self_attns,
430
+ )
431
+
432
+ def _update_causal_mask(
433
+ self,
434
+ attention_mask: Union[torch.Tensor, "BlockMask"],
435
+ input_tensor: torch.Tensor,
436
+ cache_position: torch.Tensor,
437
+ past_key_values: Cache,
438
+ output_attentions: bool = False,
439
+ ):
440
+ if self.config._attn_implementation == "flash_attention_2":
441
+ if attention_mask is not None and (attention_mask == 0.0).any():
442
+ return attention_mask
443
+ return None
444
+ if self.config._attn_implementation == "flex_attention":
445
+ if isinstance(attention_mask, torch.Tensor):
446
+ attention_mask = make_flex_block_causal_mask(attention_mask)
447
+ return attention_mask
448
+
449
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
450
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
451
+ # to infer the attention mask.
452
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
453
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
454
+
455
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
456
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
457
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
458
+ attention_mask,
459
+ inputs_embeds=input_tensor,
460
+ past_key_values_length=past_seen_tokens,
461
+ is_training=self.training,
462
+ ):
463
+ return None
464
+
465
+ dtype = input_tensor.dtype
466
+ sequence_length = input_tensor.shape[1]
467
+ if using_compilable_cache:
468
+ target_length = past_key_values.get_max_cache_shape()
469
+ else:
470
+ target_length = (
471
+ attention_mask.shape[-1]
472
+ if isinstance(attention_mask, torch.Tensor)
473
+ else past_seen_tokens + sequence_length + 1
474
+ )
475
+
476
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
477
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
478
+ attention_mask,
479
+ sequence_length=sequence_length,
480
+ target_length=target_length,
481
+ dtype=dtype,
482
+ cache_position=cache_position,
483
+ batch_size=input_tensor.shape[0],
484
+ )
485
+
486
+ if (
487
+ self.config._attn_implementation == "sdpa"
488
+ and attention_mask is not None
489
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
490
+ and not output_attentions
491
+ ):
492
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
493
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
494
+ # Details: https://github.com/pytorch/pytorch/issues/110213
495
+ min_dtype = torch.finfo(dtype).min
496
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
497
+
498
+ return causal_mask
499
+
500
+ @staticmethod
501
+ def _prepare_4d_causal_attention_mask_with_cache_position(
502
+ attention_mask: torch.Tensor,
503
+ sequence_length: int,
504
+ target_length: int,
505
+ dtype: torch.dtype,
506
+ cache_position: torch.Tensor,
507
+ batch_size: int,
508
+ **kwargs,
509
+ ):
510
+ """
511
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
512
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
513
+
514
+ Args:
515
+ attention_mask (`torch.Tensor`):
516
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
517
+ `(batch_size, 1, query_length, key_value_length)`.
518
+ sequence_length (`int`):
519
+ The sequence length being processed.
520
+ target_length (`int`):
521
+ The target length: when generating with static cache, the mask should be as long as the static cache,
522
+ to account for the 0 padding, the part of the cache that is not filled yet.
523
+ dtype (`torch.dtype`):
524
+ The dtype to use for the 4D attention mask.
525
+ cache_position (`torch.Tensor`):
526
+ Indices depicting the position of the input sequence tokens in the sequence.
527
+ batch_size (`torch.Tensor`):
528
+ Batch size.
529
+ """
530
+ if attention_mask is not None and attention_mask.dim() == 4:
531
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
532
+ causal_mask = attention_mask
533
+ else:
534
+ min_dtype = torch.finfo(dtype).min
535
+ causal_mask = torch.full(
536
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
537
+ )
538
+ if sequence_length != 1:
539
+ causal_mask = torch.triu(causal_mask, diagonal=1)
540
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
541
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
542
+ if attention_mask is not None:
543
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
544
+ mask_length = attention_mask.shape[-1]
545
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
546
+ causal_mask.device
547
+ )
548
+ padding_mask = padding_mask == 0
549
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
550
+ padding_mask, min_dtype
551
+ )
552
+
553
+ return causal_mask
554
+
555
+
556
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
557
+
558
+
559
+ @auto_docstring
560
+ class LlamaMlaForCausalLM(LlamaMlaPreTrainedModel, GenerationMixin):
561
+ _tied_weights_keys = ["lm_head.weight"]
562
+ _tp_plan = {"lm_head": "colwise_rep"}
563
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
564
+
565
+ def __init__(self, config):
566
+ super().__init__(config)
567
+ self.model = LlamaMlaModel(config)
568
+ self.vocab_size = config.vocab_size
569
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
570
+
571
+ # Initialize weights and apply final processing
572
+ self.post_init()
573
+
574
+ def get_input_embeddings(self):
575
+ return self.model.embed_tokens
576
+
577
+ def set_input_embeddings(self, value):
578
+ self.model.embed_tokens = value
579
+
580
+ def get_output_embeddings(self):
581
+ return self.lm_head
582
+
583
+ def set_output_embeddings(self, new_embeddings):
584
+ self.lm_head = new_embeddings
585
+
586
+ def set_decoder(self, decoder):
587
+ self.model = decoder
588
+
589
+ def get_decoder(self):
590
+ return self.model
591
+
592
+ @can_return_tuple
593
+ @auto_docstring
594
+ def forward(
595
+ self,
596
+ input_ids: Optional[torch.LongTensor] = None,
597
+ attention_mask: Optional[torch.Tensor] = None,
598
+ position_ids: Optional[torch.LongTensor] = None,
599
+ past_key_values: Optional[Cache] = None,
600
+ inputs_embeds: Optional[torch.FloatTensor] = None,
601
+ labels: Optional[torch.LongTensor] = None,
602
+ use_cache: Optional[bool] = None,
603
+ output_attentions: Optional[bool] = None,
604
+ output_hidden_states: Optional[bool] = None,
605
+ cache_position: Optional[torch.LongTensor] = None,
606
+ logits_to_keep: Union[int, torch.Tensor] = 0,
607
+ **kwargs: Unpack[KwargsForCausalLM],
608
+ ) -> CausalLMOutputWithPast:
609
+ r"""
610
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
611
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
612
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
613
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
614
+
615
+ Example:
616
+
617
+ ```python
618
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
619
+
620
+ >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
621
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
622
+
623
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
624
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
625
+
626
+ >>> # Generate
627
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
628
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
629
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
630
+ ```"""
631
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
632
+ output_hidden_states = (
633
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
634
+ )
635
+
636
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
637
+ outputs: BaseModelOutputWithPast = self.model(
638
+ input_ids=input_ids,
639
+ attention_mask=attention_mask,
640
+ position_ids=position_ids,
641
+ past_key_values=past_key_values,
642
+ inputs_embeds=inputs_embeds,
643
+ use_cache=use_cache,
644
+ output_attentions=output_attentions,
645
+ output_hidden_states=output_hidden_states,
646
+ cache_position=cache_position,
647
+ **kwargs,
648
+ )
649
+
650
+ hidden_states = outputs.last_hidden_state
651
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
652
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
653
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
654
+
655
+ loss = None
656
+ if labels is not None:
657
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
658
+
659
+ return CausalLMOutputWithPast(
660
+ loss=loss,
661
+ logits=logits,
662
+ past_key_values=outputs.past_key_values,
663
+ hidden_states=outputs.hidden_states,
664
+ attentions=outputs.attentions,
665
+ )
666
+
667
+ __all__ = [
668
+ "LlamaMlaForCausalLM",
669
+ "LlamaMlaModel",
670
+ "LlamaMlaPreTrainedModel",
671
+ ]