leejunhyeok commited on
Commit
80e1a1c
·
verified ·
1 Parent(s): 72cc86d

Update modeling_motif.py

Browse files
Files changed (1) hide show
  1. modeling_motif.py +821 -101
modeling_motif.py CHANGED
@@ -1,5 +1,5 @@
1
  import math
2
- from typing import List, Optional, Tuple, Union, Callable, Dict
3
 
4
  import torch
5
  import torch.utils.checkpoint
@@ -28,25 +28,14 @@ from .configuration_motif import MotifConfig
28
  from dataclasses import dataclass
29
 
30
  import torch.nn.functional as F
31
- import time
32
-
33
- logger = logging.get_logger(__name__)
34
-
35
- if is_flash_attn_2_available():
36
- from transformers.modeling_flash_attention_utils import _flash_attention_forward
37
-
38
-
39
- _CONFIG_FOR_DOC = "MotifConfig"
40
 
41
  from transformers.activations import ACT2CLS as _ACT2CLS
42
  from transformers.activations import ClassInstantier
43
-
44
-
45
  class PolyNorm(torch.nn.Module):
46
- """
47
  A trainable activation function introduced in https://arxiv.org/html/2411.03884v1.
48
  The code is copied from https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md,
49
- with the change `* torch.rsqrt` => `/ torch.sqrt`
50
  """
51
 
52
  def __init__(self, eps=1e-6):
@@ -62,11 +51,52 @@ class PolyNorm(torch.nn.Module):
62
  return self.weight[0] * self._norm(x ** 3) + self.weight[1] * self._norm(
63
  x ** 2) + self.weight[2] * self._norm(x) + self.bias
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- CUSTOM_ACT2CLS = {"poly_norm": PolyNorm}
67
  ACT2CLS = {**_ACT2CLS, **CUSTOM_ACT2CLS}
68
  ACT2FN = ClassInstantier(ACT2CLS)
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  class MotifRMSNorm(nn.Module):
72
 
@@ -80,7 +110,8 @@ class MotifRMSNorm(nn.Module):
80
 
81
  def forward(self, hidden_states):
82
  input_dtype = hidden_states.dtype
83
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
 
84
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
85
  return self.weight * hidden_states.to(input_dtype)
86
 
@@ -88,21 +119,24 @@ class MotifRMSNorm(nn.Module):
88
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
89
 
90
 
91
- ALL_LAYERNORM_LAYERS.append(MotifRMSNorm)
92
 
93
 
94
  class MotifRotaryEmbeddingWithCache(nn.Module):
95
  """
96
  Rotary positional embedding module with caching for efficiency.
 
97
  Args:
98
  dim (int): Dimensionality of the embedding.
99
  max_position_embeddings (int): Maximum sequence length for caching. Default is 2048.
100
  base (int): Base for computing inverse frequency. Default is 10000.
101
  device (torch.device, optional): Device for tensor storage.
 
102
  Methods:
103
  forward(x, seq_len=None):
104
  Computes cosine and sine embeddings for input sequence length.
105
  Automatically updates cache if `seq_len` exceeds cached length.
 
106
  Attributes:
107
  inv_freq (torch.Tensor): Inverse frequency tensor for position encoding.
108
  cos_cached (torch.Tensor): Cached cosine embeddings.
@@ -138,8 +172,8 @@ class MotifRotaryEmbeddingWithCache(nn.Module):
138
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
139
 
140
  return (
141
- self.cos_cached[None, :seq_len].to(dtype=x.dtype),
142
- self.sin_cached[None, :seq_len].to(dtype=x.dtype),
143
  )
144
 
145
 
@@ -156,6 +190,7 @@ class MotifRotaryEmbedding(nn.Module):
156
  config: Optional[MotifConfig] = None,
157
  ):
158
  super().__init__()
 
159
  self.rope_kwargs = {}
160
  if config is None:
161
  logger.warning_once(
@@ -200,7 +235,7 @@ class MotifRotaryEmbedding(nn.Module):
200
  device,
201
  seq_len=seq_len,
202
  **self.rope_kwargs)
203
- self.register_buffer("inv_freq", inv_freq, persistent=False)
204
  self.max_seq_len_cached = seq_len
205
 
206
  if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
@@ -234,8 +269,10 @@ class MotifRotaryEmbedding(nn.Module):
234
  def rotate_half(x):
235
  """
236
  Rotates half of the dimensions of the input tensor using torch.roll and in-place negation.
 
237
  Args:
238
  x (torch.Tensor): The input tensor.
 
239
  Returns:
240
  torch.Tensor: A tensor where the latter half of the dimensions are negated
241
  and moved before the first half.
@@ -247,27 +284,60 @@ def rotate_half(x):
247
  return rotated_tensor
248
 
249
 
250
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
251
  """
252
  Applies rotary position embeddings to the input tensors.
 
253
  Args:
254
  q (torch.Tensor): Query tensor of shape (B, NH, S, D_KV).
255
  k (torch.Tensor): Key tensor of shape (B, NH, S, D_KV).
256
  cos (torch.Tensor): Cosine values for rotary embedding.
257
  sin (torch.Tensor): Sine values for rotary embedding.
258
- unsqueeze_dim (int, optional): Dimension along which `cos` and `sin` are unsqueezed.
259
  Defaults to 1.
 
 
 
 
260
  Returns:
261
  Tuple[torch.Tensor, torch.Tensor]: Returns transformed query and key tensors after applying rotary embeddings.
262
  """
263
- device = q.device
264
- return map(
265
- lambda x: (x * cos[position_ids].unsqueeze(unsqueeze_dim).to(device)) +
266
- (rotate_half(x) * sin[position_ids].unsqueeze(unsqueeze_dim).to(device)), (q, k))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
 
269
  class MotifMLP(nn.Module):
270
-
271
  def __init__(self, config):
272
  super().__init__()
273
  self.hidden_size = config.hidden_size
@@ -277,33 +347,420 @@ class MotifMLP(nn.Module):
277
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
278
  self.act_fn = ACT2FN[config.hidden_act]
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  def forward(self, hidden_state):
281
- return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
 
284
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
285
- return torch.repeat_interleave(hidden_states, dim=1, repeats=n_rep)
 
 
 
 
 
 
 
 
 
 
 
286
 
 
 
287
 
288
- # @log_timing
289
  class MotifAttention(nn.Module):
290
  """
291
  Differential Attention (DiffAttention) module.
292
- Implements the Differential Attention from
 
293
  "DIFFERENTIAL TRANSFORMER" (https://arxiv.org/pdf/2410.05258).
 
294
  Overview
295
  Standard transformers often over-allocate attention to irrelevant context.
296
- DiffAttention addresses this by computing attention as the difference between
297
- two separate softmax attention maps, effectively canceling noise and promoting
298
  sparse, structured attention patterns.
 
299
  Reference Implementation
300
  https://github.com/microsoft/unilm/tree/master/Diff-Transformer
 
301
  Args
302
- The differential attention mechanism computes attention as the difference of two softmax attention scores, weighted by a learnable scalar λ.
303
  λ is re-parameterized as λ = exp(λ_q1 · λ_k1) − exp(λ_q2 · λ_k2) + λ_init.
304
  - lambda_q1, lambda_q2 (nn.Parameter): Learnable vectors used to compute the first and second components of λ for query transformations.
305
  - lambda_k1, lambda_k2 (nn.Parameter): Learnable vectors used to compute the first and second components of λ for key transformations.
306
  - lambda_init (float): A constant used for initializing λ, typically set as λ_init = 0.8 − 0.6 × exp(−0.3 × (layer_index − 1)).
 
307
  """
308
 
309
  def __init__(self, config: MotifConfig, layer_idx: Optional[int] = None):
@@ -383,7 +840,7 @@ class MotifAttention(nn.Module):
383
  self.subln = MotifRMSNorm(2 * self.head_dim, eps=1e-5)
384
  self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * (layer_idx - 1))
385
 
386
- self.rotary_emb = MotifRotaryEmbedding(self.head_dim,
387
  max_position_embeddings=self.max_position_embeddings,
388
  base=self.rope_theta)
389
 
@@ -429,12 +886,12 @@ class MotifAttention(nn.Module):
429
  cos, sin = (self.rotary_emb(value_states, q_len + past_key_value.get_usable_length(q_len, self.layer_idx))
430
  if use_cache else position_embeddings)
431
 
432
- query_states, key_states = apply_rotary_pos_emb(
433
- query_states,
434
- key_states,
435
- cos,
436
- sin
437
- )
438
 
439
  if past_key_value is not None:
440
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
@@ -503,7 +960,6 @@ class MotifAttention(nn.Module):
503
  return attn_output, attn_weights, past_key_value
504
 
505
 
506
- # @log_timing
507
  class MotifFlashAttention2(MotifAttention):
