appledora commited on
Commit
036d5e2
·
verified ·
1 Parent(s): 42c58e8

Upload modeling_recast_llama.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_recast_llama.py +843 -0
modeling_recast_llama.py ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # filename: recastmlp_llama_model.py
2
+ from .configuration_recast_llama import RECAST1B_llama
3
+ from transformers import PreTrainedModel
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from typing import Optional, Tuple, Union, List
9
+ from transformers import AutoConfig
10
+ from transformers.utils import logging
11
+ from transformers.cache_utils import Cache, StaticCache
12
+ from transformers.modeling_outputs import CausalLMOutputWithPast
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
15
+ from transformers.models.llama.modeling_llama import (
16
+ LlamaDecoderLayer,
17
+ LlamaRotaryEmbedding,
18
+ LlamaRMSNorm,
19
+ apply_rotary_pos_emb,
20
+ repeat_kv,
21
+ )
22
+ from transformers.modeling_outputs import BaseModelOutputWithPast
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class MLPTemplateBank(nn.Module):
28
+ def __init__(self, config, coef_rows, coef_columns):
29
+ super().__init__()
30
+ self.hidden_size = config.hidden_size
31
+ self.intermediate_size = config.intermediate_size
32
+ self.coef_shape = (coef_rows, coef_columns)
33
+
34
+ assert coef_columns is not None, "coef_columns must not be None"
35
+
36
+ # Ensure divisibility for proper reshaping
37
+ assert (self.hidden_size * self.intermediate_size) % coef_rows == 0, \
38
+ f"hidden_size * intermediate_size ({self.hidden_size * self.intermediate_size}) must be divisible by coef_rows ({coef_rows})"
39
+
40
+ template_size = self.hidden_size * self.intermediate_size // coef_rows
41
+
42
+ self.up_templates = nn.Parameter(
43
+ torch.randn(coef_columns, template_size)
44
+ )
45
+ self.gate_templates = nn.Parameter(
46
+ torch.randn(coef_columns, template_size)
47
+ )
48
+
49
+ # Better initialization
50
+ nn.init.xavier_uniform_(self.up_templates)
51
+ nn.init.xavier_uniform_(self.gate_templates)
52
+
53
+ def forward(self, up_coeffs, gate_coeffs):
54
+ # Compute chunked weights
55
+ up_chunks = torch.matmul(up_coeffs, self.up_templates)
56
+ gate_chunks = torch.matmul(gate_coeffs, self.gate_templates)
57
+
58
+ # Reshape to final weight matrices
59
+ up_weights = up_chunks.reshape(self.intermediate_size, self.hidden_size)
60
+ gate_weights = gate_chunks.reshape(self.intermediate_size, self.hidden_size)
61
+
62
+ return up_weights, gate_weights
63
+
64
+ class SharedLlamaMLP(nn.Module):
65
+ def __init__(self, config, bank):
66
+ super().__init__()
67
+ self.config = config
68
+ self.bank = bank
69
+ self.hidden_size = config.hidden_size
70
+ self.intermediate_size = config.intermediate_size
71
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
72
+
73
+ # Initialize coefficients with proper shapes
74
+ self.up_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
75
+ self.gate_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
76
+
77
+ # Initialize with small random values instead of ones, then orthogonalize
78
+ nn.init.orthogonal_(self.up_coefficients)
79
+ nn.init.orthogonal_(self.gate_coefficients)
80
+
81
+ if config.mlp_bias:
82
+ self.gate_bias = nn.Parameter(torch.zeros(self.intermediate_size))
83
+ self.up_bias = nn.Parameter(torch.zeros(self.intermediate_size))
84
+ else:
85
+ self.register_parameter("gate_bias", None)
86
+ self.register_parameter("up_bias", None)
87
+
88
+ self.act_fn = F.silu
89
+
90
+ def forward(self, x):
91
+ # Generate weights using template bank
92
+ up_weights, gate_weights = self.bank(
93
+ self.up_coefficients,
94
+ self.gate_coefficients # Fixed order
95
+ )
96
+
97
+ # Apply SwiGLU: SiLU(gate * x) * up * x
98
+ hidden_states = self.act_fn(F.linear(x, gate_weights, self.gate_bias)) * F.linear(x, up_weights, self.up_bias)
99
+ output = self.down_proj(hidden_states)
100
+
101
+ return output
102
+
103
+ class AttTemplateBank(nn.Module):
104
+ def __init__(self, config, coef_rows, coef_columns):
105
+ super().__init__()
106
+ self.hidden_size = config.hidden_size
107
+ self.num_heads = config.num_attention_heads
108
+ self.head_dim = config.hidden_size // config.num_attention_heads
109
+ self.num_key_value_heads = getattr(config, 'num_key_value_heads', config.num_attention_heads)
110
+ self.kv_dim = self.num_key_value_heads * self.head_dim
111
+ self.coef_shape = (coef_rows, coef_columns)
112
+
113
+ # Ensure divisibility
114
+ assert (self.hidden_size * self.hidden_size) % coef_rows == 0, \
115
+ "Q projection size must be divisible by coef_rows"
116
+ assert (self.kv_dim * self.hidden_size) % coef_rows == 0, \
117
+ "K/V projection size must be divisible by coef_rows"
118
+
119
+ # Create templates for Q, K, V
120
+ self.q_templates = nn.Parameter(
121
+ torch.randn(coef_columns, self.hidden_size * self.hidden_size // coef_rows)
122
+ )
123
+ self.k_templates = nn.Parameter(
124
+ torch.randn(coef_columns, self.kv_dim * self.hidden_size // coef_rows)
125
+ )
126
+ self.v_templates = nn.Parameter(
127
+ torch.randn(coef_columns, self.kv_dim * self.hidden_size // coef_rows)
128
+ )
129
+
130
+ # Initialize templates
131
+ nn.init.xavier_uniform_(self.q_templates)
132
+ nn.init.xavier_uniform_(self.k_templates)
133
+ nn.init.xavier_uniform_(self.v_templates)
134
+
135
+ def forward(self, q_coeffs, k_coeffs, v_coeffs):
136
+ # Compute chunked weights
137
+ q_chunks = torch.matmul(q_coeffs, self.q_templates)
138
+ k_chunks = torch.matmul(k_coeffs, self.k_templates)
139
+ v_chunks = torch.matmul(v_coeffs, self.v_templates)
140
+
141
+ # Reshape to final weight matrices
142
+ q_weights = q_chunks.reshape(self.hidden_size, self.hidden_size)
143
+ k_weights = k_chunks.reshape(self.kv_dim, self.hidden_size)
144
+ v_weights = v_chunks.reshape(self.kv_dim, self.hidden_size)
145
+
146
+ return q_weights, k_weights, v_weights
147
+
148
+ class SharedLlamaAttention(nn.Module):
149
+ def __init__(self, config, layer_idx: Optional[int] = None, bank: Optional[AttTemplateBank] = None):
150
+ super().__init__()
151
+ self.config = config
152
+ self.bank = bank
153
+ self.layer_idx = layer_idx
154
+ self.attention_dropout = config.attention_dropout
155
+ self.hidden_size = config.hidden_size
156
+ self.num_heads = config.num_attention_heads
157
+ self.head_dim = self.hidden_size // self.num_heads
158
+ self.num_key_value_heads = getattr(config, 'num_key_value_heads', config.num_attention_heads)
159
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
160
+ self.max_position_embeddings = config.max_position_embeddings
161
+ self.rope_theta = getattr(config, 'rope_theta', 10000.0)
162
+ self.is_causal = True
163
+
164
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=getattr(config, 'attention_bias', False))
165
+ self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
166
+
167
+ # Initialize coefficients with proper shapes
168
+ self.q_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
169
+ self.k_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
170
+ self.v_coefficients = nn.Parameter(torch.randn(bank.coef_shape))
171
+
172
+ # Initialize with small random values
173
+ nn.init.orthogonal_(self.q_coefficients)
174
+ nn.init.orthogonal_(self.k_coefficients)
175
+ nn.init.orthogonal_(self.v_coefficients)
176
+
177
+ def forward(
178
+ self,
179
+ hidden_states,
180
+ attention_mask=None,
181
+ past_key_value=None,
182
+ cache_position=None,
183
+ position_embeddings=None,
184
+ position_ids=None,
185
+ output_attentions=False,
186
+ use_cache=False,
187
+ **kwargs,
188
+ ):
189
+ bsz, q_len, _ = hidden_states.size()
190
+
191
+ # Generate weights using template bank
192
+ q_weights, k_weights, v_weights = self.bank(
193
+ self.q_coefficients,
194
+ self.k_coefficients,
195
+ self.v_coefficients
196
+ )
197
+
198
+ # Apply projections
199
+ query_states = F.linear(hidden_states, q_weights)
200
+ key_states = F.linear(hidden_states, k_weights)
201
+ value_states = F.linear(hidden_states, v_weights)
202
+
203
+ # Reshape for multi-head attention
204
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
205
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
206
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
207
+
208
+ # Apply rotary embeddings
209
+ if position_embeddings is None:
210
+ cos, sin = self.rotary_emb(value_states, position_ids)
211
+ else:
212
+ cos, sin = position_embeddings
213
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
214
+
215
+ # Handle past key values
216
+ if past_key_value is not None:
217
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
218
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
219
+
220
+ # Repeat key/value for grouped query attention
221
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
222
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
223
+
224
+ # Compute attention
225
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
226
+
227
+ if attention_mask is not None:
228
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
229
+ attn_weights = attn_weights + causal_mask
230
+
231
+ # Apply softmax and dropout
232
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
233
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
234
+ attn_output = torch.matmul(attn_weights, value_states)
235
+
236
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
237
+ raise ValueError(
238
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
239
+ f" {attn_output.size()}"
240
+ )
241
+
242
+ attn_output = attn_output.transpose(1, 2).contiguous()
243
+ attn_output = attn_output.reshape(bsz, q_len, -1)
244
+ attn_output = self.o_proj(attn_output)
245
+
246
+ if not output_attentions:
247
+ attn_weights = None
248
+
249
+ return attn_output, attn_weights, past_key_value
250
+
251
+
252
+ def fixed_cross_entropy(
253
+ source,
254
+ target,
255
+ num_items_in_batch: int = None,
256
+ ignore_index: int = -100,
257
+ **kwargs,
258
+ ):
259
+ reduction = "sum" if num_items_in_batch is not None else "mean"
260
+ loss = nn.functional.cross_entropy(
261
+ source, target, ignore_index=ignore_index, reduction=reduction
262
+ )
263
+ if reduction == "sum":
264
+ loss = loss / num_items_in_batch
265
+ return loss
266
+
267
+
268
+ class RECAST1B_llamaModel(PreTrainedModel):
269
+ config_class = RECAST1B_llama
270
+ base_model_prefix = "llama"
271
+ supports_gradient_checkpointing = True
272
+
273
+ def __init__(self, config):
274
+ super().__init__(config)
275
+ self.padding_idx = config.pad_token_id
276
+ self.vocab_size = config.vocab_size
277
+
278
+ self.embed_tokens = nn.Embedding(
279
+ config.vocab_size, config.hidden_size, self.padding_idx
280
+ )
281
+
282
+ original_config = AutoConfig.from_pretrained(
283
+ "meta-llama/Llama-3.2-1b", trust_remote_code=True
284
+ )
285
+ self.rotary_emb = LlamaRotaryEmbedding(
286
+ config=original_config,
287
+ )
288
+
289
+ # Create template banks first
290
+ self.mlp_banks = []
291
+ self.attn_banks = []
292
+ layers_per_group = config.num_hidden_layers // config.num_groups
293
+ # Explicitly calculate coef_width if not provided in config
294
+ if hasattr(config, "coef_width") and config.coef_width is not None:
295
+ coef_width = config.coef_width
296
+ else:
297
+ coef_width = config.coef_height * layers_per_group
298
+ config.coef_width = coef_width
299
+ print(
300
+ f"Model config: num_groups={config.num_groups}, layers_per_group={layers_per_group}"
301
+ )
302
+ print(f"Coefficient shape: ({config.coef_height}, {config.coef_width})")
303
+ mlp_banks = nn.ModuleList(
304
+ [
305
+ MLPTemplateBank(
306
+ config=config, coef_rows=config.coef_height, coef_columns=coef_width
307
+ )
308
+ for _ in range(config.num_groups)
309
+ ]
310
+ )
311
+
312
+ attn_banks = nn.ModuleList(
313
+ [
314
+ AttTemplateBank(
315
+ config=config, coef_rows=config.coef_height, coef_columns=coef_width
316
+ )
317
+ for _ in range(config.num_groups)
318
+ ]
319
+ )
320
+ self.mlp_banks = mlp_banks
321
+ self.attn_banks = attn_banks
322
+ # Create layers using LlamaDecoderLayer but replace MLPs
323
+ self.layers = nn.ModuleList()
324
+ for layer_idx in range(config.num_hidden_layers):
325
+ # Create standard LlamaDecoderLayer
326
+ decoder_layer = LlamaDecoderLayer(config, layer_idx)
327
+
328
+ # Replace its MLP with our SharedLlamaMLP
329
+ group_idx = layer_idx // layers_per_group
330
+ decoder_layer.mlp = SharedLlamaMLP(config, self.mlp_banks[group_idx])
331
+ decoder_layer.self_attn = SharedLlamaAttention(
332
+ config, layer_idx, self.attn_banks[group_idx]
333
+ )
334
+
335
+ self.layers.append(decoder_layer)
336
+
337
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
338
+ self.gradient_checkpointing = False
339
+
340
+ def forward(
341
+ self,
342
+ input_ids: torch.LongTensor = None,
343
+ attention_mask: Optional[torch.Tensor] = None,
344
+ position_ids: Optional[torch.LongTensor] = None,
345
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
346
+ inputs_embeds: Optional[torch.FloatTensor] = None,
347
+ use_cache: Optional[bool] = None,
348
+ output_attentions: Optional[bool] = None,
349
+ output_hidden_states: Optional[bool] = None,
350
+ return_dict: Optional[bool] = None,
351
+ cache_position: Optional[torch.LongTensor] = None,
352
+ **flash_attn_kwargs,
353
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
354
+ output_attentions = (
355
+ output_attentions
356
+ if output_attentions is not None
357
+ else self.config.output_attentions
358
+ )
359
+ output_hidden_states = (
360
+ output_hidden_states
361
+ if output_hidden_states is not None
362
+ else self.config.output_hidden_states
363
+ )
364
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
365
+ return_dict = (
366
+ return_dict if return_dict is not None else self.config.use_return_dict
367
+ )
368
+
369
+ if (input_ids is None) ^ (inputs_embeds is not None):
370
+ raise ValueError(
371
+ "You must specify exactly one of input_ids or inputs_embeds"
372
+ )
373
+
374
+ if self.gradient_checkpointing and self.training and use_cache:
375
+ logger.warning_once(
376
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
377
+ )
378
+ use_cache = False
379
+
380
+ if inputs_embeds is None:
381
+ inputs_embeds = self.embed_tokens(input_ids)
382
+ # Set up cache position if not provided
383
+ if cache_position is None:
384
+ past_seen_tokens = (
385
+ 0
386
+ if past_key_values is None
387
+ else (
388
+ past_key_values.get_seq_length()
389
+ if isinstance(past_key_values, Cache)
390
+ else past_key_values[0][0].size(-2) if past_key_values else 0
391
+ )
392
+ )
393
+ cache_position = torch.arange(
394
+ past_seen_tokens,
395
+ past_seen_tokens + inputs_embeds.shape[1],
396
+ device=inputs_embeds.device,
397
+ )
398
+ # Create position embeddings to be shared across the decoder layers
399
+ # Set up position IDs if not provided
400
+ if position_ids is None:
401
+ position_ids = cache_position.unsqueeze(0)
402
+ # Get updated causal mask
403
+ causal_mask = self._update_causal_mask(
404
+ attention_mask,
405
+ inputs_embeds,
406
+ cache_position,
407
+ past_key_values,
408
+ output_attentions,
409
+ )
410
+ hidden_states = inputs_embeds
411
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
412
+
413
+ # Initialize outputs
414
+ all_hidden_states = () if output_hidden_states else None
415
+ all_self_attns = () if output_attentions else None
416
+ next_decoder_cache = None
417
+
418
+ # Process through layers
419
+ for decoder_layer in self.layers:
420
+ if output_hidden_states:
421
+ all_hidden_states += (hidden_states,)
422
+
423
+ if self.gradient_checkpointing and self.training:
424
+ layer_outputs = self._gradient_checkpointing_func(
425
+ decoder_layer.__call__,
426
+ hidden_states,
427
+ causal_mask,
428
+ position_ids,
429
+ past_key_values,
430
+ output_attentions,
431
+ use_cache,
432
+ position_embeddings,
433
+ )
434
+ else:
435
+ layer_outputs = decoder_layer(
436
+ hidden_states,
437
+ attention_mask=causal_mask,
438
+ position_ids=position_ids,
439
+ past_key_value=past_key_values,
440
+ output_attentions=output_attentions,
441
+ use_cache=use_cache,
442
+ position_embeddings=position_embeddings,
443
+ **flash_attn_kwargs,
444
+ )
445
+
446
+ hidden_states = layer_outputs[0]
447
+
448
+ if use_cache:
449
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
450
+
451
+ if output_attentions:
452
+ all_self_attns += (layer_outputs[1],)
453
+
454
+ # Final layer norm
455
+ hidden_states = self.norm(hidden_states)
456
+
457
+ # Add last hidden state
458
+ if output_hidden_states:
459
+ all_hidden_states += (hidden_states,)
460
+
461
+ next_cache = next_decoder_cache if use_cache else None
462
+
463
+ if not return_dict:
464
+ return tuple(
465
+ v
466
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
467
+ if v is not None
468
+ )
469
+
470
+ return BaseModelOutputWithPast(
471
+ last_hidden_state=hidden_states,
472
+ past_key_values=next_cache,
473
+ hidden_states=all_hidden_states,
474
+ attentions=all_self_attns,
475
+ )
476
+
477
+ @classmethod
478
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
479
+ if isinstance(
480
+ pretrained_model_name_or_path, str
481
+ ) and pretrained_model_name_or_path.endswith(".pt"):
482
+ print("Loading from local checkpoint")
483
+ # Load from local checkpoint
484
+ config = kwargs.get("config", None)
485
+ if config is None:
486
+ config = AutoConfig.from_pretrained(
487
+ pretrained_model_name_or_path, trust_remote_code=True
488
+ )
489
+
490
+ model = cls(config)
491
+ checkpoint = torch.load(pretrained_model_name_or_path, map_location="cpu")
492
+ state_dict = checkpoint["model_state_dict"]
493
+ logger.info(
494
+ f"Loaded checkpoint from epoch {checkpoint.get('epoch')} with loss {checkpoint.get('loss')}"
495
+ )
496
+
497
+ missing_keys, unexpected_keys = model.load_state_dict(
498
+ state_dict, strict=False
499
+ )
500
+
501
+ if len(missing_keys) > 0:
502
+ logger.warning(f"Missing keys: {missing_keys}")
503
+ if len(unexpected_keys) > 0:
504
+ logger.warning(f"Unexpected keys: {unexpected_keys}")
505
+
506
+ return model
507
+ else:
508
+ print("Loading from hub")
509
+ # Load from hub using parent's from_pretrained
510
+ return super().from_pretrained(
511
+ pretrained_model_name_or_path, *model_args, **kwargs
512
+ )
513
+
514
+ def get_input_embeddings(self):
515
+ return self.embed_tokens
516
+
517
+ def set_input_embeddings(self, value):
518
+ self.embed_tokens = value
519
+
520
+ def _update_causal_mask(
521
+ self,
522
+ attention_mask: torch.Tensor,
523
+ input_tensor: torch.Tensor,
524
+ cache_position: torch.Tensor,
525
+ past_key_values: Cache,
526
+ output_attentions: bool,
527
+ ):
528
+ if self.config._attn_implementation == "flash_attention_2":
529
+ if attention_mask is not None and 0.0 in attention_mask:
530
+ return attention_mask
531
+ return None
532
+
533
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
534
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
535
+ # to infer the attention mask.
536
+ past_seen_tokens = (
537
+ past_key_values.get_seq_length() if past_key_values is not None else 0
538
+ )
539
+ using_static_cache = isinstance(past_key_values, StaticCache)
540
+
541
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
542
+ if (
543
+ self.config._attn_implementation == "sdpa"
544
+ and not using_static_cache
545
+ and not output_attentions
546
+ ):
547
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
548
+ attention_mask,
549
+ inputs_embeds=input_tensor,
550
+ past_key_values_length=past_seen_tokens,
551
+ is_training=self.training,
552
+ ):
553
+ return None
554
+
555
+ dtype, device = input_tensor.dtype, input_tensor.device
556
+ sequence_length = input_tensor.shape[1]
557
+ if using_static_cache:
558
+ target_length = past_key_values.get_max_cache_shape()
559
+ else:
560
+ target_length = (
561
+ attention_mask.shape[-1]
562
+ if isinstance(attention_mask, torch.Tensor)
563
+ else past_seen_tokens + sequence_length + 1
564
+ )
565
+
566
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
567
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
568
+ attention_mask,
569
+ sequence_length=sequence_length,
570
+ target_length=target_length,
571
+ dtype=dtype,
572
+ device=device,
573
+ cache_position=cache_position,
574
+ batch_size=input_tensor.shape[0],
575
+ )
576
+
577
+ if (
578
+ self.config._attn_implementation == "sdpa"
579
+ and attention_mask is not None
580
+ and attention_mask.device.type == "cuda"
581
+ and not output_attentions
582
+ ):
583
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
584
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
585
+ # Details: https://github.com/pytorch/pytorch/issues/110213
586
+ min_dtype = torch.finfo(dtype).min
587
+ causal_mask = AttentionMaskConverter._unmask_unattended(
588
+ causal_mask, min_dtype
589
+ )
590
+
591
+ return causal_mask
592
+
593
+ @staticmethod
594
+ def _prepare_4d_causal_attention_mask_with_cache_position(
595
+ attention_mask: torch.Tensor,
596
+ sequence_length: int,
597
+ target_length: int,
598
+ dtype: torch.dtype,
599
+ device: torch.device,
600
+ cache_position: torch.Tensor,
601
+ batch_size: int,
602
+ **kwargs,
603
+ ):
604
+ if attention_mask is not None and attention_mask.dim() == 4:
605
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
606
+ causal_mask = attention_mask
607
+ else:
608
+ min_dtype = torch.finfo(dtype).min
609
+ causal_mask = torch.full(
610
+ (sequence_length, target_length),
611
+ fill_value=min_dtype,
612
+ dtype=dtype,
613
+ device=device,
614
+ )
615
+ if sequence_length != 1:
616
+ causal_mask = torch.triu(causal_mask, diagonal=1)
617
+ causal_mask *= torch.arange(
618
+ target_length, device=device
619
+ ) > cache_position.reshape(-1, 1)
620
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
621
+ if attention_mask is not None:
622
+ causal_mask = (
623
+ causal_mask.clone()
624
+ ) # copy to contiguous memory for in-place edit
625
+ mask_length = attention_mask.shape[-1]
626
+ padding_mask = (
627
+ causal_mask[:, :, :, :mask_length]
628
+ + attention_mask[:, None, None, :]
629
+ )
630
+ padding_mask = padding_mask == 0
631
+ causal_mask[:, :, :, :mask_length] = causal_mask[
632
+ :, :, :, :mask_length
633
+ ].masked_fill(padding_mask, min_dtype)
634
+
635
+ return causal_mask
636
+
637
+
638
+ class RECAST1B_LlamaForCausalLM(PreTrainedModel, GenerationMixin):
639
+ _tied_weights_keys = ["lm_head.weight"]
640
+ _tp_plan = {"lm_head": "colwise_rep"}
641
+ config_class = RECAST1B_llama
642
+ base_model_prefix = "llama"
643
+ supports_gradient_checkpointing = True
644
+
645
+ def __init__(self, config):
646
+ super().__init__(config)
647
+ self.model = RECAST1B_llamaModel(config)
648
+ self.vocab_size = config.vocab_size
649
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
650
+
651
+ # Initialize weights and apply final processing
652
+ self.post_init()
653
+
654
+ def get_input_embeddings(self):
655
+ return self.model.embed_tokens
656
+
657
+ def set_input_embeddings(self, value):
658
+ self.model.embed_tokens = value
659
+
660
+ def get_output_embeddings(self):
661
+ return self.lm_head
662
+
663
+ def set_output_embeddings(self, new_embeddings):
664
+ self.lm_head = new_embeddings
665
+
666
+ def set_decoder(self, decoder):
667
+ self.model = decoder
668
+
669
+ def get_decoder(self):
670
+ return self.model
671
+
672
+ def loss_function(
673
+ self,
674
+ logits,
675
+ labels,
676
+ vocab_size: int,
677
+ num_items_in_batch: int = None,
678
+ ignore_index: int = -100,
679
+ **kwargs,
680
+ ):
681
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
682
+ logits = logits.float()
683
+ # Shift so that tokens < n predict n
684
+ shift_logits = logits[..., :-1, :].contiguous()
685
+ shift_labels = labels[..., 1:].contiguous()
686
+ # Flatten the tokens
687
+ shift_logits = shift_logits.view(-1, vocab_size)
688
+ shift_labels = shift_labels.view(-1)
689
+ # Enable model parallelism
690
+ shift_labels = shift_labels.to(shift_logits.device)
691
+ loss = fixed_cross_entropy(
692
+ shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
693
+ )
694
+ return loss
695
+
696
+ def forward(
697
+ self,
698
+ input_ids: torch.LongTensor = None,
699
+ attention_mask: Optional[torch.Tensor] = None,
700
+ position_ids: Optional[torch.LongTensor] = None,
701
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
702
+ inputs_embeds: Optional[torch.FloatTensor] = None,
703
+ labels: Optional[torch.LongTensor] = None,
704
+ use_cache: Optional[bool] = None,
705
+ output_attentions: Optional[bool] = None,
706
+ output_hidden_states: Optional[bool] = None,
707
+ return_dict: Optional[bool] = None,
708
+ cache_position: Optional[torch.LongTensor] = None,
709
+ num_logits_to_keep: int = 0,
710
+ **kwargs,
711
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
712
+ """
713
+ Args:
714
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
715
+ Labels for computing the masked language modeling loss. Indices should be in
716
+ `[0, ..., config.vocab_size]` or -100 (masked tokens).
717
+ num_logits_to_keep (`int`, *optional*):
718
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate all logits.
719
+ """
720
+ output_attentions = (
721
+ output_attentions
722
+ if output_attentions is not None
723
+ else self.config.output_attentions
724
+ )
725
+ output_hidden_states = (
726
+ output_hidden_states
727
+ if output_hidden_states is not None
728
+ else self.config.output_hidden_states
729
+ )
730
+ return_dict = (
731
+ return_dict if return_dict is not None else self.config.use_return_dict
732
+ )
733
+
734
+ outputs = self.model(
735
+ input_ids=input_ids,
736
+ attention_mask=attention_mask,
737
+ position_ids=position_ids,
738
+ past_key_values=past_key_values,
739
+ inputs_embeds=inputs_embeds,
740
+ use_cache=use_cache,
741
+ output_attentions=output_attentions,
742
+ output_hidden_states=output_hidden_states,
743
+ return_dict=return_dict,
744
+ cache_position=cache_position,
745
+ **kwargs,
746
+ )
747
+
748
+ hidden_states = outputs[0]
749
+ # Only compute necessary logits
750
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
751
+
752
+ loss = None
753
+ if labels is not None:
754
+ # Calculate batch size for loss function
755
+ num_items_in_batch = (
756
+ input_ids.size(0) if input_ids is not None else inputs_embeds.size(0)
757
+ )
758
+ loss = self.loss_function(
759
+ logits=logits,
760
+ labels=labels,
761
+ vocab_size=self.config.vocab_size,
762
+ num_items_in_batch=num_items_in_batch,
763
+ **kwargs,
764
+ )
765
+
766
+ if not return_dict:
767
+ output = (logits,) + outputs[1:]
768
+ return (loss,) + output if loss is not None else output
769
+
770
+ return CausalLMOutputWithPast(
771
+ loss=loss,
772
+ logits=logits,
773
+ past_key_values=outputs.past_key_values,
774
+ hidden_states=outputs.hidden_states,
775
+ attentions=outputs.attentions,
776
+ )
777
+
778
+ def prepare_inputs_for_generation(
779
+ self,
780
+ input_ids,
781
+ past_key_values=None,
782
+ attention_mask=None,
783
+ inputs_embeds=None,
784
+ **kwargs,
785
+ ):
786
+ if past_key_values:
787
+ input_ids = input_ids[:, -1:]
788
+
789
+ position_ids = kwargs.get("position_ids", None)
790
+ if attention_mask is not None and position_ids is None:
791
+ # create position_ids on the fly for batch generation
792
+ position_ids = attention_mask.long().cumsum(-1) - 1
793
+ position_ids.masked_fill_(attention_mask == 0, 1)
794
+ if past_key_values:
795
+ position_ids = position_ids[:, -1].unsqueeze(-1)
796
+
797
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
798
+ if inputs_embeds is not None and past_key_values is None:
799
+ model_inputs = {"inputs_embeds": inputs_embeds}
800
+ else:
801
+ model_inputs = {"input_ids": input_ids}
802
+
803
+ model_inputs.update(
804
+ {
805
+ "position_ids": position_ids,
806
+ "past_key_values": past_key_values,
807
+ "use_cache": kwargs.get("use_cache"),
808
+ "attention_mask": attention_mask,
809
+ }
810
+ )
811
+ return model_inputs
812
+
813
+ @classmethod
814
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
815
+ if isinstance(
816
+ pretrained_model_name_or_path, str
817
+ ) and pretrained_model_name_or_path.endswith(".pt"):
818
+ print("Loading from local checkpoint")
819
+ config = kwargs.get("config", None)
820
+ if config is None:
821
+ config = AutoConfig.from_pretrained(
822
+ pretrained_model_name_or_path, trust_remote_code=True
823
+ )
824
+ model = torch.load(pretrained_model_name_or_path, map_location="cpu")
825
+ # model = cls(config)
826
+ # checkpoint = torch.load(pretrained_model_name_or_path, map_location="cpu")
827
+ # state_dict = checkpoint["model_state_dict"]
828
+
829
+ # missing_keys, unexpected_keys = model.load_state_dict(
830
+ # state_dict, strict=False
831
+ # )
832
+
833
+ # if len(missing_keys) > 0:
834
+ # logger.warning(f"Missing keys: {missing_keys}")
835
+ # if len(unexpected_keys) > 0:
836
+ # logger.warning(f"Unexpected keys: {unexpected_keys}")
837
+
838
+ return model
839
+ else:
840
+ print("Loading from hub")
841
+ return super().from_pretrained(
842
+ pretrained_model_name_or_path, *model_args, **kwargs
843
+ )