leejunhyeok eunhwanpark-motiftech commited on
Commit
0a87b16
·
1 Parent(s): cc0f6ea

Update modeling_motif.py (#1)

Browse files

- Update modeling_motif.py (f362538a86e469b8b3e07d66dd1d2f23aa4ef81b)
- bugfix (91c40cee07724f4fa12fec66cafa1a2546fd2dbd)
- Update modeling_motif.py (a55dcfd62edeef5966ab77bd04f3622b24cb4175)
- Update modeling_motif.py (2a76ec89b014032a79e5b67a9af06cd960efb51b)
- Update config.json (c28133974bedd8beb8e4d50db3c466e58798d139)
- Update modeling_motif.py (6230f7725953151d5c33872ba28180d75e9ecc7c)
- Update modeling_motif.py (db404ffc2c810b0325e72b181b96bb67535ec217)
- Update config.json (0d851ca4ecccf20f0e1368c2ff83c6e802d23e6d)
- Update modeling_motif.py (607612fcbf1067b65f29c92a4326ea261a97e0e1)
- Update modeling_motif.py (9b405398aa85b7621786d0793cd05f1343e661d9)
- Update modeling_motif.py (bdd0329e7e0cf8c3e2095f665244351a7e1e4c55)
- Update modeling_motif.py (38eae03f20a4d46b13ef774fab49a4043b2a5697)
- Update modeling_motif.py (8855d030415e3f65179c8aad8b3733a6d0ad644d)
- Update modeling_motif.py (20b97f12bd6cfbd1115d7c8d2e3ba19a27d55471)
- Update modeling_motif.py (72cc86d5a3b4d20faf06536fa16a841811c7b475)
- Update modeling_motif.py (80e1a1c498842e2d58fd8e7abf320753eb758ddc)
- Update config.json (9be8a4a64fa7d9cdcb7ec479918fa03034efde8a)
- Update modeling_motif.py (6d0fba540a217fb66acc35aca103c1981aa79ef0)
- Update modeling_motif.py (bd7180cf47aa477313e74d5b739105f2700bdddf)
- Update modeling_motif.py (5ac76cbcfa8efa114a15d46288178215c919dd7f)
- Update modeling_motif.py (5bd1ac1f84c476bf5b1a4fe7535db793809906de)
- Update modeling_motif.py (a4c14b0328abf4e3ae6b0213d777f75e0f3432e7)
- Update config.json (22c1361731f8565e0776c7aee2a0aa460657aa09)
- Update modeling_motif.py (7ebd625f70f9cce13df77f0520364d6f3b26994d)
- Update modeling_motif.py (d48c60537b21d0c3757b2292b47f6b1055e7ef1b)
- Update config.json (756b7c09f36b5a5bb689e8efbd894fd20c5699e5)
- Update config.json (8f9e2731310565ac33e8f6960b19c56be99698d2)
- Update config.json (4321f887f6e1fa1161e114d91ccacb2b8361e750)
- Update config.json (da650a07d5165dde18b47a5b07f5f8de47365f79)
- Update config.json (76d2431ef87346a08c6187122cd60964c3d035db)
- Update modeling_motif.py (cce6844530ccb252625158c7e4bfd5d603f0e475)
- Update modeling_motif.py (c4b7f5e79b10d819c7ced148c4aa2340bb4b82c8)
- Update config.json (498f8bc0c383405c33be3a06a1182ada10ff212d)
- Update modeling_motif.py (f76fc654662109d4549c8ed558fc641137e8a91d)
- Update config.json (8dbdeff6683e4e6555c9c4cf703eee8473de8059)
- Update config.json (1d63261c89ce9d31071d93ac43201d94937fe089)
- Update modeling_motif.py (cfa11bc205bf18888fe7f8dff1ecdce292a615a8)
- Update config.json (14c71d590d1b610bca1480c9066fdd9e18cdfdcb)
- Update config.json (a95fee7e4e0105080cf758bec90d2f6a16205702)
- Update modeling_motif.py (83d1f84e9a8c3bde8109b2a66e38d06f04972b39)
- Update config.json (37ed6924ae8c26c9fe130435bb1c23399dc9a51d)
- Update config.json (eb7a67bdbdc31bcc931a7eacc44f4a4a07751b17)
- Update modeling_motif.py (27913c020c6192fcde29c9e18165a9db8280bc8d)
- Update modeling_motif.py (ba7c5761e3eba518b64edf32d46422fec2175554)
- Update modeling_motif.py (7ec81c6e7fea3393085d6d9d94017bd4d8cbbfd6)
- Update configuration_motif.py (7d2479ad9374145522214574abfda2d012c6d5a4)
- Update modeling_motif.py (d56ef752fd1db815e87b88e6bf4791f558d6363b)
- Update modeling_motif.py (6187f7768db6406e4ba951b99281eb2274c92130)
- Update modeling_motif.py (a162402203d84f6c89c8159070e1457f5613005c)
- Update modeling_motif.py (0ff3917f6694c08e6768a2940b4ba8f508b209b8)
- Update modeling_motif.py (4293a010c114394bff10c64a621d22937df61f46)
- Update modeling_motif.py (7d405b99a275c6241c34279f548dc9a7680c6906)
- Update modeling_motif.py (03a80ebc5d1df4a3f81667d6a2d54608957cd57e)
- Update config.json (c8cabde5806c7009379fb69d1adfe066d2dab70f)
- Update modeling_motif.py (8bdf2ec91a74f9cfda0557ad56999915cb7ac9d1)
- Update modeling_motif.py (097873e443d8e170c89992527e03615d15a71a2e)
- Update modeling_motif.py (95a3a69d05f9284a6241310ba17a92d8db544c02)
- Update README.md (a14cdf2cc082f3be9d297d14287bb9ae7d4d3f88)
- Update configuration_motif.py (7ed4264a3a35b2b61b29dc31c9833135f4b812cd)
- Update generation_config.json (8d10c7cefc977930e3dc46c50b67409bcc28e650)
- Update README.md (48ae3c671d073dfff611e9e1c209fb2ae5f78b43)
- Update README.md (96a8455fc6078a63ad4bb3f340a76ff1c21bc7d7)
- Update README.md (9a018e3364b5300685f90cb5ceebbeb63fb78799)
- Update configuration_motif.py (a77c948640d4559106e5c0e0b59fb8c58a9af0f6)


Co-authored-by: Eunhwan Park <[email protected]>

Files changed (5) hide show
  1. README.md +37 -1
  2. config.json +5 -59
  3. configuration_motif.py +5 -89
  4. generation_config.json +1 -1
  5. modeling_motif.py +108 -1001
README.md CHANGED
@@ -195,4 +195,40 @@ The benchmarks and corresponding scores listed in the table below are taken dire
195
  |MBPP|0-shot|53.9|62.2|60.3|+11.87%|-3.05%|
196
  |MBPP+|0-shot|44.4|50.6|50.8|+14.41%|+0.40%|
197
  |MultiPL-E|0-shot|22.6|34.9|-|-|-|
198
- |||||**Average**|**+18.55%**|**+1.12%**|
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  |MBPP|0-shot|53.9|62.2|60.3|+11.87%|-3.05%|
196
  |MBPP+|0-shot|44.4|50.6|50.8|+14.41%|+0.40%|
197
  |MultiPL-E|0-shot|22.6|34.9|-|-|-|
198
+ |||||**Average**|**+18.55%**|**+1.12%**|
199
+
200
+
201
+ ## How to use
202
+ ```python
203
+ from transformers import AutoModelForCausalLM, AutoTokenizer
204
+
205
+ model = AutoModelForCausalLM.from_pretrained(
206
+ "Motif-Technologies/Motif-2.6B",
207
+ trust_remote_code = True,
208
+ _attn_implementation = "eager", # also supports flash_attention_2
209
+ ).cuda()
210
+
211
+ tokenizer = AutoTokenizer.from_pretrained(
212
+ "Motif-Technologies/Motif-2.6B",
213
+ trust_remote_code = True,
214
+ )
215
+
216
+ query = "What is the capital city of South Korea?"
217
+ input_ids = tokenizer.apply_chat_template(
218
+ [
219
+ {'role': 'system', 'content': 'you are an helpful assistant'},
220
+ {'role': 'user', 'content': query},
221
+ ],
222
+ add_generation_prompt = True,
223
+ return_tensors='pt',
224
+ ).cuda()
225
+
226
+ output = model.generate(input_ids, max_new_tokens=128, pad_token_id=tokenizer.eos_token_id)
227
+ output = tokenizer.decode(res[0, input_ids.shape[-1]:], skip_special_tokens = True)
228
+ print(output)
229
+
230
+ """
231
+ The capital city of South Korea is Seoul. Located in the southern part of the country, Seoul is not only the largest city in South Korea but also one of the largest metropolitan areas in the world.
232
+ It is a vibrant and dynamic city known for its rich history, cultural heritage, and modern amenities. Seoul is a major economic, cultural, and political center in East Asia, and it plays a crucial role in the region's politics, economy, and culture.
233
+ The city is divided into different administrative districts, each with its own unique characteristics and attractions.
234
+ """
config.json CHANGED
@@ -8,82 +8,28 @@
8
  "AutoConfig": "configuration_motif.MotifConfig",
9
  "AutoModelForCausalLM": "modeling_motif.MotifForCausalLM"
10
  },
11
- "bfloat16": true,
12
  "bos_token_id": 219396,
13
- "continual_training": false,
14
- "decoder_split_layers": [],
15
- "decontam_attn": false,
16
- "dim_model_base": 2048,
17
- "dim_model_base_attn": 128,
18
- "dim_model_base_init": 2048,
19
- "dim_model_base_lmh": 1,
20
- "dim_model_base_logits": 2048,
21
- "dim_model_base_lr": 256,
22
- "down_proj_alpha": 0.15625,
23
- "embed_tokens_alpha": null,
24
- "encoder_split_layers": [],
25
  "eos_token_id": 219395,
26
- "first_expansion": false,
27
- "fused_rope": true,
28
- "gate_up_proj_alpha": 0.15625,
29
  "hidden_act": "poly_norm",
30
- "hidden_act_moe": null,
31
  "hidden_size": 2048,
32
- "hidden_states_shrink": 0.17677669529663687,
33
- "init_scale_o": 1,
34
  "initializer_range": 2e-05,
35
- "input_layernorm_alpha": null,
36
  "intermediate_size": 8192,
37
- "k_proj_alpha": 0.15625,
38
- "lm_head_alpha": null,
39
  "loss_reduction": "mean",
40
  "max_position_embeddings": 16384,
41
  "max_window_layers": 28,
42
- "mix_attn": false,
43
  "model_type": "Motif",
44
- "moe": false,
45
- "moe_intermediate_size": null,
46
- "moe_layer": false,
47
- "muP": false,
48
- "multi_token_heads": null,
49
- "n_group": null,
50
- "n_routed_experts": null,
51
- "norm_alpha": null,
52
- "norm_topk_prob": null,
53
  "num_attention_heads": 16,
54
  "num_hidden_layers": 32,
55
  "num_key_value_heads": 16,
56
- "num_stages": false,
57
- "o_proj_alpha": 0.15625,
58
- "post_attention_layernorm_alpha": null,
59
- "q_proj_alpha": 0.15625,
60
  "rms_norm_eps": 1e-06,
61
  "rope_scaling": null,
62
  "rope_theta": 500000.0,
63
- "routed_scaling_factor": null,
64
- "scale_emb": 1,
65
- "scoring_func": null,
66
- "seq_aux": null,
67
  "sliding_window": null,
68
- "tensor_parallel": true,
69
  "tie_word_embeddings": true,
70
- "topk_group": null,
71
- "topk_method": null,
72
- "torch_dtype": "float32",
73
- "transformers_version": "4.51.3",
74
- "use_advanced_parallelization": true,
75
  "use_bias": false,
76
- "use_cache": false,
77
- "use_emb_alpha": false,
78
- "use_fused_mlp": null,
79
- "use_moreh_attention": true,
80
- "use_moreh_moe": false,
81
- "use_mrope": false,
82
- "use_norm_alpha": false,
83
- "use_pipeline": false,
84
- "use_qk_norm": false,
85
  "use_sliding_window": false,
86
- "v_proj_alpha": 0.15625,
87
- "vocab_size": 219520,
88
- "wesar_weights": false
89
- }
 
8
  "AutoConfig": "configuration_motif.MotifConfig",
9
  "AutoModelForCausalLM": "modeling_motif.MotifForCausalLM"
10
  },
 
11
  "bos_token_id": 219396,
 
 
 
 
 
 
 
 
 
 
 
 
12
  "eos_token_id": 219395,
 
 
 
13
  "hidden_act": "poly_norm",
 
14
  "hidden_size": 2048,
 
 
15
  "initializer_range": 2e-05,
 
16
  "intermediate_size": 8192,
 
 
17
  "loss_reduction": "mean",
18
  "max_position_embeddings": 16384,
19
  "max_window_layers": 28,
 
20
  "model_type": "Motif",
 
 
 
 
 
 
 
 
 
21
  "num_attention_heads": 16,
22
  "num_hidden_layers": 32,
23
  "num_key_value_heads": 16,
 
 
 
 
24
  "rms_norm_eps": 1e-06,
25
  "rope_scaling": null,
26
  "rope_theta": 500000.0,
 
 
 
 
27
  "sliding_window": null,
 
28
  "tie_word_embeddings": true,
29
+ "torch_dtype": "bfloat16",
30
+ "transformers_version": "4.46.3",
 
 
 
31
  "use_bias": false,
32
+ "use_cache": true,
 
 
 
 
 
 
 
 
33
  "use_sliding_window": false,
34
+ "vocab_size": 219520
35
+ }
 
 
configuration_motif.py CHANGED
@@ -1,8 +1,9 @@
 
 
 
1
  from transformers.configuration_utils import PretrainedConfig
2
  from transformers.modeling_rope_utils import rope_config_validation
3
  from transformers.utils import logging
4
- from typing import Optional
5
- import math
6
 
7
  logger = logging.get_logger(__name__)
8
 
@@ -13,11 +14,8 @@ class MotifConfig(PretrainedConfig):
13
  Motif model according to the specified arguments, defining the model architecture. Instantiating a configuration
14
  with the defaults will yield a similar configuration to that of
15
  Motif-102B [moreh/Motif-102B](https://huggingface.co/moreh/Motif-102B).
16
-
17
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
18
  documentation from [`PretrainedConfig`] for more information.
19
-
20
-
21
  Args:
22
  vocab_size (`int`, *optional*, defaults to 151936):
23
  Vocabulary size of the Motif model. Defines the number of different tokens that can be represented by the
@@ -97,16 +95,12 @@ class MotifConfig(PretrainedConfig):
97
  The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
98
  attention_dropout (`float`, *optional*, defaults to 0.0):
99
  The dropout ratio for the attention probabilities.
100
-
101
  ```python
102
  >>> from transformers import MotifModel, MotifConfig
103
-
104
  >>> # Initializing a Motif style configuration
105
  >>> configuration = MotifConfig()
106
-
107
  >>> # Initializing a model from the Motif-102B style configuration
108
  >>> model = MotifModel(configuration)
109
-
110
  >>> # Accessing the model configuration
111
  >>> configuration = model.config
112
  ```"""
@@ -134,13 +128,8 @@ class MotifConfig(PretrainedConfig):
134
  sliding_window=4096,
135
  max_window_layers=28,
136
  attention_dropout=0.0,
137
- multi_token_heads: Optional[int] = None,
138
  **kwargs,
139
  ):
140
- """
141
- Arguments:
142
- multi_token_heads: If not None, use multi-token heads as in the paper https://arxiv.org/pdf/2404.19737
143
- """
144
 
145
  self.vocab_size = vocab_size
146
  self.max_position_embeddings = max_position_embeddings
@@ -165,87 +154,14 @@ class MotifConfig(PretrainedConfig):
165
  self.rope_scaling = rope_scaling
166
  self.attention_dropout = attention_dropout
167
 
168
- ###kwargs
169
-
170
- # some scale factors
171
-
172
- self.scale_emb = getattr(kwargs, "scale_emb", 1)
173
- self.init_scale_o = getattr(kwargs, "init_scale_o", 1)
174
-
175
- # muparam
176
- self.hidden_states_shrink = 1 / math.sqrt(num_hidden_layers)
177
- self.dim_model_base = hidden_size
178
- self.dim_model_base_attn = (hidden_size // num_attention_heads)
179
- self.dim_model_base_init = hidden_size
180
- self.dim_model_base_lr = getattr(kwargs, "dim_model_base_lr", hidden_size//8)
181
- self.dim_model_base_lmh = 1
182
- self.dim_model_base_logits = hidden_size
183
-
184
- self.muP = getattr(kwargs, "muP", False)
185
- # proxy hidden size ( following YuLan-Mini )
186
- # reparameterization(wesar_weights)
187
- logger.info(kwargs)
188
- self.wesar_weights = getattr(kwargs, "wesar_weights", False)
189
- logger.info(f'initial wesar reparameterization : {self.wesar_weights}')
190
-
191
- # alpha (scale factor)
192
- self.embed_tokens_alpha = getattr(kwargs, "embed_tokens_alpha", None)
193
- self.q_proj_alpha = getattr(kwargs, "q_proj_alpha", None)
194
- self.k_proj_alpha = getattr(kwargs, "k_proj_alpha", None)
195
- self.v_proj_alpha = getattr(kwargs, "v_proj_alpha", None)
196
- self.o_proj_alpha = getattr(kwargs, "o_proj_alpha", None)
197
- self.down_proj_alpha = getattr(kwargs, "down_proj_alpha", None)
198
- self.gate_up_proj_alpha = getattr(kwargs, "gate_up_proj_alpha", None)
199
- self.input_layernorm_alpha = getattr(kwargs, "input_layernorm_alpha", None)
200
- self.post_attention_layernorm_alpha = getattr(kwargs, "post_attention_layernorm_alpha", None)
201
- self.norm_alpha = getattr(kwargs, "norm_alpha", None)
202
- self.lm_head_alpha = getattr(kwargs, "lm_head_alpha", None)
203
- self.use_norm_alpha = getattr(kwargs, "use_norm_alpha", False)
204
- self.use_emb_alpha = getattr(kwargs, "use_emb_alpha", False)
205
-
206
  # Validate the correctness of rotary position embeddings parameters
207
  # BC: if there is a 'type' field, move it to 'rope_type'.
208
  if self.rope_scaling is not None and "type" in self.rope_scaling:
209
  self.rope_scaling["rope_type"] = self.rope_scaling["type"]
210
  rope_config_validation(self)
211
-
212
- self.multi_token_heads = multi_token_heads
213
- self.multi_token_config_validation()
214
-
215
-
216
-
217
- # moe
218
- self.topk_method = getattr(kwargs, "topk_method", None)
219
- self.scoring_func = getattr(kwargs, "scoring_func", None)
220
- self.routed_scaling_factor = getattr(kwargs, "routed_scaling_factor", None)
221
- self.norm_topk_prob = getattr(kwargs, "norm_topk_prob", None)
222
- self.seq_aux = getattr(kwargs, "seq_aux", None)
223
- self.hidden_act_moe = getattr(kwargs, "hidden_act_moe", None)
224
-
225
-
226
- self.n_group = getattr(kwargs, "n_group", None)
227
- self.n_routed_experts = getattr(kwargs, "n_routed_experts", None)
228
- self.moe_intermediate_size = getattr(kwargs, "moe_intermediate_size", None)
229
- self.topk_group = getattr(kwargs, "topk_group", None)
230
-
231
-
232
- self.use_fused_mlp = getattr(kwargs, "use_fused_mlp", None)
233
- self.use_moreh_moe = getattr(kwargs, "use_moreh_moe", False)
234
- self.continual_training = getattr(kwargs, "continual_training", False)
235
-
236
- # external
237
- self.first_expansion = getattr(kwargs, "first_expansion", False)
238
- self.moe_layer = getattr(kwargs, "moe_layer", False)
239
-
240
-
241
-
242
  super().__init__(
243
  tie_word_embeddings=tie_word_embeddings,
244
  **kwargs,
245
  )
246
- logger.info(f' kwargs : {kwargs}')
247
- logger.info(f'after wesar reparameterization : {self.wesar_weights}')
248
-
249
- def multi_token_config_validation(self):
250
- if self.multi_token_heads is not None:
251
- assert isinstance(self.multi_token_heads, int) and self.multi_token_heads >= 1
 
1
+ import math
2
+ from typing import Optional
3
+
4
  from transformers.configuration_utils import PretrainedConfig
5
  from transformers.modeling_rope_utils import rope_config_validation
6
  from transformers.utils import logging
 
 
7
 
8
  logger = logging.get_logger(__name__)
9
 
 
14
  Motif model according to the specified arguments, defining the model architecture. Instantiating a configuration
15
  with the defaults will yield a similar configuration to that of
16
  Motif-102B [moreh/Motif-102B](https://huggingface.co/moreh/Motif-102B).
 
17
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
18
  documentation from [`PretrainedConfig`] for more information.
 
 
19
  Args:
20
  vocab_size (`int`, *optional*, defaults to 151936):
21
  Vocabulary size of the Motif model. Defines the number of different tokens that can be represented by the
 
95
  The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
96
  attention_dropout (`float`, *optional*, defaults to 0.0):
97
  The dropout ratio for the attention probabilities.
 
98
  ```python
99
  >>> from transformers import MotifModel, MotifConfig
 
100
  >>> # Initializing a Motif style configuration
101
  >>> configuration = MotifConfig()
 
102
  >>> # Initializing a model from the Motif-102B style configuration
103
  >>> model = MotifModel(configuration)
 
104
  >>> # Accessing the model configuration
105
  >>> configuration = model.config
106
  ```"""
 
128
  sliding_window=4096,
129
  max_window_layers=28,
130
  attention_dropout=0.0,
 
131
  **kwargs,
132
  ):
 
 
 
 
133
 
134
  self.vocab_size = vocab_size
135
  self.max_position_embeddings = max_position_embeddings
 
154
  self.rope_scaling = rope_scaling
155
  self.attention_dropout = attention_dropout
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  # Validate the correctness of rotary position embeddings parameters
158
  # BC: if there is a 'type' field, move it to 'rope_type'.
159
  if self.rope_scaling is not None and "type" in self.rope_scaling:
160
  self.rope_scaling["rope_type"] = self.rope_scaling["type"]
161
  rope_config_validation(self)
162
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  super().__init__(
164
  tie_word_embeddings=tie_word_embeddings,
165
  **kwargs,
166
  )
167
+ logger.info(f' kwargs : {kwargs}')
 
 
 
 
 
generation_config.json CHANGED
@@ -6,5 +6,5 @@
6
  219405
7
  ],
8
  "transformers_version": "4.51.3",
9
- "use_cache": false
10
  }
 
6
  219405
7
  ],
8
  "transformers_version": "4.51.3",
9
+ "use_cache": true
10
  }
modeling_motif.py CHANGED
@@ -1,178 +1,32 @@
1
  import math
 
2
  from typing import List, Optional, Tuple, Union
3
 
4
  import torch
 
5
  import torch.utils.checkpoint
6
  from torch import nn
7
  from torch.nn import CrossEntropyLoss
 
 
8
  from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
9
  from transformers.generation import GenerationMixin
10
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
11
  from transformers.modeling_flash_attention_utils import _flash_attention_forward
12
- from transformers.modeling_outputs import (
13
- CausalLMOutputWithPast,
14
- ModelOutput,
15
- )
16
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
17
  from transformers.modeling_utils import PreTrainedModel
18
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
19
- from transformers.utils import (
20
- add_start_docstrings,
21
- add_start_docstrings_to_model_forward,
22
- is_flash_attn_greater_or_equal_2_10,
23
- is_flash_attn_2_available,
24
- logging,
25
- replace_return_docstrings,
26
- )
27
- from .configuration_motif import MotifConfig
28
- from dataclasses import dataclass
29
-
30
- import torch.nn.functional as F
31
- import time
32
-
33
- logger = logging.get_logger(__name__)
34
-
35
- if is_flash_attn_2_available():
36
- from transformers.modeling_flash_attention_utils import _flash_attention_forward
37
-
38
- try:
39
- moreh_ops = torch.ops.moreh
40
- MorehRMSNorm = moreh_ops.T5LayerNorm
41
- ScaledDotProductAttention = moreh_ops.scaled_dot_product_attention
42
- MorehFlashAttention = moreh_ops.flash_attention
43
- logger.warning_once("Using moreh ops")
44
- except AttributeError:
45
- MorehRMSNorm = None
46
- ScaledDotProductAttention = None
47
- MorehFlashAttention = None
48
- logger.warning_once("Failed to import moreh ops")
49
-
50
-
51
-
52
- # DEBUG = False
53
- # logger.info(f"DEBUG: {DEBUG} : will log timing")
54
- # def log_timing(obj):
55
- # """Decorator to log timing of function or class execution"""
56
- # if isinstance(obj, type):
57
- # # If decorating a class
58
- # class TimedClass(obj):
59
- # def __getattribute__(self, name):
60
- # attr = super().__getattribute__(name)
61
- # if callable(attr) and not name.startswith('__'):
62
- # def timed_method(*args, **kwargs):
63
- # if not DEBUG:
64
- # return attr(*args, **kwargs)
65
- # if name != "forward":
66
- # return attr(*args, **kwargs)
67
-
68
- # start_time = time.time()
69
- # logger.info(f"Entering {obj.__name__}.{name}")
70
- # result = attr(*args, **kwargs)
71
- # end_time = time.time()
72
- # logger.info(f"Exiting {obj.__name__}.{name}, took {end_time - start_time:.4f} seconds")
73
- # return result
74
- # return timed_method
75
- # return attr
76
- # return TimedClass
77
- # else:
78
- # # If decorating a function
79
- # def wrapper(*args, **kwargs):
80
- # if not DEBUG:
81
- # return obj(*args, **kwargs)
82
-
83
- # start_time = time.time()
84
- # logger.info(f"Entering {obj.__name__}")
85
- # result = obj(*args, **kwargs)
86
- # end_time = time.time()
87
- # logger.info(f"Exiting {obj.__name__}, took {end_time - start_time:.4f} seconds")
88
- # return result
89
- # return wrapper
90
-
91
-
92
-
93
- #_CHECKPOINT_FOR_DOC = "moreh/Motif-102B"
94
- _CONFIG_FOR_DOC = "MotifConfig"
95
-
96
- #from .moreh_moe import MorehMoeMLP, MorehMoeFusedMLP
97
-
98
- import torch
99
- from transformers.activations import ACT2CLS as _ACT2CLS
100
- from transformers.activations import ClassInstantier
101
- moreh_ops = torch.ops.moreh
102
-
103
- from typing import Callable, Dict, List, Tuple
104
-
105
- import torch
106
-
107
-
108
- # @log_timing
109
- def multi_head_forward_backward(shared_activation: torch.Tensor,
110
- head_fns: List[Callable[[torch.Tensor], Dict[str, torch.Tensor]]],
111
- return_keys=("loss", ),
112
- return_only_first_head=True) -> Tuple[torch.Tensor, ...]:
113
- """
114
- The forward-backward pattern introduced in the paper https://arxiv.org/abs/2404.19737
115
- to reduce memory overhead due to activations from multiple heads.
116
-
117
- Args:
118
- - shared_activation: the shared activation across all heads
119
- - head_fns: the head-wise forward computations that start from `shared_activation`.
120
- it should output a dictionary of tensors with keys matching `return_keys`
121
- - return_keys: the keys to return in order
122
- - return_only_first_head: whether to return only the values from the first head
123
-
124
- Returns:
125
- - a tuple of return tensors
126
-
127
- Side effect:
128
- - (only when `torch.is_grad_enabled()`)
129
- the gradients accumulated as if `sum(head_fn(shared_activation)["loss"] for head_fn in head_fns).backward()` had been called
130
- """
131
- if not return_only_first_head:
132
- raise NotImplementedError
133
-
134
- return_key_set = set(return_keys)
135
- if "loss" not in return_key_set:
136
- raise Exception("'loss' is a required return key.")
137
-
138
- detached_shared_activation = shared_activation.detach()
139
- detached_shared_activation.requires_grad = True
140
- return_values = {key: None for key in return_keys}
141
- for head_idx, head_fn in enumerate(head_fns):
142
- if head_idx > 0 and not torch.is_grad_enabled():
143
- continue
144
-
145
- # forward pass for the head
146
- headwise_outputs = head_fn(detached_shared_activation)
147
- if set(headwise_outputs.keys()) != return_key_set:
148
- raise Exception(f"Headwise output keys {headwise_outputs.keys()} do not match return keys {return_keys}.")
149
-
150
- # backward pass for the head
151
- # effect 1: the parameters of the head
152
- # effect 2: gradient accumulated in `detached_shared_activation.grad`
153
- if torch.is_grad_enabled():
154
- headwise_loss = headwise_outputs["loss"]
155
- headwise_loss.backward(
156
- ) # NOTE: You do not need to retain graph since no graph is shared across backward passes
157
-
158
- if head_idx == 0:
159
- for key in return_keys:
160
- return_values[key] = headwise_outputs[key]
161
-
162
- assert all(value is not None for value in return_values.values())
163
-
164
- # backward pass for the shared part
165
- if torch.is_grad_enabled():
166
- shared_activation.backward(detached_shared_activation.grad)
167
 
168
- return tuple(return_values[key] for key in return_keys)
169
 
170
 
171
  class PolyNorm(torch.nn.Module):
172
- """
173
  A trainable activation function introduced in https://arxiv.org/html/2411.03884v1.
174
- The code is copied from https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md,
175
- with the change `* torch.rsqrt` => `/ torch.sqrt` for potential MAF incompatibility.
176
  """
177
 
178
  def __init__(self, eps=1e-6):
@@ -189,29 +43,16 @@ class PolyNorm(torch.nn.Module):
189
  x ** 2) + self.weight[2] * self._norm(x) + self.bias
190
 
191
 
192
- class PolyNorm_Test(torch.nn.Module):
193
- """
194
- A trainable activation function introduced in https://arxiv.org/html/2411.03884v1.
195
- The code is copied from https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md,
196
- with the change `* torch.rsqrt` => `/ torch.sqrt` for potential MAF incompatibility.
197
- """
198
-
199
- def __init__(self, eps=1e-6):
200
- super(PolyNorm_Test, self).__init__()
201
- self.weight = torch.nn.Parameter(torch.ones(3) / 3)
202
- self.bias = torch.nn.Parameter(torch.zeros(1))
203
- self.eps = eps
204
-
205
- def forward(self, x):
206
-
207
- #return torch.nn.SiLU(x)
208
- return moreh_ops.poly_norm(x, self.weight, self.bias)
209
-
210
-
211
- CUSTOM_ACT2CLS = {"poly_norm": PolyNorm_Test, "poly_norm_test": PolyNorm_Test}
212
  ACT2CLS = {**_ACT2CLS, **CUSTOM_ACT2CLS}
213
  ACT2FN = ClassInstantier(ACT2CLS)
214
 
 
 
 
 
 
 
215
 
216
 
217
  class MotifRMSNorm(nn.Module):
@@ -235,7 +76,7 @@ class MotifRMSNorm(nn.Module):
235
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
236
 
237
 
238
- ALL_LAYERNORM_LAYERS.append(MotifRMSNorm if MorehRMSNorm is None else MorehRMSNorm)
239
 
240
 
241
  class MotifRotaryEmbeddingWithCache(nn.Module):
@@ -267,7 +108,6 @@ class MotifRotaryEmbeddingWithCache(nn.Module):
267
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
268
  self.register_buffer("inv_freq", inv_freq, persistent=False)
269
 
270
- # Build here to make `torch.jit.trace` work.
271
  self._set_cos_sin_cache(seq_len=max_position_embeddings,
272
  device=self.inv_freq.device,
273
  dtype=torch.get_default_dtype())
@@ -288,12 +128,11 @@ class MotifRotaryEmbeddingWithCache(nn.Module):
288
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
289
 
290
  return (
291
- self.cos_cached[None, :seq_len].to(dtype=x.dtype),
292
- self.sin_cached[None, :seq_len].to(dtype=x.dtype),
293
  )
294
 
295
 
296
- # @log_timing
297
  class MotifRotaryEmbedding(nn.Module):
298
 
299
  def __init__(
@@ -324,7 +163,6 @@ class MotifRotaryEmbedding(nn.Module):
324
  self.max_seq_len_cached = max_position_embeddings
325
  self.original_max_seq_len = max_position_embeddings
326
  else:
327
- # BC: "rope_type" was originally "type"
328
  if config.rope_scaling is not None:
329
  self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
330
  else:
@@ -386,10 +224,10 @@ class MotifRotaryEmbedding(nn.Module):
386
  def rotate_half(x):
387
  """
388
  Rotates half of the dimensions of the input tensor using torch.roll and in-place negation.
389
-
390
  Args:
391
  x (torch.Tensor): The input tensor.
392
-
393
  Returns:
394
  torch.Tensor: A tensor where the latter half of the dimensions are negated
395
  and moved before the first half.
@@ -401,8 +239,7 @@ def rotate_half(x):
401
  return rotated_tensor
402
 
403
 
404
- # @log_timing
405
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1, fused_rope=True):
406
  """
407
  Applies rotary position embeddings to the input tensors.
408
 
@@ -411,438 +248,47 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1, fus
411
  k (torch.Tensor): Key tensor of shape (B, NH, S, D_KV).
412
  cos (torch.Tensor): Cosine values for rotary embedding.
413
  sin (torch.Tensor): Sine values for rotary embedding.
414
- unsqueeze_dim (int, optional): Dimension along which `cos` and `sin` are unsqueezed.
415
  Defaults to 1.
416
- fused_rope (bool, optional): If True, applies fused rotary embeddings using
417
- `moreh_ops.apply_rotary_emb`. If False, computes rotary embeddings manually.
418
- Defaults to False.
419
 
420
  Returns:
421
  Tuple[torch.Tensor, torch.Tensor]: Returns transformed query and key tensors after applying rotary embeddings.
422
  """
423
  '''
424
- # (B, NH, S, D_KV) -> (B, S, NH, D_KV)
425
  cos = cos.unsqueeze(unsqueeze_dim)
426
  sin = sin.unsqueeze(unsqueeze_dim)
427
  q_embed = (q * cos) + (rotate_half(q) * sin)
428
  k_embed = (k * cos) + (rotate_half(k) * sin)
429
  '''
430
- #cos = cos[position_ids]
431
- #sin = sin[position_ids]
432
-
433
- #cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
434
- #sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
435
-
436
- q = q.transpose(1, 2)
437
- k = k.transpose(1, 2)
438
-
439
- # Expand 'batch' dim
440
- cos = cos.expand(q.shape[0], *cos.shape[1:])
441
- sin = sin.expand(q.shape[0], *sin.shape[1:])
442
-
443
- q_embed = moreh_ops.apply_rotary_emb(q, cos, sin, opcode=1)
444
- k_embed = moreh_ops.apply_rotary_emb(k, cos, sin, opcode=1)
445
 
446
- # (B, S, NH, D_KV) -> (B, NH, S, D_KV)
447
- q_embed = q_embed.transpose(1, 2)
448
- k_embed = k_embed.transpose(1, 2)
449
 
450
- return q_embed, k_embed
451
-
452
-
453
- # @log_timing
454
  class MotifMLP(nn.Module):
455
 
456
  def __init__(self, config):
457
  super().__init__()
458
  self.hidden_size = config.hidden_size
459
  self.intermediate_size = config.intermediate_size
460
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
461
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
462
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
463
  self.act_fn = ACT2FN[config.hidden_act]
464
 
465
- if config.wesar_weights:
466
- self.gate_up_proj_alpha = nn.Parameter(torch.tensor(1) *config.gate_up_proj_alpha)
467
- self.down_proj_alpha = nn.Parameter(torch.tensor(1) * config.down_proj_alpha)
468
- else:
469
- self.gate_up_proj_alpha=1
470
- self.down_proj_alpha=1
471
- if config.muP:
472
- self.down_proj.__do_scale_tager__ = True
473
- self.gate_proj.__do_scale_tager_mu_dim_model__ = True
474
- self.up_proj.__do_scale_tager_mu_dim_model__ = True
475
- self.down_proj.__do_scale_tager_mu_ffn__ = True
476
-
477
-
478
  def forward(self, hidden_state):
479
- hidden_state = hidden_state*self.gate_up_proj_alpha
480
- #hidden_state = self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))*
481
- return self.down_proj_alpha*self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
482
-
483
-
484
- class MorehMoeFusedMLP(nn.Module):
485
- def __init__(self,
486
- ffn_dim,
487
- hidden_dim,
488
- hidden_act_moe,
489
- num_experts,
490
- num_groups=1,
491
- device=None,
492
- continual_training=False):
493
- super().__init__()
494
- self.ffn_dim = ffn_dim
495
- self.hidden_dim = hidden_dim
496
- self.hidden_act_moe = hidden_act_moe
497
-
498
- self.num_experts = num_experts
499
- self.num_groups = num_groups
500
-
501
- assert self.num_experts % self.num_groups == 0
502
- self.num_experts_per_group = self.num_experts // self.num_groups
503
-
504
- ## bsz, seq, group size, 2*ffn_size
505
-
506
- moreh_ops = torch.ops.moreh
507
- self.w13 = nn.ModuleList([
508
- moreh_ops.MoeFanInLinear(self.hidden_dim,
509
- self.ffn_dim * 2,
510
- bias=False,
511
- num_experts=self.num_experts_per_group,
512
- device=device)
513
- for _ in range(self.num_groups)
514
- ])
515
-
516
- self.w2 = nn.ModuleList([
517
- moreh_ops.MoeFanOutLinear(self.ffn_dim,
518
- self.hidden_dim,
519
- bias=False,
520
- num_experts=self.num_experts_per_group,
521
- device=device)
522
- for _ in range(self.num_groups)
523
- ])
524
-
525
- ## use silu?
526
- self.act_fn = ACT2FN[self.hidden_act_moe]
527
-
528
- if continual_training:
529
- logger.info('two optipons 1. zero init all weights, 2. add scaling param to moe output.')
530
- self._zero_init()
531
-
532
- def _zero_init(self):
533
- for module in self.w2:
534
- for n,param in module.named_parameters():
535
- logger.info(f'{n} {param.shape}')
536
- param.data.zero_()
537
-
538
-
539
- def forward(self, hidden_states, selected_experts, routing_weights):
540
- w13_final_output = None
541
- for group_idx in range(self.num_groups):
542
- w13_output_in_group = self._get_w13_output(hidden_states,
543
- selected_experts,
544
- group_idx)
545
- if w13_final_output is None:
546
- w13_final_output = w13_output_in_group
547
- else:
548
- w13_final_output += w13_output_in_group
549
-
550
- current_hidden_states = self.act_fn(
551
- w13_final_output[:, :, :, :self.ffn_dim]
552
- ) * w13_final_output[:, :, :, self.ffn_dim:]
553
-
554
- final_hidden_states = None
555
- for group_idx in range(self.num_groups):
556
- w2_output_in_group = self._get_w2_output(current_hidden_states,
557
- selected_experts,
558
- routing_weights, group_idx)
559
- if final_hidden_states is None:
560
- final_hidden_states = w2_output_in_group
561
- else:
562
- final_hidden_states += w2_output_in_group
563
- return final_hidden_states
564
-
565
- def _get_w13_output(self, hidden_states, selected_experts, group_idx):
566
- selected_experts_in_group = selected_experts - (
567
- group_idx * self.num_experts_per_group)
568
-
569
- w13_output = self.w13[group_idx](hidden_states,
570
- selected_experts_in_group)
571
- return w13_output
572
-
573
- def _get_w2_output(self, hidden_states, selected_experts, routing_weights,
574
- group_idx):
575
- selected_experts_in_group = selected_experts - (
576
- group_idx * self.num_experts_per_group)
577
- output = self.w2[group_idx](hidden_states, selected_experts_in_group,
578
- routing_weights)
579
- return output
580
-
581
-
582
- class MoEGate(nn.Module):
583
-
584
- def __init__(self, config):
585
- super().__init__()
586
- self.config = config
587
- self.top_k = config.num_experts_per_tok
588
- self.n_routed_experts = config.n_routed_experts
589
- self.routed_scaling_factor = config.routed_scaling_factor
590
- self.scoring_func = config.scoring_func
591
- self.seq_aux = config.seq_aux
592
- self.topk_method = config.topk_method
593
- self.n_group = config.n_group
594
- self.topk_group = config.topk_group
595
-
596
- # topk selection algorithm
597
- self.norm_topk_prob = config.norm_topk_prob
598
- self.gating_dim = config.hidden_size
599
- self.weight = nn.Parameter(
600
- torch.empty((self.n_routed_experts, self.gating_dim)))
601
- if self.topk_method == "noaux_tc":
602
- self.e_score_correction_bias = nn.Parameter(
603
- torch.empty((self.n_routed_experts)))
604
- self.reset_parameters()
605
-
606
- def reset_parameters(self) -> None:
607
- import torch.nn.init as init
608
-
609
- init.kaiming_uniform_(self.weight, a=math.sqrt(5))
610
-
611
- def forward(self, hidden_states):
612
- bsz, seq_len, h = hidden_states.shape
613
- ### compute gating score
614
- hidden_states = hidden_states.view(-1, h)
615
- logits = F.linear(hidden_states.type(torch.float32),
616
- self.weight.type(torch.float32), None)
617
- if self.scoring_func == "sigmoid":
618
- scores = logits.sigmoid()
619
- else:
620
- raise NotImplementedError(
621
- f"insupportable scoring function for MoE gating: {self.scoring_func}"
622
- )
623
-
624
- ### select top-k experts
625
- if self.topk_method == "greedy":
626
- topk_weight, topk_idx = torch.topk(scores,
627
- k=self.top_k,
628
- dim=-1,
629
- sorted=False)
630
- elif self.topk_method == "group_limited_greedy":
631
- group_scores = (scores.view(bsz * seq_len, self.n_group,
632
- -1).max(dim=-1).values) # [n, n_group]
633
- group_idx = torch.topk(group_scores,
634
- k=self.topk_group,
635
- dim=-1,
636
- sorted=False)[1] # [n, top_k_group]
637
- group_mask = torch.zeros_like(group_scores) # [n, n_group]
638
- group_mask.scatter_(1, group_idx, 1) # [n, n_group]
639
- score_mask = (group_mask.unsqueeze(-1).expand(
640
- bsz * seq_len, self.n_group,
641
- self.n_routed_experts // self.n_group).reshape(
642
- bsz * seq_len, -1)) # [n, e]
643
- tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
644
- topk_weight, topk_idx = torch.topk(tmp_scores,
645
- k=self.top_k,
646
- dim=-1,
647
- sorted=False)
648
- elif self.topk_method == "noaux_tc":
649
- ### will be used. ###
650
- scores_for_choice = scores.view(
651
- bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
652
- group_scores = (scores_for_choice.view(
653
- bsz * seq_len, self.n_group,
654
- -1).topk(2, dim=-1)[0].sum(dim=-1)) # [n, n_group]
655
- group_idx = torch.topk(group_scores,
656
- k=self.topk_group,
657
- dim=-1,
658
- sorted=False)[1] # [n, top_k_group]
659
- group_mask = torch.zeros_like(group_scores) # [n, n_group]
660
- group_mask.scatter_(1, group_idx, 1) # [n, n_group]
661
- score_mask = (group_mask.unsqueeze(-1).expand(
662
- bsz * seq_len, self.n_group,
663
- self.n_routed_experts // self.n_group).reshape(
664
- bsz * seq_len, -1)) # [n, e]
665
- tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(),
666
- 0.0) # [n, e]
667
- _, topk_idx = torch.topk(tmp_scores,
668
- k=self.top_k,
669
- dim=-1,
670
- sorted=False)
671
- topk_weight = scores.gather(1, topk_idx)
672
- else:
673
- raise NotImplementedError(
674
- f"insupportable TopK function for MoE gating: {self.topk_method}"
675
- )
676
-
677
- ### norm gate to sum 1
678
- if self.top_k > 1 and self.norm_topk_prob:
679
- denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
680
- topk_weight = topk_weight / denominator
681
- topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
682
-
683
- return topk_idx, topk_weight
684
-
685
-
686
- class MotifMoE(nn.Module):
687
- """
688
- A mixed expert module containing shared experts.
689
- """
690
- def __init__(self, config):
691
- super().__init__()
692
- self.config = config
693
- self.num_experts_per_tok = config.num_experts_per_tok
694
- self.use_moreh_moe = config.use_moreh_moe
695
- self.use_fused_mlp = config.use_fused_mlp
696
-
697
- if hasattr(config, "ep_size") and config.ep_size > 1:
698
- assert config.ep_size == dist.get_world_size()
699
- assert not config.use_moreh_moe
700
- self.ep_size = config.ep_size
701
- self.experts_per_rank = config.n_routed_experts // config.ep_size
702
- self.ep_rank = dist.get_rank()
703
- self.experts = nn.ModuleList([
704
- (DeepseekV3MLP(config,
705
- intermediate_size=config.moe_intermediate_size)
706
- if i >= self.ep_rank * self.experts_per_rank and i <
707
- (self.ep_rank + 1) * self.experts_per_rank else None)
708
- for i in range(config.n_routed_experts)
709
- ])
710
- else:
711
- self.ep_size = 1
712
- self.experts_per_rank = config.n_routed_experts
713
- self.ep_rank = 0
714
- if self.use_moreh_moe:
715
- if not self.use_fused_mlp:
716
- self.experts = MorehMoeMLP(
717
- ffn_dim=config.moe_intermediate_size,
718
- hidden_dim=config.hidden_size,
719
- hidden_act_moe=config.hidden_act_moe,
720
- num_experts=config.n_routed_experts,
721
- device=None)
722
- else:
723
- ## group expert.
724
- self.experts = MorehMoeFusedMLP(
725
- ffn_dim=config.moe_intermediate_size,
726
- hidden_dim=config.hidden_size,
727
- hidden_act_moe=config.hidden_act_moe,
728
- num_experts=config.n_routed_experts,
729
- num_groups=config.n_group,
730
- device=None,
731
- continual_training=config.continual_training,
732
- )
733
- else:
734
- self.experts = nn.ModuleList([
735
- DeepseekV3MLP(
736
- config, intermediate_size=config.moe_intermediate_size)
737
- for i in range(config.n_routed_experts)
738
- ])
739
-
740
- self.gate = MoEGate(config)
741
-
742
- def forward(self, hidden_states):
743
- identity = hidden_states
744
- orig_shape = hidden_states.shape
745
- topk_idx, topk_weight = self.gate(hidden_states)
746
- if self.use_moreh_moe:
747
- y = self.experts(hidden_states, topk_idx.view(*orig_shape[:-1], -1),
748
- topk_weight.view(*orig_shape[:-1], -1))
749
- y = y.type(hidden_states.dtype)
750
- else:
751
- hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
752
- flat_topk_idx = topk_idx.view(-1)
753
- if self.training:
754
- hidden_states = hidden_states.repeat_interleave(
755
- self.num_experts_per_tok, dim=0)
756
- y = torch.empty_like(hidden_states)
757
- for i, expert in enumerate(self.experts):
758
- y[flat_topk_idx == i] = expert(
759
- hidden_states[flat_topk_idx == i])
760
- y = (y.view(*topk_weight.shape, -1) *
761
- topk_weight.unsqueeze(-1)).sum(dim=1)
762
- y = y.type(hidden_states.dtype)
763
- y = y.view(*orig_shape)
764
- # y = AddAuxiliaryLoss.apply(y, aux_loss)
765
- else:
766
- y = self.moe_infer(hidden_states, topk_idx,
767
- topk_weight).view(*orig_shape)
768
- return y, identity
769
-
770
- @torch.no_grad()
771
- def moe_infer(self, x, topk_ids, topk_weight):
772
- cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
773
- cnts.scatter_(1, topk_ids, 1)
774
- tokens_per_expert = cnts.sum(dim=0)
775
- idxs = topk_ids.view(-1).argsort()
776
- sorted_tokens = x[idxs // topk_ids.shape[1]]
777
- sorted_tokens_shape = sorted_tokens.shape
778
- if self.ep_size > 1:
779
- tokens_per_ep_rank = tokens_per_expert.view(self.ep_size,
780
- -1).sum(dim=1)
781
- tokens_per_expert_group = tokens_per_expert.new_empty(
782
- tokens_per_expert.shape[0])
783
- dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
784
- output_splits = (tokens_per_expert_group.view(
785
- self.ep_size, -1).sum(1).cpu().numpy().tolist())
786
- gathered_tokens = sorted_tokens.new_empty(
787
- tokens_per_expert_group.sum(dim=0).cpu().item(),
788
- sorted_tokens.shape[1])
789
- input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
790
- dist.all_to_all(
791
- list(gathered_tokens.split(output_splits)),
792
- list(sorted_tokens.split(input_split_sizes)),
793
- )
794
- tokens_per_expert_post_gather = tokens_per_expert_group.view(
795
- self.ep_size, self.experts_per_rank).sum(dim=0)
796
- gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],),
797
- dtype=np.int32)
798
- s = 0
799
- for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
800
- gatherd_idxs[s:s + k] = i % self.experts_per_rank
801
- s += k
802
- gatherd_idxs = gatherd_idxs.argsort()
803
- sorted_tokens = gathered_tokens[gatherd_idxs]
804
- tokens_per_expert = tokens_per_expert_post_gather
805
- tokens_per_expert = tokens_per_expert.cpu().numpy()
806
-
807
- outputs = []
808
- start_idx = 0
809
- for i, num_tokens in enumerate(tokens_per_expert):
810
- end_idx = start_idx + num_tokens
811
- if num_tokens == 0:
812
- continue
813
- expert = self.experts[i + self.ep_rank * self.experts_per_rank]
814
- tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
815
- expert_out = expert(tokens_for_this_expert)
816
- outputs.append(expert_out)
817
- start_idx = end_idx
818
-
819
- outs = torch.cat(outputs,
820
- dim=0) if len(outputs) else sorted_tokens.new_empty(0)
821
- if self.ep_size > 1:
822
- new_x = torch.empty_like(outs)
823
- new_x[gatherd_idxs] = outs
824
- gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
825
- dist.all_to_all(
826
- list(gathered_tokens.split(input_split_sizes)),
827
- list(new_x.split(output_splits)),
828
- )
829
- outs = gathered_tokens
830
-
831
- new_x = torch.empty_like(outs)
832
- new_x[idxs] = outs
833
- final_out = (new_x.view(
834
- *topk_ids.shape, -1).type(topk_weight.dtype).mul_(
835
- topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
836
- return final_out
837
 
838
 
839
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
840
-
841
-
842
  """
843
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
844
  num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
845
-
846
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
847
  if n_rep == 1:
848
  return hidden_states
@@ -851,32 +297,31 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
851
  """
852
 
853
  return torch.repeat_interleave(hidden_states, dim=1, repeats=n_rep)
 
854
 
855
-
856
- # @log_timing
857
  class MotifAttention(nn.Module):
858
  """
859
  Differential Attention (DiffAttention) module.
860
 
861
- Implements the Differential Attention from
862
  "DIFFERENTIAL TRANSFORMER" (https://arxiv.org/pdf/2410.05258).
863
 
864
  Overview
865
  Standard transformers often over-allocate attention to irrelevant context.
866
- DiffAttention addresses this by computing attention as the difference between
867
- two separate softmax attention maps, effectively canceling noise and promoting
868
  sparse, structured attention patterns.
869
 
870
  Reference Implementation
871
  https://github.com/microsoft/unilm/tree/master/Diff-Transformer
872
 
873
  Args
874
- The differential attention mechanism computes attention as the difference of two softmax attention scores, weighted by a learnable scalar λ.
875
  λ is re-parameterized as λ = exp(λ_q1 · λ_k1) − exp(λ_q2 · λ_k2) + λ_init.
876
  - lambda_q1, lambda_q2 (nn.Parameter): Learnable vectors used to compute the first and second components of λ for query transformations.
877
  - lambda_k1, lambda_k2 (nn.Parameter): Learnable vectors used to compute the first and second components of λ for key transformations.
878
  - lambda_init (float): A constant used for initializing λ, typically set as λ_init = 0.8 − 0.6 × exp(−0.3 × (layer_index − 1)).
879
-
880
  """
881
 
882
  def __init__(self, config: MotifConfig, layer_idx: Optional[int] = None):
@@ -899,11 +344,7 @@ class MotifAttention(nn.Module):
899
  self.rope_theta = config.rope_theta
900
  self.is_causal = True
901
  self.attention_dropout = config.attention_dropout
902
- try:
903
- self.batch_num = config.batch_num
904
- logger.info(f'self.batcn_num : {self.batch_num}')
905
- except:
906
- self.batch_num = None
907
  if (self.head_dim * self.num_heads) != self.hidden_size:
908
  raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
909
  f" and `num_heads`: {self.num_heads}).")
@@ -912,61 +353,22 @@ class MotifAttention(nn.Module):
912
  self.num_key_value_heads //= 2
913
  self.n_rep = self.num_heads // self.num_key_value_heads
914
 
915
- ##mix attn
916
-
917
- self.mix_attn = config.mix_attn
918
-
919
- if self.mix_attn:
920
-
921
- self.cq, self.ck = 6, 11
922
- self.ch = 2
923
-
924
- self.key_query_conv = nn.Conv2d(
925
- in_channels=self.num_heads*2,
926
- out_channels=self.num_heads*2,
927
- kernel_size=(self.cq, self.ck),
928
- padding="same",
929
- groups=self.num_heads*2
930
- )
931
-
932
- self.head_conv = nn.Conv1d(
933
- in_channels=self.num_heads,
934
- out_channels=self.num_heads,
935
- kernel_size=1,
936
- padding=0,
937
- groups=self.num_heads // self.ch
938
- )
939
-
940
- self.group_norm = nn.GroupNorm(num_groups=self.num_heads, num_channels=self.num_heads)
941
-
942
-
943
-
944
- # re-init projections
945
  self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
946
  self.k_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
947
  self.v_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
948
  self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
949
 
950
- # init lambdas
951
  for name in ["lambda_q1", "lambda_k1", "lambda_q2", "lambda_k2"]:
952
  setattr(self, name, nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32)))
953
  getattr(self, name).data.normal_(mean=0.0, std=0.1)
954
 
955
- # Uses same norm as motif norm, without elementwise_affine option
956
  self.subln = MotifRMSNorm(2 * self.head_dim, eps=1e-5)
957
  self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * (layer_idx - 1))
958
 
959
- self.rotary_emb = MotifRotaryEmbedding(self.head_dim,
960
  max_position_embeddings=self.max_position_embeddings,
961
  base=self.rope_theta)
962
 
963
- for param in ["q_proj_alpha", "k_proj_alpha", "v_proj_alpha", "o_proj_alpha"]:
964
- setattr(
965
- self, param,
966
- nn.Parameter(torch.tensor(getattr(config, param, 1.0), dtype=torch.float))
967
- if config.wesar_weights else 1.0)
968
-
969
-
970
  def forward(
971
  self,
972
  hidden_states: torch.Tensor,
@@ -976,15 +378,13 @@ class MotifAttention(nn.Module):
976
  output_attentions: bool = False,
977
  use_cache: bool = False,
978
  cache_position: Optional[torch.LongTensor] = None,
979
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
980
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
981
  bsz, q_len, _ = hidden_states.size()
982
 
983
- query_states = self.q_proj(hidden_states) * self.q_proj_alpha
984
- key_states = self.k_proj(hidden_states) * self.k_proj_alpha
985
- value_states = self.v_proj(hidden_states) * self.v_proj_alpha
986
-
987
- ## bsz, seq, n_heads, head_dim
988
 
989
  query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
990
  key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -1006,17 +406,15 @@ class MotifAttention(nn.Module):
1006
  key_states,
1007
  cos,
1008
  sin,
1009
- fused_rope=self.config.fused_rope)
1010
 
1011
  if past_key_value is not None:
1012
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
1013
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
1014
 
1015
- # repeat k/v heads if n_kv_heads < n_heads
1016
  key_states = repeat_kv(key_states, self.num_key_value_groups)
1017
  value_states = repeat_kv(value_states, self.num_key_value_groups)
1018
 
1019
- ## bsz, #haead, q_len, head_dim -> bsz, head, q_len, q_len
1020
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
1021
 
1022
  kv_seq_len = key_states.shape[-2]
@@ -1025,49 +423,31 @@ class MotifAttention(nn.Module):
1025
  attention_mask = torch.triu(
1026
  torch.full((q_len, kv_seq_len), float("-inf"), dtype=attn_weights.dtype, device=attn_weights.device),
1027
  1 + offset)
1028
- ##attn weights conv2d, softmax and add attention_mask
1029
- if self.mix_attn:
1030
- ## condition mask==0, value : 0
1031
- attn_weights = attn_weights.masked_fill( attention_mask == 0, 0)
1032
- attn_weights = self.key_query_conv(attn_weights)
1033
- attn_weights = attn_weights[:, :, :kv_seq_len, :kv_seq_len]
1034
-
1035
- ###add attn
1036
  attn_weights = attn_weights + attention_mask
1037
 
1038
- # upcast attention to fp32
1039
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
1040
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
1041
 
1042
- # differential transformer lambdas
1043
  lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(attn_weights)
1044
  lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(attn_weights)
1045
  lambda_full = lambda_1 - lambda_2 + self.lambda_init
1046
  attn_weights = attn_weights.view(bsz, self.num_heads, 2, q_len, -1)
1047
  attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1]
1048
- ##head_conv
1049
- if self.mix_attn:
1050
- attn_weights = attn_weights.view(bsz, self.num_heads, -1).contiguous()
1051
- attn_weights = self.head_conv(attn_weights)
1052
- attn_weights = attn_weights.view(bsz, self.num_heads, q_len, -1).contiguous()
1053
 
1054
- ##shape : bsz, #heads, seq, head_dim
1055
  attn_output = torch.matmul(attn_weights, value_states)
1056
 
1057
-
1058
  attn_output = self.subln(attn_output)
1059
  attn_output = attn_output * (1 - self.lambda_init)
1060
 
1061
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim * 2):
1062
  raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
1063
  f" {attn_output.size()}")
1064
- if self.mix_attn:
1065
- attn_output = self.group_norm(attn_output)
1066
-
1067
  attn_output = attn_output.transpose(1, 2).contiguous()
1068
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
1069
 
1070
- attn_output = self.o_proj(attn_output) * self.o_proj_alpha
1071
 
1072
  if not output_attentions:
1073
  attn_weights = None
@@ -1075,7 +455,6 @@ class MotifAttention(nn.Module):
1075
  return attn_output, attn_weights, past_key_value
1076
 
1077
 
1078
- # @log_timing
1079
  class MotifFlashAttention2(MotifAttention):
1080
  """
1081
  Motif flash attention module, following Motif attention module. This module inherits from `MotifAttention`
@@ -1085,18 +464,16 @@ class MotifFlashAttention2(MotifAttention):
1085
  config.max_window_layers layers.
1086
  """
1087
 
1088
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
1089
  def __init__(self, *args, **kwargs):
1090
  super().__init__(*args, **kwargs)
1091
-
1092
- logger.info(f'flash attention True')
1093
-
1094
  # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
1095
  # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
1096
  # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
1097
 
1098
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
1099
 
 
 
1100
  def _reshape_heads(self, tensor, batch_size, seq_len):
1101
  """2-way head split tensor reshape"""
1102
  return tensor.reshape(batch_size, seq_len, self.num_heads, 2, self.head_dim)
@@ -1106,55 +483,27 @@ class MotifFlashAttention2(MotifAttention):
1106
  return tensor.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
1107
 
1108
  def _compute_attention(self, query_states, key_states, value_states, attention_mask, q_len, position_ids,
1109
- dropout_rate, sliding_window, is_moreh_attention, batch_num):
1110
  """Flash Attention 2 implements"""
1111
- if is_moreh_attention:
1112
- scale_factor = 1.0 / math.sqrt(self.head_dim)
1113
- # Copied from _flash_attention_forward
1114
- if not self._flash_attn_uses_top_left_mask:
1115
- causal = self.is_causal
1116
- else:
1117
- causal = self.is_causal and q_len != 1
1118
-
1119
- bsz = query_states.shape[0]
1120
-
1121
- if batch_num:
1122
- query_states = query_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
1123
- key_states = key_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
1124
- value_states = value_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
1125
-
1126
- attn_out = moreh_ops.flash_attention_varlen_dp(query_states,
1127
- key_states,
1128
- value_states,
1129
- attention_mask,
1130
- attention_mask,
1131
- max_seqlen_q=q_len,
1132
- max_seqlen_kv=q_len,
1133
- dropout_p=dropout_rate,
1134
- softmax_scale=scale_factor,
1135
- is_causal=causal,
1136
- batch_num=batch_num)
1137
- attn_out = attn_out.reshape(bsz, q_len, self.num_heads, -1)
1138
- else:
1139
- return MorehFlashAttention(query_states,
1140
- key_states,
1141
- value_states,
1142
- padding_mask=attention_mask,
1143
- dropout_p=dropout_rate,
1144
- softmax_scale=scale_factor,
1145
- causal=causal)
1146
- return attn_out
1147
  else:
1148
- return _flash_attention_forward(query_states,
1149
- key_states,
1150
- value_states,
 
 
1151
  attention_mask,
1152
  q_len,
1153
  position_ids=position_ids,
1154
  dropout=dropout_rate,
1155
  sliding_window=sliding_window,
1156
- is_causal=self.is_causal,
 
1157
  use_top_left_mask=self._flash_attn_uses_top_left_mask)
 
1158
 
1159
  def forward(
1160
  self,
@@ -1169,9 +518,9 @@ class MotifFlashAttention2(MotifAttention):
1169
  ):
1170
  bsz, q_len, _ = hidden_states.size()
1171
 
1172
- query_states = self.q_proj(hidden_states) * self.q_proj_alpha
1173
- key_states = self.k_proj(hidden_states) * self.k_proj_alpha
1174
- value_states = self.v_proj(hidden_states) * self.v_proj_alpha
1175
 
1176
  query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
1177
  key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -1192,13 +541,12 @@ class MotifFlashAttention2(MotifAttention):
1192
  key_states,
1193
  cos,
1194
  sin,
1195
- fused_rope=True)
1196
 
1197
  if past_key_value is not None:
1198
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
1199
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
1200
 
1201
- # repeat k/v heads if n_kv_heads < n_heads
1202
  key_states = repeat_kv(key_states, self.num_key_value_groups)
1203
  value_states = repeat_kv(value_states, self.num_key_value_groups)
1204
  dropout_rate = 0.0 if not self.training else self.attention_dropout
@@ -1207,7 +555,7 @@ class MotifFlashAttention2(MotifAttention):
1207
  # therefore the input hidden states gets silently casted in float32. Hence, we need
1208
  # cast them back in float16 just to be sure everything works as expected.
1209
  input_dtype = query_states.dtype
1210
- if input_dtype == torch.float32 and MorehFlashAttention is None:
1211
  if torch.is_autocast_enabled():
1212
  target_dtype = torch.get_autocast_gpu_dtype()
1213
  # Handle the case where the model is quantized
@@ -1234,7 +582,7 @@ class MotifFlashAttention2(MotifAttention):
1234
  value_states = value_states.transpose(1, 2)
1235
 
1236
  if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None
1237
- and self.layer_idx >= self.config.max_window_layers and MorehFlashAttention is None):
1238
  sliding_window = self.config.sliding_window
1239
  else:
1240
  sliding_window = None
@@ -1254,12 +602,10 @@ class MotifFlashAttention2(MotifAttention):
1254
  k1, k2 = k1.contiguous(), k2.contiguous()
1255
  v1, v2 = v1.contiguous(), v2.contiguous()
1256
 
1257
- is_moreh_attention = MorehFlashAttention is not None
1258
-
1259
- attn11, attn12 = self._compute_attention(q1, k1, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num), \
1260
- self._compute_attention(q1, k1, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num)
1261
- attn21, attn22 = self._compute_attention(q2, k2, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num), \
1262
- self._compute_attention(q2, k2, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num)
1263
 
1264
  attn1, attn2 = torch.cat([attn11, attn12], dim=-1), torch.cat([attn21, attn22], dim=-1)
1265
 
@@ -1277,16 +623,15 @@ class MotifFlashAttention2(MotifAttention):
1277
  attn_output = attn_output * (1 - self.lambda_init)
1278
 
1279
  if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim * 2):
1280
- raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
1281
  f" {attn_output.size()}")
1282
 
1283
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
1284
- attn_output = self.o_proj(attn_output) * self.o_proj_alpha
1285
 
1286
  return attn_output, None, past_key_value
1287
 
1288
 
1289
- # @log_timing
1290
  class MotifSdpaAttention(MotifAttention):
1291
  """
1292
  Motif attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -1294,7 +639,6 @@ class MotifSdpaAttention(MotifAttention):
1294
  SDPA API.
1295
  """
1296
 
1297
- # Adapted from MotifAttention.forward
1298
  def forward(
1299
  self,
1300
  hidden_states: torch.Tensor,
@@ -1307,7 +651,6 @@ class MotifSdpaAttention(MotifAttention):
1307
  position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
1308
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1309
  if output_attentions:
1310
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
1311
  logger.warning_once(
1312
  "MotifModel is using MotifSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
1313
  'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
@@ -1343,8 +686,7 @@ class MotifSdpaAttention(MotifAttention):
1343
  query_states, key_states = apply_rotary_pos_emb(query_states,
1344
  key_states,
1345
  cos,
1346
- sin,
1347
- fused_rope=self.config.fused_rope)
1348
 
1349
  if past_key_value is not None:
1350
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
@@ -1380,45 +722,25 @@ class MotifSdpaAttention(MotifAttention):
1380
  MOTIF_ATTENTION_CLASSES = {
1381
  "eager": MotifAttention,
1382
  "flash_attention_2": MotifFlashAttention2,
1383
- "sdpa": MotifSdpaAttention,
1384
  }
1385
 
1386
 
1387
- # @log_timing
1388
  class MotifDecoderLayer(nn.Module):
1389
 
1390
- def __init__(self, config: MotifConfig, moe_layer: bool, layer_idx: int):
1391
  super().__init__()
1392
  self.hidden_size = config.hidden_size
1393
- if config.use_moreh_attention:
1394
- config._attn_implementation = "flash_attention_2"
1395
  if config.sliding_window and config._attn_implementation != "flash_attention_2":
1396
  logger.warning_once(
1397
  f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
1398
  "unexpected results may be encountered.")
1399
- if not config.mix_attn:
1400
- self.self_attn = MOTIF_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
1401
- else:
1402
- self.self_attn = MOTIF_ATTENTION_CLASSES["eager"](config, layer_idx)
1403
  self.mlp = MotifMLP(config)
1404
- ### moe
1405
- self.moe = None
1406
- if moe_layer:
1407
- self.moe = MotifMoE(config)
1408
-
1409
- RMSNorm = MorehRMSNorm if MorehRMSNorm is not None else MotifRMSNorm
1410
- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1411
- self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1412
 
1413
- if config.wesar_weights and config.use_norm_alpha:
1414
- self.input_layernorm_alpha = nn.Parameter(torch.tensor(1).float())
1415
- else:
1416
- self.input_layernorm_alpha = 1
1417
-
1418
- if config.wesar_weights and config.use_norm_alpha :
1419
- self.post_attention_layernorm_alpha = nn.Parameter(torch.tensor(1).float())
1420
- else:
1421
- self.post_attention_layernorm_alpha = 1
1422
 
1423
  def forward(
1424
  self,
@@ -1456,7 +778,7 @@ class MotifDecoderLayer(nn.Module):
1456
 
1457
  residual = hidden_states
1458
 
1459
- hidden_states = self.input_layernorm(hidden_states) * self.input_layernorm_alpha
1460
 
1461
  # Self Attention
1462
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
@@ -1473,16 +795,8 @@ class MotifDecoderLayer(nn.Module):
1473
 
1474
  # Fully Connected
1475
  residual = hidden_states
1476
- hidden_states = self.post_attention_layernorm(hidden_states) * self.post_attention_layernorm_alpha
1477
-
1478
- if self.moe is not None:
1479
- hidden_states, identity = self.moe(hidden_states)
1480
- ## add output of shared expert and output of small moe experts.
1481
- ## hidden state must be zero tensor (for first forward)
1482
- hidden_states += self.mlp(identity)
1483
- else:
1484
- hidden_states = self.mlp(hidden_states)
1485
-
1486
  hidden_states = residual + hidden_states
1487
 
1488
  outputs = (hidden_states, )
@@ -1532,45 +846,24 @@ class MotifPreTrainedModel(PreTrainedModel):
1532
  def _init_weights(self, module):
1533
  module_std = self.config.initializer_range
1534
  if isinstance(module, nn.Linear):
1535
- if getattr(module, "__do_scale_tager__", False):
1536
- module_std = module_std / self.config.init_scale_o
1537
-
1538
- if getattr(module, "__do_scale_tager_mu_o__", False):
1539
- if self.config.dim_model_base_init is not None:
1540
- module_std = module_std / math.sqrt(2*(self.config.hidden_size / self.config.dim_model_base_init)*self.config.num_hidden_layers)
1541
- else:
1542
- module_std = module_std
1543
- elif getattr(module, "__do_scale_tager_mu_ffn__", False):
1544
- if self.config.dim_model_base_init is not None:
1545
- module_std = module_std = module_std / math.sqrt(2*(self.config.hidden_size / self.config.dim_model_base_init)*self.config.num_hidden_layers)
1546
- else:
1547
- module_std = module_std
1548
- elif getattr(module, "__do_scale_tager_mu_dim_model__", False):
1549
- if self.config.dim_model_base_init is not None:
1550
- module_std = module_std / math.sqrt(self.config.hidden_size / self.config.dim_model_base_init)
1551
- else:
1552
- module_std = module_std
1553
- elif getattr(module, "__do_scale_tager_mu_dim_base_model__", False):
1554
- module_std = module_std / math.sqrt(self.config.dim_model_base_lmh) ### lmhead.. 1
1555
- else:
1556
- module_std = module_std
1557
-
1558
- torch.nn.init.trunc_normal_(module.weight.data, mean=0.0, std=module_std, a=-3*module_std, b=3*module_std)
1559
  if module.bias is not None:
1560
  module.bias.data.zero_()
1561
 
1562
  elif isinstance(module, nn.Embedding):
1563
- torch.nn.init.trunc_normal_(module.weight.data, mean=0.0, std=module_std, a=-3*module_std, b=3*module_std)
 
1564
  if module.padding_idx is not None:
1565
  module.weight.data[module.padding_idx].zero_()
1566
 
1567
 
1568
  @dataclass
1569
  class MotifModelOutputWithPast(ModelOutput):
1570
- """
1571
- This augments `BaseModelOutputWithPast` in `transformers.modeling_outputs` with new optional keys: `causal_mask`, `position_embeddings`.
1572
  The optional keys are currently used in the following ways:
1573
- - pass information to the token-wise last attention layers in multi-token training
1574
  """
1575
  last_hidden_state: torch.FloatTensor = None
1576
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
@@ -1655,7 +948,6 @@ MOTIF_INPUTS_DOCSTRING = r"""
1655
  """
1656
 
1657
 
1658
- # @log_timing
1659
  @add_start_docstrings(
1660
  "The bare Motif Model outputting raw hidden-states without any specific head on top.",
1661
  MOTIF_START_DOCSTRING,
@@ -1672,23 +964,11 @@ class MotifModel(MotifPreTrainedModel):
1672
  super().__init__(config)
1673
  self.padding_idx = config.pad_token_id
1674
  self.vocab_size = config.vocab_size
1675
- self.multi_token_heads = config.multi_token_heads
1676
 
1677
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1678
- # NOTE: For multi-token models, the last decoder layers (one for each token index)
1679
- # are implemented as a part of `MotifModelForCausalLM` to enable a custom forward-backward procedure.
1680
-
1681
- num_hidden_layers = config.num_hidden_layers if self.multi_token_heads is None else config.num_hidden_layers - 1
1682
- if config.moe:
1683
- moe_layer = [True for i in range(num_hidden_layers)]
1684
- else:
1685
- moe_layer = [False for i in range(num_hidden_layers)]
1686
- logger.info(f'current_moe layer { moe_layer }')
1687
- self.layers = nn.ModuleList([MotifDecoderLayer(config = config, moe_layer= moe_layer[layer_idx],
1688
- layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)])
1689
- self._attn_implementation = config._attn_implementation
1690
- RMSNorm = MorehRMSNorm if MorehRMSNorm is not None else MotifRMSNorm
1691
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1692
  self.hidden_size = config.hidden_size
1693
  self.num_heads = config.num_attention_heads
1694
  self.head_dim = self.hidden_size // self.num_heads
@@ -1701,36 +981,6 @@ class MotifModel(MotifPreTrainedModel):
1701
  self.gradient_checkpointing = False
1702
  self.post_init()
1703
 
1704
- self.use_pipeline = config.use_pipeline
1705
- if self.use_pipeline:
1706
- logger.info('use reinforced pp..')
1707
- if config.num_stages==2:
1708
- ### moe version
1709
- if config.decontam_attn:
1710
- self.split_layers = [15]
1711
- else:
1712
- if num_hidden_layers == 32:
1713
- self.split_layers = [15] # 14: 15,17 # 13: 14:18
1714
- else:
1715
- self.split_layers = [6]
1716
- elif config.num_stages==3:
1717
- self.split_layers = [9,20] ## 11, 11, 10
1718
- elif config.num_stages==4:
1719
- self.split_layers = [7,15,23] #7,9,9,7
1720
- elif config.num_stages==16:
1721
- self.split_layers = [1,3,5,7,9,11,13,15,17,19,21,23,25,27,29]
1722
- logger.info(f' check the split layers (moe): {self.split_layers}')
1723
-
1724
- self.scale_emb = 1
1725
-
1726
- # Reparameterization <|_1_|>
1727
- if config.wesar_weights :
1728
- logger.info(f'config.wesar_weights {config.wesar_weights}')
1729
- self.norm_alpha = nn.Parameter(torch.tensor(1).float())
1730
- self.scale_emb = 10
1731
- else:
1732
- self.norm_alpha = 1
1733
-
1734
  def get_input_embeddings(self):
1735
  return self.embed_tokens
1736
 
@@ -1769,7 +1019,6 @@ class MotifModel(MotifPreTrainedModel):
1769
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
1770
  use_cache = False
1771
 
1772
- # kept for BC (non `Cache` `past_key_values` inputs)
1773
  return_legacy_cache = False
1774
  if use_cache and not isinstance(past_key_values, Cache):
1775
  return_legacy_cache = True
@@ -1783,26 +1032,23 @@ class MotifModel(MotifPreTrainedModel):
1783
  "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)")
1784
 
1785
  if inputs_embeds is None:
1786
- inputs_embeds = self.embed_tokens(input_ids) * self.scale_emb
1787
 
1788
  if cache_position is None:
1789
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1790
  cache_position = torch.arange(past_seen_tokens,
1791
  past_seen_tokens + inputs_embeds.shape[1],
1792
  device=inputs_embeds.device)
1793
- position_ids = None
1794
  if position_ids is None:
1795
  position_ids = cache_position.unsqueeze(0)
1796
-
1797
  causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values,
1798
  output_attentions)
1799
 
1800
  hidden_states = inputs_embeds
1801
  bsz, q_len, _ = hidden_states.size()
1802
- # create position embeddings to be shared across the decoder layers
1803
  position_embeddings = self.rotary_emb(hidden_states, seq_len=q_len)
1804
 
1805
- # decoder layers
1806
  all_hidden_states = () if output_hidden_states else None
1807
  all_self_attns = () if output_attentions else None
1808
  next_decoder_cache = None
@@ -1837,20 +1083,14 @@ class MotifModel(MotifPreTrainedModel):
1837
 
1838
  hidden_states = layer_outputs[0]
1839
 
1840
-
1841
- if self.use_pipeline and idx in self.split_layers:
1842
- hidden_states = torch.moreh.pipeline_assign(hidden_states)
1843
-
1844
  if use_cache:
1845
  next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1846
 
1847
  if output_attentions:
1848
  all_self_attns += (layer_outputs[1], )
1849
 
1850
- # <|_2_|>
1851
- hidden_states = self.norm(hidden_states)* self.norm_alpha
1852
-
1853
- # add hidden states from the last decoder layer
1854
  if output_hidden_states:
1855
  all_hidden_states += (hidden_states, )
1856
 
@@ -1881,8 +1121,6 @@ class MotifModel(MotifPreTrainedModel):
1881
  output_attentions: bool,
1882
  ):
1883
  if self.config._attn_implementation == "flash_attention_2":
1884
- if MorehFlashAttention is not None:
1885
- return attention_mask
1886
  if attention_mask is not None and 0.0 in attention_mask:
1887
  return attention_mask
1888
  return None
@@ -1909,6 +1147,7 @@ class MotifModel(MotifPreTrainedModel):
1909
  dtype, device = input_tensor.dtype, input_tensor.device
1910
  min_dtype = torch.finfo(dtype).min
1911
  sequence_length = input_tensor.shape[1]
 
1912
  # SlidingWindowCache or StaticCache
1913
  if using_sliding_window_cache or using_static_cache:
1914
  target_length = past_key_values.get_max_cache_shape()
@@ -2003,7 +1242,6 @@ class MotifModel(MotifPreTrainedModel):
2003
  return causal_mask
2004
 
2005
 
2006
- # @log_timing
2007
  class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
2008
  _tied_weights_keys = ["lm_head.weight"]
2009
 
@@ -2011,35 +1249,14 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
2011
  super().__init__(config)
2012
  self.model = MotifModel(config)
2013
  self.vocab_size = config.vocab_size
2014
- self.multi_token_heads = config.multi_token_heads
2015
 
2016
- if self.multi_token_heads is None:
2017
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
2018
- else:
2019
- self.tokenwise_last_layers = nn.ModuleList(
2020
- [MotifDecoderLayer(config, config.num_hidden_layers - 1) for _ in range(self.multi_token_heads)])
2021
- self.tokenwise_lm_heads = nn.ModuleList(
2022
- [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(self.multi_token_heads)])
2023
- self.should_skip_separate_backward_pass = self.multi_token_heads is not None
2024
 
2025
  # Initialize weights and apply final processing
2026
  self.post_init()
2027
 
2028
- # <|_3_|>
2029
- if config.muP:
2030
- self.lm_head.__do_scale_tager_mu_dim_base_model__=True
2031
-
2032
- # <|_4_|>
2033
- self.lm_head_alpha = 1
2034
- if config.wesar_weights:
2035
- self.lm_head_alpha = nn.Parameter(torch.tensor(1).float())
2036
-
2037
  if getattr(config, "tie_word_embeddings", True):
2038
- logger.info('tie embeddings')
2039
  self.tie_weights()
2040
- else:
2041
- # <|_5_|>
2042
- self.lm_head.__do_scale_tager_mu_dim_base_model__ = False
2043
 
2044
  def get_input_embeddings(self):
2045
  return self.model.embed_tokens
@@ -2059,101 +1276,7 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
2059
  def get_decoder(self):
2060
  return self.model
2061
 
2062
- def multi_token_forward_backward(self,
2063
- hidden_states: torch.FloatTensor,
2064
- outputs: MotifModelOutputWithPast,
2065
- labels: torch.LongTensor,
2066
- position_ids: Optional[torch.LongTensor],
2067
- output_attentions: Optional[bool],
2068
- use_cache: Optional[bool],
2069
- cache_position: Optional[torch.LongTensor],
2070
- return_dict: Optional[bool],
2071
- num_logits_to_keep: int = 0) -> CausalLMOutputWithPast:
2072
- """
2073
- This implements the main forward-backward procedure for multi-token model training proposed in
2074
- the paper https://arxiv.org/abs/2404.19737.
2075
- Essentially,
2076
- - The multi-token model tries to predict n (instead of 1) tokens at a time.
2077
- - Applying this only during training and using first-token prediction during inference is still helpful.
2078
- - The change in architecture: when using n-token prediction, each token index (between 1 and n) has its own
2079
- (1) last attention layer and (2) lm head.
2080
- - The change in loss: sum of cross-entropy losses corresponding to each token index.
2081
- - Custom forward-backward procedure for memory efficiency: refer to the implementation of `multi_head_forward_backward`.
2082
- """
2083
- if not return_dict:
2084
- raise NotImplementedError("return_dict must be True for multi-token training")
2085
-
2086
- past_key_values = outputs.past_key_values
2087
- causal_mask = outputs.causal_mask
2088
- position_embeddings = outputs.position_embeddings
2089
-
2090
- if labels is not None:
2091
- labels = labels.to(hidden_states.device)
2092
-
2093
- def _tokenwise_forward(hidden_states: torch.Tensor, token_idx):
2094
- ## Model forward
2095
- layer = self.tokenwise_last_layers[token_idx]
2096
- lm_head = self.tokenwise_lm_heads[token_idx]
2097
-
2098
- layer_outputs = layer(
2099
- hidden_states,
2100
- attention_mask=causal_mask,
2101
- position_ids=position_ids,
2102
- past_key_values=past_key_values, # TODO: update past_key_values?
2103
- output_attentions=output_attentions,
2104
- use_cache=use_cache,
2105
- cache_position=cache_position,
2106
- position_embeddings=position_embeddings,
2107
- )
2108
- last_hidden_states = layer_outputs[0]
2109
- if num_logits_to_keep > 0:
2110
- assert labels is None
2111
- last_hidden_states = last_hidden_states[:, -num_logits_to_keep:, :]
2112
- tokenwise_logits = lm_head(last_hidden_states)
2113
-
2114
- if labels is None:
2115
- return {
2116
- "loss": None,
2117
- "logits": tokenwise_logits,
2118
- }
2119
-
2120
- ## Compute loss
2121
- shift_n = token_idx + 1
2122
- shift_logits = tokenwise_logits[..., :-shift_n, :].contiguous()
2123
- shift_labels = labels[..., shift_n:].contiguous()
2124
-
2125
- loss_fct = CrossEntropyLoss()
2126
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
2127
- shift_labels = shift_labels.view(-1)
2128
-
2129
- tokenwise_loss = loss_fct(shift_logits, shift_labels)
2130
-
2131
- return {
2132
- "loss": tokenwise_loss,
2133
- "logits": tokenwise_logits,
2134
- }
2135
-
2136
- head_fns = [
2137
- lambda hidden_states, token_idx=token_idx: _tokenwise_forward(hidden_states, token_idx)
2138
- for token_idx in range(self.multi_token_heads)
2139
- ]
2140
- loss, logits = multi_head_forward_backward(hidden_states,
2141
- head_fns,
2142
- return_keys=("loss", "logits"),
2143
- return_only_first_head=True)
2144
-
2145
- if not return_dict:
2146
- output = (logits, ) + outputs[1:]
2147
- return (loss, ) + output
2148
-
2149
- return CausalLMOutputWithPast(
2150
- loss=loss,
2151
- logits=logits,
2152
- past_key_values=outputs.past_key_values,
2153
- hidden_states=outputs.hidden_states,
2154
- attentions=outputs.attentions,
2155
- )
2156
-
2157
  @add_start_docstrings_to_model_forward(MOTIF_INPUTS_DOCSTRING)
2158
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
2159
  def forward(
@@ -2209,8 +1332,6 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
2209
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
2210
 
2211
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
2212
- outputs_include_causal_mask = self.multi_token_heads is not None
2213
- outputs_include_position_embeddings = self.multi_token_heads is not None
2214
  outputs: MotifModelOutputWithPast = self.model(
2215
  input_ids=input_ids,
2216
  attention_mask=attention_mask,
@@ -2222,25 +1343,12 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
2222
  output_hidden_states=output_hidden_states,
2223
  return_dict=return_dict,
2224
  cache_position=cache_position,
2225
- outputs_include_causal_mask=outputs_include_causal_mask,
2226
- outputs_include_position_embeddings=outputs_include_position_embeddings,
2227
  )
2228
 
2229
  hidden_states = outputs[0]
2230
 
2231
- if self.multi_token_heads is not None:
2232
- return self.multi_token_forward_backward(hidden_states,
2233
- outputs,
2234
- labels,
2235
- position_ids,
2236
- output_attentions,
2237
- use_cache,
2238
- cache_position,
2239
- return_dict,
2240
- num_logits_to_keep=num_logits_to_keep)
2241
-
2242
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
2243
- hidden_states = hidden_states * self.lm_head_alpha
2244
  logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
2245
  logits = logits.float()
2246
 
@@ -2254,7 +1362,6 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
2254
  loss_fct = CrossEntropyLoss()
2255
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
2256
  shift_labels = shift_labels.view(-1)
2257
- # Enable model parallelism
2258
  shift_labels = shift_labels.to(shift_logits.device)
2259
  loss = loss_fct(shift_logits, shift_labels)
2260
 
@@ -2268,4 +1375,4 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
2268
  past_key_values=outputs.past_key_values,
2269
  hidden_states=outputs.hidden_states,
2270
  attentions=outputs.attentions,
2271
- )
 
1
  import math
2
+ from dataclasses import dataclass
3
  from typing import List, Optional, Tuple, Union
4
 
5
  import torch
6
+ import torch.nn.functional as F
7
  import torch.utils.checkpoint
8
  from torch import nn
9
  from torch.nn import CrossEntropyLoss
10
+ from transformers.activations import ACT2CLS as _ACT2CLS
11
+ from transformers.activations import ClassInstantier
12
  from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
13
  from transformers.generation import GenerationMixin
14
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
15
  from transformers.modeling_flash_attention_utils import _flash_attention_forward
16
+ from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
 
 
 
17
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
18
  from transformers.modeling_utils import PreTrainedModel
19
  from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
20
+ from transformers.utils import (add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available,
21
+ is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ from .configuration_motif import MotifConfig
24
 
25
 
26
  class PolyNorm(torch.nn.Module):
27
+ """
28
  A trainable activation function introduced in https://arxiv.org/html/2411.03884v1.
29
+ The code is copied from https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
 
30
  """
31
 
32
  def __init__(self, eps=1e-6):
 
43
  x ** 2) + self.weight[2] * self._norm(x) + self.bias
44
 
45
 
46
+ CUSTOM_ACT2CLS = {"poly_norm": PolyNorm}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  ACT2CLS = {**_ACT2CLS, **CUSTOM_ACT2CLS}
48
  ACT2FN = ClassInstantier(ACT2CLS)
49
 
50
+ logger = logging.get_logger(__name__)
51
+
52
+ if is_flash_attn_2_available():
53
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
54
+
55
+ _CONFIG_FOR_DOC = "MotifConfig"
56
 
57
 
58
  class MotifRMSNorm(nn.Module):
 
76
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
77
 
78
 
79
+ ALL_LAYERNORM_LAYERS.append(MotifRMSNorm)
80
 
81
 
82
  class MotifRotaryEmbeddingWithCache(nn.Module):
 
108
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
109
  self.register_buffer("inv_freq", inv_freq, persistent=False)
110
 
 
111
  self._set_cos_sin_cache(seq_len=max_position_embeddings,
112
  device=self.inv_freq.device,
113
  dtype=torch.get_default_dtype())
 
128
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
129
 
130
  return (
131
+ self.cos_cached[ :seq_len].to(dtype=x.dtype),
132
+ self.sin_cached[ :seq_len].to(dtype=x.dtype),
133
  )
134
 
135
 
 
136
  class MotifRotaryEmbedding(nn.Module):
137
 
138
  def __init__(
 
163
  self.max_seq_len_cached = max_position_embeddings
164
  self.original_max_seq_len = max_position_embeddings
165
  else:
 
166
  if config.rope_scaling is not None:
167
  self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
168
  else:
 
224
  def rotate_half(x):
225
  """
226
  Rotates half of the dimensions of the input tensor using torch.roll and in-place negation.
227
+
228
  Args:
229
  x (torch.Tensor): The input tensor.
230
+
231
  Returns:
232
  torch.Tensor: A tensor where the latter half of the dimensions are negated
233
  and moved before the first half.
 
239
  return rotated_tensor
240
 
241
 
242
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
 
243
  """
244
  Applies rotary position embeddings to the input tensors.
245
 
 
248
  k (torch.Tensor): Key tensor of shape (B, NH, S, D_KV).
249
  cos (torch.Tensor): Cosine values for rotary embedding.
250
  sin (torch.Tensor): Sine values for rotary embedding.
251
+ unsqueeze_dim (int, optional): Dimension along which `cos` and `sin` are unsqueezed.
252
  Defaults to 1.
 
 
 
253
 
254
  Returns:
255
  Tuple[torch.Tensor, torch.Tensor]: Returns transformed query and key tensors after applying rotary embeddings.
256
  """
257
  '''
258
+ # (B, NH, S, D_KV) -> (B, S, NH, D_KV)
259
  cos = cos.unsqueeze(unsqueeze_dim)
260
  sin = sin.unsqueeze(unsqueeze_dim)
261
  q_embed = (q * cos) + (rotate_half(q) * sin)
262
  k_embed = (k * cos) + (rotate_half(k) * sin)
263
  '''
264
+ device = q.device
265
+ return map(
266
+ lambda x: (x * cos[position_ids].unsqueeze(unsqueeze_dim).to(device)) +
267
+ (rotate_half(x) * sin[position_ids].unsqueeze(unsqueeze_dim).to(device)), (q, k))
 
 
 
 
 
 
 
 
 
 
 
268
 
 
 
 
269
 
 
 
 
 
270
  class MotifMLP(nn.Module):
271
 
272
  def __init__(self, config):
273
  super().__init__()
274
  self.hidden_size = config.hidden_size
275
  self.intermediate_size = config.intermediate_size
276
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias)
277
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias)
278
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
279
  self.act_fn = ACT2FN[config.hidden_act]
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  def forward(self, hidden_state):
282
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
 
285
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
286
+
287
+
288
  """
289
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
290
  num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
291
+
292
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
293
  if n_rep == 1:
294
  return hidden_states
 
297
  """
298
 
299
  return torch.repeat_interleave(hidden_states, dim=1, repeats=n_rep)
300
+
301
 
 
 
302
  class MotifAttention(nn.Module):
303
  """
304
  Differential Attention (DiffAttention) module.
305
 
306
+ Implements the Differential Attention from
307
  "DIFFERENTIAL TRANSFORMER" (https://arxiv.org/pdf/2410.05258).
308
 
309
  Overview
310
  Standard transformers often over-allocate attention to irrelevant context.
311
+ DiffAttention addresses this by computing attention as the difference between
312
+ two separate softmax attention maps, effectively canceling noise and promoting
313
  sparse, structured attention patterns.
314
 
315
  Reference Implementation
316
  https://github.com/microsoft/unilm/tree/master/Diff-Transformer
317
 
318
  Args
319
+ The differential attention mechanism computes attention as the difference of two softmax attention scores, weighted by a learnable scalar λ.
320
  λ is re-parameterized as λ = exp(λ_q1 · λ_k1) − exp(λ_q2 · λ_k2) + λ_init.
321
  - lambda_q1, lambda_q2 (nn.Parameter): Learnable vectors used to compute the first and second components of λ for query transformations.
322
  - lambda_k1, lambda_k2 (nn.Parameter): Learnable vectors used to compute the first and second components of λ for key transformations.
323
  - lambda_init (float): A constant used for initializing λ, typically set as λ_init = 0.8 − 0.6 × exp(−0.3 × (layer_index − 1)).
324
+
325
  """
326
 
327
  def __init__(self, config: MotifConfig, layer_idx: Optional[int] = None):
 
344
  self.rope_theta = config.rope_theta
345
  self.is_causal = True
346
  self.attention_dropout = config.attention_dropout
347
+
 
 
 
 
348
  if (self.head_dim * self.num_heads) != self.hidden_size:
349
  raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
350
  f" and `num_heads`: {self.num_heads}).")
 
353
  self.num_key_value_heads //= 2
354
  self.n_rep = self.num_heads // self.num_key_value_heads
355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
357
  self.k_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
358
  self.v_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
359
  self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
360
 
 
361
  for name in ["lambda_q1", "lambda_k1", "lambda_q2", "lambda_k2"]:
362
  setattr(self, name, nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32)))
363
  getattr(self, name).data.normal_(mean=0.0, std=0.1)
364
 
 
365
  self.subln = MotifRMSNorm(2 * self.head_dim, eps=1e-5)
366
  self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * (layer_idx - 1))
367
 
368
+ self.rotary_emb = MotifRotaryEmbeddingWithCache(self.head_dim,
369
  max_position_embeddings=self.max_position_embeddings,
370
  base=self.rope_theta)
371
 
 
 
 
 
 
 
 
372
  def forward(
373
  self,
374
  hidden_states: torch.Tensor,
 
378
  output_attentions: bool = False,
379
  use_cache: bool = False,
380
  cache_position: Optional[torch.LongTensor] = None,
381
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
382
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
383
  bsz, q_len, _ = hidden_states.size()
384
 
385
+ query_states = self.q_proj(hidden_states)
386
+ key_states = self.k_proj(hidden_states)
387
+ value_states = self.v_proj(hidden_states)
 
 
388
 
389
  query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
390
  key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
406
  key_states,
407
  cos,
408
  sin,
409
+ position_ids=position_ids)
410
 
411
  if past_key_value is not None:
412
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
413
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
414
 
 
415
  key_states = repeat_kv(key_states, self.num_key_value_groups)
416
  value_states = repeat_kv(value_states, self.num_key_value_groups)
417
 
 
418
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
419
 
420
  kv_seq_len = key_states.shape[-2]
 
423
  attention_mask = torch.triu(
424
  torch.full((q_len, kv_seq_len), float("-inf"), dtype=attn_weights.dtype, device=attn_weights.device),
425
  1 + offset)
426
+
 
 
 
 
 
 
 
427
  attn_weights = attn_weights + attention_mask
428
 
 
429
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
430
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
431
 
 
432
  lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(attn_weights)
433
  lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(attn_weights)
434
  lambda_full = lambda_1 - lambda_2 + self.lambda_init
435
  attn_weights = attn_weights.view(bsz, self.num_heads, 2, q_len, -1)
436
  attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1]
 
 
 
 
 
437
 
 
438
  attn_output = torch.matmul(attn_weights, value_states)
439
 
 
440
  attn_output = self.subln(attn_output)
441
  attn_output = attn_output * (1 - self.lambda_init)
442
 
443
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim * 2):
444
  raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
445
  f" {attn_output.size()}")
446
+
 
 
447
  attn_output = attn_output.transpose(1, 2).contiguous()
448
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
449
 
450
+ attn_output = self.o_proj(attn_output)
451
 
452
  if not output_attentions:
453
  attn_weights = None
 
455
  return attn_output, attn_weights, past_key_value
456
 
457
 
 
458
  class MotifFlashAttention2(MotifAttention):
459
  """
460
  Motif flash attention module, following Motif attention module. This module inherits from `MotifAttention`
 
464
  config.max_window_layers layers.
465
  """
466
 
 
467
  def __init__(self, *args, **kwargs):
468
  super().__init__(*args, **kwargs)
 
 
 
469
  # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
470
  # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
471
  # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
472
 
473
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
474
 
475
+ logger.info(f'flash attention is used {not self._flash_attn_uses_top_left_mask}')
476
+
477
  def _reshape_heads(self, tensor, batch_size, seq_len):
478
  """2-way head split tensor reshape"""
479
  return tensor.reshape(batch_size, seq_len, self.num_heads, 2, self.head_dim)
 
483
  return tensor.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
484
 
485
  def _compute_attention(self, query_states, key_states, value_states, attention_mask, q_len, position_ids,
486
+ dropout_rate, sliding_window):
487
  """Flash Attention 2 implements"""
488
+ _input_type = query_states.dtype
489
+ scale_factor = 1.0 / math.sqrt(self.head_dim)
490
+ if not self._flash_attn_uses_top_left_mask:
491
+ causal = self.is_causal
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
  else:
493
+ causal = self.is_causal and q_len != 1
494
+
495
+ attn_out = _flash_attention_forward(query_states.bfloat16(),
496
+ key_states.bfloat16(),
497
+ value_states.bfloat16(),
498
  attention_mask,
499
  q_len,
500
  position_ids=position_ids,
501
  dropout=dropout_rate,
502
  sliding_window=sliding_window,
503
+ is_causal=True,
504
+ softmax_scale=scale_factor,
505
  use_top_left_mask=self._flash_attn_uses_top_left_mask)
506
+ return attn_out.to(_input_type)
507
 
508
  def forward(
509
  self,
 
518
  ):
519
  bsz, q_len, _ = hidden_states.size()
520
 
521
+ query_states = self.q_proj(hidden_states)
522
+ key_states = self.k_proj(hidden_states)
523
+ value_states = self.v_proj(hidden_states)
524
 
525
  query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
526
  key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
541
  key_states,
542
  cos,
543
  sin,
544
+ position_ids=position_ids)
545
 
546
  if past_key_value is not None:
547
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
548
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
549
 
 
550
  key_states = repeat_kv(key_states, self.num_key_value_groups)
551
  value_states = repeat_kv(value_states, self.num_key_value_groups)
552
  dropout_rate = 0.0 if not self.training else self.attention_dropout
 
555
  # therefore the input hidden states gets silently casted in float32. Hence, we need
556
  # cast them back in float16 just to be sure everything works as expected.
557
  input_dtype = query_states.dtype
558
+ if input_dtype == torch.float32:
559
  if torch.is_autocast_enabled():
560
  target_dtype = torch.get_autocast_gpu_dtype()
561
  # Handle the case where the model is quantized
 
582
  value_states = value_states.transpose(1, 2)
583
 
584
  if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None
585
+ and self.layer_idx >= self.config.max_window_layers):
586
  sliding_window = self.config.sliding_window
587
  else:
588
  sliding_window = None
 
602
  k1, k2 = k1.contiguous(), k2.contiguous()
603
  v1, v2 = v1.contiguous(), v2.contiguous()
604
 
605
+ attn11, attn12 = self._compute_attention(q1, k1, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window), \
606
+ self._compute_attention(q1, k1, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window)
607
+ attn21, attn22 = self._compute_attention(q2, k2, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window), \
608
+ self._compute_attention(q2, k2, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window)
 
 
609
 
610
  attn1, attn2 = torch.cat([attn11, attn12], dim=-1), torch.cat([attn21, attn22], dim=-1)
611
 
 
623
  attn_output = attn_output * (1 - self.lambda_init)
624
 
625
  if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim * 2):
626
+ raise ValueError(f"`attn_output` should be of size {(bsz, q_len, self.num_heads, 2*self.head_dim)}, but is"
627
  f" {attn_output.size()}")
628
 
629
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
630
+ attn_output = self.o_proj(attn_output)
631
 
632
  return attn_output, None, past_key_value
633
 
634
 
 
635
  class MotifSdpaAttention(MotifAttention):
636
  """
637
  Motif attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
 
639
  SDPA API.
640
  """
641
 
 
642
  def forward(
643
  self,
644
  hidden_states: torch.Tensor,
 
651
  position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
652
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
653
  if output_attentions:
 
654
  logger.warning_once(
655
  "MotifModel is using MotifSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
656
  'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
 
686
  query_states, key_states = apply_rotary_pos_emb(query_states,
687
  key_states,
688
  cos,
689
+ sin)
 
690
 
691
  if past_key_value is not None:
692
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
 
722
  MOTIF_ATTENTION_CLASSES = {
723
  "eager": MotifAttention,
724
  "flash_attention_2": MotifFlashAttention2,
725
+ "sdpa": MotifAttention,
726
  }
727
 
728
 
 
729
  class MotifDecoderLayer(nn.Module):
730
 
731
+ def __init__(self, config: MotifConfig, layer_idx: int):
732
  super().__init__()
733
  self.hidden_size = config.hidden_size
 
 
734
  if config.sliding_window and config._attn_implementation != "flash_attention_2":
735
  logger.warning_once(
736
  f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
737
  "unexpected results may be encountered.")
738
+ self.self_attn = MOTIF_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
 
 
 
739
  self.mlp = MotifMLP(config)
740
+
741
+ self.input_layernorm = MotifRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
742
+ self.post_attention_layernorm = MotifRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
 
 
 
743
 
 
 
 
 
 
 
 
 
 
744
 
745
  def forward(
746
  self,
 
778
 
779
  residual = hidden_states
780
 
781
+ hidden_states = self.input_layernorm(hidden_states)
782
 
783
  # Self Attention
784
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
 
795
 
796
  # Fully Connected
797
  residual = hidden_states
798
+ hidden_states = self.post_attention_layernorm(hidden_states)
799
+ hidden_states = self.mlp(hidden_states)
 
 
 
 
 
 
 
 
800
  hidden_states = residual + hidden_states
801
 
802
  outputs = (hidden_states, )
 
846
  def _init_weights(self, module):
847
  module_std = self.config.initializer_range
848
  if isinstance(module, nn.Linear):
849
+ module.weight.data.normal_(mean=0.0, std=module_std)
850
+ module.weight.data = torch.where(abs(module.weight.data) > module_std*3, 0, module.weight.data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
851
  if module.bias is not None:
852
  module.bias.data.zero_()
853
 
854
  elif isinstance(module, nn.Embedding):
855
+ module.weight.data.normal_(mean=0.0, std=module_std)
856
+ module.weight.data = torch.where(abs(module.weight.data) > module_std*3, 0, module.weight.data)
857
  if module.padding_idx is not None:
858
  module.weight.data[module.padding_idx].zero_()
859
 
860
 
861
  @dataclass
862
  class MotifModelOutputWithPast(ModelOutput):
863
+ """
864
+ This augments `BaseModelOutputWithPast` in `transformers.modeling_outputs` with new optional keys: `causal_mask`, `position_embeddings`.
865
  The optional keys are currently used in the following ways:
866
+ - pass information to the token-wise last attention layers in multi-token training
867
  """
868
  last_hidden_state: torch.FloatTensor = None
869
  past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
 
948
  """
949
 
950
 
 
951
  @add_start_docstrings(
952
  "The bare Motif Model outputting raw hidden-states without any specific head on top.",
953
  MOTIF_START_DOCSTRING,
 
964
  super().__init__(config)
965
  self.padding_idx = config.pad_token_id
966
  self.vocab_size = config.vocab_size
 
967
 
968
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
969
+ num_hidden_layers = config.num_hidden_layers
970
+ self.layers = nn.ModuleList([MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)])
971
+ self.norm = MotifRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
 
 
 
 
 
 
 
 
 
972
  self.hidden_size = config.hidden_size
973
  self.num_heads = config.num_attention_heads
974
  self.head_dim = self.hidden_size // self.num_heads
 
981
  self.gradient_checkpointing = False
982
  self.post_init()
983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
984
  def get_input_embeddings(self):
985
  return self.embed_tokens
986
 
 
1019
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
1020
  use_cache = False
1021
 
 
1022
  return_legacy_cache = False
1023
  if use_cache and not isinstance(past_key_values, Cache):
1024
  return_legacy_cache = True
 
1032
  "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)")
1033
 
1034
  if inputs_embeds is None:
1035
+ inputs_embeds = self.embed_tokens(input_ids)
1036
 
1037
  if cache_position is None:
1038
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1039
  cache_position = torch.arange(past_seen_tokens,
1040
  past_seen_tokens + inputs_embeds.shape[1],
1041
  device=inputs_embeds.device)
 
1042
  if position_ids is None:
1043
  position_ids = cache_position.unsqueeze(0)
1044
+
1045
  causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values,
1046
  output_attentions)
1047
 
1048
  hidden_states = inputs_embeds
1049
  bsz, q_len, _ = hidden_states.size()
 
1050
  position_embeddings = self.rotary_emb(hidden_states, seq_len=q_len)
1051
 
 
1052
  all_hidden_states = () if output_hidden_states else None
1053
  all_self_attns = () if output_attentions else None
1054
  next_decoder_cache = None
 
1083
 
1084
  hidden_states = layer_outputs[0]
1085
 
 
 
 
 
1086
  if use_cache:
1087
  next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1088
 
1089
  if output_attentions:
1090
  all_self_attns += (layer_outputs[1], )
1091
 
1092
+ hidden_states = self.norm(hidden_states)
1093
+
 
 
1094
  if output_hidden_states:
1095
  all_hidden_states += (hidden_states, )
1096
 
 
1121
  output_attentions: bool,
1122
  ):
1123
  if self.config._attn_implementation == "flash_attention_2":
 
 
1124
  if attention_mask is not None and 0.0 in attention_mask:
1125
  return attention_mask
1126
  return None
 
1147
  dtype, device = input_tensor.dtype, input_tensor.device
1148
  min_dtype = torch.finfo(dtype).min
1149
  sequence_length = input_tensor.shape[1]
1150
+
1151
  # SlidingWindowCache or StaticCache
1152
  if using_sliding_window_cache or using_static_cache:
1153
  target_length = past_key_values.get_max_cache_shape()
 
1242
  return causal_mask
1243
 
1244
 
 
1245
  class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
1246
  _tied_weights_keys = ["lm_head.weight"]
1247
 
 
1249
  super().__init__(config)
1250
  self.model = MotifModel(config)
1251
  self.vocab_size = config.vocab_size
 
1252
 
1253
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
 
 
 
 
 
 
1254
 
1255
  # Initialize weights and apply final processing
1256
  self.post_init()
1257
 
 
 
 
 
 
 
 
 
 
1258
  if getattr(config, "tie_word_embeddings", True):
 
1259
  self.tie_weights()
 
 
 
1260
 
1261
  def get_input_embeddings(self):
1262
  return self.model.embed_tokens
 
1276
  def get_decoder(self):
1277
  return self.model
1278
 
1279
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1280
  @add_start_docstrings_to_model_forward(MOTIF_INPUTS_DOCSTRING)
1281
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1282
  def forward(
 
1332
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1333
 
1334
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
 
 
1335
  outputs: MotifModelOutputWithPast = self.model(
1336
  input_ids=input_ids,
1337
  attention_mask=attention_mask,
 
1343
  output_hidden_states=output_hidden_states,
1344
  return_dict=return_dict,
1345
  cache_position=cache_position,
 
 
1346
  )
1347
 
1348
  hidden_states = outputs[0]
1349
 
 
 
 
 
 
 
 
 
 
 
 
1350
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1351
+ hidden_states = hidden_states
1352
  logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
1353
  logits = logits.float()
1354
 
 
1362
  loss_fct = CrossEntropyLoss()
1363
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
1364
  shift_labels = shift_labels.view(-1)
 
1365
  shift_labels = shift_labels.to(shift_logits.device)
1366
  loss = loss_fct(shift_logits, shift_labels)
1367
 
 
1375
  past_key_values=outputs.past_key_values,
1376
  hidden_states=outputs.hidden_states,
1377
  attentions=outputs.attentions,
1378
+ )