508
  """
509
  Motif flash attention module, following Motif attention module. This module inherits from `MotifAttention`
@@ -517,7 +973,7 @@ class MotifFlashAttention2(MotifAttention):
517
  def __init__(self, *args, **kwargs):
518
  super().__init__(*args, **kwargs)
519
 
520
- logger.info(f'flash attention True')
521
 
522
  # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
523
  # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
@@ -525,6 +981,8 @@ class MotifFlashAttention2(MotifAttention):
525
 
526
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
527
 
 
 
528
  def _reshape_heads(self, tensor, batch_size, seq_len):
529
  """2-way head split tensor reshape"""
530
  return tensor.reshape(batch_size, seq_len, self.num_heads, 2, self.head_dim)
@@ -534,27 +992,59 @@ class MotifFlashAttention2(MotifAttention):
534
  return tensor.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
535
 
536
  def _compute_attention(self, query_states, key_states, value_states, attention_mask, q_len, position_ids,
537
- dropout_rate, sliding_window, batch_num):
538
  """Flash Attention 2 implements"""
 
539
  scale_factor = 1.0 / math.sqrt(self.head_dim)
540
- # Copied from _flash_attention_forward
541
  if not self._flash_attn_uses_top_left_mask:
542
  causal = self.is_causal
543
  else:
544
  causal = self.is_causal and q_len != 1
 
 
 
545
 
546
- bsz = query_states.shape[0]
 
 
 
547
 
548
- return _flash_attention_forward(query_states.bfloat16(),
549
- key_states.bfloat16(),
550
- value_states.bfloat16(),
551
- attention_mask,
552
- q_len,
553
- position_ids=position_ids,
554
- dropout=dropout_rate,
555
- sliding_window=sliding_window,
556
- is_causal=self.is_causal,
557
- use_top_left_mask=self._flash_attn_uses_top_left_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
 
559
  def forward(
560
  self,
@@ -588,12 +1078,12 @@ class MotifFlashAttention2(MotifAttention):
588
  cos, sin = (self.rotary_emb(value_states, q_len + past_key_value.get_usable_length(q_len, self.layer_idx))
589
  if use_cache else position_embeddings)
590
 
591
- query_states, key_states = apply_rotary_pos_emb(
592
- query_states,
593
- key_states,
594
- cos,
595
- sin
596
- )
597
 
598
  if past_key_value is not None:
599
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
@@ -604,6 +1094,28 @@ class MotifFlashAttention2(MotifAttention):
604
  value_states = repeat_kv(value_states, self.num_key_value_groups)
605
  dropout_rate = 0.0 if not self.training else self.attention_dropout
606
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607
  q_len = query_states.shape[-2]
608
  kv_seq_len = key_states.shape[-2]
609
 
@@ -613,7 +1125,7 @@ class MotifFlashAttention2(MotifAttention):
613
  value_states = value_states.transpose(1, 2)
614
 
615
  if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None
616
- and self.layer_idx >= self.config.max_window_layers):
617
  sliding_window = self.config.sliding_window
618
  else:
619
  sliding_window = None
@@ -633,13 +1145,14 @@ class MotifFlashAttention2(MotifAttention):
633
  k1, k2 = k1.contiguous(), k2.contiguous()
634
  v1, v2 = v1.contiguous(), v2.contiguous()
635
 
636
- attn11, attn12 = self._compute_attention(q1, k1, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window, self.batch_num), \
637
- self._compute_attention(q1, k1, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window, self.batch_num)
638
- attn21, attn22 = self._compute_attention(q2, k2, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window, self.batch_num), \
639
- self._compute_attention(q2, k2, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window, self.batch_num)
640
 
641
-
642
- attn1, attn2 = torch.cat([attn11, attn12], dim=-1).float(), torch.cat([attn21, attn22], dim=-1).float()
 
 
 
 
643
 
644
  lambda_q1 = self.lambda_q1.unsqueeze(0).expand([bsz, self.lambda_q1.shape[0]]) # bsz, num_head
645
  lambda_q2 = self.lambda_q2.unsqueeze(0).expand([bsz, self.lambda_q2.shape[0]]) # bsz, num_head
@@ -655,16 +1168,15 @@ class MotifFlashAttention2(MotifAttention):
655
  attn_output = attn_output * (1 - self.lambda_init)
656
 
657
  if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim * 2):
658
- raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
659
  f" {attn_output.size()}")
660
 
661
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).bfloat16()
662
  attn_output = self.o_proj(attn_output) * self.o_proj_alpha
663
 
664
- return attn_output.float(), None, past_key_value
665
 
666
 
667
- # @log_timing
668
  class MotifSdpaAttention(MotifAttention):
669
  """
670
  Motif attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -758,16 +1270,17 @@ class MotifSdpaAttention(MotifAttention):
758
  MOTIF_ATTENTION_CLASSES = {
759
  "eager": MotifAttention,
760
  "flash_attention_2": MotifFlashAttention2,
761
- "sdpa": MotifSdpaAttention,
762
  }
763
 
764
 
765
  class MotifDecoderLayer(nn.Module):
766
 
767
- def __init__(self, config: MotifConfig, layer_idx: int):
768
  super().__init__()
769
  self.hidden_size = config.hidden_size
770
-
 
771
  if config.sliding_window and config._attn_implementation != "flash_attention_2":
772
  logger.warning_once(
773
  f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
@@ -777,8 +1290,12 @@ class MotifDecoderLayer(nn.Module):
777
  else:
778
  self.self_attn = MOTIF_ATTENTION_CLASSES["eager"](config, layer_idx)
779
  self.mlp = MotifMLP(config)
780
-
781
- RMSNorm = MotifRMSNorm
 
 
 
 
782
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
783
  self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
784
 
@@ -847,7 +1364,13 @@ class MotifDecoderLayer(nn.Module):
847
  residual = hidden_states
848
  hidden_states = self.post_attention_layernorm(hidden_states) * self.post_attention_layernorm_alpha
849
 
850
- hidden_states = self.mlp(hidden_states)
 
 
 
 
 
 
851
 
852
  hidden_states = residual + hidden_states
853
 
@@ -866,9 +1389,11 @@ MOTIF_START_DOCSTRING = r"""
866
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
867
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
868
  etc.)
 
869
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
870
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
871
  and behavior.
 
872
  Parameters:
873
  config ([`MotifConfig`]):
874
  Model configuration class with all the parameters of the model. Initializing with a config file does not
@@ -918,23 +1443,26 @@ class MotifPreTrainedModel(PreTrainedModel):
918
  module_std = module_std / math.sqrt(self.config.dim_model_base_lmh) ### lmhead.. 1
919
  else:
920
  module_std = module_std
921
-
922
- torch.nn.init.trunc_normal_(module.weight.data, mean=0.0, std=module_std, a=-3*module_std, b=3*module_std)
 
923
  if module.bias is not None:
924
  module.bias.data.zero_()
925
 
926
  elif isinstance(module, nn.Embedding):
927
- torch.nn.init.trunc_normal_(module.weight.data, mean=0.0, std=module_std, a=-3*module_std, b=3*module_std)
 
 
928
  if module.padding_idx is not None:
929
  module.weight.data[module.padding_idx].zero_()
930
 
931
 
932
  @dataclass
933
  class MotifModelOutputWithPast(ModelOutput):
934
- """
935
- This augments `BaseModelOutputWithPast` in `transformers.modeling_outputs` with new optional keys: `causal_mask`, `position_embeddings`.
936
  The optional keys are currently used in the following ways:
937
- - pass information to the token-wise last attention layers in multi-token training
938
  """
939
  last_hidden_state: torch.FloatTensor = None
940
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
@@ -949,39 +1477,51 @@ MOTIF_INPUTS_DOCSTRING = r"""
949
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
950
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
951
  it.
 
952
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
953
  [`PreTrainedTokenizer.__call__`] for details.
 
954
  [What are input IDs?](../glossary#input-ids)
955
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
956
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
957
  - 1 for tokens that are **not masked**,
958
  - 0 for tokens that are **masked**.
 
959
  [What are attention masks?](../glossary#attention-mask)
 
960
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
961
  [`PreTrainedTokenizer.__call__`] for details.
 
962
  If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
963
  `past_key_values`).
 
964
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
965
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
966
  information on the default strategy.
 
967
  - 1 indicates the head is **not masked**,
968
  - 0 indicates the head is **masked**.
969
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
970
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
971
  config.n_positions - 1]`.
 
972
  [What are position IDs?](../glossary#position-ids)
973
  past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
974
  Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
975
  blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
976
  returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
 
977
  Two formats are allowed:
978
  - a [`~cache_utils.Cache`] instance, see our
979
  [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
980
  - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
981
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
982
  cache format.
 
983
  The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
984
  legacy cache format will be returned.
 
985
  If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
986
  have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
987
  of shape `(batch_size, sequence_length)`.
@@ -1014,6 +1554,7 @@ MOTIF_INPUTS_DOCSTRING = r"""
1014
  class MotifModel(MotifPreTrainedModel):
1015
  """
1016
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MotifDecoderLayer`]
 
1017
  Args:
1018
  config: MotifConfig
