Update modeling_motif.py (#1)
Browse files- Update modeling_motif.py (f362538a86e469b8b3e07d66dd1d2f23aa4ef81b)
- bugfix (91c40cee07724f4fa12fec66cafa1a2546fd2dbd)
- Update modeling_motif.py (a55dcfd62edeef5966ab77bd04f3622b24cb4175)
- Update modeling_motif.py (2a76ec89b014032a79e5b67a9af06cd960efb51b)
- Update config.json (c28133974bedd8beb8e4d50db3c466e58798d139)
- Update modeling_motif.py (6230f7725953151d5c33872ba28180d75e9ecc7c)
- Update modeling_motif.py (db404ffc2c810b0325e72b181b96bb67535ec217)
- Update config.json (0d851ca4ecccf20f0e1368c2ff83c6e802d23e6d)
- Update modeling_motif.py (607612fcbf1067b65f29c92a4326ea261a97e0e1)
- Update modeling_motif.py (9b405398aa85b7621786d0793cd05f1343e661d9)
- Update modeling_motif.py (bdd0329e7e0cf8c3e2095f665244351a7e1e4c55)
- Update modeling_motif.py (38eae03f20a4d46b13ef774fab49a4043b2a5697)
- Update modeling_motif.py (8855d030415e3f65179c8aad8b3733a6d0ad644d)
- Update modeling_motif.py (20b97f12bd6cfbd1115d7c8d2e3ba19a27d55471)
- Update modeling_motif.py (72cc86d5a3b4d20faf06536fa16a841811c7b475)
- Update modeling_motif.py (80e1a1c498842e2d58fd8e7abf320753eb758ddc)
- Update config.json (9be8a4a64fa7d9cdcb7ec479918fa03034efde8a)
- Update modeling_motif.py (6d0fba540a217fb66acc35aca103c1981aa79ef0)
- Update modeling_motif.py (bd7180cf47aa477313e74d5b739105f2700bdddf)
- Update modeling_motif.py (5ac76cbcfa8efa114a15d46288178215c919dd7f)
- Update modeling_motif.py (5bd1ac1f84c476bf5b1a4fe7535db793809906de)
- Update modeling_motif.py (a4c14b0328abf4e3ae6b0213d777f75e0f3432e7)
- Update config.json (22c1361731f8565e0776c7aee2a0aa460657aa09)
- Update modeling_motif.py (7ebd625f70f9cce13df77f0520364d6f3b26994d)
- Update modeling_motif.py (d48c60537b21d0c3757b2292b47f6b1055e7ef1b)
- Update config.json (756b7c09f36b5a5bb689e8efbd894fd20c5699e5)
- Update config.json (8f9e2731310565ac33e8f6960b19c56be99698d2)
- Update config.json (4321f887f6e1fa1161e114d91ccacb2b8361e750)
- Update config.json (da650a07d5165dde18b47a5b07f5f8de47365f79)
- Update config.json (76d2431ef87346a08c6187122cd60964c3d035db)
- Update modeling_motif.py (cce6844530ccb252625158c7e4bfd5d603f0e475)
- Update modeling_motif.py (c4b7f5e79b10d819c7ced148c4aa2340bb4b82c8)
- Update config.json (498f8bc0c383405c33be3a06a1182ada10ff212d)
- Update modeling_motif.py (f76fc654662109d4549c8ed558fc641137e8a91d)
- Update config.json (8dbdeff6683e4e6555c9c4cf703eee8473de8059)
- Update config.json (1d63261c89ce9d31071d93ac43201d94937fe089)
- Update modeling_motif.py (cfa11bc205bf18888fe7f8dff1ecdce292a615a8)
- Update config.json (14c71d590d1b610bca1480c9066fdd9e18cdfdcb)
- Update config.json (a95fee7e4e0105080cf758bec90d2f6a16205702)
- Update modeling_motif.py (83d1f84e9a8c3bde8109b2a66e38d06f04972b39)
- Update config.json (37ed6924ae8c26c9fe130435bb1c23399dc9a51d)
- Update config.json (eb7a67bdbdc31bcc931a7eacc44f4a4a07751b17)
- Update modeling_motif.py (27913c020c6192fcde29c9e18165a9db8280bc8d)
- Update modeling_motif.py (ba7c5761e3eba518b64edf32d46422fec2175554)
- Update modeling_motif.py (7ec81c6e7fea3393085d6d9d94017bd4d8cbbfd6)
- Update configuration_motif.py (7d2479ad9374145522214574abfda2d012c6d5a4)
- Update modeling_motif.py (d56ef752fd1db815e87b88e6bf4791f558d6363b)
- Update modeling_motif.py (6187f7768db6406e4ba951b99281eb2274c92130)
- Update modeling_motif.py (a162402203d84f6c89c8159070e1457f5613005c)
- Update modeling_motif.py (0ff3917f6694c08e6768a2940b4ba8f508b209b8)
- Update modeling_motif.py (4293a010c114394bff10c64a621d22937df61f46)
- Update modeling_motif.py (7d405b99a275c6241c34279f548dc9a7680c6906)
- Update modeling_motif.py (03a80ebc5d1df4a3f81667d6a2d54608957cd57e)
- Update config.json (c8cabde5806c7009379fb69d1adfe066d2dab70f)
- Update modeling_motif.py (8bdf2ec91a74f9cfda0557ad56999915cb7ac9d1)
- Update modeling_motif.py (097873e443d8e170c89992527e03615d15a71a2e)
- Update modeling_motif.py (95a3a69d05f9284a6241310ba17a92d8db544c02)
- Update README.md (a14cdf2cc082f3be9d297d14287bb9ae7d4d3f88)
- Update configuration_motif.py (7ed4264a3a35b2b61b29dc31c9833135f4b812cd)
- Update generation_config.json (8d10c7cefc977930e3dc46c50b67409bcc28e650)
- Update README.md (48ae3c671d073dfff611e9e1c209fb2ae5f78b43)
- Update README.md (96a8455fc6078a63ad4bb3f340a76ff1c21bc7d7)
- Update README.md (9a018e3364b5300685f90cb5ceebbeb63fb78799)
- Update configuration_motif.py (a77c948640d4559106e5c0e0b59fb8c58a9af0f6)
Co-authored-by: Eunhwan Park <[email protected]>
- README.md +37 -1
- config.json +5 -59
- configuration_motif.py +5 -89
- generation_config.json +1 -1
- modeling_motif.py +108 -1001
@@ -195,4 +195,40 @@ The benchmarks and corresponding scores listed in the table below are taken dire
|
|
195 |
|MBPP|0-shot|53.9|62.2|60.3|+11.87%|-3.05%|
|
196 |
|MBPP+|0-shot|44.4|50.6|50.8|+14.41%|+0.40%|
|
197 |
|MultiPL-E|0-shot|22.6|34.9|-|-|-|
|
198 |
-
|||||**Average**|**+18.55%**|**+1.12%**|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
|MBPP|0-shot|53.9|62.2|60.3|+11.87%|-3.05%|
|
196 |
|MBPP+|0-shot|44.4|50.6|50.8|+14.41%|+0.40%|
|
197 |
|MultiPL-E|0-shot|22.6|34.9|-|-|-|
|
198 |
+
|||||**Average**|**+18.55%**|**+1.12%**|
|
199 |
+
|
200 |
+
|
201 |
+
## How to use
|
202 |
+
```python
|
203 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
204 |
+
|
205 |
+
model = AutoModelForCausalLM.from_pretrained(
|
206 |
+
"Motif-Technologies/Motif-2.6B",
|
207 |
+
trust_remote_code = True,
|
208 |
+
_attn_implementation = "eager", # also supports flash_attention_2
|
209 |
+
).cuda()
|
210 |
+
|
211 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
212 |
+
"Motif-Technologies/Motif-2.6B",
|
213 |
+
trust_remote_code = True,
|
214 |
+
)
|
215 |
+
|
216 |
+
query = "What is the capital city of South Korea?"
|
217 |
+
input_ids = tokenizer.apply_chat_template(
|
218 |
+
[
|
219 |
+
{'role': 'system', 'content': 'you are an helpful assistant'},
|
220 |
+
{'role': 'user', 'content': query},
|
221 |
+
],
|
222 |
+
add_generation_prompt = True,
|
223 |
+
return_tensors='pt',
|
224 |
+
).cuda()
|
225 |
+
|
226 |
+
output = model.generate(input_ids, max_new_tokens=128, pad_token_id=tokenizer.eos_token_id)
|
227 |
+
output = tokenizer.decode(res[0, input_ids.shape[-1]:], skip_special_tokens = True)
|
228 |
+
print(output)
|
229 |
+
|
230 |
+
"""
|
231 |
+
The capital city of South Korea is Seoul. Located in the southern part of the country, Seoul is not only the largest city in South Korea but also one of the largest metropolitan areas in the world.
|
232 |
+
It is a vibrant and dynamic city known for its rich history, cultural heritage, and modern amenities. Seoul is a major economic, cultural, and political center in East Asia, and it plays a crucial role in the region's politics, economy, and culture.
|
233 |
+
The city is divided into different administrative districts, each with its own unique characteristics and attractions.
|
234 |
+
"""
|
@@ -8,82 +8,28 @@
|
|
8 |
"AutoConfig": "configuration_motif.MotifConfig",
|
9 |
"AutoModelForCausalLM": "modeling_motif.MotifForCausalLM"
|
10 |
},
|
11 |
-
"bfloat16": true,
|
12 |
"bos_token_id": 219396,
|
13 |
-
"continual_training": false,
|
14 |
-
"decoder_split_layers": [],
|
15 |
-
"decontam_attn": false,
|
16 |
-
"dim_model_base": 2048,
|
17 |
-
"dim_model_base_attn": 128,
|
18 |
-
"dim_model_base_init": 2048,
|
19 |
-
"dim_model_base_lmh": 1,
|
20 |
-
"dim_model_base_logits": 2048,
|
21 |
-
"dim_model_base_lr": 256,
|
22 |
-
"down_proj_alpha": 0.15625,
|
23 |
-
"embed_tokens_alpha": null,
|
24 |
-
"encoder_split_layers": [],
|
25 |
"eos_token_id": 219395,
|
26 |
-
"first_expansion": false,
|
27 |
-
"fused_rope": true,
|
28 |
-
"gate_up_proj_alpha": 0.15625,
|
29 |
"hidden_act": "poly_norm",
|
30 |
-
"hidden_act_moe": null,
|
31 |
"hidden_size": 2048,
|
32 |
-
"hidden_states_shrink": 0.17677669529663687,
|
33 |
-
"init_scale_o": 1,
|
34 |
"initializer_range": 2e-05,
|
35 |
-
"input_layernorm_alpha": null,
|
36 |
"intermediate_size": 8192,
|
37 |
-
"k_proj_alpha": 0.15625,
|
38 |
-
"lm_head_alpha": null,
|
39 |
"loss_reduction": "mean",
|
40 |
"max_position_embeddings": 16384,
|
41 |
"max_window_layers": 28,
|
42 |
-
"mix_attn": false,
|
43 |
"model_type": "Motif",
|
44 |
-
"moe": false,
|
45 |
-
"moe_intermediate_size": null,
|
46 |
-
"moe_layer": false,
|
47 |
-
"muP": false,
|
48 |
-
"multi_token_heads": null,
|
49 |
-
"n_group": null,
|
50 |
-
"n_routed_experts": null,
|
51 |
-
"norm_alpha": null,
|
52 |
-
"norm_topk_prob": null,
|
53 |
"num_attention_heads": 16,
|
54 |
"num_hidden_layers": 32,
|
55 |
"num_key_value_heads": 16,
|
56 |
-
"num_stages": false,
|
57 |
-
"o_proj_alpha": 0.15625,
|
58 |
-
"post_attention_layernorm_alpha": null,
|
59 |
-
"q_proj_alpha": 0.15625,
|
60 |
"rms_norm_eps": 1e-06,
|
61 |
"rope_scaling": null,
|
62 |
"rope_theta": 500000.0,
|
63 |
-
"routed_scaling_factor": null,
|
64 |
-
"scale_emb": 1,
|
65 |
-
"scoring_func": null,
|
66 |
-
"seq_aux": null,
|
67 |
"sliding_window": null,
|
68 |
-
"tensor_parallel": true,
|
69 |
"tie_word_embeddings": true,
|
70 |
-
"
|
71 |
-
"
|
72 |
-
"torch_dtype": "float32",
|
73 |
-
"transformers_version": "4.51.3",
|
74 |
-
"use_advanced_parallelization": true,
|
75 |
"use_bias": false,
|
76 |
-
"use_cache":
|
77 |
-
"use_emb_alpha": false,
|
78 |
-
"use_fused_mlp": null,
|
79 |
-
"use_moreh_attention": true,
|
80 |
-
"use_moreh_moe": false,
|
81 |
-
"use_mrope": false,
|
82 |
-
"use_norm_alpha": false,
|
83 |
-
"use_pipeline": false,
|
84 |
-
"use_qk_norm": false,
|
85 |
"use_sliding_window": false,
|
86 |
-
"
|
87 |
-
|
88 |
-
"wesar_weights": false
|
89 |
-
}
|
|
|
8 |
"AutoConfig": "configuration_motif.MotifConfig",
|
9 |
"AutoModelForCausalLM": "modeling_motif.MotifForCausalLM"
|
10 |
},
|
|
|
11 |
"bos_token_id": 219396,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
"eos_token_id": 219395,
|
|
|
|
|
|
|
13 |
"hidden_act": "poly_norm",
|
|
|
14 |
"hidden_size": 2048,
|
|
|
|
|
15 |
"initializer_range": 2e-05,
|
|
|
16 |
"intermediate_size": 8192,
|
|
|
|
|
17 |
"loss_reduction": "mean",
|
18 |
"max_position_embeddings": 16384,
|
19 |
"max_window_layers": 28,
|
|
|
20 |
"model_type": "Motif",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
"num_attention_heads": 16,
|
22 |
"num_hidden_layers": 32,
|
23 |
"num_key_value_heads": 16,
|
|
|
|
|
|
|
|
|
24 |
"rms_norm_eps": 1e-06,
|
25 |
"rope_scaling": null,
|
26 |
"rope_theta": 500000.0,
|
|
|
|
|
|
|
|
|
27 |
"sliding_window": null,
|
|
|
28 |
"tie_word_embeddings": true,
|
29 |
+
"torch_dtype": "bfloat16",
|
30 |
+
"transformers_version": "4.46.3",
|
|
|
|
|
|
|
31 |
"use_bias": false,
|
32 |
+
"use_cache": true,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
"use_sliding_window": false,
|
34 |
+
"vocab_size": 219520
|
35 |
+
}
|
|
|
|
@@ -1,8 +1,9 @@
|
|
|
|
|
|
|
|
1 |
from transformers.configuration_utils import PretrainedConfig
|
2 |
from transformers.modeling_rope_utils import rope_config_validation
|
3 |
from transformers.utils import logging
|
4 |
-
from typing import Optional
|
5 |
-
import math
|
6 |
|
7 |
logger = logging.get_logger(__name__)
|
8 |
|
@@ -13,11 +14,8 @@ class MotifConfig(PretrainedConfig):
|
|
13 |
Motif model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
14 |
with the defaults will yield a similar configuration to that of
|
15 |
Motif-102B [moreh/Motif-102B](https://huggingface.co/moreh/Motif-102B).
|
16 |
-
|
17 |
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
18 |
documentation from [`PretrainedConfig`] for more information.
|
19 |
-
|
20 |
-
|
21 |
Args:
|
22 |
vocab_size (`int`, *optional*, defaults to 151936):
|
23 |
Vocabulary size of the Motif model. Defines the number of different tokens that can be represented by the
|
@@ -97,16 +95,12 @@ class MotifConfig(PretrainedConfig):
|
|
97 |
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
98 |
attention_dropout (`float`, *optional*, defaults to 0.0):
|
99 |
The dropout ratio for the attention probabilities.
|
100 |
-
|
101 |
```python
|
102 |
>>> from transformers import MotifModel, MotifConfig
|
103 |
-
|
104 |
>>> # Initializing a Motif style configuration
|
105 |
>>> configuration = MotifConfig()
|
106 |
-
|
107 |
>>> # Initializing a model from the Motif-102B style configuration
|
108 |
>>> model = MotifModel(configuration)
|
109 |
-
|
110 |
>>> # Accessing the model configuration
|
111 |
>>> configuration = model.config
|
112 |
```"""
|
@@ -134,13 +128,8 @@ class MotifConfig(PretrainedConfig):
|
|
134 |
sliding_window=4096,
|
135 |
max_window_layers=28,
|
136 |
attention_dropout=0.0,
|
137 |
-
multi_token_heads: Optional[int] = None,
|
138 |
**kwargs,
|
139 |
):
|
140 |
-
"""
|
141 |
-
Arguments:
|
142 |
-
multi_token_heads: If not None, use multi-token heads as in the paper https://arxiv.org/pdf/2404.19737
|
143 |
-
"""
|
144 |
|
145 |
self.vocab_size = vocab_size
|
146 |
self.max_position_embeddings = max_position_embeddings
|
@@ -165,87 +154,14 @@ class MotifConfig(PretrainedConfig):
|
|
165 |
self.rope_scaling = rope_scaling
|
166 |
self.attention_dropout = attention_dropout
|
167 |
|
168 |
-
###kwargs
|
169 |
-
|
170 |
-
# some scale factors
|
171 |
-
|
172 |
-
self.scale_emb = getattr(kwargs, "scale_emb", 1)
|
173 |
-
self.init_scale_o = getattr(kwargs, "init_scale_o", 1)
|
174 |
-
|
175 |
-
# muparam
|
176 |
-
self.hidden_states_shrink = 1 / math.sqrt(num_hidden_layers)
|
177 |
-
self.dim_model_base = hidden_size
|
178 |
-
self.dim_model_base_attn = (hidden_size // num_attention_heads)
|
179 |
-
self.dim_model_base_init = hidden_size
|
180 |
-
self.dim_model_base_lr = getattr(kwargs, "dim_model_base_lr", hidden_size//8)
|
181 |
-
self.dim_model_base_lmh = 1
|
182 |
-
self.dim_model_base_logits = hidden_size
|
183 |
-
|
184 |
-
self.muP = getattr(kwargs, "muP", False)
|
185 |
-
# proxy hidden size ( following YuLan-Mini )
|
186 |
-
# reparameterization(wesar_weights)
|
187 |
-
logger.info(kwargs)
|
188 |
-
self.wesar_weights = getattr(kwargs, "wesar_weights", False)
|
189 |
-
logger.info(f'initial wesar reparameterization : {self.wesar_weights}')
|
190 |
-
|
191 |
-
# alpha (scale factor)
|
192 |
-
self.embed_tokens_alpha = getattr(kwargs, "embed_tokens_alpha", None)
|
193 |
-
self.q_proj_alpha = getattr(kwargs, "q_proj_alpha", None)
|
194 |
-
self.k_proj_alpha = getattr(kwargs, "k_proj_alpha", None)
|
195 |
-
self.v_proj_alpha = getattr(kwargs, "v_proj_alpha", None)
|
196 |
-
self.o_proj_alpha = getattr(kwargs, "o_proj_alpha", None)
|
197 |
-
self.down_proj_alpha = getattr(kwargs, "down_proj_alpha", None)
|
198 |
-
self.gate_up_proj_alpha = getattr(kwargs, "gate_up_proj_alpha", None)
|
199 |
-
self.input_layernorm_alpha = getattr(kwargs, "input_layernorm_alpha", None)
|
200 |
-
self.post_attention_layernorm_alpha = getattr(kwargs, "post_attention_layernorm_alpha", None)
|
201 |
-
self.norm_alpha = getattr(kwargs, "norm_alpha", None)
|
202 |
-
self.lm_head_alpha = getattr(kwargs, "lm_head_alpha", None)
|
203 |
-
self.use_norm_alpha = getattr(kwargs, "use_norm_alpha", False)
|
204 |
-
self.use_emb_alpha = getattr(kwargs, "use_emb_alpha", False)
|
205 |
-
|
206 |
# Validate the correctness of rotary position embeddings parameters
|
207 |
# BC: if there is a 'type' field, move it to 'rope_type'.
|
208 |
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
209 |
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
210 |
rope_config_validation(self)
|
211 |
-
|
212 |
-
self.multi_token_heads = multi_token_heads
|
213 |
-
self.multi_token_config_validation()
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
# moe
|
218 |
-
self.topk_method = getattr(kwargs, "topk_method", None)
|
219 |
-
self.scoring_func = getattr(kwargs, "scoring_func", None)
|
220 |
-
self.routed_scaling_factor = getattr(kwargs, "routed_scaling_factor", None)
|
221 |
-
self.norm_topk_prob = getattr(kwargs, "norm_topk_prob", None)
|
222 |
-
self.seq_aux = getattr(kwargs, "seq_aux", None)
|
223 |
-
self.hidden_act_moe = getattr(kwargs, "hidden_act_moe", None)
|
224 |
-
|
225 |
-
|
226 |
-
self.n_group = getattr(kwargs, "n_group", None)
|
227 |
-
self.n_routed_experts = getattr(kwargs, "n_routed_experts", None)
|
228 |
-
self.moe_intermediate_size = getattr(kwargs, "moe_intermediate_size", None)
|
229 |
-
self.topk_group = getattr(kwargs, "topk_group", None)
|
230 |
-
|
231 |
-
|
232 |
-
self.use_fused_mlp = getattr(kwargs, "use_fused_mlp", None)
|
233 |
-
self.use_moreh_moe = getattr(kwargs, "use_moreh_moe", False)
|
234 |
-
self.continual_training = getattr(kwargs, "continual_training", False)
|
235 |
-
|
236 |
-
# external
|
237 |
-
self.first_expansion = getattr(kwargs, "first_expansion", False)
|
238 |
-
self.moe_layer = getattr(kwargs, "moe_layer", False)
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
super().__init__(
|
243 |
tie_word_embeddings=tie_word_embeddings,
|
244 |
**kwargs,
|
245 |
)
|
246 |
-
logger.info(f' kwargs : {kwargs}')
|
247 |
-
logger.info(f'after wesar reparameterization : {self.wesar_weights}')
|
248 |
-
|
249 |
-
def multi_token_config_validation(self):
|
250 |
-
if self.multi_token_heads is not None:
|
251 |
-
assert isinstance(self.multi_token_heads, int) and self.multi_token_heads >= 1
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
from transformers.configuration_utils import PretrainedConfig
|
5 |
from transformers.modeling_rope_utils import rope_config_validation
|
6 |
from transformers.utils import logging
|
|
|
|
|
7 |
|
8 |
logger = logging.get_logger(__name__)
|
9 |
|
|
|
14 |
Motif model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
15 |
with the defaults will yield a similar configuration to that of
|
16 |
Motif-102B [moreh/Motif-102B](https://huggingface.co/moreh/Motif-102B).
|
|
|
17 |
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
18 |
documentation from [`PretrainedConfig`] for more information.
|
|
|
|
|
19 |
Args:
|
20 |
vocab_size (`int`, *optional*, defaults to 151936):
|
21 |
Vocabulary size of the Motif model. Defines the number of different tokens that can be represented by the
|
|
|
95 |
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
96 |
attention_dropout (`float`, *optional*, defaults to 0.0):
|
97 |
The dropout ratio for the attention probabilities.
|
|
|
98 |
```python
|
99 |
>>> from transformers import MotifModel, MotifConfig
|
|
|
100 |
>>> # Initializing a Motif style configuration
|
101 |
>>> configuration = MotifConfig()
|
|
|
102 |
>>> # Initializing a model from the Motif-102B style configuration
|
103 |
>>> model = MotifModel(configuration)
|
|
|
104 |
>>> # Accessing the model configuration
|
105 |
>>> configuration = model.config
|
106 |
```"""
|
|
|
128 |
sliding_window=4096,
|
129 |
max_window_layers=28,
|
130 |
attention_dropout=0.0,
|
|
|
131 |
**kwargs,
|
132 |
):
|
|
|
|
|
|
|
|
|
133 |
|
134 |
self.vocab_size = vocab_size
|
135 |
self.max_position_embeddings = max_position_embeddings
|
|
|
154 |
self.rope_scaling = rope_scaling
|
155 |
self.attention_dropout = attention_dropout
|
156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
# Validate the correctness of rotary position embeddings parameters
|
158 |
# BC: if there is a 'type' field, move it to 'rope_type'.
|
159 |
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
160 |
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
161 |
rope_config_validation(self)
|
162 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
super().__init__(
|
164 |
tie_word_embeddings=tie_word_embeddings,
|
165 |
**kwargs,
|
166 |
)
|
167 |
+
logger.info(f' kwargs : {kwargs}')
|
|
|
|
|
|
|
|
|
|
@@ -6,5 +6,5 @@
|
|
6 |
219405
|
7 |
],
|
8 |
"transformers_version": "4.51.3",
|
9 |
-
"use_cache":
|
10 |
}
|
|
|
6 |
219405
|
7 |
],
|
8 |
"transformers_version": "4.51.3",
|
9 |
+
"use_cache": true
|
10 |
}
|
@@ -1,178 +1,32 @@
|
|
1 |
import math
|
|
|
2 |
from typing import List, Optional, Tuple, Union
|
3 |
|
4 |
import torch
|
|
|
5 |
import torch.utils.checkpoint
|
6 |
from torch import nn
|
7 |
from torch.nn import CrossEntropyLoss
|
|
|
|
|
8 |
from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
9 |
from transformers.generation import GenerationMixin
|
10 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
11 |
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
12 |
-
from transformers.modeling_outputs import
|
13 |
-
CausalLMOutputWithPast,
|
14 |
-
ModelOutput,
|
15 |
-
)
|
16 |
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
17 |
from transformers.modeling_utils import PreTrainedModel
|
18 |
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
19 |
-
from transformers.utils import (
|
20 |
-
|
21 |
-
add_start_docstrings_to_model_forward,
|
22 |
-
is_flash_attn_greater_or_equal_2_10,
|
23 |
-
is_flash_attn_2_available,
|
24 |
-
logging,
|
25 |
-
replace_return_docstrings,
|
26 |
-
)
|
27 |
-
from .configuration_motif import MotifConfig
|
28 |
-
from dataclasses import dataclass
|
29 |
-
|
30 |
-
import torch.nn.functional as F
|
31 |
-
import time
|
32 |
-
|
33 |
-
logger = logging.get_logger(__name__)
|
34 |
-
|
35 |
-
if is_flash_attn_2_available():
|
36 |
-
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
37 |
-
|
38 |
-
try:
|
39 |
-
moreh_ops = torch.ops.moreh
|
40 |
-
MorehRMSNorm = moreh_ops.T5LayerNorm
|
41 |
-
ScaledDotProductAttention = moreh_ops.scaled_dot_product_attention
|
42 |
-
MorehFlashAttention = moreh_ops.flash_attention
|
43 |
-
logger.warning_once("Using moreh ops")
|
44 |
-
except AttributeError:
|
45 |
-
MorehRMSNorm = None
|
46 |
-
ScaledDotProductAttention = None
|
47 |
-
MorehFlashAttention = None
|
48 |
-
logger.warning_once("Failed to import moreh ops")
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
# DEBUG = False
|
53 |
-
# logger.info(f"DEBUG: {DEBUG} : will log timing")
|
54 |
-
# def log_timing(obj):
|
55 |
-
# """Decorator to log timing of function or class execution"""
|
56 |
-
# if isinstance(obj, type):
|
57 |
-
# # If decorating a class
|
58 |
-
# class TimedClass(obj):
|
59 |
-
# def __getattribute__(self, name):
|
60 |
-
# attr = super().__getattribute__(name)
|
61 |
-
# if callable(attr) and not name.startswith('__'):
|
62 |
-
# def timed_method(*args, **kwargs):
|
63 |
-
# if not DEBUG:
|
64 |
-
# return attr(*args, **kwargs)
|
65 |
-
# if name != "forward":
|
66 |
-
# return attr(*args, **kwargs)
|
67 |
-
|
68 |
-
# start_time = time.time()
|
69 |
-
# logger.info(f"Entering {obj.__name__}.{name}")
|
70 |
-
# result = attr(*args, **kwargs)
|
71 |
-
# end_time = time.time()
|
72 |
-
# logger.info(f"Exiting {obj.__name__}.{name}, took {end_time - start_time:.4f} seconds")
|
73 |
-
# return result
|
74 |
-
# return timed_method
|
75 |
-
# return attr
|
76 |
-
# return TimedClass
|
77 |
-
# else:
|
78 |
-
# # If decorating a function
|
79 |
-
# def wrapper(*args, **kwargs):
|
80 |
-
# if not DEBUG:
|
81 |
-
# return obj(*args, **kwargs)
|
82 |
-
|
83 |
-
# start_time = time.time()
|
84 |
-
# logger.info(f"Entering {obj.__name__}")
|
85 |
-
# result = obj(*args, **kwargs)
|
86 |
-
# end_time = time.time()
|
87 |
-
# logger.info(f"Exiting {obj.__name__}, took {end_time - start_time:.4f} seconds")
|
88 |
-
# return result
|
89 |
-
# return wrapper
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
#_CHECKPOINT_FOR_DOC = "moreh/Motif-102B"
|
94 |
-
_CONFIG_FOR_DOC = "MotifConfig"
|
95 |
-
|
96 |
-
#from .moreh_moe import MorehMoeMLP, MorehMoeFusedMLP
|
97 |
-
|
98 |
-
import torch
|
99 |
-
from transformers.activations import ACT2CLS as _ACT2CLS
|
100 |
-
from transformers.activations import ClassInstantier
|
101 |
-
moreh_ops = torch.ops.moreh
|
102 |
-
|
103 |
-
from typing import Callable, Dict, List, Tuple
|
104 |
-
|
105 |
-
import torch
|
106 |
-
|
107 |
-
|
108 |
-
# @log_timing
|
109 |
-
def multi_head_forward_backward(shared_activation: torch.Tensor,
|
110 |
-
head_fns: List[Callable[[torch.Tensor], Dict[str, torch.Tensor]]],
|
111 |
-
return_keys=("loss", ),
|
112 |
-
return_only_first_head=True) -> Tuple[torch.Tensor, ...]:
|
113 |
-
"""
|
114 |
-
The forward-backward pattern introduced in the paper https://arxiv.org/abs/2404.19737
|
115 |
-
to reduce memory overhead due to activations from multiple heads.
|
116 |
-
|
117 |
-
Args:
|
118 |
-
- shared_activation: the shared activation across all heads
|
119 |
-
- head_fns: the head-wise forward computations that start from `shared_activation`.
|
120 |
-
it should output a dictionary of tensors with keys matching `return_keys`
|
121 |
-
- return_keys: the keys to return in order
|
122 |
-
- return_only_first_head: whether to return only the values from the first head
|
123 |
-
|
124 |
-
Returns:
|
125 |
-
- a tuple of return tensors
|
126 |
-
|
127 |
-
Side effect:
|
128 |
-
- (only when `torch.is_grad_enabled()`)
|
129 |
-
the gradients accumulated as if `sum(head_fn(shared_activation)["loss"] for head_fn in head_fns).backward()` had been called
|
130 |
-
"""
|
131 |
-
if not return_only_first_head:
|
132 |
-
raise NotImplementedError
|
133 |
-
|
134 |
-
return_key_set = set(return_keys)
|
135 |
-
if "loss" not in return_key_set:
|
136 |
-
raise Exception("'loss' is a required return key.")
|
137 |
-
|
138 |
-
detached_shared_activation = shared_activation.detach()
|
139 |
-
detached_shared_activation.requires_grad = True
|
140 |
-
return_values = {key: None for key in return_keys}
|
141 |
-
for head_idx, head_fn in enumerate(head_fns):
|
142 |
-
if head_idx > 0 and not torch.is_grad_enabled():
|
143 |
-
continue
|
144 |
-
|
145 |
-
# forward pass for the head
|
146 |
-
headwise_outputs = head_fn(detached_shared_activation)
|
147 |
-
if set(headwise_outputs.keys()) != return_key_set:
|
148 |
-
raise Exception(f"Headwise output keys {headwise_outputs.keys()} do not match return keys {return_keys}.")
|
149 |
-
|
150 |
-
# backward pass for the head
|
151 |
-
# effect 1: the parameters of the head
|
152 |
-
# effect 2: gradient accumulated in `detached_shared_activation.grad`
|
153 |
-
if torch.is_grad_enabled():
|
154 |
-
headwise_loss = headwise_outputs["loss"]
|
155 |
-
headwise_loss.backward(
|
156 |
-
) # NOTE: You do not need to retain graph since no graph is shared across backward passes
|
157 |
-
|
158 |
-
if head_idx == 0:
|
159 |
-
for key in return_keys:
|
160 |
-
return_values[key] = headwise_outputs[key]
|
161 |
-
|
162 |
-
assert all(value is not None for value in return_values.values())
|
163 |
-
|
164 |
-
# backward pass for the shared part
|
165 |
-
if torch.is_grad_enabled():
|
166 |
-
shared_activation.backward(detached_shared_activation.grad)
|
167 |
|
168 |
-
|
169 |
|
170 |
|
171 |
class PolyNorm(torch.nn.Module):
|
172 |
-
"""
|
173 |
A trainable activation function introduced in https://arxiv.org/html/2411.03884v1.
|
174 |
-
The code is copied from https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
|
175 |
-
with the change `* torch.rsqrt` => `/ torch.sqrt` for potential MAF incompatibility.
|
176 |
"""
|
177 |
|
178 |
def __init__(self, eps=1e-6):
|
@@ -189,29 +43,16 @@ class PolyNorm(torch.nn.Module):
|
|
189 |
x ** 2) + self.weight[2] * self._norm(x) + self.bias
|
190 |
|
191 |
|
192 |
-
|
193 |
-
"""
|
194 |
-
A trainable activation function introduced in https://arxiv.org/html/2411.03884v1.
|
195 |
-
The code is copied from https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md,
|
196 |
-
with the change `* torch.rsqrt` => `/ torch.sqrt` for potential MAF incompatibility.
|
197 |
-
"""
|
198 |
-
|
199 |
-
def __init__(self, eps=1e-6):
|
200 |
-
super(PolyNorm_Test, self).__init__()
|
201 |
-
self.weight = torch.nn.Parameter(torch.ones(3) / 3)
|
202 |
-
self.bias = torch.nn.Parameter(torch.zeros(1))
|
203 |
-
self.eps = eps
|
204 |
-
|
205 |
-
def forward(self, x):
|
206 |
-
|
207 |
-
#return torch.nn.SiLU(x)
|
208 |
-
return moreh_ops.poly_norm(x, self.weight, self.bias)
|
209 |
-
|
210 |
-
|
211 |
-
CUSTOM_ACT2CLS = {"poly_norm": PolyNorm_Test, "poly_norm_test": PolyNorm_Test}
|
212 |
ACT2CLS = {**_ACT2CLS, **CUSTOM_ACT2CLS}
|
213 |
ACT2FN = ClassInstantier(ACT2CLS)
|
214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
|
217 |
class MotifRMSNorm(nn.Module):
|
@@ -235,7 +76,7 @@ class MotifRMSNorm(nn.Module):
|
|
235 |
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
236 |
|
237 |
|
238 |
-
ALL_LAYERNORM_LAYERS.append(MotifRMSNorm
|
239 |
|
240 |
|
241 |
class MotifRotaryEmbeddingWithCache(nn.Module):
|
@@ -267,7 +108,6 @@ class MotifRotaryEmbeddingWithCache(nn.Module):
|
|
267 |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
268 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
269 |
|
270 |
-
# Build here to make `torch.jit.trace` work.
|
271 |
self._set_cos_sin_cache(seq_len=max_position_embeddings,
|
272 |
device=self.inv_freq.device,
|
273 |
dtype=torch.get_default_dtype())
|
@@ -288,12 +128,11 @@ class MotifRotaryEmbeddingWithCache(nn.Module):
|
|
288 |
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
289 |
|
290 |
return (
|
291 |
-
self.cos_cached[
|
292 |
-
self.sin_cached[
|
293 |
)
|
294 |
|
295 |
|
296 |
-
# @log_timing
|
297 |
class MotifRotaryEmbedding(nn.Module):
|
298 |
|
299 |
def __init__(
|
@@ -324,7 +163,6 @@ class MotifRotaryEmbedding(nn.Module):
|
|
324 |
self.max_seq_len_cached = max_position_embeddings
|
325 |
self.original_max_seq_len = max_position_embeddings
|
326 |
else:
|
327 |
-
# BC: "rope_type" was originally "type"
|
328 |
if config.rope_scaling is not None:
|
329 |
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
330 |
else:
|
@@ -386,10 +224,10 @@ class MotifRotaryEmbedding(nn.Module):
|
|
386 |
def rotate_half(x):
|
387 |
"""
|
388 |
Rotates half of the dimensions of the input tensor using torch.roll and in-place negation.
|
389 |
-
|
390 |
Args:
|
391 |
x (torch.Tensor): The input tensor.
|
392 |
-
|
393 |
Returns:
|
394 |
torch.Tensor: A tensor where the latter half of the dimensions are negated
|
395 |
and moved before the first half.
|
@@ -401,8 +239,7 @@ def rotate_half(x):
|
|
401 |
return rotated_tensor
|
402 |
|
403 |
|
404 |
-
|
405 |
-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1, fused_rope=True):
|
406 |
"""
|
407 |
Applies rotary position embeddings to the input tensors.
|
408 |
|
@@ -411,438 +248,47 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1, fus
|
|
411 |
k (torch.Tensor): Key tensor of shape (B, NH, S, D_KV).
|
412 |
cos (torch.Tensor): Cosine values for rotary embedding.
|
413 |
sin (torch.Tensor): Sine values for rotary embedding.
|
414 |
-
unsqueeze_dim (int, optional): Dimension along which `cos` and `sin` are unsqueezed.
|
415 |
Defaults to 1.
|
416 |
-
fused_rope (bool, optional): If True, applies fused rotary embeddings using
|
417 |
-
`moreh_ops.apply_rotary_emb`. If False, computes rotary embeddings manually.
|
418 |
-
Defaults to False.
|
419 |
|
420 |
Returns:
|
421 |
Tuple[torch.Tensor, torch.Tensor]: Returns transformed query and key tensors after applying rotary embeddings.
|
422 |
"""
|
423 |
'''
|
424 |
-
# (B, NH, S, D_KV) -> (B, S, NH, D_KV)
|
425 |
cos = cos.unsqueeze(unsqueeze_dim)
|
426 |
sin = sin.unsqueeze(unsqueeze_dim)
|
427 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
428 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
429 |
'''
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
#sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
|
435 |
-
|
436 |
-
q = q.transpose(1, 2)
|
437 |
-
k = k.transpose(1, 2)
|
438 |
-
|
439 |
-
# Expand 'batch' dim
|
440 |
-
cos = cos.expand(q.shape[0], *cos.shape[1:])
|
441 |
-
sin = sin.expand(q.shape[0], *sin.shape[1:])
|
442 |
-
|
443 |
-
q_embed = moreh_ops.apply_rotary_emb(q, cos, sin, opcode=1)
|
444 |
-
k_embed = moreh_ops.apply_rotary_emb(k, cos, sin, opcode=1)
|
445 |
|
446 |
-
# (B, S, NH, D_KV) -> (B, NH, S, D_KV)
|
447 |
-
q_embed = q_embed.transpose(1, 2)
|
448 |
-
k_embed = k_embed.transpose(1, 2)
|
449 |
|
450 |
-
return q_embed, k_embed
|
451 |
-
|
452 |
-
|
453 |
-
# @log_timing
|
454 |
class MotifMLP(nn.Module):
|
455 |
|
456 |
def __init__(self, config):
|
457 |
super().__init__()
|
458 |
self.hidden_size = config.hidden_size
|
459 |
self.intermediate_size = config.intermediate_size
|
460 |
-
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=
|
461 |
-
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=
|
462 |
-
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=
|
463 |
self.act_fn = ACT2FN[config.hidden_act]
|
464 |
|
465 |
-
if config.wesar_weights:
|
466 |
-
self.gate_up_proj_alpha = nn.Parameter(torch.tensor(1) *config.gate_up_proj_alpha)
|
467 |
-
self.down_proj_alpha = nn.Parameter(torch.tensor(1) * config.down_proj_alpha)
|
468 |
-
else:
|
469 |
-
self.gate_up_proj_alpha=1
|
470 |
-
self.down_proj_alpha=1
|
471 |
-
if config.muP:
|
472 |
-
self.down_proj.__do_scale_tager__ = True
|
473 |
-
self.gate_proj.__do_scale_tager_mu_dim_model__ = True
|
474 |
-
self.up_proj.__do_scale_tager_mu_dim_model__ = True
|
475 |
-
self.down_proj.__do_scale_tager_mu_ffn__ = True
|
476 |
-
|
477 |
-
|
478 |
def forward(self, hidden_state):
|
479 |
-
hidden_state
|
480 |
-
#hidden_state = self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))*
|
481 |
-
return self.down_proj_alpha*self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
482 |
-
|
483 |
-
|
484 |
-
class MorehMoeFusedMLP(nn.Module):
|
485 |
-
def __init__(self,
|
486 |
-
ffn_dim,
|
487 |
-
hidden_dim,
|
488 |
-
hidden_act_moe,
|
489 |
-
num_experts,
|
490 |
-
num_groups=1,
|
491 |
-
device=None,
|
492 |
-
continual_training=False):
|
493 |
-
super().__init__()
|
494 |
-
self.ffn_dim = ffn_dim
|
495 |
-
self.hidden_dim = hidden_dim
|
496 |
-
self.hidden_act_moe = hidden_act_moe
|
497 |
-
|
498 |
-
self.num_experts = num_experts
|
499 |
-
self.num_groups = num_groups
|
500 |
-
|
501 |
-
assert self.num_experts % self.num_groups == 0
|
502 |
-
self.num_experts_per_group = self.num_experts // self.num_groups
|
503 |
-
|
504 |
-
## bsz, seq, group size, 2*ffn_size
|
505 |
-
|
506 |
-
moreh_ops = torch.ops.moreh
|
507 |
-
self.w13 = nn.ModuleList([
|
508 |
-
moreh_ops.MoeFanInLinear(self.hidden_dim,
|
509 |
-
self.ffn_dim * 2,
|
510 |
-
bias=False,
|
511 |
-
num_experts=self.num_experts_per_group,
|
512 |
-
device=device)
|
513 |
-
for _ in range(self.num_groups)
|
514 |
-
])
|
515 |
-
|
516 |
-
self.w2 = nn.ModuleList([
|
517 |
-
moreh_ops.MoeFanOutLinear(self.ffn_dim,
|
518 |
-
self.hidden_dim,
|
519 |
-
bias=False,
|
520 |
-
num_experts=self.num_experts_per_group,
|
521 |
-
device=device)
|
522 |
-
for _ in range(self.num_groups)
|
523 |
-
])
|
524 |
-
|
525 |
-
## use silu?
|
526 |
-
self.act_fn = ACT2FN[self.hidden_act_moe]
|
527 |
-
|
528 |
-
if continual_training:
|
529 |
-
logger.info('two optipons 1. zero init all weights, 2. add scaling param to moe output.')
|
530 |
-
self._zero_init()
|
531 |
-
|
532 |
-
def _zero_init(self):
|
533 |
-
for module in self.w2:
|
534 |
-
for n,param in module.named_parameters():
|
535 |
-
logger.info(f'{n} {param.shape}')
|
536 |
-
param.data.zero_()
|
537 |
-
|
538 |
-
|
539 |
-
def forward(self, hidden_states, selected_experts, routing_weights):
|
540 |
-
w13_final_output = None
|
541 |
-
for group_idx in range(self.num_groups):
|
542 |
-
w13_output_in_group = self._get_w13_output(hidden_states,
|
543 |
-
selected_experts,
|
544 |
-
group_idx)
|
545 |
-
if w13_final_output is None:
|
546 |
-
w13_final_output = w13_output_in_group
|
547 |
-
else:
|
548 |
-
w13_final_output += w13_output_in_group
|
549 |
-
|
550 |
-
current_hidden_states = self.act_fn(
|
551 |
-
w13_final_output[:, :, :, :self.ffn_dim]
|
552 |
-
) * w13_final_output[:, :, :, self.ffn_dim:]
|
553 |
-
|
554 |
-
final_hidden_states = None
|
555 |
-
for group_idx in range(self.num_groups):
|
556 |
-
w2_output_in_group = self._get_w2_output(current_hidden_states,
|
557 |
-
selected_experts,
|
558 |
-
routing_weights, group_idx)
|
559 |
-
if final_hidden_states is None:
|
560 |
-
final_hidden_states = w2_output_in_group
|
561 |
-
else:
|
562 |
-
final_hidden_states += w2_output_in_group
|
563 |
-
return final_hidden_states
|
564 |
-
|
565 |
-
def _get_w13_output(self, hidden_states, selected_experts, group_idx):
|
566 |
-
selected_experts_in_group = selected_experts - (
|
567 |
-
group_idx * self.num_experts_per_group)
|
568 |
-
|
569 |
-
w13_output = self.w13[group_idx](hidden_states,
|
570 |
-
selected_experts_in_group)
|
571 |
-
return w13_output
|
572 |
-
|
573 |
-
def _get_w2_output(self, hidden_states, selected_experts, routing_weights,
|
574 |
-
group_idx):
|
575 |
-
selected_experts_in_group = selected_experts - (
|
576 |
-
group_idx * self.num_experts_per_group)
|
577 |
-
output = self.w2[group_idx](hidden_states, selected_experts_in_group,
|
578 |
-
routing_weights)
|
579 |
-
return output
|
580 |
-
|
581 |
-
|
582 |
-
class MoEGate(nn.Module):
|
583 |
-
|
584 |
-
def __init__(self, config):
|
585 |
-
super().__init__()
|
586 |
-
self.config = config
|
587 |
-
self.top_k = config.num_experts_per_tok
|
588 |
-
self.n_routed_experts = config.n_routed_experts
|
589 |
-
self.routed_scaling_factor = config.routed_scaling_factor
|
590 |
-
self.scoring_func = config.scoring_func
|
591 |
-
self.seq_aux = config.seq_aux
|
592 |
-
self.topk_method = config.topk_method
|
593 |
-
self.n_group = config.n_group
|
594 |
-
self.topk_group = config.topk_group
|
595 |
-
|
596 |
-
# topk selection algorithm
|
597 |
-
self.norm_topk_prob = config.norm_topk_prob
|
598 |
-
self.gating_dim = config.hidden_size
|
599 |
-
self.weight = nn.Parameter(
|
600 |
-
torch.empty((self.n_routed_experts, self.gating_dim)))
|
601 |
-
if self.topk_method == "noaux_tc":
|
602 |
-
self.e_score_correction_bias = nn.Parameter(
|
603 |
-
torch.empty((self.n_routed_experts)))
|
604 |
-
self.reset_parameters()
|
605 |
-
|
606 |
-
def reset_parameters(self) -> None:
|
607 |
-
import torch.nn.init as init
|
608 |
-
|
609 |
-
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
610 |
-
|
611 |
-
def forward(self, hidden_states):
|
612 |
-
bsz, seq_len, h = hidden_states.shape
|
613 |
-
### compute gating score
|
614 |
-
hidden_states = hidden_states.view(-1, h)
|
615 |
-
logits = F.linear(hidden_states.type(torch.float32),
|
616 |
-
self.weight.type(torch.float32), None)
|
617 |
-
if self.scoring_func == "sigmoid":
|
618 |
-
scores = logits.sigmoid()
|
619 |
-
else:
|
620 |
-
raise NotImplementedError(
|
621 |
-
f"insupportable scoring function for MoE gating: {self.scoring_func}"
|
622 |
-
)
|
623 |
-
|
624 |
-
### select top-k experts
|
625 |
-
if self.topk_method == "greedy":
|
626 |
-
topk_weight, topk_idx = torch.topk(scores,
|
627 |
-
k=self.top_k,
|
628 |
-
dim=-1,
|
629 |
-
sorted=False)
|
630 |
-
elif self.topk_method == "group_limited_greedy":
|
631 |
-
group_scores = (scores.view(bsz * seq_len, self.n_group,
|
632 |
-
-1).max(dim=-1).values) # [n, n_group]
|
633 |
-
group_idx = torch.topk(group_scores,
|
634 |
-
k=self.topk_group,
|
635 |
-
dim=-1,
|
636 |
-
sorted=False)[1] # [n, top_k_group]
|
637 |
-
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
638 |
-
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
639 |
-
score_mask = (group_mask.unsqueeze(-1).expand(
|
640 |
-
bsz * seq_len, self.n_group,
|
641 |
-
self.n_routed_experts // self.n_group).reshape(
|
642 |
-
bsz * seq_len, -1)) # [n, e]
|
643 |
-
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
644 |
-
topk_weight, topk_idx = torch.topk(tmp_scores,
|
645 |
-
k=self.top_k,
|
646 |
-
dim=-1,
|
647 |
-
sorted=False)
|
648 |
-
elif self.topk_method == "noaux_tc":
|
649 |
-
### will be used. ###
|
650 |
-
scores_for_choice = scores.view(
|
651 |
-
bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
|
652 |
-
group_scores = (scores_for_choice.view(
|
653 |
-
bsz * seq_len, self.n_group,
|
654 |
-
-1).topk(2, dim=-1)[0].sum(dim=-1)) # [n, n_group]
|
655 |
-
group_idx = torch.topk(group_scores,
|
656 |
-
k=self.topk_group,
|
657 |
-
dim=-1,
|
658 |
-
sorted=False)[1] # [n, top_k_group]
|
659 |
-
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
660 |
-
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
661 |
-
score_mask = (group_mask.unsqueeze(-1).expand(
|
662 |
-
bsz * seq_len, self.n_group,
|
663 |
-
self.n_routed_experts // self.n_group).reshape(
|
664 |
-
bsz * seq_len, -1)) # [n, e]
|
665 |
-
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(),
|
666 |
-
0.0) # [n, e]
|
667 |
-
_, topk_idx = torch.topk(tmp_scores,
|
668 |
-
k=self.top_k,
|
669 |
-
dim=-1,
|
670 |
-
sorted=False)
|
671 |
-
topk_weight = scores.gather(1, topk_idx)
|
672 |
-
else:
|
673 |
-
raise NotImplementedError(
|
674 |
-
f"insupportable TopK function for MoE gating: {self.topk_method}"
|
675 |
-
)
|
676 |
-
|
677 |
-
### norm gate to sum 1
|
678 |
-
if self.top_k > 1 and self.norm_topk_prob:
|
679 |
-
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
680 |
-
topk_weight = topk_weight / denominator
|
681 |
-
topk_weight = topk_weight * self.routed_scaling_factor # must multiply the scaling factor
|
682 |
-
|
683 |
-
return topk_idx, topk_weight
|
684 |
-
|
685 |
-
|
686 |
-
class MotifMoE(nn.Module):
|
687 |
-
"""
|
688 |
-
A mixed expert module containing shared experts.
|
689 |
-
"""
|
690 |
-
def __init__(self, config):
|
691 |
-
super().__init__()
|
692 |
-
self.config = config
|
693 |
-
self.num_experts_per_tok = config.num_experts_per_tok
|
694 |
-
self.use_moreh_moe = config.use_moreh_moe
|
695 |
-
self.use_fused_mlp = config.use_fused_mlp
|
696 |
-
|
697 |
-
if hasattr(config, "ep_size") and config.ep_size > 1:
|
698 |
-
assert config.ep_size == dist.get_world_size()
|
699 |
-
assert not config.use_moreh_moe
|
700 |
-
self.ep_size = config.ep_size
|
701 |
-
self.experts_per_rank = config.n_routed_experts // config.ep_size
|
702 |
-
self.ep_rank = dist.get_rank()
|
703 |
-
self.experts = nn.ModuleList([
|
704 |
-
(DeepseekV3MLP(config,
|
705 |
-
intermediate_size=config.moe_intermediate_size)
|
706 |
-
if i >= self.ep_rank * self.experts_per_rank and i <
|
707 |
-
(self.ep_rank + 1) * self.experts_per_rank else None)
|
708 |
-
for i in range(config.n_routed_experts)
|
709 |
-
])
|
710 |
-
else:
|
711 |
-
self.ep_size = 1
|
712 |
-
self.experts_per_rank = config.n_routed_experts
|
713 |
-
self.ep_rank = 0
|
714 |
-
if self.use_moreh_moe:
|
715 |
-
if not self.use_fused_mlp:
|
716 |
-
self.experts = MorehMoeMLP(
|
717 |
-
ffn_dim=config.moe_intermediate_size,
|
718 |
-
hidden_dim=config.hidden_size,
|
719 |
-
hidden_act_moe=config.hidden_act_moe,
|
720 |
-
num_experts=config.n_routed_experts,
|
721 |
-
device=None)
|
722 |
-
else:
|
723 |
-
## group expert.
|
724 |
-
self.experts = MorehMoeFusedMLP(
|
725 |
-
ffn_dim=config.moe_intermediate_size,
|
726 |
-
hidden_dim=config.hidden_size,
|
727 |
-
hidden_act_moe=config.hidden_act_moe,
|
728 |
-
num_experts=config.n_routed_experts,
|
729 |
-
num_groups=config.n_group,
|
730 |
-
device=None,
|
731 |
-
continual_training=config.continual_training,
|
732 |
-
)
|
733 |
-
else:
|
734 |
-
self.experts = nn.ModuleList([
|
735 |
-
DeepseekV3MLP(
|
736 |
-
config, intermediate_size=config.moe_intermediate_size)
|
737 |
-
for i in range(config.n_routed_experts)
|
738 |
-
])
|
739 |
-
|
740 |
-
self.gate = MoEGate(config)
|
741 |
-
|
742 |
-
def forward(self, hidden_states):
|
743 |
-
identity = hidden_states
|
744 |
-
orig_shape = hidden_states.shape
|
745 |
-
topk_idx, topk_weight = self.gate(hidden_states)
|
746 |
-
if self.use_moreh_moe:
|
747 |
-
y = self.experts(hidden_states, topk_idx.view(*orig_shape[:-1], -1),
|
748 |
-
topk_weight.view(*orig_shape[:-1], -1))
|
749 |
-
y = y.type(hidden_states.dtype)
|
750 |
-
else:
|
751 |
-
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
752 |
-
flat_topk_idx = topk_idx.view(-1)
|
753 |
-
if self.training:
|
754 |
-
hidden_states = hidden_states.repeat_interleave(
|
755 |
-
self.num_experts_per_tok, dim=0)
|
756 |
-
y = torch.empty_like(hidden_states)
|
757 |
-
for i, expert in enumerate(self.experts):
|
758 |
-
y[flat_topk_idx == i] = expert(
|
759 |
-
hidden_states[flat_topk_idx == i])
|
760 |
-
y = (y.view(*topk_weight.shape, -1) *
|
761 |
-
topk_weight.unsqueeze(-1)).sum(dim=1)
|
762 |
-
y = y.type(hidden_states.dtype)
|
763 |
-
y = y.view(*orig_shape)
|
764 |
-
# y = AddAuxiliaryLoss.apply(y, aux_loss)
|
765 |
-
else:
|
766 |
-
y = self.moe_infer(hidden_states, topk_idx,
|
767 |
-
topk_weight).view(*orig_shape)
|
768 |
-
return y, identity
|
769 |
-
|
770 |
-
@torch.no_grad()
|
771 |
-
def moe_infer(self, x, topk_ids, topk_weight):
|
772 |
-
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
|
773 |
-
cnts.scatter_(1, topk_ids, 1)
|
774 |
-
tokens_per_expert = cnts.sum(dim=0)
|
775 |
-
idxs = topk_ids.view(-1).argsort()
|
776 |
-
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
777 |
-
sorted_tokens_shape = sorted_tokens.shape
|
778 |
-
if self.ep_size > 1:
|
779 |
-
tokens_per_ep_rank = tokens_per_expert.view(self.ep_size,
|
780 |
-
-1).sum(dim=1)
|
781 |
-
tokens_per_expert_group = tokens_per_expert.new_empty(
|
782 |
-
tokens_per_expert.shape[0])
|
783 |
-
dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
|
784 |
-
output_splits = (tokens_per_expert_group.view(
|
785 |
-
self.ep_size, -1).sum(1).cpu().numpy().tolist())
|
786 |
-
gathered_tokens = sorted_tokens.new_empty(
|
787 |
-
tokens_per_expert_group.sum(dim=0).cpu().item(),
|
788 |
-
sorted_tokens.shape[1])
|
789 |
-
input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
|
790 |
-
dist.all_to_all(
|
791 |
-
list(gathered_tokens.split(output_splits)),
|
792 |
-
list(sorted_tokens.split(input_split_sizes)),
|
793 |
-
)
|
794 |
-
tokens_per_expert_post_gather = tokens_per_expert_group.view(
|
795 |
-
self.ep_size, self.experts_per_rank).sum(dim=0)
|
796 |
-
gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],),
|
797 |
-
dtype=np.int32)
|
798 |
-
s = 0
|
799 |
-
for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
|
800 |
-
gatherd_idxs[s:s + k] = i % self.experts_per_rank
|
801 |
-
s += k
|
802 |
-
gatherd_idxs = gatherd_idxs.argsort()
|
803 |
-
sorted_tokens = gathered_tokens[gatherd_idxs]
|
804 |
-
tokens_per_expert = tokens_per_expert_post_gather
|
805 |
-
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
806 |
-
|
807 |
-
outputs = []
|
808 |
-
start_idx = 0
|
809 |
-
for i, num_tokens in enumerate(tokens_per_expert):
|
810 |
-
end_idx = start_idx + num_tokens
|
811 |
-
if num_tokens == 0:
|
812 |
-
continue
|
813 |
-
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
|
814 |
-
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
815 |
-
expert_out = expert(tokens_for_this_expert)
|
816 |
-
outputs.append(expert_out)
|
817 |
-
start_idx = end_idx
|
818 |
-
|
819 |
-
outs = torch.cat(outputs,
|
820 |
-
dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
821 |
-
if self.ep_size > 1:
|
822 |
-
new_x = torch.empty_like(outs)
|
823 |
-
new_x[gatherd_idxs] = outs
|
824 |
-
gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
|
825 |
-
dist.all_to_all(
|
826 |
-
list(gathered_tokens.split(input_split_sizes)),
|
827 |
-
list(new_x.split(output_splits)),
|
828 |
-
)
|
829 |
-
outs = gathered_tokens
|
830 |
-
|
831 |
-
new_x = torch.empty_like(outs)
|
832 |
-
new_x[idxs] = outs
|
833 |
-
final_out = (new_x.view(
|
834 |
-
*topk_ids.shape, -1).type(topk_weight.dtype).mul_(
|
835 |
-
topk_weight.unsqueeze(dim=-1)).sum(dim=1).type(new_x.dtype))
|
836 |
-
return final_out
|
837 |
|
838 |
|
839 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
840 |
-
|
841 |
-
|
842 |
"""
|
843 |
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
844 |
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
845 |
-
|
846 |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
847 |
if n_rep == 1:
|
848 |
return hidden_states
|
@@ -851,32 +297,31 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
851 |
"""
|
852 |
|
853 |
return torch.repeat_interleave(hidden_states, dim=1, repeats=n_rep)
|
|
|
854 |
|
855 |
-
|
856 |
-
# @log_timing
|
857 |
class MotifAttention(nn.Module):
|
858 |
"""
|
859 |
Differential Attention (DiffAttention) module.
|
860 |
|
861 |
-
Implements the Differential Attention from
|
862 |
"DIFFERENTIAL TRANSFORMER" (https://arxiv.org/pdf/2410.05258).
|
863 |
|
864 |
Overview
|
865 |
Standard transformers often over-allocate attention to irrelevant context.
|
866 |
-
DiffAttention addresses this by computing attention as the difference between
|
867 |
-
two separate softmax attention maps, effectively canceling noise and promoting
|
868 |
sparse, structured attention patterns.
|
869 |
|
870 |
Reference Implementation
|
871 |
https://github.com/microsoft/unilm/tree/master/Diff-Transformer
|
872 |
|
873 |
Args
|
874 |
-
The differential attention mechanism computes attention as the difference of two softmax attention scores, weighted by a learnable scalar λ.
|
875 |
λ is re-parameterized as λ = exp(λ_q1 · λ_k1) − exp(λ_q2 · λ_k2) + λ_init.
|
876 |
- lambda_q1, lambda_q2 (nn.Parameter): Learnable vectors used to compute the first and second components of λ for query transformations.
|
877 |
- lambda_k1, lambda_k2 (nn.Parameter): Learnable vectors used to compute the first and second components of λ for key transformations.
|
878 |
- lambda_init (float): A constant used for initializing λ, typically set as λ_init = 0.8 − 0.6 × exp(−0.3 × (layer_index − 1)).
|
879 |
-
|
880 |
"""
|
881 |
|
882 |
def __init__(self, config: MotifConfig, layer_idx: Optional[int] = None):
|
@@ -899,11 +344,7 @@ class MotifAttention(nn.Module):
|
|
899 |
self.rope_theta = config.rope_theta
|
900 |
self.is_causal = True
|
901 |
self.attention_dropout = config.attention_dropout
|
902 |
-
|
903 |
-
self.batch_num = config.batch_num
|
904 |
-
logger.info(f'self.batcn_num : {self.batch_num}')
|
905 |
-
except:
|
906 |
-
self.batch_num = None
|
907 |
if (self.head_dim * self.num_heads) != self.hidden_size:
|
908 |
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
909 |
f" and `num_heads`: {self.num_heads}).")
|
@@ -912,61 +353,22 @@ class MotifAttention(nn.Module):
|
|
912 |
self.num_key_value_heads //= 2
|
913 |
self.n_rep = self.num_heads // self.num_key_value_heads
|
914 |
|
915 |
-
##mix attn
|
916 |
-
|
917 |
-
self.mix_attn = config.mix_attn
|
918 |
-
|
919 |
-
if self.mix_attn:
|
920 |
-
|
921 |
-
self.cq, self.ck = 6, 11
|
922 |
-
self.ch = 2
|
923 |
-
|
924 |
-
self.key_query_conv = nn.Conv2d(
|
925 |
-
in_channels=self.num_heads*2,
|
926 |
-
out_channels=self.num_heads*2,
|
927 |
-
kernel_size=(self.cq, self.ck),
|
928 |
-
padding="same",
|
929 |
-
groups=self.num_heads*2
|
930 |
-
)
|
931 |
-
|
932 |
-
self.head_conv = nn.Conv1d(
|
933 |
-
in_channels=self.num_heads,
|
934 |
-
out_channels=self.num_heads,
|
935 |
-
kernel_size=1,
|
936 |
-
padding=0,
|
937 |
-
groups=self.num_heads // self.ch
|
938 |
-
)
|
939 |
-
|
940 |
-
self.group_norm = nn.GroupNorm(num_groups=self.num_heads, num_channels=self.num_heads)
|
941 |
-
|
942 |
-
|
943 |
-
|
944 |
-
# re-init projections
|
945 |
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
946 |
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
|
947 |
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
|
948 |
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
949 |
|
950 |
-
# init lambdas
|
951 |
for name in ["lambda_q1", "lambda_k1", "lambda_q2", "lambda_k2"]:
|
952 |
setattr(self, name, nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32)))
|
953 |
getattr(self, name).data.normal_(mean=0.0, std=0.1)
|
954 |
|
955 |
-
# Uses same norm as motif norm, without elementwise_affine option
|
956 |
self.subln = MotifRMSNorm(2 * self.head_dim, eps=1e-5)
|
957 |
self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * (layer_idx - 1))
|
958 |
|
959 |
-
self.rotary_emb =
|
960 |
max_position_embeddings=self.max_position_embeddings,
|
961 |
base=self.rope_theta)
|
962 |
|
963 |
-
for param in ["q_proj_alpha", "k_proj_alpha", "v_proj_alpha", "o_proj_alpha"]:
|
964 |
-
setattr(
|
965 |
-
self, param,
|
966 |
-
nn.Parameter(torch.tensor(getattr(config, param, 1.0), dtype=torch.float))
|
967 |
-
if config.wesar_weights else 1.0)
|
968 |
-
|
969 |
-
|
970 |
def forward(
|
971 |
self,
|
972 |
hidden_states: torch.Tensor,
|
@@ -976,15 +378,13 @@ class MotifAttention(nn.Module):
|
|
976 |
output_attentions: bool = False,
|
977 |
use_cache: bool = False,
|
978 |
cache_position: Optional[torch.LongTensor] = None,
|
979 |
-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
980 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
981 |
bsz, q_len, _ = hidden_states.size()
|
982 |
|
983 |
-
query_states = self.q_proj(hidden_states)
|
984 |
-
key_states = self.k_proj(hidden_states)
|
985 |
-
value_states = self.v_proj(hidden_states)
|
986 |
-
|
987 |
-
## bsz, seq, n_heads, head_dim
|
988 |
|
989 |
query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
|
990 |
key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
@@ -1006,17 +406,15 @@ class MotifAttention(nn.Module):
|
|
1006 |
key_states,
|
1007 |
cos,
|
1008 |
sin,
|
1009 |
-
|
1010 |
|
1011 |
if past_key_value is not None:
|
1012 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
1013 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
1014 |
|
1015 |
-
# repeat k/v heads if n_kv_heads < n_heads
|
1016 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
1017 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
1018 |
|
1019 |
-
## bsz, #haead, q_len, head_dim -> bsz, head, q_len, q_len
|
1020 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
1021 |
|
1022 |
kv_seq_len = key_states.shape[-2]
|
@@ -1025,49 +423,31 @@ class MotifAttention(nn.Module):
|
|
1025 |
attention_mask = torch.triu(
|
1026 |
torch.full((q_len, kv_seq_len), float("-inf"), dtype=attn_weights.dtype, device=attn_weights.device),
|
1027 |
1 + offset)
|
1028 |
-
|
1029 |
-
if self.mix_attn:
|
1030 |
-
## condition mask==0, value : 0
|
1031 |
-
attn_weights = attn_weights.masked_fill( attention_mask == 0, 0)
|
1032 |
-
attn_weights = self.key_query_conv(attn_weights)
|
1033 |
-
attn_weights = attn_weights[:, :, :kv_seq_len, :kv_seq_len]
|
1034 |
-
|
1035 |
-
###add attn
|
1036 |
attn_weights = attn_weights + attention_mask
|
1037 |
|
1038 |
-
# upcast attention to fp32
|
1039 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
1040 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
1041 |
|
1042 |
-
# differential transformer lambdas
|
1043 |
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(attn_weights)
|
1044 |
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(attn_weights)
|
1045 |
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
1046 |
attn_weights = attn_weights.view(bsz, self.num_heads, 2, q_len, -1)
|
1047 |
attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1]
|
1048 |
-
##head_conv
|
1049 |
-
if self.mix_attn:
|
1050 |
-
attn_weights = attn_weights.view(bsz, self.num_heads, -1).contiguous()
|
1051 |
-
attn_weights = self.head_conv(attn_weights)
|
1052 |
-
attn_weights = attn_weights.view(bsz, self.num_heads, q_len, -1).contiguous()
|
1053 |
|
1054 |
-
##shape : bsz, #heads, seq, head_dim
|
1055 |
attn_output = torch.matmul(attn_weights, value_states)
|
1056 |
|
1057 |
-
|
1058 |
attn_output = self.subln(attn_output)
|
1059 |
attn_output = attn_output * (1 - self.lambda_init)
|
1060 |
|
1061 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim * 2):
|
1062 |
raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
1063 |
f" {attn_output.size()}")
|
1064 |
-
|
1065 |
-
attn_output = self.group_norm(attn_output)
|
1066 |
-
|
1067 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
1068 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
1069 |
|
1070 |
-
attn_output = self.o_proj(attn_output)
|
1071 |
|
1072 |
if not output_attentions:
|
1073 |
attn_weights = None
|
@@ -1075,7 +455,6 @@ class MotifAttention(nn.Module):
|
|
1075 |
return attn_output, attn_weights, past_key_value
|
1076 |
|
1077 |
|
1078 |
-
# @log_timing
|
1079 |
class MotifFlashAttention2(MotifAttention):
|
1080 |
"""
|
1081 |
Motif flash attention module, following Motif attention module. This module inherits from `MotifAttention`
|
@@ -1085,18 +464,16 @@ class MotifFlashAttention2(MotifAttention):
|
|
1085 |
config.max_window_layers layers.
|
1086 |
"""
|
1087 |
|
1088 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
1089 |
def __init__(self, *args, **kwargs):
|
1090 |
super().__init__(*args, **kwargs)
|
1091 |
-
|
1092 |
-
logger.info(f'flash attention True')
|
1093 |
-
|
1094 |
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
1095 |
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
1096 |
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
1097 |
|
1098 |
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
1099 |
|
|
|
|
|
1100 |
def _reshape_heads(self, tensor, batch_size, seq_len):
|
1101 |
"""2-way head split tensor reshape"""
|
1102 |
return tensor.reshape(batch_size, seq_len, self.num_heads, 2, self.head_dim)
|
@@ -1106,55 +483,27 @@ class MotifFlashAttention2(MotifAttention):
|
|
1106 |
return tensor.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
1107 |
|
1108 |
def _compute_attention(self, query_states, key_states, value_states, attention_mask, q_len, position_ids,
|
1109 |
-
dropout_rate, sliding_window
|
1110 |
"""Flash Attention 2 implements"""
|
1111 |
-
|
1112 |
-
|
1113 |
-
|
1114 |
-
|
1115 |
-
causal = self.is_causal
|
1116 |
-
else:
|
1117 |
-
causal = self.is_causal and q_len != 1
|
1118 |
-
|
1119 |
-
bsz = query_states.shape[0]
|
1120 |
-
|
1121 |
-
if batch_num:
|
1122 |
-
query_states = query_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
|
1123 |
-
key_states = key_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
|
1124 |
-
value_states = value_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
|
1125 |
-
|
1126 |
-
attn_out = moreh_ops.flash_attention_varlen_dp(query_states,
|
1127 |
-
key_states,
|
1128 |
-
value_states,
|
1129 |
-
attention_mask,
|
1130 |
-
attention_mask,
|
1131 |
-
max_seqlen_q=q_len,
|
1132 |
-
max_seqlen_kv=q_len,
|
1133 |
-
dropout_p=dropout_rate,
|
1134 |
-
softmax_scale=scale_factor,
|
1135 |
-
is_causal=causal,
|
1136 |
-
batch_num=batch_num)
|
1137 |
-
attn_out = attn_out.reshape(bsz, q_len, self.num_heads, -1)
|
1138 |
-
else:
|
1139 |
-
return MorehFlashAttention(query_states,
|
1140 |
-
key_states,
|
1141 |
-
value_states,
|
1142 |
-
padding_mask=attention_mask,
|
1143 |
-
dropout_p=dropout_rate,
|
1144 |
-
softmax_scale=scale_factor,
|
1145 |
-
causal=causal)
|
1146 |
-
return attn_out
|
1147 |
else:
|
1148 |
-
|
1149 |
-
|
1150 |
-
|
|
|
|
|
1151 |
attention_mask,
|
1152 |
q_len,
|
1153 |
position_ids=position_ids,
|
1154 |
dropout=dropout_rate,
|
1155 |
sliding_window=sliding_window,
|
1156 |
-
is_causal=
|
|
|
1157 |
use_top_left_mask=self._flash_attn_uses_top_left_mask)
|
|
|
1158 |
|
1159 |
def forward(
|
1160 |
self,
|
@@ -1169,9 +518,9 @@ class MotifFlashAttention2(MotifAttention):
|
|
1169 |
):
|
1170 |
bsz, q_len, _ = hidden_states.size()
|
1171 |
|
1172 |
-
query_states = self.q_proj(hidden_states)
|
1173 |
-
key_states = self.k_proj(hidden_states)
|
1174 |
-
value_states = self.v_proj(hidden_states)
|
1175 |
|
1176 |
query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
|
1177 |
key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
@@ -1192,13 +541,12 @@ class MotifFlashAttention2(MotifAttention):
|
|
1192 |
key_states,
|
1193 |
cos,
|
1194 |
sin,
|
1195 |
-
|
1196 |
|
1197 |
if past_key_value is not None:
|
1198 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
1199 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
1200 |
|
1201 |
-
# repeat k/v heads if n_kv_heads < n_heads
|
1202 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
1203 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
1204 |
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
@@ -1207,7 +555,7 @@ class MotifFlashAttention2(MotifAttention):
|
|
1207 |
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
1208 |
# cast them back in float16 just to be sure everything works as expected.
|
1209 |
input_dtype = query_states.dtype
|
1210 |
-
if input_dtype == torch.float32
|
1211 |
if torch.is_autocast_enabled():
|
1212 |
target_dtype = torch.get_autocast_gpu_dtype()
|
1213 |
# Handle the case where the model is quantized
|
@@ -1234,7 +582,7 @@ class MotifFlashAttention2(MotifAttention):
|
|
1234 |
value_states = value_states.transpose(1, 2)
|
1235 |
|
1236 |
if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None
|
1237 |
-
and self.layer_idx >= self.config.max_window_layers
|
1238 |
sliding_window = self.config.sliding_window
|
1239 |
else:
|
1240 |
sliding_window = None
|
@@ -1254,12 +602,10 @@ class MotifFlashAttention2(MotifAttention):
|
|
1254 |
k1, k2 = k1.contiguous(), k2.contiguous()
|
1255 |
v1, v2 = v1.contiguous(), v2.contiguous()
|
1256 |
|
1257 |
-
|
1258 |
-
|
1259 |
-
|
1260 |
-
self._compute_attention(
|
1261 |
-
attn21, attn22 = self._compute_attention(q2, k2, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num), \
|
1262 |
-
self._compute_attention(q2, k2, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num)
|
1263 |
|
1264 |
attn1, attn2 = torch.cat([attn11, attn12], dim=-1), torch.cat([attn21, attn22], dim=-1)
|
1265 |
|
@@ -1277,16 +623,15 @@ class MotifFlashAttention2(MotifAttention):
|
|
1277 |
attn_output = attn_output * (1 - self.lambda_init)
|
1278 |
|
1279 |
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim * 2):
|
1280 |
-
raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads,
|
1281 |
f" {attn_output.size()}")
|
1282 |
|
1283 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
1284 |
-
attn_output = self.o_proj(attn_output)
|
1285 |
|
1286 |
return attn_output, None, past_key_value
|
1287 |
|
1288 |
|
1289 |
-
# @log_timing
|
1290 |
class MotifSdpaAttention(MotifAttention):
|
1291 |
"""
|
1292 |
Motif attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
@@ -1294,7 +639,6 @@ class MotifSdpaAttention(MotifAttention):
|
|
1294 |
SDPA API.
|
1295 |
"""
|
1296 |
|
1297 |
-
# Adapted from MotifAttention.forward
|
1298 |
def forward(
|
1299 |
self,
|
1300 |
hidden_states: torch.Tensor,
|
@@ -1307,7 +651,6 @@ class MotifSdpaAttention(MotifAttention):
|
|
1307 |
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
1308 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
1309 |
if output_attentions:
|
1310 |
-
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
1311 |
logger.warning_once(
|
1312 |
"MotifModel is using MotifSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
1313 |
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
@@ -1343,8 +686,7 @@ class MotifSdpaAttention(MotifAttention):
|
|
1343 |
query_states, key_states = apply_rotary_pos_emb(query_states,
|
1344 |
key_states,
|
1345 |
cos,
|
1346 |
-
sin
|
1347 |
-
fused_rope=self.config.fused_rope)
|
1348 |
|
1349 |
if past_key_value is not None:
|
1350 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
@@ -1380,45 +722,25 @@ class MotifSdpaAttention(MotifAttention):
|
|
1380 |
MOTIF_ATTENTION_CLASSES = {
|
1381 |
"eager": MotifAttention,
|
1382 |
"flash_attention_2": MotifFlashAttention2,
|
1383 |
-
"sdpa":
|
1384 |
}
|
1385 |
|
1386 |
|
1387 |
-
# @log_timing
|
1388 |
class MotifDecoderLayer(nn.Module):
|
1389 |
|
1390 |
-
def __init__(self, config: MotifConfig,
|
1391 |
super().__init__()
|
1392 |
self.hidden_size = config.hidden_size
|
1393 |
-
if config.use_moreh_attention:
|
1394 |
-
config._attn_implementation = "flash_attention_2"
|
1395 |
if config.sliding_window and config._attn_implementation != "flash_attention_2":
|
1396 |
logger.warning_once(
|
1397 |
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
|
1398 |
"unexpected results may be encountered.")
|
1399 |
-
|
1400 |
-
self.self_attn = MOTIF_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
1401 |
-
else:
|
1402 |
-
self.self_attn = MOTIF_ATTENTION_CLASSES["eager"](config, layer_idx)
|
1403 |
self.mlp = MotifMLP(config)
|
1404 |
-
|
1405 |
-
self.
|
1406 |
-
|
1407 |
-
self.moe = MotifMoE(config)
|
1408 |
-
|
1409 |
-
RMSNorm = MorehRMSNorm if MorehRMSNorm is not None else MotifRMSNorm
|
1410 |
-
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1411 |
-
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1412 |
|
1413 |
-
if config.wesar_weights and config.use_norm_alpha:
|
1414 |
-
self.input_layernorm_alpha = nn.Parameter(torch.tensor(1).float())
|
1415 |
-
else:
|
1416 |
-
self.input_layernorm_alpha = 1
|
1417 |
-
|
1418 |
-
if config.wesar_weights and config.use_norm_alpha :
|
1419 |
-
self.post_attention_layernorm_alpha = nn.Parameter(torch.tensor(1).float())
|
1420 |
-
else:
|
1421 |
-
self.post_attention_layernorm_alpha = 1
|
1422 |
|
1423 |
def forward(
|
1424 |
self,
|
@@ -1456,7 +778,7 @@ class MotifDecoderLayer(nn.Module):
|
|
1456 |
|
1457 |
residual = hidden_states
|
1458 |
|
1459 |
-
hidden_states = self.input_layernorm(hidden_states)
|
1460 |
|
1461 |
# Self Attention
|
1462 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
@@ -1473,16 +795,8 @@ class MotifDecoderLayer(nn.Module):
|
|
1473 |
|
1474 |
# Fully Connected
|
1475 |
residual = hidden_states
|
1476 |
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
1477 |
-
|
1478 |
-
if self.moe is not None:
|
1479 |
-
hidden_states, identity = self.moe(hidden_states)
|
1480 |
-
## add output of shared expert and output of small moe experts.
|
1481 |
-
## hidden state must be zero tensor (for first forward)
|
1482 |
-
hidden_states += self.mlp(identity)
|
1483 |
-
else:
|
1484 |
-
hidden_states = self.mlp(hidden_states)
|
1485 |
-
|
1486 |
hidden_states = residual + hidden_states
|
1487 |
|
1488 |
outputs = (hidden_states, )
|
@@ -1532,45 +846,24 @@ class MotifPreTrainedModel(PreTrainedModel):
|
|
1532 |
def _init_weights(self, module):
|
1533 |
module_std = self.config.initializer_range
|
1534 |
if isinstance(module, nn.Linear):
|
1535 |
-
|
1536 |
-
|
1537 |
-
|
1538 |
-
if getattr(module, "__do_scale_tager_mu_o__", False):
|
1539 |
-
if self.config.dim_model_base_init is not None:
|
1540 |
-
module_std = module_std / math.sqrt(2*(self.config.hidden_size / self.config.dim_model_base_init)*self.config.num_hidden_layers)
|
1541 |
-
else:
|
1542 |
-
module_std = module_std
|
1543 |
-
elif getattr(module, "__do_scale_tager_mu_ffn__", False):
|
1544 |
-
if self.config.dim_model_base_init is not None:
|
1545 |
-
module_std = module_std = module_std / math.sqrt(2*(self.config.hidden_size / self.config.dim_model_base_init)*self.config.num_hidden_layers)
|
1546 |
-
else:
|
1547 |
-
module_std = module_std
|
1548 |
-
elif getattr(module, "__do_scale_tager_mu_dim_model__", False):
|
1549 |
-
if self.config.dim_model_base_init is not None:
|
1550 |
-
module_std = module_std / math.sqrt(self.config.hidden_size / self.config.dim_model_base_init)
|
1551 |
-
else:
|
1552 |
-
module_std = module_std
|
1553 |
-
elif getattr(module, "__do_scale_tager_mu_dim_base_model__", False):
|
1554 |
-
module_std = module_std / math.sqrt(self.config.dim_model_base_lmh) ### lmhead.. 1
|
1555 |
-
else:
|
1556 |
-
module_std = module_std
|
1557 |
-
|
1558 |
-
torch.nn.init.trunc_normal_(module.weight.data, mean=0.0, std=module_std, a=-3*module_std, b=3*module_std)
|
1559 |
if module.bias is not None:
|
1560 |
module.bias.data.zero_()
|
1561 |
|
1562 |
elif isinstance(module, nn.Embedding):
|
1563 |
-
|
|
|
1564 |
if module.padding_idx is not None:
|
1565 |
module.weight.data[module.padding_idx].zero_()
|
1566 |
|
1567 |
|
1568 |
@dataclass
|
1569 |
class MotifModelOutputWithPast(ModelOutput):
|
1570 |
-
"""
|
1571 |
-
This augments `BaseModelOutputWithPast` in `transformers.modeling_outputs` with new optional keys: `causal_mask`, `position_embeddings`.
|
1572 |
The optional keys are currently used in the following ways:
|
1573 |
-
- pass information to the token-wise last attention layers in multi-token training
|
1574 |
"""
|
1575 |
last_hidden_state: torch.FloatTensor = None
|
1576 |
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
@@ -1655,7 +948,6 @@ MOTIF_INPUTS_DOCSTRING = r"""
|
|
1655 |
"""
|
1656 |
|
1657 |
|
1658 |
-
# @log_timing
|
1659 |
@add_start_docstrings(
|
1660 |
"The bare Motif Model outputting raw hidden-states without any specific head on top.",
|
1661 |
MOTIF_START_DOCSTRING,
|
@@ -1672,23 +964,11 @@ class MotifModel(MotifPreTrainedModel):
|
|
1672 |
super().__init__(config)
|
1673 |
self.padding_idx = config.pad_token_id
|
1674 |
self.vocab_size = config.vocab_size
|
1675 |
-
self.multi_token_heads = config.multi_token_heads
|
1676 |
|
1677 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
1678 |
-
|
1679 |
-
|
1680 |
-
|
1681 |
-
num_hidden_layers = config.num_hidden_layers if self.multi_token_heads is None else config.num_hidden_layers - 1
|
1682 |
-
if config.moe:
|
1683 |
-
moe_layer = [True for i in range(num_hidden_layers)]
|
1684 |
-
else:
|
1685 |
-
moe_layer = [False for i in range(num_hidden_layers)]
|
1686 |
-
logger.info(f'current_moe layer { moe_layer }')
|
1687 |
-
self.layers = nn.ModuleList([MotifDecoderLayer(config = config, moe_layer= moe_layer[layer_idx],
|
1688 |
-
layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)])
|
1689 |
-
self._attn_implementation = config._attn_implementation
|
1690 |
-
RMSNorm = MorehRMSNorm if MorehRMSNorm is not None else MotifRMSNorm
|
1691 |
-
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1692 |
self.hidden_size = config.hidden_size
|
1693 |
self.num_heads = config.num_attention_heads
|
1694 |
self.head_dim = self.hidden_size // self.num_heads
|
@@ -1701,36 +981,6 @@ class MotifModel(MotifPreTrainedModel):
|
|
1701 |
self.gradient_checkpointing = False
|
1702 |
self.post_init()
|
1703 |
|
1704 |
-
self.use_pipeline = config.use_pipeline
|
1705 |
-
if self.use_pipeline:
|
1706 |
-
logger.info('use reinforced pp..')
|
1707 |
-
if config.num_stages==2:
|
1708 |
-
### moe version
|
1709 |
-
if config.decontam_attn:
|
1710 |
-
self.split_layers = [15]
|
1711 |
-
else:
|
1712 |
-
if num_hidden_layers == 32:
|
1713 |
-
self.split_layers = [15] # 14: 15,17 # 13: 14:18
|
1714 |
-
else:
|
1715 |
-
self.split_layers = [6]
|
1716 |
-
elif config.num_stages==3:
|
1717 |
-
self.split_layers = [9,20] ## 11, 11, 10
|
1718 |
-
elif config.num_stages==4:
|
1719 |
-
self.split_layers = [7,15,23] #7,9,9,7
|
1720 |
-
elif config.num_stages==16:
|
1721 |
-
self.split_layers = [1,3,5,7,9,11,13,15,17,19,21,23,25,27,29]
|
1722 |
-
logger.info(f' check the split layers (moe): {self.split_layers}')
|
1723 |
-
|
1724 |
-
self.scale_emb = 1
|
1725 |
-
|
1726 |
-
# Reparameterization <|_1_|>
|
1727 |
-
if config.wesar_weights :
|
1728 |
-
logger.info(f'config.wesar_weights {config.wesar_weights}')
|
1729 |
-
self.norm_alpha = nn.Parameter(torch.tensor(1).float())
|
1730 |
-
self.scale_emb = 10
|
1731 |
-
else:
|
1732 |
-
self.norm_alpha = 1
|
1733 |
-
|
1734 |
def get_input_embeddings(self):
|
1735 |
return self.embed_tokens
|
1736 |
|
@@ -1769,7 +1019,6 @@ class MotifModel(MotifPreTrainedModel):
|
|
1769 |
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
1770 |
use_cache = False
|
1771 |
|
1772 |
-
# kept for BC (non `Cache` `past_key_values` inputs)
|
1773 |
return_legacy_cache = False
|
1774 |
if use_cache and not isinstance(past_key_values, Cache):
|
1775 |
return_legacy_cache = True
|
@@ -1783,26 +1032,23 @@ class MotifModel(MotifPreTrainedModel):
|
|
1783 |
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)")
|
1784 |
|
1785 |
if inputs_embeds is None:
|
1786 |
-
inputs_embeds = self.embed_tokens(input_ids)
|
1787 |
|
1788 |
if cache_position is None:
|
1789 |
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
1790 |
cache_position = torch.arange(past_seen_tokens,
|
1791 |
past_seen_tokens + inputs_embeds.shape[1],
|
1792 |
device=inputs_embeds.device)
|
1793 |
-
position_ids = None
|
1794 |
if position_ids is None:
|
1795 |
position_ids = cache_position.unsqueeze(0)
|
1796 |
-
|
1797 |
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values,
|
1798 |
output_attentions)
|
1799 |
|
1800 |
hidden_states = inputs_embeds
|
1801 |
bsz, q_len, _ = hidden_states.size()
|
1802 |
-
# create position embeddings to be shared across the decoder layers
|
1803 |
position_embeddings = self.rotary_emb(hidden_states, seq_len=q_len)
|
1804 |
|
1805 |
-
# decoder layers
|
1806 |
all_hidden_states = () if output_hidden_states else None
|
1807 |
all_self_attns = () if output_attentions else None
|
1808 |
next_decoder_cache = None
|
@@ -1837,20 +1083,14 @@ class MotifModel(MotifPreTrainedModel):
|
|
1837 |
|
1838 |
hidden_states = layer_outputs[0]
|
1839 |
|
1840 |
-
|
1841 |
-
if self.use_pipeline and idx in self.split_layers:
|
1842 |
-
hidden_states = torch.moreh.pipeline_assign(hidden_states)
|
1843 |
-
|
1844 |
if use_cache:
|
1845 |
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
1846 |
|
1847 |
if output_attentions:
|
1848 |
all_self_attns += (layer_outputs[1], )
|
1849 |
|
1850 |
-
|
1851 |
-
|
1852 |
-
|
1853 |
-
# add hidden states from the last decoder layer
|
1854 |
if output_hidden_states:
|
1855 |
all_hidden_states += (hidden_states, )
|
1856 |
|
@@ -1881,8 +1121,6 @@ class MotifModel(MotifPreTrainedModel):
|
|
1881 |
output_attentions: bool,
|
1882 |
):
|
1883 |
if self.config._attn_implementation == "flash_attention_2":
|
1884 |
-
if MorehFlashAttention is not None:
|
1885 |
-
return attention_mask
|
1886 |
if attention_mask is not None and 0.0 in attention_mask:
|
1887 |
return attention_mask
|
1888 |
return None
|
@@ -1909,6 +1147,7 @@ class MotifModel(MotifPreTrainedModel):
|
|
1909 |
dtype, device = input_tensor.dtype, input_tensor.device
|
1910 |
min_dtype = torch.finfo(dtype).min
|
1911 |
sequence_length = input_tensor.shape[1]
|
|
|
1912 |
# SlidingWindowCache or StaticCache
|
1913 |
if using_sliding_window_cache or using_static_cache:
|
1914 |
target_length = past_key_values.get_max_cache_shape()
|
@@ -2003,7 +1242,6 @@ class MotifModel(MotifPreTrainedModel):
|
|
2003 |
return causal_mask
|
2004 |
|
2005 |
|
2006 |
-
# @log_timing
|
2007 |
class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
2008 |
_tied_weights_keys = ["lm_head.weight"]
|
2009 |
|
@@ -2011,35 +1249,14 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
|
2011 |
super().__init__(config)
|
2012 |
self.model = MotifModel(config)
|
2013 |
self.vocab_size = config.vocab_size
|
2014 |
-
self.multi_token_heads = config.multi_token_heads
|
2015 |
|
2016 |
-
|
2017 |
-
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
2018 |
-
else:
|
2019 |
-
self.tokenwise_last_layers = nn.ModuleList(
|
2020 |
-
[MotifDecoderLayer(config, config.num_hidden_layers - 1) for _ in range(self.multi_token_heads)])
|
2021 |
-
self.tokenwise_lm_heads = nn.ModuleList(
|
2022 |
-
[nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(self.multi_token_heads)])
|
2023 |
-
self.should_skip_separate_backward_pass = self.multi_token_heads is not None
|
2024 |
|
2025 |
# Initialize weights and apply final processing
|
2026 |
self.post_init()
|
2027 |
|
2028 |
-
# <|_3_|>
|
2029 |
-
if config.muP:
|
2030 |
-
self.lm_head.__do_scale_tager_mu_dim_base_model__=True
|
2031 |
-
|
2032 |
-
# <|_4_|>
|
2033 |
-
self.lm_head_alpha = 1
|
2034 |
-
if config.wesar_weights:
|
2035 |
-
self.lm_head_alpha = nn.Parameter(torch.tensor(1).float())
|
2036 |
-
|
2037 |
if getattr(config, "tie_word_embeddings", True):
|
2038 |
-
logger.info('tie embeddings')
|
2039 |
self.tie_weights()
|
2040 |
-
else:
|
2041 |
-
# <|_5_|>
|
2042 |
-
self.lm_head.__do_scale_tager_mu_dim_base_model__ = False
|
2043 |
|
2044 |
def get_input_embeddings(self):
|
2045 |
return self.model.embed_tokens
|
@@ -2059,101 +1276,7 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
|
2059 |
def get_decoder(self):
|
2060 |
return self.model
|
2061 |
|
2062 |
-
|
2063 |
-
hidden_states: torch.FloatTensor,
|
2064 |
-
outputs: MotifModelOutputWithPast,
|
2065 |
-
labels: torch.LongTensor,
|
2066 |
-
position_ids: Optional[torch.LongTensor],
|
2067 |
-
output_attentions: Optional[bool],
|
2068 |
-
use_cache: Optional[bool],
|
2069 |
-
cache_position: Optional[torch.LongTensor],
|
2070 |
-
return_dict: Optional[bool],
|
2071 |
-
num_logits_to_keep: int = 0) -> CausalLMOutputWithPast:
|
2072 |
-
"""
|
2073 |
-
This implements the main forward-backward procedure for multi-token model training proposed in
|
2074 |
-
the paper https://arxiv.org/abs/2404.19737.
|
2075 |
-
Essentially,
|
2076 |
-
- The multi-token model tries to predict n (instead of 1) tokens at a time.
|
2077 |
-
- Applying this only during training and using first-token prediction during inference is still helpful.
|
2078 |
-
- The change in architecture: when using n-token prediction, each token index (between 1 and n) has its own
|
2079 |
-
(1) last attention layer and (2) lm head.
|
2080 |
-
- The change in loss: sum of cross-entropy losses corresponding to each token index.
|
2081 |
-
- Custom forward-backward procedure for memory efficiency: refer to the implementation of `multi_head_forward_backward`.
|
2082 |
-
"""
|
2083 |
-
if not return_dict:
|
2084 |
-
raise NotImplementedError("return_dict must be True for multi-token training")
|
2085 |
-
|
2086 |
-
past_key_values = outputs.past_key_values
|
2087 |
-
causal_mask = outputs.causal_mask
|
2088 |
-
position_embeddings = outputs.position_embeddings
|
2089 |
-
|
2090 |
-
if labels is not None:
|
2091 |
-
labels = labels.to(hidden_states.device)
|
2092 |
-
|
2093 |
-
def _tokenwise_forward(hidden_states: torch.Tensor, token_idx):
|
2094 |
-
## Model forward
|
2095 |
-
layer = self.tokenwise_last_layers[token_idx]
|
2096 |
-
lm_head = self.tokenwise_lm_heads[token_idx]
|
2097 |
-
|
2098 |
-
layer_outputs = layer(
|
2099 |
-
hidden_states,
|
2100 |
-
attention_mask=causal_mask,
|
2101 |
-
position_ids=position_ids,
|
2102 |
-
past_key_values=past_key_values, # TODO: update past_key_values?
|
2103 |
-
output_attentions=output_attentions,
|
2104 |
-
use_cache=use_cache,
|
2105 |
-
cache_position=cache_position,
|
2106 |
-
position_embeddings=position_embeddings,
|
2107 |
-
)
|
2108 |
-
last_hidden_states = layer_outputs[0]
|
2109 |
-
if num_logits_to_keep > 0:
|
2110 |
-
assert labels is None
|
2111 |
-
last_hidden_states = last_hidden_states[:, -num_logits_to_keep:, :]
|
2112 |
-
tokenwise_logits = lm_head(last_hidden_states)
|
2113 |
-
|
2114 |
-
if labels is None:
|
2115 |
-
return {
|
2116 |
-
"loss": None,
|
2117 |
-
"logits": tokenwise_logits,
|
2118 |
-
}
|
2119 |
-
|
2120 |
-
## Compute loss
|
2121 |
-
shift_n = token_idx + 1
|
2122 |
-
shift_logits = tokenwise_logits[..., :-shift_n, :].contiguous()
|
2123 |
-
shift_labels = labels[..., shift_n:].contiguous()
|
2124 |
-
|
2125 |
-
loss_fct = CrossEntropyLoss()
|
2126 |
-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
2127 |
-
shift_labels = shift_labels.view(-1)
|
2128 |
-
|
2129 |
-
tokenwise_loss = loss_fct(shift_logits, shift_labels)
|
2130 |
-
|
2131 |
-
return {
|
2132 |
-
"loss": tokenwise_loss,
|
2133 |
-
"logits": tokenwise_logits,
|
2134 |
-
}
|
2135 |
-
|
2136 |
-
head_fns = [
|
2137 |
-
lambda hidden_states, token_idx=token_idx: _tokenwise_forward(hidden_states, token_idx)
|
2138 |
-
for token_idx in range(self.multi_token_heads)
|
2139 |
-
]
|
2140 |
-
loss, logits = multi_head_forward_backward(hidden_states,
|
2141 |
-
head_fns,
|
2142 |
-
return_keys=("loss", "logits"),
|
2143 |
-
return_only_first_head=True)
|
2144 |
-
|
2145 |
-
if not return_dict:
|
2146 |
-
output = (logits, ) + outputs[1:]
|
2147 |
-
return (loss, ) + output
|
2148 |
-
|
2149 |
-
return CausalLMOutputWithPast(
|
2150 |
-
loss=loss,
|
2151 |
-
logits=logits,
|
2152 |
-
past_key_values=outputs.past_key_values,
|
2153 |
-
hidden_states=outputs.hidden_states,
|
2154 |
-
attentions=outputs.attentions,
|
2155 |
-
)
|
2156 |
-
|
2157 |
@add_start_docstrings_to_model_forward(MOTIF_INPUTS_DOCSTRING)
|
2158 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
2159 |
def forward(
|
@@ -2209,8 +1332,6 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
|
2209 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
2210 |
|
2211 |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
2212 |
-
outputs_include_causal_mask = self.multi_token_heads is not None
|
2213 |
-
outputs_include_position_embeddings = self.multi_token_heads is not None
|
2214 |
outputs: MotifModelOutputWithPast = self.model(
|
2215 |
input_ids=input_ids,
|
2216 |
attention_mask=attention_mask,
|
@@ -2222,25 +1343,12 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
|
2222 |
output_hidden_states=output_hidden_states,
|
2223 |
return_dict=return_dict,
|
2224 |
cache_position=cache_position,
|
2225 |
-
outputs_include_causal_mask=outputs_include_causal_mask,
|
2226 |
-
outputs_include_position_embeddings=outputs_include_position_embeddings,
|
2227 |
)
|
2228 |
|
2229 |
hidden_states = outputs[0]
|
2230 |
|
2231 |
-
if self.multi_token_heads is not None:
|
2232 |
-
return self.multi_token_forward_backward(hidden_states,
|
2233 |
-
outputs,
|
2234 |
-
labels,
|
2235 |
-
position_ids,
|
2236 |
-
output_attentions,
|
2237 |
-
use_cache,
|
2238 |
-
cache_position,
|
2239 |
-
return_dict,
|
2240 |
-
num_logits_to_keep=num_logits_to_keep)
|
2241 |
-
|
2242 |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
2243 |
-
hidden_states = hidden_states
|
2244 |
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
2245 |
logits = logits.float()
|
2246 |
|
@@ -2254,7 +1362,6 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
|
2254 |
loss_fct = CrossEntropyLoss()
|
2255 |
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
2256 |
shift_labels = shift_labels.view(-1)
|
2257 |
-
# Enable model parallelism
|
2258 |
shift_labels = shift_labels.to(shift_logits.device)
|
2259 |
loss = loss_fct(shift_logits, shift_labels)
|
2260 |
|
@@ -2268,4 +1375,4 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
|
2268 |
past_key_values=outputs.past_key_values,
|
2269 |
hidden_states=outputs.hidden_states,
|
2270 |
attentions=outputs.attentions,
|
2271 |
-
)
|
|
|
1 |
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
from typing import List, Optional, Tuple, Union
|
4 |
|
5 |
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
import torch.utils.checkpoint
|
8 |
from torch import nn
|
9 |
from torch.nn import CrossEntropyLoss
|
10 |
+
from transformers.activations import ACT2CLS as _ACT2CLS
|
11 |
+
from transformers.activations import ClassInstantier
|
12 |
from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
13 |
from transformers.generation import GenerationMixin
|
14 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
15 |
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
16 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
|
|
|
|
|
|
|
17 |
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
18 |
from transformers.modeling_utils import PreTrainedModel
|
19 |
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
20 |
+
from transformers.utils import (add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available,
|
21 |
+
is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
from .configuration_motif import MotifConfig
|
24 |
|
25 |
|
26 |
class PolyNorm(torch.nn.Module):
|
27 |
+
"""
|
28 |
A trainable activation function introduced in https://arxiv.org/html/2411.03884v1.
|
29 |
+
The code is copied from https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
|
|
|
30 |
"""
|
31 |
|
32 |
def __init__(self, eps=1e-6):
|
|
|
43 |
x ** 2) + self.weight[2] * self._norm(x) + self.bias
|
44 |
|
45 |
|
46 |
+
CUSTOM_ACT2CLS = {"poly_norm": PolyNorm}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
ACT2CLS = {**_ACT2CLS, **CUSTOM_ACT2CLS}
|
48 |
ACT2FN = ClassInstantier(ACT2CLS)
|
49 |
|
50 |
+
logger = logging.get_logger(__name__)
|
51 |
+
|
52 |
+
if is_flash_attn_2_available():
|
53 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
54 |
+
|
55 |
+
_CONFIG_FOR_DOC = "MotifConfig"
|
56 |
|
57 |
|
58 |
class MotifRMSNorm(nn.Module):
|
|
|
76 |
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
77 |
|
78 |
|
79 |
+
ALL_LAYERNORM_LAYERS.append(MotifRMSNorm)
|
80 |
|
81 |
|
82 |
class MotifRotaryEmbeddingWithCache(nn.Module):
|
|
|
108 |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
109 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
110 |
|
|
|
111 |
self._set_cos_sin_cache(seq_len=max_position_embeddings,
|
112 |
device=self.inv_freq.device,
|
113 |
dtype=torch.get_default_dtype())
|
|
|
128 |
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
129 |
|
130 |
return (
|
131 |
+
self.cos_cached[ :seq_len].to(dtype=x.dtype),
|
132 |
+
self.sin_cached[ :seq_len].to(dtype=x.dtype),
|
133 |
)
|
134 |
|
135 |
|
|
|
136 |
class MotifRotaryEmbedding(nn.Module):
|
137 |
|
138 |
def __init__(
|
|
|
163 |
self.max_seq_len_cached = max_position_embeddings
|
164 |
self.original_max_seq_len = max_position_embeddings
|
165 |
else:
|
|
|
166 |
if config.rope_scaling is not None:
|
167 |
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
168 |
else:
|
|
|
224 |
def rotate_half(x):
|
225 |
"""
|
226 |
Rotates half of the dimensions of the input tensor using torch.roll and in-place negation.
|
227 |
+
|
228 |
Args:
|
229 |
x (torch.Tensor): The input tensor.
|
230 |
+
|
231 |
Returns:
|
232 |
torch.Tensor: A tensor where the latter half of the dimensions are negated
|
233 |
and moved before the first half.
|
|
|
239 |
return rotated_tensor
|
240 |
|
241 |
|
242 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
243 |
"""
|
244 |
Applies rotary position embeddings to the input tensors.
|
245 |
|
|
|
248 |
k (torch.Tensor): Key tensor of shape (B, NH, S, D_KV).
|
249 |
cos (torch.Tensor): Cosine values for rotary embedding.
|
250 |
sin (torch.Tensor): Sine values for rotary embedding.
|
251 |
+
unsqueeze_dim (int, optional): Dimension along which `cos` and `sin` are unsqueezed.
|
252 |
Defaults to 1.
|
|
|
|
|
|
|
253 |
|
254 |
Returns:
|
255 |
Tuple[torch.Tensor, torch.Tensor]: Returns transformed query and key tensors after applying rotary embeddings.
|
256 |
"""
|
257 |
'''
|
258 |
+
# (B, NH, S, D_KV) -> (B, S, NH, D_KV)
|
259 |
cos = cos.unsqueeze(unsqueeze_dim)
|
260 |
sin = sin.unsqueeze(unsqueeze_dim)
|
261 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
262 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
263 |
'''
|
264 |
+
device = q.device
|
265 |
+
return map(
|
266 |
+
lambda x: (x * cos[position_ids].unsqueeze(unsqueeze_dim).to(device)) +
|
267 |
+
(rotate_half(x) * sin[position_ids].unsqueeze(unsqueeze_dim).to(device)), (q, k))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
|
|
|
|
|
|
|
269 |
|
|
|
|
|
|
|
|
|
270 |
class MotifMLP(nn.Module):
|
271 |
|
272 |
def __init__(self, config):
|
273 |
super().__init__()
|
274 |
self.hidden_size = config.hidden_size
|
275 |
self.intermediate_size = config.intermediate_size
|
276 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias)
|
277 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias)
|
278 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
|
279 |
self.act_fn = ACT2FN[config.hidden_act]
|
280 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
def forward(self, hidden_state):
|
282 |
+
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
|
284 |
|
285 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
286 |
+
|
287 |
+
|
288 |
"""
|
289 |
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
290 |
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
291 |
+
|
292 |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
293 |
if n_rep == 1:
|
294 |
return hidden_states
|
|
|
297 |
"""
|
298 |
|
299 |
return torch.repeat_interleave(hidden_states, dim=1, repeats=n_rep)
|
300 |
+
|
301 |
|
|
|
|
|
302 |
class MotifAttention(nn.Module):
|
303 |
"""
|
304 |
Differential Attention (DiffAttention) module.
|
305 |
|
306 |
+
Implements the Differential Attention from
|
307 |
"DIFFERENTIAL TRANSFORMER" (https://arxiv.org/pdf/2410.05258).
|
308 |
|
309 |
Overview
|
310 |
Standard transformers often over-allocate attention to irrelevant context.
|
311 |
+
DiffAttention addresses this by computing attention as the difference between
|
312 |
+
two separate softmax attention maps, effectively canceling noise and promoting
|
313 |
sparse, structured attention patterns.
|
314 |
|
315 |
Reference Implementation
|
316 |
https://github.com/microsoft/unilm/tree/master/Diff-Transformer
|
317 |
|
318 |
Args
|
319 |
+
The differential attention mechanism computes attention as the difference of two softmax attention scores, weighted by a learnable scalar λ.
|
320 |
λ is re-parameterized as λ = exp(λ_q1 · λ_k1) − exp(λ_q2 · λ_k2) + λ_init.
|
321 |
- lambda_q1, lambda_q2 (nn.Parameter): Learnable vectors used to compute the first and second components of λ for query transformations.
|
322 |
- lambda_k1, lambda_k2 (nn.Parameter): Learnable vectors used to compute the first and second components of λ for key transformations.
|
323 |
- lambda_init (float): A constant used for initializing λ, typically set as λ_init = 0.8 − 0.6 × exp(−0.3 × (layer_index − 1)).
|
324 |
+
|
325 |
"""
|
326 |
|
327 |
def __init__(self, config: MotifConfig, layer_idx: Optional[int] = None):
|
|
|
344 |
self.rope_theta = config.rope_theta
|
345 |
self.is_causal = True
|
346 |
self.attention_dropout = config.attention_dropout
|
347 |
+
|
|
|
|
|
|
|
|
|
348 |
if (self.head_dim * self.num_heads) != self.hidden_size:
|
349 |
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
350 |
f" and `num_heads`: {self.num_heads}).")
|
|
|
353 |
self.num_key_value_heads //= 2
|
354 |
self.n_rep = self.num_heads // self.num_key_value_heads
|
355 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
357 |
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
|
358 |
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
|
359 |
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
360 |
|
|
|
361 |
for name in ["lambda_q1", "lambda_k1", "lambda_q2", "lambda_k2"]:
|
362 |
setattr(self, name, nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32)))
|
363 |
getattr(self, name).data.normal_(mean=0.0, std=0.1)
|
364 |
|
|
|
365 |
self.subln = MotifRMSNorm(2 * self.head_dim, eps=1e-5)
|
366 |
self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * (layer_idx - 1))
|
367 |
|
368 |
+
self.rotary_emb = MotifRotaryEmbeddingWithCache(self.head_dim,
|
369 |
max_position_embeddings=self.max_position_embeddings,
|
370 |
base=self.rope_theta)
|
371 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
def forward(
|
373 |
self,
|
374 |
hidden_states: torch.Tensor,
|
|
|
378 |
output_attentions: bool = False,
|
379 |
use_cache: bool = False,
|
380 |
cache_position: Optional[torch.LongTensor] = None,
|
381 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
382 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
383 |
bsz, q_len, _ = hidden_states.size()
|
384 |
|
385 |
+
query_states = self.q_proj(hidden_states)
|
386 |
+
key_states = self.k_proj(hidden_states)
|
387 |
+
value_states = self.v_proj(hidden_states)
|
|
|
|
|
388 |
|
389 |
query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
|
390 |
key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
406 |
key_states,
|
407 |
cos,
|
408 |
sin,
|
409 |
+
position_ids=position_ids)
|
410 |
|
411 |
if past_key_value is not None:
|
412 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
413 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
414 |
|
|
|
415 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
416 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
417 |
|
|
|
418 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
419 |
|
420 |
kv_seq_len = key_states.shape[-2]
|
|
|
423 |
attention_mask = torch.triu(
|
424 |
torch.full((q_len, kv_seq_len), float("-inf"), dtype=attn_weights.dtype, device=attn_weights.device),
|
425 |
1 + offset)
|
426 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
427 |
attn_weights = attn_weights + attention_mask
|
428 |
|
|
|
429 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
430 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
431 |
|
|
|
432 |
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(attn_weights)
|
433 |
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(attn_weights)
|
434 |
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
435 |
attn_weights = attn_weights.view(bsz, self.num_heads, 2, q_len, -1)
|
436 |
attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1]
|
|
|
|
|
|
|
|
|
|
|
437 |
|
|
|
438 |
attn_output = torch.matmul(attn_weights, value_states)
|
439 |
|
|
|
440 |
attn_output = self.subln(attn_output)
|
441 |
attn_output = attn_output * (1 - self.lambda_init)
|
442 |
|
443 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim * 2):
|
444 |
raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
445 |
f" {attn_output.size()}")
|
446 |
+
|
|
|
|
|
447 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
448 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
449 |
|
450 |
+
attn_output = self.o_proj(attn_output)
|
451 |
|
452 |
if not output_attentions:
|
453 |
attn_weights = None
|
|
|
455 |
return attn_output, attn_weights, past_key_value
|
456 |
|
457 |
|
|
|
458 |
class MotifFlashAttention2(MotifAttention):
|
459 |
"""
|
460 |
Motif flash attention module, following Motif attention module. This module inherits from `MotifAttention`
|
|
|
464 |
config.max_window_layers layers.
|
465 |
"""
|
466 |
|
|
|
467 |
def __init__(self, *args, **kwargs):
|
468 |
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
469 |
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
470 |
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
471 |
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
472 |
|
473 |
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
474 |
|
475 |
+
logger.info(f'flash attention is used {not self._flash_attn_uses_top_left_mask}')
|
476 |
+
|
477 |
def _reshape_heads(self, tensor, batch_size, seq_len):
|
478 |
"""2-way head split tensor reshape"""
|
479 |
return tensor.reshape(batch_size, seq_len, self.num_heads, 2, self.head_dim)
|
|
|
483 |
return tensor.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
484 |
|
485 |
def _compute_attention(self, query_states, key_states, value_states, attention_mask, q_len, position_ids,
|
486 |
+
dropout_rate, sliding_window):
|
487 |
"""Flash Attention 2 implements"""
|
488 |
+
_input_type = query_states.dtype
|
489 |
+
scale_factor = 1.0 / math.sqrt(self.head_dim)
|
490 |
+
if not self._flash_attn_uses_top_left_mask:
|
491 |
+
causal = self.is_causal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
492 |
else:
|
493 |
+
causal = self.is_causal and q_len != 1
|
494 |
+
|
495 |
+
attn_out = _flash_attention_forward(query_states.bfloat16(),
|
496 |
+
key_states.bfloat16(),
|
497 |
+
value_states.bfloat16(),
|
498 |
attention_mask,
|
499 |
q_len,
|
500 |
position_ids=position_ids,
|
501 |
dropout=dropout_rate,
|
502 |
sliding_window=sliding_window,
|
503 |
+
is_causal=True,
|
504 |
+
softmax_scale=scale_factor,
|
505 |
use_top_left_mask=self._flash_attn_uses_top_left_mask)
|
506 |
+
return attn_out.to(_input_type)
|
507 |
|
508 |
def forward(
|
509 |
self,
|
|
|
518 |
):
|
519 |
bsz, q_len, _ = hidden_states.size()
|
520 |
|
521 |
+
query_states = self.q_proj(hidden_states)
|
522 |
+
key_states = self.k_proj(hidden_states)
|
523 |
+
value_states = self.v_proj(hidden_states)
|
524 |
|
525 |
query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
|
526 |
key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
541 |
key_states,
|
542 |
cos,
|
543 |
sin,
|
544 |
+
position_ids=position_ids)
|
545 |
|
546 |
if past_key_value is not None:
|
547 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
548 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
549 |
|
|
|
550 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
551 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
552 |
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
|
|
555 |
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
556 |
# cast them back in float16 just to be sure everything works as expected.
|
557 |
input_dtype = query_states.dtype
|
558 |
+
if input_dtype == torch.float32:
|
559 |
if torch.is_autocast_enabled():
|
560 |
target_dtype = torch.get_autocast_gpu_dtype()
|
561 |
# Handle the case where the model is quantized
|
|
|
582 |
value_states = value_states.transpose(1, 2)
|
583 |
|
584 |
if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None
|
585 |
+
and self.layer_idx >= self.config.max_window_layers):
|
586 |
sliding_window = self.config.sliding_window
|
587 |
else:
|
588 |
sliding_window = None
|
|
|
602 |
k1, k2 = k1.contiguous(), k2.contiguous()
|
603 |
v1, v2 = v1.contiguous(), v2.contiguous()
|
604 |
|
605 |
+
attn11, attn12 = self._compute_attention(q1, k1, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window), \
|
606 |
+
self._compute_attention(q1, k1, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window)
|
607 |
+
attn21, attn22 = self._compute_attention(q2, k2, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window), \
|
608 |
+
self._compute_attention(q2, k2, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window)
|
|
|
|
|
609 |
|
610 |
attn1, attn2 = torch.cat([attn11, attn12], dim=-1), torch.cat([attn21, attn22], dim=-1)
|
611 |
|
|
|
623 |
attn_output = attn_output * (1 - self.lambda_init)
|
624 |
|
625 |
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim * 2):
|
626 |
+
raise ValueError(f"`attn_output` should be of size {(bsz, q_len, self.num_heads, 2*self.head_dim)}, but is"
|
627 |
f" {attn_output.size()}")
|
628 |
|
629 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
630 |
+
attn_output = self.o_proj(attn_output)
|
631 |
|
632 |
return attn_output, None, past_key_value
|
633 |
|
634 |
|
|
|
635 |
class MotifSdpaAttention(MotifAttention):
|
636 |
"""
|
637 |
Motif attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
|
|
639 |
SDPA API.
|
640 |
"""
|
641 |
|
|
|
642 |
def forward(
|
643 |
self,
|
644 |
hidden_states: torch.Tensor,
|
|
|
651 |
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
652 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
653 |
if output_attentions:
|
|
|
654 |
logger.warning_once(
|
655 |
"MotifModel is using MotifSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
656 |
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
|
|
686 |
query_states, key_states = apply_rotary_pos_emb(query_states,
|
687 |
key_states,
|
688 |
cos,
|
689 |
+
sin)
|
|
|
690 |
|
691 |
if past_key_value is not None:
|
692 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
|
|
722 |
MOTIF_ATTENTION_CLASSES = {
|
723 |
"eager": MotifAttention,
|
724 |
"flash_attention_2": MotifFlashAttention2,
|
725 |
+
"sdpa": MotifAttention,
|
726 |
}
|
727 |
|
728 |
|
|
|
729 |
class MotifDecoderLayer(nn.Module):
|
730 |
|
731 |
+
def __init__(self, config: MotifConfig, layer_idx: int):
|
732 |
super().__init__()
|
733 |
self.hidden_size = config.hidden_size
|
|
|
|
|
734 |
if config.sliding_window and config._attn_implementation != "flash_attention_2":
|
735 |
logger.warning_once(
|
736 |
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
|
737 |
"unexpected results may be encountered.")
|
738 |
+
self.self_attn = MOTIF_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
|
|
|
|
|
|
739 |
self.mlp = MotifMLP(config)
|
740 |
+
|
741 |
+
self.input_layernorm = MotifRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
742 |
+
self.post_attention_layernorm = MotifRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
|
|
|
|
|
|
|
743 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
744 |
|
745 |
def forward(
|
746 |
self,
|
|
|
778 |
|
779 |
residual = hidden_states
|
780 |
|
781 |
+
hidden_states = self.input_layernorm(hidden_states)
|
782 |
|
783 |
# Self Attention
|
784 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
|
795 |
|
796 |
# Fully Connected
|
797 |
residual = hidden_states
|
798 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
799 |
+
hidden_states = self.mlp(hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
800 |
hidden_states = residual + hidden_states
|
801 |
|
802 |
outputs = (hidden_states, )
|
|
|
846 |
def _init_weights(self, module):
|
847 |
module_std = self.config.initializer_range
|
848 |
if isinstance(module, nn.Linear):
|
849 |
+
module.weight.data.normal_(mean=0.0, std=module_std)
|
850 |
+
module.weight.data = torch.where(abs(module.weight.data) > module_std*3, 0, module.weight.data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
851 |
if module.bias is not None:
|
852 |
module.bias.data.zero_()
|
853 |
|
854 |
elif isinstance(module, nn.Embedding):
|
855 |
+
module.weight.data.normal_(mean=0.0, std=module_std)
|
856 |
+
module.weight.data = torch.where(abs(module.weight.data) > module_std*3, 0, module.weight.data)
|
857 |
if module.padding_idx is not None:
|
858 |
module.weight.data[module.padding_idx].zero_()
|
859 |
|
860 |
|
861 |
@dataclass
|
862 |
class MotifModelOutputWithPast(ModelOutput):
|
863 |
+
"""
|
864 |
+
This augments `BaseModelOutputWithPast` in `transformers.modeling_outputs` with new optional keys: `causal_mask`, `position_embeddings`.
|
865 |
The optional keys are currently used in the following ways:
|
866 |
+
- pass information to the token-wise last attention layers in multi-token training
|
867 |
"""
|
868 |
last_hidden_state: torch.FloatTensor = None
|
869 |
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
|
|
948 |
"""
|
949 |
|
950 |
|
|
|
951 |
@add_start_docstrings(
|
952 |
"The bare Motif Model outputting raw hidden-states without any specific head on top.",
|
953 |
MOTIF_START_DOCSTRING,
|
|
|
964 |
super().__init__(config)
|
965 |
self.padding_idx = config.pad_token_id
|
966 |
self.vocab_size = config.vocab_size
|
|
|
967 |
|
968 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
969 |
+
num_hidden_layers = config.num_hidden_layers
|
970 |
+
self.layers = nn.ModuleList([MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)])
|
971 |
+
self.norm = MotifRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
972 |
self.hidden_size = config.hidden_size
|
973 |
self.num_heads = config.num_attention_heads
|
974 |
self.head_dim = self.hidden_size // self.num_heads
|
|
|
981 |
self.gradient_checkpointing = False
|
982 |
self.post_init()
|
983 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
984 |
def get_input_embeddings(self):
|
985 |
return self.embed_tokens
|
986 |
|
|
|
1019 |
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
1020 |
use_cache = False
|
1021 |
|
|
|
1022 |
return_legacy_cache = False
|
1023 |
if use_cache and not isinstance(past_key_values, Cache):
|
1024 |
return_legacy_cache = True
|
|
|
1032 |
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)")
|
1033 |
|
1034 |
if inputs_embeds is None:
|
1035 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
1036 |
|
1037 |
if cache_position is None:
|
1038 |
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
1039 |
cache_position = torch.arange(past_seen_tokens,
|
1040 |
past_seen_tokens + inputs_embeds.shape[1],
|
1041 |
device=inputs_embeds.device)
|
|
|
1042 |
if position_ids is None:
|
1043 |
position_ids = cache_position.unsqueeze(0)
|
1044 |
+
|
1045 |
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values,
|
1046 |
output_attentions)
|
1047 |
|
1048 |
hidden_states = inputs_embeds
|
1049 |
bsz, q_len, _ = hidden_states.size()
|
|
|
1050 |
position_embeddings = self.rotary_emb(hidden_states, seq_len=q_len)
|
1051 |
|
|
|
1052 |
all_hidden_states = () if output_hidden_states else None
|
1053 |
all_self_attns = () if output_attentions else None
|
1054 |
next_decoder_cache = None
|
|
|
1083 |
|
1084 |
hidden_states = layer_outputs[0]
|
1085 |
|
|
|
|
|
|
|
|
|
1086 |
if use_cache:
|
1087 |
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
1088 |
|
1089 |
if output_attentions:
|
1090 |
all_self_attns += (layer_outputs[1], )
|
1091 |
|
1092 |
+
hidden_states = self.norm(hidden_states)
|
1093 |
+
|
|
|
|
|
1094 |
if output_hidden_states:
|
1095 |
all_hidden_states += (hidden_states, )
|
1096 |
|
|
|
1121 |
output_attentions: bool,
|
1122 |
):
|
1123 |
if self.config._attn_implementation == "flash_attention_2":
|
|
|
|
|
1124 |
if attention_mask is not None and 0.0 in attention_mask:
|
1125 |
return attention_mask
|
1126 |
return None
|
|
|
1147 |
dtype, device = input_tensor.dtype, input_tensor.device
|
1148 |
min_dtype = torch.finfo(dtype).min
|
1149 |
sequence_length = input_tensor.shape[1]
|
1150 |
+
|
1151 |
# SlidingWindowCache or StaticCache
|
1152 |
if using_sliding_window_cache or using_static_cache:
|
1153 |
target_length = past_key_values.get_max_cache_shape()
|
|
|
1242 |
return causal_mask
|
1243 |
|
1244 |
|
|
|
1245 |
class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
1246 |
_tied_weights_keys = ["lm_head.weight"]
|
1247 |
|
|
|
1249 |
super().__init__(config)
|
1250 |
self.model = MotifModel(config)
|
1251 |
self.vocab_size = config.vocab_size
|
|
|
1252 |
|
1253 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1254 |
|
1255 |
# Initialize weights and apply final processing
|
1256 |
self.post_init()
|
1257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1258 |
if getattr(config, "tie_word_embeddings", True):
|
|
|
1259 |
self.tie_weights()
|
|
|
|
|
|
|
1260 |
|
1261 |
def get_input_embeddings(self):
|
1262 |
return self.model.embed_tokens
|
|
|
1276 |
def get_decoder(self):
|
1277 |
return self.model
|
1278 |
|
1279 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1280 |
@add_start_docstrings_to_model_forward(MOTIF_INPUTS_DOCSTRING)
|
1281 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1282 |
def forward(
|
|
|
1332 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1333 |
|
1334 |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
|
|
|
1335 |
outputs: MotifModelOutputWithPast = self.model(
|
1336 |
input_ids=input_ids,
|
1337 |
attention_mask=attention_mask,
|
|
|
1343 |
output_hidden_states=output_hidden_states,
|
1344 |
return_dict=return_dict,
|
1345 |
cache_position=cache_position,
|
|
|
|
|
1346 |
)
|
1347 |
|
1348 |
hidden_states = outputs[0]
|
1349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1350 |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
1351 |
+
hidden_states = hidden_states
|
1352 |
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
1353 |
logits = logits.float()
|
1354 |
|
|
|
1362 |
loss_fct = CrossEntropyLoss()
|
1363 |
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
1364 |
shift_labels = shift_labels.view(-1)
|
|
|
1365 |
shift_labels = shift_labels.to(shift_logits.device)
|
1366 |
loss = loss_fct(shift_logits, shift_labels)
|
1367 |
|
|
|
1375 |
past_key_values=outputs.past_key_values,
|
1376 |
hidden_states=outputs.hidden_states,
|
1377 |
attentions=outputs.attentions,
|
1378 |
+
)
|