kaitos255 commited on
Commit
e29924d
·
1 Parent(s): 2dfc504

initial commit

Browse files
LICENSE.txt ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2025- Preferred Networks, Inc. All rights reserved.
2
+
3
+ Apache License
4
+ Version 2.0, January 2004
5
+ http://www.apache.org/licenses/
6
+
7
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8
+
9
+ 1. Definitions.
10
+
11
+ "License" shall mean the terms and conditions for use, reproduction,
12
+ and distribution as defined by Sections 1 through 9 of this document.
13
+
14
+ "Licensor" shall mean the copyright owner or entity authorized by
15
+ the copyright owner that is granting the License.
16
+
17
+ "Legal Entity" shall mean the union of the acting entity and all
18
+ other entities that control, are controlled by, or are under common
19
+ control with that entity. For the purposes of this definition,
20
+ "control" means (i) the power, direct or indirect, to cause the
21
+ direction or management of such entity, whether by contract or
22
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
23
+ outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ "You" (or "Your") shall mean an individual or Legal Entity
26
+ exercising permissions granted by this License.
27
+
28
+ "Source" form shall mean the preferred form for making modifications,
29
+ including but not limited to software source code, documentation
30
+ source, and configuration files.
31
+
32
+ "Object" form shall mean any form resulting from mechanical
33
+ transformation or translation of a Source form, including but
34
+ not limited to compiled object code, generated documentation,
35
+ and conversions to other media types.
36
+
37
+ "Work" shall mean the work of authorship, whether in Source or
38
+ Object form, made available under the License, as indicated by a
39
+ copyright notice that is included in or attached to the work
40
+ (an example is provided in the Appendix below).
41
+
42
+ "Derivative Works" shall mean any work, whether in Source or Object
43
+ form, that is based on (or derived from) the Work and for which the
44
+ editorial revisions, annotations, elaborations, or other modifications
45
+ represent, as a whole, an original work of authorship. For the purposes
46
+ of this License, Derivative Works shall not include works that remain
47
+ separable from, or merely link (or bind by name) to the interfaces of,
48
+ the Work and Derivative Works thereof.
49
+
50
+ "Contribution" shall mean any work of authorship, including
51
+ the original version of the Work and any modifications or additions
52
+ to that Work or Derivative Works thereof, that is intentionally
53
+ submitted to Licensor for inclusion in the Work by the copyright owner
54
+ or by an individual or Legal Entity authorized to submit on behalf of
55
+ the copyright owner. For the purposes of this definition, "submitted"
56
+ means any form of electronic, verbal, or written communication sent
57
+ to the Licensor or its representatives, including but not limited to
58
+ communication on electronic mailing lists, source code control systems,
59
+ and issue tracking systems that are managed by, or on behalf of, the
60
+ Licensor for the purpose of discussing and improving the Work, but
61
+ excluding communication that is conspicuously marked or otherwise
62
+ designated in writing by the copyright owner as "Not a Contribution."
63
+
64
+ "Contributor" shall mean Licensor and any individual or Legal Entity
65
+ on behalf of whom a Contribution has been received by Licensor and
66
+ subsequently incorporated within the Work.
67
+
68
+ 2. Grant of Copyright License. Subject to the terms and conditions of
69
+ this License, each Contributor hereby grants to You a perpetual,
70
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71
+ copyright license to reproduce, prepare Derivative Works of,
72
+ publicly display, publicly perform, sublicense, and distribute the
73
+ Work and such Derivative Works in Source or Object form.
74
+
75
+ 3. Grant of Patent License. Subject to the terms and conditions of
76
+ this License, each Contributor hereby grants to You a perpetual,
77
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78
+ (except as stated in this section) patent license to make, have made,
79
+ use, offer to sell, sell, import, and otherwise transfer the Work,
80
+ where such license applies only to those patent claims licensable
81
+ by such Contributor that are necessarily infringed by their
82
+ Contribution(s) alone or by combination of their Contribution(s)
83
+ with the Work to which such Contribution(s) was submitted. If You
84
+ institute patent litigation against any entity (including a
85
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
86
+ or a Contribution incorporated within the Work constitutes direct
87
+ or contributory patent infringement, then any patent licenses
88
+ granted to You under this License for that Work shall terminate
89
+ as of the date such litigation is filed.
90
+
91
+ 4. Redistribution. You may reproduce and distribute copies of the
92
+ Work or Derivative Works thereof in any medium, with or without
93
+ modifications, and in Source or Object form, provided that You
94
+ meet the following conditions:
95
+
96
+ (a) You must give any other recipients of the Work or
97
+ Derivative Works a copy of this License; and
98
+
99
+ (b) You must cause any modified files to carry prominent notices
100
+ stating that You changed the files; and
101
+
102
+ (c) You must retain, in the Source form of any Derivative Works
103
+ that You distribute, all copyright, patent, trademark, and
104
+ attribution notices from the Source form of the Work,
105
+ excluding those notices that do not pertain to any part of
106
+ the Derivative Works; and
107
+
108
+ (d) If the Work includes a "NOTICE" text file as part of its
109
+ distribution, then any Derivative Works that You distribute must
110
+ include a readable copy of the attribution notices contained
111
+ within such NOTICE file, excluding those notices that do not
112
+ pertain to any part of the Derivative Works, in at least one
113
+ of the following places: within a NOTICE text file distributed
114
+ as part of the Derivative Works; within the Source form or
115
+ documentation, if provided along with the Derivative Works; or,
116
+ within a display generated by the Derivative Works, if and
117
+ wherever such third-party notices normally appear. The contents
118
+ of the NOTICE file are for informational purposes only and
119
+ do not modify the License. You may add Your own attribution
120
+ notices within Derivative Works that You distribute, alongside
121
+ or as an addendum to the NOTICE text from the Work, provided
122
+ that such additional attribution notices cannot be construed
123
+ as modifying the License.
124
+
125
+ You may add Your own copyright statement to Your modifications and
126
+ may provide additional or different license terms and conditions
127
+ for use, reproduction, or distribution of Your modifications, or
128
+ for any such Derivative Works as a whole, provided Your use,
129
+ reproduction, and distribution of the Work otherwise complies with
130
+ the conditions stated in this License.
131
+
132
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
133
+ any Contribution intentionally submitted for inclusion in the Work
134
+ by You to the Licensor shall be under the terms and conditions of
135
+ this License, without any additional terms or conditions.
136
+ Notwithstanding the above, nothing herein shall supersede or modify
137
+ the terms of any separate license agreement you may have executed
138
+ with Licensor regarding such Contributions.
139
+
140
+ 6. Trademarks. This License does not grant permission to use the trade
141
+ names, trademarks, service marks, or product names of the Licensor,
142
+ except as required for reasonable and customary use in describing the
143
+ origin of the Work and reproducing the content of the NOTICE file.
144
+
145
+ 7. Disclaimer of Warranty. Unless required by applicable law or
146
+ agreed to in writing, Licensor provides the Work (and each
147
+ Contributor provides its Contributions) on an "AS IS" BASIS,
148
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149
+ implied, including, without limitation, any warranties or conditions
150
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151
+ PARTICULAR PURPOSE. You are solely responsible for determining the
152
+ appropriateness of using or redistributing the Work and assume any
153
+ risks associated with Your exercise of permissions under this License.
154
+
155
+ 8. Limitation of Liability. In no event and under no legal theory,
156
+ whether in tort (including negligence), contract, or otherwise,
157
+ unless required by applicable law (such as deliberate and grossly
158
+ negligent acts) or agreed to in writing, shall any Contributor be
159
+ liable to You for damages, including any direct, indirect, special,
160
+ incidental, or consequential damages of any character arising as a
161
+ result of this License or out of the use or inability to use the
162
+ Work (including but not limited to damages for loss of goodwill,
163
+ work stoppage, computer failure or malfunction, or any and all
164
+ other commercial damages or losses), even if such Contributor
165
+ has been advised of the possibility of such damages.
166
+
167
+ 9. Accepting Warranty or Additional Liability. While redistributing
168
+ the Work or Derivative Works thereof, You may choose to offer,
169
+ and charge a fee for, acceptance of support, warranty, indemnity,
170
+ or other liability obligations and/or rights consistent with this
171
+ License. However, in accepting such obligations, You may act only
172
+ on Your own behalf and on Your sole responsibility, not on behalf
173
+ of any other Contributor, and only if You agree to indemnify,
174
+ defend, and hold each Contributor harmless for any liability
175
+ incurred by, or claims asserted against, such Contributor by reason
176
+ of your accepting any such warranty or additional liability.
177
+
178
+ END OF TERMS AND CONDITIONS
179
+
180
+ APPENDIX: How to apply the Apache License to your work.
181
+
182
+ To apply the Apache License to your work, attach the following
183
+ boilerplate notice, with the fields enclosed by brackets "[]"
184
+ replaced with your own identifying information. (Don't include
185
+ the brackets!) The text should be enclosed in the appropriate
186
+ comment syntax for the file format. We also recommend that a
187
+ file or class name and description of purpose be included on the
188
+ same "printed page" as the copyright notice for easier
189
+ identification within third-party archives.
190
+
191
+ Copyright [yyyy] [name of copyright owner]
192
+
193
+ Licensed under the Apache License, Version 2.0 (the "License");
194
+ you may not use this file except in compliance with the License.
195
+ You may obtain a copy of the License at
196
+
197
+ http://www.apache.org/licenses/LICENSE-2.0
198
+
199
+ Unless required by applicable law or agreed to in writing, software
200
+ distributed under the License is distributed on an "AS IS" BASIS,
201
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202
+ See the License for the specific language governing permissions and
203
+ limitations under the License.
204
+
205
+ ---
206
+
207
+ This software contains modified codes from huggingface trainsformers library which is released under Apache v2.0 license.
208
+
209
+ ---
210
+ Copyright 2018- The Hugging Face team. All rights reserved.
211
+
212
+ Apache License
213
+ Version 2.0, January 2004
214
+ http://www.apache.org/licenses/
215
+
216
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
217
+
218
+ 1. Definitions.
219
+
220
+ "License" shall mean the terms and conditions for use, reproduction,
221
+ and distribution as defined by Sections 1 through 9 of this document.
222
+
223
+ "Licensor" shall mean the copyright owner or entity authorized by
224
+ the copyright owner that is granting the License.
225
+
226
+ "Legal Entity" shall mean the union of the acting entity and all
227
+ other entities that control, are controlled by, or are under common
228
+ control with that entity. For the purposes of this definition,
229
+ "control" means (i) the power, direct or indirect, to cause the
230
+ direction or management of such entity, whether by contract or
231
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
232
+ outstanding shares, or (iii) beneficial ownership of such entity.
233
+
234
+ "You" (or "Your") shall mean an individual or Legal Entity
235
+ exercising permissions granted by this License.
236
+
237
+ "Source" form shall mean the preferred form for making modifications,
238
+ including but not limited to software source code, documentation
239
+ source, and configuration files.
240
+
241
+ "Object" form shall mean any form resulting from mechanical
242
+ transformation or translation of a Source form, including but
243
+ not limited to compiled object code, generated documentation,
244
+ and conversions to other media types.
245
+
246
+ "Work" shall mean the work of authorship, whether in Source or
247
+ Object form, made available under the License, as indicated by a
248
+ copyright notice that is included in or attached to the work
249
+ (an example is provided in the Appendix below).
250
+
251
+ "Derivative Works" shall mean any work, whether in Source or Object
252
+ form, that is based on (or derived from) the Work and for which the
253
+ editorial revisions, annotations, elaborations, or other modifications
254
+ represent, as a whole, an original work of authorship. For the purposes
255
+ of this License, Derivative Works shall not include works that remain
256
+ separable from, or merely link (or bind by name) to the interfaces of,
257
+ the Work and Derivative Works thereof.
258
+
259
+ "Contribution" shall mean any work of authorship, including
260
+ the original version of the Work and any modifications or additions
261
+ to that Work or Derivative Works thereof, that is intentionally
262
+ submitted to Licensor for inclusion in the Work by the copyright owner
263
+ or by an individual or Legal Entity authorized to submit on behalf of
264
+ the copyright owner. For the purposes of this definition, "submitted"
265
+ means any form of electronic, verbal, or written communication sent
266
+ to the Licensor or its representatives, including but not limited to
267
+ communication on electronic mailing lists, source code control systems,
268
+ and issue tracking systems that are managed by, or on behalf of, the
269
+ Licensor for the purpose of discussing and improving the Work, but
270
+ excluding communication that is conspicuously marked or otherwise
271
+ designated in writing by the copyright owner as "Not a Contribution."
272
+
273
+ "Contributor" shall mean Licensor and any individual or Legal Entity
274
+ on behalf of whom a Contribution has been received by Licensor and
275
+ subsequently incorporated within the Work.
276
+
277
+ 2. Grant of Copyright License. Subject to the terms and conditions of
278
+ this License, each Contributor hereby grants to You a perpetual,
279
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
280
+ copyright license to reproduce, prepare Derivative Works of,
281
+ publicly display, publicly perform, sublicense, and distribute the
282
+ Work and such Derivative Works in Source or Object form.
283
+
284
+ 3. Grant of Patent License. Subject to the terms and conditions of
285
+ this License, each Contributor hereby grants to You a perpetual,
286
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
287
+ (except as stated in this section) patent license to make, have made,
288
+ use, offer to sell, sell, import, and otherwise transfer the Work,
289
+ where such license applies only to those patent claims licensable
290
+ by such Contributor that are necessarily infringed by their
291
+ Contribution(s) alone or by combination of their Contribution(s)
292
+ with the Work to which such Contribution(s) was submitted. If You
293
+ institute patent litigation against any entity (including a
294
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
295
+ or a Contribution incorporated within the Work constitutes direct
296
+ or contributory patent infringement, then any patent licenses
297
+ granted to You under this License for that Work shall terminate
298
+ as of the date such litigation is filed.
299
+
300
+ 4. Redistribution. You may reproduce and distribute copies of the
301
+ Work or Derivative Works thereof in any medium, with or without
302
+ modifications, and in Source or Object form, provided that You
303
+ meet the following conditions:
304
+
305
+ (a) You must give any other recipients of the Work or
306
+ Derivative Works a copy of this License; and
307
+
308
+ (b) You must cause any modified files to carry prominent notices
309
+ stating that You changed the files; and
310
+
311
+ (c) You must retain, in the Source form of any Derivative Works
312
+ that You distribute, all copyright, patent, trademark, and
313
+ attribution notices from the Source form of the Work,
314
+ excluding those notices that do not pertain to any part of
315
+ the Derivative Works; and
316
+
317
+ (d) If the Work includes a "NOTICE" text file as part of its
318
+ distribution, then any Derivative Works that You distribute must
319
+ include a readable copy of the attribution notices contained
320
+ within such NOTICE file, excluding those notices that do not
321
+ pertain to any part of the Derivative Works, in at least one
322
+ of the following places: within a NOTICE text file distributed
323
+ as part of the Derivative Works; within the Source form or
324
+ documentation, if provided along with the Derivative Works; or,
325
+ within a display generated by the Derivative Works, if and
326
+ wherever such third-party notices normally appear. The contents
327
+ of the NOTICE file are for informational purposes only and
328
+ do not modify the License. You may add Your own attribution
329
+ notices within Derivative Works that You distribute, alongside
330
+ or as an addendum to the NOTICE text from the Work, provided
331
+ that such additional attribution notices cannot be construed
332
+ as modifying the License.
333
+
334
+ You may add Your own copyright statement to Your modifications and
335
+ may provide additional or different license terms and conditions
336
+ for use, reproduction, or distribution of Your modifications, or
337
+ for any such Derivative Works as a whole, provided Your use,
338
+ reproduction, and distribution of the Work otherwise complies with
339
+ the conditions stated in this License.
340
+
341
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
342
+ any Contribution intentionally submitted for inclusion in the Work
343
+ by You to the Licensor shall be under the terms and conditions of
344
+ this License, without any additional terms or conditions.
345
+ Notwithstanding the above, nothing herein shall supersede or modify
346
+ the terms of any separate license agreement you may have executed
347
+ with Licensor regarding such Contributions.
348
+
349
+ 6. Trademarks. This License does not grant permission to use the trade
350
+ names, trademarks, service marks, or product names of the Licensor,
351
+ except as required for reasonable and customary use in describing the
352
+ origin of the Work and reproducing the content of the NOTICE file.
353
+
354
+ 7. Disclaimer of Warranty. Unless required by applicable law or
355
+ agreed to in writing, Licensor provides the Work (and each
356
+ Contributor provides its Contributions) on an "AS IS" BASIS,
357
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
358
+ implied, including, without limitation, any warranties or conditions
359
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
360
+ PARTICULAR PURPOSE. You are solely responsible for determining the
361
+ appropriateness of using or redistributing the Work and assume any
362
+ risks associated with Your exercise of permissions under this License.
363
+
364
+ 8. Limitation of Liability. In no event and under no legal theory,
365
+ whether in tort (including negligence), contract, or otherwise,
366
+ unless required by applicable law (such as deliberate and grossly
367
+ negligent acts) or agreed to in writing, shall any Contributor be
368
+ liable to You for damages, including any direct, indirect, special,
369
+ incidental, or consequential damages of any character arising as a
370
+ result of this License or out of the use or inability to use the
371
+ Work (including but not limited to damages for loss of goodwill,
372
+ work stoppage, computer failure or malfunction, or any and all
373
+ other commercial damages or losses), even if such Contributor
374
+ has been advised of the possibility of such damages.
375
+
376
+ 9. Accepting Warranty or Additional Liability. While redistributing
377
+ the Work or Derivative Works thereof, You may choose to offer,
378
+ and charge a fee for, acceptance of support, warranty, indemnity,
379
+ or other liability obligations and/or rights consistent with this
380
+ License. However, in accepting such obligations, You may act only
381
+ on Your own behalf and on Your sole responsibility, not on behalf
382
+ of any other Contributor, and only if You agree to indemnify,
383
+ defend, and hold each Contributor harmless for any liability
384
+ incurred by, or claims asserted against, such Contributor by reason
385
+ of your accepting any such warranty or additional liability.
386
+
387
+ END OF TERMS AND CONDITIONS
388
+
389
+ APPENDIX: How to apply the Apache License to your work.
390
+
391
+ To apply the Apache License to your work, attach the following
392
+ boilerplate notice, with the fields enclosed by brackets "[]"
393
+ replaced with your own identifying information. (Don't include
394
+ the brackets!) The text should be enclosed in the appropriate
395
+ comment syntax for the file format. We also recommend that a
396
+ file or class name and description of purpose be included on the
397
+ same "printed page" as the copyright notice for easier
398
+ identification within third-party archives.
399
+
400
+ Copyright [yyyy] [name of copyright owner]
401
+
402
+ Licensed under the Apache License, Version 2.0 (the "License");
403
+ you may not use this file except in compliance with the License.
404
+ You may obtain a copy of the License at
405
+
406
+ http://www.apache.org/licenses/LICENSE-2.0
407
+
408
+ Unless required by applicable law or agreed to in writing, software
409
+ distributed under the License is distributed on an "AS IS" BASIS,
410
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
411
+ See the License for the specific language governing permissions and
412
+ limitations under the License.
config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "PlamoBiModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "modeling_plamo.PlamoConfig",
8
+ "AutoModel": "modeling_plamo.PlamoBiModel"
9
+ },
10
+ "bos_token_id": 1,
11
+ "capacity_factor": 1.0,
12
+ "eos_token_id": 1,
13
+ "eval_attention_n_bit": null,
14
+ "eval_mlp_n_bit": null,
15
+ "eval_offload_moe": false,
16
+ "expert_dropout": 0.0,
17
+ "fp8_accum_dtype": "bfloat16",
18
+ "group_size": 1024,
19
+ "hidden_size": 2048,
20
+ "hidden_size_per_head": 128,
21
+ "initializer_range": 0.02,
22
+ "intermediate_size": 8192,
23
+ "k_expert": null,
24
+ "linear_type": "fp8",
25
+ "max_length": 4096,
26
+ "max_position_embeddings": 4096,
27
+ "model_type": "plamo",
28
+ "n_expert": null,
29
+ "num_attention_heads": 16,
30
+ "num_hidden_layers": 16,
31
+ "num_key_value_heads": 1,
32
+ "pad_token_id": 3,
33
+ "rms_norm_eps": 1e-06,
34
+ "shared_intermediate_size": null,
35
+ "sparse_intermediate_size": null,
36
+ "sparse_step": null,
37
+ "tie_word_embeddings": false,
38
+ "tokenizer_class": "PlamoTokenizer",
39
+ "torch_dtype": "bfloat16",
40
+ "transformers_version": "4.47.0",
41
+ "use_cache": false,
42
+ "vocab_size": 50112
43
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e9e26cd19d9a90dc79d1c0e8d4755d881fe984a8582d7dcd776b9adb2dcf9f1
3
+ size 2101303432
modeling_plamo.py ADDED
@@ -0,0 +1,1089 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ from typing import Any, List, NamedTuple, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from transformers import AutoTokenizer, PretrainedConfig, PreTrainedModel
8
+ from transformers.modeling_attn_mask_utils import (
9
+ _prepare_4d_causal_attention_mask,
10
+ _prepare_4d_causal_attention_mask_for_sdpa,
11
+ )
12
+ from transformers.modeling_outputs import BaseModelOutputWithPast
13
+ from transformers.tokenization_utils_base import BatchEncoding
14
+
15
+
16
+ def _swiglu(h: torch.Tensor) -> torch.Tensor:
17
+ h0, h1 = h.chunk(2, dim=-1)
18
+ return torch.nn.functional.silu(h0) * h1
19
+
20
+
21
+ class PlamoAttentionCache:
22
+ def __init__(self, key: torch.Tensor, value: torch.Tensor) -> None:
23
+ B, nh, L, c = key.shape
24
+ assert len(value.shape) == 4
25
+ assert value.shape[0] == B
26
+ assert value.shape[2] == L
27
+ self.key = key
28
+ self.value = value
29
+
30
+ def _validate(self, cache: torch.Tensor, new_tensor: torch.Tensor) -> None:
31
+ assert len(cache.shape) == 4
32
+ assert len(new_tensor.shape) == 4
33
+ assert cache.shape[0] == new_tensor.shape[0]
34
+ assert cache.shape[1] == new_tensor.shape[1]
35
+ assert cache.shape[3] == new_tensor.shape[3]
36
+
37
+ def append_cache(self, k: torch.Tensor, v: torch.Tensor) -> None:
38
+ self._validate(self.key, k)
39
+ self._validate(self.value, v)
40
+ assert k.shape[2] == v.shape[2]
41
+ self.key = torch.cat([self.key, k], dim=2)
42
+ self.value = torch.cat([self.value, v], dim=2)
43
+
44
+ def sequence_length(self) -> int:
45
+ return self.key.shape[2]
46
+
47
+
48
+ PlamoLayerCache = PlamoAttentionCache
49
+
50
+ PlamoCache = list[PlamoLayerCache]
51
+
52
+
53
+ class DecoderInput(NamedTuple):
54
+ hidden_states: torch.Tensor
55
+ position_ids: torch.Tensor
56
+ attention_mask: Optional[torch.Tensor] = None
57
+ past_key_values: Optional[PlamoCache] = None
58
+ output_hidden_states: Optional[bool] = False
59
+ output_attentions: Optional[bool] = False
60
+ use_cache: Optional[bool] = False
61
+ gradient_checkpointing: bool = False
62
+ input_ids: Optional[torch.Tensor] = None
63
+
64
+
65
+ class DecoderOutput(NamedTuple):
66
+ hidden_states: torch.Tensor
67
+ all_hidden_states: Optional[Tuple[torch.Tensor, ...]]
68
+ all_self_attns: Optional[Tuple[torch.Tensor, ...]]
69
+ next_decoder_cache: Optional[PlamoCache]
70
+
71
+
72
+ class LinearType(str, enum.Enum):
73
+ Normal = "normal"
74
+ Fp8 = "fp8"
75
+ Fp8Retain = "fp8-retain"
76
+
77
+
78
+ class PlamoConfig(PretrainedConfig): # type: ignore
79
+ model_type: str = "plamo"
80
+
81
+ def __init__(
82
+ self,
83
+ vocab_size: int = 32000,
84
+ hidden_size: int = 4096,
85
+ intermediate_size: int = 13312,
86
+ num_hidden_layers: int = 32,
87
+ num_attention_heads: int = 32,
88
+ num_key_value_heads: int = 4,
89
+ hidden_size_per_head: int = 128,
90
+ max_position_embeddings: int = 2048,
91
+ initializer_range: float = 0.02,
92
+ rms_norm_eps: float = 1e-6,
93
+ use_cache: bool = True,
94
+ tokenizer_class: str = "PlamoTokenizer",
95
+ pad_token_id: Optional[int] = None,
96
+ bos_token_id: int = 1,
97
+ eos_token_id: int = 2,
98
+ tie_word_embeddings: bool = False,
99
+ n_expert: Optional[int] = None,
100
+ k_expert: Optional[int] = None,
101
+ expert_dropout: float = 0.0,
102
+ capacity_factor: float = 1.0,
103
+ group_size: int = 1024,
104
+ sparse_step: Optional[int] = None,
105
+ sparse_intermediate_size: Optional[int] = None,
106
+ shared_intermediate_size: Optional[int] = None,
107
+ linear_type: LinearType = LinearType.Normal,
108
+ fp8_accum_dtype: Optional[str] = None,
109
+ eval_attention_n_bit: Optional[int] = None,
110
+ eval_mlp_n_bit: Optional[int] = None,
111
+ eval_offload_moe: bool = False,
112
+ attention_dropout: float = 0.0,
113
+ **kwargs: Any,
114
+ ) -> None:
115
+ self.vocab_size = vocab_size
116
+ self.max_position_embeddings = max_position_embeddings
117
+ self.hidden_size = hidden_size
118
+ self.intermediate_size = intermediate_size
119
+ self.num_hidden_layers = num_hidden_layers
120
+ self.num_attention_heads = num_attention_heads
121
+ self.hidden_size_per_head = hidden_size_per_head
122
+
123
+ self.initializer_range = initializer_range
124
+ self.rms_norm_eps = rms_norm_eps
125
+ self.use_cache = use_cache
126
+
127
+ self.num_key_value_heads = num_key_value_heads
128
+
129
+ self.n_expert = n_expert
130
+ self.k_expert = k_expert
131
+ self.sparse_intermediate_size = sparse_intermediate_size
132
+ self.shared_intermediate_size = shared_intermediate_size
133
+ self.expert_dropout = expert_dropout
134
+ self.capacity_factor = capacity_factor
135
+ self.group_size = group_size
136
+ self.sparse_step = sparse_step
137
+
138
+ self.linear_type = linear_type
139
+ self.fp8_accum_dtype = fp8_accum_dtype
140
+
141
+ self.eval_attention_n_bit = eval_attention_n_bit
142
+ self.eval_mlp_n_bit = eval_mlp_n_bit
143
+ self.eval_offload_moe = eval_offload_moe
144
+
145
+ self.attention_dropout = attention_dropout
146
+
147
+ super().__init__(
148
+ tokenizer_class=tokenizer_class,
149
+ pad_token_id=pad_token_id,
150
+ bos_token_id=bos_token_id,
151
+ eos_token_id=eos_token_id,
152
+ tie_word_embeddings=tie_word_embeddings,
153
+ **kwargs,
154
+ )
155
+
156
+
157
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
158
+ def _make_causal_mask(
159
+ input_ids_shape: Tuple[int, int],
160
+ dtype: torch.dtype,
161
+ device: torch.device,
162
+ past_key_values_length: int = 0,
163
+ ) -> torch.Tensor:
164
+ """
165
+ Make causal mask used for bi-directional self-attention.
166
+ """
167
+ bsz, tgt_len = input_ids_shape
168
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
169
+ mask_cond = torch.arange(mask.size(-1), device=device)
170
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
171
+ mask = mask.to(dtype)
172
+
173
+ if past_key_values_length > 0:
174
+ mask = torch.cat(
175
+ [
176
+ torch.zeros(
177
+ tgt_len, past_key_values_length, dtype=dtype, device=device
178
+ ),
179
+ mask,
180
+ ],
181
+ dim=-1,
182
+ )
183
+ return mask[None, None, :, :].expand(
184
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
185
+ )
186
+
187
+
188
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
189
+ def _expand_mask(
190
+ mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
191
+ ) -> torch.Tensor:
192
+ """
193
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
194
+ """
195
+ bsz, src_len = mask.size()
196
+ tgt_len = tgt_len if tgt_len is not None else src_len
197
+
198
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
199
+
200
+ inverted_mask = 1.0 - expanded_mask
201
+
202
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) # type: ignore
203
+
204
+
205
+ class RotaryEmbedding(torch.nn.Module):
206
+ def __init__(
207
+ self,
208
+ dim: int,
209
+ max_position_embeddings: int = 2048,
210
+ base: int = 10000,
211
+ device: Optional[torch.device] = None,
212
+ ) -> None:
213
+ super().__init__()
214
+
215
+ self.dim = dim
216
+ self.max_position_embeddings = max_position_embeddings
217
+ self.base = base
218
+ inv_freq = 1.0 / (
219
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
220
+ )
221
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
222
+
223
+ # Build here to make `torch.jit.trace` work.
224
+ self._set_cos_sin_cache(
225
+ seq_len=max_position_embeddings,
226
+ device=self.inv_freq.device,
227
+ dtype=torch.get_default_dtype(),
228
+ )
229
+
230
+ def _set_cos_sin_cache(self, seq_len: int, device: Any, dtype: Any) -> None:
231
+ self.max_seq_len_cached = seq_len
232
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) # type: ignore
233
+
234
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
235
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
236
+ emb = torch.cat((freqs, freqs), dim=-1)
237
+ self.register_buffer(
238
+ "cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False
239
+ )
240
+ self.register_buffer(
241
+ "sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
242
+ )
243
+
244
+ def forward(
245
+ self, x: torch.Tensor, seq_len: int
246
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
247
+ # x: [bs, num_attention_heads, seq_len, head_size]
248
+ if seq_len > self.max_seq_len_cached:
249
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
250
+
251
+ return (
252
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), # type: ignore
253
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), # type: ignore
254
+ )
255
+
256
+
257
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
258
+ """Rotates half the hidden dims of the input."""
259
+ x1 = x[..., : x.shape[-1] // 2]
260
+ x2 = x[..., x.shape[-1] // 2 :]
261
+ return torch.cat((-x2, x1), dim=-1)
262
+
263
+
264
+ def _rotary_pos_emb(
265
+ x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor
266
+ ) -> torch.Tensor:
267
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
268
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
269
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
270
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
271
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
272
+ x_embed = (x * cos) + (_rotate_half(x) * sin)
273
+ return x_embed
274
+
275
+
276
+ def _rms_norm(
277
+ hidden_states: torch.Tensor, weight: Optional[torch.Tensor], eps: float
278
+ ) -> torch.Tensor:
279
+ input_dtype = hidden_states.dtype
280
+ hidden_states = hidden_states.to(torch.float32)
281
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
282
+ hidden_states = hidden_states * torch.rsqrt(variance + eps)
283
+ hidden_states = hidden_states.to(input_dtype)
284
+ if weight is not None:
285
+ hidden_states = weight * hidden_states
286
+ return hidden_states
287
+
288
+
289
+ class RMSNorm(nn.Module):
290
+ def __init__(
291
+ self,
292
+ hidden_size: int,
293
+ eps: float = 1e-6,
294
+ device: Optional[Union[torch.device, str]] = None,
295
+ ) -> None:
296
+ super().__init__()
297
+ self.weight = nn.Parameter(torch.ones(hidden_size, device=device))
298
+ self.variance_epsilon = eps
299
+
300
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
301
+ return _rms_norm(hidden_states, self.weight, self.variance_epsilon)
302
+
303
+
304
+ class Attention(torch.nn.Module):
305
+ def __init__(self, config: PlamoConfig) -> None:
306
+ super().__init__()
307
+ self.config = config
308
+ self.hidden_size = config.hidden_size
309
+ head_dim = config.hidden_size_per_head
310
+ self.max_position_embeddings = config.max_position_embeddings
311
+
312
+ self.q_num_heads = config.num_attention_heads
313
+ self.qk_dim = self.v_dim = head_dim
314
+ self.k_num_heads = self.v_num_heads = config.num_key_value_heads
315
+ assert self.q_num_heads % self.k_num_heads == 0
316
+ self.n_group = self.q_num_heads // self.k_num_heads
317
+
318
+ self.q_proj_dim = self.q_num_heads * self.qk_dim
319
+ self.k_proj_dim = self.k_num_heads * self.qk_dim
320
+ self.v_proj_dim = self.k_num_heads * self.v_dim
321
+ self.qkv_proj = nn.Linear(
322
+ self.hidden_size,
323
+ self.q_proj_dim + self.k_proj_dim + self.v_proj_dim,
324
+ bias=False,
325
+ )
326
+ self.o_proj = nn.Linear(
327
+ self.q_num_heads * self.v_dim, self.hidden_size, bias=False
328
+ )
329
+ self.rotary_emb = RotaryEmbedding(
330
+ self.qk_dim, max_position_embeddings=self.max_position_embeddings
331
+ )
332
+
333
+ self.q_weight = torch.nn.Parameter(torch.ones((self.q_num_heads, self.qk_dim)))
334
+ self.k_weight = torch.nn.Parameter(torch.ones((self.k_num_heads, self.qk_dim)))
335
+ self.is_causal = True
336
+ self.attention_dropout = config.attention_dropout
337
+
338
+ def forward(
339
+ self,
340
+ hidden_states: torch.Tensor,
341
+ attention_mask: Optional[torch.Tensor] = None,
342
+ position_ids: Optional[torch.Tensor] = None,
343
+ past_key_value: Optional[PlamoLayerCache] = None,
344
+ output_attentions: bool = False,
345
+ use_cache: bool = False,
346
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[PlamoLayerCache]]:
347
+ bsz, q_len, _ = hidden_states.size()
348
+
349
+ qkv = self.qkv_proj(hidden_states)
350
+ query_states, key_states, value_states = torch.split(
351
+ qkv, [self.q_proj_dim, self.k_proj_dim, self.v_proj_dim], dim=-1
352
+ )
353
+ query_states = query_states.view(
354
+ bsz, q_len, self.q_num_heads, self.qk_dim
355
+ ).transpose(1, 2)
356
+ key_states = key_states.view(
357
+ bsz, q_len, self.k_num_heads, self.qk_dim
358
+ ).transpose(1, 2)
359
+ value_states = value_states.view(
360
+ bsz, q_len, self.v_num_heads, self.v_dim
361
+ ).transpose(1, 2)
362
+
363
+ attn_dtype = query_states.dtype
364
+
365
+ query_states = (
366
+ _rms_norm(query_states, None, 1e-6) * self.q_weight[None, :, None]
367
+ )
368
+ key_states = _rms_norm(key_states, None, 1e-6) * self.k_weight[None, :, None]
369
+
370
+ if use_cache and past_key_value is None:
371
+ bsz, nhead_k, _, c_k = key_states.shape
372
+ _, nhead_v, _, c_v = value_states.shape
373
+ past_key_value = PlamoAttentionCache(
374
+ torch.zeros(
375
+ (bsz, nhead_k, 0, c_k),
376
+ dtype=key_states.dtype,
377
+ device=key_states.device,
378
+ ),
379
+ torch.zeros(
380
+ (bsz, nhead_v, 0, c_v),
381
+ dtype=value_states.dtype,
382
+ device=value_states.device,
383
+ ),
384
+ )
385
+
386
+ kv_seq_len = key_states.shape[-2]
387
+ if past_key_value is not None:
388
+ kv_seq_len += past_key_value.sequence_length()
389
+
390
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
391
+ assert position_ids is not None
392
+ query_states = _rotary_pos_emb(query_states, cos, sin, position_ids)
393
+ key_states = _rotary_pos_emb(key_states, cos, sin, position_ids)
394
+ # [bsz, nh, t, hd]
395
+
396
+ if past_key_value is not None:
397
+ # reuse k, v, self_attention
398
+ past_key_value.append_cache(key_states, value_states)
399
+ key_states = past_key_value.key
400
+ value_states = past_key_value.value
401
+
402
+ def _expand_kv(t: torch.Tensor, repeat: int, target: int) -> torch.Tensor:
403
+ t = torch.repeat_interleave(t, repeat, dim=1)
404
+ return t[:, :target]
405
+
406
+ # expand shared kv
407
+ assert self.k_num_heads == self.v_num_heads
408
+ key_states = _expand_kv(key_states, self.n_group, self.q_num_heads)
409
+ value_states = _expand_kv(value_states, self.n_group, self.q_num_heads)
410
+
411
+ query_states = query_states.to(attn_dtype)
412
+ key_states = key_states.to(attn_dtype)
413
+ value_states = value_states.to(attn_dtype)
414
+
415
+ if attention_mask is not None and attention_mask.dtype != torch.bool:
416
+ attention_mask = attention_mask.to(attn_dtype)
417
+
418
+ attn_output = F.scaled_dot_product_attention(
419
+ query_states,
420
+ key_states,
421
+ value_states,
422
+ attn_mask=attention_mask,
423
+ is_causal=self.is_causal,
424
+ dropout_p=self.attention_dropout if self.training else 0.0,
425
+ )
426
+ attn_output = attn_output.transpose(1, 2)
427
+
428
+ attn_output = attn_output.reshape(bsz, q_len, self.q_num_heads * self.v_dim)
429
+ attn_output = self.o_proj(attn_output)
430
+
431
+ if not output_attentions:
432
+ attn_weights = None
433
+
434
+ return attn_output, attn_weights, past_key_value
435
+
436
+
437
+ class DenseMLP(nn.Module):
438
+ def __init__(self, config: PlamoConfig) -> None:
439
+ super().__init__()
440
+ self.config = config
441
+ self.hidden_size = config.hidden_size
442
+ self.intermediate_size = config.intermediate_size
443
+ self.gate_up_proj = torch.nn.Linear(
444
+ self.hidden_size, self.intermediate_size * 2, bias=False
445
+ )
446
+ self.down_proj = torch.nn.Linear(
447
+ self.intermediate_size, self.hidden_size, bias=False
448
+ )
449
+
450
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
451
+ h = self.gate_up_proj(x)
452
+ h = _swiglu(h)
453
+ return self.down_proj(h) # type: ignore
454
+
455
+
456
+ def MLP(config: PlamoConfig, is_sparse: bool) -> torch.nn.Module:
457
+ return DenseMLP(config)
458
+
459
+
460
+ class PlamoDecoderLayer(torch.nn.Module):
461
+ def __init__(self, config: PlamoConfig, is_sparse: bool) -> None:
462
+ super().__init__()
463
+ self.config = config
464
+ self.hidden_size = config.hidden_size
465
+ self.self_attn = Attention(config)
466
+ self.mlp = MLP(config, is_sparse=is_sparse)
467
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
468
+ self.norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
469
+
470
+ def forward(
471
+ self,
472
+ hidden_states: torch.Tensor,
473
+ attention_mask: Optional[torch.Tensor] = None,
474
+ position_ids: Optional[torch.LongTensor] = None,
475
+ past_key_value: Optional[PlamoLayerCache] = None,
476
+ output_attentions: Optional[bool] = False,
477
+ use_cache: Optional[bool] = False,
478
+ ) -> Tuple[Any, ...]:
479
+ # from LlamaDecoder
480
+ residual = hidden_states
481
+ hidden_states = self.norm(hidden_states)
482
+
483
+ # Self Attention
484
+ hidden_states_sa, self_attn_weights, present_key_value = self.self_attn(
485
+ hidden_states=hidden_states,
486
+ attention_mask=attention_mask,
487
+ position_ids=position_ids,
488
+ past_key_value=past_key_value,
489
+ output_attentions=output_attentions,
490
+ use_cache=use_cache,
491
+ )
492
+
493
+ hidden_states = residual + hidden_states_sa
494
+
495
+ residual = hidden_states
496
+ hidden_states = self.norm2(hidden_states)
497
+
498
+ # Fully Connected
499
+ hidden_states_mlp = self.mlp(hidden_states)
500
+
501
+ # Residual
502
+ hidden_states = residual + hidden_states_mlp
503
+
504
+ outputs: Any = (hidden_states,)
505
+
506
+ if output_attentions:
507
+ outputs += (self_attn_weights,)
508
+
509
+ if use_cache:
510
+ outputs += (present_key_value,)
511
+
512
+ return outputs # type: ignore
513
+
514
+
515
+ def is_sparse(config: PlamoConfig, i: int) -> bool:
516
+ if config.sparse_step is None:
517
+ return False
518
+ if config.sparse_step == 1:
519
+ return True
520
+ return (i % config.sparse_step) == 1
521
+
522
+
523
+ class PlamoDecoder(torch.nn.Module):
524
+ def __init__(self, config: PlamoConfig) -> None:
525
+ super().__init__()
526
+
527
+ self.layers = torch.nn.ModuleList(
528
+ [
529
+ PlamoDecoderLayer(config, is_sparse=is_sparse(config, i))
530
+ for i in range(config.num_hidden_layers)
531
+ ]
532
+ )
533
+
534
+ def forward(self, x: DecoderInput) -> DecoderOutput:
535
+ all_hidden_states: Optional[Tuple[torch.Tensor, ...]] = (
536
+ () if x.output_hidden_states else None
537
+ )
538
+ all_self_attns: Optional[Tuple[torch.Tensor, ...]] = (
539
+ () if x.output_attentions else None
540
+ )
541
+ next_decoder_cache: Optional[PlamoCache] = [] if x.use_cache else None
542
+ hidden_states = x.hidden_states
543
+ for idx, decoder_layer in enumerate(self.layers):
544
+ if x.output_hidden_states:
545
+ assert all_hidden_states is not None
546
+ all_hidden_states += (hidden_states,)
547
+
548
+ past_key_value = (
549
+ x.past_key_values[idx] if x.past_key_values is not None else None
550
+ )
551
+
552
+ if self.training and x.gradient_checkpointing:
553
+
554
+ def create_custom_forward(module): # type: ignore
555
+ def custom_forward(*inputs): # type: ignore
556
+ # None for past_key_value
557
+ return module(*inputs, x.output_attentions, None)
558
+
559
+ return custom_forward
560
+
561
+ layer_outputs = torch.utils.checkpoint.checkpoint(
562
+ create_custom_forward(decoder_layer), # type: ignore
563
+ hidden_states,
564
+ x.attention_mask,
565
+ x.position_ids,
566
+ None,
567
+ use_reentrant=False,
568
+ )
569
+ else:
570
+ layer_outputs = decoder_layer(
571
+ hidden_states,
572
+ attention_mask=x.attention_mask,
573
+ position_ids=x.position_ids,
574
+ past_key_value=past_key_value,
575
+ output_attentions=x.output_attentions,
576
+ use_cache=x.use_cache,
577
+ )
578
+
579
+ hidden_states = layer_outputs[0]
580
+ if x.use_cache:
581
+ cache = layer_outputs[2 if x.output_attentions else 1]
582
+ assert cache is not None
583
+ assert next_decoder_cache is not None
584
+ next_decoder_cache += (cache,)
585
+
586
+ if x.output_attentions:
587
+ assert layer_outputs[1] is not None
588
+ assert all_self_attns is not None
589
+ all_self_attns += (layer_outputs[1],)
590
+ return DecoderOutput(
591
+ hidden_states, all_hidden_states, all_self_attns, next_decoder_cache
592
+ )
593
+
594
+
595
+ class PlamoPreTrainedModel(PreTrainedModel): # type: ignore
596
+ config_class = PlamoConfig
597
+ _no_split_modules: List[str]
598
+ base_model_prefix = "model"
599
+ supports_gradient_checkpointing = True
600
+ _supports_sdpa = True
601
+ _no_split_modules = ["PlamoDecoderLayer"]
602
+ _skip_keys_device_placement = "past_key_values"
603
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
604
+
605
+ def _init_weights(self, module: torch.nn.Module) -> None:
606
+ std = self.config.initializer_range
607
+ if isinstance(module, nn.Linear):
608
+ module.weight.data.normal_(mean=0.0, std=std)
609
+ if module.bias is not None:
610
+ module.bias.data.zero_()
611
+ elif isinstance(module, nn.Embedding):
612
+ module.weight.data.normal_(mean=0.0, std=std)
613
+ if module.padding_idx is not None:
614
+ module.weight.data[module.padding_idx].zero_()
615
+
616
+ def _set_gradient_checkpointing(
617
+ self, module: torch.nn.Module, value: bool = False
618
+ ) -> None:
619
+ module.gradient_checkpointing = value # type: ignore
620
+
621
+
622
+ class PlamoModel(PlamoPreTrainedModel):
623
+ def __init__(self, config: PlamoConfig):
624
+ super().__init__(config)
625
+ assert config.eval_attention_n_bit is None
626
+ assert config.eval_mlp_n_bit is None
627
+ assert not config.eval_offload_moe
628
+
629
+ self.padding_idx = config.pad_token_id
630
+ self.vocab_size = config.vocab_size
631
+
632
+ self.embed_tokens = nn.Embedding(
633
+ config.vocab_size, config.hidden_size, self.padding_idx
634
+ )
635
+ self.layers = PlamoDecoder(config) # type: ignore
636
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
637
+
638
+ self.gradient_checkpointing = False
639
+ # Initialize weights and apply final processing
640
+ self.post_init()
641
+
642
+ def get_input_embeddings(self) -> torch.nn.Embedding:
643
+ return self.embed_tokens
644
+
645
+ def set_input_embeddings(self, value: torch.nn.Embedding) -> None:
646
+ self.embed_tokens = value
647
+
648
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
649
+ def _prepare_decoder_attention_mask(
650
+ self,
651
+ attention_mask: torch.Tensor,
652
+ input_shape: Tuple[int, int],
653
+ inputs_embeds: Optional[torch.Tensor],
654
+ past_key_values_length: int,
655
+ ) -> Optional[torch.Tensor]:
656
+ # create causal mask
657
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
658
+ combined_attention_mask: Optional[torch.Tensor] = None
659
+ if input_shape[-1] > 1:
660
+ assert inputs_embeds is not None
661
+ combined_attention_mask = _make_causal_mask(
662
+ input_shape,
663
+ inputs_embeds.dtype,
664
+ device=inputs_embeds.device,
665
+ past_key_values_length=past_key_values_length,
666
+ )
667
+
668
+ if attention_mask is not None:
669
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
670
+ assert inputs_embeds is not None
671
+ expanded_attn_mask = _expand_mask(
672
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
673
+ ).to(inputs_embeds.device)
674
+ combined_attention_mask = (
675
+ expanded_attn_mask
676
+ if combined_attention_mask is None
677
+ else expanded_attn_mask + combined_attention_mask
678
+ )
679
+
680
+ return combined_attention_mask
681
+
682
+ def forward(
683
+ self,
684
+ input_ids: Optional[torch.LongTensor] = None,
685
+ attention_mask: Optional[torch.Tensor] = None,
686
+ position_ids: Optional[torch.Tensor] = None,
687
+ past_key_values: Optional[PlamoCache] = None,
688
+ inputs_embeds: Optional[torch.Tensor] = None,
689
+ use_cache: Optional[bool] = None,
690
+ output_attentions: Optional[bool] = None,
691
+ output_hidden_states: Optional[bool] = None,
692
+ return_dict: Optional[bool] = None,
693
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
694
+ assert input_ids is not None
695
+ output_attentions = (
696
+ output_attentions
697
+ if output_attentions is not None
698
+ else self.config.output_attentions
699
+ )
700
+ output_hidden_states = (
701
+ output_hidden_states
702
+ if output_hidden_states is not None
703
+ else self.config.output_hidden_states
704
+ )
705
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
706
+
707
+ return_dict = (
708
+ return_dict if return_dict is not None else self.config.use_return_dict
709
+ )
710
+
711
+ # retrieve input_ids and inputs_embeds
712
+ if input_ids is not None and inputs_embeds is not None:
713
+ raise ValueError(
714
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
715
+ )
716
+ elif input_ids is not None:
717
+ batch_size, seq_length = input_ids.shape
718
+ else:
719
+ raise ValueError(
720
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
721
+ )
722
+
723
+ seq_length_with_past = seq_length
724
+ past_key_values_length = 0
725
+
726
+ if past_key_values is not None:
727
+ past_key_values_length = past_key_values[0].sequence_length()
728
+ seq_length_with_past = seq_length_with_past + past_key_values_length
729
+
730
+ if position_ids is None:
731
+ device = input_ids.device
732
+ position_ids = torch.arange(
733
+ past_key_values_length,
734
+ seq_length + past_key_values_length,
735
+ dtype=torch.long,
736
+ device=device,
737
+ )
738
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
739
+ else:
740
+ position_ids = position_ids.view(-1, seq_length).long()
741
+
742
+ if inputs_embeds is None:
743
+ inputs_embeds = self.embed_tokens(input_ids)
744
+ # embed positions
745
+ if (
746
+ attention_mask is not None
747
+ or not self.training
748
+ or past_key_values is not None
749
+ ):
750
+ if attention_mask is None:
751
+ attention_mask = torch.ones(
752
+ (batch_size, seq_length_with_past),
753
+ dtype=torch.bool,
754
+ device=inputs_embeds.device,
755
+ )
756
+ # attention_mask = self._prepare_decoder_attention_mask(
757
+ # attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
758
+ # )
759
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
760
+ attention_mask,
761
+ (batch_size, seq_length),
762
+ inputs_embeds,
763
+ past_key_values_length,
764
+ )
765
+
766
+ hidden_states = inputs_embeds
767
+
768
+ if self.gradient_checkpointing and self.training:
769
+ if use_cache:
770
+ use_cache = False
771
+
772
+ # decoder layers
773
+ out = self.layers(
774
+ DecoderInput(
775
+ hidden_states,
776
+ position_ids,
777
+ attention_mask,
778
+ past_key_values,
779
+ output_hidden_states,
780
+ output_attentions,
781
+ use_cache,
782
+ self.gradient_checkpointing,
783
+ )
784
+ )
785
+ assert isinstance(out, DecoderOutput)
786
+ hidden_states = out.hidden_states
787
+ all_hidden_states = out.all_hidden_states
788
+ all_self_attns = out.all_self_attns
789
+ next_decoder_cache = out.next_decoder_cache
790
+
791
+ hidden_states = self.norm(hidden_states)
792
+
793
+ # add hidden states from the last decoder layer
794
+ if output_hidden_states:
795
+ assert all_hidden_states is not None
796
+ all_hidden_states += (hidden_states,)
797
+
798
+ next_cache = next_decoder_cache if use_cache else None
799
+ if not return_dict:
800
+ return tuple(
801
+ v
802
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
803
+ if v is not None
804
+ )
805
+ return BaseModelOutputWithPast(
806
+ last_hidden_state=hidden_states,
807
+ past_key_values=next_cache,
808
+ hidden_states=all_hidden_states,
809
+ attentions=all_self_attns,
810
+ )
811
+
812
+
813
+ class ModifiedAttention(Attention):
814
+ def __init__(self, config: PlamoConfig, **kwargs):
815
+ super().__init__(config, **kwargs)
816
+ self.is_causal = False
817
+
818
+
819
+ PLAMO_ATTENTION_CLASSES = {
820
+ "sdpa": ModifiedAttention,
821
+ }
822
+
823
+
824
+ class ModifiedPlamoDecoderLayer(PlamoDecoderLayer):
825
+ def __init__(self, config: PlamoConfig, is_sparse: bool):
826
+ nn.Module.__init__(self)
827
+ self.config = config
828
+ self.hidden_size = config.hidden_size
829
+
830
+ self.self_attn = PLAMO_ATTENTION_CLASSES[config._attn_implementation](
831
+ config=config
832
+ )
833
+
834
+ self.mlp = MLP(config, is_sparse=is_sparse)
835
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
836
+ self.norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
837
+
838
+
839
+ class ModifiedPlamoDecoder(PlamoDecoder):
840
+ def __init__(self, config: PlamoConfig) -> None:
841
+ nn.Module.__init__(self)
842
+ self.layers = nn.ModuleList(
843
+ [
844
+ ModifiedPlamoDecoderLayer(
845
+ config, is_sparse=is_sparse(config, layer_idx)
846
+ )
847
+ for layer_idx in range(config.num_hidden_layers)
848
+ ]
849
+ )
850
+
851
+
852
+ class PlamoBiModel(PlamoModel):
853
+ _no_split_modules = ["ModifiedPlamoDecoderLayer"]
854
+
855
+ def __init__(self, config: PlamoConfig):
856
+ PlamoPreTrainedModel.__init__(self, config)
857
+ self.padding_idx = config.pad_token_id
858
+ self.vocab_size = config.vocab_size
859
+
860
+ self.embed_tokens = nn.Embedding(
861
+ config.vocab_size, config.hidden_size, self.padding_idx
862
+ )
863
+
864
+ self.layers = ModifiedPlamoDecoder(config)
865
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
866
+ self.gradient_checkpointing = False
867
+ self._attn_implementation = config._attn_implementation
868
+ self.post_init()
869
+
870
+ def forward(
871
+ self,
872
+ input_ids: Optional[torch.LongTensor] = None,
873
+ attention_mask: Optional[torch.Tensor] = None,
874
+ position_ids: Optional[torch.Tensor] = None,
875
+ past_key_values: Optional[PlamoCache] = None,
876
+ inputs_embeds: Optional[torch.Tensor] = None,
877
+ use_cache: Optional[bool] = None,
878
+ output_attentions: Optional[bool] = None,
879
+ output_hidden_states: Optional[bool] = None,
880
+ return_dict: Optional[bool] = None,
881
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
882
+ assert input_ids is not None
883
+ output_attentions = (
884
+ output_attentions
885
+ if output_attentions is not None
886
+ else self.config.output_attentions
887
+ )
888
+ output_hidden_states = (
889
+ output_hidden_states
890
+ if output_hidden_states is not None
891
+ else self.config.output_hidden_states
892
+ )
893
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
894
+
895
+ return_dict = (
896
+ return_dict if return_dict is not None else self.config.use_return_dict
897
+ )
898
+
899
+ if input_ids is not None and inputs_embeds is not None:
900
+ raise ValueError(
901
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
902
+ )
903
+ elif input_ids is not None:
904
+ batch_size, seq_length = input_ids.shape
905
+ else:
906
+ raise ValueError(
907
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
908
+ )
909
+
910
+ seq_length_with_past = seq_length
911
+ past_key_values_length = 0
912
+
913
+ if past_key_values is not None:
914
+ past_key_values_length = past_key_values[0].sequence_length()
915
+ seq_length_with_past = seq_length_with_past + past_key_values_length
916
+
917
+ if position_ids is None:
918
+ device = input_ids.device
919
+ position_ids = torch.arange(
920
+ past_key_values_length,
921
+ seq_length + past_key_values_length,
922
+ dtype=torch.long,
923
+ device=device,
924
+ )
925
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
926
+ else:
927
+ position_ids = position_ids.view(-1, seq_length).long()
928
+
929
+ if inputs_embeds is None:
930
+ inputs_embeds = self.embed_tokens(input_ids)
931
+
932
+ if self._attn_implementation == "sdpa" and not output_attentions:
933
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
934
+ attention_mask,
935
+ (batch_size, seq_length),
936
+ inputs_embeds,
937
+ past_key_values_length,
938
+ )
939
+ else:
940
+ attention_mask = _prepare_4d_causal_attention_mask(
941
+ attention_mask,
942
+ (batch_size, seq_length),
943
+ inputs_embeds,
944
+ past_key_values_length,
945
+ sliding_window=self.config.sliding_window,
946
+ )
947
+ hidden_states = inputs_embeds
948
+
949
+ if self.gradient_checkpointing and self.training:
950
+ if use_cache:
951
+ use_cache = False
952
+
953
+ out = self.layers(
954
+ DecoderInput(
955
+ hidden_states,
956
+ position_ids,
957
+ attention_mask,
958
+ past_key_values,
959
+ output_hidden_states,
960
+ output_attentions,
961
+ use_cache,
962
+ self.gradient_checkpointing,
963
+ )
964
+ )
965
+
966
+ assert isinstance(out, DecoderOutput)
967
+ hidden_states = out.hidden_states
968
+ all_hidden_states = out.all_hidden_states
969
+ all_self_attns = out.all_self_attns
970
+ next_decoder_cache = out.next_decoder_cache
971
+
972
+ hidden_states = self.norm(hidden_states)
973
+
974
+ if output_hidden_states:
975
+ assert all_hidden_states is not None
976
+ all_hidden_states += (hidden_states,)
977
+
978
+ next_cache = next_decoder_cache if use_cache else None
979
+ if not return_dict:
980
+ return tuple(
981
+ v
982
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
983
+ if v is not None
984
+ )
985
+ return BaseModelOutputWithPast(
986
+ last_hidden_state=hidden_states,
987
+ past_key_values=next_cache,
988
+ hidden_states=all_hidden_states,
989
+ attentions=all_self_attns,
990
+ )
991
+
992
+ def _tokenize(
993
+ self,
994
+ texts: List[str],
995
+ tokenizer: AutoTokenizer,
996
+ add_special_tokens: bool = True,
997
+ ) -> BatchEncoding:
998
+ tokenizer.pad_token = tokenizer.eos_token
999
+ tokenizer.padding_side = "left"
1000
+
1001
+ return tokenizer(
1002
+ texts,
1003
+ return_tensors="pt",
1004
+ truncation=True,
1005
+ padding=True,
1006
+ max_length=self.config.max_length,
1007
+ add_special_tokens=add_special_tokens,
1008
+ )
1009
+
1010
+ def _tokenize_with_instruction(
1011
+ self,
1012
+ sentences: List[str],
1013
+ tokenizer: AutoTokenizer,
1014
+ instruction: str,
1015
+ add_special_tokens: bool = True,
1016
+ ) -> Tuple[BatchEncoding, torch.Tensor]:
1017
+ sentence_features = self._tokenize(
1018
+ sentences, tokenizer, add_special_tokens=False
1019
+ )
1020
+
1021
+ sentences_with_instruction = [instruction + sentence for sentence in sentences]
1022
+ sentence_features_with_instruction = self._tokenize(
1023
+ sentences_with_instruction, tokenizer, add_special_tokens
1024
+ )
1025
+
1026
+ embed_mask_list = []
1027
+ for i in range(len(sentences)):
1028
+ n_tokens = int(sentence_features["attention_mask"][i].sum().item())
1029
+ mask = torch.zeros_like(
1030
+ sentence_features_with_instruction["attention_mask"][i]
1031
+ )
1032
+ if n_tokens > 0:
1033
+ mask[-n_tokens:] = torch.ones(n_tokens, dtype=mask.dtype)
1034
+ embed_mask_list.append(mask.unsqueeze(0))
1035
+ embed_mask = torch.cat(embed_mask_list, dim=0)
1036
+
1037
+ return sentence_features_with_instruction, embed_mask
1038
+
1039
+ def _mean_pooling(
1040
+ self,
1041
+ sentence_features: BatchEncoding,
1042
+ last_hidden_state: torch.Tensor,
1043
+ embed_mask: Optional[torch.Tensor] = None,
1044
+ ) -> torch.Tensor:
1045
+ if embed_mask is None:
1046
+ mask = sentence_features["attention_mask"]
1047
+ else:
1048
+ mask = embed_mask
1049
+ sum_hidden = (
1050
+ last_hidden_state * mask.unsqueeze(-1).type_as(last_hidden_state)
1051
+ ).sum(dim=1)
1052
+ lengths = mask.sum(dim=1, keepdim=True).clamp(min=1)
1053
+ return sum_hidden / lengths
1054
+
1055
+ def encode(
1056
+ self,
1057
+ sentences: Union[str, List[str]],
1058
+ tokenizer: AutoTokenizer,
1059
+ instruction: str,
1060
+ ) -> torch.Tensor:
1061
+ if isinstance(sentences, str):
1062
+ sentences = [sentences]
1063
+
1064
+ sentence_features, embed_mask = self._tokenize_with_instruction(
1065
+ sentences,
1066
+ tokenizer,
1067
+ instruction=instruction,
1068
+ )
1069
+ sentence_features = sentence_features.to(self.device)
1070
+ embed_mask = embed_mask.to(self.device)
1071
+
1072
+ reps = self(**sentence_features)
1073
+ return self._mean_pooling(sentence_features, reps.last_hidden_state, embed_mask)
1074
+
1075
+ def encode_document(
1076
+ self,
1077
+ sentences: Union[str, List[str]],
1078
+ tokenizer: AutoTokenizer,
1079
+ ) -> torch.Tensor:
1080
+ default_document_instruction = ""
1081
+ return self.encode(sentences, tokenizer, default_document_instruction)
1082
+
1083
+ def encode_query(
1084
+ self,
1085
+ sentences: Union[str, List[str]],
1086
+ tokenizer: AutoTokenizer,
1087
+ ) -> torch.Tensor:
1088
+ default_query_instruction = "次の文章に対して、関連する文章を検索してください: "
1089
+ return self.encode(sentences, tokenizer, default_query_instruction)
special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|startoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|startoftext|>",
17
+ "unk_token": {
18
+ "content": "<|unknown|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenization_plamo.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ from shutil import copyfile
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ import sentencepiece as spm
7
+ from transformers.tokenization_utils import PreTrainedTokenizer
8
+ from transformers.utils import logging
9
+
10
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
11
+ logger = logging.get_logger(__name__)
12
+
13
+
14
+ def _get_tokenizer_threads(default: int = -1) -> int:
15
+ env_names = [
16
+ "PLAMO_TOKENIZER_NUM_THREADS",
17
+ "RAYON_NUM_THREADS",
18
+ ]
19
+ for name in env_names:
20
+ v = os.environ.get(name, None)
21
+ if v:
22
+ try:
23
+ return int(v)
24
+ except ValueError:
25
+ warnings.warn(
26
+ f"Value assigned to env `{name}` is not an integer. Current value is {v}",
27
+ category=RuntimeWarning,
28
+ stacklevel=2,
29
+ )
30
+ return default
31
+
32
+
33
+ class PlamoTokenizer(PreTrainedTokenizer): # type: ignore
34
+ vocab_files_names = VOCAB_FILES_NAMES
35
+ model_input_names = ["input_ids", "attention_mask"]
36
+
37
+ def __init__(
38
+ self,
39
+ vocab_file: str,
40
+ unk_token: str = "<unk>",
41
+ bos_token: str = "<s>",
42
+ eos_token: str = "</s>",
43
+ pad_token: str = "<pad>",
44
+ cls_token: str = "<cls>",
45
+ sep_token: str = "<sep>",
46
+ mask_token: str = "<mask>",
47
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
48
+ clean_up_tokenization_spaces: bool = False,
49
+ num_threads: int = -1,
50
+ **kwargs: Any,
51
+ ) -> None:
52
+ """Tokenizer for PLaMo.
53
+
54
+ Args:
55
+ vocab_file (str): Vocabrary file path.
56
+ unk_token (str): Unknown token.
57
+ bos_token (str): Beginning of sentence token.
58
+ eos_token (str): End of sentence token.
59
+ pad_token (str): Padding token.
60
+ cls_token (str):
61
+ Classification token, to extract a summary of an input sequence leveraging self-attention along the
62
+ full depth of the model.
63
+ sep_token (str): Separation token, to separate context and query in an input sequence.
64
+ mask_token (str): Mask token, to use when training a model with masked-language modeling.
65
+ sp_model_kwargs (Dict[atr, Any] or None): kwargs for sentencepiece model.
66
+ clean_up_tokenization_spaces (bool): Whether or not to clean up the tokenization spaces.
67
+ num_threads (int):
68
+ Number of threads. This value will be ignored if one of `PLAMO_TOKENIZER_NUM_THREADS` or
69
+ `RAYON_NUM_THREADS` is set as an environment variable.
70
+ """
71
+ if "add_bos_token" not in kwargs:
72
+ kwargs["add_bos_token"] = False
73
+ if "add_eos_token" not in kwargs:
74
+ kwargs["add_eos_token"] = False
75
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
76
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
77
+ self.sp_model.Init(model_file=vocab_file, num_threads=_get_tokenizer_threads(num_threads))
78
+ self.vocab_file = vocab_file
79
+ self.add_bos_token = kwargs["add_bos_token"]
80
+ self.add_eos_token = kwargs["add_eos_token"]
81
+
82
+ super().__init__(
83
+ vocab_file=vocab_file,
84
+ unk_token=unk_token,
85
+ bos_token=bos_token,
86
+ eos_token=eos_token,
87
+ pad_token=pad_token,
88
+ cls_token=cls_token,
89
+ sep_token=sep_token,
90
+ mask_token=mask_token,
91
+ sp_model_kwargs=sp_model_kwargs,
92
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
93
+ **kwargs,
94
+ )
95
+
96
+ # the functions below are copied from hf transformers LlamaTokenizer's implementation to fix the behaviour of the tokenizer
97
+ # https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/models/llama/tokenization_llama.py
98
+
99
+ def __getstate__(self) -> dict[str, Any]:
100
+ state = self.__dict__.copy()
101
+ state["sp_model"] = None
102
+ state["sp_model_proto"] = self.sp_model.serialized_model_proto()
103
+ return state
104
+
105
+ def __setstate__(self, d: dict[str, Any]) -> None:
106
+ self.__dict__ = d
107
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
108
+ self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
109
+
110
+ @property
111
+ def vocab_size(self) -> Any:
112
+ """Returns vocab size"""
113
+ return self.sp_model.get_piece_size()
114
+
115
+ def get_vocab(self) -> dict[str, int]:
116
+ """Returns vocab as a dict"""
117
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
118
+ vocab.update(self.added_tokens_encoder)
119
+ return vocab
120
+
121
+ def convert_tokens_to_string(self, tokens: List[int]) -> str:
122
+ """Converts a sequence of tokens (string) in a single string."""
123
+ current_sub_tokens: List[int] = []
124
+ out_string = ""
125
+ prev_is_special = False
126
+ for i, token in enumerate(tokens):
127
+ # make sure that special tokens are not decoded using sentencepiece model
128
+ if token in self.all_special_tokens:
129
+ if not prev_is_special and i != 0:
130
+ out_string += " "
131
+ out_string += self.sp_model.decode(current_sub_tokens) + token
132
+ prev_is_special = True
133
+ current_sub_tokens = []
134
+ else:
135
+ current_sub_tokens.append(token)
136
+ prev_is_special = False
137
+ out_string += self.sp_model.decode(current_sub_tokens)
138
+ return out_string
139
+
140
+ def _tokenize(self, text: str) -> Any:
141
+ """Returns a tokenized string."""
142
+ return self.sp_model.encode(text, out_type=str)
143
+
144
+ def _convert_token_to_id(self, token: str) -> Any:
145
+ """Converts a token (str) in an id using the vocab."""
146
+ return self.sp_model.piece_to_id(token)
147
+
148
+ def _convert_id_to_token(self, index: int) -> Any:
149
+ """Converts an index (integer) in a token (str) using the vocab."""
150
+ token = self.sp_model.IdToPiece(index)
151
+ return token
152
+
153
+ def build_inputs_with_special_tokens(
154
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
155
+ ) -> List[int]:
156
+ bos_token_id = [self.bos_token_id] if self.add_bos_token else []
157
+ eos_token_id = [self.eos_token_id] if self.add_eos_token else []
158
+
159
+ output = bos_token_id + token_ids_0 + eos_token_id
160
+
161
+ if token_ids_1 is not None:
162
+ output = output + bos_token_id + token_ids_1 + eos_token_id
163
+
164
+ return output
165
+
166
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
167
+ """
168
+ Save the vocabulary and special tokens file to a directory.
169
+
170
+ Args:
171
+ save_directory (`str`):
172
+ The directory in which to save the vocabulary.
173
+
174
+ Returns:
175
+ `Tuple(str)`: Paths to the files saved.
176
+ """
177
+ if not os.path.isdir(save_directory):
178
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
179
+ return ("",)
180
+ out_vocab_file = os.path.join(
181
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
182
+ )
183
+
184
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
185
+ copyfile(self.vocab_file, out_vocab_file)
186
+ elif not os.path.isfile(self.vocab_file):
187
+ with open(out_vocab_file, "wb") as fi:
188
+ content_spiece_model = self.sp_model.serialized_model_proto()
189
+ fi.write(content_spiece_model)
190
+
191
+ return (out_vocab_file,)
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9603895be773fe5807f5183bf9279da4df3a81ce5941a1a9521e8b496201c69a
3
+ size 805457
tokenizer_config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<|unknown|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<|startoftext|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "<|endoftext|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "3": {
30
+ "content": "<|pad|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ }
37
+ },
38
+ "auto_map": {
39
+ "AutoTokenizer": [
40
+ "tokenization_plamo.PlamoTokenizer",
41
+ null
42
+ ]
43
+ },
44
+ "bos_token": "<|startoftext|>",
45
+ "clean_up_tokenization_spaces": false,
46
+ "cls_token": null,
47
+ "eos_token": "<|startoftext|>",
48
+ "extra_special_tokens": {},
49
+ "local_file_only": true,
50
+ "mask_token": null,
51
+ "model_max_length": 1000000000000000019884624838656,
52
+ "pad_token": "<|startoftext|>",
53
+ "sep_token": null,
54
+ "sp_model_kwargs": {},
55
+ "tokenizer_class": "PlamoTokenizer",
56
+ "unk_token": "<|unknown|>"
57
+ }