1019
  """
@@ -1025,14 +1566,19 @@ class MotifModel(MotifPreTrainedModel):
1025
  self.multi_token_heads = config.multi_token_heads
1026
 
1027
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
 
1028
 
1029
  num_hidden_layers = config.num_hidden_layers if self.multi_token_heads is None else config.num_hidden_layers - 1
1030
-
1031
- self.layers = nn.ModuleList([
1032
- MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)
1033
- ])
 
 
 
1034
  self._attn_implementation = config._attn_implementation
1035
- RMSNorm = MotifRMSNorm
1036
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1037
  self.hidden_size = config.hidden_size
1038
  self.num_heads = config.num_attention_heads
@@ -1046,6 +1592,34 @@ class MotifModel(MotifPreTrainedModel):
1046
  self.gradient_checkpointing = False
1047
  self.post_init()
1048
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1049
  def get_input_embeddings(self):
1050
  return self.embed_tokens
1051
 
@@ -1084,6 +1658,7 @@ class MotifModel(MotifPreTrainedModel):
1084
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
1085
  use_cache = False
1086
 
 
1087
  return_legacy_cache = False
1088
  if use_cache and not isinstance(past_key_values, Cache):
1089
  return_legacy_cache = True
@@ -1097,17 +1672,17 @@ class MotifModel(MotifPreTrainedModel):
1097
  "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)")
1098
 
1099
  if inputs_embeds is None:
1100
- inputs_embeds = self.embed_tokens(input_ids)
1101
 
1102
  if cache_position is None:
1103
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1104
  cache_position = torch.arange(past_seen_tokens,
1105
  past_seen_tokens + inputs_embeds.shape[1],
1106
  device=inputs_embeds.device)
1107
- position_ids = None
1108
  if position_ids is None:
1109
  position_ids = cache_position.unsqueeze(0)
1110
-
1111
  causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values,
1112
  output_attentions)
1113
 
@@ -1150,6 +1725,10 @@ class MotifModel(MotifPreTrainedModel):
1150
  )
1151
 
1152
  hidden_states = layer_outputs[0]
 
 
 
 
1153
 
1154
  if use_cache:
1155
  next_decoder_cache = layer_outputs[2 if output_attentions else 1]
@@ -1157,8 +1736,9 @@ class MotifModel(MotifPreTrainedModel):
1157
  if output_attentions:
1158
  all_self_attns += (layer_outputs[1], )
1159
 
1160
- hidden_states = self.norm(hidden_states)
1161
-
 
1162
  # add hidden states from the last decoder layer
1163
  if output_hidden_states:
1164
  all_hidden_states += (hidden_states, )
@@ -1190,6 +1770,8 @@ class MotifModel(MotifPreTrainedModel):
1190
  output_attentions: bool,
1191
  ):
1192
  if self.config._attn_implementation == "flash_attention_2":
 
 
1193
  if attention_mask is not None and 0.0 in attention_mask:
1194
  return attention_mask
1195
  return None
@@ -1261,6 +1843,7 @@ class MotifModel(MotifPreTrainedModel):
1261
  """
1262
  Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1263
  `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
 
1264
  Args:
1265
  attention_mask (`torch.Tensor`):
1266
  A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
@@ -1318,14 +1901,33 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
1318
  self.vocab_size = config.vocab_size
1319
  self.multi_token_heads = config.multi_token_heads
1320
 
1321
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
 
 
 
 
 
 
1322
 
1323
  # Initialize weights and apply final processing
1324
  self.post_init()
1325
-
 
 
 
 
 
 
 
 
 
1326
  if getattr(config, "tie_word_embeddings", True):
1327
  logger.info('tie embeddings')
1328
  self.tie_weights()
 
 
 
1329
 
1330
  def get_input_embeddings(self):
1331
  return self.model.embed_tokens
@@ -1345,7 +1947,101 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
1345
  def get_decoder(self):
1346
  return self.model
1347
 
1348
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1349
  @add_start_docstrings_to_model_forward(MOTIF_INPUTS_DOCSTRING)
1350
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1351
  def forward(
@@ -1370,18 +2066,25 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
1370
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1371
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1372
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
 
1373
  num_logits_to_keep (`int`, *optional*):
1374
  Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1375
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1376
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
 
1377
  Returns:
 
1378
  Example:
 
1379
  ```python
1380
  >>> from transformers import AutoTokenizer, MotifForCausalLM
1381
- >>> model = MotifForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS, trust_remote_code = True)
1382
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER, trust_remote_code = True)
 
 
1383
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
1384
  >>> inputs = tokenizer(prompt, return_tensors="pt")
 
1385
  >>> # Generate
1386
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1387
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
@@ -1394,6 +2097,8 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
1394
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1395
 
1396
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
 
 
1397
  outputs: MotifModelOutputWithPast = self.model(
1398
  input_ids=input_ids,
1399
  attention_mask=attention_mask,
@@ -1405,16 +2110,31 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
1405
  output_hidden_states=output_hidden_states,
1406
  return_dict=return_dict,
1407
  cache_position=cache_position,
 
 
1408
  )
1409
 
1410
  hidden_states = outputs[0]
1411
 
 
 
 
 
 
 
 
 
 
 
 
1412
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
 
1413
  logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
1414
  logits = logits.float()
1415
 
1416
  loss = None
1417
  if labels is not None:
 
1418
  # Shift so that tokens < n predict n
1419
  shift_logits = logits[..., :-1, :].contiguous()
1420
  shift_labels = labels[..., 1:].contiguous()
@@ -1436,4 +2156,4 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
1436
  past_key_values=outputs.past_key_values,
1437
  hidden_states=outputs.hidden_states,
1438
  attentions=outputs.attentions,
1439
- )
 
1
  import math
2
+ from typing import List, Optional, Tuple, Union
3
 
4
  import torch
5
  import torch.utils.checkpoint
 
28
  from dataclasses import dataclass
29
 
30
  import torch.nn.functional as F
 
 
 
 
 
 
 
 
 
31
 
32
  from transformers.activations import ACT2CLS as _ACT2CLS
33
  from transformers.activations import ClassInstantier
 
 
34
  class PolyNorm(torch.nn.Module):
35
+ """
36
  A trainable activation function introduced in https://arxiv.org/html/2411.03884v1.
37
  The code is copied from https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md,
38
+ with the change `* torch.rsqrt` => `/ torch.sqrt` for potential MAF incompatibility.
39
  """
40
 
41
  def __init__(self, eps=1e-6):
 
51
  return self.weight[0] * self._norm(x ** 3) + self.weight[1] * self._norm(
52
  x ** 2) + self.weight[2] * self._norm(x) + self.bias
53
 
54
+ class PolyNorm_Test(torch.nn.Module):
55
+ """
56
+ A trainable activation function introduced in https://arxiv.org/html/2411.03884v1.
57
+ The code is copied from https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md,
58
+ with the change `* torch.rsqrt` => `/ torch.sqrt` for potential MAF incompatibility.
59
+ """
60
+
61
+ def __init__(self, eps=1e-6):
62
+ super(PolyNorm_Test, self).__init__()
63
+ self.weight = torch.nn.Parameter(torch.ones(3) / 3)
64
+ self.bias = torch.nn.Parameter(torch.zeros(1))
65
+ self.eps = eps
66
+
67
+ def forward(self, x):
68
+
69
+ #return torch.nn.SiLU(x)
70
+ return moreh_ops.poly_norm(x, self.weight, self.bias)
71
+
72
+
73
+ CUSTOM_ACT2CLS = {"poly_norm": PolyNorm, "poly_norm_test": PolyNorm_Test}
74
 
 
75
  ACT2CLS = {**_ACT2CLS, **CUSTOM_ACT2CLS}
76
  ACT2FN = ClassInstantier(ACT2CLS)
77
 
78
+ logger = logging.get_logger(__name__)
79
+
80
+ if is_flash_attn_2_available():
81
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
82
+
83
+ try:
84
+ moreh_ops = torch.ops.moreh
85
+ MorehRMSNorm = moreh_ops.T5LayerNorm
86
+ ScaledDotProductAttention = moreh_ops.scaled_dot_product_attention
87
+ MorehFlashAttention = moreh_ops.flash_attention
88
+ logger.warning_once("Using moreh ops")
89
+ except AttributeError:
90
+ MorehRMSNorm = None
91
+ ScaledDotProductAttention = None
92
+ MorehFlashAttention = None
93
+ logger.warning_once("Failed to import moreh ops")
94
+
95
+ #_CHECKPOINT_FOR_DOC = "moreh/Motif-102B"
96
+ _CONFIG_FOR_DOC = "MotifConfig"
97
+
98
+ #from .moreh_moe import MorehMoeMLP, MorehMoeFusedMLP
99
+
100
 
101
  class MotifRMSNorm(nn.Module):
102
 
 
110
 
111
  def forward(self, hidden_states):
112
  input_dtype = hidden_states.dtype
113
+ hidden_states = hidden_states.to(torch.float32)
114
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
115
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
116
  return self.weight * hidden_states.to(input_dtype)
117
 
 
119
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
120
 
121
 
122
+ ALL_LAYERNORM_LAYERS.append(MotifRMSNorm if MorehRMSNorm is None else MorehRMSNorm)
123
 
124
 
125
  class MotifRotaryEmbeddingWithCache(nn.Module):
126
  """
127
  Rotary positional embedding module with caching for efficiency.
128
+
129
  Args:
130
  dim (int): Dimensionality of the embedding.
131
  max_position_embeddings (int): Maximum sequence length for caching. Default is 2048.
132
  base (int): Base for computing inverse frequency. Default is 10000.
133
  device (torch.device, optional): Device for tensor storage.
134
+
135
  Methods:
136
  forward(x, seq_len=None):
137
  Computes cosine and sine embeddings for input sequence length.
138
  Automatically updates cache if `seq_len` exceeds cached length.
139
+
140
  Attributes:
141
  inv_freq (torch.Tensor): Inverse frequency tensor for position encoding.
142
  cos_cached (torch.Tensor): Cached cosine embeddings.
 
172
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
173
 
174
  return (
175
+ self.cos_cached[ :seq_len].to(dtype=x.dtype),
176
+ self.sin_cached[ :seq_len].to(dtype=x.dtype),
177
  )
178
 
179
 
 
190
  config: Optional[MotifConfig] = None,
191
  ):
192
  super().__init__()
193
+ # TODO (joao): remove the `if` below, only used for BC
194
  self.rope_kwargs = {}
195
  if config is None:
196
  logger.warning_once(
 
235
  device,
236
  seq_len=seq_len,
237
  **self.rope_kwargs)
238
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
239
  self.max_seq_len_cached = seq_len
240
 
241
  if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
 
269
  def rotate_half(x):
270
  """
