christian-lms commited on
Commit
97009c5
·
verified ·
1 Parent(s): 635fdbc

Delete modeling_ernie4_5_moe.py

Browse files
Files changed (1) hide show
  1. modeling_ernie4_5_moe.py +0 -1504
modeling_ernie4_5_moe.py DELETED
@@ -1,1504 +0,0 @@
1
- # Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from copy import deepcopy
16
- from dataclasses import dataclass
17
- from functools import partial
18
- from typing import Callable, Optional, Tuple, Union
19
-
20
- import torch
21
- import torch.nn.functional as F
22
- import torch.nn as nn
23
-
24
- from transformers.cache_utils import Cache, DynamicCache, StaticCache
25
- from transformers.generation import GenerationMixin
26
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
27
- from transformers.modeling_outputs import ModelOutput, MoeCausalLMOutputWithPast
28
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
29
- from transformers.modeling_attn_mask_utils import AttentionMaskConverter
30
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
31
- from transformers.processing_utils import Unpack
32
- from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, logging, is_torch_flex_attn_available
33
-
34
- from .configuration_ernie4_5_moe import Ernie4_5_MoeConfig
35
-
36
-
37
- if is_torch_flex_attn_available():
38
- from torch.nn.attention.flex_attention import BlockMask
39
-
40
- from transformers.integrations.flex_attention import make_flex_block_causal_mask
41
-
42
- logger = logging.get_logger(__name__)
43
-
44
-
45
- class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
46
-
47
- @dataclass
48
- class Erine4_5_MoeModelOutputWithPast(ModelOutput):
49
- last_hidden_state: Optional[torch.FloatTensor] = None
50
- past_key_values: Optional[Cache] = None
51
- hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
52
- attentions: Optional[tuple[torch.FloatTensor, ...]] = None
53
- router_loss: Optional[torch.FloatTensor] = None
54
- gate_logits: Optional[tuple[torch.FloatTensor, ...]] = None
55
- mtp_outputs: Optional[torch.FloatTensor] = None
56
-
57
-
58
- @dataclass
59
- class Ernie4_5_MoeCausalLMOutputWithPast(MoeCausalLMOutputWithPast):
60
- router_loss: Optional[torch.FloatTensor] = None
61
-
62
- def rotate_half(x):
63
- """Rotates half the hidden dims of the input."""
64
-
65
- x1 = x[..., 0::2]
66
- x2 = x[..., 1::2]
67
- return torch.stack((-x2, x1), dim=-1).reshape(x.shape)
68
-
69
-
70
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
71
- """
72
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
73
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
74
- """
75
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
76
- if n_rep == 1:
77
- return hidden_states
78
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
79
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
80
-
81
-
82
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
83
- """Applies Rotary Position Embedding to the query and key tensors.
84
-
85
- Args:
86
- q (`torch.Tensor`): The query tensor.
87
- k (`torch.Tensor`): The key tensor.
88
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
89
- sin (`torch.Tensor`): The sine part of the rotary embedding.
90
- position_ids (`torch.Tensor`, *optional*):
91
- Deprecated and unused.
92
- unsqueeze_dim (`int`, *optional*, defaults to 1):
93
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
94
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
95
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
96
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
97
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
98
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
99
- Returns:
100
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
101
- """
102
- orig_dtype = q.dtype
103
- sin_pos = torch.stack([sin, sin], dim=-1).reshape(*sin.shape[:-1],-1)
104
- cos_pos = torch.stack([cos, cos], dim=-1).reshape(*sin.shape[:-1],-1)
105
- q_embed = (q.float() * cos_pos) + (rotate_half(q).float() * sin_pos)
106
- k_embed = (k.float() * cos_pos) + (rotate_half(k).float() * sin_pos)
107
- return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
108
-
109
-
110
- def eager_attention_forward(
111
- module: nn.Module,
112
- query: torch.Tensor,
113
- key: torch.Tensor,
114
- value: torch.Tensor,
115
- attention_mask: Optional[torch.Tensor],
116
- scaling: float,
117
- dropout: float = 0.0,
118
- **kwargs,
119
- ):
120
- key_states = repeat_kv(key, module.num_key_value_groups)
121
- value_states = repeat_kv(value, module.num_key_value_groups)
122
-
123
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
124
- if attention_mask is not None:
125
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
126
- attn_weights = attn_weights + causal_mask.to(attn_weights.device)
127
-
128
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
129
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
130
- attn_output = torch.matmul(attn_weights, value_states)
131
- attn_output = attn_output.transpose(1, 2).contiguous()
132
-
133
- return attn_output, attn_weights
134
-
135
-
136
- def topk_gate_func(
137
- module: nn.Module,
138
- hidden_states: torch.Tensor,
139
- ):
140
- capacity = module.get_capacity(hidden_states.shape[0])
141
- with torch.autocast(device_type='cuda',dtype=torch.float32):
142
- logits = module.gate(hidden_states.float())
143
- router_loss = torch.zeros([1], dtype=torch.float32, device=hidden_states.device)
144
- router_loss.detach()
145
- return logits, capacity, router_loss
146
-
147
-
148
- class Ernie4_5_ResidualWithDropout(nn.Module):
149
- """
150
- Fused dropout implementation with residual connection support.
151
-
152
- This layer combines dropout and residual addition in a single operation for better performance,
153
- particularly on GPU devices. The dropout is conditionally applied based on the probability.
154
-
155
- Args:
156
- prob (float): Dropout probability (between 0 and 1)
157
-
158
- Attributes:
159
- prob (float): Stores the dropout probability
160
- dropout (nn.Dropout): The actual dropout layer instance
161
- """
162
-
163
- def __init__(self, prob):
164
- """
165
- Initialize the fused dropout layer.
166
-
167
- Args:
168
- prob (float): Dropout probability (0 means no dropout)
169
- """
170
- super().__init__()
171
- self.prob = prob
172
- self.dropout = nn.Dropout(p=prob)
173
-
174
- def forward(self, x, y):
175
- """
176
- Forward pass of the fused dropout layer.
177
-
178
- Args:
179
- x (torch.Tensor): Input tensor to potentially apply dropout on
180
- y (torch.Tensor): Residual tensor to add to the (possibly dropped out) x
181
-
182
- Returns:
183
- torch.Tensor: Result of x (with optional dropout) + y
184
- """
185
- if self.prob > 0:
186
- x = self.dropout(x)
187
- output = x + y
188
-
189
- return output
190
-
191
-
192
- class Ernie4_5_Attention(nn.Module):
193
- """Multi-headed attention from 'Attention Is All You Need' paper"""
194
-
195
- def __init__(self, config, layer_idx=0):
196
- """
197
- Args:
198
- config (ErnieConfig): Model configuration.
199
- layer_idx (int, optional): Index in transformer stack. Defaults to 0.
200
- """
201
- super().__init__()
202
- self.layer_idx = layer_idx
203
- self.hidden_size = config.hidden_size
204
- self.num_heads = config.num_attention_heads
205
- self.num_key_value_heads = config.num_key_value_heads if config.num_key_value_heads is not None else self.nums_head
206
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
207
- self.head_dim = self.hidden_size // self.num_heads
208
- self.freq_allocation = config.freq_allocation if hasattr(config, "freq_allocation") else 0
209
- self.scaling = self.head_dim**-0.5
210
- self.attention_dropout = getattr(config, "attention_probs_dropout_prob", 0.0)
211
- self.is_causal = True
212
-
213
- self.q_proj = nn.Linear(
214
- self.hidden_size,
215
- self.num_heads * self.head_dim,
216
- bias=config.use_bias,
217
- )
218
-
219
- self.k_proj = nn.Linear(
220
- self.hidden_size,
221
- self.num_key_value_heads * self.head_dim,
222
- bias=config.use_bias,
223
- )
224
-
225
- self.v_proj = nn.Linear(
226
- self.hidden_size,
227
- self.num_key_value_heads * self.head_dim,
228
- bias=config.use_bias,
229
- )
230
-
231
- self.o_proj = nn.Linear(
232
- self.hidden_size,
233
- self.hidden_size,
234
- bias=config.use_bias,
235
- )
236
-
237
- self.config = config
238
-
239
-
240
- def forward(
241
- self,
242
- hidden_states: torch.Tensor,
243
- attention_mask: Optional[torch.Tensor] = None,
244
- past_key_value: Optional[Cache] = None,
245
- position_ids: Optional[torch.Tensor] = None,
246
- cache_position: Optional[torch.LongTensor] = None,
247
- position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
248
- **kwargs: Unpack[FlashAttentionKwargs],
249
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
250
- B, L = hidden_states.shape[:-1]
251
-
252
- query_states = self.q_proj(hidden_states).view(B, L, self.num_heads, -1).transpose(1, 2)
253
- key_states = self.k_proj(hidden_states).view(B, L, self.num_key_value_heads, -1).transpose(1, 2)
254
- value_states = self.v_proj(hidden_states).view(B, L, self.num_key_value_heads, -1).transpose(1, 2)
255
-
256
- cos, sin = position_embeddings
257
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
258
-
259
- if past_key_value is not None:
260
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
261
- cache_kwargs = {"cache_position": cache_position}
262
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
263
-
264
- attention_interface: Callable = eager_attention_forward
265
- if self.config._attn_implementation != "eager":
266
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
267
-
268
- attn_output, attn_weights = attention_interface(
269
- self,
270
- query_states,
271
- key_states,
272
- value_states,
273
- attention_mask,
274
- dropout=0.0 if not self.training else self.attention_dropout,
275
- scaling=self.scaling,
276
- **kwargs,
277
- )
278
- attn_output = attn_output.reshape(B, L, -1).contiguous()
279
- attn_output = self.o_proj(attn_output)
280
-
281
- return attn_output, attn_weights
282
-
283
-
284
- class Ernie4_5_MLP(nn.Module):
285
- """
286
- Ernie4_5_MLP - Gated Multi-Layer Perceptron module used in Ernie model.
287
- """
288
-
289
- def __init__(self, config,intermediate_size=None):
290
- """
291
- Initialize the MLP module with configuration options.
292
-
293
- Args:
294
- config: Model configuration object with attributes:
295
- - hidden_size: int
296
- - intermediate_size: int
297
- - use_bias: bool
298
- layer_idx (int): Index of current layer (default: 0)
299
- """
300
- super().__init__()
301
- self.config = config
302
- self.hidden_size = config.hidden_size
303
- self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
304
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias)
305
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias)
306
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
307
-
308
-
309
- def forward(self, x):
310
- """
311
- Args:
312
- x (Tensor): shape [batch_size, seq_len, hidden_size]
313
-
314
- Returns:
315
- Tensor: shape [batch_size, seq_len, hidden_size]
316
- """
317
- down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
318
- return down_proj
319
-
320
-
321
- class Ernie4_5_MoeStatics(nn.Module):
322
- """
323
- Stores MoE (Mixture of Experts) statistics
324
- and expert usage information.
325
- """
326
-
327
- def __init__(self, config):
328
- """
329
- Initialize MoE statistics tracking.
330
-
331
- Args:
332
- config: Model configuration containing MoE parameters
333
- """
334
- super().__init__()
335
-
336
- num_experts = config.moe_num_experts
337
- num_experts_groups = 1
338
-
339
- self.e_score_correction_bias = nn.Parameter(
340
- torch.zeros(num_experts_groups, num_experts, dtype=torch.float32),
341
- requires_grad=False
342
- )
343
-
344
- class Ernie4_5_MoeMLP(nn.Module):
345
- """Mixture of Experts (MoE) variant of ERNIE's MLP layer."""
346
-
347
- def __init__(self,config):
348
- super().__init__()
349
- self.config = config
350
- self.k = config.moe_k
351
- self.sinkhorn_2gate = config.sinkhorn_2gate
352
- self.sinkhorn_temp = config.sinkhorn_temp
353
-
354
- moe_intermediate_size = config.moe_intermediate_size if config.moe_intermediate_size else config.intermediate_size
355
- self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
356
- if config.moe_gate_act == "softmax":
357
- self.gate_act = partial(F.softmax, dim=-1)
358
- elif config.moe_gate_act == "sigmoid":
359
- self.gate_act = F.sigmoid
360
- else:
361
- raise ValueError(f"{config.moe_gate_act} is not supported.")
362
-
363
- self.experts = nn.ModuleList(
364
- [Ernie4_5_MLP(config,moe_intermediate_size) for i in range(config.moe_num_experts)]
365
- )
366
-
367
- if config.moe_use_aux_free:
368
- self.moe_statics = Ernie4_5_MoeStatics(config)
369
-
370
- self.use_correction_bias = config.moe_use_aux_free
371
- self.num_local_experts = len(self.experts)
372
-
373
- self.shared_experts = self._init_shared_experts()
374
-
375
- def _init_shared_experts(self):
376
- """
377
- Initialize the shared expert module.
378
-
379
- Returns:
380
- shared_experts: Shared expert module, returns None if no shared experts are needed.
381
-
382
- """
383
- cfg = deepcopy(self.config)
384
- if getattr(cfg, 'moe_num_shared_experts', 0) > 0:
385
- if getattr(cfg, 'moe_intermediate_size', None):
386
- cfg.intermediate_size = cfg.moe_intermediate_size * cfg.moe_num_shared_experts
387
- else:
388
- cfg.intermediate_size = cfg.intermediate_size * cfg.moe_num_shared_experts
389
- shared_experts = Ernie4_5_MLP(cfg, cfg.intermediate_size)
390
- else:
391
- shared_experts = None
392
- return shared_experts
393
-
394
- def forward(
395
- self,
396
- input: torch.Tensor,
397
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
398
- """
399
- Forward pass through MoE layer.
400
-
401
- Args:
402
- input (Tensor): Input tensor of shape [s, d].
403
- token_type_ids: Optional tensor for token types.
404
-
405
- Returns:
406
- tuple: (output, combine_weights, router_loss, gate_logits)
407
- """
408
-
409
- if input.dim() == 3:
410
- orig_shape = input.shape
411
- input = input.reshape(-1, input.shape[-1])
412
- else:
413
- orig_shape = None
414
- assert input.dim() == 2, f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}"
415
-
416
- assert self.gate is not None
417
-
418
- gate_input = input
419
-
420
- (
421
- dispatched_input,
422
- combine_weights,
423
- dispatch_mask,
424
- scatter_index,
425
- router_loss,
426
- gate_logits,
427
- gate_prob
428
- ) = self.gate_and_dispatch(gate_input)
429
-
430
- expert_out = self.forward_experts(dispatched_input)
431
-
432
- combined_output = self.combine_expert_output(expert_out, combine_weights, scatter_index)
433
-
434
- if self.shared_experts is not None:
435
- shared_expert_out = self.shared_experts(gate_input)
436
- combined_output += shared_expert_out
437
-
438
- if orig_shape:
439
- combined_output = combined_output.reshape(orig_shape[:-1] + (combined_output.shape[-1],))
440
-
441
- return combined_output, combine_weights, router_loss, gate_logits
442
-
443
- def forward_experts(self, dispatched_input: torch.Tensor) -> torch.Tensor:
444
- """
445
- Forward pass through experts sequentially.
446
-
447
- Args:
448
- dispatched_input (Tensor): Input tensor of shape [num_experts, capacity, dim].
449
-
450
- Returns:
451
- Tensor: Expert outputs of shape [num_experts, capacity, dim].
452
- """
453
- true_experts = self.experts
454
- dispatched_input = dispatched_input.reshape(
455
- 1, self.num_local_experts, -1, dispatched_input.shape[-1]
456
- )
457
- expert_outputs = []
458
- if isinstance(self.experts, nn.ModuleList):
459
- chunks = dispatched_input.permute(1, 0, 2, 3).contiguous().unbind(0)
460
- assert len(chunks) == len(true_experts), f"{len(chunks)}, {len(true_experts)}"
461
- for chunk, expert in zip(chunks, true_experts):
462
- expert_outputs.append(expert(chunk))
463
- else:
464
- dispatched_input = dispatched_input.permute(1, 0, 2, 3).contiguous()
465
- orig_shape = dispatched_input.shape
466
- chunks = dispatched_input.reshape(orig_shape[0], -1, orig_shape[-1])
467
- chunks = self.experts(chunks)
468
- chunks = chunks.reshape(orig_shape[:-1] + (chunks.shape[-1],)).unbind(0)
469
- expert_outputs.extend(chunks)
470
-
471
- expert_output = torch.stack(expert_outputs, dim=1)
472
- return expert_output
473
-
474
- def moe_gate_dispatch(
475
- self,
476
- x: torch.Tensor,
477
- gate_logits: torch.Tensor,
478
- k: int,
479
- capacity: Optional[int],
480
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
481
- torch.Tensor, torch.Tensor]:
482
-
483
- S, H = x.shape
484
- E = gate_logits.shape[1]
485
- device = x.device
486
- topk_prob, topk_idx = torch.topk(gate_logits, k, dim=-1)
487
- combine_weights = topk_prob
488
- expert_id = topk_idx
489
- y = x.new_zeros((E, capacity, H))
490
- scatter_index = x.new_full((k, S), -1, dtype=torch.int32)
491
-
492
- # per-expert slot counters
493
- slot_counter = torch.zeros(E, dtype=torch.int32, device=device)
494
-
495
- for tok in range(S):
496
- for route in range(k):
497
- e = expert_id[tok, route].item()
498
- slot = slot_counter[e].item()
499
- if slot >= capacity:
500
- combine_weights[tok, route] = 0.0
501
- continue
502
-
503
- # record mapping & dispatch activation
504
- scatter_index[route, tok] = e * capacity + slot
505
- y[e, slot] = x[tok]
506
- slot_counter[e] += 1
507
-
508
- expert_offset = torch.cumsum(slot_counter, 0, dtype=torch.int64)
509
-
510
- return y, combine_weights, scatter_index, expert_offset, expert_id
511
-
512
- def combine_expert_output(self, expert_output: torch.Tensor, combine_weights: torch.Tensor, scatter_index: torch.Tensor) -> torch.Tensor:
513
- """
514
- Combine expert outputs using combination weights.
515
-
516
- Args:
517
- expert_output (Tensor): Expert outputs [num_experts, capacity, dim].
518
- combine_weights (Tensor): Combination weights.
519
- scatter_index (Tensor): Scatter indices.
520
-
521
- Returns:
522
- Tensor: Combined output [seqlen, dim].
523
- """
524
- expert_output = expert_output.reshape(-1, expert_output.shape[-1])
525
- combined_output = self.combining(expert_output, combine_weights, scatter_index)
526
- return combined_output
527
-
528
- def combining(self, x, combine_weights, scatter_index):
529
- """
530
- Combines and aggregates input matrix using combination weights.
531
-
532
- Args:
533
- x (Tensor): Input tensor of shape [num_experts * capacity, dim]
534
- combine_weights (Tensor): Combination weights of shape [seq, 2]
535
- scatter_index (Tensor): Scatter indices of shape [seq, 2]
536
-
537
- Returns:
538
- Tensor: Combined output tensor of shape [seq, dim]
539
- """
540
- dim = x.shape[-1]
541
-
542
- scatter_index = scatter_index.reshape([-1])
543
- num_k = combine_weights.shape[-1]
544
-
545
- combine_weights = combine_weights.unsqueeze(1)
546
-
547
- x = x[scatter_index].reshape([-1, num_k, dim])
548
-
549
- return torch.matmul(combine_weights, x).squeeze(1)
550
-
551
- def gate_and_dispatch(self, input):
552
- """
553
- Calculate gate and dispatch inputs.
554
-
555
- Args:
556
- input: Input tensor of shape [seq, dim]
557
-
558
- Returns:
559
- tuple: (dispatched_input, combine_weights, dispatch_mask,
560
- scatter_index, router_loss, gate_logits, gate_prob)
561
- """
562
- gate_logits, capacity, router_loss = topk_gate_func(
563
- self,
564
- input,
565
- )
566
-
567
- # capacity no use
568
- prob = self.gate_act(gate_logits)
569
- (
570
- dispatched_input,
571
- combine_weights_unnorm,
572
- scatter_index,
573
- dispatch_mask,
574
- _,
575
- ) = self.moe_gate_dispatch(input, prob, k=self.k, capacity=capacity)
576
- dispatch_mask = torch.diff(F.pad(dispatch_mask, (1, 0)))
577
-
578
- scatter_index.detach()
579
- dispatch_mask.detach()
580
-
581
- scatter_index = scatter_index.transpose(0, 1) # [k, s] -> [s, k]
582
- combine_weights = combine_weights_unnorm / torch.clamp(
583
- combine_weights_unnorm.sum(dim=-1, keepdim=True), min=1e-12
584
- )
585
- combine_weights = combine_weights.to(dtype=dispatched_input.dtype)
586
-
587
- return dispatched_input, combine_weights, dispatch_mask, scatter_index, router_loss, gate_logits, prob
588
-
589
- def get_capacity(self, num_tokens, cap_factor=None):
590
- """
591
- Calculate capacity based on number of tokens.
592
-
593
- Args:
594
- num_tokens: Number of input tokens
595
- cap_factor: Optional capacity factor override
596
-
597
- Returns:
598
- int: Calculated capacity
599
- """
600
- num_experts = self.config.moe_num_experts
601
- if cap_factor is not None:
602
- cap = cap_factor
603
- else:
604
- if self.training:
605
- cap = self.config.moe_capacity[0]
606
- elif num_tokens < num_experts:
607
- cap = self.config.moe_capacity[2]
608
- else:
609
- cap = self.config.moe_capacity[1]
610
-
611
- capacity = int(cap * num_tokens // num_experts)
612
- assert capacity > 0, f"requires capacity to >= 0. cap={cap}, num_tokens={num_tokens}"
613
- return capacity
614
-
615
-
616
- class Ernie4_5_RMSNorm(nn.Module):
617
- """
618
- Ernie Root Mean Square Layer Normalization (Ernie4_5_RMSNorm) implementation.
619
-
620
- Ernie4_5_RMSNorm is a simplified version of LayerNorm that focuses on the root mean square of inputs,
621
- omitting the mean-centering operation. This provides computational efficiency while maintaining
622
- good performance.
623
-
624
- """
625
-
626
- def __init__(self, config):
627
- """
628
- Initialize RMSNorm layer.
629
-
630
- Args:
631
- config (ErnieConfig): Model configuration.
632
- """
633
- super().__init__()
634
- self.config = config
635
- self.hidden_size = config.hidden_size
636
- self.weight = nn.Parameter(torch.ones(config.hidden_size))
637
- self.variance_epsilon = config.rms_norm_eps
638
-
639
- def forward(self, hidden_states):
640
- """
641
- Apply RMS normalization to input hidden states.
642
-
643
- Args:
644
- hidden_states (Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
645
-
646
- Returns:
647
- Tensor: Normalized output tensor of same shape as input
648
- """
649
- input_dtype = hidden_states.dtype
650
- hidden_states = hidden_states.to(torch.float32)
651
- variance = hidden_states.pow(2).mean(dim=-1, keepdim=True)
652
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
653
-
654
- return self.weight * hidden_states.to(input_dtype)
655
-
656
-
657
- class Ernie4_5_RopeEmbedding(nn.Module):
658
- def __init__(self, config: Ernie4_5_MoeConfig, device=None):
659
- super().__init__()
660
- # BC: "rope_type" was originally "type"
661
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
662
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
663
- else:
664
- self.rope_type = "default"
665
- self.max_seq_len_cached = config.max_position_embeddings
666
- self.original_max_seq_len = config.max_position_embeddings
667
-
668
- self.config = config
669
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
670
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
671
- self.register_buffer("inv_freq", inv_freq, persistent=False)
672
- self.original_inv_freq = self.inv_freq
673
-
674
- @torch.no_grad()
675
- def forward(self, x, position_ids):
676
- inv_freq_expanded = self.inv_freq[None,None,:].float()
677
- position_ids_expanded = position_ids[...,None].float()
678
- freqs = (inv_freq_expanded.float() * position_ids_expanded.float())
679
- cos = torch.cos(freqs) * self.attention_scaling
680
- sin = torch.sin(freqs) * self.attention_scaling
681
- return cos, sin
682
- # return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
683
-
684
-
685
- class Ernie4_5_DecoderLayer(nn.Module):
686
- """A single transformer decoder layer in ERNIE-MoE model.
687
-
688
- Contains self-attention and feed-forward components with optional MoE (Mixture of Experts)
689
- support, residual connections, and layer normalization.
690
- """
691
-
692
- def __init__(self, config, layer_idx):
693
- """Initialize the decoder layer.
694
-
695
- Args:
696
- config (ErnieMoEConfig): Model configuration.
697
- layer_idx (int): Index of this layer in the transformer stack
698
- """
699
- super().__init__()
700
- self.hidden_size = config.hidden_size
701
- self.layer_idx = layer_idx
702
- self.config = config
703
- self.use_moe = config.use_moe
704
- self.self_attn = Ernie4_5_Attention(config, layer_idx)
705
-
706
- moe_layer_start_index = (
707
- min(config.moe_layer_start_index)
708
- if isinstance(config.moe_layer_start_index, (tuple, list))
709
- else config.moe_layer_start_index
710
- )
711
- moe_layer_end_index = (
712
- max(config.moe_layer_end_index)
713
- if isinstance(config.moe_layer_end_index, (tuple, list))
714
- else config.moe_layer_end_index
715
- )
716
-
717
- if (
718
- self.use_moe
719
- and ((layer_idx + 1) % config.moe_layer_interval == 0)
720
- and layer_idx >= moe_layer_start_index
721
- and layer_idx <= moe_layer_end_index
722
- ):
723
- self.mlp = Ernie4_5_MoeMLP(config)
724
- else:
725
- self.mlp = Ernie4_5_MLP(config)
726
-
727
- self.input_layernorm = Ernie4_5_RMSNorm(config)
728
- self.post_attention_layernorm = Ernie4_5_RMSNorm(config)
729
-
730
- self.residual_add1 = Ernie4_5_ResidualWithDropout(config.hidden_dropout_prob)
731
- self.residual_add2 = Ernie4_5_ResidualWithDropout(config.hidden_dropout_prob)
732
-
733
- def forward(
734
- self,
735
- hidden_states: torch.Tensor,
736
- attention_mask: Optional[torch.Tensor] = None,
737
- position_ids: Optional[torch.Tensor] = None,
738
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
739
- output_attentions: Optional[bool] = False,
740
- use_cache: Optional[bool] = False,
741
- cache_position: Optional[torch.LongTensor] = None,
742
- position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
743
- output_router_loss: bool = True,
744
- output_gate_logits: bool = True,
745
- **kwargs: Unpack[FlashAttentionKwargs],
746
- ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
747
- """Forward pass through the decoder layer.
748
-
749
- Args:
750
- hidden_states (torch.Tensor): Input tensor [batch_size, seq_len, hidden_size]
751
- attention_mask (Optional[torch.Tensor]): Attention mask tensor
752
- position_ids (Optional[torch.Tensor]): Position indices for rotary embeddings
753
- past_key_value (Optional[Tuple[torch.Tensor]]): Cached key/value states
754
- output_attentions (Optional[bool]): Whether to return attention weights
755
- use_cache (Optional[bool]): Whether to cache key/value states
756
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
757
- Indices depicting the position of the input sequence tokens in the sequence.
758
- position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
759
- Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
760
- with `head_dim` being the embedding dimension of each attention head.
761
- output_router_loss (bool): Whether to return MoE router loss
762
- output_gate_logits (bool): Whether to return MoE gate logits
763
-
764
- Returns:
765
- Union: Various output combinations depending on arguments:
766
- - Base case: Hidden states tensor
767
- - With attention: Tuple of (hidden_states, attention_weights)
768
- - With router loss: May include gate logits in output tuple
769
- - With MoE gate logits: May include gate logits in output tuple
770
- """
771
- residual = hidden_states
772
-
773
- hidden_states = self.input_layernorm(hidden_states)
774
-
775
- # Self Attention
776
- hidden_states, self_attn_weights = self.self_attn(
777
- hidden_states=hidden_states,
778
- attention_mask=attention_mask,
779
- past_key_value=past_key_value,
780
- position_ids=position_ids,
781
- use_cache=use_cache,
782
- cache_position=cache_position,
783
- position_embeddings=position_embeddings,
784
- **kwargs,
785
- )
786
-
787
- hidden_states = self.residual_add1(hidden_states, residual)
788
-
789
- # Fully Connected
790
- residual = hidden_states
791
- hidden_states = self.post_attention_layernorm(hidden_states)
792
-
793
- router_loss = None
794
- gate_logits = None
795
-
796
- if isinstance(self.mlp, Ernie4_5_MoeMLP):
797
- hidden_states, _, router_loss, gate_logits = self.mlp(hidden_states)
798
- else:
799
- hidden_states = self.mlp(hidden_states)
800
-
801
- hidden_states = self.residual_add2(hidden_states, residual)
802
-
803
- outputs = (hidden_states,)
804
-
805
- if output_attentions:
806
- outputs += (self_attn_weights,)
807
-
808
- if output_router_loss:
809
- outputs += (router_loss,)
810
-
811
- if output_gate_logits:
812
- outputs += (gate_logits,)
813
-
814
- return outputs
815
-
816
-
817
- @auto_docstring
818
- class Ernie4_5_PretrainedModel(PreTrainedModel):
819
- """Base class for ERNIE pretrained models."""
820
- config_class = Ernie4_5_MoeConfig
821
- base_model_prefix = "model"
822
- supports_gradient_checkpointing = True
823
- _no_split_modules = ["Ernie4_5_DecoderLayer"]
824
- _skip_keys_device_placement = ["past_key_values"]
825
- _supports_flash_attn_2 = True
826
- _supports_sdpa = True
827
- _supports_flex_attn = True
828
- _supports_cache_class = True
829
- _supports_quantized_cache = True
830
- _supports_static_cache = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
831
-
832
-
833
- def subbatch(f, arg_idx, axis, bs, out_idx, same_arg_idx={}):
834
- """
835
- Converts a function to one that applies to subbatch of an input dimension.
836
- Useful for processing large tensors in smaller chunks to reduce memory usage.
837
-
838
- Args:
839
- f (Callable): Function to be subbatched.
840
- arg_idx ([int]): Indices of the inputs to be subbatched.
841
- axis ([int]): Indices of the dimensions to be subbatched for each input.
842
- bs (int): Subbatch size.
843
- out_idx (int): Dimension to concatenate outputs along.
844
- same_arg_idx (dict): Mapping of argument indices that share the same tensor.
845
-
846
- Returns:
847
- Callable: New function that processes inputs in subbatches.
848
- """
849
-
850
- @functools.wraps(f)
851
- def wrapper(*args, **kwargs):
852
-
853
- assert len(arg_idx) == len(axis), "Number of batching args and number of batching dims should match."
854
-
855
- inps = [args[i] for i in arg_idx]
856
- axis_width = [inp.shape[d] for inp, d in zip(inps, axis)]
857
- assert len(set(axis_width)) == 1, "Batch sizes should be kept equal."
858
-
859
- inp_axis = {idx: d for idx, d in zip(arg_idx, axis)}
860
-
861
- axis_width = axis_width[0]
862
- if axis_width < bs:
863
- return f(*args, **kwargs)
864
-
865
- outs = []
866
- for slice_at in range(0, axis_width, bs):
867
- _args = []
868
- for i, inp in enumerate(args):
869
- if i in same_arg_idx:
870
- assert (
871
- i > same_arg_idx[i]
872
- ), f"expect i > same_arg_idx[i], but got i: {i} and same_arg_idx[i]: {same_arg_idx[i]}"
873
- _args.append(_args[same_arg_idx[i]])
874
- elif i in arg_idx:
875
- d = inp_axis[i]
876
- start = slice_at
877
- end = min(inp.shape[d], slice_at + bs)
878
- # Build slice for all dims, only slice along axis d
879
- slices = [slice(None)] * inp.ndim
880
- slices[d] = slice(start, end)
881
- _args.append(inp[tuple(slices)])
882
- else:
883
- _args.append(inp)
884
-
885
- out = f(*_args, **kwargs)
886
- outs.append(out)
887
-
888
- return torch.cat(outs, dim=out_idx)
889
-
890
- return wrapper
891
-
892
-
893
- class ErniePretrainingCriterion(nn.Module):
894
- """Criterion for ERNIE pretraining task."""
895
-
896
- def __init__(self, config, return_tuple=True):
897
- """Initialize the pretraining criterion.
898
-
899
- Args:
900
- config (ErnieConfig): Model configuration.
901
- return_tuple (bool): Whether to return loss as tuple (loss, loss_sum). Defaults to True.
902
- """
903
- super().__init__()
904
- self.ignored_index = getattr(config, "ignored_index", -100)
905
- self.config = config
906
- self.return_tuple = return_tuple
907
-
908
- self.loss_func = nn.CrossEntropyLoss(reduction="none")
909
-
910
- def forward(self, prediction_scores, masked_lm_labels, loss_mask, router_loss=None, mtp_logits=None):
911
- """Compute the combined pretraining loss.
912
-
913
- Args:
914
- prediction_scores: Prediction scores tensor, [batch_size, seq_len, vocab_size]
915
- masked_lm_labels: Target labels tensor [batch_size, seq_len]
916
- loss_mask: Optional mask for valid tokens
917
- router_loss: Optional MoE router loss tensor
918
-
919
- Returns:
920
- Union:
921
- - If return_tuple=True: Tuple of (combined_loss, mlm_loss_sum)
922
- - If return_tuple=False: Combined loss tensor
923
- """
924
- if self.config.num_nextn_predict_layers > 0 and self.training:
925
- masked_lm_labels_ori = masked_lm_labels
926
- masked_lm_labels = masked_lm_labels[:, : -self.config.num_nextn_predict_layers]
927
- loss_mask = loss_mask[:, : -self.config.num_nextn_predict_layers]
928
- seq_length = masked_lm_labels.shape[1]
929
-
930
- res = self.forward_impl(prediction_scores, masked_lm_labels, loss_mask)
931
-
932
- if self.config.num_nextn_predict_layers > 0 and self.training:
933
- mtp_loss_res = []
934
- for depth in range(self.config.num_nextn_predict_layers):
935
- prediction_scores_cur_depth = mtp_logits[depth]
936
- masked_lm_labels_cur_depth = masked_lm_labels_ori[:, (depth + 1) : (depth + 1 + seq_length)]
937
- res_cur_depth = super().forward(
938
- prediction_scores_cur_depth,
939
- masked_lm_labels_cur_depth,
940
- )
941
- mtp_loss_res.append(res_cur_depth)
942
-
943
- def add_loss(main_loss, loss):
944
- return main_loss + loss - loss.detach()
945
-
946
-
947
- if self.return_tuple:
948
- loss, loss_sum = res
949
- if self.config.num_nextn_predict_layers > 0 and self.training:
950
- loss = add_loss(
951
- loss, self.config.multi_token_pred_lambda * sum([x[0] for x in mtp_loss_res]) / len(mtp_loss_res)
952
- )
953
- loss_sum = loss_sum + self.config.multi_token_pred_lambda * sum(
954
- [x[1].detach() for x in mtp_loss_res]
955
- ) / len(mtp_loss_res)
956
- else:
957
- loss, loss_sum = res, None
958
- if self.config.num_nextn_predict_layers > 0 and self.training:
959
- loss = add_loss(
960
- loss, self.config.multi_token_pred_lambda * sum([x[0] for x in mtp_loss_res]) / len(mtp_loss_res)
961
- )
962
-
963
- if router_loss is not None and isinstance(router_loss, torch.Tensor):
964
- loss = loss + router_loss - router_loss.detach()
965
-
966
- return loss, loss_sum
967
-
968
-
969
- def loss_impl(self, prediction_scores: torch.Tensor, masked_lm_labels: torch.Tensor) -> torch.Tensor:
970
- """
971
- Core loss computation without reduction (but per-token).
972
-
973
- Args:
974
- prediction_scores (torch.Tensor): Logits tensor [batch_size, seq_len, vocab_size].
975
- masked_lm_labels (torch.Tensor): Target labels tensor [batch_size, seq_len].
976
-
977
- Returns:
978
- torch.Tensor: Unreduced loss tensor of shape [batch_size, seq_len].
979
- Losses are calculated in float32.
980
- """
981
- scores_float32 = prediction_scores.to(torch.float32)
982
- # prediction_scores: [batch_size, seq_len, vocab_size]
983
- # masked_lm_labels: [batch_size, seq_len]
984
- # Transpose prediction_scores to [batch_size, vocab_size, seq_len]
985
- unreduced_loss = self.loss_func(
986
- scores_float32.transpose(1, 2), # Shape: [batch_size, vocab_size, seq_len]
987
- masked_lm_labels.long() # Shape: [batch_size, seq_len], ensure long type
988
- )
989
- # unreduced_loss will be of shape [batch_size, seq_len] and dtype float32
990
- return unreduced_loss
991
-
992
- def forward_impl(self, prediction_scores, masked_lm_labels, loss_mask=None):
993
- prediction_scores_dims = len(prediction_scores.shape)
994
-
995
- loss_subbatch_seqlen_config_key = "loss_subbatch_seqlen"
996
- default_loss_subbatch_seqlen = 32768
997
-
998
- current_loss_subbatch_seqlen = getattr(self.config, loss_subbatch_seqlen_config_key, default_loss_subbatch_seqlen)
999
-
1000
- if prediction_scores_dims == 2 and prediction_scores.shape[0] > current_loss_subbatch_seqlen:
1001
- sb_loss_func = subbatch(
1002
- self.loss_impl, [0, 1], [0, 0], current_loss_subbatch_seqlen, 0
1003
- )
1004
- masked_lm_loss = sb_loss_func(prediction_scores, masked_lm_labels)
1005
- elif prediction_scores_dims == 3 and prediction_scores.shape[1] > current_loss_subbatch_seqlen:
1006
- sb_loss_func = subbatch(
1007
- self.loss_impl, [0, 1], [1, 1], current_loss_subbatch_seqlen, 1
1008
- )
1009
- masked_lm_loss = sb_loss_func(prediction_scores, masked_lm_labels)
1010
- else:
1011
- masked_lm_loss = self.loss_impl(prediction_scores, masked_lm_labels)
1012
-
1013
- if loss_mask is None:
1014
- loss_mask = masked_lm_labels != self.ignored_index
1015
-
1016
- loss_mask = loss_mask.reshape(-1).to(torch.float32)
1017
-
1018
- masked_lm_loss = torch.sum(masked_lm_loss.to(torch.float32).reshape(-1) * loss_mask)
1019
-
1020
- # The division will be in float32
1021
- loss = masked_lm_loss / loss_mask.sum()
1022
-
1023
- loss_sum = masked_lm_loss.sum().detach()
1024
-
1025
- if not self.return_tuple:
1026
- if self.training:
1027
- return loss
1028
- return loss_sum
1029
- return loss, loss_sum
1030
-
1031
- @auto_docstring
1032
- class Ernie4_5_Model(Ernie4_5_PretrainedModel):
1033
- """The core ERNIE transformer model with MoE (Mixture of Experts) support."""
1034
- _keep_in_fp32_modules = ['gate']
1035
- def __init__(self, config: Ernie4_5_MoeConfig):
1036
- """Initialize the ERNIE model architecture."""
1037
- super().__init__(config)
1038
- self.padding_idx = config.pad_token_id
1039
- self.vocab_size = config.vocab_size
1040
- self.hidden_size = config.hidden_size
1041
- self.config = config
1042
-
1043
- self.embed_tokens = nn.Embedding(
1044
- self.vocab_size,
1045
- self.hidden_size,
1046
- )
1047
-
1048
- self.layers = nn.ModuleList(
1049
- [
1050
- Ernie4_5_DecoderLayer(config, i)
1051
- for i in range(config.num_hidden_layers)
1052
- ]
1053
- )
1054
- self.norm = Ernie4_5_RMSNorm(config)
1055
- self.rotary_emb = Ernie4_5_RopeEmbedding(config=config)
1056
-
1057
- self.gradient_checkpointing = False
1058
-
1059
- if config.num_nextn_predict_layers > 0 and self.training:
1060
- self.mtp_block = nn.ModuleList(
1061
- [Ernie4_5_DecoderLayer(config, layer_idx) for layer_idx in range(config.num_nextn_predict_layers)]
1062
- )
1063
- self.mtp_emb_norm = nn.ModuleList(
1064
- [Ernie4_5_RMSNorm(config) for _ in range(config.num_nextn_predict_layers)]
1065
- )
1066
- self.mtp_hidden_norm = nn.ModuleList(
1067
- [Ernie4_5_RMSNorm(config) for _ in range(config.num_nextn_predict_layers)]
1068
- )
1069
- self.mtp_linear_proj = nn.ModuleList(
1070
- [nn.Linear(config.hidden_size * 2, config.hidden_size, bias=config.use_bias) for _ in range(config.num_nextn_predict_layers)]
1071
- )
1072
-
1073
- self.post_init()
1074
-
1075
- def get_input_embeddings(self):
1076
- """Get the input embedding layer."""
1077
- return self.embed_tokens
1078
-
1079
- def set_input_embeddings(self, value):
1080
- """Set new input embeddings."""
1081
- self.embed_tokens = value
1082
-
1083
- def forward(
1084
- self,
1085
- input_ids: Optional[torch.LongTensor] = None,
1086
- attention_mask: Optional[torch.Tensor] = None,
1087
- position_ids: Optional[torch.LongTensor] = None,
1088
- past_key_values: Optional[Cache] = None,
1089
- inputs_embeds: Optional[torch.FloatTensor] = None,
1090
- use_cache: Optional[bool] = None,
1091
- output_attentions: Optional[bool] = None,
1092
- output_hidden_states: Optional[bool] = None,
1093
- cache_position: Optional[torch.LongTensor] = None,
1094
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
1095
- ):
1096
- """Forward pass through the ERNIE model."""
1097
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1098
- output_hidden_states = (
1099
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1100
- )
1101
-
1102
- if (input_ids is None) ^ (inputs_embeds is not None):
1103
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1104
-
1105
- if self.gradient_checkpointing and self.training:
1106
- if use_cache:
1107
- logger.warning_once(
1108
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1109
- )
1110
- use_cache = False
1111
-
1112
- if use_cache and past_key_values is None:
1113
- past_key_values = DynamicCache()
1114
-
1115
- if inputs_embeds is None:
1116
- inputs_embeds = self.embed_tokens(input_ids)
1117
-
1118
- inputs_embeds = inputs_embeds.to(self.embed_tokens.weight.dtype)
1119
-
1120
- if cache_position is None:
1121
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1122
- cache_position = torch.arange(
1123
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1124
- )
1125
- if position_ids is None:
1126
- position_ids = cache_position.unsqueeze(0)
1127
-
1128
- seq_length = inputs_embeds.size(1)
1129
- if self.config.num_nextn_predict_layers > 0 and self.training:
1130
- seq_length -= self.config.num_nextn_predict_layers
1131
- seq_length_with_past = seq_length
1132
- if position_ids is not None:
1133
- position_ids = position_ids[:, :seq_length]
1134
- inputs_embeds_extra = inputs_embeds[:, -self.config.num_nextn_predict_layers :, :]
1135
- inputs_embeds = inputs_embeds[:, : -self.config.num_nextn_predict_layers, :]
1136
- inputs_embeds_ori = inputs_embeds
1137
-
1138
- causal_mask = self._update_causal_mask(
1139
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
1140
- )
1141
-
1142
- hidden_states = inputs_embeds
1143
-
1144
- # create position embeddings to be shared across the decoder layers
1145
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
1146
-
1147
- # decoder layers
1148
- all_hidden_states = () if output_hidden_states else None
1149
- all_self_attns = () if output_attentions else None
1150
- all_router_loss = torch.tensor(0.0, device=inputs_embeds.device) if self.config.use_moe else None
1151
- all_gate_logits = ()
1152
-
1153
- for decoder_layer in self.layers:
1154
- if output_hidden_states:
1155
- all_hidden_states += (hidden_states,)
1156
-
1157
- if self.gradient_checkpointing and self.training:
1158
- layer_outputs = self._gradient_checkpointing_func(
1159
- partial(decoder_layer.__call__, **flash_attn_kwargs),
1160
- hidden_states,
1161
- causal_mask,
1162
- position_ids,
1163
- past_key_values,
1164
- output_attentions,
1165
- use_cache,
1166
- cache_position,
1167
- position_embeddings,
1168
- )
1169
- else:
1170
- layer_outputs = decoder_layer(
1171
- hidden_states,
1172
- causal_mask,
1173
- position_ids,
1174
- past_key_values,
1175
- output_attentions,
1176
- use_cache,
1177
- cache_position,
1178
- position_embeddings,
1179
- **flash_attn_kwargs,
1180
- )
1181
-
1182
- hidden_states = layer_outputs[0]
1183
-
1184
- if output_attentions:
1185
- all_self_attns += (layer_outputs[1],)
1186
-
1187
- if self.config.use_moe:
1188
- layer_outputs, gate_logits = layer_outputs[:-1], layer_outputs[-1]
1189
- all_gate_logits = all_gate_logits + (gate_logits,)
1190
-
1191
- mtp_outputs = []
1192
- if self.config.num_nextn_predict_layers > 0 and self.training:
1193
- mtp_outputs.append(hidden_states)
1194
- for depth in range(self.config.num_nextn_predict_layers):
1195
- inputs_embeds_cur_depth = torch.concat(
1196
- [inputs_embeds_ori[:, (depth + 1) :, :], inputs_embeds_extra[:, : (depth + 1), :]], axis=1
1197
- )
1198
- inputs_embeds_cur_depth_norm = self.mtp_emb_norm[depth](inputs_embeds_cur_depth)
1199
- hidden_states_norm = self.mtp_hidden_norm[depth](hidden_states)
1200
-
1201
- inputs_embeds_cur_depth = self.mtp_linear_proj[depth](
1202
- torch.concat([inputs_embeds_cur_depth_norm, hidden_states_norm], axis=-1)
1203
- )
1204
-
1205
- decoder_layer = self.mtp_block[depth]
1206
- layer_outputs = decoder_layer(
1207
- inputs_embeds_cur_depth,
1208
- causal_mask,
1209
- position_ids,
1210
- past_key_values,
1211
- output_attentions,
1212
- use_cache,
1213
- cache_position,
1214
- position_embeddings,
1215
- **flash_attn_kwargs,
1216
- )
1217
- if isinstance(layer_outputs, (tuple, list)):
1218
- hidden_states = layer_outputs[0]
1219
- else:
1220
- hidden_states = layer_outputs
1221
-
1222
- if self.config.use_moe:
1223
- layer_outputs, gate_logits = layer_outputs[:-1], layer_outputs[-1]
1224
- all_gate_logits = all_gate_logits + (gate_logits,)
1225
-
1226
- mtp_outputs.append(hidden_states)
1227
- mtp_outputs = [self.norm(hidden_states) for depth, hidden_states in enumerate(mtp_outputs)]
1228
- hidden_states, mtp_outputs = mtp_outputs[0], mtp_outputs[1:]
1229
- else:
1230
- hidden_states = self.norm(hidden_states)
1231
-
1232
- # add hidden states from the last decoder layer
1233
- if output_hidden_states:
1234
- all_hidden_states += (hidden_states,)
1235
-
1236
- # assert all_router_loss is None, f'moe not support `return-dict`'
1237
- return Erine4_5_MoeModelOutputWithPast(
1238
- last_hidden_state=hidden_states,
1239
- past_key_values=past_key_values,
1240
- hidden_states=all_hidden_states,
1241
- attentions=all_self_attns,
1242
- router_loss=all_router_loss,
1243
- gate_logits=all_gate_logits,
1244
- mtp_outputs=mtp_outputs,
1245
- )
1246
-
1247
- def _update_causal_mask(
1248
- self,
1249
- attention_mask: Union[torch.Tensor, "BlockMask"],
1250
- input_tensor: torch.Tensor,
1251
- cache_position: torch.Tensor,
1252
- past_key_values: Cache,
1253
- output_attentions: bool = False,
1254
- ):
1255
- if self.config._attn_implementation == "flash_attention_2":
1256
- if attention_mask is not None and past_key_values is not None:
1257
- is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
1258
- if is_padding_right:
1259
- raise ValueError(
1260
- "You are attempting to perform batched generation with padding_side='right'"
1261
- " this may lead to unexpected behaviour for Flash Attention version of Ernie4_5. Make sure to "
1262
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1263
- )
1264
- if attention_mask is not None and 0.0 in attention_mask:
1265
- return attention_mask
1266
- return None
1267
- if self.config._attn_implementation == "flex_attention":
1268
- if isinstance(attention_mask, torch.Tensor):
1269
- attention_mask = make_flex_block_causal_mask(attention_mask)
1270
- return attention_mask
1271
-
1272
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1273
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1274
- # to infer the attention mask.
1275
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1276
- using_static_cache = isinstance(past_key_values, StaticCache)
1277
-
1278
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1279
- if (
1280
- self.config._attn_implementation == "sdpa"
1281
- and not using_static_cache
1282
- and not output_attentions
1283
- ):
1284
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
1285
- attention_mask,
1286
- inputs_embeds=input_tensor,
1287
- past_key_values_length=past_seen_tokens,
1288
- is_training=self.training,
1289
- ):
1290
- return None
1291
-
1292
- dtype = input_tensor.dtype
1293
- min_dtype = torch.finfo(dtype).min
1294
- sequence_length = input_tensor.shape[1]
1295
- # StaticCache
1296
- if using_static_cache:
1297
- target_length = past_key_values.get_max_cache_shape()
1298
- # DynamicCache or no cache
1299
- else:
1300
- target_length = (
1301
- attention_mask.shape[-1]
1302
- if isinstance(attention_mask, torch.Tensor)
1303
- else past_seen_tokens + sequence_length + 1
1304
- )
1305
-
1306
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1307
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1308
- attention_mask,
1309
- sequence_length=sequence_length,
1310
- target_length=target_length,
1311
- dtype=dtype,
1312
- cache_position=cache_position,
1313
- batch_size=input_tensor.shape[0],
1314
- config=self.config,
1315
- past_key_values=past_key_values,
1316
- )
1317
-
1318
- if (
1319
- self.config._attn_implementation == "sdpa"
1320
- and attention_mask is not None
1321
- and attention_mask.device.type in ["cuda", "xpu", "npu"]
1322
- and not output_attentions
1323
- ):
1324
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1325
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1326
- # Details: https://github.com/pytorch/pytorch/issues/110213
1327
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1328
-
1329
- return causal_mask
1330
-
1331
- @staticmethod
1332
- def _prepare_4d_causal_attention_mask_with_cache_position(
1333
- attention_mask: torch.Tensor,
1334
- sequence_length: int,
1335
- target_length: int,
1336
- dtype: torch.dtype,
1337
- cache_position: torch.Tensor,
1338
- batch_size: int,
1339
- config: Ernie4_5_MoeConfig,
1340
- past_key_values: Cache,
1341
- ):
1342
- """
1343
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1344
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1345
-
1346
- Args:
1347
- attention_mask (`torch.Tensor`):
1348
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
1349
- sequence_length (`int`):
1350
- The sequence length being processed.
1351
- target_length (`int`):
1352
- The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
1353
- dtype (`torch.dtype`):
1354
- The dtype to use for the 4D attention mask.
1355
- cache_position (`torch.Tensor`):
1356
- Indices depicting the position of the input sequence tokens in the sequence.
1357
- batch_size (`torch.Tensor`):
1358
- Batch size.
1359
- config (`Ernie4_5_MoeConfig`):
1360
- The model's configuration class
1361
- past_key_values (`Cache`):
1362
- The cache class that is being used currently to generate
1363
- """
1364
- if attention_mask is not None and attention_mask.dim() == 4:
1365
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1366
- causal_mask = attention_mask
1367
- else:
1368
- min_dtype = torch.finfo(dtype).min
1369
- causal_mask = torch.full(
1370
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
1371
- )
1372
- diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
1373
- -1, 1
1374
- )
1375
- text_config = config.get_text_config()
1376
- causal_mask *= diagonal_attend_mask
1377
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1378
- if attention_mask is not None:
1379
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1380
- if attention_mask.shape[-1] > target_length:
1381
- attention_mask = attention_mask[:, :target_length]
1382
- mask_length = attention_mask.shape[-1]
1383
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
1384
- causal_mask.device
1385
- )
1386
- padding_mask = padding_mask == 0
1387
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1388
- padding_mask, min_dtype
1389
- )
1390
- return causal_mask
1391
-
1392
- @auto_docstring
1393
- class Ernie4_5_MoeForCausalLM(Ernie4_5_PretrainedModel,GenerationMixin):
1394
- """ERNIE Mixture of Experts (MoE) model for causal language modeling."""
1395
-
1396
- _tied_weights_keys = ["lm_head.weight"]
1397
- _tp_plan = {"lm_head": "colwise_rep"}
1398
- _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
1399
-
1400
- def __init__(self, config):
1401
- """
1402
- Initializes the ERNIE MoE model for causal language modeling.
1403
-
1404
- Args:
1405
- config (dict): Model configuration.
1406
- """
1407
- super().__init__(config)
1408
- self.config = config
1409
- self.model = Ernie4_5_Model(config)
1410
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size,bias=config.weight_share_add_bias and config.use_bias) # TODO
1411
- self._loss_function = ErniePretrainingCriterion(config)
1412
-
1413
- # Initialize weights and apply final processing
1414
- self.post_init()
1415
-
1416
- def get_input_embeddings(self):
1417
- """Returns the input embeddings layer."""
1418
- return self.model.embed_tokens
1419
-
1420
- def set_input_embeddings(self, value):
1421
- """Sets the input embeddings layer."""
1422
- self.ernie.embed_tokens = value
1423
-
1424
- def get_output_embeddings(self):
1425
- """Returns the output embeddings (LM head)."""
1426
- return self.lm_head
1427
-
1428
- def set_output_embeddings(self, new_embeddings):
1429
- """Sets the output embeddings layer."""
1430
- self.lm_head = new_embeddings
1431
-
1432
- def set_decoder(self, decoder):
1433
- """Sets the ERNIE decoder model."""
1434
- self.model = decoder
1435
-
1436
- def get_decoder(self):
1437
- """Get the transformer decoder."""
1438
- return self.model
1439
-
1440
- @can_return_tuple
1441
- def forward(
1442
- self,
1443
- input_ids,
1444
- attention_mask=None,
1445
- position_ids=None,
1446
- past_key_values: Optional[list[torch.FloatTensor]] = None,
1447
- inputs_embeds=None,
1448
- labels=None,
1449
- loss_mask=None,
1450
- use_cache=False,
1451
- output_attentions: Optional[bool] = None,
1452
- output_hidden_states: Optional[bool] = None,
1453
- **kwargs: Unpack[KwargsForCausalLM],
1454
- ):
1455
- """
1456
- Forward pass for causal language modeling.
1457
- """
1458
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1459
- output_hidden_states = (
1460
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1461
- )
1462
-
1463
- outputs = self.model(
1464
- input_ids,
1465
- position_ids=position_ids,
1466
- attention_mask=attention_mask,
1467
- inputs_embeds=inputs_embeds,
1468
- use_cache=use_cache,
1469
- past_key_values=past_key_values,
1470
- output_attentions=output_attentions,
1471
- output_hidden_states=output_hidden_states,
1472
- **kwargs,
1473
- )
1474
-
1475
- hidden_states = outputs.last_hidden_state
1476
- mtp_outputs = outputs.mtp_outputs
1477
-
1478
- logits = self.lm_head(hidden_states)
1479
- mtp_logits = []
1480
- if len(mtp_outputs) > 0:
1481
- mtp_logits = [self.lm_head(_hidden_states) for _hidden_states in mtp_outputs]
1482
- loss, router_loss = None, None
1483
- if getattr(self.config, "use_moe", False):
1484
- router_loss = outputs.router_loss
1485
-
1486
- if labels is not None:
1487
- loss, _ = self.loss_function(logits, labels, loss_mask, router_loss, mtp_logits)
1488
-
1489
- return Ernie4_5_MoeCausalLMOutputWithPast(
1490
- loss=loss,
1491
- logits=logits,
1492
- past_key_values=outputs.past_key_values,
1493
- hidden_states=outputs.hidden_states,
1494
- attentions=outputs.attentions,
1495
- router_loss=router_loss,
1496
- )
1497
-
1498
-
1499
-
1500
- __all__ = [
1501
- "Ernie4_5_Model",
1502
- "Ernie4_5_MoeForCausalLM",
1503
- "Ernie4_5_PretrainedModel"
1504
- ]