shirayukikun commited on
Commit
5b88d12
·
verified ·
1 Parent(s): 825a110

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,3 +1,301 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - ja
5
+ ---
6
+
7
+ (English part follows Japanese one.)
8
+
9
+ # byGPT-JP-multi-lm-head 6.5B alpha
10
+
11
+ バイト単位のtokenizerを採用した,日本語言語モデルです。
12
+ 一度に4tokens (bytes) ずつ予測するための,複数のlmヘッドを持つアーキテクチャを採用しています。
13
+ また,multi byte predictionに適した独自のUnicode encodingを採用しています。
14
+ 現在開発段階のモデルであり,十分な性能には達していません。
15
+
16
+ ## 利用方法
17
+
18
+ [transformers version 4.56.1](https://github.com/huggingface/transformers/releases/tag/v4.56.1) において、動作確認しています。
19
+ 他のバージョンでは動作しない可能性があります。
20
+
21
+
22
+ ```python
23
+ import argparse
24
+
25
+ import torch
26
+ from transformers import AutoModelForCausalLM, AutoTokenizer
27
+
28
+
29
+ SAMPLE_INPUT_TEXTS = [
30
+ "日本三景一覧:\n1. 広島県, 宮島\n2. 京都府, 天橋立\n3. 宮城県, ",
31
+ "原文: I like to play soccer. 訳文: 私はサッカーをするのが好きです。\n原文: She enjoys reading books. 訳文: 彼女は本を読むのが好きです。\n原文: They went to the park. 訳文:",
32
+ ]
33
+
34
+ def main(args):
35
+ torch.manual_seed(args.seed)
36
+ device = torch.device("cuda")
37
+
38
+ tokenizer = AutoTokenizer.from_pretrained(
39
+ args.model_name_or_path,
40
+ trust_remote_code=True,
41
+ )
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ args.model_name_or_path,
44
+ dtype=torch.bfloat16,
45
+ trust_remote_code=True,
46
+ )
47
+ model.to(device)
48
+ model.eval()
49
+
50
+ input_texts = [f"{tokenizer.bos_token}{text}" for text in SAMPLE_INPUT_TEXTS]
51
+ batch = tokenizer(
52
+ input_texts, return_tensors="pt", padding="longest", add_special_tokens=False
53
+ )
54
+ batch = batch.to(device)
55
+ decoded_ids = model.generate(
56
+ input_ids=batch.input_ids,
57
+ attention_mask=batch.attention_mask,
58
+ eos_token_id=[tokenizer.encode("\n", add_special_tokens=False)],
59
+ pad_token_id=tokenizer.pad_token_id,
60
+ max_new_tokens=args.max_new_tokens,
61
+ do_sample=False,
62
+ )
63
+
64
+ decoded_texts = tokenizer.batch_decode(decoded_ids, skip_special_tokens=False)
65
+ for text in decoded_texts:
66
+ print("===")
67
+ print(f"Decoded: {text}")
68
+
69
+
70
+ if __name__ == "__main__":
71
+ parser = argparse.ArgumentParser(allow_abbrev=False)
72
+ parser.add_argument(
73
+ "--model_name_or_path",
74
+ "-m",
75
+ type=str,
76
+ default="tohoku-nlp/bygpt-jp-multi-lm-head-6.5B-alpha",
77
+ help="Path to the model or model identifier from huggingface.co/models."
78
+ )
79
+ parser.add_argument("--max_new_tokens", "-n", type=int, help="Maximum number of new tokens to generate.", default=160)
80
+ parser.add_argument("--seed", "-s", type=int, help="Random seed", default=42)
81
+ args = parser.parse_args()
82
+ main(args)
83
+
84
+ ```
85
+
86
+
87
+ ### 利用上の注意点
88
+ 本モデルは,1度に4bytes (tokens) ずつ予測するため,特殊tokenも複数トークン (bytes) で構成されています.
89
+ そのため,例えばtokenizer.eos\_tokenはlist of intです.
90
+ また,generate関数はcustom\_generateの機能により実装されており,利用可能な機能に制限があります.
91
+ また,このモデルはinstruction tuning等は実施していないモデルです.
92
+
93
+ ## モデルアーキテクチャ
94
+
95
+ [Llama](https://arxiv.org/abs/2302.13971) アーキテクチャをベースとしています。
96
+ 具体的には、以下のモジュールを採用しています。
97
+
98
+ - [SwiGLU](https://arxiv.org/abs/2002.05202)
99
+ - [Rotary Positional Embeddings (RoPE)](https://arxiv.org/abs/2104.09864)
100
+ - [Grouped Query Attention (GQA)](https://aclanthology.org/2023.emnlp-main.298/)
101
+
102
+ また,4tokens (bytes) ずつ予測するため,
103
+ - 4つのlmヘッド
104
+ - 入力のembeddingを4tokenごとにマージするモジュール
105
+ を追加しています。
106
+
107
+
108
+ ## 学習データ
109
+
110
+ [llm-jp-corpus-v3](https://gitlab.llm-jp.nii.ac.jp/datasets/llm-jp-corpus-v3) の日本語コーパスのサブセット (ja\_cc, ja\_warp\_html, ja\_warp\_pdf, ja\_wiki, kaken) を使用しました。
111
+
112
+
113
+ ### 学習設定
114
+
115
+ | | tohoku-nlp/bygpt-jp-multi-lm-head-6.5B-alpha |
116
+ | ---- | ---- |
117
+ | Training Steps | 208,000 |
118
+ | Batch Size (tokens) | 5,898,240 |
119
+ | Max Learning Rate | 5.0E-4 |
120
+ | Min Learning Rate | 1.0E-5 |
121
+ | Learning Rate Warmup Steps | 2,000 |
122
+ | Scheduler | cosine |
123
+ | Optimizer | AdamW |
124
+ | Optimizer Config | beta_1 = 0.9, beta_2 = 0.999, eps = 1.0E-8 |
125
+ | Weight Decay | 0.01 |
126
+ | Gradient Clipping | 1.0 |
127
+ | Sequence Length | 11,520 |
128
+
129
+ 学習には[Megatron-LM](https://arxiv.org/abs/1909.08053)をベースに,独自の変更を加えたコードベースを使用しています。
130
+
131
+
132
+ ## ライセンス
133
+
134
+ このモデルは Apache License 2.0 の下で配布しています。
135
+
136
+ # 免責事項
137
+
138
+ 本モデルの作者は本モデルを作成するにあたって、その内容、機能等について細心の注意を払っておりますが、モデルの出力が正確であるかどうか、安全なものであるか等について保証をするものではなく、何らの責任を負うものではありません。
139
+ 本モデルの利用により、万一、利用者に何らかの不都合や損害が発生したとしても、モデルやデータセットの作者や作者の所属組織は何らの責任を負うものではありません。
140
+
141
+ ## 謝辞
142
+
143
+ このモデルの学習にあたり様々な面でご協力いただきました [Tohoku NLP Group](https://www.nlp.ecei.tohoku.ac.jp/) の皆様に感謝いたします。
144
+
145
+ ## 作成者
146
+ - [Keito Kudo](https://x.com/k8kudo)
147
+ - [Go Kamoda](https://x.com/go2oo2)
148
+ - [Daiki Shiono](https://x.com/onely7_deep)
149
+ - [Jun Suzuki](https://x.com/drJunSuzuki)
150
+
151
+
152
+ <br>
153
+ <br>
154
+ <br>
155
+ <br>
156
+
157
+
158
+
159
+ ---
160
+ license: apache-2.0
161
+ language:
162
+ - ja
163
+ ---
164
+
165
+ (English part follows Japanese one.)
166
+
167
+ # byGPT-JP-multi-lm-head 6.5B alpha
168
+
169
+ This is a Japanese language model that adopts a byte-level tokenizer.
170
+ It adopts an architecture with multiple LM heads for predicting 4 tokens (bytes) at once.
171
+ It also adopts a unique Unicode encoding suitable for multi-byte prediction.
172
+ This is currently a model in development stage and has not yet reached sufficient performance.
173
+
174
+ ## Usage
175
+
176
+ Operation has been confirmed with [transformers version 4.56.1](https://github.com/huggingface/transformers/releases/tag/v4.56.1).
177
+ It may not work with other versions.
178
+
179
+ ```python
180
+ import argparse
181
+
182
+ import torch
183
+ from transformers import AutoModelForCausalLM, AutoTokenizer
184
+
185
+
186
+ SAMPLE_INPUT_TEXTS = [
187
+ "日本三景一覧:\n1. 広島県, 宮島\n2. 京都府, 天橋立\n3. 宮城県, ",
188
+ "原文: I like to play soccer. 訳文: 私はサッカーをするのが好きです。\n原文: She enjoys reading books. 訳文: 彼女は本を読むのが好きです。\n原文: They went to the park. 訳文:",
189
+ ]
190
+
191
+ def main(args):
192
+ torch.manual_seed(args.seed)
193
+ device = torch.device("cuda")
194
+
195
+ tokenizer = AutoTokenizer.from_pretrained(
196
+ args.model_name_or_path,
197
+ trust_remote_code=True,
198
+ )
199
+ model = AutoModelForCausalLM.from_pretrained(
200
+ args.model_name_or_path,
201
+ dtype=torch.bfloat16,
202
+ trust_remote_code=True,
203
+ )
204
+ model.to(device)
205
+ model.eval()
206
+
207
+ input_texts = [f"{tokenizer.bos_token}{text}" for text in SAMPLE_INPUT_TEXTS]
208
+ batch = tokenizer(
209
+ input_texts, return_tensors="pt", padding="longest", add_special_tokens=False
210
+ )
211
+ batch = batch.to(device)
212
+ decoded_ids = model.generate(
213
+ input_ids=batch.input_ids,
214
+ attention_mask=batch.attention_mask,
215
+ eos_token_id=[tokenizer.encode("\n", add_special_tokens=False)],
216
+ pad_token_id=tokenizer.pad_token_id,
217
+ max_new_tokens=args.max_new_tokens,
218
+ do_sample=False,
219
+ )
220
+
221
+ decoded_texts = tokenizer.batch_decode(decoded_ids, skip_special_tokens=False)
222
+ for text in decoded_texts:
223
+ print("===")
224
+ print(f"Decoded: {text}")
225
+
226
+
227
+ if __name__ == "__main__":
228
+ parser = argparse.ArgumentParser(allow_abbrev=False)
229
+ parser.add_argument(
230
+ "--model_name_or_path",
231
+ "-m",
232
+ type=str,
233
+ default="tohoku-nlp/bygpt-jp-multi-lm-head-6.5B-alpha",
234
+ help="Path to the model or model identifier from huggingface.co/models."
235
+ )
236
+ parser.add_argument("--max_new_tokens", "-n", type=int, help="Maximum number of new tokens to generate.", default=160)
237
+ parser.add_argument("--seed", "-s", type=int, help="Random seed", default=42)
238
+ args = parser.parse_args()
239
+ main(args)
240
+
241
+ ```
242
+
243
+ ### Important Notes for Usage
244
+ Since this model predicts 4 bytes (tokens) at once, special tokens are also composed of multiple tokens (bytes).
245
+ Therefore, for example, tokenizer.eos_token is a list of int.
246
+ Also, the generate function is implemented through custom_generate functionality, which has limitations on available features.
247
+ Additionally, this model has not undergone instruction tuning.
248
+
249
+ ## Model Architecture
250
+
251
+ Based on the [Llama](https://arxiv.org/abs/2302.13971) architecture.
252
+ Specifically, it adopts the following modules:
253
+
254
+ - [SwiGLU](https://arxiv.org/abs/2002.05202)
255
+ - [Rotary Positional Embeddings (RoPE)](https://arxiv.org/abs/2104.09864)
256
+ - [Grouped Query Attention (GQA)](https://aclanthology.org/2023.emnlp-main.298/)
257
+
258
+ Also, for predicting 4 tokens (bytes) at once, we have added:
259
+ - 4 LM heads
260
+ - A module to merge input embeddings every 4 tokens
261
+
262
+ ## Training Data
263
+
264
+ We used a subset of the Japanese corpus from [llm-jp-corpus-v3](https://gitlab.llm-jp.nii.ac.jp/datasets/llm-jp-corpus-v3) (ja_cc, ja_warp_html, ja_warp_pdf, ja_wiki, kaken).
265
+
266
+ ### Training Configuration
267
+
268
+ | | tohoku-nlp/bygpt-jp-multi-lm-head-6.5B-alpha |
269
+ | ---- | ---- |
270
+ | Training Steps | 170,000 |
271
+ | Batch Size (tokens) | 5,898,240 |
272
+ | Max Learning Rate | 5.0E-4 |
273
+ | Min Learning Rate | 1.0E-5 |
274
+ | Learning Rate Warmup Steps | 2,000 |
275
+ | Scheduler | cosine |
276
+ | Optimizer | AdamW |
277
+ | Optimizer Config | beta_1 = 0.9, beta_2 = 0.999, eps = 1.0E-8 |
278
+ | Weight Decay | 0.01 |
279
+ | Gradient Clipping | 1.0 |
280
+ | Sequence Length | 11,520 |
281
+
282
+ For training, we used a codebase based on [Megatron-LM](https://arxiv.org/abs/1909.08053) with our own custom modifications.
283
+
284
+ ## License
285
+
286
+ This model is distributed under the Apache License 2.0.
287
+
288
+ # Disclaimer
289
+
290
+ While the authors of this model have paid careful attention to its content and functionality during creation, we do not guarantee that the model's outputs are accurate or safe, and we assume no responsibility for them.
291
+ Even if users experience any inconvenience or damage due to the use of this model, the authors of the model and dataset and their affiliated organizations assume no responsibility.
292
+
293
+ ## Acknowledgments
294
+
295
+ We thank all members of the [Tohoku NLP Group](https://www.nlp.ecei.tohoku.ac.jp/) who cooperated with us in various aspects of training this model.
296
+
297
+ ## Authors
298
+ - [Keito Kudo](https://x.com/k8kudo)
299
+ - [Go Kamoda](https://x.com/go2oo2)
300
+ - [Daiki Shiono](https://x.com/onely7_deep)
301
+ - [Jun Suzuki](https://x.com/drJunSuzuki)
added_tokens.json ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<|begin_of_text|>": [
3
+ 194,
4
+ 128,
5
+ 64,
6
+ 0
7
+ ],
8
+ "<|cls|>": [
9
+ 195,
10
+ 128,
11
+ 64,
12
+ 0
13
+ ],
14
+ "<|end_header_id|>": [
15
+ 202,
16
+ 128,
17
+ 64,
18
+ 0
19
+ ],
20
+ "<|end_of_role|>": [
21
+ 203,
22
+ 128,
23
+ 64,
24
+ 0
25
+ ],
26
+ "<|end_of_text|>": [
27
+ 193,
28
+ 128,
29
+ 64,
30
+ 0
31
+ ],
32
+ "<|extra_id_0|>": [
33
+ 204,
34
+ 128,
35
+ 64,
36
+ 0
37
+ ],
38
+ "<|extra_id_10|>": [
39
+ 214,
40
+ 128,
41
+ 64,
42
+ 0
43
+ ],
44
+ "<|extra_id_11|>": [
45
+ 215,
46
+ 128,
47
+ 64,
48
+ 0
49
+ ],
50
+ "<|extra_id_12|>": [
51
+ 216,
52
+ 128,
53
+ 64,
54
+ 0
55
+ ],
56
+ "<|extra_id_13|>": [
57
+ 217,
58
+ 128,
59
+ 64,
60
+ 0
61
+ ],
62
+ "<|extra_id_14|>": [
63
+ 218,
64
+ 128,
65
+ 64,
66
+ 0
67
+ ],
68
+ "<|extra_id_15|>": [
69
+ 219,
70
+ 128,
71
+ 64,
72
+ 0
73
+ ],
74
+ "<|extra_id_16|>": [
75
+ 220,
76
+ 128,
77
+ 64,
78
+ 0
79
+ ],
80
+ "<|extra_id_17|>": [
81
+ 221,
82
+ 128,
83
+ 64,
84
+ 0
85
+ ],
86
+ "<|extra_id_18|>": [
87
+ 222,
88
+ 128,
89
+ 64,
90
+ 0
91
+ ],
92
+ "<|extra_id_19|>": [
93
+ 223,
94
+ 128,
95
+ 64,
96
+ 0
97
+ ],
98
+ "<|extra_id_1|>": [
99
+ 205,
100
+ 128,
101
+ 64,
102
+ 0
103
+ ],
104
+ "<|extra_id_20|>": [
105
+ 224,
106
+ 128,
107
+ 64,
108
+ 0
109
+ ],
110
+ "<|extra_id_21|>": [
111
+ 225,
112
+ 128,
113
+ 64,
114
+ 0
115
+ ],
116
+ "<|extra_id_22|>": [
117
+ 226,
118
+ 128,
119
+ 64,
120
+ 0
121
+ ],
122
+ "<|extra_id_23|>": [
123
+ 227,
124
+ 128,
125
+ 64,
126
+ 0
127
+ ],
128
+ "<|extra_id_24|>": [
129
+ 228,
130
+ 128,
131
+ 64,
132
+ 0
133
+ ],
134
+ "<|extra_id_25|>": [
135
+ 229,
136
+ 128,
137
+ 64,
138
+ 0
139
+ ],
140
+ "<|extra_id_26|>": [
141
+ 230,
142
+ 128,
143
+ 64,
144
+ 0
145
+ ],
146
+ "<|extra_id_27|>": [
147
+ 231,
148
+ 128,
149
+ 64,
150
+ 0
151
+ ],
152
+ "<|extra_id_28|>": [
153
+ 232,
154
+ 128,
155
+ 64,
156
+ 0
157
+ ],
158
+ "<|extra_id_29|>": [
159
+ 233,
160
+ 128,
161
+ 64,
162
+ 0
163
+ ],
164
+ "<|extra_id_2|>": [
165
+ 206,
166
+ 128,
167
+ 64,
168
+ 0
169
+ ],
170
+ "<|extra_id_30|>": [
171
+ 234,
172
+ 128,
173
+ 64,
174
+ 0
175
+ ],
176
+ "<|extra_id_31|>": [
177
+ 235,
178
+ 128,
179
+ 64,
180
+ 0
181
+ ],
182
+ "<|extra_id_32|>": [
183
+ 236,
184
+ 128,
185
+ 64,
186
+ 0
187
+ ],
188
+ "<|extra_id_33|>": [
189
+ 237,
190
+ 128,
191
+ 64,
192
+ 0
193
+ ],
194
+ "<|extra_id_34|>": [
195
+ 238,
196
+ 128,
197
+ 64,
198
+ 0
199
+ ],
200
+ "<|extra_id_35|>": [
201
+ 239,
202
+ 128,
203
+ 64,
204
+ 0
205
+ ],
206
+ "<|extra_id_36|>": [
207
+ 240,
208
+ 128,
209
+ 64,
210
+ 0
211
+ ],
212
+ "<|extra_id_37|>": [
213
+ 241,
214
+ 128,
215
+ 64,
216
+ 0
217
+ ],
218
+ "<|extra_id_38|>": [
219
+ 242,
220
+ 128,
221
+ 64,
222
+ 0
223
+ ],
224
+ "<|extra_id_39|>": [
225
+ 243,
226
+ 128,
227
+ 64,
228
+ 0
229
+ ],
230
+ "<|extra_id_3|>": [
231
+ 207,
232
+ 128,
233
+ 64,
234
+ 0
235
+ ],
236
+ "<|extra_id_40|>": [
237
+ 244,
238
+ 128,
239
+ 64,
240
+ 0
241
+ ],
242
+ "<|extra_id_41|>": [
243
+ 245,
244
+ 128,
245
+ 64,
246
+ 0
247
+ ],
248
+ "<|extra_id_42|>": [
249
+ 246,
250
+ 128,
251
+ 64,
252
+ 0
253
+ ],
254
+ "<|extra_id_43|>": [
255
+ 247,
256
+ 128,
257
+ 64,
258
+ 0
259
+ ],
260
+ "<|extra_id_44|>": [
261
+ 248,
262
+ 128,
263
+ 64,
264
+ 0
265
+ ],
266
+ "<|extra_id_45|>": [
267
+ 249,
268
+ 128,
269
+ 64,
270
+ 0
271
+ ],
272
+ "<|extra_id_46|>": [
273
+ 250,
274
+ 128,
275
+ 64,
276
+ 0
277
+ ],
278
+ "<|extra_id_4|>": [
279
+ 208,
280
+ 128,
281
+ 64,
282
+ 0
283
+ ],
284
+ "<|extra_id_5|>": [
285
+ 209,
286
+ 128,
287
+ 64,
288
+ 0
289
+ ],
290
+ "<|extra_id_6|>": [
291
+ 210,
292
+ 128,
293
+ 64,
294
+ 0
295
+ ],
296
+ "<|extra_id_7|>": [
297
+ 211,
298
+ 128,
299
+ 64,
300
+ 0
301
+ ],
302
+ "<|extra_id_8|>": [
303
+ 212,
304
+ 128,
305
+ 64,
306
+ 0
307
+ ],
308
+ "<|extra_id_9|>": [
309
+ 213,
310
+ 128,
311
+ 64,
312
+ 0
313
+ ],
314
+ "<|mask|>": [
315
+ 197,
316
+ 128,
317
+ 64,
318
+ 0
319
+ ],
320
+ "<|pad|>": [
321
+ 192,
322
+ 128,
323
+ 64,
324
+ 0
325
+ ],
326
+ "<|sep|>": [
327
+ 196,
328
+ 128,
329
+ 64,
330
+ 0
331
+ ],
332
+ "<|start_header_id|>": [
333
+ 201,
334
+ 128,
335
+ 64,
336
+ 0
337
+ ],
338
+ "<|vision_br|>": [
339
+ 199,
340
+ 128,
341
+ 64,
342
+ 0
343
+ ],
344
+ "<|vision_end|>": [
345
+ 200,
346
+ 128,
347
+ 64,
348
+ 0
349
+ ],
350
+ "<|vision_start|>": [
351
+ 198,
352
+ 128,
353
+ 64,
354
+ 0
355
+ ]
356
+ }
config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "None",
3
+ "architectures": [
4
+ "ByLlamaPatchForCausalLM"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_byllama_patch.ByLlamaPatchConfig",
10
+ "AutoModel": "modeling_byllama_patch.ByLlamaPatchModel",
11
+ "AutoModelForCausalLM": "modeling_byllama_patch.ByLlamaPatchForCausalLM"
12
+ },
13
+ "bos_token_id": 1,
14
+ "embedding_aggregator_type": "mean",
15
+ "eos_token_id": 2,
16
+ "head_dim": 128,
17
+ "hidden_act": "silu",
18
+ "hidden_size": 4096,
19
+ "initializer_range": 0.02,
20
+ "input_embedding_dim": 4096,
21
+ "intermediate_size": 13312,
22
+ "max_position_embeddings": 5760,
23
+ "mlp_bias": false,
24
+ "model_type": "byllama_patch",
25
+ "num_attention_heads": 32,
26
+ "num_hidden_layers": 32,
27
+ "num_key_value_heads": 8,
28
+ "num_lm_heads": 4,
29
+ "output_vocab_size": 256,
30
+ "pretraining_tp": 1,
31
+ "qkv_bias": false,
32
+ "rms_norm_eps": 1e-05,
33
+ "rope_scaling": null,
34
+ "rope_theta": 10000.0,
35
+ "tie_word_embeddings": false,
36
+ "torch_dtype": "float32",
37
+ "transformers_version": "4.46.3",
38
+ "use_cache": true,
39
+ "vocab_size": 256
40
+ }
configuration_byllama_patch.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """LLaMA model configuration"""
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.modeling_rope_utils import rope_config_validation
24
+
25
+
26
+ class ByLlamaPatchConfig(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
29
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
30
+ defaults will yield a similar configuration to that of the LLaMA-7B.
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+
36
+ Args:
37
+ vocab_size (`int`, *optional*, defaults to 32000):
38
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
39
+ `inputs_ids` passed when calling [`LlamaModel`]
40
+ hidden_size (`int`, *optional*, defaults to 4096):
41
+ Dimension of the hidden representations.
42
+ intermediate_size (`int`, *optional*, defaults to 11008):
43
+ Dimension of the MLP representations.
44
+ num_hidden_layers (`int`, *optional*, defaults to 32):
45
+ Number of hidden layers in the Transformer decoder.
46
+ num_attention_heads (`int`, *optional*, defaults to 32):
47
+ Number of attention heads for each attention layer in the Transformer decoder.
48
+ num_key_value_heads (`int`, *optional*):
49
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
50
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
51
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
52
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
53
+ by meanpooling all the original heads within that group. For more details checkout [this
54
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
55
+ `num_attention_heads`.
56
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
57
+ The non-linear activation function (function or string) in the decoder.
58
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
59
+ The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
60
+ Llama 2 up to 4096, CodeLlama up to 16384.
61
+ initializer_range (`float`, *optional*, defaults to 0.02):
62
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
63
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
64
+ The epsilon used by the rms normalization layers.
65
+ use_cache (`bool`, *optional*, defaults to `True`):
66
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
67
+ relevant if `config.is_decoder=True`.
68
+ pad_token_id (`int`, *optional*):
69
+ Padding token id.
70
+ bos_token_id (`int`, *optional*, defaults to 1):
71
+ Beginning of stream token id.
72
+ eos_token_id (`int`, *optional*, defaults to 2):
73
+ End of stream token id.
74
+ pretraining_tp (`int`, *optional*, defaults to 1):
75
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
76
+ document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
77
+ understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
78
+ results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
79
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
80
+ Whether to tie weight embeddings
81
+ rope_theta (`float`, *optional*, defaults to 10000.0):
82
+ The base period of the RoPE embeddings.
83
+ rope_scaling (`Dict`, *optional*):
84
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
85
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
86
+ accordingly.
87
+ Expected contents:
88
+ `rope_type` (`str`):
89
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
90
+ 'llama3'], with 'default' being the original RoPE implementation.
91
+ `factor` (`float`, *optional*):
92
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
93
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
94
+ original maximum pre-trained length.
95
+ `original_max_position_embeddings` (`int`, *optional*):
96
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
97
+ pretraining.
98
+ `attention_factor` (`float`, *optional*):
99
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
100
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
101
+ `factor` field to infer the suggested value.
102
+ `beta_fast` (`float`, *optional*):
103
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
104
+ ramp function. If unspecified, it defaults to 32.
105
+ `beta_slow` (`float`, *optional*):
106
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
107
+ ramp function. If unspecified, it defaults to 1.
108
+ `short_factor` (`List[float]`, *optional*):
109
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
110
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
111
+ size divided by the number of attention heads divided by 2
112
+ `long_factor` (`List[float]`, *optional*):
113
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
114
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
115
+ size divided by the number of attention heads divided by 2
116
+ `low_freq_factor` (`float`, *optional*):
117
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
118
+ `high_freq_factor` (`float`, *optional*):
119
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
120
+ attention_bias (`bool`, *optional*, defaults to `False`):
121
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
122
+ attention_dropout (`float`, *optional*, defaults to 0.0):
123
+ The dropout ratio for the attention probabilities.
124
+ mlp_bias (`bool`, *optional*, defaults to `False`):
125
+ Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
126
+ head_dim (`int`, *optional*):
127
+ The attention head dimension. If None, it will default to hidden_size // num_heads
128
+
129
+ ```python
130
+ >>> from transformers import LlamaModel, LlamaConfig
131
+
132
+ >>> # Initializing a LLaMA llama-7b style configuration
133
+ >>> configuration = LlamaConfig()
134
+
135
+ >>> # Initializing a model from the llama-7b style configuration
136
+ >>> model = LlamaModel(configuration)
137
+
138
+ >>> # Accessing the model configuration
139
+ >>> configuration = model.config
140
+ ```"""
141
+
142
+ model_type = "byllama_patch"
143
+ keys_to_ignore_at_inference = ["past_key_values"]
144
+
145
+ def __init__(
146
+ self,
147
+ vocab_size=32000,
148
+ hidden_size=4096,
149
+ intermediate_size=11008,
150
+ num_hidden_layers=32,
151
+ num_attention_heads=32,
152
+ num_key_value_heads=None,
153
+ hidden_act="silu",
154
+ max_position_embeddings=2048,
155
+ initializer_range=0.02,
156
+ rms_norm_eps=1e-6,
157
+ use_cache=True,
158
+ pad_token_id=None,
159
+ bos_token_id=1,
160
+ eos_token_id=2,
161
+ pretraining_tp=1,
162
+ tie_word_embeddings=False,
163
+ rope_theta=10000.0,
164
+ rope_scaling=None,
165
+ attention_bias=False,
166
+ qkv_bias=False,
167
+ attention_dropout=0.0,
168
+ mlp_bias=False,
169
+ head_dim=None,
170
+ num_lm_heads=4,
171
+ embedding_aggregator_type="linear",
172
+ input_embedding_dim=None,
173
+ output_vocab_size=None,
174
+ **kwargs,
175
+ ):
176
+ self.vocab_size = vocab_size
177
+ self.max_position_embeddings = max_position_embeddings
178
+ self.hidden_size = hidden_size
179
+ self.intermediate_size = intermediate_size
180
+ self.num_hidden_layers = num_hidden_layers
181
+ self.num_attention_heads = num_attention_heads
182
+
183
+ # for backward compatibility
184
+ if num_key_value_heads is None:
185
+ num_key_value_heads = num_attention_heads
186
+
187
+ self.num_key_value_heads = num_key_value_heads
188
+ self.hidden_act = hidden_act
189
+ self.initializer_range = initializer_range
190
+ self.rms_norm_eps = rms_norm_eps
191
+ self.pretraining_tp = pretraining_tp
192
+ self.use_cache = use_cache
193
+ self.rope_theta = rope_theta
194
+ self.rope_scaling = rope_scaling
195
+ self.attention_bias = attention_bias
196
+ self.qkv_bias = qkv_bias
197
+ self.attention_dropout = attention_dropout
198
+ self.mlp_bias = mlp_bias
199
+ self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
200
+ # Validate the correctness of rotary position embeddings parameters
201
+ # BC: if there is a 'type' field, copy it it to 'rope_type'.
202
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
203
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
204
+ rope_config_validation(self)
205
+
206
+ # Custom attribute
207
+ self.num_lm_heads = num_lm_heads
208
+ self.embedding_aggregator_type = embedding_aggregator_type
209
+ self.input_embedding_dim = input_embedding_dim if input_embedding_dim is not None else hidden_size
210
+ self.output_vocab_size = output_vocab_size if output_vocab_size is not None else vocab_size
211
+
212
+ super().__init__(
213
+ pad_token_id=pad_token_id,
214
+ bos_token_id=bos_token_id,
215
+ eos_token_id=eos_token_id,
216
+ tie_word_embeddings=tie_word_embeddings,
217
+ **kwargs,
218
+ )
custom_generate/generate.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, List
2
+ import inspect
3
+
4
+ import torch
5
+ from torch import nn
6
+ from transformers.cache_utils import Cache, DynamicCache
7
+ from transformers.generation.logits_process import LogitsProcessorList
8
+ from transformers.generation.stopping_criteria import (
9
+ StoppingCriteriaList,
10
+ EosTokenCriteria,
11
+ MaxLengthCriteria,
12
+ )
13
+ from transformers.generation.utils import logging
14
+
15
+ logger = logging.get_logger(__name__)
16
+
17
+
18
+ class EosTokenCriteriaForPatch(EosTokenCriteria):
19
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
20
+ is_done = torch.zeros(input_ids.shape[0], dtype=torch.bool, device=input_ids.device)
21
+ for eos_token_ids in self.eos_token_id:
22
+ eos_token_length = eos_token_ids.shape[-1]
23
+ suffix = input_ids[:, -eos_token_length:]
24
+ is_done |= torch.all(suffix == eos_token_ids, dim=-1)
25
+ return is_done
26
+
27
+
28
+
29
+ def get_initial_cache_position(model, input_ids, model_kwargs):
30
+ """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
31
+
32
+ assert input_ids.size(1) % model.config.num_lm_heads == 0, "Input length must be divisible by num_lm_heads"
33
+ seq_len = input_ids.size(1) // model.config.num_lm_heads
34
+ cache_position = torch.ones(seq_len, dtype=torch.int64).cumsum(0) - 1
35
+
36
+ past_length = 0
37
+ if model_kwargs.get("past_key_values") is not None:
38
+ cache = model_kwargs["past_key_values"]
39
+ past_length = 0
40
+ past_length = cache.get_seq_length()
41
+ cache_position = cache_position[past_length:]
42
+
43
+ model_kwargs["cache_position"] = cache_position
44
+ return model_kwargs
45
+
46
+
47
+
48
+
49
+
50
+ def prepare_inputs_for_generation(
51
+ model,
52
+ input_ids: torch.LongTensor,
53
+ past_key_values: Optional[Cache] = None,
54
+ attention_mask: Optional[torch.LongTensor] = None,
55
+ cache_position: Optional[torch.LongTensor] = None,
56
+ **kwargs,
57
+ ):
58
+ """
59
+ Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or
60
+ slicing inputs given the existing cache.
61
+
62
+ See the forward pass in the model documentation for expected arguments (different models might have different
63
+ requirements for e.g. `past_key_values`). This function should work as is for most LLMs.
64
+ """
65
+
66
+ # 1. Handle BC:
67
+ model_inputs = {
68
+ "cache_position": cache_position
69
+ }
70
+ assert input_ids.size(1) % model.config.num_lm_heads == 0, "Input length must be divisible by num_lm_heads"
71
+
72
+ # 2. Generic cache-dependent input preparation
73
+ if past_key_values is not None:
74
+ model_inputs["past_key_values"] = past_key_values
75
+ if input_ids.shape[1] != cache_position.shape[0] * model.config.num_lm_heads:
76
+ indices = torch.arange(
77
+ cache_position[0] * model.config.num_lm_heads,
78
+ (cache_position[-1] + 1) * model.config.num_lm_heads,
79
+ device=cache_position.device,
80
+ )
81
+ input_ids = input_ids[:, indices]
82
+
83
+ # 3. Prepare base model inputs
84
+ input_ids_key = "input_ids"
85
+ # `clone` calls in this function ensure a consistent stride. See #32227
86
+ model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
87
+ model_inputs["inputs_embeds"] = None
88
+
89
+ # 4. Create missing `position_ids` on the fly
90
+ if (
91
+ attention_mask is not None
92
+ and kwargs.get("position_ids") is None
93
+ and "position_ids" in set(inspect.signature(model.forward).parameters.keys())
94
+ ):
95
+ bsz = input_ids.size(0)
96
+ agregated_attention_mask = attention_mask.view(
97
+ bsz, -1, model.config.num_lm_heads
98
+ ).all(dim=-1)
99
+ position_ids = agregated_attention_mask.long().cumsum(-1) - 1
100
+ position_ids.masked_fill_(agregated_attention_mask == 0, 1)
101
+ kwargs["position_ids"] = position_ids
102
+
103
+ # 5. Slice model inputs if it's an input that should have the same length as `input_ids`
104
+ model_input = kwargs.get("position_ids")
105
+ if model_input is not None:
106
+ if past_key_values is not None:
107
+ current_input_length = (
108
+ model_inputs["inputs_embeds"].shape[1]
109
+ if model_inputs["inputs_embeds"] is not None
110
+ else model_inputs[input_ids_key].shape[1]
111
+ )
112
+ assert current_input_length % model.config.num_lm_heads == 0, "Input length must be divisible by num_lm_heads"
113
+ current_input_length //= model.config.num_lm_heads
114
+ model_input = model_input[:, -current_input_length:]
115
+ model_input = model_input.clone(memory_format=torch.contiguous_format)
116
+ model_inputs["position_ids"] = model_input
117
+
118
+ if attention_mask is not None:
119
+ model_inputs["attention_mask"] = attention_mask
120
+
121
+ # 6. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
122
+ for key, value in kwargs.items():
123
+ if key not in model_inputs:
124
+ model_inputs[key] = value
125
+
126
+ # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
127
+ model_inputs.pop("labels", None)
128
+ return model_inputs
129
+
130
+
131
+
132
+ def extract_past_from_model_output(outputs):
133
+ past_key_values = None
134
+ cache_name = "past_key_values"
135
+ if "past_key_values" in outputs:
136
+ past_key_values = outputs.past_key_values
137
+ elif "mems" in outputs:
138
+ past_key_values = outputs.mems
139
+ elif "past_buckets_states" in outputs:
140
+ past_key_values = outputs.past_buckets_states
141
+ elif "cache_params" in outputs:
142
+ past_key_values = outputs.cache_params
143
+ cache_name = "cache_params"
144
+
145
+ return cache_name, past_key_values
146
+
147
+
148
+ def update_model_kwargs_for_generation(
149
+ model,
150
+ outputs,
151
+ model_kwargs,
152
+ ):
153
+ # update past_key_values keeping its naming used in model code
154
+ cache_name, cache = extract_past_from_model_output(outputs)
155
+ model_kwargs[cache_name] = cache
156
+ if getattr(outputs, "state", None) is not None:
157
+ model_kwargs["state"] = outputs.state
158
+
159
+ # update attention mask
160
+ if "attention_mask" in model_kwargs:
161
+ attention_mask = model_kwargs["attention_mask"]
162
+ model_kwargs["attention_mask"] = torch.cat(
163
+ [
164
+ attention_mask,
165
+ attention_mask.new_ones(
166
+ (attention_mask.shape[0], model.config.num_lm_heads)
167
+ ),
168
+ ],
169
+ dim=-1
170
+ )
171
+
172
+ if model_kwargs.get("use_cache", True):
173
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
174
+ else:
175
+ past_positions = model_kwargs.pop("cache_position")
176
+ new_positions = torch.arange(
177
+ past_positions[-1] + 1, (past_positions[-1] + 1) + 1, dtype=past_positions.dtype
178
+ ).to(past_positions.device)
179
+ model_kwargs["cache_position"] = torch.cat((past_positions, new_positions))
180
+ return model_kwargs
181
+
182
+
183
+
184
+ @torch.no_grad()
185
+ def generate(
186
+ model,
187
+ input_ids: torch.LongTensor,
188
+ attention_mask: Optional[torch.LongTensor] = None,
189
+ max_new_tokens: Optional[int] = None,
190
+ max_length: Optional[int] = None,
191
+ eos_token_id: Optional[Union[List[int], List[List[int]]]] = None,
192
+ pad_token_id: Optional[List[int]] = None,
193
+ do_sample: bool = False,
194
+ logits_processor: Optional[LogitsProcessorList] = None,
195
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
196
+ ):
197
+ if logits_processor is None:
198
+ logits_processor = LogitsProcessorList()
199
+
200
+
201
+ if stopping_criteria is None:
202
+ stopping_criteria = StoppingCriteriaList()
203
+
204
+ if eos_token_id is not None:
205
+ eos_ids_tensor = torch.tensor(
206
+ eos_token_id, device=input_ids.device, dtype=torch.long
207
+ )
208
+ if eos_ids_tensor.dim() == 1:
209
+ eos_ids_tensor = eos_ids_tensor.unsqueeze(0)
210
+ stopping_criteria.append(
211
+ EosTokenCriteriaForPatch(eos_token_id=eos_ids_tensor)
212
+ )
213
+ if pad_token_id is None:
214
+ pad_token_id = eos_ids_tensor[0].clone()
215
+ logger.warning_once(
216
+ f"Setting `pad_token_id` to `eos_token_id`: {eos_token_id[0]} for open-end generation."
217
+ )
218
+ else:
219
+ pad_token_id = torch.tensor(
220
+ pad_token_id,
221
+ device=input_ids.device,
222
+ dtype=torch.long
223
+ )
224
+
225
+
226
+
227
+ if max_new_tokens is not None:
228
+ if max_length is not None:
229
+ logger.warning_once(
230
+ "`max_length` is ignored when `max_new_tokens` is set."
231
+ )
232
+ max_length = input_ids.shape[-1] + max_new_tokens
233
+
234
+ if max_length is not None:
235
+ stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
236
+
237
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
238
+
239
+ model_kwargs = {}
240
+ if attention_mask is not None:
241
+ model_kwargs["attention_mask"] = attention_mask
242
+ else:
243
+ model_kwargs["attention_mask"] = torch.ones_like(input_ids)
244
+
245
+ dynamic_cache_kwargs = {"config": model.config.get_text_config(decoder=True)}
246
+ model_kwargs["past_key_values"] = DynamicCache(**dynamic_cache_kwargs)
247
+
248
+
249
+
250
+ batch_size, cur_len = input_ids.shape
251
+ scores = ()
252
+ this_peer_finished = False
253
+ unfinished_sequences = torch.ones(
254
+ batch_size, 1, dtype=torch.long, device=input_ids.device
255
+ )
256
+ model_kwargs = get_initial_cache_position(model, input_ids, model_kwargs)
257
+
258
+ while not this_peer_finished:
259
+ # prepare model inputs
260
+ model_inputs = prepare_inputs_for_generation(model, input_ids, **model_kwargs)
261
+
262
+ # forward pass to get next token
263
+ outputs = model(**model_inputs, return_dict=True)
264
+
265
+ # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
266
+ model_kwargs = update_model_kwargs_for_generation(
267
+ model,
268
+ outputs,
269
+ model_kwargs,
270
+ )
271
+
272
+ # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
273
+ # (the clone itself is always small)
274
+ next_token_logits = outputs.logits.clone()[:, -model.config.num_lm_heads:, :].float()
275
+ next_token_logits = next_token_logits.to(input_ids.device)
276
+
277
+ # pre-process distribution
278
+ next_token_scores = logits_processor(input_ids, next_token_logits)
279
+ scores += (next_token_scores,)
280
+
281
+ # token selection
282
+ if do_sample:
283
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
284
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
285
+ else:
286
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
287
+
288
+ next_tokens[:, 0::4] += 192
289
+ next_tokens[:, 1::4] += 128
290
+ next_tokens[:, 2::4] += 64
291
+
292
+ # finished sentences should have their next token be a padding token
293
+ if has_eos_stopping_criteria:
294
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
295
+
296
+ # update generated ids, model inputs, and length for next step
297
+ input_ids = torch.cat([input_ids, next_tokens], dim=-1)
298
+
299
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores).unsqueeze(-1)
300
+ this_peer_finished = unfinished_sequences.max() == 0
301
+ cur_len += model.config.num_lm_heads
302
+
303
+ # This is needed to properly delete outputs.logits which may be very large for first iteration
304
+ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
305
+ del outputs
306
+
307
+
308
+ return input_ids
309
+
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.46.3"
6
+ }
model-00001-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc100c11ee3bbc83898c3cde6ae4d4a7c7a3814f7c9fc2fd1e1b497082be0687
3
+ size 4936898720
model-00002-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7ce10f08a3cf75b5663a41b5f461210f32be001a2bf92b2a3f57a83dc02e798
3
+ size 4999813312
model-00003-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8c3bf3759af6cff566ea61a4fd4225e8be0c298ba861e02bff776cef9653dbe
3
+ size 4966259040
model-00004-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9017d7ad64ad5733ea3642a28034b006211f77232afbf6099ac5170df713c86a
3
+ size 4999813352
model-00005-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26cd84a858ec63a0d4300659bd733b4c6651016e091b6f6f6680bc1ba72a293d
3
+ size 4932704376
model-00006-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:871f0c555486dbd142e189cc458a5cc3094ed885e22965eaab20f6dd3382c290
3
+ size 1480673040
model.safetensors.index.json ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 26316128256
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00006-of-00006.safetensors",
7
+ "model.embed_tokens.weight": "model-00001-of-00006.safetensors",
8
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00006.safetensors",
9
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00006.safetensors",
10
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00006.safetensors",
11
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00006.safetensors",
12
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00006.safetensors",
13
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00006.safetensors",
14
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00006.safetensors",
15
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00006.safetensors",
16
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00006.safetensors",
17
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00006.safetensors",
18
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00006.safetensors",
19
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00006.safetensors",
20
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00006.safetensors",
21
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00006.safetensors",
22
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00006.safetensors",
23
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00006.safetensors",
24
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00006.safetensors",
25
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00006.safetensors",
26
+ "model.layers.10.input_layernorm.weight": "model-00002-of-00006.safetensors",
27
+ "model.layers.10.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
28
+ "model.layers.10.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
29
+ "model.layers.10.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
30
+ "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00006.safetensors",
31
+ "model.layers.10.self_attn.k_proj.weight": "model-00002-of-00006.safetensors",
32
+ "model.layers.10.self_attn.o_proj.weight": "model-00002-of-00006.safetensors",
33
+ "model.layers.10.self_attn.q_proj.weight": "model-00002-of-00006.safetensors",
34
+ "model.layers.10.self_attn.v_proj.weight": "model-00002-of-00006.safetensors",
35
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00006.safetensors",
36
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
37
+ "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
38
+ "model.layers.11.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
39
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00006.safetensors",
40
+ "model.layers.11.self_attn.k_proj.weight": "model-00002-of-00006.safetensors",
41
+ "model.layers.11.self_attn.o_proj.weight": "model-00002-of-00006.safetensors",
42
+ "model.layers.11.self_attn.q_proj.weight": "model-00002-of-00006.safetensors",
43
+ "model.layers.11.self_attn.v_proj.weight": "model-00002-of-00006.safetensors",
44
+ "model.layers.12.input_layernorm.weight": "model-00003-of-00006.safetensors",
45
+ "model.layers.12.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
46
+ "model.layers.12.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
47
+ "model.layers.12.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
48
+ "model.layers.12.post_attention_layernorm.weight": "model-00003-of-00006.safetensors",
49
+ "model.layers.12.self_attn.k_proj.weight": "model-00003-of-00006.safetensors",
50
+ "model.layers.12.self_attn.o_proj.weight": "model-00003-of-00006.safetensors",
51
+ "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00006.safetensors",
52
+ "model.layers.12.self_attn.v_proj.weight": "model-00003-of-00006.safetensors",
53
+ "model.layers.13.input_layernorm.weight": "model-00003-of-00006.safetensors",
54
+ "model.layers.13.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
55
+ "model.layers.13.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
56
+ "model.layers.13.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
57
+ "model.layers.13.post_attention_layernorm.weight": "model-00003-of-00006.safetensors",
58
+ "model.layers.13.self_attn.k_proj.weight": "model-00003-of-00006.safetensors",
59
+ "model.layers.13.self_attn.o_proj.weight": "model-00003-of-00006.safetensors",
60
+ "model.layers.13.self_attn.q_proj.weight": "model-00003-of-00006.safetensors",
61
+ "model.layers.13.self_attn.v_proj.weight": "model-00003-of-00006.safetensors",
62
+ "model.layers.14.input_layernorm.weight": "model-00003-of-00006.safetensors",
63
+ "model.layers.14.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
64
+ "model.layers.14.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
65
+ "model.layers.14.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
66
+ "model.layers.14.post_attention_layernorm.weight": "model-00003-of-00006.safetensors",
67
+ "model.layers.14.self_attn.k_proj.weight": "model-00003-of-00006.safetensors",
68
+ "model.layers.14.self_attn.o_proj.weight": "model-00003-of-00006.safetensors",
69
+ "model.layers.14.self_attn.q_proj.weight": "model-00003-of-00006.safetensors",
70
+ "model.layers.14.self_attn.v_proj.weight": "model-00003-of-00006.safetensors",
71
+ "model.layers.15.input_layernorm.weight": "model-00003-of-00006.safetensors",
72
+ "model.layers.15.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
73
+ "model.layers.15.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
74
+ "model.layers.15.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
75
+ "model.layers.15.post_attention_layernorm.weight": "model-00003-of-00006.safetensors",
76
+ "model.layers.15.self_attn.k_proj.weight": "model-00003-of-00006.safetensors",
77
+ "model.layers.15.self_attn.o_proj.weight": "model-00003-of-00006.safetensors",
78
+ "model.layers.15.self_attn.q_proj.weight": "model-00003-of-00006.safetensors",
79
+ "model.layers.15.self_attn.v_proj.weight": "model-00003-of-00006.safetensors",
80
+ "model.layers.16.input_layernorm.weight": "model-00003-of-00006.safetensors",
81
+ "model.layers.16.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
82
+ "model.layers.16.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
83
+ "model.layers.16.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
84
+ "model.layers.16.post_attention_layernorm.weight": "model-00003-of-00006.safetensors",
85
+ "model.layers.16.self_attn.k_proj.weight": "model-00003-of-00006.safetensors",
86
+ "model.layers.16.self_attn.o_proj.weight": "model-00003-of-00006.safetensors",
87
+ "model.layers.16.self_attn.q_proj.weight": "model-00003-of-00006.safetensors",
88
+ "model.layers.16.self_attn.v_proj.weight": "model-00003-of-00006.safetensors",
89
+ "model.layers.17.input_layernorm.weight": "model-00003-of-00006.safetensors",
90
+ "model.layers.17.mlp.down_proj.weight": "model-00003-of-00006.safetensors",
91
+ "model.layers.17.mlp.gate_proj.weight": "model-00003-of-00006.safetensors",
92
+ "model.layers.17.mlp.up_proj.weight": "model-00003-of-00006.safetensors",
93
+ "model.layers.17.post_attention_layernorm.weight": "model-00003-of-00006.safetensors",
94
+ "model.layers.17.self_attn.k_proj.weight": "model-00003-of-00006.safetensors",
95
+ "model.layers.17.self_attn.o_proj.weight": "model-00003-of-00006.safetensors",
96
+ "model.layers.17.self_attn.q_proj.weight": "model-00003-of-00006.safetensors",
97
+ "model.layers.17.self_attn.v_proj.weight": "model-00003-of-00006.safetensors",
98
+ "model.layers.18.input_layernorm.weight": "model-00004-of-00006.safetensors",
99
+ "model.layers.18.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
100
+ "model.layers.18.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
101
+ "model.layers.18.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
102
+ "model.layers.18.post_attention_layernorm.weight": "model-00004-of-00006.safetensors",
103
+ "model.layers.18.self_attn.k_proj.weight": "model-00003-of-00006.safetensors",
104
+ "model.layers.18.self_attn.o_proj.weight": "model-00004-of-00006.safetensors",
105
+ "model.layers.18.self_attn.q_proj.weight": "model-00003-of-00006.safetensors",
106
+ "model.layers.18.self_attn.v_proj.weight": "model-00003-of-00006.safetensors",
107
+ "model.layers.19.input_layernorm.weight": "model-00004-of-00006.safetensors",
108
+ "model.layers.19.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
109
+ "model.layers.19.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
110
+ "model.layers.19.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
111
+ "model.layers.19.post_attention_layernorm.weight": "model-00004-of-00006.safetensors",
112
+ "model.layers.19.self_attn.k_proj.weight": "model-00004-of-00006.safetensors",
113
+ "model.layers.19.self_attn.o_proj.weight": "model-00004-of-00006.safetensors",
114
+ "model.layers.19.self_attn.q_proj.weight": "model-00004-of-00006.safetensors",
115
+ "model.layers.19.self_attn.v_proj.weight": "model-00004-of-00006.safetensors",
116
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00006.safetensors",
117
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00006.safetensors",
118
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00006.safetensors",
119
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00006.safetensors",
120
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00006.safetensors",
121
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00006.safetensors",
122
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00006.safetensors",
123
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00006.safetensors",
124
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00006.safetensors",
125
+ "model.layers.20.input_layernorm.weight": "model-00004-of-00006.safetensors",
126
+ "model.layers.20.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
127
+ "model.layers.20.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
128
+ "model.layers.20.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
129
+ "model.layers.20.post_attention_layernorm.weight": "model-00004-of-00006.safetensors",
130
+ "model.layers.20.self_attn.k_proj.weight": "model-00004-of-00006.safetensors",
131
+ "model.layers.20.self_attn.o_proj.weight": "model-00004-of-00006.safetensors",
132
+ "model.layers.20.self_attn.q_proj.weight": "model-00004-of-00006.safetensors",
133
+ "model.layers.20.self_attn.v_proj.weight": "model-00004-of-00006.safetensors",
134
+ "model.layers.21.input_layernorm.weight": "model-00004-of-00006.safetensors",
135
+ "model.layers.21.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
136
+ "model.layers.21.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
137
+ "model.layers.21.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
138
+ "model.layers.21.post_attention_layernorm.weight": "model-00004-of-00006.safetensors",
139
+ "model.layers.21.self_attn.k_proj.weight": "model-00004-of-00006.safetensors",
140
+ "model.layers.21.self_attn.o_proj.weight": "model-00004-of-00006.safetensors",
141
+ "model.layers.21.self_attn.q_proj.weight": "model-00004-of-00006.safetensors",
142
+ "model.layers.21.self_attn.v_proj.weight": "model-00004-of-00006.safetensors",
143
+ "model.layers.22.input_layernorm.weight": "model-00004-of-00006.safetensors",
144
+ "model.layers.22.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
145
+ "model.layers.22.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
146
+ "model.layers.22.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
147
+ "model.layers.22.post_attention_layernorm.weight": "model-00004-of-00006.safetensors",
148
+ "model.layers.22.self_attn.k_proj.weight": "model-00004-of-00006.safetensors",
149
+ "model.layers.22.self_attn.o_proj.weight": "model-00004-of-00006.safetensors",
150
+ "model.layers.22.self_attn.q_proj.weight": "model-00004-of-00006.safetensors",
151
+ "model.layers.22.self_attn.v_proj.weight": "model-00004-of-00006.safetensors",
152
+ "model.layers.23.input_layernorm.weight": "model-00004-of-00006.safetensors",
153
+ "model.layers.23.mlp.down_proj.weight": "model-00004-of-00006.safetensors",
154
+ "model.layers.23.mlp.gate_proj.weight": "model-00004-of-00006.safetensors",
155
+ "model.layers.23.mlp.up_proj.weight": "model-00004-of-00006.safetensors",
156
+ "model.layers.23.post_attention_layernorm.weight": "model-00004-of-00006.safetensors",
157
+ "model.layers.23.self_attn.k_proj.weight": "model-00004-of-00006.safetensors",
158
+ "model.layers.23.self_attn.o_proj.weight": "model-00004-of-00006.safetensors",
159
+ "model.layers.23.self_attn.q_proj.weight": "model-00004-of-00006.safetensors",
160
+ "model.layers.23.self_attn.v_proj.weight": "model-00004-of-00006.safetensors",
161
+ "model.layers.24.input_layernorm.weight": "model-00005-of-00006.safetensors",
162
+ "model.layers.24.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
163
+ "model.layers.24.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
164
+ "model.layers.24.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
165
+ "model.layers.24.post_attention_layernorm.weight": "model-00005-of-00006.safetensors",
166
+ "model.layers.24.self_attn.k_proj.weight": "model-00004-of-00006.safetensors",
167
+ "model.layers.24.self_attn.o_proj.weight": "model-00004-of-00006.safetensors",
168
+ "model.layers.24.self_attn.q_proj.weight": "model-00004-of-00006.safetensors",
169
+ "model.layers.24.self_attn.v_proj.weight": "model-00004-of-00006.safetensors",
170
+ "model.layers.25.input_layernorm.weight": "model-00005-of-00006.safetensors",
171
+ "model.layers.25.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
172
+ "model.layers.25.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
173
+ "model.layers.25.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
174
+ "model.layers.25.post_attention_layernorm.weight": "model-00005-of-00006.safetensors",
175
+ "model.layers.25.self_attn.k_proj.weight": "model-00005-of-00006.safetensors",
176
+ "model.layers.25.self_attn.o_proj.weight": "model-00005-of-00006.safetensors",
177
+ "model.layers.25.self_attn.q_proj.weight": "model-00005-of-00006.safetensors",
178
+ "model.layers.25.self_attn.v_proj.weight": "model-00005-of-00006.safetensors",
179
+ "model.layers.26.input_layernorm.weight": "model-00005-of-00006.safetensors",
180
+ "model.layers.26.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
181
+ "model.layers.26.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
182
+ "model.layers.26.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
183
+ "model.layers.26.post_attention_layernorm.weight": "model-00005-of-00006.safetensors",
184
+ "model.layers.26.self_attn.k_proj.weight": "model-00005-of-00006.safetensors",
185
+ "model.layers.26.self_attn.o_proj.weight": "model-00005-of-00006.safetensors",
186
+ "model.layers.26.self_attn.q_proj.weight": "model-00005-of-00006.safetensors",
187
+ "model.layers.26.self_attn.v_proj.weight": "model-00005-of-00006.safetensors",
188
+ "model.layers.27.input_layernorm.weight": "model-00005-of-00006.safetensors",
189
+ "model.layers.27.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
190
+ "model.layers.27.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
191
+ "model.layers.27.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
192
+ "model.layers.27.post_attention_layernorm.weight": "model-00005-of-00006.safetensors",
193
+ "model.layers.27.self_attn.k_proj.weight": "model-00005-of-00006.safetensors",
194
+ "model.layers.27.self_attn.o_proj.weight": "model-00005-of-00006.safetensors",
195
+ "model.layers.27.self_attn.q_proj.weight": "model-00005-of-00006.safetensors",
196
+ "model.layers.27.self_attn.v_proj.weight": "model-00005-of-00006.safetensors",
197
+ "model.layers.28.input_layernorm.weight": "model-00005-of-00006.safetensors",
198
+ "model.layers.28.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
199
+ "model.layers.28.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
200
+ "model.layers.28.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
201
+ "model.layers.28.post_attention_layernorm.weight": "model-00005-of-00006.safetensors",
202
+ "model.layers.28.self_attn.k_proj.weight": "model-00005-of-00006.safetensors",
203
+ "model.layers.28.self_attn.o_proj.weight": "model-00005-of-00006.safetensors",
204
+ "model.layers.28.self_attn.q_proj.weight": "model-00005-of-00006.safetensors",
205
+ "model.layers.28.self_attn.v_proj.weight": "model-00005-of-00006.safetensors",
206
+ "model.layers.29.input_layernorm.weight": "model-00005-of-00006.safetensors",
207
+ "model.layers.29.mlp.down_proj.weight": "model-00005-of-00006.safetensors",
208
+ "model.layers.29.mlp.gate_proj.weight": "model-00005-of-00006.safetensors",
209
+ "model.layers.29.mlp.up_proj.weight": "model-00005-of-00006.safetensors",
210
+ "model.layers.29.post_attention_layernorm.weight": "model-00005-of-00006.safetensors",
211
+ "model.layers.29.self_attn.k_proj.weight": "model-00005-of-00006.safetensors",
212
+ "model.layers.29.self_attn.o_proj.weight": "model-00005-of-00006.safetensors",
213
+ "model.layers.29.self_attn.q_proj.weight": "model-00005-of-00006.safetensors",
214
+ "model.layers.29.self_attn.v_proj.weight": "model-00005-of-00006.safetensors",
215
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00006.safetensors",
216
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00006.safetensors",
217
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00006.safetensors",
218
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00006.safetensors",
219
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00006.safetensors",
220
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00006.safetensors",
221
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00006.safetensors",
222
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00006.safetensors",
223
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00006.safetensors",
224
+ "model.layers.30.input_layernorm.weight": "model-00006-of-00006.safetensors",
225
+ "model.layers.30.mlp.down_proj.weight": "model-00006-of-00006.safetensors",
226
+ "model.layers.30.mlp.gate_proj.weight": "model-00006-of-00006.safetensors",
227
+ "model.layers.30.mlp.up_proj.weight": "model-00006-of-00006.safetensors",
228
+ "model.layers.30.post_attention_layernorm.weight": "model-00006-of-00006.safetensors",
229
+ "model.layers.30.self_attn.k_proj.weight": "model-00005-of-00006.safetensors",
230
+ "model.layers.30.self_attn.o_proj.weight": "model-00005-of-00006.safetensors",
231
+ "model.layers.30.self_attn.q_proj.weight": "model-00005-of-00006.safetensors",
232
+ "model.layers.30.self_attn.v_proj.weight": "model-00005-of-00006.safetensors",
233
+ "model.layers.31.input_layernorm.weight": "model-00006-of-00006.safetensors",
234
+ "model.layers.31.mlp.down_proj.weight": "model-00006-of-00006.safetensors",
235
+ "model.layers.31.mlp.gate_proj.weight": "model-00006-of-00006.safetensors",
236
+ "model.layers.31.mlp.up_proj.weight": "model-00006-of-00006.safetensors",
237
+ "model.layers.31.post_attention_layernorm.weight": "model-00006-of-00006.safetensors",
238
+ "model.layers.31.self_attn.k_proj.weight": "model-00006-of-00006.safetensors",
239
+ "model.layers.31.self_attn.o_proj.weight": "model-00006-of-00006.safetensors",
240
+ "model.layers.31.self_attn.q_proj.weight": "model-00006-of-00006.safetensors",
241
+ "model.layers.31.self_attn.v_proj.weight": "model-00006-of-00006.safetensors",
242
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00006.safetensors",
243
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00006.safetensors",
244
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00006.safetensors",
245
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00006.safetensors",
246
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00006.safetensors",
247
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00006.safetensors",
248
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00006.safetensors",
249
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00006.safetensors",
250
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00006.safetensors",
251
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00006.safetensors",
252
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00006.safetensors",
253
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00006.safetensors",
254
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00006.safetensors",
255
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00006.safetensors",
256
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00006.safetensors",
257
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00006.safetensors",
258
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00006.safetensors",
259
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00006.safetensors",
260
+ "model.layers.6.input_layernorm.weight": "model-00002-of-00006.safetensors",
261
+ "model.layers.6.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
262
+ "model.layers.6.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
263
+ "model.layers.6.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
264
+ "model.layers.6.post_attention_layernorm.weight": "model-00002-of-00006.safetensors",
265
+ "model.layers.6.self_attn.k_proj.weight": "model-00002-of-00006.safetensors",
266
+ "model.layers.6.self_attn.o_proj.weight": "model-00002-of-00006.safetensors",
267
+ "model.layers.6.self_attn.q_proj.weight": "model-00002-of-00006.safetensors",
268
+ "model.layers.6.self_attn.v_proj.weight": "model-00002-of-00006.safetensors",
269
+ "model.layers.7.input_layernorm.weight": "model-00002-of-00006.safetensors",
270
+ "model.layers.7.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
271
+ "model.layers.7.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
272
+ "model.layers.7.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
273
+ "model.layers.7.post_attention_layernorm.weight": "model-00002-of-00006.safetensors",
274
+ "model.layers.7.self_attn.k_proj.weight": "model-00002-of-00006.safetensors",
275
+ "model.layers.7.self_attn.o_proj.weight": "model-00002-of-00006.safetensors",
276
+ "model.layers.7.self_attn.q_proj.weight": "model-00002-of-00006.safetensors",
277
+ "model.layers.7.self_attn.v_proj.weight": "model-00002-of-00006.safetensors",
278
+ "model.layers.8.input_layernorm.weight": "model-00002-of-00006.safetensors",
279
+ "model.layers.8.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
280
+ "model.layers.8.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
281
+ "model.layers.8.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
282
+ "model.layers.8.post_attention_layernorm.weight": "model-00002-of-00006.safetensors",
283
+ "model.layers.8.self_attn.k_proj.weight": "model-00002-of-00006.safetensors",
284
+ "model.layers.8.self_attn.o_proj.weight": "model-00002-of-00006.safetensors",
285
+ "model.layers.8.self_attn.q_proj.weight": "model-00002-of-00006.safetensors",
286
+ "model.layers.8.self_attn.v_proj.weight": "model-00002-of-00006.safetensors",
287
+ "model.layers.9.input_layernorm.weight": "model-00002-of-00006.safetensors",
288
+ "model.layers.9.mlp.down_proj.weight": "model-00002-of-00006.safetensors",
289
+ "model.layers.9.mlp.gate_proj.weight": "model-00002-of-00006.safetensors",
290
+ "model.layers.9.mlp.up_proj.weight": "model-00002-of-00006.safetensors",
291
+ "model.layers.9.post_attention_layernorm.weight": "model-00002-of-00006.safetensors",
292
+ "model.layers.9.self_attn.k_proj.weight": "model-00002-of-00006.safetensors",
293
+ "model.layers.9.self_attn.o_proj.weight": "model-00002-of-00006.safetensors",
294
+ "model.layers.9.self_attn.q_proj.weight": "model-00002-of-00006.safetensors",
295
+ "model.layers.9.self_attn.v_proj.weight": "model-00002-of-00006.safetensors",
296
+ "model.norm.weight": "model-00006-of-00006.safetensors"
297
+ }
298
+ }
modeling_byllama_patch.py ADDED
@@ -0,0 +1,1299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ import math
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
30
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
31
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
32
+ from transformers.modeling_outputs import (
33
+ BaseModelOutputWithPast,
34
+ CausalLMOutputWithPast,
35
+ )
36
+ from transformers.generation import GenerationMixin
37
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
38
+ from transformers.modeling_utils import PreTrainedModel
39
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
40
+ from transformers.utils import (
41
+ add_start_docstrings,
42
+ add_start_docstrings_to_model_forward,
43
+ is_flash_attn_greater_or_equal_2_10,
44
+ logging,
45
+ replace_return_docstrings,
46
+ )
47
+
48
+ from .configuration_byllama_patch import ByLlamaPatchConfig
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf"
53
+ _CONFIG_FOR_DOC = "ByLlamaPatchConfig"
54
+
55
+
56
+
57
+ class MeanEmbeddingAggregator(nn.Module):
58
+ def __init__(self, config: ByLlamaPatchConfig):
59
+ super().__init__()
60
+ self.config = config
61
+ self.group_size = config.num_lm_heads
62
+ assert config.hidden_size % self.group_size == 0, "hidden_size must be divisible by group_size"
63
+ assert config.hidden_size == config.input_embedding_dim, "hidden_size must be equal to input_embedding_dim"
64
+
65
+ def forward(self, embeddings: torch.Tensor):
66
+ """
67
+ Args:
68
+ embeddings (:obj:`torch.Tensor`): The embeddings to aggregate. Should have shape `(batch_size, seq_len, hidden_size)`.
69
+ Output:
70
+ group_embeddings (:obj:`torch.Tensor`): The aggregated embeddings. Should have shape `(batch_size, seq_len // group_size, hidden_size)`.
71
+ """
72
+ batch_size, seq_len, hidden_size = embeddings.size()
73
+ group_embeddings = embeddings.view(batch_size, seq_len // self.group_size, self.group_size, hidden_size)
74
+ group_embeddings = group_embeddings.sum(dim=2)
75
+ return group_embeddings
76
+
77
+
78
+ class LinearEmbeddingAggregator(nn.Module):
79
+ def __init__(self, config: ByLlamaPatchConfig):
80
+ super().__init__()
81
+ self.config = config
82
+ self.group_size = config.num_lm_heads
83
+ self.linear = nn.Linear(
84
+ config.input_embedding_dim * self.group_size,
85
+ config.hidden_size,
86
+ bias=False
87
+ )
88
+ assert config.hidden_size % self.group_size == 0, "hidden_size must be divisible by group_size"
89
+
90
+ def forward(self, embeddings: torch.Tensor):
91
+ """
92
+ Args:
93
+ embeddings (:obj:`torch.Tensor`): The embeddings to aggregate. Should have shape `(batch_size, seq_len, hidden_size)`.
94
+ Output:
95
+ group_embeddings (:obj:`torch.Tensor`): The aggregated embeddings. Should have shape `(batch_size, seq_len // group_size, hidden_size)`.
96
+ """
97
+ batch_size, seq_len, input_embedding_dim = embeddings.size()
98
+ assert seq_len % self.group_size == 0, "seq_len must be divisible by group_size"
99
+ group_embeddings = embeddings.view(batch_size, seq_len // self.group_size, self.group_size * input_embedding_dim)
100
+ return self.linear(group_embeddings)
101
+
102
+
103
+ class LlamaRMSNorm(nn.Module):
104
+ def __init__(self, hidden_size, eps=1e-6):
105
+ """
106
+ LlamaRMSNorm is equivalent to T5LayerNorm
107
+ """
108
+ super().__init__()
109
+ self.weight = nn.Parameter(torch.ones(hidden_size))
110
+ self.variance_epsilon = eps
111
+
112
+ def forward(self, hidden_states):
113
+ input_dtype = hidden_states.dtype
114
+ hidden_states = hidden_states.to(torch.float32)
115
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
116
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
117
+ return self.weight * hidden_states.to(input_dtype)
118
+
119
+ def extra_repr(self):
120
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
121
+
122
+
123
+ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
124
+
125
+
126
+ class LlamaRotaryEmbedding(nn.Module):
127
+ def __init__(
128
+ self,
129
+ dim=None,
130
+ max_position_embeddings=2048,
131
+ base=10000,
132
+ device=None,
133
+ scaling_factor=1.0,
134
+ rope_type="default",
135
+ config: Optional[ByLlamaPatchConfig] = None,
136
+ ):
137
+ super().__init__()
138
+ # TODO (joao): remove the `if` below, only used for BC
139
+ self.rope_kwargs = {}
140
+ if config is None:
141
+ logger.warning_once(
142
+ "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the "
143
+ "`config` argument. All other arguments will be removed in v4.46"
144
+ )
145
+ self.rope_kwargs = {
146
+ "rope_type": rope_type,
147
+ "factor": scaling_factor,
148
+ "dim": dim,
149
+ "base": base,
150
+ "max_position_embeddings": max_position_embeddings,
151
+ }
152
+ self.rope_type = rope_type
153
+ self.max_seq_len_cached = max_position_embeddings
154
+ self.original_max_seq_len = max_position_embeddings
155
+ else:
156
+ # BC: "rope_type" was originally "type"
157
+ if config.rope_scaling is not None:
158
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
159
+ else:
160
+ self.rope_type = "default"
161
+ self.max_seq_len_cached = config.max_position_embeddings
162
+ self.original_max_seq_len = config.max_position_embeddings
163
+
164
+ self.config = config
165
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
166
+
167
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
168
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
169
+ self.original_inv_freq = self.inv_freq
170
+
171
+ def _dynamic_frequency_update(self, position_ids, device):
172
+ """
173
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
174
+ 1 - growing beyond the cached sequence length (allow scaling)
175
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
176
+ """
177
+ seq_len = torch.max(position_ids) + 1
178
+ if seq_len > self.max_seq_len_cached: # growth
179
+ inv_freq, self.attention_scaling = self.rope_init_fn(
180
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
181
+ )
182
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
183
+ self.max_seq_len_cached = seq_len
184
+
185
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
186
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
187
+ self.max_seq_len_cached = self.original_max_seq_len
188
+
189
+ @torch.no_grad()
190
+ def forward(self, x, position_ids):
191
+ if "dynamic" in self.rope_type:
192
+ self._dynamic_frequency_update(position_ids, device=x.device)
193
+
194
+ # Core RoPE block
195
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
196
+ position_ids_expanded = position_ids[:, None, :].float()
197
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
198
+ device_type = x.device.type
199
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
200
+ with torch.autocast(device_type=device_type, enabled=False):
201
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
202
+ emb = torch.cat((freqs, freqs), dim=-1)
203
+ cos = emb.cos()
204
+ sin = emb.sin()
205
+
206
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
207
+ cos = cos * self.attention_scaling
208
+ sin = sin * self.attention_scaling
209
+
210
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
211
+
212
+
213
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
214
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
215
+
216
+ def __init__(self, *args, **kwargs):
217
+ logger.warning_once(
218
+ "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
219
+ "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
220
+ )
221
+ kwargs["rope_type"] = "linear"
222
+ super().__init__(*args, **kwargs)
223
+
224
+
225
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
226
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
227
+
228
+ def __init__(self, *args, **kwargs):
229
+ logger.warning_once(
230
+ "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
231
+ "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
232
+ "__init__)."
233
+ )
234
+ kwargs["rope_type"] = "dynamic"
235
+ super().__init__(*args, **kwargs)
236
+
237
+
238
+ def rotate_half(x):
239
+ """Rotates half the hidden dims of the input."""
240
+ x1 = x[..., : x.shape[-1] // 2]
241
+ x2 = x[..., x.shape[-1] // 2 :]
242
+ return torch.cat((-x2, x1), dim=-1)
243
+
244
+
245
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
246
+ """Applies Rotary Position Embedding to the query and key tensors.
247
+
248
+ Args:
249
+ q (`torch.Tensor`): The query tensor.
250
+ k (`torch.Tensor`): The key tensor.
251
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
252
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
253
+ position_ids (`torch.Tensor`, *optional*):
254
+ Deprecated and unused.
255
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
256
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
257
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
258
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
259
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
260
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
261
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
262
+ Returns:
263
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
264
+ """
265
+ cos = cos.unsqueeze(unsqueeze_dim)
266
+ sin = sin.unsqueeze(unsqueeze_dim)
267
+ q_embed = (q * cos) + (rotate_half(q) * sin)
268
+ k_embed = (k * cos) + (rotate_half(k) * sin)
269
+ return q_embed, k_embed
270
+
271
+
272
+ class LlamaMLP(nn.Module):
273
+ def __init__(self, config):
274
+ super().__init__()
275
+ self.config = config
276
+ self.hidden_size = config.hidden_size
277
+ self.intermediate_size = config.intermediate_size
278
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
279
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
280
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
281
+ self.act_fn = ACT2FN[config.hidden_act]
282
+
283
+ def forward(self, x):
284
+ if self.config.pretraining_tp > 1:
285
+ slice = self.intermediate_size // self.config.pretraining_tp
286
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
287
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
288
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
289
+
290
+ gate_proj = torch.cat(
291
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
292
+ )
293
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
294
+
295
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
296
+ down_proj = [
297
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
298
+ ]
299
+ down_proj = sum(down_proj)
300
+ else:
301
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
302
+
303
+ return down_proj
304
+
305
+
306
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
307
+ """
308
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
309
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
310
+ """
311
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
312
+ if n_rep == 1:
313
+ return hidden_states
314
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
315
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
316
+
317
+
318
+ class LlamaAttention(nn.Module):
319
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
320
+
321
+ def __init__(self, config: ByLlamaPatchConfig, layer_idx: Optional[int] = None):
322
+ super().__init__()
323
+ self.config = config
324
+ self.layer_idx = layer_idx
325
+ if layer_idx is None:
326
+ logger.warning_once(
327
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
328
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
329
+ "when creating this class."
330
+ )
331
+
332
+ self.attention_dropout = config.attention_dropout
333
+ self.hidden_size = config.hidden_size
334
+ self.num_heads = config.num_attention_heads
335
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
336
+ self.num_key_value_heads = config.num_key_value_heads
337
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
338
+ self.max_position_embeddings = config.max_position_embeddings
339
+ self.rope_theta = config.rope_theta
340
+ self.is_causal = True
341
+
342
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias or getattr(self.config, "qkv_bias", False))
343
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias or getattr(self.config, "qkv_bias", False))
344
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias or getattr(self.config, "qkv_bias", False))
345
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
346
+
347
+ # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
348
+ self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
349
+
350
+ def forward(
351
+ self,
352
+ hidden_states: torch.Tensor,
353
+ attention_mask: Optional[torch.Tensor] = None,
354
+ position_ids: Optional[torch.LongTensor] = None,
355
+ past_key_value: Optional[Cache] = None,
356
+ output_attentions: bool = False,
357
+ use_cache: bool = False,
358
+ cache_position: Optional[torch.LongTensor] = None,
359
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
360
+ **kwargs,
361
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
362
+ bsz, q_len, _ = hidden_states.size()
363
+
364
+ if self.config.pretraining_tp > 1:
365
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
366
+ query_slices = self.q_proj.weight.split(
367
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
368
+ )
369
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
370
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
371
+
372
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
373
+ query_states = torch.cat(query_states, dim=-1)
374
+
375
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
376
+ key_states = torch.cat(key_states, dim=-1)
377
+
378
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
379
+ value_states = torch.cat(value_states, dim=-1)
380
+
381
+ else:
382
+ query_states = self.q_proj(hidden_states)
383
+ key_states = self.k_proj(hidden_states)
384
+ value_states = self.v_proj(hidden_states)
385
+
386
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
387
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
388
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
389
+
390
+ if position_embeddings is None:
391
+ logger.warning_once(
392
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
393
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
394
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
395
+ "removed and `position_embeddings` will be mandatory."
396
+ )
397
+ cos, sin = self.rotary_emb(value_states, position_ids)
398
+ else:
399
+ cos, sin = position_embeddings
400
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
401
+
402
+ if past_key_value is not None:
403
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
404
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
405
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
406
+
407
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
408
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
409
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
410
+
411
+ if attention_mask is not None: # no matter the length, we just slice it
412
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
413
+ attn_weights = attn_weights + causal_mask
414
+
415
+ # upcast attention to fp32
416
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
417
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
418
+ attn_output = torch.matmul(attn_weights, value_states)
419
+
420
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
421
+ raise ValueError(
422
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
423
+ f" {attn_output.size()}"
424
+ )
425
+
426
+ attn_output = attn_output.transpose(1, 2).contiguous()
427
+
428
+ attn_output = attn_output.reshape(bsz, q_len, -1)
429
+
430
+ if self.config.pretraining_tp > 1:
431
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
432
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
433
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
434
+ else:
435
+ attn_output = self.o_proj(attn_output)
436
+
437
+ if not output_attentions:
438
+ attn_weights = None
439
+
440
+ return attn_output, attn_weights, past_key_value
441
+
442
+
443
+ class LlamaFlashAttention2(LlamaAttention):
444
+ """
445
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
446
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
447
+ flash attention and deal with padding tokens in case the input contains any of them.
448
+ """
449
+
450
+ def __init__(self, *args, **kwargs):
451
+ super().__init__(*args, **kwargs)
452
+
453
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
454
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
455
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
456
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
457
+
458
+ def forward(
459
+ self,
460
+ hidden_states: torch.Tensor,
461
+ attention_mask: Optional[torch.LongTensor] = None,
462
+ position_ids: Optional[torch.LongTensor] = None,
463
+ past_key_value: Optional[Cache] = None,
464
+ output_attentions: bool = False,
465
+ use_cache: bool = False,
466
+ cache_position: Optional[torch.LongTensor] = None,
467
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
468
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
469
+ if isinstance(past_key_value, StaticCache):
470
+ raise ValueError(
471
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
472
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
473
+ )
474
+
475
+ output_attentions = False
476
+
477
+ bsz, q_len, _ = hidden_states.size()
478
+
479
+ query_states = self.q_proj(hidden_states)
480
+ key_states = self.k_proj(hidden_states)
481
+ value_states = self.v_proj(hidden_states)
482
+
483
+ # Flash attention requires the input to have the shape
484
+ # batch_size x seq_length x head_dim x hidden_dim
485
+ # therefore we just need to keep the original shape
486
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
487
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
488
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
489
+
490
+ if position_embeddings is None:
491
+ logger.warning_once(
492
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
493
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
494
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
495
+ "removed and `position_embeddings` will be mandatory."
496
+ )
497
+ cos, sin = self.rotary_emb(value_states, position_ids)
498
+ else:
499
+ cos, sin = position_embeddings
500
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
501
+
502
+ if past_key_value is not None:
503
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
504
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
505
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
506
+
507
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
508
+ # to be able to avoid many of these transpose/reshape/view.
509
+ query_states = query_states.transpose(1, 2)
510
+ key_states = key_states.transpose(1, 2)
511
+ value_states = value_states.transpose(1, 2)
512
+
513
+ dropout_rate = self.attention_dropout if self.training else 0.0
514
+
515
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
516
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
517
+ # cast them back in the correct dtype just to be sure everything works as expected.
518
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
519
+ # in fp32. (LlamaRMSNorm handles it correctly)
520
+
521
+ input_dtype = query_states.dtype
522
+ if input_dtype == torch.float32:
523
+ if torch.is_autocast_enabled():
524
+ target_dtype = torch.get_autocast_gpu_dtype()
525
+ # Handle the case where the model is quantized
526
+ elif hasattr(self.config, "_pre_quantization_dtype"):
527
+ target_dtype = self.config._pre_quantization_dtype
528
+ else:
529
+ target_dtype = self.q_proj.weight.dtype
530
+
531
+ logger.warning_once(
532
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
533
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
534
+ f" {target_dtype}."
535
+ )
536
+
537
+ query_states = query_states.to(target_dtype)
538
+ key_states = key_states.to(target_dtype)
539
+ value_states = value_states.to(target_dtype)
540
+
541
+ attn_output = _flash_attention_forward(
542
+ query_states,
543
+ key_states,
544
+ value_states,
545
+ attention_mask,
546
+ q_len,
547
+ position_ids=position_ids,
548
+ dropout=dropout_rate,
549
+ sliding_window=getattr(self, "sliding_window", None),
550
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
551
+ is_causal=self.is_causal,
552
+ )
553
+
554
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
555
+ attn_output = self.o_proj(attn_output)
556
+
557
+ if not output_attentions:
558
+ attn_weights = None
559
+
560
+ return attn_output, attn_weights, past_key_value
561
+
562
+
563
+ class LlamaSdpaAttention(LlamaAttention):
564
+ """
565
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
566
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
567
+ SDPA API.
568
+ """
569
+
570
+ # Adapted from LlamaAttention.forward
571
+ def forward(
572
+ self,
573
+ hidden_states: torch.Tensor,
574
+ attention_mask: Optional[torch.Tensor] = None,
575
+ position_ids: Optional[torch.LongTensor] = None,
576
+ past_key_value: Optional[Cache] = None,
577
+ output_attentions: bool = False,
578
+ use_cache: bool = False,
579
+ cache_position: Optional[torch.LongTensor] = None,
580
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
581
+ **kwargs,
582
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
583
+ if output_attentions:
584
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
585
+ logger.warning_once(
586
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
587
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
588
+ )
589
+ return super().forward(
590
+ hidden_states=hidden_states,
591
+ attention_mask=attention_mask,
592
+ position_ids=position_ids,
593
+ past_key_value=past_key_value,
594
+ output_attentions=output_attentions,
595
+ use_cache=use_cache,
596
+ cache_position=cache_position,
597
+ position_embeddings=position_embeddings,
598
+ )
599
+
600
+ bsz, q_len, _ = hidden_states.size()
601
+
602
+ query_states = self.q_proj(hidden_states)
603
+ key_states = self.k_proj(hidden_states)
604
+ value_states = self.v_proj(hidden_states)
605
+
606
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
607
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
608
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
609
+
610
+ if position_embeddings is None:
611
+ logger.warning_once(
612
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
613
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
614
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
615
+ "removed and `position_embeddings` will be mandatory."
616
+ )
617
+ cos, sin = self.rotary_emb(value_states, position_ids)
618
+ else:
619
+ cos, sin = position_embeddings
620
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
621
+
622
+ if past_key_value is not None:
623
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
624
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
625
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
626
+
627
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
628
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
629
+
630
+ causal_mask = attention_mask
631
+ if attention_mask is not None:
632
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
633
+
634
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
635
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
636
+ if query_states.device.type == "cuda" and causal_mask is not None:
637
+ query_states = query_states.contiguous()
638
+ key_states = key_states.contiguous()
639
+ value_states = value_states.contiguous()
640
+
641
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
642
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
643
+ is_causal = True if causal_mask is None and q_len > 1 else False
644
+
645
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
646
+ query_states,
647
+ key_states,
648
+ value_states,
649
+ attn_mask=causal_mask,
650
+ dropout_p=self.attention_dropout if self.training else 0.0,
651
+ is_causal=is_causal,
652
+ )
653
+
654
+ attn_output = attn_output.transpose(1, 2).contiguous()
655
+ attn_output = attn_output.view(bsz, q_len, -1)
656
+
657
+ attn_output = self.o_proj(attn_output)
658
+
659
+ return attn_output, None, past_key_value
660
+
661
+
662
+ LLAMA_ATTENTION_CLASSES = {
663
+ "eager": LlamaAttention,
664
+ "flash_attention_2": LlamaFlashAttention2,
665
+ "sdpa": LlamaSdpaAttention,
666
+ }
667
+
668
+
669
+ class LlamaDecoderLayer(nn.Module):
670
+ def __init__(self, config: ByLlamaPatchConfig, layer_idx: int):
671
+ super().__init__()
672
+ self.hidden_size = config.hidden_size
673
+
674
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
675
+
676
+ self.mlp = LlamaMLP(config)
677
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
678
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
679
+
680
+ def forward(
681
+ self,
682
+ hidden_states: torch.Tensor,
683
+ attention_mask: Optional[torch.Tensor] = None,
684
+ position_ids: Optional[torch.LongTensor] = None,
685
+ past_key_value: Optional[Cache] = None,
686
+ output_attentions: Optional[bool] = False,
687
+ use_cache: Optional[bool] = False,
688
+ cache_position: Optional[torch.LongTensor] = None,
689
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
690
+ **kwargs,
691
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
692
+ """
693
+ Args:
694
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
695
+ attention_mask (`torch.FloatTensor`, *optional*):
696
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
697
+ query_sequence_length, key_sequence_length)` if default attention is used.
698
+ output_attentions (`bool`, *optional*):
699
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
700
+ returned tensors for more detail.
701
+ use_cache (`bool`, *optional*):
702
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
703
+ (see `past_key_values`).
704
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
705
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
706
+ Indices depicting the position of the input sequence tokens in the sequence
707
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
708
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
709
+ with `head_dim` being the embedding dimension of each attention head.
710
+ kwargs (`dict`, *optional*):
711
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
712
+ into the model
713
+ """
714
+ residual = hidden_states
715
+
716
+ hidden_states = self.input_layernorm(hidden_states)
717
+
718
+ # Self Attention
719
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
720
+ hidden_states=hidden_states,
721
+ attention_mask=attention_mask,
722
+ position_ids=position_ids,
723
+ past_key_value=past_key_value,
724
+ output_attentions=output_attentions,
725
+ use_cache=use_cache,
726
+ cache_position=cache_position,
727
+ position_embeddings=position_embeddings,
728
+ **kwargs,
729
+ )
730
+ hidden_states = residual + hidden_states
731
+
732
+ # Fully Connected
733
+ residual = hidden_states
734
+ hidden_states = self.post_attention_layernorm(hidden_states)
735
+ hidden_states = self.mlp(hidden_states)
736
+ hidden_states = residual + hidden_states
737
+
738
+ outputs = (hidden_states,)
739
+
740
+ if output_attentions:
741
+ outputs += (self_attn_weights,)
742
+
743
+ if use_cache:
744
+ outputs += (present_key_value,)
745
+
746
+ return outputs
747
+
748
+
749
+ LLAMA_START_DOCSTRING = r"""
750
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
751
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
752
+ etc.)
753
+
754
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
755
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
756
+ and behavior.
757
+
758
+ Parameters:
759
+ config ([`ByLlamaPatchConfig`]):
760
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
761
+ load the weights associated with the model, only the configuration. Check out the
762
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
763
+ """
764
+
765
+
766
+ @add_start_docstrings(
767
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
768
+ LLAMA_START_DOCSTRING,
769
+ )
770
+ class LlamaPreTrainedModel(PreTrainedModel):
771
+ config_class = ByLlamaPatchConfig
772
+ base_model_prefix = "model"
773
+ supports_gradient_checkpointing = True
774
+ _no_split_modules = ["LlamaDecoderLayer"]
775
+ _skip_keys_device_placement = ["past_key_values"]
776
+ _supports_flash_attn_2 = True
777
+ _supports_sdpa = True
778
+ _supports_cache_class = True
779
+ _supports_quantized_cache = True
780
+ _supports_static_cache = True
781
+
782
+ def _init_weights(self, module):
783
+ std = self.config.initializer_range
784
+ if isinstance(module, nn.Linear):
785
+ module.weight.data.normal_(mean=0.0, std=std)
786
+ if module.bias is not None:
787
+ module.bias.data.zero_()
788
+ elif isinstance(module, nn.Embedding):
789
+ module.weight.data.normal_(mean=0.0, std=std)
790
+ if module.padding_idx is not None:
791
+ module.weight.data[module.padding_idx].zero_()
792
+
793
+
794
+ LLAMA_INPUTS_DOCSTRING = r"""
795
+ Args:
796
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
797
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
798
+ it.
799
+
800
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
801
+ [`PreTrainedTokenizer.__call__`] for details.
802
+
803
+ [What are input IDs?](../glossary#input-ids)
804
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
805
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
806
+
807
+ - 1 for tokens that are **not masked**,
808
+ - 0 for tokens that are **masked**.
809
+
810
+ [What are attention masks?](../glossary#attention-mask)
811
+
812
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
813
+ [`PreTrainedTokenizer.__call__`] for details.
814
+
815
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
816
+ `past_key_values`).
817
+
818
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
819
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
820
+ information on the default strategy.
821
+
822
+ - 1 indicates the head is **not masked**,
823
+ - 0 indicates the head is **masked**.
824
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
825
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
826
+ config.n_positions - 1]`.
827
+
828
+ [What are position IDs?](../glossary#position-ids)
829
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
830
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
831
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
832
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
833
+
834
+ Two formats are allowed:
835
+ - a [`~cache_utils.Cache`] instance, see our
836
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
837
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
838
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
839
+ cache format.
840
+
841
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
842
+ legacy cache format will be returned.
843
+
844
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
845
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
846
+ of shape `(batch_size, sequence_length)`.
847
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
848
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
849
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
850
+ model's internal embedding lookup matrix.
851
+ use_cache (`bool`, *optional*):
852
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
853
+ `past_key_values`).
854
+ output_attentions (`bool`, *optional*):
855
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
856
+ tensors for more detail.
857
+ output_hidden_states (`bool`, *optional*):
858
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
859
+ more detail.
860
+ return_dict (`bool`, *optional*):
861
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
862
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
863
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
864
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
865
+ the complete sequence length.
866
+ """
867
+
868
+
869
+ @add_start_docstrings(
870
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
871
+ LLAMA_START_DOCSTRING,
872
+ )
873
+ class ByLlamaPatchModel(LlamaPreTrainedModel):
874
+ """
875
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
876
+
877
+ Args:
878
+ config: ByLlamaPatchConfig
879
+ """
880
+
881
+ def __init__(self, config: ByLlamaPatchConfig):
882
+ super().__init__(config)
883
+ self.vocab_size = config.vocab_size
884
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.input_embedding_dim)
885
+
886
+ if config.embedding_aggregator_type == "mean":
887
+ self.embedding_aggregator = MeanEmbeddingAggregator(config)
888
+ elif config.embedding_aggregator_type == "linear":
889
+ self.embedding_aggregator = LinearEmbeddingAggregator(config)
890
+ else:
891
+ raise ValueError(
892
+ f"Invalid `embedding_aggregator_type` in config: {config.embedding_aggregator_type}. "
893
+ )
894
+
895
+ self.layers = nn.ModuleList(
896
+ [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
897
+ )
898
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
899
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
900
+ self.gradient_checkpointing = False
901
+
902
+ # Initialize weights and apply final processing
903
+ self.post_init()
904
+
905
+ def get_input_embeddings(self):
906
+ return self.embed_tokens
907
+
908
+ def set_input_embeddings(self, value):
909
+ self.embed_tokens = value
910
+
911
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
912
+ def forward(
913
+ self,
914
+ input_ids: torch.LongTensor = None,
915
+ attention_mask: Optional[torch.Tensor] = None,
916
+ position_ids: Optional[torch.LongTensor] = None,
917
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
918
+ inputs_embeds: Optional[torch.FloatTensor] = None,
919
+ use_cache: Optional[bool] = None,
920
+ output_attentions: Optional[bool] = None,
921
+ output_hidden_states: Optional[bool] = None,
922
+ return_dict: Optional[bool] = None,
923
+ cache_position: Optional[torch.LongTensor] = None,
924
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
925
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
926
+ output_hidden_states = (
927
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
928
+ )
929
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
930
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
931
+ # Aggregate attention mask
932
+ if attention_mask is not None:
933
+ attention_mask = torch.ones_like(
934
+ input_ids, dtype=attention_mask.dtype, device=attention_mask.device
935
+ )
936
+ bsz = input_ids.size(0)
937
+ attention_mask = attention_mask.view(bsz, -1, self.config.num_lm_heads).all(dim=-1)
938
+
939
+
940
+ if (input_ids is None) ^ (inputs_embeds is not None):
941
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
942
+
943
+ if self.gradient_checkpointing and self.training and use_cache:
944
+ logger.warning_once(
945
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
946
+ )
947
+ use_cache = False
948
+
949
+ if inputs_embeds is None:
950
+ inputs_embeds = self.embed_tokens(input_ids)
951
+ inputs_embeds = self.embedding_aggregator(inputs_embeds)
952
+
953
+ # kept for BC (non `Cache` `past_key_values` inputs)
954
+ return_legacy_cache = False
955
+ if use_cache and not isinstance(past_key_values, Cache):
956
+ return_legacy_cache = True
957
+ if past_key_values is None:
958
+ past_key_values = DynamicCache()
959
+ else:
960
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
961
+ logger.warning_once(
962
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
963
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
964
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
965
+ )
966
+
967
+ if cache_position is None:
968
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
969
+ cache_position = torch.arange(
970
+ past_seen_tokens,
971
+ past_seen_tokens + inputs_embeds.shape[1],
972
+ device=inputs_embeds.device
973
+ )
974
+ if position_ids is None:
975
+ position_ids = cache_position.unsqueeze(0)
976
+
977
+ causal_mask = self._update_causal_mask(
978
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
979
+ )
980
+ hidden_states = inputs_embeds
981
+
982
+ # create position embeddings to be shared across the decoder layers
983
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
984
+
985
+ # decoder layers
986
+ all_hidden_states = () if output_hidden_states else None
987
+ all_self_attns = () if output_attentions else None
988
+ next_decoder_cache = None
989
+
990
+ for decoder_layer in self.layers:
991
+ if output_hidden_states:
992
+ all_hidden_states += (hidden_states,)
993
+
994
+ if self.gradient_checkpointing and self.training:
995
+ layer_outputs = self._gradient_checkpointing_func(
996
+ decoder_layer.__call__,
997
+ hidden_states,
998
+ causal_mask,
999
+ position_ids,
1000
+ past_key_values,
1001
+ output_attentions,
1002
+ use_cache,
1003
+ cache_position,
1004
+ position_embeddings,
1005
+ )
1006
+ else:
1007
+ layer_outputs = decoder_layer(
1008
+ hidden_states,
1009
+ attention_mask=causal_mask,
1010
+ position_ids=position_ids,
1011
+ past_key_value=past_key_values,
1012
+ output_attentions=output_attentions,
1013
+ use_cache=use_cache,
1014
+ cache_position=cache_position,
1015
+ position_embeddings=position_embeddings,
1016
+ )
1017
+
1018
+ hidden_states = layer_outputs[0]
1019
+
1020
+ if use_cache:
1021
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1022
+
1023
+ if output_attentions:
1024
+ all_self_attns += (layer_outputs[1],)
1025
+
1026
+ hidden_states = self.norm(hidden_states)
1027
+
1028
+ # add hidden states from the last decoder layer
1029
+ if output_hidden_states:
1030
+ all_hidden_states += (hidden_states,)
1031
+
1032
+ next_cache = next_decoder_cache if use_cache else None
1033
+ if return_legacy_cache:
1034
+ next_cache = next_cache.to_legacy_cache()
1035
+
1036
+ if not return_dict:
1037
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1038
+ return BaseModelOutputWithPast(
1039
+ last_hidden_state=hidden_states,
1040
+ past_key_values=next_cache,
1041
+ hidden_states=all_hidden_states,
1042
+ attentions=all_self_attns,
1043
+ )
1044
+
1045
+ def _update_causal_mask(
1046
+ self,
1047
+ attention_mask: torch.Tensor,
1048
+ input_tensor: torch.Tensor,
1049
+ cache_position: torch.Tensor,
1050
+ past_key_values: Cache,
1051
+ output_attentions: bool,
1052
+ ):
1053
+ if self.config._attn_implementation == "flash_attention_2":
1054
+ if attention_mask is not None and 0.0 in attention_mask:
1055
+ return attention_mask
1056
+ return None
1057
+
1058
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1059
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1060
+ # to infer the attention mask.
1061
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1062
+ using_static_cache = isinstance(past_key_values, StaticCache)
1063
+
1064
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1065
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
1066
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1067
+ attention_mask,
1068
+ inputs_embeds=input_tensor,
1069
+ past_key_values_length=past_seen_tokens,
1070
+ is_training=self.training,
1071
+ ):
1072
+ return None
1073
+
1074
+ dtype, device = input_tensor.dtype, input_tensor.device
1075
+ sequence_length = input_tensor.shape[1]
1076
+ if using_static_cache:
1077
+ target_length = past_key_values.get_max_cache_shape()
1078
+ else:
1079
+ target_length = (
1080
+ attention_mask.shape[-1]
1081
+ if isinstance(attention_mask, torch.Tensor)
1082
+ else past_seen_tokens + sequence_length + 1
1083
+ )
1084
+
1085
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1086
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1087
+ attention_mask,
1088
+ sequence_length=sequence_length,
1089
+ target_length=target_length,
1090
+ dtype=dtype,
1091
+ device=device,
1092
+ cache_position=cache_position,
1093
+ batch_size=input_tensor.shape[0],
1094
+ )
1095
+
1096
+ if (
1097
+ self.config._attn_implementation == "sdpa"
1098
+ and attention_mask is not None
1099
+ and attention_mask.device.type == "cuda"
1100
+ and not output_attentions
1101
+ ):
1102
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1103
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1104
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1105
+ min_dtype = torch.finfo(dtype).min
1106
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1107
+
1108
+ return causal_mask
1109
+
1110
+ @staticmethod
1111
+ def _prepare_4d_causal_attention_mask_with_cache_position(
1112
+ attention_mask: torch.Tensor,
1113
+ sequence_length: int,
1114
+ target_length: int,
1115
+ dtype: torch.dtype,
1116
+ device: torch.device,
1117
+ cache_position: torch.Tensor,
1118
+ batch_size: int,
1119
+ **kwargs,
1120
+ ):
1121
+ """
1122
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1123
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1124
+
1125
+ Args:
1126
+ attention_mask (`torch.Tensor`):
1127
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
1128
+ `(batch_size, 1, query_length, key_value_length)`.
1129
+ sequence_length (`int`):
1130
+ The sequence length being processed.
1131
+ target_length (`int`):
1132
+ The target length: when generating with static cache, the mask should be as long as the static cache,
1133
+ to account for the 0 padding, the part of the cache that is not filled yet.
1134
+ dtype (`torch.dtype`):
1135
+ The dtype to use for the 4D attention mask.
1136
+ device (`torch.device`):
1137
+ The device to plcae the 4D attention mask on.
1138
+ cache_position (`torch.Tensor`):
1139
+ Indices depicting the position of the input sequence tokens in the sequence.
1140
+ batch_size (`torch.Tensor`):
1141
+ Batch size.
1142
+ """
1143
+ if attention_mask is not None and attention_mask.dim() == 4:
1144
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1145
+ causal_mask = attention_mask
1146
+ else:
1147
+ min_dtype = torch.finfo(dtype).min
1148
+ causal_mask = torch.full(
1149
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
1150
+ )
1151
+ if sequence_length != 1:
1152
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1153
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1154
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1155
+ if attention_mask is not None:
1156
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1157
+ mask_length = attention_mask.shape[-1]
1158
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
1159
+ padding_mask = padding_mask == 0
1160
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1161
+ padding_mask, min_dtype
1162
+ )
1163
+
1164
+ return causal_mask
1165
+
1166
+
1167
+ class ByLlamaPatchForCausalLM(LlamaPreTrainedModel, GenerationMixin):
1168
+ _tied_weights_keys = ["lm_head.weight"]
1169
+
1170
+ def __init__(self, config):
1171
+ super().__init__(config)
1172
+ self.model = ByLlamaPatchModel(config)
1173
+ self.vocab_size = config.vocab_size
1174
+ self.lm_head = nn.Linear(
1175
+ config.hidden_size, config.output_vocab_size, bias=False
1176
+ )
1177
+ # Initialize weights and apply final processing
1178
+ self.post_init()
1179
+
1180
+ def get_input_embeddings(self):
1181
+ return self.model.embed_tokens
1182
+
1183
+ def set_input_embeddings(self, value):
1184
+ self.model.embed_tokens = value
1185
+
1186
+ def get_output_embeddings(self):
1187
+ return self.lm_head
1188
+
1189
+ def set_output_embeddings(self, new_embeddings):
1190
+ self.lm_head = new_embeddings
1191
+
1192
+ def set_decoder(self, decoder):
1193
+ self.model = decoder
1194
+
1195
+ def get_decoder(self):
1196
+ return self.model
1197
+
1198
+ @classmethod
1199
+ def can_generate(cls) -> bool:
1200
+ return True
1201
+
1202
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1203
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1204
+ def forward(
1205
+ self,
1206
+ input_ids: torch.LongTensor = None,
1207
+ attention_mask: Optional[torch.Tensor] = None,
1208
+ position_ids: Optional[torch.LongTensor] = None,
1209
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1210
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1211
+ labels: Optional[torch.LongTensor] = None,
1212
+ use_cache: Optional[bool] = None,
1213
+ output_attentions: Optional[bool] = None,
1214
+ output_hidden_states: Optional[bool] = None,
1215
+ return_dict: Optional[bool] = None,
1216
+ cache_position: Optional[torch.LongTensor] = None,
1217
+ num_logits_to_keep: int = 0,
1218
+ **loss_kwargs,
1219
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1220
+ r"""
1221
+ Args:
1222
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1223
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1224
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1225
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1226
+
1227
+ num_logits_to_keep (`int`, *optional*):
1228
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1229
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1230
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1231
+
1232
+ Returns:
1233
+
1234
+ Example:
1235
+
1236
+ ```python
1237
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
1238
+
1239
+ >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
1240
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
1241
+
1242
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1243
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1244
+
1245
+ >>> # Generate
1246
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1247
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1248
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1249
+ ```"""
1250
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1251
+ output_hidden_states = (
1252
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1253
+ )
1254
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1255
+
1256
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1257
+ outputs = self.model(
1258
+ input_ids=input_ids,
1259
+ attention_mask=attention_mask,
1260
+ position_ids=position_ids,
1261
+ past_key_values=past_key_values,
1262
+ inputs_embeds=inputs_embeds,
1263
+ use_cache=use_cache,
1264
+ output_attentions=output_attentions,
1265
+ output_hidden_states=output_hidden_states,
1266
+ return_dict=return_dict,
1267
+ cache_position=cache_position,
1268
+ )
1269
+
1270
+ hidden_states = outputs[0]
1271
+ if self.config.pretraining_tp > 1:
1272
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1273
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1274
+ logits = torch.cat(logits, dim=-1)
1275
+ else:
1276
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1277
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
1278
+
1279
+
1280
+
1281
+ # Split logits tensor by num_heads
1282
+ bsz, seq_len, _ = logits.size()
1283
+ logits = logits.view(bsz, seq_len * self.config.num_lm_heads, -1).contiguous()
1284
+
1285
+ loss = None
1286
+ if labels is not None:
1287
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)
1288
+
1289
+ if not return_dict:
1290
+ output = (logits,) + outputs[1:]
1291
+ return (loss,) + output if loss is not None else output
1292
+
1293
+ return CausalLMOutputWithPast(
1294
+ loss=loss,
1295
+ logits=logits,
1296
+ past_key_values=outputs.past_key_values,
1297
+ hidden_states=outputs.hidden_states,
1298
+ attentions=outputs.attentions,
1299
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|begin_of_text|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<|cls|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "<|end_of_text|>",
18
+ "lstrip": false,
19
+ "normalized": true,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<|mask|>",
25
+ "lstrip": false,
26
+ "normalized": true,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<|pad|>",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "<|sep|>",
39
+ "lstrip": false,
40
+ "normalized": true,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ }
44
+ }
tokenization_utf8_like_byte_v3.py ADDED
@@ -0,0 +1,1205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 T5 Authors and HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Tokenization class for model ByT5."""
15
+
16
+ import warnings
17
+ from typing import (
18
+ Dict,
19
+ List,
20
+ Optional,
21
+ Union,
22
+ Tuple
23
+ )
24
+ import json
25
+ import os
26
+ import copy
27
+ import ast
28
+
29
+ import torch
30
+ import numpy as np
31
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
32
+ from transformers.tokenization_utils_base import (
33
+ BatchEncoding,
34
+ EncodedInput,
35
+ PaddingStrategy,
36
+ TruncationStrategy
37
+ )
38
+ from transformers.utils import logging
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
43
+ ADDED_TOKENS_FILE = "added_tokens.json"
44
+ TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
45
+
46
+ LARGE_INTEGER = int(1e20)
47
+
48
+ def make_serializeable(obj):
49
+ if isinstance(obj, dict):
50
+ return {str(k): make_serializeable(v) for k, v in obj.items()}
51
+ if isinstance(obj, list):
52
+ return [make_serializeable(v) for v in obj]
53
+ if isinstance(obj, tuple):
54
+ return make_serializeable(list(obj))
55
+ return obj
56
+
57
+
58
+ class ByteLMTokenizerV3(PreTrainedTokenizer):
59
+ """Byte tokenizer with completely seperate space for special tokens.
60
+
61
+ tok.pad Parameters
62
+ ----------
63
+ PreTrainedTokenizer : _type_
64
+ _description_
65
+
66
+ Returns
67
+ -------
68
+ _type_
69
+ _description_
70
+
71
+ Raises
72
+ ------
73
+ ValueError
74
+ _description_
75
+ ValueError
76
+ _description_
77
+ """
78
+
79
+ model_input_names: list[str] = ["input_ids", "attention_mask"]
80
+ reserve_sizes: list[int] = [59, 0, 0, 0]
81
+ byte_head_ints: list[int] = [
82
+ int("11000000", base=2),
83
+ int("10000000", base=2),
84
+ int("01000000", base=2),
85
+ int("00000000", base=2),
86
+ ]
87
+ byte_n_free_bits: list[int] = [6, 6, 6, 6]
88
+ patch_padding: bool
89
+ reserve_token_list: list[tuple[int]]
90
+
91
+ def __init__(
92
+ self,
93
+ patch_padding=True,
94
+ pad_token="<|pad|>",
95
+ eos_token="<|end_of_text|>",
96
+ bos_token="<|begin_of_text|>",
97
+ cls_token="<|cls|>",
98
+ sep_token="<|sep|>",
99
+ mask_token="<|mask|>",
100
+ vision_start_token="<|vision_start|>", # for vlm
101
+ vision_br_token="<|vision_br|>", # for vlm
102
+ vision_end_token="<|vision_end|>", # for vlm
103
+ start_header_id_token="<|start_header_id|>", # for it
104
+ end_header_id_token="<|end_header_id|>", # for it
105
+ eor_id="<|end_of_role|>", # for it
106
+ extra_ids=47,
107
+ **kwargs,
108
+ ) -> None:
109
+ assert np.prod(
110
+ [
111
+ 2**n_free_bits - reserve_size
112
+ for reserve_size, n_free_bits in zip(
113
+ self.reserve_sizes, self.byte_n_free_bits
114
+ )
115
+ ]
116
+ ) >= int(
117
+ "110000", base=16
118
+ ), "Not enough positions for all unicode. Too many reserve size."
119
+
120
+ self.patch_padding = patch_padding
121
+
122
+ # list up all reserve tokens
123
+ self._list_up_reserve_tokens()
124
+
125
+ _bos_token = (
126
+ AddedToken(bos_token, lstrip=False, rstrip=False)
127
+ if isinstance(bos_token, str)
128
+ else bos_token
129
+ )
130
+ _eos_token = (
131
+ AddedToken(eos_token, lstrip=False, rstrip=False)
132
+ if isinstance(eos_token, str)
133
+ else eos_token
134
+ )
135
+ _pad_token = (
136
+ AddedToken(pad_token, lstrip=False, rstrip=False)
137
+ if isinstance(pad_token, str)
138
+ else pad_token
139
+ )
140
+ _cls_token = (
141
+ AddedToken(cls_token, lstrip=False, rstrip=False)
142
+ if isinstance(cls_token, str)
143
+ else cls_token
144
+ )
145
+ _sep_token = (
146
+ AddedToken(sep_token, lstrip=False, rstrip=False)
147
+ if isinstance(sep_token, str)
148
+ else sep_token
149
+ )
150
+ _mask_token = (
151
+ AddedToken(mask_token, lstrip=False, rstrip=False)
152
+ if isinstance(mask_token, str)
153
+ else mask_token
154
+ )
155
+ _vision_start_token = (
156
+ AddedToken(vision_start_token, lstrip=False, rstrip=False)
157
+ if isinstance(vision_start_token, str)
158
+ else vision_start_token
159
+ )
160
+ _vision_br_token = (
161
+ AddedToken(vision_br_token, lstrip=False, rstrip=False)
162
+ if isinstance(vision_br_token, str)
163
+ else vision_br_token
164
+ )
165
+ _vision_end_token = (
166
+ AddedToken(vision_end_token, lstrip=False, rstrip=False)
167
+ if isinstance(vision_end_token, str)
168
+ else vision_end_token
169
+ )
170
+ _start_header_id_token = (
171
+ AddedToken(start_header_id_token, lstrip=False, rstrip=False)
172
+ if isinstance(start_header_id_token, str)
173
+ else start_header_id_token
174
+ )
175
+ _end_header_id_token = (
176
+ AddedToken(end_header_id_token, lstrip=False, rstrip=False)
177
+ if isinstance(end_header_id_token, str)
178
+ else end_header_id_token
179
+ )
180
+ _eor_id = (
181
+ AddedToken(eor_id, lstrip=False, rstrip=False)
182
+ if isinstance(eor_id, str)
183
+ else eor_id
184
+ )
185
+
186
+ self.offset = 0
187
+ self._added_tokens_decoder = {
188
+ self.reserve_token_list[i]: special_token
189
+ for i, special_token in enumerate(
190
+ [
191
+ _pad_token,
192
+ _eos_token,
193
+ _bos_token,
194
+ _cls_token,
195
+ _sep_token,
196
+ _mask_token,
197
+ _vision_start_token,
198
+ _vision_br_token,
199
+ _vision_end_token,
200
+ _start_header_id_token,
201
+ _end_header_id_token,
202
+ _eor_id,
203
+ ]
204
+ )
205
+ }
206
+
207
+ offset = len(self._added_tokens_decoder)
208
+ extra_tokens = {
209
+ self.reserve_token_list[j + offset]: AddedToken(
210
+ f"<|extra_id_{i}|>", lstrip=False, rstrip=False
211
+ )
212
+ for j, i in enumerate(range(extra_ids))
213
+ }
214
+ self._added_tokens_decoder.update(extra_tokens)
215
+
216
+ super().__init__(
217
+ bos_token=_bos_token,
218
+ eos_token=_eos_token,
219
+ pad_token=_pad_token,
220
+ cls_token=_cls_token,
221
+ sep_token=_sep_token,
222
+ mask_token=_mask_token,
223
+ vision_start_token=_vision_start_token,
224
+ vision_br_token=_vision_br_token,
225
+ vision_end_token=_vision_end_token,
226
+ start_header_id_token=_start_header_id_token,
227
+ end_header_id_token=_end_header_id_token,
228
+ eor_id=_eor_id,
229
+ **kwargs,
230
+ )
231
+
232
+ self._vocab_size = len(self.get_vocab())
233
+
234
+ def _list_up_reserve_tokens(self):
235
+ self.reserve_token_list = [
236
+ (
237
+ i + self.byte_head_ints[0],
238
+ self.byte_head_ints[1],
239
+ self.byte_head_ints[2],
240
+ self.byte_head_ints[3],
241
+ )
242
+ for i in range(self.reserve_sizes[0])
243
+ ]
244
+
245
+ @property
246
+ def vocab_size(self):
247
+ return self._vocab_size
248
+
249
+ def create_tree(
250
+ self, byte_options: list[list[int]], byte_index: int, max_byte_index: int
251
+ ) -> list[list[int]]:
252
+ if byte_index == max_byte_index:
253
+ return [[reserve_option] for reserve_option in byte_options[byte_index]]
254
+
255
+ concat_list = []
256
+ for byte_reserve_option in byte_options[byte_index]:
257
+ if byte_reserve_option is not None:
258
+ concat_list += [
259
+ [byte_reserve_option] + following_bytes
260
+ if following_bytes != [None]
261
+ else [byte_reserve_option]
262
+ for following_bytes in self.create_tree(
263
+ byte_options=byte_options,
264
+ byte_index=byte_index + 1,
265
+ max_byte_index=max_byte_index,
266
+ )
267
+ ]
268
+ else:
269
+ concat_list.append([None])
270
+ return concat_list
271
+
272
+ def get_vocab(self):
273
+ byte_options = [
274
+ list(range(reserve_size, 2**n_free_bits))
275
+ for reserve_size, n_free_bits in zip(
276
+ self.reserve_sizes, self.byte_n_free_bits
277
+ )
278
+ ]
279
+
280
+ if not self.patch_padding:
281
+ for i in range(len(byte_options) - 1):
282
+ byte_options[i] += [None]
283
+
284
+ byte_options.reverse()
285
+ byte_tokens = self.create_tree(
286
+ byte_options=byte_options, byte_index=0, max_byte_index=3
287
+ )
288
+
289
+ byte_tokens = sorted(
290
+ byte_tokens,
291
+ key=lambda lst: sum([e * (256**i) for i, e in enumerate(lst)])
292
+ + 256 ** len(lst),
293
+ )
294
+
295
+ for byte_token_index in range(len(byte_tokens)):
296
+ byte_tokens[byte_token_index].reverse()
297
+ for position in range(len(byte_tokens[byte_token_index])):
298
+ byte_tokens[byte_token_index][position] += self.byte_head_ints[position]
299
+ byte_tokens[byte_token_index] = tuple(byte_tokens[byte_token_index])
300
+
301
+ vocab = {self.convert_ids_to_tokens(tokens): tokens for tokens in byte_tokens}
302
+ vocab.pop("")
303
+ vocab.update(self.added_tokens_encoder)
304
+ return vocab
305
+
306
+
307
+ def _get_padding_truncation_strategies(
308
+ self, padding=False, truncation=None, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
309
+ ):
310
+ """
311
+ Find the correct padding/truncation strategy
312
+ """
313
+
314
+ # Backward compatibility for previous behavior, maybe we should deprecate it:
315
+ # If you only set max_length, it activates truncation for max_length
316
+ if max_length is not None and padding is False and truncation is None:
317
+ if verbose:
318
+ if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False):
319
+ logger.warning(
320
+ "Truncation was not explicitly activated but `max_length` is provided a specific value, please"
321
+ " use `truncation=True` to explicitly truncate examples to max length. Defaulting to"
322
+ " 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the"
323
+ " tokenizer you can select this strategy more precisely by providing a specific strategy to"
324
+ " `truncation`."
325
+ )
326
+ self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
327
+ truncation = "longest_first"
328
+
329
+ # Get padding strategy
330
+ if padding is not False:
331
+ if padding is True:
332
+ if verbose:
333
+ if max_length is not None and (
334
+ truncation is None or truncation is False or truncation == "do_not_truncate"
335
+ ):
336
+ warnings.warn(
337
+ "`max_length` is ignored when `padding`=`True` and there is no truncation strategy. "
338
+ "To pad to max length, use `padding='max_length'`."
339
+ )
340
+ padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
341
+ elif not isinstance(padding, PaddingStrategy):
342
+ padding_strategy = PaddingStrategy(padding)
343
+ elif isinstance(padding, PaddingStrategy):
344
+ padding_strategy = padding
345
+ else:
346
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
347
+
348
+ # Get truncation strategy
349
+ if truncation is not False and truncation is not None:
350
+ if truncation is True:
351
+ truncation_strategy = (
352
+ TruncationStrategy.LONGEST_FIRST
353
+ ) # Default to truncate the longest sequences in pairs of inputs
354
+ elif not isinstance(truncation, TruncationStrategy):
355
+ truncation_strategy = TruncationStrategy(truncation)
356
+ elif isinstance(truncation, TruncationStrategy):
357
+ truncation_strategy = truncation
358
+ else:
359
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
360
+
361
+ # Set max length if needed
362
+ if max_length is None:
363
+ if padding_strategy == PaddingStrategy.MAX_LENGTH:
364
+ if self.model_max_length > LARGE_INTEGER:
365
+ if verbose:
366
+ if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False):
367
+ logger.warning(
368
+ "Asking to pad to max_length but no maximum length is provided and the model has no"
369
+ " predefined maximum length. Default to no padding."
370
+ )
371
+ self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
372
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
373
+ else:
374
+ max_length = self.model_max_length
375
+
376
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
377
+ if self.model_max_length > LARGE_INTEGER:
378
+ if verbose:
379
+ if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False):
380
+ logger.warning(
381
+ "Asking to truncate to max_length but no maximum length is provided and the model has"
382
+ " no predefined maximum length. Default to no truncation."
383
+ )
384
+ self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True
385
+ truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
386
+ else:
387
+ max_length = self.model_max_length
388
+
389
+ # Test if we have a padding token
390
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD and self.pad_token is None:
391
+ raise ValueError(
392
+ "Asking to pad but the tokenizer does not have a padding token. "
393
+ "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
394
+ "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
395
+ )
396
+
397
+ # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
398
+ if (
399
+ truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
400
+ and padding_strategy != PaddingStrategy.DO_NOT_PAD
401
+ and pad_to_multiple_of is not None
402
+ and max_length is not None
403
+ and (max_length % pad_to_multiple_of != 0)
404
+ ):
405
+ raise ValueError(
406
+ "Truncation and padding are both activated but "
407
+ f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
408
+ )
409
+
410
+ return padding_strategy, truncation_strategy, max_length, kwargs
411
+
412
+
413
+
414
+ def _add_bos_if_not_present(self, token_ids: list[int]) -> list[int]:
415
+ """Do not add bos again if user already added it."""
416
+ if len(token_ids) > 0 and token_ids[0] == self.bos_token_id:
417
+ warnings.warn(
418
+ f"This sequence already has {self.bos_token}. In future versions this behavior may lead to duplicated"
419
+ " bos tokens being added."
420
+ )
421
+ return token_ids
422
+ else:
423
+ return list(self.bos_token_id) + token_ids
424
+
425
+
426
+ def _add_eos_if_not_present(self, token_ids: list[int]) -> list[int]:
427
+ """Do not add eos again if user already added it."""
428
+ if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
429
+ warnings.warn(
430
+ f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
431
+ " eos tokens being added."
432
+ )
433
+ return token_ids
434
+ else:
435
+ return token_ids + list(self.eos_token_id)
436
+
437
+ def _pad(
438
+ self,
439
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
440
+ max_length: Optional[int] = None,
441
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
442
+ pad_to_multiple_of: Optional[int] = None,
443
+ padding_side: Optional[bool] = None,
444
+ return_attention_mask: Optional[bool] = None,
445
+ ) -> dict:
446
+ """
447
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
448
+
449
+ Args:
450
+ encoded_inputs:
451
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
452
+ max_length: maximum length of the returned list and optionally padding length (see below).
453
+ Will truncate by taking into account the special tokens.
454
+ padding_strategy: PaddingStrategy to use for padding.
455
+
456
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
457
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
458
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
459
+ The tokenizer padding sides are defined in `padding_side` argument:
460
+
461
+ - 'left': pads on the left of the sequences
462
+ - 'right': pads on the right of the sequences
463
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
464
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
465
+ `>= 7.5` (Volta).
466
+ padding_side:
467
+ The side on which the model should have padding applied. Should be selected between ['right', 'left'].
468
+ Default value is picked from the class attribute of the same name.
469
+ return_attention_mask:
470
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
471
+ """
472
+ # Load from model defaults
473
+ if return_attention_mask is None:
474
+ return_attention_mask = "attention_mask" in self.model_input_names
475
+
476
+ required_input = encoded_inputs[self.model_input_names[0]]
477
+
478
+ if padding_strategy == PaddingStrategy.LONGEST:
479
+ max_length = len(required_input)
480
+
481
+ if (
482
+ max_length is not None
483
+ and pad_to_multiple_of is not None
484
+ and (max_length % pad_to_multiple_of != 0)
485
+ ):
486
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
487
+
488
+ needs_to_be_padded = (
489
+ padding_strategy != PaddingStrategy.DO_NOT_PAD
490
+ and len(required_input) != max_length
491
+ )
492
+
493
+ # Initialize attention mask if not present.
494
+ if return_attention_mask and "attention_mask" not in encoded_inputs:
495
+ encoded_inputs["attention_mask"] = [1] * len(required_input)
496
+
497
+ if needs_to_be_padded:
498
+ if self.patch_padding:
499
+ difference = (max_length - len(required_input)) // len(
500
+ self.byte_head_ints
501
+ )
502
+ mask_patch_size = 4
503
+ else:
504
+ difference = max_length - len(required_input)
505
+ mask_patch_size = 1
506
+
507
+ padding_side = (
508
+ padding_side if padding_side is not None else self.padding_side
509
+ )
510
+
511
+ if padding_side == "right":
512
+ if return_attention_mask:
513
+ encoded_inputs["attention_mask"] = (
514
+ encoded_inputs["attention_mask"]
515
+ + [0] * difference * mask_patch_size
516
+ )
517
+ if "token_type_ids" in encoded_inputs:
518
+ encoded_inputs["token_type_ids"] = (
519
+ encoded_inputs["token_type_ids"]
520
+ + list(self.pad_token_type_id) * difference
521
+ )
522
+ if "special_tokens_mask" in encoded_inputs:
523
+ encoded_inputs["special_tokens_mask"] = (
524
+ encoded_inputs["special_tokens_mask"]
525
+ + [1] * difference * mask_patch_size
526
+ )
527
+ encoded_inputs[self.model_input_names[0]] = (
528
+ required_input + list(self.pad_token_id) * difference
529
+ )
530
+ elif padding_side == "left":
531
+ if return_attention_mask:
532
+ encoded_inputs["attention_mask"] = [
533
+ 0
534
+ ] * difference * mask_patch_size + encoded_inputs["attention_mask"]
535
+ if "token_type_ids" in encoded_inputs:
536
+ encoded_inputs["token_type_ids"] = (
537
+ list(self.pad_token_type_id) * difference
538
+ + encoded_inputs["token_type_ids"]
539
+ )
540
+ if "special_tokens_mask" in encoded_inputs:
541
+ encoded_inputs["special_tokens_mask"] = [
542
+ 1
543
+ ] * difference * mask_patch_size + encoded_inputs[
544
+ "special_tokens_mask"
545
+ ]
546
+ encoded_inputs[self.model_input_names[0]] = (
547
+ list(self.pad_token_id) * difference + required_input
548
+ )
549
+ else:
550
+ raise ValueError(f"Invalid padding strategy:{padding_side}")
551
+
552
+ return encoded_inputs
553
+
554
+
555
+ def build_inputs_with_special_tokens(
556
+ self, token_ids_0: list[int], token_ids_1: list[int] | None = None
557
+ ) -> list[int]:
558
+ """
559
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
560
+ adding special tokens. A sequence has the following format:
561
+ - single sequence: `X </s>`
562
+ - pair of sequences: `A </s> B </s>`
563
+ Args:
564
+ token_ids_0 (`List[int]`):
565
+ List of IDs to which the special tokens will be added.
566
+ token_ids_1 (`List[int]`, *optional*):
567
+ Optional second list of IDs for sequence pairs.
568
+ Returns:
569
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
570
+ """
571
+ token_ids_0 = self._add_bos_if_not_present(token_ids_0)
572
+ token_ids_0 = self._add_eos_if_not_present(token_ids_0)
573
+ if token_ids_1 is None:
574
+ return token_ids_0
575
+ else:
576
+ token_ids_1 = self._add_bos_if_not_present(token_ids_1)
577
+ token_ids_1 = self._add_eos_if_not_present(token_ids_1)
578
+ return token_ids_0 + token_ids_1
579
+
580
+ def _tokenize(self, text: str) -> list[str]:
581
+ """Take as input a string and return a list of strings (tokens) for words/sub-words"""
582
+ token_ids = []
583
+ for c in text:
584
+ token_ids.extend(self.unicode_to_bytes(ord(c)))
585
+
586
+ # Convert to string
587
+ token_ids = [str(i) for i in token_ids]
588
+ return token_ids
589
+
590
+ def _convert_token_to_id(self, token):
591
+ """Converts a token (str) in an id using the vocab."""
592
+ token_id = int(token) + self.offset
593
+ return token_id
594
+
595
+ def _convert_id_to_token(self, index):
596
+ """Converts an index (integer) in a token (str) using the vocab."""
597
+ return str(index - self.offset)
598
+
599
+ def _convert_token_to_id_with_added_voc(self, token):
600
+ if token is None:
601
+ return None
602
+
603
+ if token in self._added_tokens_encoder:
604
+ return list(self._added_tokens_encoder[token])
605
+ return [self._convert_token_to_id(token)]
606
+
607
+ def convert_tokens_to_ids(
608
+ self, tokens: Union[str, List[str]]
609
+ ) -> Union[int, List[int]]:
610
+ """
611
+ Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
612
+ vocabulary.
613
+
614
+ Args:
615
+ tokens (`str` or `List[str]`): One or several token(s) to convert to token id(s).
616
+
617
+ Returns:
618
+ `int` or `List[int]`: The token id or list of token ids.
619
+ """
620
+ if tokens is None:
621
+ return None
622
+
623
+ if isinstance(tokens, str):
624
+ return self._convert_token_to_id_with_added_voc(tokens)
625
+
626
+ ids = []
627
+ for token in tokens:
628
+ ids.extend(self._convert_token_to_id_with_added_voc(token))
629
+ return ids
630
+
631
+ def convert_bytes_for_single_char_to_char(self, ids: list[int]) -> str:
632
+ byte_ints = []
633
+ byte_offset = 1
634
+
635
+ if self.is_special_token(ids): # special token
636
+ return self.added_tokens_decoder[tuple(ids)].__str__()
637
+
638
+ for byte_position in range(1, len(ids) + 1):
639
+ byte_int = (
640
+ ids[-byte_position]
641
+ - self.byte_head_ints[-byte_position]
642
+ - self.reserve_sizes[-byte_position]
643
+ )
644
+ if byte_int != -self.reserve_sizes[-byte_position]: # not padding
645
+ byte_ints.append(byte_int * byte_offset)
646
+
647
+ byte_offset *= (
648
+ 2 ** self.byte_n_free_bits[-byte_position]
649
+ - self.reserve_sizes[-byte_position]
650
+ )
651
+
652
+ codepoint = sum(byte_ints)
653
+ if codepoint >= int("110000", base=16):
654
+ return None
655
+ else:
656
+ try:
657
+ return chr(codepoint)
658
+ except ValueError:
659
+ return None
660
+
661
+ # def is_special_token(self, ids: list[int]):
662
+ # return ids[0] < self.byte_head_ints[0] + (self.reserve_sizes[0] - 1)
663
+
664
+ def is_special_token(self, ids: list[int]):
665
+ return tuple(ids) in self._added_tokens_decoder
666
+
667
+ def convert_ids_to_tokens(
668
+ self, ids: list[int] | tuple[int], skip_special_tokens: bool = False
669
+ ) -> str | None:
670
+ """convert ids for single/multiple unicode character(s) to unicode character(s)"""
671
+
672
+ decoded_chars = ""
673
+
674
+ if isinstance(ids, tuple):
675
+ ids = list(ids)
676
+
677
+ if self.patch_padding:
678
+ for byte_position in range(0, len(ids), len(self.byte_head_ints)):
679
+ char_bytes = ids[
680
+ byte_position : byte_position + len(self.byte_head_ints)
681
+ ]
682
+ if (
683
+ skip_special_tokens and not self.is_special_token(char_bytes)
684
+ ) or not skip_special_tokens:
685
+ char = self.convert_bytes_for_single_char_to_char(char_bytes)
686
+ if char:
687
+ decoded_chars += char
688
+ return decoded_chars
689
+
690
+ if not self.is_special_token(ids): # not special token
691
+ byte_ints = []
692
+ byte_offset = 1
693
+ for byte_position in range(1, len(ids) + 1):
694
+ if ids[-byte_position] == 0:
695
+ break
696
+ byte_int = (
697
+ ids[-byte_position]
698
+ - self.byte_head_ints[-byte_position]
699
+ - self.reserve_sizes[-byte_position]
700
+ )
701
+ assert byte_int >= 0
702
+ byte_ints.append(byte_int * byte_offset)
703
+ byte_offset *= (
704
+ 2 ** self.byte_n_free_bits[-byte_position]
705
+ - self.reserve_sizes[-byte_position]
706
+ )
707
+
708
+ codepoint = sum(byte_ints)
709
+ if codepoint >= int("110000", base=16):
710
+ return None
711
+ else:
712
+ return chr(codepoint)
713
+ else: # special token
714
+ return self._added_tokens_decoder[tuple(ids)]
715
+
716
+ def unicode_to_bytes(self, codepoint: int) -> list[int]:
717
+ byte_list_reversed = []
718
+ for byte_position_from_right in range(len(self.byte_n_free_bits)):
719
+ byte_n_free_ids = (
720
+ 2 ** self.byte_n_free_bits[-1 - byte_position_from_right]
721
+ - self.reserve_sizes[-1 - byte_position_from_right]
722
+ )
723
+ byte_id = (
724
+ codepoint % byte_n_free_ids
725
+ + self.reserve_sizes[-1 - byte_position_from_right]
726
+ + self.byte_head_ints[-1 - byte_position_from_right]
727
+ )
728
+ codepoint //= byte_n_free_ids
729
+ byte_list_reversed.append(byte_id)
730
+
731
+ if codepoint == 0:
732
+ if self.patch_padding:
733
+ for pad_byte_position_from_right in range(
734
+ len(byte_list_reversed), len(self.byte_n_free_bits)
735
+ ):
736
+ byte_list_reversed.append(
737
+ self.byte_head_ints[-1 - pad_byte_position_from_right] + self.reserve_sizes[-1 - pad_byte_position_from_right]
738
+ )
739
+ byte_list_reversed.reverse()
740
+ return byte_list_reversed
741
+ raise ValueError("codepoint is too large")
742
+
743
+ # ByteTokenizer has no vocab file
744
+ def save_vocabulary(
745
+ self, save_directory: str, filename_prefix: str | None = None
746
+ ) -> tuple[str]:
747
+ return ()
748
+
749
+
750
+ def image_to_ids(self, image_data: list[list[list[int]]]) -> list[int]:
751
+ image_data = torch.tensor(image_data)
752
+ x, y, rgb = image_data.size()
753
+ assert rgb == 3
754
+ image_br_token = list(self.added_tokens_encoder["<|vision_br|>"])
755
+ image_special_byte_index = self.added_tokens_encoder["<|vision_start|>"][0]
756
+
757
+ # add img byte by padding to the beginning
758
+ image_data = torch.nn.functional.pad(
759
+ image_data, (1, 0), "constant", value=image_special_byte_index
760
+ ).view(x, y * 4)
761
+
762
+ image_data = torch.concat(
763
+ [image_data, torch.tensor(image_br_token * x).view(x, 4)], dim=1
764
+ ).view(-1)
765
+ return image_data.tolist()
766
+
767
+ def save_pretrained(
768
+ self,
769
+ save_directory: Union[str, os.PathLike],
770
+ legacy_format: Optional[bool] = None,
771
+ filename_prefix: Optional[str] = None,
772
+ push_to_hub: bool = False,
773
+ **kwargs,
774
+ ) -> Tuple[str]:
775
+ """
776
+ Save the full tokenizer state.
777
+
778
+
779
+ This method make sure the full tokenizer can then be re-loaded using the
780
+ [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`] class method..
781
+
782
+ Warning,None This won't save modifications you may have applied to the tokenizer after the instantiation (for
783
+ instance, modifying `tokenizer.do_lower_case` after creation).
784
+
785
+ Args:
786
+ save_directory (`str` or `os.PathLike`): The path to a directory where the tokenizer will be saved.
787
+ legacy_format (`bool`, *optional*):
788
+ Only applicable for a fast tokenizer. If unset (default), will save the tokenizer in the unified JSON
789
+ format as well as in legacy format if it exists, i.e. with tokenizer specific vocabulary and a separate
790
+ added_tokens files.
791
+
792
+ If `False`, will only save the tokenizer in the unified JSON format. This format is incompatible with
793
+ "slow" tokenizers (not powered by the *tokenizers* library), so the tokenizer will not be able to be
794
+ loaded in the corresponding "slow" tokenizer.
795
+
796
+ If `True`, will save the tokenizer in legacy format. If the "slow" tokenizer doesn't exits, a value
797
+ error is raised.
798
+ filename_prefix (`str`, *optional*):
799
+ A prefix to add to the names of the files saved by the tokenizer.
800
+ push_to_hub (`bool`, *optional*, defaults to `False`):
801
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
802
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
803
+ namespace).
804
+ kwargs (`Dict[str, Any]`, *optional*):
805
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
806
+
807
+ Returns:
808
+ A tuple of `str`: The files saved.
809
+ """
810
+ use_auth_token = kwargs.pop("use_auth_token", None)
811
+
812
+ if use_auth_token is not None:
813
+ warnings.warn(
814
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
815
+ FutureWarning,
816
+ )
817
+ if kwargs.get("token", None) is not None:
818
+ raise ValueError(
819
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
820
+ )
821
+ kwargs["token"] = use_auth_token
822
+
823
+ if os.path.isfile(save_directory):
824
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
825
+ return
826
+
827
+ os.makedirs(save_directory, exist_ok=True)
828
+
829
+ if push_to_hub:
830
+ commit_message = kwargs.pop("commit_message", None)
831
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
832
+ repo_id = self._create_repo(repo_id, **kwargs)
833
+ files_timestamps = self._get_files_timestamps(save_directory)
834
+
835
+ special_tokens_map_file = os.path.join(
836
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + SPECIAL_TOKENS_MAP_FILE
837
+ )
838
+ tokenizer_config_file = os.path.join(
839
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_CONFIG_FILE
840
+ )
841
+
842
+ tokenizer_config = copy.deepcopy(self.init_kwargs)
843
+
844
+ # Let's save the init kwargs
845
+ target_keys = set(self.init_kwargs.keys())
846
+ # Let's save the special tokens map (only the strings)
847
+ target_keys.update(["model_max_length", "clean_up_tokenization_spaces"])
848
+
849
+ for k in target_keys:
850
+ if hasattr(self, k):
851
+ tokenizer_config[k] = getattr(self, k)
852
+
853
+ # Let's make sure we properly save the special tokens.
854
+ tokenizer_config.update(self.special_tokens_map)
855
+
856
+ if self.chat_template is not None:
857
+ if isinstance(self.chat_template, dict):
858
+ # Chat template dicts are saved to the config as lists of dicts with fixed key names.
859
+ # They will be reconstructed as a single dict during loading.
860
+ tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()]
861
+ else:
862
+ tokenizer_config["chat_template"] = self.chat_template
863
+
864
+ if len(self.init_inputs) > 0:
865
+ tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)
866
+ for file_id in self.vocab_files_names.keys():
867
+ tokenizer_config.pop(file_id, None)
868
+
869
+ # no typefields, this way old fast and slow can load it
870
+ tokenizer_config = self.convert_added_tokens(tokenizer_config, add_type_field=True, save=True)
871
+
872
+ # Process added tokens seperatly: allows previous versions to ignore it!
873
+ added_tokens = {}
874
+ for key, value in self.added_tokens_decoder.items():
875
+ added_tokens[key] = value.__getstate__()
876
+ tokenizer_config["added_tokens_decoder"] = added_tokens
877
+
878
+ # Add tokenizer class to the tokenizer config to be able to reload it with from_pretrained
879
+ tokenizer_class = self.__class__.__name__
880
+ # Remove the Fast at the end unless we have a special `PreTrainedTokenizerFast`
881
+ if tokenizer_class.endswith("Fast") and tokenizer_class != "PreTrainedTokenizerFast":
882
+ tokenizer_class = tokenizer_class[:-4]
883
+ tokenizer_config["tokenizer_class"] = tokenizer_class
884
+ if getattr(self, "_auto_map", None) is not None:
885
+ tokenizer_config["auto_map"] = self._auto_map
886
+ if getattr(self, "_processor_class", None) is not None:
887
+ tokenizer_config["processor_class"] = self._processor_class
888
+
889
+ # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
890
+ # loaded from the Hub.
891
+ if self._auto_class is not None:
892
+ custom_object_save(self, save_directory, config=tokenizer_config)
893
+
894
+ # remove private information
895
+ if "name_or_path" in tokenizer_config:
896
+ tokenizer_config.pop("name_or_path")
897
+ tokenizer_config.pop("special_tokens_map_file", None)
898
+ tokenizer_config.pop("tokenizer_file", None)
899
+
900
+ with open(tokenizer_config_file, "w", encoding="utf-8") as f:
901
+ out_str = json.dumps(
902
+ make_serializeable(tokenizer_config),
903
+ indent=2,
904
+ sort_keys=True,
905
+ ensure_ascii=False
906
+ ) + "\n"
907
+ f.write(out_str)
908
+ logger.info(f"tokenizer config file saved in {tokenizer_config_file}")
909
+
910
+ # Sanitize AddedTokens in special_tokens_map
911
+
912
+ # kept for forward compatibility, will be removed in transoformers 5. Typefields are not saved for FC, special should not be save either
913
+ write_dict = self.convert_added_tokens(self.special_tokens_map_extended, save=True, add_type_field=False)
914
+ with open(special_tokens_map_file, "w", encoding="utf-8") as f:
915
+ out_str = json.dumps(write_dict, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
916
+ f.write(out_str)
917
+ logger.info(f"Special tokens file saved in {special_tokens_map_file}")
918
+
919
+ file_names = (tokenizer_config_file, special_tokens_map_file)
920
+
921
+ save_files = self._save_pretrained(
922
+ save_directory=save_directory,
923
+ file_names=file_names,
924
+ legacy_format=legacy_format,
925
+ filename_prefix=filename_prefix,
926
+ )
927
+
928
+ if push_to_hub:
929
+ self._upload_modified_files(
930
+ save_directory,
931
+ repo_id,
932
+ files_timestamps,
933
+ commit_message=commit_message,
934
+ token=kwargs.get("token"),
935
+ )
936
+
937
+ return save_files
938
+
939
+
940
+ def _save_pretrained(
941
+ self,
942
+ save_directory: Union[str, os.PathLike],
943
+ file_names: Tuple[str],
944
+ legacy_format: Optional[bool] = None,
945
+ filename_prefix: Optional[str] = None,
946
+ ) -> Tuple[str]:
947
+ """
948
+ Save a tokenizer using the slow-tokenizer/legacy format: vocabulary + added tokens.
949
+
950
+ Fast tokenizers can also be saved in a unique JSON file containing {config + vocab + added-tokens} using the
951
+ specific [`~tokenization_utils_fast.PreTrainedTokenizerFast._save_pretrained`]
952
+ """
953
+ if legacy_format is False:
954
+ raise ValueError(
955
+ "Only fast tokenizers (instances of PreTrainedTokenizerFast) can be saved in non legacy format."
956
+ )
957
+
958
+ save_directory = str(save_directory)
959
+
960
+ added_tokens_file = os.path.join(
961
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + ADDED_TOKENS_FILE
962
+ )
963
+ # the new get_added_vocab() also returns special tokens and tokens that have an index < vocab_size
964
+ # added_vocab = {tok: index for tok, index in self.added_tokens_encoder.items() if index >= self.vocab_size}
965
+ added_vocab = {tok: list(index) for tok, index in self.added_tokens_encoder.items()}
966
+ if added_vocab:
967
+ with open(added_tokens_file, "w", encoding="utf-8") as f:
968
+ out_str = json.dumps(added_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
969
+ f.write(out_str)
970
+ logger.info(f"added tokens file saved in {added_tokens_file}")
971
+
972
+ vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix)
973
+
974
+ return file_names + vocab_files + (added_tokens_file,)
975
+
976
+
977
+
978
+ @classmethod
979
+ def _from_pretrained(
980
+ cls,
981
+ resolved_vocab_files,
982
+ pretrained_model_name_or_path,
983
+ init_configuration,
984
+ *init_inputs,
985
+ token=None,
986
+ cache_dir=None,
987
+ local_files_only=False,
988
+ _commit_hash=None,
989
+ _is_local=False,
990
+ trust_remote_code=False,
991
+ **kwargs,
992
+ ):
993
+ # We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json
994
+ # file or if `from_slow` is set to True.
995
+ from_slow = kwargs.get("from_slow", False)
996
+ gguf_file = kwargs.get("gguf_file", None)
997
+ has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None
998
+
999
+ # If one passes a GGUF file path to `gguf_file` there is no need for this check as the tokenizer will be
1000
+ # loaded directly from the GGUF file.
1001
+ if (from_slow or not has_tokenizer_file) and cls.slow_tokenizer_class is not None and not gguf_file:
1002
+ slow_tokenizer = (cls.slow_tokenizer_class)._from_pretrained(
1003
+ copy.deepcopy(resolved_vocab_files),
1004
+ pretrained_model_name_or_path,
1005
+ copy.deepcopy(init_configuration),
1006
+ *init_inputs,
1007
+ token=token,
1008
+ cache_dir=cache_dir,
1009
+ local_files_only=local_files_only,
1010
+ _commit_hash=_commit_hash,
1011
+ **(copy.deepcopy(kwargs)),
1012
+ )
1013
+ else:
1014
+ slow_tokenizer = None
1015
+
1016
+ # Prepare tokenizer initialization kwargs
1017
+ # Did we saved some inputs and kwargs to reload ?
1018
+ tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None)
1019
+ if tokenizer_config_file is not None:
1020
+ with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
1021
+ init_kwargs = json.load(tokenizer_config_handle)
1022
+ # First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers.
1023
+
1024
+ config_tokenizer_class = init_kwargs.get("tokenizer_class")
1025
+ init_kwargs.pop("tokenizer_class", None)
1026
+ if not has_tokenizer_file:
1027
+ init_kwargs.pop("tokenizer_file", None)
1028
+ saved_init_inputs = init_kwargs.pop("init_inputs", ())
1029
+ if not init_inputs:
1030
+ init_inputs = saved_init_inputs
1031
+ else:
1032
+ config_tokenizer_class = None
1033
+ init_kwargs = init_configuration
1034
+
1035
+ if not _is_local:
1036
+ if "auto_map" in init_kwargs:
1037
+ # For backward compatibility with odl format.
1038
+ if isinstance(init_kwargs["auto_map"], (tuple, list)):
1039
+ init_kwargs["auto_map"] = {"AutoTokenizer": init_kwargs["auto_map"]}
1040
+
1041
+
1042
+ if config_tokenizer_class is None:
1043
+ # Matt: This entire block is only used to decide if the tokenizer class matches the class in the repo.
1044
+ # If not, it raises a warning, but otherwise continues. Since we mostly load tokenizers with
1045
+ # AutoTokenizer these days, it seems like a lot of work (and a source of bugs) for little gain.
1046
+ # Maybe we can just remove this entirely?
1047
+ from transformers.models.auto.configuration_auto import AutoConfig # tests_ignore
1048
+
1049
+ # Second attempt. If we have not yet found tokenizer_class, let's try to use the config.
1050
+ try:
1051
+ config = AutoConfig.from_pretrained(
1052
+ pretrained_model_name_or_path,
1053
+ token=token,
1054
+ cache_dir=cache_dir,
1055
+ local_files_only=local_files_only,
1056
+ trust_remote_code=trust_remote_code,
1057
+ _commit_hash=_commit_hash,
1058
+ )
1059
+ config_tokenizer_class = config.tokenizer_class
1060
+ except (OSError, ValueError, KeyError):
1061
+ # skip if an error occurred.
1062
+ config = None
1063
+ if config_tokenizer_class is None:
1064
+ # Third attempt. If we have not yet found the original type of the tokenizer,
1065
+ # we are loading we see if we can infer it from the type of the configuration file
1066
+ from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES # tests_ignore
1067
+
1068
+ if hasattr(config, "model_type"):
1069
+ model_type = config.model_type
1070
+ else:
1071
+ # Fallback: use pattern matching on the string.
1072
+ model_type = None
1073
+ for pattern in TOKENIZER_MAPPING_NAMES.keys():
1074
+ if pattern in str(pretrained_model_name_or_path):
1075
+ model_type = pattern
1076
+ break
1077
+
1078
+ if model_type is not None:
1079
+ config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING_NAMES.get(
1080
+ model_type, (None, None)
1081
+ )
1082
+ if config_tokenizer_class is None:
1083
+ config_tokenizer_class = config_tokenizer_class_fast
1084
+
1085
+ if config_tokenizer_class is not None:
1086
+ if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""):
1087
+ logger.warning(
1088
+ "The tokenizer class you load from this checkpoint is not the same type as the class this"
1089
+ " function is called from. It may result in unexpected tokenization. \nThe tokenizer class you"
1090
+ f" load from this checkpoint is '{config_tokenizer_class}'. \nThe class this function is called"
1091
+ f" from is '{cls.__name__}'."
1092
+ )
1093
+
1094
+ # Update with newly provided kwargs
1095
+ init_kwargs.update(kwargs)
1096
+
1097
+ # Merge resolved_vocab_files arguments in init_kwargs.
1098
+ added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None)
1099
+ special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
1100
+ for args_name, file_path in resolved_vocab_files.items():
1101
+ if args_name not in init_kwargs:
1102
+ init_kwargs[args_name] = file_path
1103
+ tokenizer_file = resolved_vocab_files.pop("tokenizer_file", None)
1104
+
1105
+ if slow_tokenizer is not None:
1106
+ init_kwargs["__slow_tokenizer"] = slow_tokenizer
1107
+ init_kwargs["name_or_path"] = pretrained_model_name_or_path
1108
+
1109
+ #### Handle tokenizer serialization of added and special tokens
1110
+ added_tokens_decoder: Dict[int, AddedToken] = {}
1111
+ added_tokens_map: Dict[str, AddedToken] = {}
1112
+ # if we have info on the slow added tokens
1113
+ if "added_tokens_decoder" in init_kwargs:
1114
+ for idx, token in init_kwargs["added_tokens_decoder"].items():
1115
+ if isinstance(token, dict):
1116
+ token = AddedToken(**token)
1117
+ if isinstance(token, AddedToken):
1118
+ added_tokens_decoder[ast.literal_eval(idx)] = token
1119
+ added_tokens_map[str(token)] = token
1120
+ else:
1121
+ raise ValueError(
1122
+ f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary or an AddedToken instance"
1123
+ )
1124
+ else:
1125
+ # begin legacy: read the added_tokens_file and update kwargs with special_tokens_map if modified
1126
+ if special_tokens_map_file is not None:
1127
+ with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
1128
+ special_tokens_map = json.load(special_tokens_map_handle)
1129
+ for key, value in special_tokens_map.items():
1130
+ if key in kwargs and kwargs[key]:
1131
+ # This value has already been redefined by the kwargs
1132
+ # We keep this new value and ignore the one stored in the special_tokens_map_file
1133
+ continue
1134
+ if isinstance(value, dict):
1135
+ value = AddedToken(**value, special=True)
1136
+ elif key == "additional_special_tokens" and isinstance(value, list):
1137
+ additional_special_tokens = init_kwargs.pop("additional_special_tokens", []) or []
1138
+ for token in value:
1139
+ token = AddedToken(**token, special=True) if isinstance(token, dict) else token
1140
+ if token not in additional_special_tokens:
1141
+ additional_special_tokens.append(token)
1142
+ value = additional_special_tokens
1143
+ init_kwargs[key] = value
1144
+
1145
+ # slow -> slow|fast, legacy: convert the `"added_tokens.json"` file to `added_tokens_decoder`.
1146
+ # this is for legacy purpose. We don't add the tokens after init for efficiency.
1147
+ if added_tokens_file is not None:
1148
+ special_tokens = []
1149
+ for key in cls.SPECIAL_TOKENS_ATTRIBUTES & init_kwargs.keys():
1150
+ if init_kwargs[key] is not None:
1151
+ if key == "additional_special_tokens":
1152
+ special_tokens += [str(token) for token in init_kwargs[key]]
1153
+ else:
1154
+ special_tokens.append(str(init_kwargs[key]))
1155
+
1156
+ with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
1157
+ added_tok_encoder = json.load(added_tokens_handle)
1158
+ for str_token, index in added_tok_encoder.items():
1159
+ # if index not in added_tokens_decoder and str_token not in added_tokens_map:
1160
+ special = str_token in special_tokens
1161
+ added_tokens_decoder[index] = AddedToken(
1162
+ str_token, rstrip=False, lstrip=False, normalized=not special, special=special
1163
+ )
1164
+ added_tokens_map[str(token)] = added_tokens_decoder[index]
1165
+
1166
+ # allows converting a fast -> slow: add the `tokenizer.json`'s `"added_tokens"` to the slow tokenizer
1167
+ # if `tokenizer_config.json` is `None`
1168
+ if tokenizer_file is not None:
1169
+ # This is for slow so can be done before
1170
+ with open(tokenizer_file, encoding="utf-8") as tokenizer_file_handle:
1171
+ tokenizer_file_handle = json.load(tokenizer_file_handle)
1172
+ added_tokens = tokenizer_file_handle.pop("added_tokens")
1173
+ for serialized_tokens in added_tokens:
1174
+ idx = serialized_tokens.pop("id")
1175
+ added_tokens_decoder[idx] = AddedToken(**serialized_tokens)
1176
+ added_tokens_map[str(added_tokens_decoder[idx])] = added_tokens_decoder[idx]
1177
+ # end legacy
1178
+
1179
+ # Passing AddedTokens and not strings to the class to prevent it from casting the string to a different AddedToken
1180
+ # convert {'__type': 'AddedToken', 'content': '<ent>', 'lstrip': False, 'normalized': True, ...} to AddedTokens
1181
+ init_kwargs["added_tokens_decoder"] = added_tokens_decoder
1182
+ init_kwargs = cls.convert_added_tokens(init_kwargs, save=False)
1183
+ for key in cls.SPECIAL_TOKENS_ATTRIBUTES & init_kwargs.keys():
1184
+ if added_tokens_map != {} and init_kwargs[key] is not None:
1185
+ if key != "additional_special_tokens":
1186
+ init_kwargs[key] = added_tokens_map.get(str(init_kwargs[key]), init_kwargs[key])
1187
+
1188
+ # Instantiate the tokenizer.
1189
+ try:
1190
+ tokenizer = cls(*init_inputs, **init_kwargs)
1191
+ except OSError:
1192
+ raise OSError(
1193
+ "Unable to load vocabulary from file. "
1194
+ "Please check that the provided vocabulary is accessible and not corrupted."
1195
+ )
1196
+
1197
+ # if added_tokens_decoder != {} and max(list(added_tokens_decoder.keys())[-1], 0) > tokenizer.vocab_size:
1198
+ # logger.warning_advice(
1199
+ # "Special tokens have been added in the vocabulary, make sure the associated word embeddings are"
1200
+ # " fine-tuned or trained."
1201
+ # )
1202
+ return tokenizer
1203
+
1204
+
1205
+
tokenizer_config.json ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "(192, 128, 64, 0)": {
4
+ "content": "<|pad|>",
5
+ "lstrip": false,
6
+ "normalized": true,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "(193, 128, 64, 0)": {
12
+ "content": "<|end_of_text|>",
13
+ "lstrip": false,
14
+ "normalized": true,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "(194, 128, 64, 0)": {
20
+ "content": "<|begin_of_text|>",
21
+ "lstrip": false,
22
+ "normalized": true,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "(195, 128, 64, 0)": {
28
+ "content": "<|cls|>",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "(196, 128, 64, 0)": {
36
+ "content": "<|sep|>",
37
+ "lstrip": false,
38
+ "normalized": true,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ },
43
+ "(197, 128, 64, 0)": {
44
+ "content": "<|mask|>",
45
+ "lstrip": false,
46
+ "normalized": true,
47
+ "rstrip": false,
48
+ "single_word": false,
49
+ "special": true
50
+ },
51
+ "(198, 128, 64, 0)": {
52
+ "content": "<|vision_start|>",
53
+ "lstrip": false,
54
+ "normalized": true,
55
+ "rstrip": false,
56
+ "single_word": false,
57
+ "special": false
58
+ },
59
+ "(199, 128, 64, 0)": {
60
+ "content": "<|vision_br|>",
61
+ "lstrip": false,
62
+ "normalized": true,
63
+ "rstrip": false,
64
+ "single_word": false,
65
+ "special": false
66
+ },
67
+ "(200, 128, 64, 0)": {
68
+ "content": "<|vision_end|>",
69
+ "lstrip": false,
70
+ "normalized": true,
71
+ "rstrip": false,
72
+ "single_word": false,
73
+ "special": false
74
+ },
75
+ "(201, 128, 64, 0)": {
76
+ "content": "<|start_header_id|>",
77
+ "lstrip": false,
78
+ "normalized": true,
79
+ "rstrip": false,
80
+ "single_word": false,
81
+ "special": false
82
+ },
83
+ "(202, 128, 64, 0)": {
84
+ "content": "<|end_header_id|>",
85
+ "lstrip": false,
86
+ "normalized": true,
87
+ "rstrip": false,
88
+ "single_word": false,
89
+ "special": false
90
+ },
91
+ "(203, 128, 64, 0)": {
92
+ "content": "<|end_of_role|>",
93
+ "lstrip": false,
94
+ "normalized": true,
95
+ "rstrip": false,
96
+ "single_word": false,
97
+ "special": false
98
+ },
99
+ "(204, 128, 64, 0)": {
100
+ "content": "<|extra_id_0|>",
101
+ "lstrip": false,
102
+ "normalized": true,
103
+ "rstrip": false,
104
+ "single_word": false,
105
+ "special": false
106
+ },
107
+ "(205, 128, 64, 0)": {
108
+ "content": "<|extra_id_1|>",
109
+ "lstrip": false,
110
+ "normalized": true,
111
+ "rstrip": false,
112
+ "single_word": false,
113
+ "special": false
114
+ },
115
+ "(206, 128, 64, 0)": {
116
+ "content": "<|extra_id_2|>",
117
+ "lstrip": false,
118
+ "normalized": true,
119
+ "rstrip": false,
120
+ "single_word": false,
121
+ "special": false
122
+ },
123
+ "(207, 128, 64, 0)": {
124
+ "content": "<|extra_id_3|>",
125
+ "lstrip": false,
126
+ "normalized": true,
127
+ "rstrip": false,
128
+ "single_word": false,
129
+ "special": false
130
+ },
131
+ "(208, 128, 64, 0)": {
132
+ "content": "<|extra_id_4|>",
133
+ "lstrip": false,
134
+ "normalized": true,
135
+ "rstrip": false,
136
+ "single_word": false,
137
+ "special": false
138
+ },
139
+ "(209, 128, 64, 0)": {
140
+ "content": "<|extra_id_5|>",
141
+ "lstrip": false,
142
+ "normalized": true,
143
+ "rstrip": false,
144
+ "single_word": false,
145
+ "special": false
146
+ },
147
+ "(210, 128, 64, 0)": {
148
+ "content": "<|extra_id_6|>",
149
+ "lstrip": false,
150
+ "normalized": true,
151
+ "rstrip": false,
152
+ "single_word": false,
153
+ "special": false
154
+ },
155
+ "(211, 128, 64, 0)": {
156
+ "content": "<|extra_id_7|>",
157
+ "lstrip": false,
158
+ "normalized": true,
159
+ "rstrip": false,
160
+ "single_word": false,
161
+ "special": false
162
+ },
163
+ "(212, 128, 64, 0)": {
164
+ "content": "<|extra_id_8|>",
165
+ "lstrip": false,
166
+ "normalized": true,
167
+ "rstrip": false,
168
+ "single_word": false,
169
+ "special": false
170
+ },
171
+ "(213, 128, 64, 0)": {
172
+ "content": "<|extra_id_9|>",
173
+ "lstrip": false,
174
+ "normalized": true,
175
+ "rstrip": false,
176
+ "single_word": false,
177
+ "special": false
178
+ },
179
+ "(214, 128, 64, 0)": {
180
+ "content": "<|extra_id_10|>",
181
+ "lstrip": false,
182
+ "normalized": true,
183
+ "rstrip": false,
184
+ "single_word": false,
185
+ "special": false
186
+ },
187
+ "(215, 128, 64, 0)": {
188
+ "content": "<|extra_id_11|>",
189
+ "lstrip": false,
190
+ "normalized": true,
191
+ "rstrip": false,
192
+ "single_word": false,
193
+ "special": false
194
+ },
195
+ "(216, 128, 64, 0)": {
196
+ "content": "<|extra_id_12|>",
197
+ "lstrip": false,
198
+ "normalized": true,
199
+ "rstrip": false,
200
+ "single_word": false,
201
+ "special": false
202
+ },
203
+ "(217, 128, 64, 0)": {
204
+ "content": "<|extra_id_13|>",
205
+ "lstrip": false,
206
+ "normalized": true,
207
+ "rstrip": false,
208
+ "single_word": false,
209
+ "special": false
210
+ },
211
+ "(218, 128, 64, 0)": {
212
+ "content": "<|extra_id_14|>",
213
+ "lstrip": false,
214
+ "normalized": true,
215
+ "rstrip": false,
216
+ "single_word": false,
217
+ "special": false
218
+ },
219
+ "(219, 128, 64, 0)": {
220
+ "content": "<|extra_id_15|>",
221
+ "lstrip": false,
222
+ "normalized": true,
223
+ "rstrip": false,
224
+ "single_word": false,
225
+ "special": false
226
+ },
227
+ "(220, 128, 64, 0)": {
228
+ "content": "<|extra_id_16|>",
229
+ "lstrip": false,
230
+ "normalized": true,
231
+ "rstrip": false,
232
+ "single_word": false,
233
+ "special": false
234
+ },
235
+ "(221, 128, 64, 0)": {
236
+ "content": "<|extra_id_17|>",
237
+ "lstrip": false,
238
+ "normalized": true,
239
+ "rstrip": false,
240
+ "single_word": false,
241
+ "special": false
242
+ },
243
+ "(222, 128, 64, 0)": {
244
+ "content": "<|extra_id_18|>",
245
+ "lstrip": false,
246
+ "normalized": true,
247
+ "rstrip": false,
248
+ "single_word": false,
249
+ "special": false
250
+ },
251
+ "(223, 128, 64, 0)": {
252
+ "content": "<|extra_id_19|>",
253
+ "lstrip": false,
254
+ "normalized": true,
255
+ "rstrip": false,
256
+ "single_word": false,
257
+ "special": false
258
+ },
259
+ "(224, 128, 64, 0)": {
260
+ "content": "<|extra_id_20|>",
261
+ "lstrip": false,
262
+ "normalized": true,
263
+ "rstrip": false,
264
+ "single_word": false,
265
+ "special": false
266
+ },
267
+ "(225, 128, 64, 0)": {
268
+ "content": "<|extra_id_21|>",
269
+ "lstrip": false,
270
+ "normalized": true,
271
+ "rstrip": false,
272
+ "single_word": false,
273
+ "special": false
274
+ },
275
+ "(226, 128, 64, 0)": {
276
+ "content": "<|extra_id_22|>",
277
+ "lstrip": false,
278
+ "normalized": true,
279
+ "rstrip": false,
280
+ "single_word": false,
281
+ "special": false
282
+ },
283
+ "(227, 128, 64, 0)": {
284
+ "content": "<|extra_id_23|>",
285
+ "lstrip": false,
286
+ "normalized": true,
287
+ "rstrip": false,
288
+ "single_word": false,
289
+ "special": false
290
+ },
291
+ "(228, 128, 64, 0)": {
292
+ "content": "<|extra_id_24|>",
293
+ "lstrip": false,
294
+ "normalized": true,
295
+ "rstrip": false,
296
+ "single_word": false,
297
+ "special": false
298
+ },
299
+ "(229, 128, 64, 0)": {
300
+ "content": "<|extra_id_25|>",
301
+ "lstrip": false,
302
+ "normalized": true,
303
+ "rstrip": false,
304
+ "single_word": false,
305
+ "special": false
306
+ },
307
+ "(230, 128, 64, 0)": {
308
+ "content": "<|extra_id_26|>",
309
+ "lstrip": false,
310
+ "normalized": true,
311
+ "rstrip": false,
312
+ "single_word": false,
313
+ "special": false
314
+ },
315
+ "(231, 128, 64, 0)": {
316
+ "content": "<|extra_id_27|>",
317
+ "lstrip": false,
318
+ "normalized": true,
319
+ "rstrip": false,
320
+ "single_word": false,
321
+ "special": false
322
+ },
323
+ "(232, 128, 64, 0)": {
324
+ "content": "<|extra_id_28|>",
325
+ "lstrip": false,
326
+ "normalized": true,
327
+ "rstrip": false,
328
+ "single_word": false,
329
+ "special": false
330
+ },
331
+ "(233, 128, 64, 0)": {
332
+ "content": "<|extra_id_29|>",
333
+ "lstrip": false,
334
+ "normalized": true,
335
+ "rstrip": false,
336
+ "single_word": false,
337
+ "special": false
338
+ },
339
+ "(234, 128, 64, 0)": {
340
+ "content": "<|extra_id_30|>",
341
+ "lstrip": false,
342
+ "normalized": true,
343
+ "rstrip": false,
344
+ "single_word": false,
345
+ "special": false
346
+ },
347
+ "(235, 128, 64, 0)": {
348
+ "content": "<|extra_id_31|>",
349
+ "lstrip": false,
350
+ "normalized": true,
351
+ "rstrip": false,
352
+ "single_word": false,
353
+ "special": false
354
+ },
355
+ "(236, 128, 64, 0)": {
356
+ "content": "<|extra_id_32|>",
357
+ "lstrip": false,
358
+ "normalized": true,
359
+ "rstrip": false,
360
+ "single_word": false,
361
+ "special": false
362
+ },
363
+ "(237, 128, 64, 0)": {
364
+ "content": "<|extra_id_33|>",
365
+ "lstrip": false,
366
+ "normalized": true,
367
+ "rstrip": false,
368
+ "single_word": false,
369
+ "special": false
370
+ },
371
+ "(238, 128, 64, 0)": {
372
+ "content": "<|extra_id_34|>",
373
+ "lstrip": false,
374
+ "normalized": true,
375
+ "rstrip": false,
376
+ "single_word": false,
377
+ "special": false
378
+ },
379
+ "(239, 128, 64, 0)": {
380
+ "content": "<|extra_id_35|>",
381
+ "lstrip": false,
382
+ "normalized": true,
383
+ "rstrip": false,
384
+ "single_word": false,
385
+ "special": false
386
+ },
387
+ "(240, 128, 64, 0)": {
388
+ "content": "<|extra_id_36|>",
389
+ "lstrip": false,
390
+ "normalized": true,
391
+ "rstrip": false,
392
+ "single_word": false,
393
+ "special": false
394
+ },
395
+ "(241, 128, 64, 0)": {
396
+ "content": "<|extra_id_37|>",
397
+ "lstrip": false,
398
+ "normalized": true,
399
+ "rstrip": false,
400
+ "single_word": false,
401
+ "special": false
402
+ },
403
+ "(242, 128, 64, 0)": {
404
+ "content": "<|extra_id_38|>",
405
+ "lstrip": false,
406
+ "normalized": true,
407
+ "rstrip": false,
408
+ "single_word": false,
409
+ "special": false
410
+ },
411
+ "(243, 128, 64, 0)": {
412
+ "content": "<|extra_id_39|>",
413
+ "lstrip": false,
414
+ "normalized": true,
415
+ "rstrip": false,
416
+ "single_word": false,
417
+ "special": false
418
+ },
419
+ "(244, 128, 64, 0)": {
420
+ "content": "<|extra_id_40|>",
421
+ "lstrip": false,
422
+ "normalized": true,
423
+ "rstrip": false,
424
+ "single_word": false,
425
+ "special": false
426
+ },
427
+ "(245, 128, 64, 0)": {
428
+ "content": "<|extra_id_41|>",
429
+ "lstrip": false,
430
+ "normalized": true,
431
+ "rstrip": false,
432
+ "single_word": false,
433
+ "special": false
434
+ },
435
+ "(246, 128, 64, 0)": {
436
+ "content": "<|extra_id_42|>",
437
+ "lstrip": false,
438
+ "normalized": true,
439
+ "rstrip": false,
440
+ "single_word": false,
441
+ "special": false
442
+ },
443
+ "(247, 128, 64, 0)": {
444
+ "content": "<|extra_id_43|>",
445
+ "lstrip": false,
446
+ "normalized": true,
447
+ "rstrip": false,
448
+ "single_word": false,
449
+ "special": false
450
+ },
451
+ "(248, 128, 64, 0)": {
452
+ "content": "<|extra_id_44|>",
453
+ "lstrip": false,
454
+ "normalized": true,
455
+ "rstrip": false,
456
+ "single_word": false,
457
+ "special": false
458
+ },
459
+ "(249, 128, 64, 0)": {
460
+ "content": "<|extra_id_45|>",
461
+ "lstrip": false,
462
+ "normalized": true,
463
+ "rstrip": false,
464
+ "single_word": false,
465
+ "special": false
466
+ },
467
+ "(250, 128, 64, 0)": {
468
+ "content": "<|extra_id_46|>",
469
+ "lstrip": false,
470
+ "normalized": true,
471
+ "rstrip": false,
472
+ "single_word": false,
473
+ "special": false
474
+ }
475
+ },
476
+ "bos_token": "<|begin_of_text|>",
477
+ "clean_up_tokenization_spaces": true,
478
+ "cls_token": "<|cls|>",
479
+ "end_header_id_token": {
480
+ "__type": "AddedToken",
481
+ "content": "<|end_header_id|>",
482
+ "lstrip": false,
483
+ "normalized": true,
484
+ "rstrip": false,
485
+ "single_word": false,
486
+ "special": false
487
+ },
488
+ "eor_id": {
489
+ "__type": "AddedToken",
490
+ "content": "<|end_of_role|>",
491
+ "lstrip": false,
492
+ "normalized": true,
493
+ "rstrip": false,
494
+ "single_word": false,
495
+ "special": false
496
+ },
497
+ "eos_token": "<|end_of_text|>",
498
+ "mask_token": "<|mask|>",
499
+ "model_max_length": 1000000000000000019884624838656,
500
+ "pad_token": "<|pad|>",
501
+ "sep_token": "<|sep|>",
502
+ "start_header_id_token": {
503
+ "__type": "AddedToken",
504
+ "content": "<|start_header_id|>",
505
+ "lstrip": false,
506
+ "normalized": true,
507
+ "rstrip": false,
508
+ "single_word": false,
509
+ "special": false
510
+ },
511
+ "tokenizer_class": "ByteLMTokenizerV3",
512
+ "vision_br_token": {
513
+ "__type": "AddedToken",
514
+ "content": "<|vision_br|>",
515
+ "lstrip": false,
516
+ "normalized": true,
517
+ "rstrip": false,
518
+ "single_word": false,
519
+ "special": false
520
+ },
521
+ "vision_end_token": {
522
+ "__type": "AddedToken",
523
+ "content": "<|vision_end|>",
524
+ "lstrip": false,
525
+ "normalized": true,
526
+ "rstrip": false,
527
+ "single_word": false,
528
+ "special": false
529
+ },
530
+ "vision_start_token": {
531
+ "__type": "AddedToken",
532
+ "content": "<|vision_start|>",
533
+ "lstrip": false,
534
+ "normalized": true,
535
+ "rstrip": false,
536
+ "single_word": false,
537
+ "special": false
538
+ },
539
+ "auto_map": {
540
+ "AutoTokenizer": [
541
+ "tokenization_utf8_like_byte_v3.ByteLMTokenizerV3",
542
+ null
543
+ ]
544
+ },
545
+ "padding_side": "left"
546
+ }