271
  Rotates half of the dimensions of the input tensor using torch.roll and in-place negation.
272
+
273
  Args:
274
  x (torch.Tensor): The input tensor.
275
+
276
  Returns:
277
  torch.Tensor: A tensor where the latter half of the dimensions are negated
278
  and moved before the first half.
 
284
  return rotated_tensor
285
 
286
 
287
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1, fused_rope=False):
288
  """
289
  Applies rotary position embeddings to the input tensors.
290
+
291
  Args:
292
  q (torch.Tensor): Query tensor of shape (B, NH, S, D_KV).
293
  k (torch.Tensor): Key tensor of shape (B, NH, S, D_KV).
294
  cos (torch.Tensor): Cosine values for rotary embedding.
295
  sin (torch.Tensor): Sine values for rotary embedding.
296
+ unsqueeze_dim (int, optional): Dimension along which `cos` and `sin` are unsqueezed.
297
  Defaults to 1.
298
+ fused_rope (bool, optional): If True, applies fused rotary embeddings using
299
+ `moreh_ops.apply_rotary_emb`. If False, computes rotary embeddings manually.
300
+ Defaults to False.
301
+
302
  Returns:
303
  Tuple[torch.Tensor, torch.Tensor]: Returns transformed query and key tensors after applying rotary embeddings.
304
  """
305
+ '''
306
+ # (B, NH, S, D_KV) -> (B, S, NH, D_KV)
307
+ cos = cos.unsqueeze(unsqueeze_dim)
308
+ sin = sin.unsqueeze(unsqueeze_dim)
309
+ q_embed = (q * cos) + (rotate_half(q) * sin)
310
+ k_embed = (k * cos) + (rotate_half(k) * sin)
311
+ '''
312
+ if not fused_rope:
313
+ device = q.device
314
+ return map(
315
+ lambda x: (x * cos[position_ids].unsqueeze(unsqueeze_dim).to(device)) +
316
+ (rotate_half(x) * sin[position_ids].unsqueeze(unsqueeze_dim).to(device)), (q, k))
317
+ else:
318
+ # (B, NH, S, D_KV) -> (B, S, NH, D_KV)
319
+ cos = cos[position_ids]
320
+ sin = sin[position_ids]
321
+
322
+ q = q.transpose(1, 2)
323
+ k = k.transpose(1, 2)
324
+
325
+ # Expand 'batch' dim
326
+ cos = cos.expand(q.shape[0], *cos.shape[1:])
327
+ sin = sin.expand(q.shape[0], *sin.shape[1:])
328
+
329
+ q_embed = moreh_ops.apply_rotary_emb(q, cos, sin, opcode=1)
330
+ k_embed = moreh_ops.apply_rotary_emb(k, cos, sin, opcode=1)
331
+
332
+ # (B, S, NH, D_KV) -> (B, NH, S, D_KV)
333
+ q_embed = q_embed.transpose(1, 2)
334
+ k_embed = k_embed.transpose(1, 2)
335
+
336
+ return q_embed, k_embed
337
 
338
 
339
  class MotifMLP(nn.Module):
340
+
341
  def __init__(self, config):
342
  super().__init__()
343
  self.hidden_size = config.hidden_size
 
347
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
348
  self.act_fn = ACT2FN[config.hidden_act]
349
 
350
+ if config.wesar_weights:
351
+ self.gate_up_proj_alpha = nn.Parameter(torch.tensor(1) *config.gate_up_proj_alpha)
352
+ self.down_proj_alpha = nn.Parameter(torch.tensor(1) * config.down_proj_alpha)
353
+ else:
354
+ self.gate_up_proj_alpha=1
355
+ self.down_proj_alpha=1
356
+ if config.muP:
357
+ self.down_proj.__do_scale_tager__ = True
358
+ self.gate_proj.__do_scale_tager_mu_dim_model__ = True
359
+ self.up_proj.__do_scale_tager_mu_dim_model__ = True
360
+ self.down_proj.__do_scale_tager_mu_ffn__ = True
361
+
362
+
363
  def forward(self, hidden_state):
