zamagi commited on
Commit
5d45a16
·
verified ·
1 Parent(s): 1b71544

Model save

Browse files
README.md ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - axolotl
5
+ - generated_from_trainer
6
+ datasets:
7
+ - Aratako/Magpie-Tanuki-Qwen2.5-72B-Answered
8
+ - Aratako/Open-Platypus-Japanese-masked-formatted
9
+ - llm-jp/wizardlm8x22b-logical-math-coding-sft-ja
10
+ - kanhatakeyama/ramdom-to-fixed-multiturn-Calm3
11
+ - llm-jp/Synthetic-JP-EN-Coding-Dataset
12
+ - llm-jp/magpie-sft-v1.0
13
+ model-index:
14
+ - name: plamo-2-1b-gorilla-chat5
15
+ results: []
16
+ ---
17
+
18
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
19
+ should probably proofread and complete it, then remove this comment. -->
20
+
21
+ [<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
22
+ <details><summary>See axolotl config</summary>
23
+
24
+ axolotl version: `0.7.0`
25
+ ```yaml
26
+
27
+
28
+ # モデルの設定
29
+ base_model: /notebooks/plamo-2-1b-gorilla-chat2 # HuggingFace上のモデル名
30
+ model_type: AutoModelForCausalLM # モデルのロードに使用するクラス
31
+ tokenizer_type: AutoTokenizer # トークナイザのロードに使用するクラス
32
+ trust_remote_code: true # リモートのカスタムコードを信頼してモデルをロード
33
+
34
+ hub_model_id: zamagi/fft-1
35
+ hub_strategy: "end"
36
+ push_dataset_to_hub:
37
+ hf_use_auth_token: true
38
+
39
+ plugins:
40
+ - axolotl.integrations.liger.LigerPlugin
41
+ liger_cross_entropy: false
42
+ liger_rope: true
43
+ liger_rms_norm: true
44
+ liger_swiglu: true
45
+ liger_fused_linear_cross_entropy: true
46
+
47
+ # 8bit/4bit設定(8bitモードでメモリ削減)
48
+ load_in_8bit: false #f # 8bit量子化されたモデルをロード
49
+ load_in_4bit: false # 4bit量子化は使用しない
50
+ strict: false # 重みの厳密な一致を要求しない(追加トークン等がある場合に許容)
51
+
52
+ chat_template: tokenizer_default
53
+
54
+ # データセットの設定
55
+ datasets:
56
+ - path: Aratako/Magpie-Tanuki-Qwen2.5-72B-Answered
57
+ type: chat_template
58
+ field_messages: messages
59
+ message_property_mappings: # メッセージ内のプロパティ名のマッピング
60
+ role: role # 役割(ユーザー/システム/アシスタント)を示すフィールド
61
+ content: content # メッセージ内容を示すフィールド
62
+ roles_to_train: ["assistant"] # 学習対象とする役割(アシスタントの発話のみ学習)
63
+ train_on_eos: last
64
+ # - path: Aratako/magpie-qwen2.5-32b-reasoning-100k-formatted
65
+ # type: chat_template
66
+ # field_messages: conversations
67
+ # message_field_role: role
68
+ # message_field_content: content
69
+ # roles_to_train: ["assistant"] # 学習対象とする役割(アシスタントの発話のみ学習)
70
+ # train_on_eos: last
71
+ # - path: Aratako/magpie-reasoning-llama-nemotron-70b-100k-filtered
72
+ # type: chat_template
73
+ # field_messages: conversations
74
+ # message_field_role: role
75
+ # message_field_content: content
76
+ - path: Aratako/Open-Platypus-Japanese-masked-formatted
77
+ type: chat_template
78
+ field_messages: conversations
79
+ message_property_mappings: # メッセージ内のプロパティ名のマッピング
80
+ role: role # 役割(ユーザー/システム/アシスタント)を示すフィールド
81
+ content: content # メッセージ内容を示すフィールド
82
+ roles_to_train: ["assistant"] # 学習対象とする役割(アシスタントの発話のみ学習)
83
+ train_on_eos: last
84
+
85
+ - path: llm-jp/wizardlm8x22b-logical-math-coding-sft-ja
86
+ type: chat_template
87
+ field_messages: messages
88
+ message_property_mappings: # メッセージ内のプロパティ名のマッピング
89
+ role: role # 役割(ユーザー/システム/アシスタント)を示すフィールド
90
+ content: content # メッセージ内容を示すフィールド
91
+ roles_to_train: ["assistant"] # 学習対象とする役割(アシスタントの発話のみ学習)
92
+ train_on_eos: last
93
+ - path: kanhatakeyama/ramdom-to-fixed-multiturn-Calm3
94
+ split: 20240806filtered
95
+ type: chat_template
96
+ field_messages: messages
97
+ message_property_mappings: # メッセージ内のプロパティ名のマッピング
98
+ role: role # 役割(ユーザー/システム/アシスタント)を示すフィールド
99
+ content: content # メッセージ内容を示すフィールド
100
+ roles_to_train: ["assistant"] # 学習対象とする役割(アシスタントの発話のみ学習)
101
+ train_on_eos: last
102
+ # - path: Aratako/magpie-ultra-v0.1-formatted
103
+ # type: chat_template
104
+ # field_messages: conversations
105
+ # message_field_role: role
106
+ # message_field_content: content
107
+ # - path: Aratako/orca-agentinstruct-1M-v1-selected
108
+ # type: chat_template
109
+ # field_messages: messages
110
+ # message_field_role: role
111
+ # message_field_content: content
112
+ - path: llm-jp/Synthetic-JP-EN-Coding-Dataset
113
+ type: chat_template
114
+ field_messages: messages
115
+ message_property_mappings: # メッセージ内のプロパティ名のマッピング
116
+ role: role # 役割(ユーザー/システム/アシスタント)を示すフィールド
117
+ content: content # メッセージ内容を示すフィールド
118
+ roles_to_train: ["assistant"] # 学習対象とする役割(アシスタントの発話のみ学習)
119
+ train_on_eos: last
120
+ - path: llm-jp/magpie-sft-v1.0 # 使用するデータセット(Hugging Face上のデータセット名)
121
+ type: chat_template # 会話形式のデータセットを使用
122
+ field_messages: conversations # 会話データが格納されたフィールド名
123
+ message_property_mappings: # メッセージ内のプロパティ名のマッピング
124
+ role: role # 役割(ユーザー/システム/アシスタント)を示すフィールド
125
+ content: content # メッセージ内容を示すフィールド
126
+ roles_to_train: ["assistant"] # 学習対象とする役割(アシスタントの発話のみ学習)
127
+ train_on_eos: last
128
+
129
+ shuffle_merged_datasets: true
130
+ dataset_prepared_path: /notebooks/data/fft-data
131
+ val_set_size: 0.002
132
+ output_dir: /notebooks/data/27b-fft-out-1
133
+ dataset_keep_in_memory: false
134
+
135
+ gpu_memory_limit: 48GiB
136
+
137
+ sequence_len: 2048
138
+ sample_packing: true
139
+ eval_sample_packing: false
140
+ pad_to_sequence_len: true
141
+
142
+ adapter:
143
+ lora_model_dir:
144
+ lora_r:
145
+ lora_alpha:
146
+ lora_dropout:
147
+ lora_target_linear:
148
+ lora_fan_in_fan_out:
149
+
150
+
151
+ # トレーニングの設定
152
+ gradient_accumulation_steps: 4
153
+ micro_batch_size: 8
154
+ num_epochs: 2
155
+ optimizer: paged_adamw_8bit
156
+ lr_scheduler:
157
+ cosine_min_lr_ratio: 0.1
158
+ learning_rate: 0.00001
159
+ max_steps: 10000
160
+
161
+ train_on_inputs: false
162
+ group_by_length: false
163
+ bf16: auto
164
+ fp16:
165
+ tf32: false
166
+
167
+ #wandb: false
168
+ #wandb_project: 27b-fft
169
+ #wandb_entity: aratako-lm
170
+ #wandb_watch:
171
+ #wandb_name: attempt-01
172
+ #wandb_log_model:
173
+
174
+ gradient_checkpointing: true
175
+ early_stopping_patience:
176
+ auto_resume_from_checkpoints: true
177
+ local_rank:
178
+ logging_steps: 1
179
+ xformers_attention:
180
+ flash_attention:
181
+
182
+ save_strategy: steps
183
+ save_steps: 100
184
+ save_total_limit: 2
185
+
186
+ warmup_steps: 50
187
+ eval_steps: 100
188
+ eval_batch_size: 1
189
+ eval_table_size:
190
+ eval_max_new_tokens:
191
+
192
+ debug:
193
+ deepspeed: /notebooks/axolotl/deepspeed_configs/zero3_bf16.json
194
+ weight_decay: 0.01
195
+ fsdp:
196
+ fsdp_config:
197
+
198
+
199
+ # 出力の保存設定
200
+ output_dir: /notebooks/output/plamo-2-1b-gorilla-chat5 # チェックポイントや最終モデルの出力先ディレクトリ
201
+ hub_model_id: zamagi/plamo-2-1b-gorilla-chat5 # (オプション) Hugging Face Hubにアップロードする場合のリポジトリ名
202
+
203
+
204
+ ```
205
+
206
+ </details><br>
207
+
208
+ # plamo-2-1b-gorilla-chat5
209
+
210
+ This model was trained from scratch on the Aratako/Magpie-Tanuki-Qwen2.5-72B-Answered, the Aratako/Open-Platypus-Japanese-masked-formatted, the llm-jp/wizardlm8x22b-logical-math-coding-sft-ja, the kanhatakeyama/ramdom-to-fixed-multiturn-Calm3, the llm-jp/Synthetic-JP-EN-Coding-Dataset and the llm-jp/magpie-sft-v1.0 datasets.
211
+ It achieves the following results on the evaluation set:
212
+ - Loss: 1.2854
213
+
214
+ ## Model description
215
+
216
+ More information needed
217
+
218
+ ## Intended uses & limitations
219
+
220
+ More information needed
221
+
222
+ ## Training and evaluation data
223
+
224
+ More information needed
225
+
226
+ ## Training procedure
227
+
228
+ ### Training hyperparameters
229
+
230
+ The following hyperparameters were used during training:
231
+ - learning_rate: 1e-05
232
+ - train_batch_size: 8
233
+ - eval_batch_size: 1
234
+ - seed: 42
235
+ - distributed_type: multi-GPU
236
+ - gradient_accumulation_steps: 4
237
+ - total_train_batch_size: 32
238
+ - optimizer: Use OptimizerNames.PAGED_ADAMW_8BIT with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
239
+ - lr_scheduler_type: cosine
240
+ - lr_scheduler_warmup_steps: 50
241
+ - training_steps: 10000
242
+
243
+ ### Training results
244
+
245
+ | Training Loss | Epoch | Step | Validation Loss |
246
+ |:-------------:|:------:|:-----:|:---------------:|
247
+ | 1.4277 | 0.0002 | 1 | 1.5568 |
248
+ | 1.3262 | 0.0196 | 100 | 1.4437 |
249
+ | 1.2695 | 0.0391 | 200 | 1.4289 |
250
+ | 1.4199 | 0.0587 | 300 | 1.4149 |
251
+ | 1.2383 | 0.0783 | 400 | 1.4073 |
252
+ | 1.418 | 0.0979 | 500 | 1.3987 |
253
+ | 1.2148 | 0.1174 | 600 | 1.3954 |
254
+ | 1.3301 | 0.1370 | 700 | 1.3906 |
255
+ | 1.3418 | 0.1566 | 800 | 1.3850 |
256
+ | 1.248 | 0.1762 | 900 | 1.3801 |
257
+ | 1.3027 | 0.1957 | 1000 | 1.3762 |
258
+ | 1.3965 | 0.2153 | 1100 | 1.3768 |
259
+ | 1.2422 | 0.2349 | 1200 | 1.3747 |
260
+ | 1.2969 | 0.2544 | 1300 | 1.3682 |
261
+ | 1.248 | 0.2740 | 1400 | 1.3629 |
262
+ | 1.3203 | 0.2936 | 1500 | 1.3582 |
263
+ | 1.2637 | 0.3132 | 1600 | 1.3576 |
264
+ | 1.3398 | 0.3327 | 1700 | 1.3559 |
265
+ | 1.1934 | 0.3523 | 1800 | 1.3508 |
266
+ | 1.1992 | 0.3719 | 1900 | 1.3525 |
267
+ | 1.1816 | 0.3914 | 2000 | 1.3475 |
268
+ | 1.1562 | 0.4110 | 2100 | 1.3441 |
269
+ | 1.373 | 0.4306 | 2200 | 1.3374 |
270
+ | 1.2188 | 0.4502 | 2300 | 1.3383 |
271
+ | 1.1738 | 0.4697 | 2400 | 1.3376 |
272
+ | 1.2344 | 0.4893 | 2500 | 1.3318 |
273
+ | 1.291 | 0.5089 | 2600 | 1.3289 |
274
+ | 1.2148 | 0.5285 | 2700 | 1.3254 |
275
+ | 1.248 | 0.5480 | 2800 | 1.3245 |
276
+ | 1.2988 | 0.5676 | 2900 | 1.3260 |
277
+ | 1.3359 | 0.5872 | 3000 | 1.3255 |
278
+ | 1.2109 | 0.6067 | 3100 | 1.3222 |
279
+ | 1.2656 | 0.6263 | 3200 | 1.3191 |
280
+ | 1.2109 | 0.6459 | 3300 | 1.3160 |
281
+ | 1.2676 | 0.6655 | 3400 | 1.3136 |
282
+ | 1.1426 | 0.6850 | 3500 | 1.3137 |
283
+ | 1.2422 | 0.7046 | 3600 | 1.3262 |
284
+ | 1.2188 | 0.7242 | 3700 | 1.3283 |
285
+ | 1.2891 | 0.7437 | 3800 | 1.3277 |
286
+ | 1.1758 | 0.7633 | 3900 | 1.3232 |
287
+ | 1.1846 | 0.7829 | 4000 | 1.3268 |
288
+ | 1.3418 | 0.8025 | 4100 | 1.3235 |
289
+ | 1.2812 | 0.8220 | 4200 | 1.3214 |
290
+ | 1.2793 | 0.8416 | 4300 | 1.3202 |
291
+ | 1.1758 | 0.8612 | 4400 | 1.3196 |
292
+ | 1.2188 | 0.8808 | 4500 | 1.3198 |
293
+ | 1.1719 | 0.9003 | 4600 | 1.3177 |
294
+ | 1.1738 | 0.9199 | 4700 | 1.3129 |
295
+ | 1.3555 | 0.9395 | 4800 | 1.3154 |
296
+ | 1.2207 | 0.9590 | 4900 | 1.3152 |
297
+ | 1.1445 | 0.9786 | 5000 | 1.3110 |
298
+ | 1.2891 | 0.9982 | 5100 | 1.3094 |
299
+ | 1.0527 | 1.0178 | 5200 | 1.3123 |
300
+ | 1.0527 | 1.0374 | 5300 | 1.3120 |
301
+ | 1.1777 | 1.0570 | 5400 | 1.3124 |
302
+ | 1.0879 | 1.0765 | 5500 | 1.3128 |
303
+ | 1.1836 | 1.0961 | 5600 | 1.3114 |
304
+ | 1.1406 | 1.1157 | 5700 | 1.3117 |
305
+ | 1.1152 | 1.1352 | 5800 | 1.3092 |
306
+ | 1.1387 | 1.1548 | 5900 | 1.3106 |
307
+ | 1.2715 | 1.1744 | 6000 | 1.3063 |
308
+ | 1.1855 | 1.1940 | 6100 | 1.3070 |
309
+ | 1.1895 | 1.2135 | 6200 | 1.3070 |
310
+ | 1.1309 | 1.2331 | 6300 | 1.3063 |
311
+ | 1.0918 | 1.2527 | 6400 | 1.3043 |
312
+ | 1.0977 | 1.2723 | 6500 | 1.3050 |
313
+ | 1.0332 | 1.2918 | 6600 | 1.3028 |
314
+ | 0.9697 | 1.3114 | 6700 | 1.3012 |
315
+ | 1.1504 | 1.3310 | 6800 | 1.3006 |
316
+ | 1.1152 | 1.3505 | 6900 | 1.3013 |
317
+ | 1.0127 | 1.3701 | 7000 | 1.2998 |
318
+ | 1.1387 | 1.3897 | 7100 | 1.2993 |
319
+ | 1.0664 | 1.4093 | 7200 | 1.2970 |
320
+ | 1.1299 | 1.4288 | 7300 | 1.2971 |
321
+ | 1.1406 | 1.4484 | 7400 | 1.2971 |
322
+ | 1.0684 | 1.4680 | 7500 | 1.2969 |
323
+ | 1.0938 | 1.4875 | 7600 | 1.2966 |
324
+ | 1.1221 | 1.5071 | 7700 | 1.2943 |
325
+ | 1.0771 | 1.5267 | 7800 | 1.2937 |
326
+ | 1.1211 | 1.5463 | 7900 | 1.2938 |
327
+ | 1.043 | 1.5658 | 8000 | 1.2941 |
328
+ | 1.0537 | 1.5854 | 8100 | 1.2924 |
329
+ | 1.0859 | 1.6050 | 8200 | 1.2918 |
330
+ | 1.1836 | 1.6246 | 8300 | 1.2911 |
331
+ | 1.2188 | 1.6441 | 8400 | 1.2906 |
332
+ | 1.0596 | 1.6637 | 8500 | 1.2912 |
333
+ | 1.041 | 1.6833 | 8600 | 1.2904 |
334
+ | 1.1367 | 1.7028 | 8700 | 1.2904 |
335
+ | 1.1006 | 1.7224 | 8800 | 1.2891 |
336
+ | 1.0996 | 1.7420 | 8900 | 1.2898 |
337
+ | 1.1387 | 1.7616 | 9000 | 1.2883 |
338
+ | 1.1543 | 1.7811 | 9100 | 1.2888 |
339
+ | 1.1328 | 1.8007 | 9200 | 1.2876 |
340
+ | 1.0801 | 1.8203 | 9300 | 1.2872 |
341
+ | 1.1855 | 1.8398 | 9400 | 1.2880 |
342
+ | 1.1113 | 1.8594 | 9500 | 1.2860 |
343
+ | 1.1289 | 1.8790 | 9600 | 1.2865 |
344
+ | 1.1543 | 1.8986 | 9700 | 1.2857 |
345
+ | 1.123 | 1.9181 | 9800 | 1.2856 |
346
+ | 1.0352 | 1.9377 | 9900 | 1.2857 |
347
+ | 0.9189 | 1.9573 | 10000 | 1.2854 |
348
+
349
+
350
+ ### Framework versions
351
+
352
+ - Transformers 4.49.0
353
+ - Pytorch 2.5.1+cu124
354
+ - Datasets 3.2.0
355
+ - Tokenizers 0.21.1
config.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/notebooks/plamo-2-1b-gorilla-chat2",
3
+ "architectures": [
4
+ "PlamoForCausalLM"
5
+ ],
6
+ "attention_window_size": 2048,
7
+ "auto_map": {
8
+ "AutoConfig": "modeling_plamo.PlamoConfig",
9
+ "AutoModelForCausalLM": "pfnet/plamo-2-1b--modeling_plamo.PlamoForCausalLM"
10
+ },
11
+ "bos_token_id": 1,
12
+ "capacity_factor": 1.0,
13
+ "eos_token_id": 2,
14
+ "eval_attention_n_bit": null,
15
+ "eval_mlp_n_bit": null,
16
+ "expert_dropout": 0.0,
17
+ "fp8_accum_dtype": "bfloat16",
18
+ "full_attention_idx": [],
19
+ "group_size": 1024,
20
+ "hidden_size": 2048,
21
+ "hidden_size_per_head": 128,
22
+ "image_feature_size": null,
23
+ "image_proj_type": "linear",
24
+ "image_token_id": null,
25
+ "intermediate_size": 8192,
26
+ "k_expert": null,
27
+ "linear_type": "fp8",
28
+ "mamba_chunk_size": 256,
29
+ "mamba_d_conv": 4,
30
+ "mamba_d_state": 64,
31
+ "mamba_enabled": true,
32
+ "mamba_num_heads": 32,
33
+ "mamba_step": 2,
34
+ "max_position_embeddings": 10485760,
35
+ "model_type": "plamo",
36
+ "n_expert": null,
37
+ "num_attention_heads": 16,
38
+ "num_hidden_layers": 16,
39
+ "num_key_value_heads": 1,
40
+ "rms_norm_eps": 1e-06,
41
+ "shared_intermediate_size": null,
42
+ "sliding_window": 2048,
43
+ "sparse_intermediate_size": null,
44
+ "sparse_step": null,
45
+ "tokenizer_class": "PlamoTokenizer",
46
+ "torch_dtype": "bfloat16",
47
+ "transformers_version": "4.49.0",
48
+ "use_cache": false,
49
+ "use_predefined_initial_state": false,
50
+ "vocab_size": 100000
51
+ }
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "do_sample": true,
5
+ "eos_token_id": 2,
6
+ "transformers_version": "4.49.0"
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cdac87f6a55c17178e232bce355494dc99ad623b224baaa982f86e936984957
3
+ size 2582909184
modeling_plamo.py ADDED
@@ -0,0 +1,1699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import math
3
+ import warnings
4
+ from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, Union
5
+
6
+ try:
7
+ # It is difficult to install mamba_ssm in login node because
8
+ # it requires GPU for installation
9
+ import mamba_ssm
10
+ except ModuleNotFoundError:
11
+ warnings.warn("mamba_ssm could not be imported", stacklevel=2)
12
+ try:
13
+ # It is difficult to install causal_conv1d in login node because
14
+ # it requires GPU for installation
15
+ import causal_conv1d.causal_conv1d_interface as causal_conv1d
16
+ except ModuleNotFoundError:
17
+ warnings.warn("causal_conv1d could not be imported", stacklevel=2)
18
+ import torch
19
+ from torch import nn
20
+ from torch.nn import functional as F
21
+ from transformers import PretrainedConfig, PreTrainedModel
22
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
23
+
24
+
25
+ def _is_first_token(mask: torch.Tensor) -> torch.Tensor:
26
+ assert mask.dtype == torch.bool
27
+ B, Nh, q_len, kv_len = mask.shape
28
+ mask = mask[:, :, :, -q_len:]
29
+ cont = q_len != kv_len
30
+ v = False if cont else True
31
+ out = torch.logical_not(torch.diagonal(mask, offset=-1, dim1=-2, dim2=-1).bool())
32
+ out = torch.cat(
33
+ [
34
+ torch.full(size=(B, Nh, 1), dtype=torch.bool, device=out.device, fill_value=v),
35
+ out,
36
+ ],
37
+ dim=-1,
38
+ )
39
+ return out
40
+
41
+
42
+ def _swiglu(h: torch.Tensor) -> torch.Tensor:
43
+ h0, h1 = h.chunk(2, dim=-1)
44
+ return torch.nn.functional.silu(h0) * h1
45
+
46
+
47
+ class RotaryEmbedding(torch.nn.Module):
48
+ def __init__(
49
+ self, dim: int, max_position_embeddings: int = 2048, base: int = 10000, device: Optional[torch.device] = None
50
+ ) -> None:
51
+ super().__init__()
52
+
53
+ self.dim = dim
54
+ self.max_position_embeddings = max_position_embeddings
55
+ self.base = base
56
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
57
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
58
+
59
+ # Build here to make `torch.jit.trace` work.
60
+ self._set_cos_sin_cache(
61
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
62
+ )
63
+
64
+ def _set_cos_sin_cache(self, seq_len: int, device: Any, dtype: Any) -> None:
65
+ self.max_seq_len_cached = seq_len
66
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) # type: ignore
67
+
68
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
69
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
70
+ emb = torch.cat((freqs, freqs), dim=-1)
71
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
72
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
73
+
74
+ def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
75
+ # x: [bs, num_attention_heads, seq_len, head_size]
76
+ if seq_len > self.max_seq_len_cached:
77
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
78
+
79
+ return (
80
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), # type: ignore
81
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), # type: ignore
82
+ )
83
+
84
+
85
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
86
+ """Rotates half the hidden dims of the input."""
87
+ x1 = x[..., : x.shape[-1] // 2]
88
+ x2 = x[..., x.shape[-1] // 2 :]
89
+ return torch.cat((-x2, x1), dim=-1)
90
+
91
+
92
+ def _rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
93
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
94
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
95
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
96
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
97
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
98
+ x_embed = (x * cos) + (_rotate_half(x) * sin)
99
+ return x_embed
100
+
101
+
102
+ class LinearType(str, enum.Enum):
103
+ Normal = "normal"
104
+ Fp8 = "fp8"
105
+ Fp8Retain = "fp8-retain"
106
+
107
+
108
+ class PlamoConfig(PretrainedConfig): # type: ignore
109
+ model_type: str = "plamo"
110
+
111
+ def __init__(
112
+ self,
113
+ hidden_size: int = 4096,
114
+ num_hidden_layers: int = 32,
115
+ rms_norm_eps: float = 1e-6,
116
+ tie_word_embeddings: bool = True,
117
+ # Attention
118
+ num_attention_heads: int = 32,
119
+ num_key_value_heads: int = 4,
120
+ hidden_size_per_head: int = 128,
121
+ max_position_embeddings: int = 2048,
122
+ attention_window_size: int = 2048,
123
+ full_attention_idx: list[int] | None = None,
124
+ # Mamba
125
+ mamba_d_state: int = 64,
126
+ mamba_d_conv: int = 4,
127
+ mamba_num_heads: int = 64,
128
+ mamba_step: int = 2,
129
+ mamba_chunk_size: int = 256,
130
+ mamba_enabled: bool = True,
131
+ # MLP
132
+ intermediate_size: int = 13312,
133
+ # Tokenizer
134
+ vocab_size: int = 32000,
135
+ tokenizer_class: str = "PlamoTokenizer",
136
+ pad_token_id: Optional[int] = None,
137
+ bos_token_id: int = 1,
138
+ eos_token_id: int = 2,
139
+ # Multimodal
140
+ image_token_id: Optional[int] = None,
141
+ image_feature_size: Optional[int] = None,
142
+ image_proj_type: Literal["linear", "mlp"] = "linear",
143
+ # FP8
144
+ linear_type: LinearType = LinearType.Normal,
145
+ fp8_accum_dtype: Optional[str] = None,
146
+ # Evaluation
147
+ eval_attention_n_bit: Optional[int] = None,
148
+ eval_mlp_n_bit: Optional[int] = None,
149
+ use_cache: bool = True,
150
+ **kwargs: Any,
151
+ ) -> None:
152
+ # max_position_embeddings is often used to determine the max length during inference,
153
+ # but samba should have extrapolation abilities
154
+ self.max_position_embeddings = max(10 * 1024 * 1024, max_position_embeddings)
155
+ self.hidden_size = hidden_size
156
+ self.rms_norm_eps = rms_norm_eps
157
+
158
+ self.num_hidden_layers = num_hidden_layers
159
+ self.num_attention_heads = num_attention_heads
160
+ self.hidden_size_per_head = hidden_size_per_head
161
+ self.num_key_value_heads = num_key_value_heads
162
+ self.attention_window_size = attention_window_size
163
+ self.full_attention_idx = full_attention_idx if full_attention_idx is not None else []
164
+
165
+ self.mamba_d_state = mamba_d_state
166
+ self.mamba_d_conv = mamba_d_conv
167
+ self.mamba_num_heads = mamba_num_heads
168
+ self.mamba_step = mamba_step
169
+ self.mamba_chunk_size = mamba_chunk_size
170
+ self.mamba_enabled = mamba_enabled
171
+
172
+ self.intermediate_size = intermediate_size
173
+
174
+ self.vocab_size = vocab_size
175
+
176
+ self.image_token_id = image_token_id
177
+ self.image_feature_size = image_feature_size
178
+ self.image_proj_type = image_proj_type
179
+
180
+ self.linear_type = linear_type
181
+ self.fp8_accum_dtype = fp8_accum_dtype
182
+
183
+ self.eval_attention_n_bit = eval_attention_n_bit
184
+ self.eval_mlp_n_bit = eval_mlp_n_bit
185
+ self.use_cache = use_cache
186
+
187
+ # fields for vLLM
188
+ self.sliding_window = attention_window_size
189
+
190
+ super().__init__(
191
+ tokenizer_class=tokenizer_class,
192
+ pad_token_id=pad_token_id,
193
+ bos_token_id=bos_token_id,
194
+ eos_token_id=eos_token_id,
195
+ tie_word_embeddings=tie_word_embeddings,
196
+ **kwargs,
197
+ )
198
+
199
+
200
+ class PlamoAttentionCache(torch.nn.Module):
201
+ def __init__(self, key: torch.Tensor, value: torch.Tensor) -> None:
202
+ super().__init__()
203
+ B, nh, L, c = key.shape
204
+ assert len(value.shape) == 4
205
+ assert value.shape[0] == B
206
+ assert value.shape[2] == L
207
+ self.register_parameter("key", torch.nn.Parameter(key, requires_grad=False))
208
+ self.register_parameter("value", torch.nn.Parameter(value, requires_grad=False))
209
+
210
+
211
+ class PlamoMambaCache(torch.nn.Module):
212
+ def __init__(self, conv_state: torch.Tensor, ssm_state: torch.Tensor) -> None:
213
+ super().__init__()
214
+ # conv_state: [B, C, d_conv]
215
+ # ssm_state: [B, nhead, nchanel_per_head, d_state]
216
+ assert len(conv_state.shape) == 3
217
+ assert len(ssm_state.shape) == 4
218
+ assert conv_state.shape[0] == ssm_state.shape[0]
219
+ self.register_parameter("conv_state", torch.nn.Parameter(conv_state, requires_grad=False))
220
+ self.register_parameter("ssm_state", torch.nn.Parameter(ssm_state, requires_grad=False))
221
+
222
+
223
+ PlamoLayerCache = PlamoAttentionCache | PlamoMambaCache
224
+
225
+
226
+ class PlamoCache(torch.nn.Module):
227
+ """
228
+ stores states of the model for fast decoding.
229
+ `transformers` uses `transformers.Cache` for this purpose, but the interface and variable names are
230
+ deeply dependent on Transformers architecture (e.g., `key_states`) and it is difficult to use
231
+ other architectures (e.g., Mamba).
232
+ This class provides a similar interface to `transformers.Cache`, but is designed to also handle
233
+ the state of Mamba properly.
234
+ """
235
+
236
+ def __init__(self, config: PlamoConfig) -> None:
237
+ super().__init__()
238
+ self.config = config
239
+ self.cache = torch.nn.ModuleList([None for _ in range(config.num_hidden_layers)]) # type: ignore
240
+
241
+ def append_kv(self, key: torch.Tensor, value: torch.Tensor, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
242
+ c = self.cache[layer_idx]
243
+ if c is None:
244
+ return key, value
245
+ assert isinstance(c, PlamoAttentionCache)
246
+
247
+ def _validate(cache: torch.Tensor, new_tensor: torch.Tensor) -> None:
248
+ assert len(cache.shape) == 4
249
+ assert len(new_tensor.shape) == 4
250
+ assert cache.shape[0] == new_tensor.shape[0]
251
+ assert cache.shape[1] == new_tensor.shape[1]
252
+ assert cache.shape[3] == new_tensor.shape[3]
253
+
254
+ _validate(c.key, key)
255
+ _validate(c.value, value)
256
+ assert key.shape[2] == value.shape[2]
257
+ return torch.cat([c.key, key], dim=2), torch.cat([c.value, value], dim=2)
258
+
259
+ def update_attention(
260
+ self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int
261
+ ) -> PlamoAttentionCache:
262
+ full_attn = layer_idx in self.config.full_attention_idx
263
+ window_size = self.config.attention_window_size
264
+
265
+ if self.cache[layer_idx] is None:
266
+ if full_attn:
267
+ self.cache[layer_idx] = PlamoAttentionCache(key_states, value_states)
268
+ else:
269
+ self.cache[layer_idx] = PlamoAttentionCache(
270
+ key_states[:, :, -window_size:, :], value_states[:, :, -window_size:, :]
271
+ )
272
+ else:
273
+ c = self.cache[layer_idx]
274
+ assert isinstance(c, PlamoAttentionCache)
275
+ k, v = self.append_kv(key_states, value_states, layer_idx)
276
+ if full_attn:
277
+ c.key.data = k
278
+ c.value.data = v
279
+ else:
280
+ c.key.data = k[:, :, -window_size:, :]
281
+ c.value.data = v[:, :, -window_size:, :]
282
+ return self.cache[layer_idx] # type: ignore
283
+
284
+ def update_mamba(self, conv_state: torch.Tensor, ssm_state: torch.Tensor, layer_idx: int) -> PlamoMambaCache:
285
+ if self.cache[layer_idx] is None:
286
+ self.cache[layer_idx] = PlamoMambaCache(conv_state, ssm_state)
287
+ else:
288
+ c = self.cache[layer_idx]
289
+ assert isinstance(c, PlamoMambaCache)
290
+ assert c.conv_state.shape == conv_state.shape
291
+ assert c.ssm_state.shape == ssm_state.shape
292
+ c.conv_state.data = conv_state
293
+ c.ssm_state.data = ssm_state
294
+ return self.cache[layer_idx] # type: ignore
295
+
296
+ def __getitem__(self, layer_idx: int) -> PlamoLayerCache | None:
297
+ assert layer_idx < len(self.cache)
298
+ layer_cache = self.cache[layer_idx]
299
+ return layer_cache # type: ignore
300
+
301
+ def __len__(self) -> int:
302
+ return len(self.cache)
303
+
304
+ def get_seq_length(self, layer_idx: Optional[int] = None) -> int:
305
+ if layer_idx is not None:
306
+ c = self.cache[layer_idx]
307
+ assert isinstance(c, PlamoAttentionCache)
308
+ return c.key.shape[2] # type: ignore
309
+
310
+ sequence_length: int | None = None
311
+ for layer_cache in self.cache:
312
+ if isinstance(layer_cache, PlamoAttentionCache):
313
+ sequence_length = (
314
+ max(layer_cache.key.shape[2], sequence_length)
315
+ if sequence_length is not None
316
+ else layer_cache.key.shape[2]
317
+ )
318
+ assert sequence_length is not None
319
+ return sequence_length
320
+
321
+ def get_max_length(self) -> int | None:
322
+ return None
323
+
324
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
325
+ """Given the sequence length of the new inputs, returns the usable length of the cache."""
326
+ # Cache without size limit -> all cache is usable
327
+ # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
328
+ # length, we will need to evict part of the cache (and thus not all cache is usable)
329
+ max_length = self.get_max_length()
330
+ previous_seq_length = self.get_seq_length(layer_idx)
331
+ if max_length is not None and previous_seq_length + new_seq_length > max_length:
332
+ return max_length - new_seq_length
333
+ return previous_seq_length
334
+
335
+ def reorder_cache(self, beam_idx: torch.Tensor) -> None:
336
+ def _mamba(cache: PlamoMambaCache) -> PlamoMambaCache:
337
+ return PlamoMambaCache(
338
+ conv_state=cache.conv_state.index_select(0, beam_idx),
339
+ ssm_state=cache.ssm_state.index_select(0, beam_idx),
340
+ )
341
+
342
+ def _attention(cache: PlamoAttentionCache) -> PlamoAttentionCache:
343
+ return PlamoAttentionCache(
344
+ key=cache.key.index_select(0, beam_idx),
345
+ value=cache.value.index_select(0, beam_idx),
346
+ )
347
+
348
+ for i in range(len(self.cache)):
349
+ if self.cache[i] is None:
350
+ continue
351
+ layer_cache = self.cache[i]
352
+ if isinstance(layer_cache, PlamoMambaCache):
353
+ self.cache[i] = _mamba(layer_cache)
354
+ else:
355
+ assert isinstance(layer_cache, PlamoAttentionCache)
356
+ self.cache[i] = _attention(layer_cache)
357
+
358
+ @property
359
+ def seen_tokens(self) -> int | None:
360
+ return None
361
+
362
+
363
+ class DecoderInput(NamedTuple):
364
+ hidden_states: torch.Tensor
365
+ attention_mask: Optional[torch.Tensor] = None
366
+ past_states: Optional[PlamoCache] = None
367
+ output_hidden_states: Optional[bool] = False
368
+ output_attentions: Optional[bool] = False
369
+ gradient_checkpointing: bool = False
370
+ input_ids: Optional[torch.Tensor] = None
371
+
372
+
373
+ class DecoderOutput(NamedTuple):
374
+ hidden_states: torch.Tensor
375
+ all_hidden_states: Optional[Tuple[torch.Tensor, ...]]
376
+ all_self_attns: Optional[Tuple[torch.Tensor, ...]]
377
+
378
+
379
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
380
+ def _make_causal_mask(
381
+ input_ids_shape: Tuple[int, int], dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
382
+ ) -> torch.Tensor:
383
+ """
384
+ Make causal mask used for bi-directional self-attention.
385
+ """
386
+ bsz, tgt_len = input_ids_shape
387
+ mask = torch.full((tgt_len, tgt_len), float("-inf"), device=device)
388
+ mask_cond = torch.arange(mask.size(-1), device=device)
389
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
390
+ mask = mask.to(dtype)
391
+
392
+ if past_key_values_length > 0:
393
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
394
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
395
+
396
+
397
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
398
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None) -> torch.Tensor:
399
+ """
400
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
401
+ """
402
+ bsz, src_len = mask.size()
403
+ tgt_len = tgt_len if tgt_len is not None else src_len
404
+
405
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
406
+
407
+ inverted_mask = 1.0 - expanded_mask
408
+
409
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), float("-inf")) # type: ignore
410
+
411
+
412
+ def _rms_norm(
413
+ hidden_states: torch.Tensor, weight: Optional[torch.Tensor], eps: float, offset: float = 1.0
414
+ ) -> torch.Tensor:
415
+ input_dtype = hidden_states.dtype
416
+ hidden_states = hidden_states.to(torch.float32)
417
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
418
+ hidden_states = hidden_states * torch.rsqrt(variance + eps)
419
+ hidden_states = hidden_states.to(input_dtype)
420
+ if weight is not None:
421
+ hidden_states = (offset + weight) * hidden_states
422
+ return hidden_states
423
+
424
+
425
+ class RMSNorm(nn.Module):
426
+ def __init__(
427
+ self,
428
+ hidden_size: int,
429
+ eps: float = 1e-6,
430
+ offset: float = 1.0,
431
+ device: Optional[Union[torch.device, str]] = None,
432
+ ) -> None:
433
+ super().__init__()
434
+ self.weight = nn.Parameter(torch.zeros(hidden_size, device=device))
435
+ self.variance_epsilon = eps
436
+ self.offset = offset
437
+
438
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
439
+ return _rms_norm(hidden_states, self.weight, self.variance_epsilon, offset=self.offset)
440
+
441
+
442
+ def get_initial_dt_bias(num_heads: int) -> torch.Tensor:
443
+ dt_min = 0.001
444
+ dt_max = 0.1
445
+ dt = torch.exp(torch.rand(num_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min))
446
+ dt = torch.clamp(dt, 1e-4)
447
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
448
+ return inv_dt
449
+
450
+
451
+ def get_initial_A(num_heads: int) -> torch.Tensor:
452
+ A = torch.arange(1, num_heads + 1, dtype=torch.float32)
453
+ return torch.log(A)
454
+
455
+
456
+ def _bf16_supported_in_triton() -> bool:
457
+ # newer torch (2.2.0 and later?) supports bfloat16 even when using Voltas
458
+ # but triton cannot compile bf16 kernels for Volta
459
+ major, _ = torch.cuda.get_device_capability()
460
+ return major >= 8
461
+
462
+
463
+ def _get_trition_dtype(dtype: torch.dtype) -> torch.dtype:
464
+ if dtype != torch.bfloat16:
465
+ return dtype
466
+ if _bf16_supported_in_triton():
467
+ return dtype
468
+ return torch.float32
469
+
470
+
471
+ def ssd_update_state(
472
+ ssm_state: torch.Tensor,
473
+ x: torch.Tensor,
474
+ dt: torch.Tensor,
475
+ A: torch.Tensor,
476
+ B: torch.Tensor,
477
+ C: torch.Tensor,
478
+ D: torch.Tensor,
479
+ z: torch.Tensor,
480
+ dt_bias: torch.Tensor,
481
+ dt_softplus: bool,
482
+ ) -> torch.Tensor:
483
+ assert ssm_state.dtype == torch.float32
484
+ if dt.is_cuda:
485
+ dtype = _get_trition_dtype(x.dtype)
486
+ else:
487
+ dtype = x.dtype
488
+ if dt.is_cuda:
489
+ f = mamba_ssm.ops.triton.selective_state_update.selective_state_update
490
+ else:
491
+ f = mamba_ssm.ops.triton.selective_state_update.selective_state_update_ref
492
+
493
+ hidden_size_per_head = x.shape[-1]
494
+ d_state = B.shape[-1]
495
+ A = A[:, None, None].expand(-1, hidden_size_per_head, d_state).float()
496
+ dt = dt[..., None].expand(-1, -1, hidden_size_per_head)
497
+ dt_bias = dt_bias[:, None].expand(-1, hidden_size_per_head)
498
+ D = D[:, None].expand(-1, hidden_size_per_head)
499
+ assert ssm_state.dtype == torch.float32
500
+ out = f(
501
+ ssm_state,
502
+ x.to(dtype),
503
+ dt.to(dtype),
504
+ A.float(),
505
+ B.to(dtype),
506
+ C.to(dtype),
507
+ D.float(),
508
+ z.to(dtype),
509
+ dt_bias.float(),
510
+ dt_softplus=dt_softplus,
511
+ )
512
+ return out[:, None] # type: ignore
513
+
514
+
515
+ def _ssd_chunk_scan_combined_naive(
516
+ x: torch.Tensor,
517
+ dt: torch.Tensor,
518
+ A: torch.Tensor,
519
+ B: torch.Tensor,
520
+ C: torch.Tensor,
521
+ D: torch.Tensor,
522
+ z: torch.Tensor,
523
+ dt_bias: torch.Tensor,
524
+ dt_softplus: bool,
525
+ seq_idx: torch.Tensor | None,
526
+ ssm_state: torch.Tensor,
527
+ ) -> tuple[torch.Tensor, torch.Tensor]:
528
+ assert ssm_state.dtype == torch.float32
529
+ length = x.shape[1]
530
+ ys = []
531
+ for i in range(length):
532
+ if i != 0 and seq_idx is not None:
533
+ ssm_state = torch.where(
534
+ (seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None, None],
535
+ torch.zeros_like(ssm_state),
536
+ ssm_state,
537
+ )
538
+ y = ssd_update_state(
539
+ ssm_state,
540
+ x[:, i],
541
+ dt[:, i],
542
+ A,
543
+ B[:, i],
544
+ C[:, i],
545
+ D,
546
+ z=z[:, i],
547
+ dt_bias=dt_bias,
548
+ dt_softplus=dt_softplus,
549
+ )
550
+ ys.append(y)
551
+ return torch.cat(ys, dim=1), ssm_state
552
+
553
+
554
+ def _ssd_chunk_scan_combined_cpu(
555
+ x: torch.Tensor,
556
+ dt: torch.Tensor,
557
+ A: torch.Tensor,
558
+ B: torch.Tensor,
559
+ C: torch.Tensor,
560
+ chunk_size: int,
561
+ D: torch.Tensor,
562
+ z: torch.Tensor,
563
+ dt_bias: torch.Tensor,
564
+ dt_softplus: bool,
565
+ ) -> tuple[torch.Tensor, torch.Tensor]:
566
+ # (bsize, nhead, nchunk, chunk_size)
567
+ dt = dt.float() # We want high precision for this before cumsum
568
+ dt = dt.permute(0, 2, 1).unflatten(2, (-1, chunk_size)) # type: ignore
569
+ if dt_bias is not None:
570
+ dt = dt + dt_bias[None, :, None, None]
571
+ if dt_softplus:
572
+ dt = F.softplus(dt)
573
+ dA = dt * A[None, :, None, None]
574
+ dA_cumsum = torch.cumsum(dA, dim=-1)
575
+
576
+ _, _, nheads, _ = x.shape
577
+ dstate = B.shape[-1]
578
+ _ = dt.shape[2]
579
+
580
+ with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_chunk_state"):
581
+ # Following is equivalent to `mamba_ssm.ops.triton.ssd_combined.chunk_state_ref(B, x, dt, dA_cumsum)`
582
+ # But `einsum` in the above function is too slow in CPU.
583
+ x_ = torch.unflatten(x, 1, (-1, chunk_size))
584
+ assert B.shape[2] == nheads # B should be already expanded
585
+ B_ = torch.unflatten(B, 1, (-1, chunk_size)).to(x.dtype) # (bsize, nchunk, chunk_size, nheads, dstate)
586
+ decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)).to(x.dtype)
587
+ dt_ = dt.to(x.dtype)
588
+
589
+ # einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B_, decay_states, dt_, x_)
590
+ B_ = B_.permute(0, 1, 3, 4, 2) # bchnl
591
+ tmp = dt_ * decay_states # bhcl
592
+ tmp = tmp.permute(0, 2, 1, 3)[:, :, :, None] # bch1l
593
+ tmp = B_ * tmp # bchnl
594
+ x_ = x_.permute(0, 1, 3, 2, 4) # bchlp
595
+ tmp = tmp @ x_ # bchnp
596
+ states = tmp.permute(0, 1, 2, 4, 3) # bchpn
597
+
598
+ states_dtype = states.dtype
599
+ if states.dtype not in [torch.float32, torch.float64]:
600
+ states = states.to(torch.float32)
601
+ with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_state_passing"):
602
+ out, last_state = mamba_ssm.ops.triton.ssd_combined.state_passing_ref(
603
+ states.flatten(start_dim=-2, end_dim=-1),
604
+ dA_cumsum[:, :, :, -1],
605
+ )
606
+ states = torch.unflatten(out, -1, (-1, dstate))
607
+ last_state = torch.unflatten(last_state, -1, (-1, dstate))
608
+ states = states.to(states_dtype)
609
+ with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_chunk_scan"):
610
+ out = mamba_ssm.ops.triton.ssd_combined.chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
611
+
612
+ return out, last_state
613
+
614
+
615
+ @torch.profiler.record_function("ssd_chunk_scan_combined")
616
+ def ssd_chunk_scan_combined(
617
+ x: torch.Tensor,
618
+ dt: torch.Tensor,
619
+ A: torch.Tensor,
620
+ B: torch.Tensor,
621
+ C: torch.Tensor,
622
+ chunk_size: int,
623
+ D: torch.Tensor,
624
+ z: torch.Tensor,
625
+ dt_bias: torch.Tensor,
626
+ dt_softplus: bool,
627
+ return_final_states: bool,
628
+ seq_idx: torch.Tensor | None,
629
+ ssm_state: torch.Tensor | None,
630
+ ) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
631
+ if seq_idx is not None:
632
+ assert seq_idx.dtype == torch.int32
633
+ assert ssm_state is None
634
+ assert not return_final_states
635
+ if ssm_state is not None:
636
+ assert ssm_state.dtype == torch.float32
637
+ assert seq_idx is None
638
+
639
+ length = x.shape[1]
640
+
641
+ """
642
+ state will be updates by following:
643
+ ```
644
+ dt = softplus(dt)
645
+ dA = exp(dt * A)
646
+ state_next = state * dA + dB * x
647
+ ```
648
+
649
+ To avoid updating state, we set dt to -inf and x to 0
650
+ because `softplus(-inf) = 0` and `exp(0) = 1`
651
+ """
652
+ pad = (chunk_size - length % chunk_size) % chunk_size
653
+ x = torch.nn.functional.pad(x, pad=[0, 0, 0, 0, pad, 0], value=0.0)
654
+ dt = torch.nn.functional.pad(dt, pad=[0, 0, pad, 0], value=float("-inf"))
655
+ B = torch.nn.functional.pad(B, pad=[0, 0, 0, 0, pad, 0], value=0.0)
656
+ C = torch.nn.functional.pad(C, pad=[0, 0, 0, 0, pad, 0], value=0.0)
657
+ z = torch.nn.functional.pad(z, pad=[0, 0, 0, 0, pad, 0], value=0.0)
658
+ if seq_idx is not None:
659
+ seq_idx = torch.nn.functional.pad(seq_idx, pad=[pad, 0], value=0)
660
+
661
+ length = x.shape[1]
662
+ assert length % chunk_size == 0, (length, chunk_size)
663
+
664
+ if dt.is_cuda:
665
+ dtype = _get_trition_dtype(x.dtype)
666
+ out = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( # type: ignore
667
+ x.to(dtype),
668
+ dt.to(dtype),
669
+ A.float(),
670
+ B.to(dtype),
671
+ C.to(dtype),
672
+ chunk_size,
673
+ D=D.float(),
674
+ z=z.to(dtype),
675
+ initial_states=ssm_state,
676
+ dt_bias=dt_bias.float(),
677
+ dt_softplus=dt_softplus,
678
+ seq_idx=seq_idx,
679
+ return_final_states=return_final_states,
680
+ )
681
+ if return_final_states:
682
+ return out[0][:, pad:], out[1]
683
+ else:
684
+ assert isinstance(out, torch.Tensor)
685
+ return out[:, pad:]
686
+ else:
687
+ if ssm_state is None and seq_idx is None:
688
+ tmp = _ssd_chunk_scan_combined_cpu(
689
+ x,
690
+ dt,
691
+ A,
692
+ B,
693
+ C,
694
+ chunk_size,
695
+ D=D,
696
+ z=z,
697
+ dt_bias=dt_bias.float(),
698
+ dt_softplus=dt_softplus,
699
+ )
700
+ else:
701
+ if ssm_state is None:
702
+ bsize, _, num_heads, channel = x.shape
703
+ state = B.shape[-1]
704
+ ssm_state = torch.zeros(bsize, num_heads, channel, state, dtype=torch.float32, device=x.device)
705
+ tmp = _ssd_chunk_scan_combined_naive(
706
+ x, dt, A, B, C, D, z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, seq_idx=seq_idx, ssm_state=ssm_state
707
+ )
708
+ tmp = (tmp[0][:, pad:], tmp[1])
709
+ if return_final_states:
710
+ return tmp
711
+ else:
712
+ return tmp[0]
713
+
714
+
715
+ def _causal_conv1d_update(
716
+ conv_state: torch.Tensor, weight: torch.Tensor, xBC: torch.Tensor
717
+ ) -> tuple[torch.Tensor, torch.Tensor]:
718
+ dtype = conv_state.dtype
719
+ xBC = xBC.to(dtype)
720
+ weight = weight.to(dtype)
721
+ if conv_state.is_cuda:
722
+ x = causal_conv1d.causal_conv1d_update(
723
+ x=xBC,
724
+ conv_state=conv_state,
725
+ weight=weight[:, 0, :],
726
+ activation="silu",
727
+ )
728
+ return x, conv_state
729
+ else:
730
+ x = causal_conv1d.causal_conv1d_update_ref(
731
+ x=xBC,
732
+ conv_state=conv_state,
733
+ weight=weight[:, 0, :],
734
+ activation="silu",
735
+ )
736
+ return x, conv_state
737
+
738
+
739
+ def _causal_conv1d_naive(
740
+ conv_state: torch.Tensor, weight: torch.Tensor, x: torch.Tensor, seq_idx: torch.Tensor | None
741
+ ) -> tuple[torch.Tensor, torch.Tensor]:
742
+ length = x.shape[-1]
743
+ out = torch.zeros_like(x)
744
+ for i in range(length):
745
+ if i != 0 and seq_idx is not None:
746
+ conv_state = torch.where(
747
+ (seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None],
748
+ torch.zeros_like(conv_state),
749
+ conv_state,
750
+ )
751
+ out[:, :, i : i + 1], conv_state = _causal_conv1d_update(conv_state, weight, x[:, :, i : i + 1])
752
+ return out, conv_state
753
+
754
+
755
+ @torch.profiler.record_function("causal_conv1d")
756
+ def _causal_conv1d(
757
+ conv_state: torch.Tensor | None, weight: torch.Tensor, x: torch.Tensor, seq_idx: torch.Tensor | None
758
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
759
+ dtype = x.dtype
760
+ if conv_state is not None:
761
+ dtype = conv_state.dtype
762
+ assert seq_idx is None
763
+ if seq_idx is not None:
764
+ assert seq_idx.dtype == torch.int32
765
+ assert conv_state is None
766
+ weight = weight.to(dtype)
767
+ x = x.to(dtype)
768
+
769
+ return_final_states = conv_state is not None
770
+ if weight.is_cuda:
771
+ if x.stride(1) != 1:
772
+ # to channel-last format
773
+ x = x.transpose(-1, -2).contiguous().transpose(-1, -2)
774
+ if conv_state is not None:
775
+ if conv_state.stride(1) != 1:
776
+ # to channel-last format
777
+ conv_state = conv_state.transpose(-1, -2).contiguous().transpose(-1, -2)
778
+ tmp = causal_conv1d.causal_conv1d_fn(
779
+ x=x,
780
+ weight=weight[:, 0, :],
781
+ initial_states=conv_state,
782
+ return_final_states=conv_state is not None,
783
+ activation="silu",
784
+ seq_idx=seq_idx,
785
+ )
786
+ if conv_state is not None:
787
+ x, conv_state = tmp
788
+ else:
789
+ x = tmp
790
+ else:
791
+ if seq_idx is None:
792
+ x, conv_state = causal_conv1d.causal_conv1d_ref(
793
+ x=x,
794
+ initial_states=conv_state,
795
+ return_final_states=True,
796
+ weight=weight[:, 0, :],
797
+ activation="silu",
798
+ )
799
+ else:
800
+ if conv_state is None:
801
+ bsize = x.shape[0]
802
+ dim = weight.shape[0]
803
+ d_conv = weight.shape[-1]
804
+ conv_state = torch.zeros(bsize, dim, d_conv - 1, dtype=x.dtype, device=x.device)
805
+ x, conv_state = _causal_conv1d_naive(conv_state, weight, x, seq_idx)
806
+ if return_final_states:
807
+ return x, conv_state
808
+ else:
809
+ return x, None
810
+
811
+
812
+ class Mamba(torch.nn.Module):
813
+ def __init__(self, config: PlamoConfig, layer_idx: int) -> None:
814
+ super().__init__()
815
+ self.config = config
816
+ self.layer_idx = layer_idx
817
+ self.hidden_size = config.hidden_size
818
+ self.d_state = config.mamba_d_state
819
+ self.d_conv = config.mamba_d_conv
820
+ self.chunk_size = config.mamba_chunk_size
821
+ self.num_heads = config.mamba_num_heads
822
+ # TODO add mamba_hidden_size_per_head config (?)
823
+ self.hidden_size_per_head = config.hidden_size_per_head
824
+
825
+ self.intermediate_size = self.num_heads * self.hidden_size_per_head
826
+
827
+ self.in_proj = torch.nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)
828
+ self.conv1d = torch.nn.Conv1d(
829
+ in_channels=self.intermediate_size,
830
+ out_channels=self.intermediate_size,
831
+ bias=False, # TODO the original implementation uses bias
832
+ kernel_size=self.d_conv,
833
+ groups=self.intermediate_size,
834
+ padding=0,
835
+ )
836
+ self.dt_dim = max(64, self.hidden_size // 16)
837
+ # Notes:
838
+ # Mamba2 removes this linear projection for simplicity (Figure 6 in the paper),
839
+ # but it may degrade the ability of content-length extrapolation.
840
+ self.bcdt_proj = torch.nn.Linear(
841
+ self.intermediate_size,
842
+ self.dt_dim + 2 * self.d_state,
843
+ bias=False,
844
+ )
845
+ self.dt_proj = torch.nn.Linear(self.dt_dim, self.num_heads, bias=False)
846
+
847
+ self.dt_bias = torch.nn.Parameter(get_initial_dt_bias(self.num_heads))
848
+ self.A_log = torch.nn.Parameter(get_initial_A(self.num_heads))
849
+ self.D = torch.nn.Parameter(torch.ones(self.num_heads))
850
+
851
+ # TODO norm weight before gating like Mamba2
852
+ self.dt_norm_weight = torch.nn.Parameter(torch.ones(self.dt_dim))
853
+ self.B_norm_weight = torch.nn.Parameter(torch.ones(self.d_state))
854
+ self.C_norm_weight = torch.nn.Parameter(torch.ones(self.d_state))
855
+
856
+ self.out_proj = torch.nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
857
+
858
+ def _no_weight_decay_param_names(self) -> set[str]:
859
+ return set(["D", "dt_bias", "A_log"])
860
+
861
+ def forward(
862
+ self,
863
+ hidden_states: torch.Tensor,
864
+ attention_mask: Optional[torch.Tensor] = None,
865
+ past_states: Optional[PlamoCache] = None,
866
+ ) -> Tuple[torch.Tensor, Optional[PlamoCache]]:
867
+ bsize, length, _ = hidden_states.shape
868
+ is_update = length == 1 and past_states is not None
869
+
870
+ bool_mask: torch.Tensor | None = None
871
+ seq_idx: torch.Tensor | None = None
872
+ if attention_mask is not None:
873
+ if len(attention_mask.shape) == 2:
874
+ attention_mask = attention_mask[None, None].expand(bsize, 1, -1, -1)
875
+ assert len(attention_mask.shape) == 4
876
+
877
+ if past_states is None:
878
+ # TODO: support seq_idx with cache
879
+ bool_mask_4d = attention_mask == 0
880
+ is_first_token = _is_first_token(bool_mask_4d)[:, 0, :]
881
+ seq_idx = torch.cumsum(is_first_token, dim=-1) - 1
882
+ seq_idx = seq_idx.to(torch.int32)
883
+
884
+ # `generate` function creates attention mask that contains past tokens,
885
+ # but mamba does not use them
886
+ attention_mask = attention_mask[:, 0, -length:, -length:]
887
+ bool_mask = torch.diagonal(attention_mask, dim1=-2, dim2=-1) == 0
888
+
889
+ conv_state: torch.Tensor | None
890
+ ssm_state: torch.Tensor | None
891
+ if past_states is None:
892
+ conv_state = None
893
+ ssm_state = None
894
+ elif past_states[self.layer_idx] is None:
895
+ conv_state = torch.zeros(
896
+ bsize, self.intermediate_size, self.d_conv - 1, dtype=hidden_states.dtype, device=hidden_states.device
897
+ )
898
+ ssm_state = torch.zeros(
899
+ bsize,
900
+ self.num_heads,
901
+ self.hidden_size_per_head,
902
+ self.d_state,
903
+ dtype=torch.float32,
904
+ device=hidden_states.device,
905
+ )
906
+ else:
907
+ c = past_states[self.layer_idx]
908
+ assert isinstance(c, PlamoMambaCache)
909
+ conv_state = c.conv_state
910
+ ssm_state = c.ssm_state
911
+
912
+ zx = self.in_proj(hidden_states)
913
+ zx = zx.reshape(bsize, length, self.num_heads, -1)
914
+ # z: (bsize, length, num_heads, hidden_size_per_head)
915
+ # x: (bsize, length, num_heads, hidden_size_per_head)
916
+ z, x = torch.split(zx, [self.hidden_size_per_head, self.hidden_size_per_head], dim=-1)
917
+
918
+ # conv
919
+ x = x.reshape(bsize, length, -1).transpose(1, 2) # (bsize, intermediate_size, length)
920
+ if bool_mask is not None:
921
+ x = torch.where(bool_mask[:, None, :], x, 0.0)
922
+ if is_update:
923
+ assert conv_state is not None
924
+ x, conv_state = _causal_conv1d_update(conv_state, self.conv1d.weight, x)
925
+ else:
926
+ x, conv_state = _causal_conv1d(conv_state, self.conv1d.weight, x, seq_idx=seq_idx)
927
+ x = x.to(dtype=hidden_states.dtype)
928
+ x = x.transpose(1, 2) # (bsize, length, intermediate_size)
929
+ x = x.reshape(bsize, length, -1)
930
+ # x: (bsize, length, num_heads, hidden_size_per_head)
931
+ # B: (bsize, length, 1, d_state)
932
+ # C: (bsize, length, 1, d_state)
933
+ # dt: (bsize, length, dt_dim)
934
+ BCdt = self.bcdt_proj(x)
935
+ x = x.reshape(bsize, length, self.num_heads, -1)
936
+ B, C, dt = torch.split(BCdt, [self.d_state, self.d_state, self.dt_dim], dim=-1)
937
+ B = B[:, :, None, :]
938
+ C = C[:, :, None, :]
939
+
940
+ A = -torch.exp(self.A_log.float()) # (num_heads,)
941
+ dt = _rms_norm(dt, None, self.config.rms_norm_eps) * self.dt_norm_weight[None, None, :]
942
+ B = _rms_norm(B, None, self.config.rms_norm_eps) * self.B_norm_weight[None, None, None, :]
943
+ C = _rms_norm(C, None, self.config.rms_norm_eps) * self.C_norm_weight[None, None, None, :]
944
+
945
+ # (bsize, length, num_heads, 1)
946
+ dt = self.dt_proj(dt)[..., None]
947
+
948
+ # TODO it may not be required
949
+ B = B.expand(-1, -1, self.num_heads, -1)
950
+ C = C.expand(-1, -1, self.num_heads, -1)
951
+
952
+ if bool_mask is not None:
953
+ """
954
+ state will be updates by following:
955
+ ```
956
+ dt = softplus(dt)
957
+ dA = exp(dt * A)
958
+ state_next = state * dA + dB * x
959
+ ```
960
+
961
+ To avoid updating state, we set dt to -inf and x to 0
962
+ because `softplus(-inf) = 0` and `exp(0) = 1`
963
+ """
964
+ dt = torch.where(bool_mask[:, :, None, None], dt, float("-inf"))
965
+ x = torch.where(bool_mask[:, :, None, None], x, 0.0)
966
+
967
+ # ssm
968
+ if is_update:
969
+ assert ssm_state is not None
970
+ out = ssd_update_state(
971
+ ssm_state,
972
+ x[:, 0],
973
+ dt[:, 0].reshape(bsize, -1),
974
+ A,
975
+ B[:, 0],
976
+ C[:, 0],
977
+ D=self.D,
978
+ z=z[:, 0],
979
+ dt_bias=self.dt_bias,
980
+ dt_softplus=True,
981
+ )
982
+ else:
983
+ tmp = ssd_chunk_scan_combined(
984
+ x,
985
+ dt.reshape(bsize, length, -1),
986
+ A,
987
+ B,
988
+ C,
989
+ self.chunk_size,
990
+ D=self.D,
991
+ z=z,
992
+ dt_bias=self.dt_bias,
993
+ dt_softplus=True,
994
+ return_final_states=past_states is not None,
995
+ seq_idx=seq_idx,
996
+ ssm_state=ssm_state,
997
+ )
998
+ if past_states is not None:
999
+ out, ssm_state = tmp
1000
+ else:
1001
+ assert isinstance(tmp, torch.Tensor)
1002
+ out = tmp
1003
+
1004
+ y = self.out_proj(out.reshape(bsize, length, -1))
1005
+
1006
+ if past_states is not None:
1007
+ assert ssm_state is not None
1008
+ assert conv_state is not None
1009
+ past_states.update_mamba(conv_state, ssm_state, self.layer_idx)
1010
+
1011
+ return y, past_states
1012
+
1013
+
1014
+ def swa_mask(q_len: int, kv_len: int, device: torch.device, window_size: int) -> torch.Tensor:
1015
+ max_len = max(q_len, kv_len)
1016
+ mask = (
1017
+ torch.ones(max_len, max_len, dtype=torch.bool, device=device)
1018
+ .triu(diagonal=-window_size)
1019
+ .tril(diagonal=window_size)
1020
+ )
1021
+ return mask[-q_len:, -kv_len:]
1022
+
1023
+
1024
+ class Attention(torch.nn.Module):
1025
+ def __init__(self, config: PlamoConfig, layer_idx: int) -> None:
1026
+ super().__init__()
1027
+ self.config = config
1028
+ self.layer_idx = layer_idx
1029
+ self.hidden_size = config.hidden_size
1030
+ head_dim = config.hidden_size_per_head
1031
+ self.max_position_embeddings = config.max_position_embeddings
1032
+
1033
+ self.q_num_heads = config.num_attention_heads
1034
+ self.qk_dim = self.v_dim = head_dim
1035
+ self.k_num_heads = self.v_num_heads = config.num_key_value_heads
1036
+ assert self.q_num_heads % self.k_num_heads == 0
1037
+ self.n_group = self.q_num_heads // self.k_num_heads
1038
+
1039
+ self.q_proj_dim = self.q_num_heads * self.qk_dim
1040
+ self.k_proj_dim = self.k_num_heads * self.qk_dim
1041
+ self.v_proj_dim = self.k_num_heads * self.v_dim
1042
+ self.qkv_proj = nn.Linear(self.hidden_size, self.q_proj_dim + self.k_proj_dim + self.v_proj_dim, bias=False)
1043
+ self.o_proj = nn.Linear(self.q_num_heads * self.v_dim, self.hidden_size, bias=False)
1044
+
1045
+ self.q_weight = torch.nn.Parameter(torch.ones((self.q_num_heads, self.qk_dim)))
1046
+ self.k_weight = torch.nn.Parameter(torch.ones((self.k_num_heads, self.qk_dim)))
1047
+
1048
+ self.rotary_emb = RotaryEmbedding(self.qk_dim, max_position_embeddings=self.config.attention_window_size)
1049
+
1050
+ def forward(
1051
+ self,
1052
+ hidden_states: torch.Tensor,
1053
+ attention_mask: Optional[torch.Tensor] = None,
1054
+ past_states: Optional[PlamoCache] = None,
1055
+ output_attentions: bool = False,
1056
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[PlamoCache]]:
1057
+ bsz, q_len, _ = hidden_states.size()
1058
+
1059
+ qkv = self.qkv_proj(hidden_states)
1060
+ query_states, key_states, value_states = torch.split(
1061
+ qkv, [self.q_proj_dim, self.k_proj_dim, self.v_proj_dim], dim=-1
1062
+ )
1063
+ query_states = query_states.view(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(1, 2)
1064
+ key_states = key_states.view(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(1, 2)
1065
+ value_states = value_states.view(bsz, q_len, self.v_num_heads, self.v_dim).transpose(1, 2)
1066
+
1067
+ attn_dtype = query_states.dtype
1068
+
1069
+ query_states = _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None]
1070
+ key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None]
1071
+
1072
+ if past_states is not None:
1073
+ # reuse k, v, self_attention
1074
+ key_states_new = key_states
1075
+ value_states_new = value_states
1076
+ key_states, value_states = past_states.append_kv(key_states, value_states, self.layer_idx) # type: ignore
1077
+ past_states.update_attention(key_states_new, value_states_new, self.layer_idx)
1078
+
1079
+ kv_seq_len = key_states.shape[-2]
1080
+ device = hidden_states.device
1081
+ position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=device)[None]
1082
+ q_position_ids = position_ids[:, -query_states.shape[2] :]
1083
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
1084
+ query_states = _rotary_pos_emb(query_states, cos, sin, q_position_ids)
1085
+ key_states = _rotary_pos_emb(key_states, cos, sin, position_ids)
1086
+ # [bsz, nh, t, hd]
1087
+
1088
+ def _expand_kv(t: torch.Tensor, repeat: int, target: int) -> torch.Tensor:
1089
+ t = torch.repeat_interleave(t, repeat, dim=1)
1090
+ return t[:, :target]
1091
+
1092
+ # expand shared kv
1093
+ assert self.k_num_heads == self.v_num_heads
1094
+ key_states = _expand_kv(key_states, self.n_group, self.q_num_heads)
1095
+ value_states = _expand_kv(value_states, self.n_group, self.q_num_heads)
1096
+
1097
+ full_attn = self.layer_idx in self.config.full_attention_idx
1098
+
1099
+ query_states = query_states.to(attn_dtype)
1100
+ key_states = key_states.to(attn_dtype)
1101
+ value_states = value_states.to(attn_dtype)
1102
+ if attention_mask is not None and attention_mask.dtype != torch.bool:
1103
+ attention_mask = attention_mask.to(attn_dtype)
1104
+ if attention_mask is None:
1105
+ if not full_attn:
1106
+ assert key_states.shape[2] <= self.config.attention_window_size + 1
1107
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=True)
1108
+ else:
1109
+ if attention_mask.dtype == torch.bool:
1110
+ attention_mask = torch.where(attention_mask, torch.tensor(0.0, dtype=torch.float), float("-inf"))
1111
+ if len(attention_mask.shape) == 2:
1112
+ attention_mask = attention_mask[None, None]
1113
+ assert len(attention_mask.shape) == 4
1114
+
1115
+ if not full_attn:
1116
+ m_swa = swa_mask(
1117
+ query_states.shape[2], key_states.shape[2], query_states.device, self.config.attention_window_size
1118
+ )
1119
+ # `generate` function creates attention mask that does not consider sliding window
1120
+ m_swa = m_swa[None, None]
1121
+ attention_mask = attention_mask[:, :, -query_states.shape[2] :, -key_states.shape[2] :]
1122
+ attention_mask = torch.where(m_swa, attention_mask, float("-inf"))
1123
+
1124
+ # like AttentionMaskConverter._unmask_unattended in huggingface.transfoermers,
1125
+ # we need to attend to all tokens in masked rows for `scaled_dot_product_attention`
1126
+ bool_mask = torch.logical_not(torch.isneginf(attention_mask))
1127
+ valid_tokens = torch.sum(bool_mask, dim=-1).bool() # (..., q_len)
1128
+ attention_mask = torch.where(valid_tokens[..., None], attention_mask, float(0.0))
1129
+ attn_output = F.scaled_dot_product_attention(
1130
+ query_states, key_states, value_states, attn_mask=attention_mask
1131
+ )
1132
+
1133
+ attn_output = attn_output.transpose(1, 2)
1134
+
1135
+ attn_output = attn_output.reshape(bsz, q_len, self.q_num_heads * self.v_dim)
1136
+ attn_output = self.o_proj(attn_output)
1137
+
1138
+ if not output_attentions:
1139
+ attn_weights = None
1140
+
1141
+ return attn_output, attn_weights, past_states
1142
+
1143
+
1144
+ class MLP(nn.Module):
1145
+ def __init__(self, config: PlamoConfig) -> None:
1146
+ super().__init__()
1147
+ self.config = config
1148
+ self.hidden_size = config.hidden_size
1149
+ self.intermediate_size = config.intermediate_size
1150
+ self.gate_up_proj = torch.nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
1151
+ self.down_proj = torch.nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
1152
+
1153
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1154
+ h = self.gate_up_proj(x)
1155
+ h = _swiglu(h)
1156
+ return self.down_proj(h) # type: ignore
1157
+
1158
+
1159
+ class PlamoDecoderLayer(torch.nn.Module):
1160
+ def __init__(self, config: PlamoConfig, is_mamba: bool, layer_idx: int) -> None:
1161
+ super().__init__()
1162
+ self.config = config
1163
+ self.hidden_size = config.hidden_size
1164
+ self.is_mamba = is_mamba
1165
+ self.mixer: torch.nn.Module
1166
+ if is_mamba:
1167
+ self.mixer = Mamba(config, layer_idx)
1168
+ else:
1169
+ self.mixer = Attention(config, layer_idx)
1170
+ self.mlp = MLP(config)
1171
+ """
1172
+ Notes: The model performance was degraded when setting all offsets to 1.
1173
+ """
1174
+ self.pre_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0)
1175
+ self.post_mixer_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / 5)
1176
+ self.pre_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0)
1177
+ self.post_mlp_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, offset=1.0 / (5**1.5))
1178
+
1179
+ def forward(
1180
+ self,
1181
+ hidden_states: torch.Tensor,
1182
+ attention_mask: Optional[torch.Tensor] = None,
1183
+ past_state: Optional[PlamoCache] = None,
1184
+ output_attentions: Optional[bool] = False,
1185
+ ) -> Tuple[Any, ...]:
1186
+ # from LlamaDecoder
1187
+ residual = hidden_states
1188
+ hidden_states = self.pre_mixer_norm(hidden_states)
1189
+
1190
+ # Self Attention
1191
+ if self.is_mamba:
1192
+ hidden_states_sa, present_key_value = self.mixer(
1193
+ hidden_states=hidden_states,
1194
+ attention_mask=attention_mask,
1195
+ past_states=past_state,
1196
+ )
1197
+ self_attn_weights = None
1198
+ else:
1199
+ hidden_states_sa, self_attn_weights, present_key_value = self.mixer(
1200
+ hidden_states=hidden_states,
1201
+ attention_mask=attention_mask,
1202
+ past_states=past_state,
1203
+ output_attentions=output_attentions,
1204
+ )
1205
+
1206
+ hidden_states_sa = self.post_mixer_norm(hidden_states_sa)
1207
+ hidden_states = residual + hidden_states_sa
1208
+
1209
+ residual = hidden_states
1210
+ hidden_states = self.pre_mlp_norm(hidden_states)
1211
+
1212
+ # Fully Connected
1213
+ hidden_states_mlp = self.mlp(hidden_states)
1214
+
1215
+ # Residual
1216
+ hidden_states_mlp = self.post_mlp_norm(hidden_states_mlp)
1217
+ hidden_states = residual + hidden_states_mlp
1218
+
1219
+ outputs: Any = (hidden_states,)
1220
+
1221
+ if output_attentions:
1222
+ outputs += (self_attn_weights,)
1223
+
1224
+ return outputs # type: ignore
1225
+
1226
+
1227
+ def is_mamba(config: PlamoConfig, i: int) -> bool:
1228
+ if not config.mamba_enabled:
1229
+ return False
1230
+ assert config.mamba_step > 1
1231
+ assert i < config.num_hidden_layers
1232
+
1233
+ if config.num_hidden_layers <= (config.mamba_step // 2):
1234
+ # use attention in last layer
1235
+ return i != config.num_hidden_layers - 1
1236
+ return (i % config.mamba_step) != (config.mamba_step // 2)
1237
+
1238
+
1239
+ class PlamoDecoder(torch.nn.Module):
1240
+ def __init__(self, config: PlamoConfig) -> None:
1241
+ super().__init__()
1242
+
1243
+ self.layers = torch.nn.ModuleList(
1244
+ [
1245
+ PlamoDecoderLayer(config, is_mamba=is_mamba(config, i), layer_idx=i)
1246
+ for i in range(config.num_hidden_layers)
1247
+ ]
1248
+ )
1249
+ self.gradient_checkpointing = False
1250
+
1251
+ def forward(self, x: DecoderInput) -> DecoderOutput:
1252
+ all_hidden_states: Optional[Tuple[torch.Tensor, ...]] = () if x.output_hidden_states else None
1253
+ all_self_attns: Optional[Tuple[torch.Tensor, ...]] = () if x.output_attentions else None
1254
+ hidden_states = x.hidden_states
1255
+
1256
+ for decoder_layer in self.layers:
1257
+ if x.output_hidden_states:
1258
+ assert all_hidden_states is not None
1259
+ all_hidden_states += (hidden_states,)
1260
+
1261
+ if self.training and x.gradient_checkpointing:
1262
+ layer_outputs = self._gradient_checkpointing_func(
1263
+ decoder_layer.__call__,
1264
+ hidden_states,
1265
+ x.attention_mask,
1266
+ x.past_states,
1267
+ x.output_attentions,
1268
+ )
1269
+ else:
1270
+ layer_outputs = decoder_layer(
1271
+ hidden_states,
1272
+ attention_mask=x.attention_mask,
1273
+ past_state=x.past_states,
1274
+ output_attentions=x.output_attentions,
1275
+ )
1276
+
1277
+ hidden_states = layer_outputs[0]
1278
+
1279
+ if x.output_attentions:
1280
+ assert layer_outputs[1] is not None
1281
+ assert all_self_attns is not None
1282
+ all_self_attns += (layer_outputs[1],)
1283
+ return DecoderOutput(hidden_states, all_hidden_states, all_self_attns)
1284
+
1285
+
1286
+ class PlamoPreTrainedModel(PreTrainedModel): # type: ignore
1287
+ config_class = PlamoConfig
1288
+ _no_split_modules: List[str]
1289
+ base_model_prefix = "model"
1290
+ supports_gradient_checkpointing = True
1291
+ _no_split_modules = ["PlamoDecoderLayer"]
1292
+ _skip_keys_device_placement = "past_key_values"
1293
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
1294
+
1295
+ def _init_weights(self, module: torch.nn.Module) -> None:
1296
+ std = 0.02
1297
+ if isinstance(module, nn.Linear):
1298
+ module.weight.data.normal_(mean=0.0, std=std)
1299
+ if module.bias is not None:
1300
+ module.bias.data.zero_()
1301
+ elif isinstance(module, nn.Embedding):
1302
+ module.weight.data.normal_(mean=0.0, std=std)
1303
+ if module.padding_idx is not None:
1304
+ module.weight.data[module.padding_idx].zero_()
1305
+
1306
+
1307
+ class PlamoModel(PlamoPreTrainedModel):
1308
+ def __init__(self, config: PlamoConfig):
1309
+ super().__init__(config)
1310
+ assert config.eval_attention_n_bit is None
1311
+ assert config.eval_mlp_n_bit is None
1312
+
1313
+ self.padding_idx = config.pad_token_id
1314
+ self.vocab_size = config.vocab_size
1315
+
1316
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1317
+ if config.image_feature_size is not None:
1318
+ if config.image_proj_type == "mlp":
1319
+ self.image_proj = MLPImageProjector(config) # type: ignore
1320
+ elif config.image_proj_type == "linear":
1321
+ self.image_proj = nn.Linear(config.image_feature_size, config.hidden_size, bias=False) # type: ignore
1322
+ else:
1323
+ raise ValueError(f"Unknown image_proj_type: {config.image_proj_type}")
1324
+ self.layers = PlamoDecoder(config) # type: ignore
1325
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1326
+
1327
+ self.gradient_checkpointing = False
1328
+ # Initialize weights and apply final processing
1329
+ self.post_init()
1330
+
1331
+ def get_input_embeddings(self) -> torch.nn.Embedding:
1332
+ return self.embed_tokens
1333
+
1334
+ def set_input_embeddings(self, value: torch.nn.Embedding) -> None:
1335
+ self.embed_tokens = value
1336
+
1337
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
1338
+ def _prepare_decoder_attention_mask(
1339
+ self,
1340
+ attention_mask: torch.Tensor,
1341
+ input_shape: Tuple[int, int],
1342
+ inputs_embeds: Optional[torch.Tensor],
1343
+ past_key_values_length: int,
1344
+ ) -> Optional[torch.Tensor]:
1345
+ # create causal mask
1346
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1347
+ combined_attention_mask: Optional[torch.Tensor] = None
1348
+ if input_shape[-1] > 1:
1349
+ assert inputs_embeds is not None
1350
+ combined_attention_mask = _make_causal_mask(
1351
+ input_shape,
1352
+ inputs_embeds.dtype,
1353
+ device=inputs_embeds.device,
1354
+ past_key_values_length=past_key_values_length,
1355
+ )
1356
+ input_shape = (input_shape[0], combined_attention_mask.shape[2])
1357
+
1358
+ if attention_mask is not None:
1359
+ if attention_mask.dim() == 4:
1360
+ # Custom 4D attention mask
1361
+ expanded_attn_mask = attention_mask
1362
+ else:
1363
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1364
+ assert inputs_embeds is not None
1365
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
1366
+ inputs_embeds.device
1367
+ )
1368
+ combined_attention_mask = (
1369
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
1370
+ )
1371
+
1372
+ return combined_attention_mask
1373
+
1374
+ def forward(
1375
+ self,
1376
+ input_ids: Optional[torch.LongTensor] = None,
1377
+ attention_mask: Optional[torch.Tensor] = None,
1378
+ position_ids: Optional[torch.Tensor] = None,
1379
+ past_key_values: Optional[PlamoCache] = None,
1380
+ inputs_embeds: Optional[torch.Tensor] = None,
1381
+ image_features: Optional[torch.Tensor] = None,
1382
+ use_cache: Optional[bool] = None,
1383
+ output_attentions: Optional[bool] = None,
1384
+ output_hidden_states: Optional[bool] = None,
1385
+ return_dict: Optional[bool] = None,
1386
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1387
+ assert input_ids is not None
1388
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1389
+ output_hidden_states = (
1390
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1391
+ )
1392
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1393
+
1394
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1395
+
1396
+ # retrieve input_ids and inputs_embeds
1397
+ if input_ids is not None and inputs_embeds is not None:
1398
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1399
+ elif input_ids is not None:
1400
+ batch_size, seq_length = input_ids.shape
1401
+ else:
1402
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
1403
+
1404
+ seq_length_with_past = seq_length
1405
+ past_key_values_length = 0
1406
+
1407
+ if past_key_values is not None:
1408
+ past_key_values_length = past_key_values.get_seq_length()
1409
+ seq_length_with_past = seq_length_with_past + past_key_values_length
1410
+
1411
+ if inputs_embeds is None:
1412
+ inputs_embeds = self.embed_tokens(input_ids)
1413
+
1414
+ if image_features is not None:
1415
+ assert self.config.image_token_id is not None
1416
+ image_embeds = self.image_proj(image_features)
1417
+ assert image_embeds.shape == inputs_embeds.shape, (image_embeds.shape, inputs_embeds.shape)
1418
+ mask = input_ids == self.config.image_token_id
1419
+ inputs_embeds[mask] = image_embeds[mask]
1420
+
1421
+ # embed positions
1422
+ require_attn_mask = False
1423
+ if not self.training or past_key_values is not None:
1424
+ require_attn_mask = True
1425
+ if seq_length_with_past >= self.config.attention_window_size:
1426
+ require_attn_mask = True
1427
+ if require_attn_mask and attention_mask is None:
1428
+ attention_mask = torch.ones(
1429
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
1430
+ )
1431
+ if attention_mask is not None:
1432
+ attention_mask = self._prepare_decoder_attention_mask(
1433
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1434
+ )
1435
+
1436
+ hidden_states = inputs_embeds
1437
+
1438
+ if self.gradient_checkpointing and self.training:
1439
+ if use_cache:
1440
+ use_cache = False
1441
+
1442
+ if use_cache and past_key_values is None:
1443
+ past_key_values = PlamoCache(self.config)
1444
+
1445
+ # decoder layers
1446
+ out = self.layers(
1447
+ DecoderInput(
1448
+ hidden_states,
1449
+ attention_mask,
1450
+ past_key_values,
1451
+ output_hidden_states,
1452
+ output_attentions,
1453
+ self.gradient_checkpointing,
1454
+ )
1455
+ )
1456
+ assert isinstance(out, DecoderOutput)
1457
+ hidden_states = out.hidden_states
1458
+ all_hidden_states = out.all_hidden_states
1459
+ all_self_attns = out.all_self_attns
1460
+
1461
+ hidden_states = self.norm(hidden_states)
1462
+
1463
+ # add hidden states from the last decoder layer
1464
+ if output_hidden_states:
1465
+ assert all_hidden_states is not None
1466
+ all_hidden_states += (hidden_states,)
1467
+
1468
+ if not return_dict:
1469
+ return tuple(
1470
+ v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
1471
+ )
1472
+ return BaseModelOutputWithPast(
1473
+ last_hidden_state=hidden_states,
1474
+ past_key_values=past_key_values,
1475
+ hidden_states=all_hidden_states,
1476
+ attentions=all_self_attns,
1477
+ )
1478
+
1479
+
1480
+ class PlamoForCausalLM(PlamoPreTrainedModel):
1481
+ _tied_weights_keys = ["lm_head.weight"]
1482
+
1483
+ # Without this, the model cannot be loaded into a meta device.
1484
+ # Relevant code:
1485
+ # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/modeling_utils.py#L4376-L4381
1486
+ # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/modeling_utils.py#L356
1487
+ # https://github.com/pytorch/pytorch/blob/v2.4.1/torch/nn/modules/module.py#L2068
1488
+ _supports_param_buffer_assignment = False
1489
+
1490
+ def __init__(self, config: PlamoConfig) -> None:
1491
+ super().__init__(config)
1492
+ self.model = PlamoModel(config)
1493
+
1494
+ self.vocab_size = config.vocab_size
1495
+ vocab_size = ((self.vocab_size + 15) // 16) * 16
1496
+ self.lm_head: torch.nn.Module = nn.Linear(config.hidden_size, vocab_size, bias=False)
1497
+
1498
+ # Initialize weights and apply final processing
1499
+ self.post_init()
1500
+
1501
+ def get_input_embeddings(self) -> torch.nn.Embedding:
1502
+ return self.model.embed_tokens
1503
+
1504
+ def set_input_embeddings(self, value: torch.nn.Embedding) -> None:
1505
+ self.model.embed_tokens = value
1506
+
1507
+ def get_output_embeddings(self) -> torch.nn.Module:
1508
+ return self.lm_head
1509
+
1510
+ def set_output_embeddings(self, new_embeddings: torch.nn.Module) -> None:
1511
+ self.lm_head = new_embeddings
1512
+
1513
+ def set_decoder(self, decoder: PlamoModel) -> None:
1514
+ self.model = decoder
1515
+
1516
+ def get_decoder(self) -> PlamoModel:
1517
+ return self.model
1518
+
1519
+ def forward( # type: ignore
1520
+ self,
1521
+ input_ids: Optional[torch.LongTensor] = None,
1522
+ attention_mask: Optional[torch.Tensor] = None,
1523
+ position_ids: Optional[torch.Tensor] = None,
1524
+ past_key_values: Optional[PlamoCache] = None,
1525
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1526
+ image_features: Optional[torch.Tensor] = None,
1527
+ labels: Optional[torch.LongTensor] = None,
1528
+ use_cache: Optional[bool] = None,
1529
+ output_attentions: Optional[bool] = None,
1530
+ output_hidden_states: Optional[bool] = None,
1531
+ return_dict: Optional[bool] = None,
1532
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1533
+ r"""
1534
+ Args:
1535
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1536
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1537
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1538
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1539
+
1540
+ Returns:
1541
+
1542
+ Example:
1543
+
1544
+ ```python
1545
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
1546
+
1547
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1548
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1549
+
1550
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
1551
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1552
+
1553
+ >>> # Generate
1554
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1555
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1556
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
1557
+ ```"""
1558
+ assert input_ids is not None
1559
+
1560
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1561
+ output_hidden_states = (
1562
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1563
+ )
1564
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1565
+
1566
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1567
+ outputs = self.model(
1568
+ input_ids=input_ids,
1569
+ attention_mask=attention_mask,
1570
+ position_ids=position_ids,
1571
+ past_key_values=past_key_values,
1572
+ inputs_embeds=inputs_embeds,
1573
+ image_features=image_features,
1574
+ use_cache=use_cache,
1575
+ output_attentions=output_attentions,
1576
+ output_hidden_states=output_hidden_states,
1577
+ return_dict=return_dict,
1578
+ )
1579
+
1580
+ hidden_states = outputs[0]
1581
+ logits = self.lm_head(hidden_states)
1582
+ logits = logits[..., : self.vocab_size]
1583
+
1584
+ loss = None
1585
+ if labels is not None:
1586
+ # Shift so that tokens < n predict n
1587
+ shift_logits = logits[..., :-1, :].contiguous()
1588
+ shift_labels = labels[..., 1:].contiguous()
1589
+ # Flatten the tokens
1590
+ loss_fct = nn.CrossEntropyLoss()
1591
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1592
+ shift_labels = shift_labels.view(-1)
1593
+ # Enable model parallelism
1594
+ shift_labels = shift_labels.to(shift_logits.device)
1595
+ loss = loss_fct(shift_logits, shift_labels)
1596
+
1597
+ if not return_dict:
1598
+ output = (logits,) + outputs[1:]
1599
+ return (loss,) + output if loss is not None else output
1600
+
1601
+ return CausalLMOutputWithPast(
1602
+ loss=loss,
1603
+ logits=logits,
1604
+ past_key_values=outputs.past_key_values,
1605
+ hidden_states=outputs.hidden_states,
1606
+ attentions=outputs.attentions,
1607
+ )
1608
+
1609
+ def prepare_inputs_for_generation(
1610
+ self,
1611
+ input_ids: torch.Tensor,
1612
+ past_key_values: Optional[PlamoCache] = None,
1613
+ attention_mask: Optional[torch.Tensor] = None,
1614
+ inputs_embeds: Optional[torch.Tensor] = None,
1615
+ image_features: Optional[torch.Tensor] = None,
1616
+ **kwargs: Any,
1617
+ ) -> Dict[str, Any]:
1618
+ if past_key_values:
1619
+ input_ids = input_ids[:, -1:]
1620
+ if image_features is not None:
1621
+ image_features = image_features[:, -1:, :]
1622
+
1623
+ position_ids = kwargs.get("position_ids", None)
1624
+ if attention_mask is not None and position_ids is None:
1625
+ # create position_ids on the fly for batch generation
1626
+ position_ids = attention_mask.long().cumsum(-1) - 1
1627
+ position_ids.masked_fill_(attention_mask == 0, 1)
1628
+ if past_key_values:
1629
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1630
+
1631
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1632
+ if inputs_embeds is not None and past_key_values is None:
1633
+ model_inputs: Dict[str, Any] = {"inputs_embeds": inputs_embeds}
1634
+ else:
1635
+ model_inputs = {"input_ids": input_ids}
1636
+
1637
+ model_inputs.update(
1638
+ {
1639
+ "position_ids": position_ids,
1640
+ "past_key_values": past_key_values,
1641
+ "use_cache": kwargs.get("use_cache"),
1642
+ "attention_mask": attention_mask,
1643
+ "image_features": image_features,
1644
+ }
1645
+ )
1646
+ return model_inputs
1647
+
1648
+ @staticmethod
1649
+ def _reorder_cache(past_key_values: PlamoCache, beam_idx: torch.Tensor) -> PlamoCache:
1650
+ past_key_values.reorder_cache(beam_idx)
1651
+ return past_key_values
1652
+
1653
+
1654
+ class MLPImageProjector(nn.Module):
1655
+ def __init__(self, config: PlamoConfig) -> None:
1656
+ super().__init__()
1657
+ self.config = config
1658
+
1659
+ assert config.image_feature_size is not None # for typing
1660
+
1661
+ # nn.LayerNorm is not supported by PFVM, so use RMSNorm + Bias instead to approximate this.
1662
+ self.norm0 = RMSNorm(config.image_feature_size, eps=config.rms_norm_eps)
1663
+ self.bias0 = Bias(config.image_feature_size)
1664
+
1665
+ # PFVM doesn't support Linear with bias, so add bias manually afterwards.
1666
+ self.linear1 = nn.Linear(config.image_feature_size, config.hidden_size, bias=False)
1667
+ self.bias1 = Bias(config.hidden_size)
1668
+ self.act1 = nn.GELU()
1669
+
1670
+ self.linear2 = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
1671
+ self.bias2 = Bias(config.hidden_size)
1672
+
1673
+ def forward(
1674
+ self,
1675
+ hidden_states: torch.Tensor,
1676
+ ) -> torch.Tensor:
1677
+ hidden_states = self.norm0(hidden_states)
1678
+ hidden_states = self.bias0(hidden_states)
1679
+
1680
+ hidden_states = self.linear1(hidden_states)
1681
+ hidden_states = self.bias1(hidden_states)
1682
+ hidden_states = self.act1(hidden_states)
1683
+
1684
+ hidden_states = self.linear2(hidden_states)
1685
+ hidden_states = self.bias2(hidden_states)
1686
+
1687
+ return hidden_states
1688
+
1689
+
1690
+ class Bias(nn.Module):
1691
+ def __init__(self, num_features: int) -> None:
1692
+ super().__init__()
1693
+ self._bias = nn.Parameter(torch.zeros((num_features,)))
1694
+
1695
+ def forward(
1696
+ self,
1697
+ x: torch.Tensor,
1698
+ ) -> torch.Tensor:
1699
+ return x + self._bias
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|plamo:bos|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|plamo:eos|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|plamo:pad|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|plamo:unk|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenization_plamo.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import os
4
+ from shutil import copyfile
5
+ from typing import Any, Optional, Tuple
6
+
7
+ import numpy as np
8
+
9
+ # NOTE: numba does not support type hints for njit: https://github.com/python/mypy/issues/16149
10
+ from numba import njit # type: ignore[attr-defined]
11
+ from numba.core import types
12
+ from numba.typed import Dict, List
13
+ from transformers.tokenization_utils import PreTrainedTokenizer
14
+ from transformers.utils import logging
15
+
16
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.jsonl"}
17
+ logger = logging.get_logger(__name__)
18
+
19
+ INVALID_SCORE = -20000000
20
+ UNKNOWN_SCORE = -10000000
21
+
22
+ TABLE_PIECE_LENGTH = 0
23
+ TABLE_TOKEN_ID = 1
24
+ TABLE_SCORE = 2
25
+ TABLE_PIECE_ID = 3
26
+
27
+ PATH_TOKEN_LENGTH = 0
28
+ PATH_TOKEN_ID = 1
29
+ PATH_NUM_TOKENS = 2
30
+
31
+
32
+ class AhoCorasick:
33
+ def __init__(self) -> None:
34
+ # List of tokens in the vocabulary.
35
+ self._tokens: list[str]
36
+
37
+ # A mapping from a byte code point to a token ID, used for byte fallback.
38
+ self._bytes: np.ndarray
39
+
40
+ # A mapping from a suffix's piece code to a suffix ID.
41
+ #
42
+ # Typically, the Aho-Corasick algorithm builds a Trie and adds suffix links between nodes
43
+ # of the Trie. In this implementation, a suffix ID corresponds to a node in the trie, and
44
+ # a piece code to an edge (in other words, a pair of a node and the next character).
45
+ #
46
+ # A piece code is a 64-bit integer:
47
+ # - The upper 32 bits store the Unicode code point of the first character.
48
+ # - The lower 32 bits store the suffix ID of the remaining suffix.
49
+ #
50
+ # A suffix ID is an integer indicating the starting position in the _table.
51
+ self._to_suffix_id: Dict[types.int64, types.int32]
52
+
53
+ # Flattened table representing the Trie structure for the Aho-Corasick algorithm.
54
+ # It stores information including scores for each piece (prefix) within each suffix.
55
+ # It is flattened for memory efficiency and performance. Suffixes are stored in
56
+ # lexicographical order of their reversed strings, which improves memory access locality
57
+ # when exploring new characters starting from the string's end. Pieces within a suffix are
58
+ # stored in the decreasing order of their lengths.
59
+ #
60
+ # Each piece (a prefix fo the suffix) contains four pieces of information:
61
+ # - TABLE_PIECE_LENGTH: Length of the piece.
62
+ # - TABLE_TOKEN_ID: Token ID (or -1 if the piece is not a valid token).
63
+ # - TABLE_SCORE: Score (or INVALID_SCORE if the piece is not a valid token).
64
+ # - TABLE_PIECE_ID: Piece ID of the suffix.
65
+ #
66
+ # Each suffix also includes a sentinel row with a length of 1, a score of UNKNOWN_SCORE,
67
+ # and a token ID of -1. Sentinel rows are identified by the score being UNKNOWN_SCORE.
68
+ self._table: np.ndarray
69
+
70
+ def build(self, vocab: list[Any]) -> None:
71
+ self._bytes = np.zeros(256, dtype=np.int32)
72
+ self._to_suffix_id = Dict.empty(key_type=types.int64, value_type=types.int32)
73
+
74
+ # Build suffix_to_score and token_to_token_id.
75
+ # The suffix_to_score dictionary maps a suffix to its score. It also includes all suffixes
76
+ # of the token for the Trie structure for the Aho-Corasick algorithm. If a suffix is not a
77
+ # valid token, its score is set to math.nan.
78
+ # The token_to_token_id dictionary maps a token to its token ID.
79
+ suffix_to_score: dict[str, float] = {}
80
+ token_to_token_id: dict[str, int] = {}
81
+ self._tokens = []
82
+ for token_id, row in enumerate(vocab):
83
+ assert isinstance(row[0], str), row
84
+ assert isinstance(row[1], (int, float)), row
85
+
86
+ token = str(row[0])
87
+ self._tokens.append(token)
88
+ token_to_token_id[token] = token_id
89
+
90
+ # Special handling for byte tokens.
91
+ if len(row) > 2 and row[2] == "BYTE":
92
+ assert len(token) == 6 and token.startswith("<0x") and token.endswith(">"), row[0]
93
+ self._bytes[int(row[0][3:5], 16)] = token_id
94
+ continue
95
+
96
+ suffix_to_score[token] = float(row[1])
97
+ # Ensure that all suffixes are included in suffix_to_score.
98
+ for i in range(1, len(token)):
99
+ suffix_to_score[token[i:]] = suffix_to_score.get(token[i:], math.nan)
100
+
101
+ # Ensure all byte tokens are set.
102
+ for i in range(256):
103
+ assert self._bytes[i] != 0, f"Byte token for <0x{i:02X}> is not set."
104
+
105
+ # List suffixes in lexicographical order of their reversed strings.
106
+ suffixes = list(suffix_to_score.keys())
107
+ suffixes.append("")
108
+ suffixes.sort(key=lambda x: x[::-1])
109
+
110
+ # Build suffix_to_id, which is a mapping from a suffix to a suffix ID, and _to_suffix_id,
111
+ # which is a mapping from a piece code to a suffix ID.
112
+ suffix_to_id: dict[str, int] = {}
113
+ num_pieces = 0
114
+ for s in suffixes:
115
+ suffix_to_id[s] = num_pieces
116
+ if s != "":
117
+ self._to_suffix_id[ord(s[0]) << 32 | suffix_to_id[s[1:]]] = np.int32(num_pieces)
118
+ num_pieces += 1 + sum(s[:i] in suffix_to_score for i in range(1, len(s) + 1))
119
+ assert suffix_to_id[""] == 0, suffix_to_id[""]
120
+
121
+ # Build _table, which is a flattened table representing the Trie structure for the Aho-Corasick.
122
+ self._table = np.zeros((num_pieces, 4), dtype=np.int32)
123
+ i = 0
124
+ for suffix in suffixes:
125
+ # Add all prefixes of the suffix to the table.
126
+ for piece_length in range(len(suffix), 0, -1):
127
+ piece = suffix[:piece_length]
128
+ score = suffix_to_score.get(piece, None)
129
+ if score is None:
130
+ continue
131
+ self._table[i, TABLE_PIECE_LENGTH] = piece_length
132
+ self._table[i, TABLE_TOKEN_ID] = token_to_token_id.get(piece, -1)
133
+ self._table[i, TABLE_SCORE] = round(score * 1e4) if math.isfinite(score) else INVALID_SCORE
134
+ self._table[i, TABLE_PIECE_ID] = suffix_to_id[piece]
135
+ i += 1
136
+
137
+ # Add a sentinel row.
138
+ self._table[i, TABLE_PIECE_LENGTH] = 1
139
+ self._table[i, TABLE_TOKEN_ID] = -1
140
+ self._table[i, TABLE_SCORE] = UNKNOWN_SCORE
141
+ i += 1
142
+ assert i == num_pieces, (i, num_pieces)
143
+
144
+ @staticmethod
145
+ @njit
146
+ def _encode(
147
+ to_suffix_id: Dict[types.int64, types.int32],
148
+ table: np.ndarray,
149
+ bytes: np.ndarray,
150
+ data: np.ndarray,
151
+ ) -> np.ndarray:
152
+ # Initialize scores array with a high value and set the score at the end to 0.
153
+ # This array keeps track of the minimum cost (best score) to encode from each position to the end.
154
+ scores = np.full((len(data) + 1,), 2**60, dtype=np.int64)
155
+ scores[-1] = 0
156
+
157
+ # Path array to store the best path information.
158
+ # The path array keeps track of token length, token ID, and number of tokens needed to encode.
159
+ path = np.zeros((len(data) + 1, 3), dtype=np.int32)
160
+
161
+ # Initialize suffix_id to 0, which represents the root of the Trie.
162
+ suffix_id = 0
163
+
164
+ # Process the input data from the end to the beginning.
165
+ for i in range(len(data) - 1, -1, -1):
166
+ c = data[i]
167
+
168
+ # Find the next suffix ID by iterating the suffix IDs of prefixes of the current suffix.
169
+ # NOTE: If no suffix ID is found, suffix_id will be set to 0.
170
+ for p in range(suffix_id, len(table)):
171
+ suffix_id = to_suffix_id.get(c << 32 | table[p, TABLE_PIECE_ID], np.int32(0))
172
+ # If a next suffix ID is found or a sentinel row is reached, break the loop.
173
+ if suffix_id > 0 or table[p, TABLE_SCORE] == UNKNOWN_SCORE:
174
+ break
175
+
176
+ # Update the best path to the current position. If multiple paths have the same score,
177
+ # this chooses the longest prefix as the best path (table is sorted in the decreasing
178
+ # order of piece length).
179
+ for p in range(suffix_id, len(table)):
180
+ score = table[p, TABLE_SCORE]
181
+ if score > INVALID_SCORE:
182
+ piece_length = table[p, TABLE_PIECE_LENGTH]
183
+ s = scores[i + piece_length] - score
184
+ if s < scores[i]:
185
+ scores[i] = s
186
+ path[i, PATH_TOKEN_LENGTH] = piece_length
187
+ path[i, PATH_TOKEN_ID] = table[p, TABLE_TOKEN_ID]
188
+ path[i, PATH_NUM_TOKENS] = path[i + piece_length, PATH_NUM_TOKENS] + 1
189
+ if score == UNKNOWN_SCORE:
190
+ # Add number of bytes to represent `c` in UTF-8 (minus 1; 1 is already
191
+ # added above).
192
+ path[i, PATH_NUM_TOKENS] += (c >= 0x80) + (c >= 0x800) + (c >= 0x10000)
193
+
194
+ # If it reaches a sentinel row, break the loop.
195
+ if score == UNKNOWN_SCORE:
196
+ break
197
+
198
+ # Decode the best path from the beginning to get the token IDs.
199
+ pos = 0
200
+ token_ids = np.zeros(path[0, PATH_NUM_TOKENS], dtype=np.int32)
201
+ token_pos = 0
202
+ while pos < len(data):
203
+ if path[pos, PATH_TOKEN_ID] >= 0:
204
+ token_ids[token_pos] = path[pos, PATH_TOKEN_ID]
205
+ token_pos += 1
206
+ else:
207
+ # Fall back to byte tokens.
208
+ c = data[pos]
209
+ s = 1 + (c >= 0x80) + (c >= 0x800) + (c >= 0x10000)
210
+ # Add byte tokens representing UTF-8 bytes.
211
+ for i in range(s):
212
+ b = c if s == 1 else (0xF00 >> s) & 0xFF if i == 0 else 0x80
213
+ token_ids[token_pos] = bytes[b | ((c >> (s - i - 1) * 6) & 0x3F)]
214
+ token_pos += 1
215
+
216
+ # Ensure that pos should increase by at least 1.
217
+ assert path[pos, PATH_TOKEN_LENGTH] > 0, (pos, path[pos])
218
+ pos += path[pos, PATH_TOKEN_LENGTH]
219
+
220
+ return token_ids
221
+
222
+ def encode(self, data: str) -> np.ndarray:
223
+ """Encodes a string into a sequence of token IDs."""
224
+ return np.asarray(
225
+ self._encode(
226
+ self._to_suffix_id,
227
+ self._table,
228
+ self._bytes,
229
+ # Convert a string into a numpy array of Unicode code points.
230
+ # NOTE: This skips UTF-32 BOM.
231
+ np.frombuffer(data.encode("utf-32"), dtype=np.int32)[1:],
232
+ )
233
+ )
234
+
235
+ def encode_as_tokens(self, data: str) -> list[str]:
236
+ """Encodes a string into a sequence of tokens."""
237
+ return [self._tokens[token_id] for token_id in self.encode(data)]
238
+
239
+
240
+ class PlamoTokenizer(PreTrainedTokenizer): # type: ignore
241
+ vocab_files_names = VOCAB_FILES_NAMES
242
+ model_input_names = ["input_ids", "attention_mask"]
243
+
244
+ _save_files = [
245
+ "special_tokens_map.json",
246
+ "tokenization_plamo.py",
247
+ "tokenizer.jsonl",
248
+ "tokenizer_config.json",
249
+ ]
250
+
251
+ def __init__(
252
+ self,
253
+ vocab_file: str,
254
+ unk_token: str = "<|plamo:unk|>",
255
+ bos_token: str = "<|plamo:bos|>",
256
+ eos_token: str = "<|plamo:eos|>",
257
+ pad_token: str = "<|plamo:pad|>",
258
+ cls_token: Optional[str] = None,
259
+ sep_token: Optional[str] = None,
260
+ mask_token: Optional[str] = None,
261
+ clean_up_tokenization_spaces: bool = False,
262
+ **kwargs: Any,
263
+ ) -> None:
264
+ """Tokenizer for PLaMo.
265
+
266
+ Args:
267
+ vocab_file (str): Vocabrary file path.
268
+ unk_token (str): Unknown token.
269
+ bos_token (str): Beginning of sentence token.
270
+ eos_token (str): End of sentence token.
271
+ pad_token (str): Padding token.
272
+ cls_token (str):
273
+ Classification token, to extract a summary of an input sequence leveraging self-attention along the
274
+ full depth of the model.
275
+ sep_token (str): Separation token, to separate context and query in an input sequence.
276
+ mask_token (str): Mask token, to use when training a model with masked-language modeling.
277
+ clean_up_tokenization_spaces (bool): Whether or not to clean up the tokenization spaces.
278
+ num_threads (int):
279
+ Number of threads. This value will be ignored if one of `PLAMO_TOKENIZER_NUM_THREADS` or
280
+ `RAYON_NUM_THREADS` is set as an environment variable.
281
+ """
282
+ if "add_bos_token" not in kwargs:
283
+ kwargs["add_bos_token"] = False
284
+ if "add_eos_token" not in kwargs:
285
+ kwargs["add_eos_token"] = False
286
+ self.data: list[Any] = [json.loads(line) for line in open(vocab_file, "r", encoding="utf-8")]
287
+ self.vocab: dict[str, int] = {v[0]: i for i, v in enumerate(self.data)}
288
+ self.aho_corasick = AhoCorasick()
289
+ self.aho_corasick.build(self.data)
290
+ self.vocab_file = vocab_file
291
+ self.add_bos_token = kwargs["add_bos_token"]
292
+ self.add_eos_token = kwargs["add_eos_token"]
293
+
294
+ super().__init__(
295
+ vocab_file=vocab_file,
296
+ unk_token=unk_token,
297
+ bos_token=bos_token,
298
+ eos_token=eos_token,
299
+ pad_token=pad_token,
300
+ cls_token=cls_token,
301
+ sep_token=sep_token,
302
+ mask_token=mask_token,
303
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
304
+ **kwargs,
305
+ )
306
+
307
+ # the functions below are copied from hf transformers LlamaTokenizer's implementation to fix the behaviour of the tokenizer
308
+ # https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/models/llama/tokenization_llama.py
309
+
310
+ def __getstate__(self) -> dict[str, Any]:
311
+ state = self.__dict__.copy()
312
+ state["aho_corasick"] = None
313
+ return state
314
+
315
+ def __setstate__(self, d: dict[str, Any]) -> None:
316
+ self.__dict__ = d
317
+ self.aho_corasick = AhoCorasick()
318
+ self.aho_corasick.build(self.data)
319
+
320
+ @property
321
+ def vocab_size(self) -> Any:
322
+ """Returns vocab size"""
323
+ return len(self.data)
324
+
325
+ def token_to_score(self, token: str) -> Optional[float]:
326
+ """Returns score of the token"""
327
+ token_id = self.vocab.get(token, None)
328
+ return None if token_id is None else self.data[token_id][1]
329
+
330
+ def get_vocab(self) -> dict[str, int]:
331
+ """Returns vocab as a dict"""
332
+ vocab = self.vocab.copy()
333
+ vocab.update(self.added_tokens_encoder)
334
+ return vocab
335
+
336
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
337
+ """Converts a sequence of tokens (string) in a single string."""
338
+ return b"".join(
339
+ [bytes([int(t[3:5], 16)]) if t.startswith("<0x") else t.encode("utf-8") for t in tokens]
340
+ ).decode("utf-8", errors="replace")
341
+
342
+ def _tokenize(self, text: str) -> Any:
343
+ """Returns a tokenized string."""
344
+ return self.aho_corasick.encode_as_tokens(text)
345
+
346
+ def _convert_token_to_id(self, token: str) -> Any:
347
+ """Converts a token (str) in an id using the vocab."""
348
+ return self.vocab.get(token, 0)
349
+
350
+ def _convert_id_to_token(self, index: int) -> Any:
351
+ """Converts an index (integer) in a token (str) using the vocab."""
352
+ return self.data[index][0]
353
+
354
+ def build_inputs_with_special_tokens(
355
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
356
+ ) -> List[int]:
357
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
358
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
359
+
360
+ output = bos_token_id + token_ids_0 + eos_token_id
361
+
362
+ if token_ids_1 is not None:
363
+ output = output + bos_token_id + token_ids_1 + eos_token_id
364
+
365
+ return output
366
+
367
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
368
+ """
369
+ Save the vocabulary and special tokens file to a directory.
370
+
371
+ Args:
372
+ save_directory (`str`):
373
+ The directory in which to save the vocabulary.
374
+
375
+ Returns:
376
+ `Tuple(str)`: Paths to the files saved.
377
+ """
378
+ if not os.path.isdir(save_directory):
379
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
380
+ return ("",)
381
+ out_vocab_file = os.path.join(
382
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
383
+ )
384
+
385
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
386
+ copyfile(self.vocab_file, out_vocab_file)
387
+ elif not os.path.isfile(self.vocab_file):
388
+ with open(out_vocab_file, "w") as f:
389
+ for token in self.data:
390
+ print(json.dumps(token, ensure_ascii=False), file=f)
391
+
392
+ return (out_vocab_file,)
tokenizer.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<|plamo:unk|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<|plamo:bos|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "<|plamo:eos|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "3": {
30
+ "content": "<|plamo:pad|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ }
37
+ },
38
+ "auto_map": {
39
+ "AutoTokenizer": [
40
+ "tokenization_plamo.PlamoTokenizer",
41
+ null
42
+ ]
43
+ },
44
+ "bos_token": "<|plamo:bos|>",
45
+ "chat_template": "{{bos_token}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '\\n\\n### 指示:\\n' + message['content'] }}{% elif message['role'] == 'system' %}{{ '以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。' }}{% elif message['role'] == 'assistant' %}{{ '\\n\\n### 応答:\\n' + message['content'] + eos_token }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '\\n\\n### 応答:\\n' }}{% endif %}{% endfor %}",
46
+ "clean_up_tokenization_spaces": false,
47
+ "cls_token": null,
48
+ "eos_token": "<|plamo:eos|>",
49
+ "extra_special_tokens": {},
50
+ "local_file_only": true,
51
+ "mask_token": null,
52
+ "model_max_length": 1000000000000000019884624838656,
53
+ "pad_token": "<|plamo:pad|>",
54
+ "sep_token": null,
55
+ "tokenizer_class": "PlamoTokenizer",
56
+ "unk_token": "<|plamo:unk|>"
57
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe708c568c254cc8e776d9e31e935d284306b966caebe4e6027051d6ce43e470
3
+ size 8888