364
+ hidden_state = hidden_state*self.gate_up_proj_alpha
365
+ #hidden_state = self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))*
366
+ return self.down_proj_alpha*self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
367
+
368
+
369
+ class MorehMoeFusedMLP(nn.Module):
370
+ def __init__(self,
371
+ ffn_dim,
372
+ hidden_dim,
373
+ hidden_act_moe,
374
+ num_experts,
375
+ num_groups=1,
376
+ device=None,
377
+ continual_training=False):
378
+ super().__init__()
379
+ self.ffn_dim = ffn_dim
380
+ self.hidden_dim = hidden_dim
381
+ self.hidden_act_moe = hidden_act_moe
382
+
383
+ self.num_experts = num_experts
384
+ self.num_groups = num_groups
385
+
386
+ assert self.num_experts % self.num_groups == 0
387
+ self.num_experts_per_group = self.num_experts // self.num_groups
388
+
389
+ ## bsz, seq, group size, 2*ffn_size
390
+
391
+ moreh_ops = torch.ops.moreh
392
+ self.w13 = nn.ModuleList([
393
+ moreh_ops.MoeFanInLinear(self.hidden_dim,
394
+ self.ffn_dim * 2,
395
+ bias=False,
396
+ num_experts=self.num_experts_per_group,
397
+ device=device)
398
+ for _ in range(self.num_groups)
399
+ ])
400
+
401
+ self.w2 = nn.ModuleList([
402
+ moreh_ops.MoeFanOutLinear(self.ffn_dim,
403
+ self.hidden_dim,
404
+ bias=False,
405
+ num_experts=self.num_experts_per_group,
406
+ device=device)
407
+ for _ in range(self.num_groups)
408
+ ])
409
+
410
+ ## use silu?
411
+ self.act_fn = ACT2FN[self.hidden_act_moe]
412
+
413
+ if continual_training:
414
+ logger.info('two optipons 1. zero init all weights, 2. add scaling param to moe output.')
415
+ self._zero_init()
416
+
417
+ def _zero_init(self):
418
+ for module in self.w2:
419
+ for n,param in module.named_parameters():
420
+ logger.info(f'{n} {param.shape}')
421
+ param.data.zero_()
422
+
423
+
424
+ def forward(self, hidden_states, selected_experts, routing_weights):
425
+ w13_final_output = None
426
+ for group_idx in range(self.num_groups):
427
+ w13_output_in_group = self._get_w13_output(hidden_states,
428
+ selected_experts,
429
+ group_idx)
430
+ if w13_final_output is None:
431
+ w13_final_output = w13_output_in_group
432
+ else:
433
+ w13_final_output += w13_output_in_group
434
+
435
+ current_hidden_states = self.act_fn(
436
+ w13_final_output[:, :, :, :self.ffn_dim]
437
+ ) * w13_final_output[:, :, :, self.ffn_dim:]
438
+
439
+ final_hidden_states = None
440
+ for group_idx in range(self.num_groups):
441
+ w2_output_in_group = self._get_w2_output(current_hidden_states,
442
+ selected_experts,
443
+ routing_weights, group_idx)
444
+ if final_hidden_states is None:
445
+ final_hidden_states = w2_output_in_group
446
+ else:
447
+ final_hidden_states += w2_output_in_group
448
+ return final_hidden_states
449
+
450
+ def _get_w13_output(self, hidden_states, selected_experts, group_idx):
451
+ selected_experts_in_group = selected_experts - (
452
+ group_idx * self.num_experts_per_group)
453
+
454
+ w13_output = self.w13[group_idx](hidden_states,
455
+ selected_experts_in_group)
456
+ return w13_output
457
+
458
+ def _get_w2_output(self, hidden_states, selected_experts, routing_weights,
459
+ group_idx):
460
+ selected_experts_in_group = selected_experts - (
461
+ group_idx * self.num_experts_per_group)
462
+ output = self.w2[group_idx](hidden_states, selected_experts_in_group,
463
+ routing_weights)
464
+ return output
465
+
466
+
467
+ class MoEGate(nn.Module):
468
+
469
+ def __init__(self, config):
470
+ super().__init__()
471
+ self.config = config
472
+ self.top_k = config.num_experts_per_tok
473
+ self.n_routed_experts = config.n_routed_experts
474
+ self.routed_scaling_factor = config.routed_scaling_factor
475
+ self.scoring_func = config.scoring_func
476
+ self.seq_aux = config.seq_aux
477
+ self.topk_method = config.topk_method
478
+ self.n_group = config.n_group
479
+ self.topk_group = config.topk_group
480
+
481
+ # topk selection algorithm
482
+ self.norm_topk_prob = config.norm_topk_prob
483
+ self.gating_dim = config.hidden_size
484
+ self.weight = nn.Parameter(
485
+ torch.empty((self.n_routed_experts, self.gating_dim)))
486
+ if self.topk_method == "noaux_tc":
487
+ self.e_score_correction_bias = nn.Parameter(
488
+ torch.empty((self.n_routed_experts)))
489
+ self.reset_parameters()
490
+
491
+ def reset_parameters(self) -> None:
492
+ import torch.nn.init as init
493
+
494
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
495
+
496
+ def forward(self, hidden_states):
497
+ bsz, seq_len, h = hidden_states.shape
498
+ ### compute gating score
499
+ hidden_states = hidden_states.view(-1, h)
500
+ logits = F.linear(hidden_states.type(torch.float32),
501
+ self.weight.type(torch.float32), None)
502
+ if self.scoring_func == "sigmoid":
503
+ scores = logits.sigmoid()
504
+ else:
505
+ raise NotImplementedError(
506
+ f"insupportable scoring function for MoE gating: {self.scoring_func}"
507
+ )
508
+
509
+ ### select top-k experts
510
+ if self.topk_method == "greedy":
511
+ topk_weight, topk_idx = torch.topk(scores,
512
+ k=self.top_k,
513
+ dim=-1,
514
+ sorted=False)
515
+ elif self.topk_method == "group_limited_greedy":
516
+ group_scores = (scores.view(bsz * seq_len, self.n_group,
517
+ -1).max(dim=-1).values) # [n, n_group]
518
+ group_idx = torch.topk(group_scores,
519
+ k=self.topk_group,
520
+ dim=-1,
521
+ sorted=False)[1] # [n, top_k_group]
522
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
523
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
524
+ score_mask = (group_mask.unsqueeze(-1).expand(
525
+ bsz * seq_len, self.n_group,
526
+ self.n_routed_experts // self.n_group).reshape(
527
+ bsz * seq_len, -1)) # [n, e]
528
+ tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
529
+ topk_weight, topk_idx = torch.topk(tmp_scores,
530
+ k=self.top_k,
531
+ dim=-1,
532
+ sorted=False)
533
+ elif self.topk_method == "noaux_tc":
534
+ ### will be used. ###
535
+ scores_for_choice = scores.view(
536
+ bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
537
+ group_scores = (scores_for_choice.view(
538
+ bsz * seq_len, self.n_group,
539
+ -1).topk(2, dim=-1)[0].sum(dim=-1)) # [n, n_group]
540
+ group_idx = torch.topk(group_scores,
541
+ k=self.topk_group,
542
+ dim=-1,
543
+ sorted=False)[1] # [n, top_k_group]
544
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
545
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
546
+ score_mask = (group_mask.unsqueeze(-1).expand(
547
+ bsz * seq_len, self.n_group,
548
+ self.n_routed_experts // self.n_group).reshape(
549
+ bsz * seq_len, -1)) # [n, e]
550
+ tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(),
551
+ 0.0) # [n, e]
552
+ _, topk_idx = torch.topk(tmp_scores,
553
+ k=self.top_k,
554
+ dim=-1,
555
+ sorted=False)
556
+ topk_weight = scores.gather(1, topk_idx)
557
+ else:
558
+ raise NotImplementedError(
559
+ f"insupportable TopK function for MoE gating: {self.topk_method}"
560
+ )
561
+
562
+ ### norm gate to sum 1
563
+ if self.top_k > 1 and self.norm_topk_prob:
564
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
565
+ topk_weight = topk_weight / denominator
566
+ topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
567
+
568
+ return topk_idx, topk_weight
569
+
570
+
571
+ class MotifMoE(nn.Module):
572
+ """
573
+ A mixed expert module containing shared experts.
574
+ """
575
+ def __init__(self, config):
576
+ super().__init__()
577
+ self.config = config
578
+ self.num_experts_per_tok = config.num_experts_per_tok
579
+ self.use_moreh_moe = config.use_moreh_moe
580
+ self.use_fused_mlp = config.use_fused_mlp
581
+
582
+ if hasattr(config, "ep_size") and config.ep_size > 1:
583
+ assert config.ep_size == dist.get_world_size()
584
+ assert not config.use_moreh_moe
585
+ self.ep_size = config.ep_size
586
+ self.experts_per_rank = config.n_routed_experts // config.ep_size
587
+ self.ep_rank = dist.get_rank()
588
+ self.experts = nn.ModuleList([
589
+ (DeepseekV3MLP(config,
590
+ intermediate_size=config.moe_intermediate_size)
591
+ if i >= self.ep_rank * self.experts_per_rank and i <
592
+ (self.ep_rank + 1) * self.experts_per_rank else None)
593
+ for i in range(config.n_routed_experts)
594
+ ])
595
+ else:
596
+ self.ep_size = 1
597
+ self.experts_per_rank = config.n_routed_experts
598
+ self.ep_rank = 0
599
+ if self.use_moreh_moe:
600
+ if not self.use_fused_mlp:
601
+ self.experts = MorehMoeMLP(
602
+ ffn_dim=config.moe_intermediate_size,
603
+ hidden_dim=config.hidden_size,
604
+ hidden_act_moe=config.hidden_act_moe,
605
+ num_experts=config.n_routed_experts,
606
+ device=None)
607
+ else:
608
+ ## group expert.
609
+ self.experts = MorehMoeFusedMLP(
610
+ ffn_dim=config.moe_intermediate_size,
611
+ hidden_dim=config.hidden_size,
612
+ hidden_act_moe=config.hidden_act_moe,
613
+ num_experts=config.n_routed_experts,
614
+ num_groups=config.n_group,
615
+ device=None,
616
+ continual_training=config.continual_training,
617
+ )
618
+ else:
619
+ self.experts = nn.ModuleList([
620
+ DeepseekV3MLP(
621
+ config, intermediate_size=config.moe_intermediate_size)
622
+ for i in range(config.n_routed_experts)
623
+ ])
624
+
625
+ self.gate = MoEGate(config)
626
+
627
+ def forward(self, hidden_states):
628
+ identity = hidden_states
629
+ orig_shape = hidden_states.shape
630
+ topk_idx, topk_weight = self.gate(hidden_states)
631
+ if self.use_moreh_moe:
632
+ y = self.experts(hidden_states, topk_idx.view(*orig_shape[:-1], -1),
633
+ topk_weight.view(*orig_shape[:-1], -1))
634
+ y = y.type(hidden_states.dtype)
635
+ else:
636
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
637
+ flat_topk_idx = topk_idx.view(-1)
638
+ if self.training:
639
+ hidden_states = hidden_states.repeat_interleave(
640
+ self.num_experts_per_tok, dim=0)
641
+ y = torch.empty_like(hidden_states)
642
+ for i, expert in enumerate(self.experts):
643
+ y[flat_topk_idx == i] = expert(
644
+ hidden_states[flat_topk_idx == i])
645
+ y = (y.view(*topk_weight.shape, -1) *
646
+ topk_weight.unsqueeze(-1)).sum(dim=1)
647
+ y = y.type(hidden_states.dtype)
648
+ y = y.view(*orig_shape)
649
+ # y = AddAuxiliaryLoss.apply(y, aux_loss)
650
+ else:
651
+ y = self.moe_infer(hidden_states, topk_idx,
652
+ topk_weight).view(*orig_shape)
653
+ return y, identity
654
+
655
+ @torch.no_grad()
656
+ def moe_infer(self, x, topk_ids, topk_weight):
657
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
658
+ cnts.scatter_(1, topk_ids, 1)
659
+ tokens_per_expert = cnts.sum(dim=0)
660
+ idxs = topk_ids.view(-1).argsort()
661
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
662
+ sorted_tokens_shape = sorted_tokens.shape
663
+ if self.ep_size > 1:
664
+ tokens_per_ep_rank = tokens_per_expert.view(self.ep_size,
665
+ -1).sum(dim=1)
666
+ tokens_per_expert_group = tokens_per_expert.new_empty(
667
+ tokens_per_expert.shape[0])
668
+ dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
669
+ output_splits = (tokens_per_expert_group.view(
670
+ self.ep_size, -1).sum(1).cpu().numpy().tolist())
671
+ gathered_tokens = sorted_tokens.new_empty(
672
+ tokens_per_expert_group.sum(dim=0).cpu().item(),
673
+ sorted_tokens.shape[1])
674
+ input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
675
+ dist.all_to_all(
676
+ list(gathered_tokens.split(output_splits)),
677
+ list(sorted_tokens.split(input_split_sizes)),
678
+ )
679
+ tokens_per_expert_post_gather = tokens_per_expert_group.view(
680
+ self.ep_size, self.experts_per_rank).sum(dim=0)
681
+ gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],),
682
+ dtype=np.int32)
683
+ s = 0
684
+ for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
685
+ gatherd_idxs[s:s + k] = i % self.experts_per_rank
686
+ s += k
687
+ gatherd_idxs = gatherd_idxs.argsort()
688
+ sorted_tokens = gathered_tokens[gatherd_idxs]
689
+ tokens_per_expert = tokens_per_expert_post_gather
690
+ tokens_per_expert = tokens_per_expert.cpu().numpy()
691
+
692
+ outputs = []
693
+ start_idx = 0
694
+ for i, num_tokens in enumerate(tokens_per_expert):
695
+ end_idx = start_idx + num_tokens
696
+ if num_tokens == 0:
697
+ continue
698
+ expert = self.experts[i + self.ep_rank * self.experts_per_rank]
699
+ tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
700
+ expert_out = expert(tokens_for_this_expert)
701
+ outputs.append(expert_out)
702
+ start_idx = end_idx
703
+
704
+ outs = torch.cat(outputs,
705
+ dim=0) if len(outputs) else sorted_tokens.new_empty(0)
706
+ if self.ep_size > 1:
707
+ new_x = torch.empty_like(outs)
708
+ new_x[gatherd_idxs] = outs
709
+ gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
710
+ dist.all_to_all(
711
+ list(gathered_tokens.split(input_split_sizes)),
712
+ list(new_x.split(output_splits)),
713
+ )
714
+ outs = gathered_tokens
715
+
716
+ new_x = torch.empty_like(outs)
717
+ new_x[idxs] = outs
718
+ final_out = (new_x.view(
719
+ *topk_ids.shape, -1).type(topk_weight.dtype).mul_(
720
+ topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
721
+ return final_out
722
 
723
 
724
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
725
+
726
+
727
+ """
728
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
729
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
730
+
731
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
732
+ if n_rep == 1:
733
+ return hidden_states
734
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
735
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
736
+ """
737
 
738
+ return torch.repeat_interleave(hidden_states, dim=1, repeats=n_rep)
739
+
740
 
 
741
  class MotifAttention(nn.Module):
742
  """
743
  Differential Attention (DiffAttention) module.
744
+
745
+ Implements the Differential Attention from
746
  "DIFFERENTIAL TRANSFORMER" (https://arxiv.org/pdf/2410.05258).
747
+
748
  Overview
749
  Standard transformers often over-allocate attention to irrelevant context.
750
+ DiffAttention addresses this by computing attention as the difference between
751
+ two separate softmax attention maps, effectively canceling noise and promoting
752
  sparse, structured attention patterns.
753
+
754
  Reference Implementation
755
  https://github.com/microsoft/unilm/tree/master/Diff-Transformer
756
+
757
  Args
758
+ The differential attention mechanism computes attention as the difference of two softmax attention scores, weighted by a learnable scalar λ.
759
  λ is re-parameterized as λ = exp(λ_q1 · λ_k1) − exp(λ_q2 · λ_k2) + λ_init.
760
  - lambda_q1, lambda_q2 (nn.Parameter): Learnable vectors used to compute the first and second components of λ for query transformations.
761
  - lambda_k1, lambda_k2 (nn.Parameter): Learnable vectors used to compute the first and second components of λ for key transformations.
762
  - lambda_init (float): A constant used for initializing λ, typically set as λ_init = 0.8 − 0.6 × exp(−0.3 × (layer_index − 1)).
763
+
764
  """
765
 
766
  def __init__(self, config: MotifConfig, layer_idx: Optional[int] = None):
 
840
  self.subln = MotifRMSNorm(2 * self.head_dim, eps=1e-5)
841
  self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * (layer_idx - 1))
842
 
843
+ self.rotary_emb = MotifRotaryEmbeddingWithCache(self.head_dim,
844
  max_position_embeddings=self.max_position_embeddings,
845
  base=self.rope_theta)
846
 
 
886
  cos, sin = (self.rotary_emb(value_states, q_len + past_key_value.get_usable_length(q_len, self.layer_idx))
887
  if use_cache else position_embeddings)
888
 
889
+ query_states, key_states = apply_rotary_pos_emb(query_states,
890
+ key_states,
891
+ cos,
892
+ sin,
893
+ position_ids=position_ids,
894
+ fused_rope=self.config.fused_rope)
895
 
896
  if past_key_value is not None:
897
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
 
960
  return attn_output, attn_weights, past_key_value
961
 
962
 
 
963
  class MotifFlashAttention2(MotifAttention):
964
  """
965
  Motif flash attention module, following Motif attention module. This module inherits from `MotifAttention`
 
973
  def __init__(self, *args, **kwargs):
974
  super().__init__(*args, **kwargs)
975
 
976
+
977
 
978
  # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
979
  # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
 
981
 
982
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
983
 
984
+ logger.info(f'flash attention is used {not self._flash_attn_uses_top_left_mask}')
985
+
986
  def _reshape_heads(self, tensor, batch_size, seq_len):
987
  """2-way head split tensor reshape"""
988
  return tensor.reshape(batch_size, seq_len, self.num_heads, 2, self.head_dim)
 
992
  return tensor.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
993
 
994
  def _compute_attention(self, query_states, key_states, value_states, attention_mask, q_len, position_ids,
995
+ dropout_rate, sliding_window, is_moreh_attention, batch_num):
996
  """Flash Attention 2 implements"""
997
+
998
  scale_factor = 1.0 / math.sqrt(self.head_dim)
999
+ # Copied from _flash_attention_forward
1000
  if not self._flash_attn_uses_top_left_mask:
1001
  causal = self.is_causal
1002
  else:
1003
  causal = self.is_causal and q_len != 1
1004
+
1005
+ if is_moreh_attention:
1006
+ bsz = query_states.shape[0]
1007
 
1008
+ if batch_num:
1009
+ query_states = query_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
1010
+ key_states = key_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
1011
+ value_states = value_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
1012
 
1013
+ attn_out = moreh_ops.flash_attention_varlen_dp(query_states,
1014
+ key_states,
1015
+ value_states,
1016
+ attention_mask,
1017
+ attention_mask,
1018
+ max_seqlen_q=q_len,
1019
+ max_seqlen_kv=q_len,
1020
+ dropout_p=dropout_rate,
1021
+ softmax_scale=scale_factor,
1022
+ is_causal=causal,
1023
+ batch_num=batch_num)
1024
+ attn_out = attn_out.reshape(bsz, q_len, self.num_heads, -1)
1025
+ else:
1026
+ return MorehFlashAttention(query_states,
1027
+ key_states,
1028
+ value_states,
1029
+ padding_mask=attention_mask,
1030
+ dropout_p=dropout_rate,
1031
+ softmax_scale=scale_factor,
1032
+ causal=causal)
1033
+ return attn_out
1034
+ else:
1035
+ attn_out = _flash_attention_forward(query_states,
1036
+ key_states,
1037
+ value_states,
1038
+ attention_mask,
1039
+ q_len,
1040
+ position_ids=position_ids,
1041
+ dropout=dropout_rate,
1042
+ sliding_window=sliding_window,
1043
+ is_causal=True,
1044
+ softmax_scale=scale_factor,
1045
+ use_top_left_mask=self._flash_attn_uses_top_left_mask)
1046
+ #logger.info(attn_out)
1047
+ return attn_out
1048
 
1049
  def forward(
1050
  self,
 
1078
  cos, sin = (self.rotary_emb(value_states, q_len + past_key_value.get_usable_length(q_len, self.layer_idx))
1079
  if use_cache else position_embeddings)
1080
 
1081
+ query_states, key_states = apply_rotary_pos_emb(query_states,
1082
+ key_states,
1083
+ cos,
1084
+ sin,
1085
+ position_ids=position_ids,
1086
+ fused_rope=False)
1087
 
1088
  if past_key_value is not None:
1089
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
 
1094
  value_states = repeat_kv(value_states, self.num_key_value_groups)
1095
  dropout_rate = 0.0 if not self.training else self.attention_dropout
1096
 
1097
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
1098
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
1099
+ # cast them back in float16 just to be sure everything works as expected.
1100
+ input_dtype = query_states.dtype
1101
+ if input_dtype == torch.float32 and MorehFlashAttention is None:
1102
+ if torch.is_autocast_enabled():
1103
+ target_dtype = torch.get_autocast_gpu_dtype()
1104
+ # Handle the case where the model is quantized
1105
+ elif hasattr(self.config, "_pre_quantization_dtype"):
1106
+ target_dtype = self.config._pre_quantization_dtype
1107
+ else:
1108
+ target_dtype = self.q_proj.weight.dtype
1109
+
1110
+ logger.warning_once(
1111
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
1112
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
1113
+ f" {target_dtype}.")
1114
+
1115
+ query_states = query_states.to(target_dtype)
1116
+ key_states = key_states.to(target_dtype)
1117
+ value_states = value_states.to(target_dtype)
1118
+
1119
  q_len = query_states.shape[-2]
1120
  kv_seq_len = key_states.shape[-2]
1121
 
 
1125
  value_states = value_states.transpose(1, 2)
1126
 
1127
  if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None
1128
+ and self.layer_idx >= self.config.max_window_layers and MorehFlashAttention is None):
1129
  sliding_window = self.config.sliding_window
1130
  else:
1131
  sliding_window = None
 
1145
  k1, k2 = k1.contiguous(), k2.contiguous()
1146
  v1, v2 = v1.contiguous(), v2.contiguous()
1147
 
1148
+ is_moreh_attention = MorehFlashAttention is not None
 
 
 
1149
 
1150
+ attn11, attn12 = self._compute_attention(q1, k1, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num), \
1151
+ self._compute_attention(q1, k1, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num)
1152
+ attn21, attn22 = self._compute_attention(q2, k2, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num), \
1153
+ self._compute_attention(q2, k2, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num)
1154
+
1155
+ attn1, attn2 = torch.cat([attn11, attn12], dim=-1), torch.cat([attn21, attn22], dim=-1)
1156
 
1157
  lambda_q1 = self.lambda_q1.unsqueeze(0).expand([bsz, self.lambda_q1.shape[0]]) # bsz, num_head
1158
  lambda_q2 = self.lambda_q2.unsqueeze(0).expand([bsz, self.lambda_q2.shape[0]]) # bsz, num_head
 
1168
  attn_output = attn_output * (1 - self.lambda_init)
1169
 
1170
  if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim * 2):
1171
+ raise ValueError(f"`attn_output` should be of size {(bsz, q_len, self.num_heads, 2*self.head_dim)}, but is"
1172
  f" {attn_output.size()}")
1173
 
1174
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
1175
  attn_output = self.o_proj(attn_output) * self.o_proj_alpha
1176
 
1177
+ return attn_output, None, past_key_value
1178
 
1179
 
 
1180
  class MotifSdpaAttention(MotifAttention):
1181
  """
1182
  Motif attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
 
1270
  MOTIF_ATTENTION_CLASSES = {
1271
  "eager": MotifAttention,
1272
  "flash_attention_2": MotifFlashAttention2,
1273
+ "sdpa": MotifAttention,
1274
  }
1275
 
1276
 
1277
  class MotifDecoderLayer(nn.Module):
1278
 
1279
+ def __init__(self, config: MotifConfig, moe_layer: bool, layer_idx: int):
1280
  super().__init__()
1281
  self.hidden_size = config.hidden_size
1282
+ if config.use_moreh_attention:
1283
+ config._attn_implementation = "flash_attention_2"
1284
  if config.sliding_window and config._attn_implementation != "flash_attention_2":
1285
  logger.warning_once(
1286
  f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
 
1290
  else:
1291
  self.self_attn = MOTIF_ATTENTION_CLASSES["eager"](config, layer_idx)
1292
  self.mlp = MotifMLP(config)
1293
+ ### moe
1294
+ self.moe = None
1295
+ if moe_layer:
1296
+ self.moe = MotifMoE(config)
1297
+
1298
+ RMSNorm = MorehRMSNorm if MorehRMSNorm is not None else MotifRMSNorm
1299
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1300
  self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1301
 
 
1364
  residual = hidden_states
1365
  hidden_states = self.post_attention_layernorm(hidden_states) * self.post_attention_layernorm_alpha
1366
 
1367
+ if self.moe is not None:
1368
+ hidden_states, identity = self.moe(hidden_states)
1369
+ ## add output of shared expert and output of small moe experts.
1370
+ ## hidden state must be zero tensor (for first forward)
1371
+ hidden_states += self.mlp(identity)
1372
+ else:
1373
+ hidden_states = self.mlp(hidden_states)
1374
 
1375
  hidden_states = residual + hidden_states
1376
 
 
1389
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1390
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1391
  etc.)
1392
+
1393
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1394
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1395
  and behavior.
1396
+
1397
  Parameters:
1398
  config ([`MotifConfig`]):
1399
  Model configuration class with all the parameters of the model. Initializing with a config file does not
 
1443
  module_std = module_std / math.sqrt(self.config.dim_model_base_lmh) ### lmhead.. 1
1444
  else:
1445
  module_std = module_std
1446
+ module.weight.data.normal_(mean=0.0, std=module_std)
1447
+ module.weight.data = torch.where(abs(module.weight.data) > module_std*3, 0, module.weight.data)
1448
+ #torch.nn.init.trunc_normal_(module.weight.data, mean=0.0, std=module_std, a=-3*module_std, b=3*module_std)
1449
  if module.bias is not None:
1450
  module.bias.data.zero_()
1451
 
1452
  elif isinstance(module, nn.Embedding):
1453
+ module.weight.data.normal_(mean=0.0, std=module_std)
1454
+ module.weight.data = torch.where(abs(module.weight.data) > module_std*3, 0, module.weight.data)
1455
+ #torch.nn.init.trunc_normal_(module.weight.data, mean=0.0, std=module_std, a=-3*module_std, b=3*module_std)
1456
  if module.padding_idx is not None:
1457
  module.weight.data[module.padding_idx].zero_()
1458
 
1459
 
1460
  @dataclass
1461
  class MotifModelOutputWithPast(ModelOutput):
1462
+ """
1463
+ This augments `BaseModelOutputWithPast` in `transformers.modeling_outputs` with new optional keys: `causal_mask`, `position_embeddings`.
1464
  The optional keys are currently used in the following ways:
1465
+ - pass information to the token-wise last attention layers in multi-token training
1466
  """
1467
  last_hidden_state: torch.FloatTensor = None
1468
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
 
1477
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1478
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1479
  it.
1480
+
1481
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1482
  [`PreTrainedTokenizer.__call__`] for details.
1483
+
1484
  [What are input IDs?](../glossary#input-ids)
1485
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1486
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1487
+
1488
  - 1 for tokens that are **not masked**,
1489
  - 0 for tokens that are **masked**.
1490
+
1491
  [What are attention masks?](../glossary#attention-mask)
1492
+
1493
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1494
  [`PreTrainedTokenizer.__call__`] for details.
1495
+
1496
  If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
1497
  `past_key_values`).
1498
+
1499
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1500
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1501
  information on the default strategy.
1502
+
1503
  - 1 indicates the head is **not masked**,
1504
  - 0 indicates the head is **masked**.
1505
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1506
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1507
  config.n_positions - 1]`.
1508
+
1509
  [What are position IDs?](../glossary#position-ids)
1510
  past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1511
  Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1512
  blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1513
  returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1514
+
1515
  Two formats are allowed:
1516
  - a [`~cache_utils.Cache`] instance, see our
1517
  [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
1518
  - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1519
  shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1520
  cache format.
1521
+
1522
  The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1523
  legacy cache format will be returned.
1524
+
1525
  If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1526
  have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1527
  of shape `(batch_size, sequence_length)`.
 
1554
  class MotifModel(MotifPreTrainedModel):
1555
  """
1556
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MotifDecoderLayer`]
1557
+
1558
  Args:
1559
  config: MotifConfig
1560
  """
 
1566
  self.multi_token_heads = config.multi_token_heads
1567
 
1568
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1569
+ # NOTE: For multi-token models, the last decoder layers (one for each token index)
1570
+ # are implemented as a part of `MotifModelForCausalLM` to enable a custom forward-backward procedure.
1571
 
1572
  num_hidden_layers = config.num_hidden_layers if self.multi_token_heads is None else config.num_hidden_layers - 1
1573
+ if config.moe:
1574
+ moe_layer = [True for i in range(num_hidden_layers)]
1575
+ else:
1576
+ moe_layer = [False for i in range(num_hidden_layers)]
1577
+ logger.info(f'current_moe layer { moe_layer }')
1578
+ self.layers = nn.ModuleList([MotifDecoderLayer(config = config, moe_layer= moe_layer[layer_idx],
1579
+ layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)])
1580
  self._attn_implementation = config._attn_implementation
1581
+ RMSNorm = MorehRMSNorm if MorehRMSNorm is not None else MotifRMSNorm
1582
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1583
  self.hidden_size = config.hidden_size
1584
  self.num_heads = config.num_attention_heads
 
1592
  self.gradient_checkpointing = False
1593
  self.post_init()
1594
 
1595
+ self.use_pipeline = config.use_pipeline
1596
+ if self.use_pipeline:
1597
+ logger.info('use reinforced pp..')
1598
+ if config.num_stages==2:
1599
+ ### moe version
1600
+ if config.decontam_attn:
1601
+ self.split_layers = [15]
1602
+ else:
1603
+ if num_hidden_layers == 32:
1604
+ self.split_layers = [14] # 14: 15,17 # 13: 14:18
1605
+ else:
1606
+ self.split_layers = [6]
1607
+ elif config.num_stages==3:
1608
+ self.split_layers = [9,20] ## 10, 11, 11
1609
+ else:
1610
+ self.split_layers = [6,15,24] #7(0,7),9(6,15),9(15,24),7(24,31)
1611
+ logger.info(f' check the split layers (moe): {self.split_layers}')
1612
+
1613
+ self.scale_emb = 1
1614
+
1615
+ # Reparameterization <|_1_|>
1616
+ if config.wesar_weights :
1617
+ logger.info(f'config.wesar_weights {config.wesar_weights}')
1618
+ self.norm_alpha = nn.Parameter(torch.tensor(1).float())
1619
+ self.scale_emb = 10
1620
+ else:
1621
+ self.norm_alpha = 1
1622
+
1623
  def get_input_embeddings(self):
1624
  return self.embed_tokens
1625
 
 
1658
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
1659
  use_cache = False
1660
 
1661
+ # kept for BC (non `Cache` `past_key_values` inputs)
1662
  return_legacy_cache = False
1663
  if use_cache and not isinstance(past_key_values, Cache):
1664
  return_legacy_cache = True
 
1672
  "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)")
1673
 
1674
  if inputs_embeds is None:
1675
+ inputs_embeds = self.embed_tokens(input_ids) * self.scale_emb
1676
 
1677
  if cache_position is None:
1678
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1679
  cache_position = torch.arange(past_seen_tokens,
1680
  past_seen_tokens + inputs_embeds.shape[1],
1681
  device=inputs_embeds.device)
1682
+ #position_ids = None
1683
  if position_ids is None:
1684
  position_ids = cache_position.unsqueeze(0)
1685
+
1686
  causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values,
1687
  output_attentions)
1688
 
 
1725
  )
1726
 
1727
  hidden_states = layer_outputs[0]
1728
+
1729
+
1730
+ if self.use_pipeline and idx in self.split_layers:
1731
+ hidden_states = torch.moreh.pipeline_assign(hidden_states)
1732
 
1733
  if use_cache:
1734
  next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 
1736
  if output_attentions:
1737
  all_self_attns += (layer_outputs[1], )
1738
 
1739
+ # <|_2_|>
1740
+ hidden_states = self.norm(hidden_states)* self.norm_alpha
1741
+
1742
  # add hidden states from the last decoder layer
1743
  if output_hidden_states:
1744
  all_hidden_states += (hidden_states, )
 
1770
  output_attentions: bool,
1771
  ):
1772
  if self.config._attn_implementation == "flash_attention_2":
1773
+ if MorehFlashAttention is not None:
1774
+ return attention_mask
1775
  if attention_mask is not None and 0.0 in attention_mask:
1776
  return attention_mask
1777
  return None
 
1843
  """
1844
  Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1845
  `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1846
+
1847
  Args:
1848
  attention_mask (`torch.Tensor`):
1849
  A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
 
1901
  self.vocab_size = config.vocab_size
1902
  self.multi_token_heads = config.multi_token_heads
1903
 
1904
+ if self.multi_token_heads is None:
1905
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1906
+ else:
1907
+ self.tokenwise_last_layers = nn.ModuleList(
1908
+ [MotifDecoderLayer(config, config.num_hidden_layers - 1) for _ in range(self.multi_token_heads)])
1909
+ self.tokenwise_lm_heads = nn.ModuleList(
1910
+ [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(self.multi_token_heads)])
1911
+ self.should_skip_separate_backward_pass = self.multi_token_heads is not None
1912
 
1913
  # Initialize weights and apply final processing
1914
  self.post_init()
1915
+
1916
+ # <|_3_|>
1917
+ if config.muP:
1918
+ self.lm_head.__do_scale_tager_mu_dim_base_model__=True
1919
+
1920
+ # <|_4_|>
1921
+ self.lm_head_alpha = 1
1922
+ if config.wesar_weights:
1923
+ self.lm_head_alpha = nn.Parameter(torch.tensor(1).float())
1924
+
1925
  if getattr(config, "tie_word_embeddings", True):
1926
  logger.info('tie embeddings')
1927
  self.tie_weights()
1928
+ else:
1929
+ # <|_5_|>
1930
+ self.lm_head.__do_scale_tager_mu_dim_base_model__ = False
1931
 
1932
  def get_input_embeddings(self):
1933
  return self.model.embed_tokens
 
1947
  def get_decoder(self):
1948
  return self.model
1949
 
1950
+ def multi_token_forward_backward(self,
1951
+ hidden_states: torch.FloatTensor,
1952
+ outputs: MotifModelOutputWithPast,
1953
+ labels: torch.LongTensor,
1954
+ position_ids: Optional[torch.LongTensor],
1955
+ output_attentions: Optional[bool],
1956
+ use_cache: Optional[bool],
1957
+ cache_position: Optional[torch.LongTensor],
1958
+ return_dict: Optional[bool],
1959
+ num_logits_to_keep: int = 0) -> CausalLMOutputWithPast:
1960
+ """
1961
+ This implements the main forward-backward procedure for multi-token model training proposed in
1962
+ the paper https://arxiv.org/abs/2404.19737.
1963
+ Essentially,
1964
+ - The multi-token model tries to predict n (instead of 1) tokens at a time.
1965
+ - Applying this only during training and using first-token prediction during inference is still helpful.
1966
+ - The change in architecture: when using n-token prediction, each token index (between 1 and n) has its own
1967
+ (1) last attention layer and (2) lm head.
1968
+ - The change in loss: sum of cross-entropy losses corresponding to each token index.
1969
+ - Custom forward-backward procedure for memory efficiency: refer to the implementation of `multi_head_forward_backward`.
1970
+ """
1971
+ if not return_dict:
1972
+ raise NotImplementedError("return_dict must be True for multi-token training")
1973
+
1974
+ past_key_values = outputs.past_key_values
1975
+ causal_mask = outputs.causal_mask
1976
+ position_embeddings = outputs.position_embeddings
1977
+
1978
+ if labels is not None:
1979
+ labels = labels.to(hidden_states.device)
1980
+
1981
+ def _tokenwise_forward(hidden_states: torch.Tensor, token_idx):
1982
+ ## Model forward
1983
+ layer = self.tokenwise_last_layers[token_idx]
1984
+ lm_head = self.tokenwise_lm_heads[token_idx]
1985
+
1986
+ layer_outputs = layer(
1987
+ hidden_states,
1988
+ attention_mask=causal_mask,
1989
+ position_ids=position_ids,
1990
+ past_key_values=past_key_values, # TODO: update past_key_values?
1991
+ output_attentions=output_attentions,
1992
+ use_cache=use_cache,
1993
+ cache_position=cache_position,
1994
+ position_embeddings=position_embeddings,
1995
+ )
1996
+ last_hidden_states = layer_outputs[0]
1997
+ if num_logits_to_keep > 0:
1998
+ assert labels is None
1999
+ last_hidden_states = last_hidden_states[:, -num_logits_to_keep:, :]
2000
+ tokenwise_logits = lm_head(last_hidden_states)
2001
+
2002
+ if labels is None:
2003
+ return {
2004
+ "loss": None,
2005
+ "logits": tokenwise_logits,
2006
+ }
2007
+
2008
+ ## Compute loss
2009
+ shift_n = token_idx + 1
2010
+ shift_logits = tokenwise_logits[..., :-shift_n, :].contiguous()
2011
+ shift_labels = labels[..., shift_n:].contiguous()
2012
+
2013
+ loss_fct = CrossEntropyLoss()
2014
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
2015
+ shift_labels = shift_labels.view(-1)
2016
+
2017
+ tokenwise_loss = loss_fct(shift_logits, shift_labels)
2018
+
2019
+ return {
2020
+ "loss": tokenwise_loss,
2021
+ "logits": tokenwise_logits,
2022
+ }
2023
+
2024
+ head_fns = [
2025
+ lambda hidden_states, token_idx=token_idx: _tokenwise_forward(hidden_states, token_idx)
2026
+ for token_idx in range(self.multi_token_heads)
2027
+ ]
2028
+ loss, logits = multi_head_forward_backward(hidden_states,
2029
+ head_fns,
2030
+ return_keys=("loss", "logits"),
2031
+ return_only_first_head=True)
2032
+
2033
+ if not return_dict:
2034
+ output = (logits, ) + outputs[1:]
2035
+ return (loss, ) + output
2036
+
2037
+ return CausalLMOutputWithPast(
2038
+ loss=loss,
2039
+ logits=logits,
2040
+ past_key_values=outputs.past_key_values,
2041
+ hidden_states=outputs.hidden_states,
2042
+ attentions=outputs.attentions,
2043
+ )
2044
+
2045
  @add_start_docstrings_to_model_forward(MOTIF_INPUTS_DOCSTRING)
2046
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
2047
  def forward(
 
2066
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
2067
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
2068
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
2069
+
2070
  num_logits_to_keep (`int`, *optional*):
2071
  Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
2072
  `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
2073
  token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
2074
+
2075
  Returns:
2076
+
2077
  Example:
2078
+
2079
  ```python
2080
  >>> from transformers import AutoTokenizer, MotifForCausalLM
2081
+
2082
+ >>> model = MotifForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
2083
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
2084
+
2085
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
2086
  >>> inputs = tokenizer(prompt, return_tensors="pt")
2087
+
2088
  >>> # Generate
2089
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
2090
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
2097
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2098
 
2099
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
2100
+ outputs_include_causal_mask = self.multi_token_heads is not None
2101
+ outputs_include_position_embeddings = self.multi_token_heads is not None
2102
  outputs: MotifModelOutputWithPast = self.model(
2103
  input_ids=input_ids,
2104
  attention_mask=attention_mask,
 
2110
  output_hidden_states=output_hidden_states,
2111
  return_dict=return_dict,
2112
  cache_position=cache_position,
2113
+ outputs_include_causal_mask=outputs_include_causal_mask,
2114
+ outputs_include_position_embeddings=outputs_include_position_embeddings,
2115
  )
2116
 
2117
  hidden_states = outputs[0]
2118
 
2119
+ if self.multi_token_heads is not None:
2120
+ return self.multi_token_forward_backward(hidden_states,
2121
+ outputs,
2122
+ labels,
2123
+ position_ids,
2124
+ output_attentions,
2125
+ use_cache,
2126
+ cache_position,
2127
+ return_dict,
2128
+ num_logits_to_keep=num_logits_to_keep)
2129
+
2130
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
2131
+ hidden_states = hidden_states * self.lm_head_alpha
2132
  logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
2133
  logits = logits.float()
2134
 
2135
  loss = None
2136
  if labels is not None:
2137
+ logits = logits
2138
  # Shift so that tokens < n predict n
2139
  shift_logits = logits[..., :-1, :].contiguous()
2140
  shift_labels = labels[..., 1:].contiguous()
 
2156
  past_key_values=outputs.past_key_values,
2157
  hidden_states=outputs.hidden_states,
2158
  attentions=outputs.attentions,
2159
+ )