codefuse-admin commited on
Commit
5021525
1 Parent(s): b36ac6e

upload model from ant-group,[email protected]

Browse files
LICENSE.md ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright [2023] [Ant Group]
2
+ Licensed under the Apache License, Version 2.0 (the "License");
3
+ you may not use this file except in compliance with the License.
4
+ You may obtain a copy of the License at
5
+ http://www.apache.org/licenses/LICENSE-2.0
6
+
7
+ Unless required by applicable law or agreed to in writing, software
8
+ distributed under the License is distributed on an "AS IS" BASIS,
9
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ See the License for the specific language governing permissions and
11
+ limitations under the License.
12
+
13
+
14
+ Apache License
15
+ Version 2.0, January 2004
16
+ http://www.apache.org/licenses/
17
+
18
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
19
+
20
+ 1. Definitions.
21
+
22
+ "License" shall mean the terms and conditions for use, reproduction,
23
+ and distribution as defined by Sections 1 through 9 of this document.
24
+
25
+ "Licensor" shall mean the copyright owner or entity authorized by
26
+ the copyright owner that is granting the License.
27
+
28
+ "Legal Entity" shall mean the union of the acting entity and all
29
+ other entities that control, are controlled by, or are under common
30
+ control with that entity. For the purposes of this definition,
31
+ "control" means (i) the power, direct or indirect, to cause the
32
+ direction or management of such entity, whether by contract or
33
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
34
+ outstanding shares, or (iii) beneficial ownership of such entity.
35
+
36
+ "You" (or "Your") shall mean an individual or Legal Entity
37
+ exercising permissions granted by this License.
38
+
39
+ "Source" form shall mean the preferred form for making modifications,
40
+ including but not limited to software source code, documentation
41
+ source, and configuration files.
42
+
43
+ "Object" form shall mean any form resulting from mechanical
44
+ transformation or translation of a Source form, including but
45
+ not limited to compiled object code, generated documentation,
46
+ and conversions to other media types.
47
+
48
+ "Work" shall mean the work of authorship, whether in Source or
49
+ Object form, made available under the License, as indicated by a
50
+ copyright notice that is included in or attached to the work
51
+ (an example is provided in the Appendix below).
52
+
53
+ "Derivative Works" shall mean any work, whether in Source or Object
54
+ form, that is based on (or derived from) the Work and for which the
55
+ editorial revisions, annotations, elaborations, or other modifications
56
+ represent, as a whole, an original work of authorship. For the purposes
57
+ of this License, Derivative Works shall not include works that remain
58
+ separable from, or merely link (or bind by name) to the interfaces of,
59
+ the Work and Derivative Works thereof.
60
+
61
+ "Contribution" shall mean any work of authorship, including
62
+ the original version of the Work and any modifications or additions
63
+ to that Work or Derivative Works thereof, that is intentionally
64
+ submitted to Licensor for inclusion in the Work by the copyright owner
65
+ or by an individual or Legal Entity authorized to submit on behalf of
66
+ the copyright owner. For the purposes of this definition, "submitted"
67
+ means any form of electronic, verbal, or written communication sent
68
+ to the Licensor or its representatives, including but not limited to
69
+ communication on electronic mailing lists, source code control systems,
70
+ and issue tracking systems that are managed by, or on behalf of, the
71
+ Licensor for the purpose of discussing and improving the Work, but
72
+ excluding communication that is conspicuously marked or otherwise
73
+ designated in writing by the copyright owner as "Not a Contribution."
74
+
75
+ "Contributor" shall mean Licensor and any individual or Legal Entity
76
+ on behalf of whom a Contribution has been received by Licensor and
77
+ subsequently incorporated within the Work.
78
+
79
+ 2. Grant of Copyright License. Subject to the terms and conditions of
80
+ this License, each Contributor hereby grants to You a perpetual,
81
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
82
+ copyright license to reproduce, prepare Derivative Works of,
83
+ publicly display, publicly perform, sublicense, and distribute the
84
+ Work and such Derivative Works in Source or Object form.
85
+
86
+ 3. Grant of Patent License. Subject to the terms and conditions of
87
+ this License, each Contributor hereby grants to You a perpetual,
88
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
89
+ (except as stated in this section) patent license to make, have made,
90
+ use, offer to sell, sell, import, and otherwise transfer the Work,
91
+ where such license applies only to those patent claims licensable
92
+ by such Contributor that are necessarily infringed by their
93
+ Contribution(s) alone or by combination of their Contribution(s)
94
+ with the Work to which such Contribution(s) was submitted. If You
95
+ institute patent litigation against any entity (including a
96
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
97
+ or a Contribution incorporated within the Work constitutes direct
98
+ or contributory patent infringement, then any patent licenses
99
+ granted to You under this License for that Work shall terminate
100
+ as of the date such litigation is filed.
101
+
102
+ 4. Redistribution. You may reproduce and distribute copies of the
103
+ Work or Derivative Works thereof in any medium, with or without
104
+ modifications, and in Source or Object form, provided that You
105
+ meet the following conditions:
106
+
107
+ (a) You must give any other recipients of the Work or
108
+ Derivative Works a copy of this License; and
109
+
110
+ (b) You must cause any modified files to carry prominent notices
111
+ stating that You changed the files; and
112
+
113
+ (c) You must retain, in the Source form of any Derivative Works
114
+ that You distribute, all copyright, patent, trademark, and
115
+ attribution notices from the Source form of the Work,
116
+ excluding those notices that do not pertain to any part of
117
+ the Derivative Works; and
118
+
119
+ (d) If the Work includes a "NOTICE" text file as part of its
120
+ distribution, then any Derivative Works that You distribute must
121
+ include a readable copy of the attribution notices contained
122
+ within such NOTICE file, excluding those notices that do not
123
+ pertain to any part of the Derivative Works, in at least one
124
+ of the following places: within a NOTICE text file distributed
125
+ as part of the Derivative Works; within the Source form or
126
+ documentation, if provided along with the Derivative Works; or,
127
+ within a display generated by the Derivative Works, if and
128
+ wherever such third-party notices normally appear. The contents
129
+ of the NOTICE file are for informational purposes only and
130
+ do not modify the License. You may add Your own attribution
131
+ notices within Derivative Works that You distribute, alongside
132
+ or as an addendum to the NOTICE text from the Work, provided
133
+ that such additional attribution notices cannot be construed
134
+ as modifying the License.
135
+
136
+ You may add Your own copyright statement to Your modifications and
137
+ may provide additional or different license terms and conditions
138
+ for use, reproduction, or distribution of Your modifications, or
139
+ for any such Derivative Works as a whole, provided Your use,
140
+ reproduction, and distribution of the Work otherwise complies with
141
+ the conditions stated in this License.
142
+
143
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
144
+ any Contribution intentionally submitted for inclusion in the Work
145
+ by You to the Licensor shall be under the terms and conditions of
146
+ this License, without any additional terms or conditions.
147
+ Notwithstanding the above, nothing herein shall supersede or modify
148
+ the terms of any separate license agreement you may have executed
149
+ with Licensor regarding such Contributions.
150
+
151
+ 6. Trademarks. This License does not grant permission to use the trade
152
+ names, trademarks, service marks, or product names of the Licensor,
153
+ except as required for reasonable and customary use in describing the
154
+ origin of the Work and reproducing the content of the NOTICE file.
155
+
156
+ 7. Disclaimer of Warranty. Unless required by applicable law or
157
+ agreed to in writing, Licensor provides the Work (and each
158
+ Contributor provides its Contributions) on an "AS IS" BASIS,
159
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
160
+ implied, including, without limitation, any warranties or conditions
161
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
162
+ PARTICULAR PURPOSE. You are solely responsible for determining the
163
+ appropriateness of using or redistributing the Work and assume any
164
+ risks associated with Your exercise of permissions under this License.
165
+
166
+ 8. Limitation of Liability. In no event and under no legal theory,
167
+ whether in tort (including negligence), contract, or otherwise,
168
+ unless required by applicable law (such as deliberate and grossly
169
+ negligent acts) or agreed to in writing, shall any Contributor be
170
+ liable to You for damages, including any direct, indirect, special,
171
+ incidental, or consequential damages of any character arising as a
172
+ result of this License or out of the use or inability to use the
173
+ Work (including but not limited to damages for loss of goodwill,
174
+ work stoppage, computer failure or malfunction, or any and all
175
+ other commercial damages or losses), even if such Contributor
176
+ has been advised of the possibility of such damages.
177
+
178
+ 9. Accepting Warranty or Additional Liability. While redistributing
179
+ the Work or Derivative Works thereof, You may choose to offer,
180
+ and charge a fee for, acceptance of support, warranty, indemnity,
181
+ or other liability obligations and/or rights consistent with this
182
+ License. However, in accepting such obligations, You may act only
183
+ on Your own behalf and on Your sole responsibility, not on behalf
184
+ of any other Contributor, and only if You agree to indemnify,
185
+ defend, and hold each Contributor harmless for any liability
186
+ incurred by, or claims asserted against, such Contributor by reason
187
+ of your accepting any such warranty or additional liability.
188
+
189
+ END OF TERMS AND CONDITIONS
190
+
191
+ APPENDIX: How to apply the Apache License to your work.
192
+
193
+ To apply the Apache License to your work, attach the following
194
+ boilerplate notice, with the fields enclosed by brackets "[]"
195
+ replaced with your own identifying information. (Don't include
196
+ the brackets!) The text should be enclosed in the appropriate
197
+ comment syntax for the file format. We also recommend that a
198
+ file or class name and description of purpose be included on the
199
+ same "printed page" as the copyright notice for easier
200
+ identification within third-party archives.
201
+
202
+ Copyright [yyyy] [name of copyright owner]
203
+
204
+ Licensed under the Apache License, Version 2.0 (the "License");
205
+ you may not use this file except in compliance with the License.
206
+ You may obtain a copy of the License at
207
+
208
+ http://www.apache.org/licenses/LICENSE-2.0
209
+
210
+ Unless required by applicable law or agreed to in writing, software
211
+ distributed under the License is distributed on an "AS IS" BASIS,
212
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
213
+ See the License for the specific language governing permissions and
214
+ limitations under the License.
README.md CHANGED
@@ -1,5 +1,105 @@
1
- ---
2
- license: other
3
- license_name: license.md
4
- license_link: LICENSE
5
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>
3
+ DevOps-Model-7B-Chat
4
+ </h1>
5
+ </div>
6
+
7
+ <p align="center">
8
+ 🤗 <a href="https://huggingface.co/codefuse-ai" target="_blank">Hugging Face</a> •
9
+ 🤖 <a href="https://modelscope.cn/organization/codefuse-ai" target="_blank">ModelScope</a>
10
+ </p>
11
+
12
+ DevOps-Model 是一个**开发运维大模型**,主要致力于在 DevOps 领域发挥实际价值。目前,DevOps-Model 能够帮助工程师回答在 DevOps 生命周期中遇到的问题。欢迎访问我们 Github 获取更多信息 [DevOps-Model](https://github.com/codefuse-ai/CodeFuse-DevOps-Model)
13
+
14
+ DevOps-Model-7B-Chat 是我们经过高质量 DevOps 语料训练基于 Qwen-7B 加训然后再经过对齐的 Chat 版本模型,我们的 Chat 模型在开源和 DevOps 领域相关的评测数据上可以取得同规模模型中的**最佳效果**。同时我们也开源了经过加训后的 [DevOps-Model-7B-Base](https://modelscope.cn/models/codefuse-ai/CodeFuse-DevOps-Model-7B-Base/summary) 模型,和 14B 参数量的[DevOps-Model-14B-Base](https://modelscope.cn/models/codefuse-ai/CodeFuse-DevOps-Model-14B-Base/summary) 和 [DevOps-Model-14B-Chat](https://modelscope.cn/models/codefuse-ai/CodeFuse-DevOps-Model-14B-Chat/summary) 。
15
+ <br>
16
+ 同时我们也在搭建 DevOps 领域专属的评测基准 [DevOpsEval](https://github.com/codefuse-ai/codefuse-devops-eval),用来更好评测 DevOps 领域模型的效果。
17
+
18
+ <br>
19
+ <br>
20
+
21
+ # 模型评测
22
+ 我们先选取了 CMMLU 和 CEval 两个评测数据集中和 DevOps 相关的一共六项考试。总计一共 574 道选择题,具体信息如下:
23
+
24
+ | 评测数据集 | 考试科目 | 题数 |
25
+ |-------|-------|-------|
26
+ | CMMLU | Computer science | 204 |
27
+ | CMMLU | Computer security | 171 |
28
+ | CMMLU | Machine learning | 122 |
29
+ | CEval | College programming | 37 |
30
+ | CEval | Computer architecture | 21 |
31
+ | CEval | Computernetwork | 19 |
32
+
33
+ 我们分别测试了 Zero-shot 和 Five-shot 的结果,我们的 DevOps-Model-7B-Chat 模型可以在测试的同规模的开源 Chat 模型中取得最高的成绩,后续我们也会进行更多的测试。
34
+
35
+ |模型|模型大小|Zero-shot 得分|Five-shot 得分|
36
+ |--|--|--|--|
37
+ |**DevOps-Model-7B-Chat**|**7B**|**62.20**|**64.11**|
38
+ |Qwen-7B-Chat|7B|46.00|52.44|
39
+ |Baichuan2-7B-Chat|7B|52.26|54.46|
40
+ |Internlm-7B-Chat|7B|52.61|55.75|
41
+
42
+
43
+ <br>
44
+
45
+ # 快速使用
46
+ 我们提供简单的示例来说明如何利用 🤗 Transformers 快速使用 Devops-Model-7B-Chat 模型
47
+
48
+ ## 要求
49
+ - python 3.8 及以上版本
50
+ - pytorch 2.0 及以上版本
51
+ - 建议使用CUDA 11.4及以上
52
+
53
+
54
+ ## 依赖项安装
55
+ 下载模型后,直接通过以下命令安装 requirements.txt 中的包就可以
56
+ ```bash
57
+ cd path_to_download_model
58
+ pip isntall -r requirements.txt
59
+ ```
60
+
61
+ ## 模型推理示例
62
+
63
+ ```python
64
+ from transformers import AutoModelForCausalLM, AutoTokenizer
65
+ from transformers.generation import GenerationConfig
66
+
67
+ tokenizer = AutoTokenizer.from_pretrained("path_to_DevOps-Model-7B-Chat", trust_remote_code=True)
68
+
69
+ model = AutoModelForCausalLM.from_pretrained("path_to_DevOps-Model-7B-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval()
70
+
71
+ # 指定 generation_config
72
+ model.generation_config = GenerationConfig.from_pretrained("path_to_DevOps-Model-7B-Chat", trust_remote_code=True)
73
+
74
+ # 第一轮对话
75
+ resp, hist = model.chat(query='你是谁', tokenizer=tokenizer, history=None)
76
+ print(resp)
77
+ # 我是 DevOps-Model,一个由蚂蚁集团平台技术事业群风险智能团队和北京大学联合研发的人工智能机器人,可以与用户进行自然语言交互,并协助解答 DevOps 全生命周期中的各种问题。如果您有任何需要协助的问题或者想要进行闲聊,都可以和我交流哦。
78
+
79
+ # 第二轮对话
80
+ resp2, hist2 = model.chat(query='Java 中 HashMap 和 Hashtable 有什么区别', tokenizer=tokenizer, history=hist)
81
+ print(resp2)
82
+ # HashMap 和 Hashtable 都是 Java 中常用的哈希表实现,它们的主要区别在于:
83
+ # 1. Hashtable 是线程安全的,而 HashMap 不是线程安全的,因此在多线程环境下,Hashtable 的性能更稳定。
84
+ # 2. Hashtable 中的方法都是同步的,而 HashMap 的方法不是同步的,因此在多线程环境下,Hashtable 的性能更好。
85
+ # 3. Hashtable 中的 key 和 value 都必须实现 Serializable 接口,而 HashMap 中的 key 和 value 可以是任何对象,包括基本数据类型。
86
+ # 4. Hashtable 的初始容量是 11,而 HashMap 的初始容量是 16。
87
+ # 总之,如果需要在多线程环境下使用哈希表,并且需要保证线程安全,那么应该使用 Hashtable;如果不需要考虑线程安全,或者需要快速地进行哈希表操作,那么应该使用 HashMap。
88
+
89
+ # 第三轮对话
90
+ resp3, hist3 = model.chat(query='线程安全代表什么', tokenizer=tokenizer, history=hist2)
91
+ print(resp3)
92
+ # 线程安全是指在多线程环境下,程序能够正确地处理并发访问,并且��会出现数据竞争、死锁、饥饿等异常情况。线程安全的程序可以保证在不同的线程之间共享同一个数据结构时,数据的正确性和一致性。线程安全的实现通常需要使用同步机制,如锁、原子操作等,来保证对共享数据的访问是线程安全的。在 Java 中,可以通过 synchronized 关键字、Lock 接口等机制来实现线程安全。
93
+ ```
94
+
95
+
96
+
97
+ # 免责声明
98
+ 由于语言模型的特性,模型生成的内容可能包含幻觉或者歧视性言论。请谨慎使用 DevOps-Model 系列模型生成的内容。
99
+ 如果要公开使用或商用该模型服务,请注意服务方需承担由此产生的不良影响或有害言论的责任,本项目开发者不承担任何由使用本项目(包括但不限于数据、模型、代码等)导致的危害或损失。
100
+
101
+
102
+ # 致谢
103
+ 本项目参考了以下开源项目,在此对相关项目和研究开发人员表示感谢。
104
+ - [LLaMA-Efficient-Tuning](https://github.com/hiyouga/LLaMA-Efficient-Tuning)
105
+ - [Qwen-7B](https://github.com/QwenLM/Qwen-7B/tree/main)
config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "swiglu",
3
+ "apply_residual_connection_post_layernorm": false,
4
+ "architectures": [
5
+ "QWenLMHeadModel"
6
+ ],
7
+ "attn_pdrop": 0.0,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_qwen.QWenConfig",
10
+ "AutoModel": "modeling_qwen.QWenLMHeadModel",
11
+ "AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel"
12
+ },
13
+ "bf16": true,
14
+ "bias_dropout_fusion": true,
15
+ "bos_token_id": 151643,
16
+ "embd_pdrop": 0.0,
17
+ "eos_token_id": 151643,
18
+ "ffn_hidden_size": 22016,
19
+ "fp16": false,
20
+ "fp32": false,
21
+ "initializer_range": 0.02,
22
+ "kv_channels": 128,
23
+ "layer_norm_epsilon": 1e-06,
24
+ "model_type": "qwen",
25
+ "n_embd": 4096,
26
+ "n_head": 32,
27
+ "n_inner": null,
28
+ "n_layer": 32,
29
+ "n_positions": 6144,
30
+ "no_bias": true,
31
+ "onnx_safe": null,
32
+ "padded_vocab_size": 151936,
33
+ "params_dtype": "torch.bfloat16",
34
+ "pos_emb": "rotary",
35
+ "resid_pdrop": 0.1,
36
+ "rotary_emb_base": 10000,
37
+ "rotary_pct": 1.0,
38
+ "scale_attn_weights": true,
39
+ "seq_length": 2048,
40
+ "tie_word_embeddings": false,
41
+ "tokenizer_type": "QWenTokenizer",
42
+ "torch_dtype": "bfloat16",
43
+ "transformers_version": "4.32.0",
44
+ "use_cache": true,
45
+ "use_dynamic_ntk": true,
46
+ "use_flash_attn": true,
47
+ "use_logn_attn": true,
48
+ "vocab_size": 151936
49
+ }
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework":"Pytorch","task":"chatbot"}
configuration_qwen.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from transformers import PretrainedConfig
7
+
8
+
9
+ class QWenConfig(PretrainedConfig):
10
+ model_type = "qwen"
11
+ keys_to_ignore_at_inference = ["past_key_values"]
12
+ attribute_map = {
13
+ "hidden_size": "n_embd",
14
+ "num_attention_heads": "n_head",
15
+ "max_position_embeddings": "n_positions",
16
+ "num_hidden_layers": "n_layer",
17
+ }
18
+
19
+ def __init__(
20
+ self,
21
+ vocab_size=151851,
22
+ n_embd=4096,
23
+ n_layer=32,
24
+ n_head=32,
25
+ n_inner=None,
26
+ embd_pdrop=0.0,
27
+ attn_pdrop=0.0,
28
+ layer_norm_epsilon=1e-5,
29
+ initializer_range=0.02,
30
+ scale_attn_weights=True,
31
+ use_cache=True,
32
+ eos_token_id=151643,
33
+ apply_residual_connection_post_layernorm=False,
34
+ bf16=False,
35
+ fp16=False,
36
+ fp32=False,
37
+ kv_channels=128,
38
+ rotary_pct=1.0,
39
+ rotary_emb_base=10000,
40
+ use_dynamic_ntk=False,
41
+ use_logn_attn=False,
42
+ use_flash_attn=True,
43
+ ffn_hidden_size=22016,
44
+ no_bias=True,
45
+ tie_word_embeddings=False,
46
+ **kwargs,
47
+ ):
48
+ self.eos_token_id = eos_token_id
49
+ super().__init__(
50
+ eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
51
+ )
52
+
53
+ self.vocab_size = vocab_size
54
+ self.n_embd = n_embd
55
+ self.n_layer = n_layer
56
+ self.n_head = n_head
57
+ self.n_inner = n_inner
58
+ self.embd_pdrop = embd_pdrop
59
+ self.attn_pdrop = attn_pdrop
60
+ self.layer_norm_epsilon = layer_norm_epsilon
61
+ self.initializer_range = initializer_range
62
+ self.scale_attn_weights = scale_attn_weights
63
+ self.use_cache = use_cache
64
+ self.apply_residual_connection_post_layernorm = (
65
+ apply_residual_connection_post_layernorm
66
+ )
67
+ self.bf16 = bf16
68
+ self.fp16 = fp16
69
+ self.fp32 = fp32
70
+ self.kv_channels = kv_channels
71
+ self.rotary_pct = rotary_pct
72
+ self.rotary_emb_base = rotary_emb_base
73
+ self.use_dynamic_ntk = use_dynamic_ntk
74
+ self.use_logn_attn = use_logn_attn
75
+ self.use_flash_attn = use_flash_attn
76
+ self.ffn_hidden_size = ffn_hidden_size
77
+ self.no_bias = no_bias
78
+ self.tie_word_embeddings = tie_word_embeddings
generation_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chat_format": "chatml",
3
+ "eos_token_id": 151643,
4
+ "pad_token_id": 151643,
5
+ "max_window_size": 6144,
6
+ "max_new_tokens": 512,
7
+ "do_sample": true,
8
+ "top_k": 0,
9
+ "top_p": 0.5,
10
+ "transformers_version": "4.31.0"
11
+ }
modeling_qwen.py ADDED
@@ -0,0 +1,1219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import importlib
7
+ import math
8
+ from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from torch.cuda.amp import autocast
14
+
15
+ from torch.nn import CrossEntropyLoss
16
+ from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
17
+ from transformers.generation.logits_process import LogitsProcessorList
18
+
19
+ if TYPE_CHECKING:
20
+ from transformers.generation.streamers import BaseStreamer
21
+ from transformers.generation.utils import GenerateOutput
22
+ from transformers.modeling_outputs import (
23
+ BaseModelOutputWithPast,
24
+ CausalLMOutputWithPast,
25
+ )
26
+ from transformers.modeling_utils import PreTrainedModel
27
+ from transformers.utils import logging
28
+
29
+ try:
30
+ from einops import rearrange
31
+ except ImportError:
32
+ rearrange = None
33
+ from torch import nn
34
+
35
+ SUPPORT_CUDA = torch.cuda.is_available()
36
+ SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
37
+ SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
38
+
39
+ from .configuration_qwen import QWenConfig
40
+ from .qwen_generation_utils import (
41
+ HistoryType,
42
+ make_context,
43
+ decode_tokens,
44
+ get_stop_words_ids,
45
+ StopWordsLogitsProcessor,
46
+ )
47
+
48
+ # from loguru import logger
49
+ logger = logging.get_logger(__name__)
50
+
51
+ _CHECKPOINT_FOR_DOC = "qwen"
52
+ _CONFIG_FOR_DOC = "QWenConfig"
53
+
54
+ QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
55
+
56
+ _ERROR_BAD_CHAT_FORMAT = """\
57
+ We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml".
58
+ If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat().
59
+ 我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。
60
+ 如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。
61
+ """
62
+
63
+ _SENTINEL = object()
64
+ _ERROR_STREAM_IN_CHAT = """\
65
+ Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True).
66
+ 向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。
67
+ """
68
+
69
+ apply_rotary_emb_func = None
70
+ rms_norm = None
71
+ flash_attn_unpadded_func = None
72
+
73
+
74
+ def _import_flash_attn():
75
+ global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func
76
+ try:
77
+ from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func
78
+ apply_rotary_emb_func = __apply_rotary_emb_func
79
+ print('Using flash_attn rope')
80
+ except ImportError:
81
+ logger.warn(
82
+ "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency "
83
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
84
+ )
85
+
86
+ try:
87
+ from flash_attn.ops.rms_norm import rms_norm as __rms_norm
88
+ rms_norm = __rms_norm
89
+ print('Using flash_attn rms_norm')
90
+ except ImportError:
91
+ logger.warn(
92
+ "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency "
93
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
94
+ )
95
+
96
+ try:
97
+ import flash_attn
98
+ if not hasattr(flash_attn, '__version__'):
99
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
100
+ else:
101
+ if int(flash_attn.__version__.split(".")[0]) >= 2:
102
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func
103
+ else:
104
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
105
+ flash_attn_unpadded_func = __flash_attn_unpadded_func
106
+
107
+ print('Using flash_attn attention func')
108
+ except ImportError:
109
+ logger.warn(
110
+ "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
111
+ "https://github.com/Dao-AILab/flash-attention"
112
+ )
113
+
114
+
115
+ class FlashSelfAttention(torch.nn.Module):
116
+ def __init__(
117
+ self,
118
+ causal=False,
119
+ softmax_scale=None,
120
+ attention_dropout=0.0,
121
+ ):
122
+ super().__init__()
123
+ assert flash_attn_unpadded_func is not None, (
124
+ "Please install FlashAttention first, " "e.g., with pip install flash-attn"
125
+ )
126
+ assert (
127
+ rearrange is not None
128
+ ), "Please install einops first, e.g., with pip install einops"
129
+ self.causal = causal
130
+ self.softmax_scale = softmax_scale
131
+ self.dropout_p = attention_dropout
132
+
133
+ def forward(self, q, k, v):
134
+ assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
135
+ assert all((i.is_cuda for i in (q, k, v)))
136
+ batch_size, seqlen_q = q.shape[0], q.shape[1]
137
+ seqlen_k = k.shape[1]
138
+ q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
139
+ cu_seqlens_q = torch.arange(
140
+ 0,
141
+ (batch_size + 1) * seqlen_q,
142
+ step=seqlen_q,
143
+ dtype=torch.int32,
144
+ device=q.device,
145
+ )
146
+
147
+ if self.training:
148
+ assert seqlen_k == seqlen_q
149
+
150
+ is_causal = self.causal
151
+ cu_seqlens_k = cu_seqlens_q
152
+ else:
153
+ is_causal = seqlen_q == seqlen_k
154
+ cu_seqlens_k = torch.arange(
155
+ 0,
156
+ (batch_size + 1) * seqlen_k,
157
+ step=seqlen_k,
158
+ dtype=torch.int32,
159
+ device=q.device,
160
+ )
161
+ self.dropout_p = 0
162
+ output = flash_attn_unpadded_func(
163
+ q,
164
+ k,
165
+ v,
166
+ cu_seqlens_q,
167
+ cu_seqlens_k,
168
+ seqlen_q,
169
+ seqlen_k,
170
+ self.dropout_p,
171
+ softmax_scale=self.softmax_scale,
172
+ causal=is_causal,
173
+ )
174
+
175
+ output = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
176
+ return output
177
+
178
+
179
+ class QWenAttention(nn.Module):
180
+ def __init__(self, config, layer_number=None):
181
+ super().__init__()
182
+
183
+ max_positions = config.max_position_embeddings
184
+ self.register_buffer(
185
+ "bias",
186
+ torch.tril(
187
+ torch.ones((max_positions, max_positions), dtype=torch.bool)
188
+ ).view(1, 1, max_positions, max_positions),
189
+ persistent=False,
190
+ )
191
+ self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
192
+ self.layer_number = max(1, layer_number)
193
+ self.params_dtype = config.params_dtype
194
+ self.seq_length = config.seq_length
195
+
196
+ self.hidden_size = config.hidden_size
197
+ self.split_size = config.hidden_size
198
+ self.num_heads = config.num_attention_heads
199
+ self.head_dim = self.hidden_size // self.num_heads
200
+
201
+ self.use_flash_attn = config.use_flash_attn
202
+ self.scale_attn_weights = True
203
+
204
+ self.layer_idx = None
205
+
206
+ self.projection_size = config.kv_channels * config.num_attention_heads
207
+
208
+ assert self.projection_size % config.num_attention_heads == 0
209
+ self.hidden_size_per_attention_head = (
210
+ self.projection_size // config.num_attention_heads
211
+ )
212
+
213
+ self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
214
+
215
+ self.c_proj = nn.Linear(
216
+ config.hidden_size, self.projection_size, bias=not config.no_bias
217
+ )
218
+
219
+ self.is_fp32 = not (config.bf16 or config.fp16)
220
+ if (
221
+ self.use_flash_attn
222
+ and flash_attn_unpadded_func is not None
223
+ and not self.is_fp32
224
+ ):
225
+ self.core_attention_flash = FlashSelfAttention(
226
+ causal=True, attention_dropout=config.attn_pdrop
227
+ )
228
+
229
+ self.bf16 = config.bf16
230
+
231
+ if config.rotary_pct == 1.0:
232
+ self.rotary_ndims = None
233
+ else:
234
+ assert config.rotary_pct < 1
235
+ self.rotary_ndims = int(
236
+ self.hidden_size_per_attention_head * config.rotary_pct
237
+ )
238
+ dim = (
239
+ self.rotary_ndims
240
+ if self.rotary_ndims is not None
241
+ else self.hidden_size_per_attention_head
242
+ )
243
+ self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
244
+
245
+ self.use_dynamic_ntk = config.use_dynamic_ntk
246
+ self.use_logn_attn = config.use_logn_attn
247
+
248
+ logn_list = [
249
+ math.log(i, self.seq_length) if i > self.seq_length else 1
250
+ for i in range(1, 32768)
251
+ ]
252
+ self.logn_tensor = torch.Tensor(logn_list)[None, :, None, None]
253
+ self._ntk_cached = 1.0
254
+
255
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
256
+
257
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
258
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
259
+
260
+ if self.scale_attn_weights:
261
+ attn_weights = attn_weights / torch.full(
262
+ [],
263
+ value.size(-1) ** 0.5,
264
+ dtype=attn_weights.dtype,
265
+ device=attn_weights.device,
266
+ )
267
+
268
+ query_length, key_length = query.size(-2), key.size(-2)
269
+ causal_mask = self.bias[
270
+ :, :, key_length - query_length : key_length, :key_length
271
+ ]
272
+ mask_value = torch.finfo(attn_weights.dtype).min
273
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(
274
+ attn_weights.device
275
+ )
276
+ attn_weights = torch.where(
277
+ causal_mask, attn_weights.to(attn_weights.dtype), mask_value
278
+ )
279
+
280
+ if attention_mask is not None:
281
+ # Apply the attention mask
282
+ attn_weights = attn_weights + attention_mask
283
+
284
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
285
+
286
+ attn_weights = attn_weights.type(value.dtype)
287
+ attn_weights = self.attn_dropout(attn_weights)
288
+
289
+ if head_mask is not None:
290
+ attn_weights = attn_weights * head_mask
291
+
292
+ attn_output = torch.matmul(attn_weights, value)
293
+ attn_output = attn_output.transpose(1, 2)
294
+
295
+ return attn_output, attn_weights
296
+
297
+ def _upcast_and_reordered_attn(
298
+ self, query, key, value, attention_mask=None, head_mask=None
299
+ ):
300
+ bsz, num_heads, q_seq_len, dk = query.size()
301
+ _, _, k_seq_len, _ = key.size()
302
+
303
+ attn_weights = torch.empty(
304
+ bsz * num_heads,
305
+ q_seq_len,
306
+ k_seq_len,
307
+ dtype=torch.float32,
308
+ device=query.device,
309
+ )
310
+
311
+ scale_factor = 1.0
312
+ if self.scale_attn_weights:
313
+ scale_factor /= float(value.size(-1)) ** 0.5
314
+
315
+ with autocast(enabled=False):
316
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
317
+ -1, dk, k_seq_len
318
+ )
319
+ attn_weights = torch.baddbmm(
320
+ attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
321
+ )
322
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
323
+
324
+ query_length, key_length = query.size(-2), key.size(-2)
325
+ causal_mask = self.bias[
326
+ :, :, key_length - query_length : key_length, :key_length
327
+ ]
328
+ mask_value = torch.finfo(attn_weights.dtype).min
329
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
330
+ attn_weights.device
331
+ )
332
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
333
+
334
+ if attention_mask is not None:
335
+ attn_weights = attn_weights + attention_mask
336
+
337
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
338
+
339
+ if attn_weights.dtype != torch.float32:
340
+ raise RuntimeError(
341
+ "Error with upcasting, attn_weights does not have dtype torch.float32"
342
+ )
343
+ attn_weights = attn_weights.type(value.dtype)
344
+ attn_weights = self.attn_dropout(attn_weights)
345
+
346
+ if head_mask is not None:
347
+ attn_weights = attn_weights * head_mask
348
+
349
+ attn_output = torch.matmul(attn_weights, value)
350
+
351
+ return attn_output, attn_weights
352
+
353
+ def _split_heads(self, tensor, num_heads, attn_head_size):
354
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
355
+ tensor = tensor.view(new_shape)
356
+ return tensor
357
+
358
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
359
+ tensor = tensor.contiguous()
360
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
361
+ return tensor.view(new_shape)
362
+
363
+ def forward(
364
+ self,
365
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
366
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
367
+ attention_mask: Optional[torch.FloatTensor] = None,
368
+ head_mask: Optional[torch.FloatTensor] = None,
369
+ encoder_hidden_states: Optional[torch.Tensor] = None,
370
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
371
+ output_attentions: Optional[bool] = False,
372
+ use_cache: Optional[bool] = False,
373
+ ):
374
+
375
+ mixed_x_layer = self.c_attn(hidden_states)
376
+ query, key, value = mixed_x_layer.split(self.split_size, dim=2)
377
+
378
+ query = self._split_heads(query, self.num_heads, self.head_dim)
379
+ key = self._split_heads(key, self.num_heads, self.head_dim)
380
+ value = self._split_heads(value, self.num_heads, self.head_dim)
381
+
382
+ kv_seq_len = hidden_states.size()[1]
383
+ if layer_past:
384
+ # layer past[0] shape: bs * seq_len * head_num * dim
385
+ kv_seq_len += layer_past[0].shape[1]
386
+ if (
387
+ self.use_dynamic_ntk
388
+ and kv_seq_len == hidden_states.size()[1]
389
+ and not self.training
390
+ ):
391
+ context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
392
+ ntk_alpha = 2 ** math.ceil(context_value) - 1
393
+ ntk_alpha = max(ntk_alpha, 1)
394
+ self._ntk_cached = ntk_alpha
395
+ else:
396
+ ntk_alpha = self._ntk_cached
397
+ rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to(
398
+ hidden_states.device
399
+ )
400
+
401
+ if rotary_pos_emb is not None:
402
+ if isinstance(rotary_pos_emb, tuple):
403
+ rotary_pos_emb = rotary_pos_emb
404
+ else:
405
+ rotary_pos_emb = (rotary_pos_emb,) * 2
406
+
407
+ if rotary_pos_emb is not None:
408
+ q_pos_emb, k_pos_emb = rotary_pos_emb
409
+ # Slice the pos emb for current inference
410
+ cur_len = query.shape[1]
411
+ q_pos_emb = q_pos_emb[:, -cur_len:, :, :]
412
+ k_pos_emb = k_pos_emb[:, -cur_len:, :, :]
413
+ query = apply_rotary_pos_emb(query, q_pos_emb)
414
+ key = apply_rotary_pos_emb(key, k_pos_emb)
415
+
416
+ if layer_past is not None:
417
+ past_key, past_value = layer_past[0], layer_past[1]
418
+ key = torch.cat((past_key, key), dim=1)
419
+ value = torch.cat((past_value, value), dim=1)
420
+
421
+ if use_cache:
422
+ present = (key, value)
423
+ else:
424
+ present = None
425
+
426
+ if self.use_logn_attn and not self.training:
427
+ if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
428
+ self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
429
+ seq_start = key.size(1) - query.size(1)
430
+ seq_end = key.size(1)
431
+ logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
432
+ query = query * logn_tensor.expand_as(query)
433
+
434
+ if (
435
+ self.use_flash_attn
436
+ and flash_attn_unpadded_func is not None
437
+ and not self.is_fp32
438
+ and query.is_cuda
439
+ ):
440
+ q, k, v = query, key, value
441
+ context_layer = self.core_attention_flash(q, k, v)
442
+
443
+ context_layer = rearrange(
444
+ context_layer, "b s h d -> b s (h d)"
445
+ ).contiguous()
446
+ else:
447
+ query = query.permute(0, 2, 1, 3)
448
+ key = key.permute(0, 2, 1, 3)
449
+ value = value.permute(0, 2, 1, 3)
450
+ attn_output, attn_weight = self._attn(
451
+ query, key, value, attention_mask, head_mask
452
+ )
453
+ context_layer = self._merge_heads(
454
+ attn_output, self.num_heads, self.head_dim
455
+ )
456
+
457
+ attn_output = self.c_proj(context_layer)
458
+ outputs = (attn_output, present)
459
+ if output_attentions:
460
+ if (
461
+ self.use_flash_attn
462
+ and flash_attn_unpadded_func is not None
463
+ and not self.is_fp32
464
+ ):
465
+ raise ValueError("Cannot output attentions while using flash-attn")
466
+ else:
467
+ outputs += (attn_weight,)
468
+
469
+ return outputs
470
+
471
+
472
+ class QWenMLP(nn.Module):
473
+ def __init__(self, config):
474
+ super().__init__()
475
+ self.w1 = nn.Linear(
476
+ config.hidden_size, config.ffn_hidden_size // 2, bias=not config.no_bias
477
+ )
478
+ self.w2 = nn.Linear(
479
+ config.hidden_size, config.ffn_hidden_size // 2, bias=not config.no_bias
480
+ )
481
+ ff_dim_in = config.ffn_hidden_size // 2
482
+ self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
483
+
484
+ def forward(self, hidden_states):
485
+ a1 = self.w1(hidden_states)
486
+ a2 = self.w2(hidden_states)
487
+ intermediate_parallel = a1 * F.silu(a2)
488
+ output = self.c_proj(intermediate_parallel)
489
+ return output
490
+
491
+
492
+ class QWenBlock(nn.Module):
493
+ def __init__(self, config, layer_idx=None, num_expert=1):
494
+ super().__init__()
495
+ self.num_expert = num_expert
496
+ self.layer_number = layer_idx
497
+ self.apply_residual_connection_post_layernorm = (
498
+ config.apply_residual_connection_post_layernorm
499
+ )
500
+ hidden_size = config.hidden_size
501
+ self.apply_residual_connection_post_layernorm = (
502
+ config.apply_residual_connection_post_layernorm
503
+ )
504
+ self.bf16 = config.bf16
505
+
506
+ self.ln_1 = RMSNorm(
507
+ hidden_size,
508
+ eps=config.layer_norm_epsilon,
509
+ )
510
+ self.attn = QWenAttention(config, layer_number=layer_idx)
511
+ self.ln_2 = RMSNorm(
512
+ hidden_size,
513
+ eps=config.layer_norm_epsilon,
514
+ )
515
+
516
+ self.mlp = QWenMLP(config)
517
+
518
+ def forward(
519
+ self,
520
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
521
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
522
+ attention_mask: Optional[torch.FloatTensor] = None,
523
+ head_mask: Optional[torch.FloatTensor] = None,
524
+ encoder_hidden_states: Optional[torch.Tensor] = None,
525
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
526
+ use_cache: Optional[bool] = False,
527
+ output_attentions: Optional[bool] = False,
528
+ ):
529
+ layernorm_output = self.ln_1(hidden_states)
530
+
531
+ attn_outputs = self.attn(
532
+ layernorm_output,
533
+ layer_past=layer_past,
534
+ attention_mask=attention_mask,
535
+ head_mask=head_mask,
536
+ use_cache=use_cache,
537
+ output_attentions=output_attentions,
538
+ )
539
+ attn_output = attn_outputs[0]
540
+
541
+ outputs = attn_outputs[1:]
542
+
543
+ if self.apply_residual_connection_post_layernorm:
544
+ residual = layernorm_output
545
+ else:
546
+ residual = hidden_states
547
+ layernorm_input = attn_output + residual
548
+
549
+ layernorm_output = self.ln_2(layernorm_input)
550
+
551
+ if self.apply_residual_connection_post_layernorm:
552
+ residual = layernorm_output
553
+ else:
554
+ residual = layernorm_input
555
+
556
+ mlp_output = self.mlp(layernorm_output)
557
+ hidden_states = residual + mlp_output
558
+
559
+ if use_cache:
560
+ outputs = (hidden_states,) + outputs
561
+ else:
562
+ outputs = (hidden_states,) + outputs[1:]
563
+
564
+ return outputs
565
+
566
+
567
+ class QWenPreTrainedModel(PreTrainedModel):
568
+ config_class = QWenConfig
569
+ base_model_prefix = "transformer"
570
+ is_parallelizable = False
571
+ supports_gradient_checkpointing = True
572
+ _no_split_modules = ["QWenBlock"]
573
+
574
+ def __init__(self, *inputs, **kwargs):
575
+ super().__init__(*inputs, **kwargs)
576
+
577
+ def _init_weights(self, module):
578
+ """Initialize the weights."""
579
+ if isinstance(module, nn.Linear):
580
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
581
+ if module.bias is not None:
582
+ module.bias.data.zero_()
583
+ elif isinstance(module, nn.Embedding):
584
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
585
+ if module.padding_idx is not None:
586
+ module.weight.data[module.padding_idx].zero_()
587
+ elif isinstance(module, RMSNorm):
588
+ module.weight.data.fill_(1.0)
589
+
590
+ for name, p in module.named_parameters():
591
+ if name == "c_proj.weight":
592
+ p.data.normal_(
593
+ mean=0.0,
594
+ std=(
595
+ self.config.initializer_range
596
+ / math.sqrt(2 * self.config.n_layer)
597
+ ),
598
+ )
599
+
600
+ def _set_gradient_checkpointing(self, module, value=False):
601
+ if isinstance(module, QWenModel):
602
+ module.gradient_checkpointing = value
603
+
604
+
605
+ class QWenModel(QWenPreTrainedModel):
606
+ _keys_to_ignore_on_load_missing = ["attn.masked_bias"]
607
+
608
+ def __init__(self, config):
609
+ super().__init__(config)
610
+ self.vocab_size = config.padded_vocab_size
611
+ self.num_hidden_layers = config.num_hidden_layers
612
+ self.embed_dim = config.hidden_size
613
+
614
+ max_sequence_length = config.max_position_embeddings
615
+ self.position_embedding_type = config.pos_emb
616
+ self.gradient_checkpointing = False
617
+
618
+ if self.position_embedding_type == "learned":
619
+ self.wpe = nn.Embedding(max_sequence_length, self.embed_dim)
620
+ self.init_method(self.position_embeddings.weight)
621
+ self._position_embeddings_key = "position_embeddings"
622
+ self.init_method(self.position_embeddings.weight)
623
+ else:
624
+ self.wpe = None
625
+ self._position_embeddings_key = ""
626
+
627
+ self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
628
+
629
+ self.drop = nn.Dropout(config.embd_pdrop)
630
+ self.h = nn.ModuleList(
631
+ [
632
+ QWenBlock(
633
+ config,
634
+ layer_idx=i,
635
+ )
636
+ for i in range(config.num_hidden_layers)
637
+ ]
638
+ )
639
+ self.ln_f = RMSNorm(
640
+ self.embed_dim,
641
+ eps=config.layer_norm_epsilon,
642
+ )
643
+
644
+ self.post_init()
645
+
646
+ def get_input_embeddings(self):
647
+ return self.wte
648
+
649
+ def set_input_embeddings(self, new_embeddings):
650
+ self.wte = new_embeddings
651
+
652
+ def forward(
653
+ self,
654
+ input_ids: Optional[torch.LongTensor] = None,
655
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
656
+ attention_mask: Optional[torch.FloatTensor] = None,
657
+ token_type_ids: Optional[torch.LongTensor] = None,
658
+ position_ids: Optional[torch.LongTensor] = None,
659
+ head_mask: Optional[torch.FloatTensor] = None,
660
+ inputs_embeds: Optional[torch.FloatTensor] = None,
661
+ encoder_hidden_states: Optional[torch.Tensor] = None,
662
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
663
+ use_cache: Optional[bool] = None,
664
+ output_attentions: Optional[bool] = None,
665
+ output_hidden_states: Optional[bool] = None,
666
+ return_dict: Optional[bool] = None,
667
+ ):
668
+ output_attentions = (
669
+ output_attentions
670
+ if output_attentions is not None
671
+ else self.config.output_attentions
672
+ )
673
+ output_hidden_states = (
674
+ output_hidden_states
675
+ if output_hidden_states is not None
676
+ else self.config.output_hidden_states
677
+ )
678
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
679
+ return_dict = (
680
+ return_dict if return_dict is not None else self.config.use_return_dict
681
+ )
682
+
683
+ if input_ids is not None and inputs_embeds is not None:
684
+ raise ValueError(
685
+ "You cannot specify both input_ids and inputs_embeds at the same time"
686
+ )
687
+ elif input_ids is not None:
688
+ input_shape = input_ids.size()
689
+ input_ids = input_ids.view(-1, input_shape[-1])
690
+ batch_size = input_ids.shape[0]
691
+ elif inputs_embeds is not None:
692
+ input_shape = inputs_embeds.size()[:-1]
693
+ batch_size = inputs_embeds.shape[0]
694
+ else:
695
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
696
+
697
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
698
+
699
+ if token_type_ids is not None:
700
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
701
+ if position_ids is not None:
702
+ position_ids = position_ids.view(-1, input_shape[-1])
703
+
704
+ if past_key_values is None:
705
+ past_length = 0
706
+ past_key_values = tuple([None] * len(self.h))
707
+ else:
708
+ past_length = past_key_values[0][0].size(-2)
709
+
710
+ if position_ids is None:
711
+ position_ids = torch.arange(
712
+ past_length,
713
+ input_shape[-1] + past_length,
714
+ dtype=torch.long,
715
+ device=device,
716
+ )
717
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
718
+
719
+ if attention_mask is not None:
720
+ if batch_size <= 0:
721
+ raise ValueError("batch_size has to be defined and > 0")
722
+ attention_mask = attention_mask.view(batch_size, -1)
723
+ attention_mask = attention_mask[:, None, None, :]
724
+ attention_mask = attention_mask.to(dtype=self.dtype)
725
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
726
+ # attention_mask中mask掉的部分是-inf, 看到的部分是0
727
+
728
+ encoder_attention_mask = None
729
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
730
+
731
+ if inputs_embeds is None:
732
+ inputs_embeds = self.wte(input_ids)
733
+ hidden_states = inputs_embeds
734
+ if self.wpe is not None:
735
+ position_embeds = self.wpe(position_ids)
736
+ hidden_states = hidden_states + position_embeds
737
+
738
+ hidden_states = self.drop(hidden_states)
739
+ output_shape = input_shape + (hidden_states.size(-1),)
740
+
741
+ if self.gradient_checkpointing and self.training:
742
+ if use_cache:
743
+ logger.warning_once(
744
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
745
+ )
746
+ use_cache = False
747
+
748
+ presents = () if use_cache else None
749
+ all_self_attentions = () if output_attentions else None
750
+ all_hidden_states = () if output_hidden_states else None
751
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
752
+
753
+ if output_hidden_states:
754
+ all_hidden_states = all_hidden_states + (hidden_states,)
755
+
756
+ if self.gradient_checkpointing and self.training:
757
+
758
+ def create_custom_forward(module):
759
+ def custom_forward(*inputs):
760
+ # None for past_key_value
761
+ return module(*inputs, use_cache, output_attentions)
762
+
763
+ return custom_forward
764
+
765
+ outputs = torch.utils.checkpoint.checkpoint(
766
+ create_custom_forward(block),
767
+ hidden_states,
768
+ None,
769
+ attention_mask,
770
+ head_mask[i],
771
+ encoder_hidden_states,
772
+ encoder_attention_mask,
773
+ )
774
+ else:
775
+ outputs = block(
776
+ hidden_states,
777
+ layer_past=layer_past,
778
+ attention_mask=attention_mask,
779
+ head_mask=head_mask[i],
780
+ encoder_hidden_states=encoder_hidden_states,
781
+ encoder_attention_mask=encoder_attention_mask,
782
+ use_cache=use_cache,
783
+ output_attentions=output_attentions,
784
+ )
785
+
786
+ hidden_states = outputs[0]
787
+ if use_cache is True:
788
+ presents = presents + (outputs[2 if output_attentions else 1],)
789
+
790
+ if output_attentions:
791
+ all_self_attentions = all_self_attentions + (outputs[1],)
792
+
793
+ hidden_states = self.ln_f(hidden_states)
794
+ hidden_states = hidden_states.view(output_shape)
795
+
796
+ if not return_dict:
797
+ return tuple(
798
+ v for v in [hidden_states, presents, all_hidden_states] if v is not None
799
+ )
800
+
801
+ return BaseModelOutputWithPast(
802
+ last_hidden_state=hidden_states,
803
+ past_key_values=presents,
804
+ hidden_states=all_hidden_states,
805
+ attentions=all_self_attentions,
806
+ )
807
+
808
+
809
+ class QWenLMHeadModel(QWenPreTrainedModel):
810
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"]
811
+ _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]
812
+
813
+ def __init__(self, config):
814
+ super().__init__(config)
815
+ assert (
816
+ config.bf16 + config.fp16 + config.fp32 <= 1
817
+ ), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"
818
+
819
+ autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
820
+
821
+ if autoset_precision:
822
+ if SUPPORT_BF16:
823
+ logger.warn(
824
+ "The model is automatically converting to bf16 for faster inference. "
825
+ "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
826
+ )
827
+ config.bf16 = True
828
+ elif SUPPORT_FP16:
829
+ logger.warn(
830
+ "The model is automatically converting to fp16 for faster inference. "
831
+ "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
832
+ )
833
+ config.fp16 = True
834
+ else:
835
+ config.fp32 = True
836
+
837
+ if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
838
+ logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
839
+ if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
840
+ logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
841
+ if config.fp32:
842
+ if SUPPORT_BF16:
843
+ logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
844
+ elif SUPPORT_FP16:
845
+ logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
846
+
847
+ if config.use_flash_attn == "auto":
848
+ if config.bf16 or config.fp16:
849
+ logger.warn("Try importing flash-attention for faster inference...")
850
+ config.use_flash_attn = True
851
+ else:
852
+ config.use_flash_attn = False
853
+ if config.use_flash_attn and config.fp32:
854
+ logger.warn("Flash attention will be disabled because it does NOT support fp32.")
855
+
856
+ if config.use_flash_attn:
857
+ _import_flash_attn()
858
+
859
+ self.transformer = QWenModel(config)
860
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
861
+
862
+ if config.bf16:
863
+ self.transformer.bfloat16()
864
+ self.lm_head.bfloat16()
865
+ if config.fp16:
866
+ self.transformer.half()
867
+ self.lm_head.half()
868
+ self.post_init()
869
+
870
+ def get_output_embeddings(self):
871
+ return self.lm_head
872
+
873
+ def set_output_embeddings(self, new_embeddings):
874
+ self.lm_head = new_embeddings
875
+
876
+ def prepare_inputs_for_generation(
877
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
878
+ ):
879
+ token_type_ids = kwargs.get("token_type_ids", None)
880
+ if past_key_values:
881
+ input_ids = input_ids[:, -1].unsqueeze(-1)
882
+ if token_type_ids is not None:
883
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
884
+
885
+ attention_mask = kwargs.get("attention_mask", None)
886
+ position_ids = kwargs.get("position_ids", None)
887
+
888
+ if attention_mask is not None and position_ids is None:
889
+ position_ids = attention_mask.long().cumsum(-1) - 1
890
+ position_ids.masked_fill_(attention_mask == 0, 1)
891
+ if past_key_values:
892
+ position_ids = position_ids[:, -1].unsqueeze(-1)
893
+ else:
894
+ position_ids = None
895
+
896
+ if inputs_embeds is not None and past_key_values is None:
897
+ model_inputs = {"inputs_embeds": inputs_embeds}
898
+ else:
899
+ model_inputs = {"input_ids": input_ids}
900
+
901
+ model_inputs.update(
902
+ {
903
+ "past_key_values": past_key_values,
904
+ "use_cache": kwargs.get("use_cache"),
905
+ "position_ids": position_ids,
906
+ "attention_mask": attention_mask,
907
+ "token_type_ids": token_type_ids,
908
+ }
909
+ )
910
+ return model_inputs
911
+
912
+ def forward(
913
+ self,
914
+ input_ids: Optional[torch.LongTensor] = None,
915
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
916
+ attention_mask: Optional[torch.FloatTensor] = None,
917
+ token_type_ids: Optional[torch.LongTensor] = None,
918
+ position_ids: Optional[torch.LongTensor] = None,
919
+ head_mask: Optional[torch.FloatTensor] = None,
920
+ inputs_embeds: Optional[torch.FloatTensor] = None,
921
+ encoder_hidden_states: Optional[torch.Tensor] = None,
922
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
923
+ labels: Optional[torch.LongTensor] = None,
924
+ use_cache: Optional[bool] = None,
925
+ output_attentions: Optional[bool] = None,
926
+ output_hidden_states: Optional[bool] = None,
927
+ return_dict: Optional[bool] = None,
928
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
929
+
930
+ return_dict = (
931
+ return_dict if return_dict is not None else self.config.use_return_dict
932
+ )
933
+
934
+ transformer_outputs = self.transformer(
935
+ input_ids,
936
+ past_key_values=past_key_values,
937
+ attention_mask=attention_mask,
938
+ token_type_ids=token_type_ids,
939
+ position_ids=position_ids,
940
+ head_mask=head_mask,
941
+ inputs_embeds=inputs_embeds,
942
+ encoder_hidden_states=encoder_hidden_states,
943
+ encoder_attention_mask=encoder_attention_mask,
944
+ use_cache=use_cache,
945
+ output_attentions=output_attentions,
946
+ output_hidden_states=output_hidden_states,
947
+ return_dict=return_dict,
948
+ )
949
+ hidden_states = transformer_outputs[0]
950
+
951
+ lm_logits = self.lm_head(hidden_states)
952
+
953
+ loss = None
954
+ if labels is not None:
955
+ labels = labels.to(lm_logits.device)
956
+ shift_logits = lm_logits[..., :-1, :].contiguous()
957
+ shift_labels = labels[..., 1:].contiguous()
958
+ loss_fct = CrossEntropyLoss()
959
+ loss = loss_fct(
960
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
961
+ )
962
+
963
+ if not return_dict:
964
+ output = (lm_logits,) + transformer_outputs[1:]
965
+ return ((loss,) + output) if loss is not None else output
966
+
967
+ return CausalLMOutputWithPast(
968
+ loss=loss,
969
+ logits=lm_logits,
970
+ past_key_values=transformer_outputs.past_key_values,
971
+ hidden_states=transformer_outputs.hidden_states,
972
+ attentions=transformer_outputs.attentions,
973
+ )
974
+
975
+ @staticmethod
976
+ def _reorder_cache(
977
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
978
+ ) -> Tuple[Tuple[torch.Tensor]]:
979
+
980
+ return tuple(
981
+ tuple(
982
+ past_state.index_select(0, beam_idx.to(past_state.device))
983
+ for past_state in layer_past
984
+ )
985
+ for layer_past in past_key_values
986
+ )
987
+
988
+ def chat(
989
+ self,
990
+ tokenizer: PreTrainedTokenizer,
991
+ query: str,
992
+ history: Optional[HistoryType],
993
+ system: str = "You are a helpful assistant.",
994
+ append_history: bool = True,
995
+ stream: Optional[bool] = _SENTINEL,
996
+ stop_words_ids: Optional[List[List[int]]] = None,
997
+ **kwargs,
998
+ ) -> Tuple[str, HistoryType]:
999
+ assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
1000
+ assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
1001
+ if history is None:
1002
+ history = []
1003
+ if stop_words_ids is None:
1004
+ stop_words_ids = []
1005
+
1006
+ raw_text, context_tokens = make_context(
1007
+ tokenizer,
1008
+ query,
1009
+ history=history,
1010
+ system=system,
1011
+ max_window_size=6144,
1012
+ chat_format=self.generation_config.chat_format,
1013
+ )
1014
+
1015
+ stop_words_ids.extend(get_stop_words_ids(
1016
+ self.generation_config.chat_format, tokenizer
1017
+ ))
1018
+ input_ids = torch.tensor([context_tokens]).to(self.device)
1019
+ outputs = self.generate(
1020
+ input_ids,
1021
+ stop_words_ids = stop_words_ids,
1022
+ return_dict_in_generate = False,
1023
+ **kwargs,
1024
+ )
1025
+
1026
+ response = decode_tokens(
1027
+ outputs[0],
1028
+ tokenizer,
1029
+ raw_text_len=len(raw_text),
1030
+ context_length=len(context_tokens),
1031
+ chat_format=self.generation_config.chat_format,
1032
+ verbose=False,
1033
+ errors='replace'
1034
+ )
1035
+
1036
+ if append_history:
1037
+ history.append((query, response))
1038
+
1039
+ return response, history
1040
+
1041
+ def chat_stream(
1042
+ self,
1043
+ tokenizer: PreTrainedTokenizer,
1044
+ query: str,
1045
+ history: Optional[HistoryType],
1046
+ system: str = "You are a helpful assistant.",
1047
+ stop_words_ids: Optional[List[List[int]]] = None,
1048
+ logits_processor: Optional[LogitsProcessorList] = None,
1049
+ **kwargs,
1050
+ ) -> Generator[str, Any, None]:
1051
+ assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
1052
+ if history is None:
1053
+ history = []
1054
+ if stop_words_ids is None:
1055
+ stop_words_ids = []
1056
+
1057
+ raw_text, context_tokens = make_context(
1058
+ tokenizer,
1059
+ query,
1060
+ history=history,
1061
+ system=system,
1062
+ max_window_size=6144,
1063
+ chat_format=self.generation_config.chat_format,
1064
+ )
1065
+
1066
+ stop_words_ids.extend(get_stop_words_ids(
1067
+ self.generation_config.chat_format, tokenizer
1068
+ ))
1069
+ if stop_words_ids is not None:
1070
+ stop_words_logits_processor = StopWordsLogitsProcessor(
1071
+ stop_words_ids=stop_words_ids,
1072
+ eos_token_id=self.generation_config.eos_token_id,
1073
+ )
1074
+ if logits_processor is None:
1075
+ logits_processor = LogitsProcessorList([stop_words_logits_processor])
1076
+ else:
1077
+ logits_processor.append(stop_words_logits_processor)
1078
+ input_ids = torch.tensor([context_tokens]).to(self.device)
1079
+
1080
+ from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
1081
+ self.__class__.generate_stream = NewGenerationMixin.generate
1082
+ self.__class__.sample_stream = NewGenerationMixin.sample_stream
1083
+ stream_config = StreamGenerationConfig(**self.generation_config.to_dict(), do_stream=True)
1084
+ def stream_generator():
1085
+ outputs = []
1086
+ for token in self.generate_stream(
1087
+ input_ids,
1088
+ return_dict_in_generate=False,
1089
+ generation_config=stream_config,
1090
+ logits_processor=logits_processor,
1091
+ seed=-1,
1092
+ **kwargs):
1093
+ outputs.append(token.item())
1094
+ yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore')
1095
+
1096
+ return stream_generator()
1097
+
1098
+ def generate(
1099
+ self,
1100
+ inputs: Optional[torch.Tensor] = None,
1101
+ generation_config: Optional[GenerationConfig] = None,
1102
+ logits_processor: Optional[LogitsProcessorList] = None,
1103
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1104
+ prefix_allowed_tokens_fn: Optional[
1105
+ Callable[[int, torch.Tensor], List[int]]
1106
+ ] = None,
1107
+ synced_gpus: Optional[bool] = None,
1108
+ assistant_model: Optional["PreTrainedModel"] = None,
1109
+ streamer: Optional["BaseStreamer"] = None,
1110
+ **kwargs,
1111
+ ) -> Union[GenerateOutput, torch.LongTensor]:
1112
+ # Process stop_words_ids.
1113
+ stop_words_ids = kwargs.pop("stop_words_ids", None)
1114
+ if stop_words_ids is None and generation_config is not None:
1115
+ stop_words_ids = getattr(generation_config, "stop_words_ids", None)
1116
+ if stop_words_ids is None:
1117
+ stop_words_ids = getattr(self.generation_config, "stop_words_ids", None)
1118
+
1119
+ if stop_words_ids is not None:
1120
+ stop_words_logits_processor = StopWordsLogitsProcessor(
1121
+ stop_words_ids=stop_words_ids,
1122
+ eos_token_id=self.generation_config.eos_token_id,
1123
+ )
1124
+ if logits_processor is None:
1125
+ logits_processor = LogitsProcessorList([stop_words_logits_processor])
1126
+ else:
1127
+ logits_processor.append(stop_words_logits_processor)
1128
+
1129
+ return super().generate(
1130
+ inputs,
1131
+ generation_config=generation_config,
1132
+ logits_processor=logits_processor,
1133
+ stopping_criteria=stopping_criteria,
1134
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1135
+ synced_gpus=synced_gpus,
1136
+ assistant_model=assistant_model,
1137
+ streamer=streamer,
1138
+ **kwargs,
1139
+ )
1140
+
1141
+
1142
+ class RotaryEmbedding(torch.nn.Module):
1143
+ def __init__(self, dim, base=10000):
1144
+ super().__init__()
1145
+ self.dim = dim
1146
+ self.base = base
1147
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
1148
+ if importlib.util.find_spec("einops") is None:
1149
+ raise RuntimeError("einops is required for Rotary Embedding")
1150
+
1151
+ self._rotary_pos_emb_cache = None
1152
+ self._seq_len_cached = 0
1153
+ self._ntk_alpha_cached = 1.0
1154
+
1155
+ def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
1156
+ seqlen = max_seq_len + offset
1157
+ if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
1158
+ base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
1159
+ self.inv_freq = 1.0 / (
1160
+ base
1161
+ ** (
1162
+ torch.arange(0, self.dim, 2, device=self.inv_freq.device).float()
1163
+ / self.dim
1164
+ )
1165
+ )
1166
+ self._seq_len_cached = max(2 * seqlen, 16)
1167
+ self._ntk_alpha_cached = ntk_alpha
1168
+ seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device)
1169
+ freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
1170
+ emb = torch.cat((freqs, freqs), dim=-1)
1171
+ from einops import rearrange
1172
+
1173
+ self._rotary_pos_emb_cache = rearrange(emb, "n d -> 1 n 1 d")
1174
+
1175
+ def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
1176
+ self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
1177
+ return self._rotary_pos_emb_cache[:, offset : offset + max_seq_len]
1178
+
1179
+
1180
+ def _rotate_half(x):
1181
+ from einops import rearrange
1182
+
1183
+ x = rearrange(x, "... (j d) -> ... j d", j=2)
1184
+ x1, x2 = x.unbind(dim=-2)
1185
+ return torch.cat((-x2, x1), dim=-1)
1186
+
1187
+
1188
+ def apply_rotary_pos_emb(t, freqs):
1189
+ if apply_rotary_emb_func is not None:
1190
+ t_ = t.float()
1191
+ freqs = freqs.squeeze(0).squeeze(1)
1192
+ cos = freqs[:, : freqs.shape[-1] // 2].cos()
1193
+ sin = freqs[:, : freqs.shape[-1] // 2].sin()
1194
+ output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
1195
+ return output
1196
+ else:
1197
+ rot_dim = freqs.shape[-1]
1198
+ t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
1199
+ t_ = t_.float()
1200
+ t_pass_ = t_pass_.float()
1201
+ t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin())
1202
+ return torch.cat((t_, t_pass_), dim=-1).type_as(t)
1203
+
1204
+
1205
+ class RMSNorm(torch.nn.Module):
1206
+ def __init__(self, dim: int, eps: float = 1e-6):
1207
+ super().__init__()
1208
+ self.eps = eps
1209
+ self.weight = nn.Parameter(torch.ones(dim))
1210
+
1211
+ def _norm(self, x):
1212
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
1213
+
1214
+ def forward(self, x):
1215
+ if rms_norm is not None and x.is_cuda:
1216
+ return rms_norm(x, self.weight, self.eps)
1217
+ else:
1218
+ output = self._norm(x.float()).type_as(x)
1219
+ return output * self.weight
pytorch_model-00001-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e088e183a6a7bb4475e83172d9442d448d235333ed67a2b505fa805d44b3de5
3
+ size 9969772092
pytorch_model-00002-of-00002.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0f5d211ab6c42228c2a5712c11fe0616ef6cafe22178e8dc8ce867dc9c57e34
3
+ size 5472963479
pytorch_model.bin.index.json ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 15442649088
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "pytorch_model-00002-of-00002.bin",
7
+ "transformer.h.0.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
8
+ "transformer.h.0.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
9
+ "transformer.h.0.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
10
+ "transformer.h.0.ln_1.weight": "pytorch_model-00001-of-00002.bin",
11
+ "transformer.h.0.ln_2.weight": "pytorch_model-00001-of-00002.bin",
12
+ "transformer.h.0.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
13
+ "transformer.h.0.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
14
+ "transformer.h.0.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
15
+ "transformer.h.1.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
16
+ "transformer.h.1.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
17
+ "transformer.h.1.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
18
+ "transformer.h.1.ln_1.weight": "pytorch_model-00001-of-00002.bin",
19
+ "transformer.h.1.ln_2.weight": "pytorch_model-00001-of-00002.bin",
20
+ "transformer.h.1.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
21
+ "transformer.h.1.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
22
+ "transformer.h.1.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
23
+ "transformer.h.10.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
24
+ "transformer.h.10.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
25
+ "transformer.h.10.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
26
+ "transformer.h.10.ln_1.weight": "pytorch_model-00001-of-00002.bin",
27
+ "transformer.h.10.ln_2.weight": "pytorch_model-00001-of-00002.bin",
28
+ "transformer.h.10.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
29
+ "transformer.h.10.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
30
+ "transformer.h.10.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
31
+ "transformer.h.11.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
32
+ "transformer.h.11.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
33
+ "transformer.h.11.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
34
+ "transformer.h.11.ln_1.weight": "pytorch_model-00001-of-00002.bin",
35
+ "transformer.h.11.ln_2.weight": "pytorch_model-00001-of-00002.bin",
36
+ "transformer.h.11.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
37
+ "transformer.h.11.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
38
+ "transformer.h.11.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
39
+ "transformer.h.12.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
40
+ "transformer.h.12.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
41
+ "transformer.h.12.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
42
+ "transformer.h.12.ln_1.weight": "pytorch_model-00001-of-00002.bin",
43
+ "transformer.h.12.ln_2.weight": "pytorch_model-00001-of-00002.bin",
44
+ "transformer.h.12.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
45
+ "transformer.h.12.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
46
+ "transformer.h.12.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
47
+ "transformer.h.13.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
48
+ "transformer.h.13.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
49
+ "transformer.h.13.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
50
+ "transformer.h.13.ln_1.weight": "pytorch_model-00001-of-00002.bin",
51
+ "transformer.h.13.ln_2.weight": "pytorch_model-00001-of-00002.bin",
52
+ "transformer.h.13.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
53
+ "transformer.h.13.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
54
+ "transformer.h.13.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
55
+ "transformer.h.14.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
56
+ "transformer.h.14.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
57
+ "transformer.h.14.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
58
+ "transformer.h.14.ln_1.weight": "pytorch_model-00001-of-00002.bin",
59
+ "transformer.h.14.ln_2.weight": "pytorch_model-00001-of-00002.bin",
60
+ "transformer.h.14.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
61
+ "transformer.h.14.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
62
+ "transformer.h.14.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
63
+ "transformer.h.15.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
64
+ "transformer.h.15.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
65
+ "transformer.h.15.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
66
+ "transformer.h.15.ln_1.weight": "pytorch_model-00001-of-00002.bin",
67
+ "transformer.h.15.ln_2.weight": "pytorch_model-00001-of-00002.bin",
68
+ "transformer.h.15.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
69
+ "transformer.h.15.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
70
+ "transformer.h.15.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
71
+ "transformer.h.16.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
72
+ "transformer.h.16.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
73
+ "transformer.h.16.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
74
+ "transformer.h.16.ln_1.weight": "pytorch_model-00001-of-00002.bin",
75
+ "transformer.h.16.ln_2.weight": "pytorch_model-00001-of-00002.bin",
76
+ "transformer.h.16.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
77
+ "transformer.h.16.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
78
+ "transformer.h.16.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
79
+ "transformer.h.17.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
80
+ "transformer.h.17.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
81
+ "transformer.h.17.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
82
+ "transformer.h.17.ln_1.weight": "pytorch_model-00001-of-00002.bin",
83
+ "transformer.h.17.ln_2.weight": "pytorch_model-00001-of-00002.bin",
84
+ "transformer.h.17.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
85
+ "transformer.h.17.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
86
+ "transformer.h.17.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
87
+ "transformer.h.18.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
88
+ "transformer.h.18.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
89
+ "transformer.h.18.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
90
+ "transformer.h.18.ln_1.weight": "pytorch_model-00001-of-00002.bin",
91
+ "transformer.h.18.ln_2.weight": "pytorch_model-00001-of-00002.bin",
92
+ "transformer.h.18.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
93
+ "transformer.h.18.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
94
+ "transformer.h.18.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
95
+ "transformer.h.19.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
96
+ "transformer.h.19.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
97
+ "transformer.h.19.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
98
+ "transformer.h.19.ln_1.weight": "pytorch_model-00001-of-00002.bin",
99
+ "transformer.h.19.ln_2.weight": "pytorch_model-00001-of-00002.bin",
100
+ "transformer.h.19.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
101
+ "transformer.h.19.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
102
+ "transformer.h.19.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
103
+ "transformer.h.2.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
104
+ "transformer.h.2.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
105
+ "transformer.h.2.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
106
+ "transformer.h.2.ln_1.weight": "pytorch_model-00001-of-00002.bin",
107
+ "transformer.h.2.ln_2.weight": "pytorch_model-00001-of-00002.bin",
108
+ "transformer.h.2.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
109
+ "transformer.h.2.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
110
+ "transformer.h.2.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
111
+ "transformer.h.20.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
112
+ "transformer.h.20.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
113
+ "transformer.h.20.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
114
+ "transformer.h.20.ln_1.weight": "pytorch_model-00001-of-00002.bin",
115
+ "transformer.h.20.ln_2.weight": "pytorch_model-00001-of-00002.bin",
116
+ "transformer.h.20.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
117
+ "transformer.h.20.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
118
+ "transformer.h.20.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
119
+ "transformer.h.21.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
120
+ "transformer.h.21.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
121
+ "transformer.h.21.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
122
+ "transformer.h.21.ln_1.weight": "pytorch_model-00001-of-00002.bin",
123
+ "transformer.h.21.ln_2.weight": "pytorch_model-00001-of-00002.bin",
124
+ "transformer.h.21.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
125
+ "transformer.h.21.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
126
+ "transformer.h.21.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
127
+ "transformer.h.22.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
128
+ "transformer.h.22.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
129
+ "transformer.h.22.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
130
+ "transformer.h.22.ln_1.weight": "pytorch_model-00002-of-00002.bin",
131
+ "transformer.h.22.ln_2.weight": "pytorch_model-00002-of-00002.bin",
132
+ "transformer.h.22.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
133
+ "transformer.h.22.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
134
+ "transformer.h.22.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
135
+ "transformer.h.23.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
136
+ "transformer.h.23.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
137
+ "transformer.h.23.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
138
+ "transformer.h.23.ln_1.weight": "pytorch_model-00002-of-00002.bin",
139
+ "transformer.h.23.ln_2.weight": "pytorch_model-00002-of-00002.bin",
140
+ "transformer.h.23.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
141
+ "transformer.h.23.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
142
+ "transformer.h.23.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
143
+ "transformer.h.24.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
144
+ "transformer.h.24.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
145
+ "transformer.h.24.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
146
+ "transformer.h.24.ln_1.weight": "pytorch_model-00002-of-00002.bin",
147
+ "transformer.h.24.ln_2.weight": "pytorch_model-00002-of-00002.bin",
148
+ "transformer.h.24.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
149
+ "transformer.h.24.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
150
+ "transformer.h.24.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
151
+ "transformer.h.25.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
152
+ "transformer.h.25.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
153
+ "transformer.h.25.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
154
+ "transformer.h.25.ln_1.weight": "pytorch_model-00002-of-00002.bin",
155
+ "transformer.h.25.ln_2.weight": "pytorch_model-00002-of-00002.bin",
156
+ "transformer.h.25.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
157
+ "transformer.h.25.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
158
+ "transformer.h.25.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
159
+ "transformer.h.26.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
160
+ "transformer.h.26.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
161
+ "transformer.h.26.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
162
+ "transformer.h.26.ln_1.weight": "pytorch_model-00002-of-00002.bin",
163
+ "transformer.h.26.ln_2.weight": "pytorch_model-00002-of-00002.bin",
164
+ "transformer.h.26.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
165
+ "transformer.h.26.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
166
+ "transformer.h.26.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
167
+ "transformer.h.27.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
168
+ "transformer.h.27.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
169
+ "transformer.h.27.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
170
+ "transformer.h.27.ln_1.weight": "pytorch_model-00002-of-00002.bin",
171
+ "transformer.h.27.ln_2.weight": "pytorch_model-00002-of-00002.bin",
172
+ "transformer.h.27.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
173
+ "transformer.h.27.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
174
+ "transformer.h.27.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
175
+ "transformer.h.28.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
176
+ "transformer.h.28.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
177
+ "transformer.h.28.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
178
+ "transformer.h.28.ln_1.weight": "pytorch_model-00002-of-00002.bin",
179
+ "transformer.h.28.ln_2.weight": "pytorch_model-00002-of-00002.bin",
180
+ "transformer.h.28.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
181
+ "transformer.h.28.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
182
+ "transformer.h.28.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
183
+ "transformer.h.29.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
184
+ "transformer.h.29.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
185
+ "transformer.h.29.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
186
+ "transformer.h.29.ln_1.weight": "pytorch_model-00002-of-00002.bin",
187
+ "transformer.h.29.ln_2.weight": "pytorch_model-00002-of-00002.bin",
188
+ "transformer.h.29.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
189
+ "transformer.h.29.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
190
+ "transformer.h.29.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
191
+ "transformer.h.3.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
192
+ "transformer.h.3.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
193
+ "transformer.h.3.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
194
+ "transformer.h.3.ln_1.weight": "pytorch_model-00001-of-00002.bin",
195
+ "transformer.h.3.ln_2.weight": "pytorch_model-00001-of-00002.bin",
196
+ "transformer.h.3.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
197
+ "transformer.h.3.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
198
+ "transformer.h.3.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
199
+ "transformer.h.30.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
200
+ "transformer.h.30.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
201
+ "transformer.h.30.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
202
+ "transformer.h.30.ln_1.weight": "pytorch_model-00002-of-00002.bin",
203
+ "transformer.h.30.ln_2.weight": "pytorch_model-00002-of-00002.bin",
204
+ "transformer.h.30.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
205
+ "transformer.h.30.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
206
+ "transformer.h.30.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
207
+ "transformer.h.31.attn.c_attn.bias": "pytorch_model-00002-of-00002.bin",
208
+ "transformer.h.31.attn.c_attn.weight": "pytorch_model-00002-of-00002.bin",
209
+ "transformer.h.31.attn.c_proj.weight": "pytorch_model-00002-of-00002.bin",
210
+ "transformer.h.31.ln_1.weight": "pytorch_model-00002-of-00002.bin",
211
+ "transformer.h.31.ln_2.weight": "pytorch_model-00002-of-00002.bin",
212
+ "transformer.h.31.mlp.c_proj.weight": "pytorch_model-00002-of-00002.bin",
213
+ "transformer.h.31.mlp.w1.weight": "pytorch_model-00002-of-00002.bin",
214
+ "transformer.h.31.mlp.w2.weight": "pytorch_model-00002-of-00002.bin",
215
+ "transformer.h.4.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
216
+ "transformer.h.4.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
217
+ "transformer.h.4.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
218
+ "transformer.h.4.ln_1.weight": "pytorch_model-00001-of-00002.bin",
219
+ "transformer.h.4.ln_2.weight": "pytorch_model-00001-of-00002.bin",
220
+ "transformer.h.4.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
221
+ "transformer.h.4.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
222
+ "transformer.h.4.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
223
+ "transformer.h.5.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
224
+ "transformer.h.5.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
225
+ "transformer.h.5.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
226
+ "transformer.h.5.ln_1.weight": "pytorch_model-00001-of-00002.bin",
227
+ "transformer.h.5.ln_2.weight": "pytorch_model-00001-of-00002.bin",
228
+ "transformer.h.5.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
229
+ "transformer.h.5.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
230
+ "transformer.h.5.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
231
+ "transformer.h.6.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
232
+ "transformer.h.6.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
233
+ "transformer.h.6.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
234
+ "transformer.h.6.ln_1.weight": "pytorch_model-00001-of-00002.bin",
235
+ "transformer.h.6.ln_2.weight": "pytorch_model-00001-of-00002.bin",
236
+ "transformer.h.6.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
237
+ "transformer.h.6.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
238
+ "transformer.h.6.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
239
+ "transformer.h.7.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
240
+ "transformer.h.7.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
241
+ "transformer.h.7.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
242
+ "transformer.h.7.ln_1.weight": "pytorch_model-00001-of-00002.bin",
243
+ "transformer.h.7.ln_2.weight": "pytorch_model-00001-of-00002.bin",
244
+ "transformer.h.7.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
245
+ "transformer.h.7.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
246
+ "transformer.h.7.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
247
+ "transformer.h.8.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
248
+ "transformer.h.8.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
249
+ "transformer.h.8.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
250
+ "transformer.h.8.ln_1.weight": "pytorch_model-00001-of-00002.bin",
251
+ "transformer.h.8.ln_2.weight": "pytorch_model-00001-of-00002.bin",
252
+ "transformer.h.8.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
253
+ "transformer.h.8.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
254
+ "transformer.h.8.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
255
+ "transformer.h.9.attn.c_attn.bias": "pytorch_model-00001-of-00002.bin",
256
+ "transformer.h.9.attn.c_attn.weight": "pytorch_model-00001-of-00002.bin",
257
+ "transformer.h.9.attn.c_proj.weight": "pytorch_model-00001-of-00002.bin",
258
+ "transformer.h.9.ln_1.weight": "pytorch_model-00001-of-00002.bin",
259
+ "transformer.h.9.ln_2.weight": "pytorch_model-00001-of-00002.bin",
260
+ "transformer.h.9.mlp.c_proj.weight": "pytorch_model-00001-of-00002.bin",
261
+ "transformer.h.9.mlp.w1.weight": "pytorch_model-00001-of-00002.bin",
262
+ "transformer.h.9.mlp.w2.weight": "pytorch_model-00001-of-00002.bin",
263
+ "transformer.ln_f.weight": "pytorch_model-00002-of-00002.bin",
264
+ "transformer.wte.weight": "pytorch_model-00001-of-00002.bin"
265
+ }
266
+ }
qwen.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
qwen_generation_utils.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Generation support."""
7
+
8
+ from typing import Tuple, List, Union, Iterable
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from transformers import PreTrainedTokenizer
14
+ from transformers import logging
15
+ from transformers.generation import LogitsProcessor
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+ # Types.
20
+ HistoryType = List[Tuple[str, str]]
21
+ TokensType = List[int]
22
+ BatchTokensType = List[List[int]]
23
+
24
+
25
+ def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType:
26
+ for tokens in batch:
27
+ context_length = len(tokens)
28
+ if context_length < seq_length:
29
+ tokens.extend([pad_id] * (seq_length - context_length))
30
+ return batch
31
+
32
+
33
+ def get_ltor_masks_and_position_ids(
34
+ data,
35
+ eod_token,
36
+ reset_position_ids,
37
+ reset_attention_mask,
38
+ eod_mask_loss,
39
+ ):
40
+ """Build masks and position id for left to right model."""
41
+
42
+ # Extract batch size and sequence length.
43
+ micro_batch_size, seq_length = data.size()
44
+
45
+ # Attention mask (lower triangular).
46
+ if reset_attention_mask:
47
+ att_mask_batch = micro_batch_size
48
+ else:
49
+ att_mask_batch = 1
50
+ attention_mask = torch.tril(
51
+ torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
52
+ ).view(att_mask_batch, 1, seq_length, seq_length)
53
+
54
+ # Loss mask.
55
+ loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
56
+ if eod_mask_loss:
57
+ loss_mask[data == eod_token] = 0.0
58
+
59
+ # Position ids.
60
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
61
+ position_ids = position_ids.unsqueeze(0).expand_as(data)
62
+ # We need to clone as the ids will be modifed based on batch index.
63
+ if reset_position_ids:
64
+ position_ids = position_ids.clone()
65
+
66
+ if reset_position_ids or reset_attention_mask:
67
+ # Loop through the batches:
68
+ for b in range(micro_batch_size):
69
+
70
+ # Find indecies where EOD token is.
71
+ eod_index = position_ids[b, data[b] == eod_token]
72
+ # Detach indecies from positions if going to modify positions.
73
+ if reset_position_ids:
74
+ eod_index = eod_index.clone()
75
+
76
+ # Loop through EOD indecies:
77
+ prev_index = 0
78
+ for j in range(eod_index.size()[0]):
79
+ i = eod_index[j]
80
+ # Mask attention loss.
81
+ if reset_attention_mask:
82
+ attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
83
+ # Reset positions.
84
+ if reset_position_ids:
85
+ position_ids[b, (i + 1) :] -= i + 1 - prev_index
86
+ prev_index = i + 1
87
+
88
+ # Convert attention mask to binary:
89
+ attention_mask = attention_mask < 0.5
90
+
91
+ return attention_mask, loss_mask, position_ids
92
+
93
+
94
+ def get_batch(context_tokens: torch.LongTensor, eod_id: int):
95
+ """Generate batch from context tokens."""
96
+ # Move to GPU.
97
+ tokens = context_tokens.contiguous().to(context_tokens.device)
98
+ # Get the attention mask and postition ids.
99
+ attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
100
+ tokens,
101
+ eod_id,
102
+ reset_position_ids=False,
103
+ reset_attention_mask=False,
104
+ eod_mask_loss=False,
105
+ )
106
+ return tokens, attention_mask, position_ids
107
+
108
+
109
+ def get_stop_words_ids(chat_format, tokenizer):
110
+ if chat_format == "raw":
111
+ stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
112
+ elif chat_format == "chatml":
113
+ stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
114
+ else:
115
+ raise NotImplementedError(f"Unknown chat format {chat_format!r}")
116
+ return stop_words_ids
117
+
118
+
119
+ def make_context(
120
+ tokenizer: PreTrainedTokenizer,
121
+ query: str,
122
+ history: List[Tuple[str, str]] = None,
123
+ system: str = "",
124
+ max_window_size: int = 6144,
125
+ chat_format: str = "chatml",
126
+ ):
127
+ if history is None:
128
+ history = []
129
+
130
+ if chat_format == "chatml":
131
+ im_start, im_end = "<|im_start|>", "<|im_end|>"
132
+ im_start_tokens = [tokenizer.im_start_id]
133
+ im_end_tokens = [tokenizer.im_end_id]
134
+ nl_tokens = tokenizer.encode("\n")
135
+
136
+ def _tokenize_str(role, content):
137
+ return f"{role}\n{content}", tokenizer.encode(
138
+ role, allowed_special=set()
139
+ ) + nl_tokens + tokenizer.encode(content, allowed_special=set())
140
+
141
+ system_text, system_tokens_part = _tokenize_str("system", system)
142
+ system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
143
+
144
+ raw_text = ""
145
+ context_tokens = []
146
+
147
+ for turn_query, turn_response in reversed(history):
148
+ query_text, query_tokens_part = _tokenize_str("user", turn_query)
149
+ query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
150
+ response_text, response_tokens_part = _tokenize_str(
151
+ "assistant", turn_response
152
+ )
153
+ response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
154
+
155
+ next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
156
+ prev_chat = (
157
+ f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
158
+ )
159
+
160
+ current_context_size = (
161
+ len(system_tokens) + len(next_context_tokens) + len(context_tokens)
162
+ )
163
+ if current_context_size < max_window_size:
164
+ context_tokens = next_context_tokens + context_tokens
165
+ raw_text = prev_chat + raw_text
166
+ else:
167
+ break
168
+
169
+ context_tokens = system_tokens + context_tokens
170
+ raw_text = f"{im_start}{system_text}{im_end}" + raw_text
171
+ context_tokens += (
172
+ nl_tokens
173
+ + im_start_tokens
174
+ + _tokenize_str("user", query)[1]
175
+ + im_end_tokens
176
+ + nl_tokens
177
+ + im_start_tokens
178
+ + tokenizer.encode("assistant")
179
+ + nl_tokens
180
+ )
181
+ raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
182
+
183
+ elif chat_format == "raw":
184
+ raw_text = query
185
+ context_tokens = tokenizer.encode(raw_text)
186
+ else:
187
+ raise NotImplementedError(f"Unknown chat format {chat_format!r}")
188
+
189
+ return raw_text, context_tokens
190
+
191
+
192
+ def _decode_default(
193
+ tokens: List[int],
194
+ *,
195
+ stop_words: List[str],
196
+ eod_words: List[str],
197
+ tokenizer: PreTrainedTokenizer,
198
+ raw_text_len: int,
199
+ verbose: bool = False,
200
+ return_end_reason: bool = False,
201
+ errors: str='replace',
202
+ ):
203
+ trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:]
204
+ if verbose:
205
+ print("\nRaw Generate: ", trim_decode_tokens)
206
+
207
+ end_reason = f"Gen length {len(tokens)}"
208
+ for stop_word in stop_words:
209
+ trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
210
+ for eod_word in eod_words:
211
+ if eod_word in trim_decode_tokens:
212
+ end_reason = f"Gen {eod_word!r}"
213
+ trim_decode_tokens = trim_decode_tokens.split(eod_word)[0]
214
+ trim_decode_tokens = trim_decode_tokens.strip()
215
+ if verbose:
216
+ print("\nEnd Reason:", end_reason)
217
+ print("\nGenerate: ", trim_decode_tokens)
218
+
219
+ if return_end_reason:
220
+ return trim_decode_tokens, end_reason
221
+ else:
222
+ return trim_decode_tokens
223
+
224
+
225
+ def _decode_chatml(
226
+ tokens: List[int],
227
+ *,
228
+ stop_words: List[str],
229
+ eod_token_ids: List[int],
230
+ tokenizer: PreTrainedTokenizer,
231
+ raw_text_len: int,
232
+ context_length: int,
233
+ verbose: bool = False,
234
+ return_end_reason: bool = False,
235
+ errors: str='replace'
236
+ ):
237
+ end_reason = f"Gen length {len(tokens)}"
238
+ eod_token_idx = context_length
239
+ for eod_token_idx in range(context_length, len(tokens)):
240
+ if tokens[eod_token_idx] in eod_token_ids:
241
+ end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
242
+ break
243
+
244
+ trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:]
245
+ if verbose:
246
+ print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:])
247
+ print("\nRaw Generate:", trim_decode_tokens)
248
+ print("\nEnd Reason:", end_reason)
249
+ for stop_word in stop_words:
250
+ trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
251
+ trim_decode_tokens = trim_decode_tokens.strip()
252
+ if verbose:
253
+ print("\nGenerate:", trim_decode_tokens)
254
+
255
+ if return_end_reason:
256
+ return trim_decode_tokens, end_reason
257
+ else:
258
+ return trim_decode_tokens
259
+
260
+
261
+ def decode_tokens(
262
+ tokens: Union[torch.LongTensor, TokensType],
263
+ tokenizer: PreTrainedTokenizer,
264
+ raw_text_len: int,
265
+ context_length: int,
266
+ chat_format: str,
267
+ verbose: bool = False,
268
+ return_end_reason: bool = False,
269
+ errors: str="replace",
270
+ ) -> str:
271
+ if torch.is_tensor(tokens):
272
+ tokens = tokens.cpu().numpy().tolist()
273
+
274
+ if chat_format == "chatml":
275
+ return _decode_chatml(
276
+ tokens,
277
+ stop_words=[],
278
+ eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
279
+ tokenizer=tokenizer,
280
+ raw_text_len=raw_text_len,
281
+ context_length=context_length,
282
+ verbose=verbose,
283
+ return_end_reason=return_end_reason,
284
+ errors=errors,
285
+ )
286
+ elif chat_format == "raw":
287
+ return _decode_default(
288
+ tokens,
289
+ stop_words=["<|endoftext|>"],
290
+ eod_words=["<|endoftext|>"],
291
+ tokenizer=tokenizer,
292
+ raw_text_len=raw_text_len,
293
+ verbose=verbose,
294
+ return_end_reason=return_end_reason,
295
+ errors=errors,
296
+ )
297
+ else:
298
+ raise NotImplementedError(f"Unknown chat format {chat_format!r}")
299
+
300
+
301
+ class StopWordsLogitsProcessor(LogitsProcessor):
302
+ """
303
+ :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.
304
+
305
+ Args:
306
+ stop_words_ids (:obj:`List[List[int]]`):
307
+ List of list of token ids of stop ids. In order to get the tokens of the words
308
+ that should not appear in the generated text, use :obj:`tokenizer(bad_word,
309
+ add_prefix_space=True).input_ids`.
310
+ eos_token_id (:obj:`int`):
311
+ The id of the `end-of-sequence` token.
312
+ """
313
+
314
+ def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
315
+
316
+ if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
317
+ raise ValueError(
318
+ f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
319
+ )
320
+ if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
321
+ raise ValueError(
322
+ f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
323
+ )
324
+ if any(
325
+ any(
326
+ (not isinstance(token_id, (int, np.integer)) or token_id < 0)
327
+ for token_id in stop_word_ids
328
+ )
329
+ for stop_word_ids in stop_words_ids
330
+ ):
331
+ raise ValueError(
332
+ f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
333
+ )
334
+
335
+ self.stop_words_ids = list(
336
+ filter(
337
+ lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
338
+ )
339
+ )
340
+ self.eos_token_id = eos_token_id
341
+ for stop_token_seq in self.stop_words_ids:
342
+ assert (
343
+ len(stop_token_seq) > 0
344
+ ), "Stop words token sequences {} cannot have an empty list".format(
345
+ stop_words_ids
346
+ )
347
+
348
+ def __call__(
349
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor
350
+ ) -> torch.FloatTensor:
351
+ stopped_samples = self._calc_stopped_samples(input_ids)
352
+ for i, should_stop in enumerate(stopped_samples):
353
+ if should_stop:
354
+ scores[i, self.eos_token_id] = float(2**15)
355
+ return scores
356
+
357
+ def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
358
+ if len(tokens) == 0:
359
+ # if bad word tokens is just one token always ban it
360
+ return True
361
+ elif len(tokens) > len(prev_tokens):
362
+ # if bad word tokens are longer then prev input_ids they can't be equal
363
+ return False
364
+ elif prev_tokens[-len(tokens) :].tolist() == tokens:
365
+ # if tokens match
366
+ return True
367
+ else:
368
+ return False
369
+
370
+ def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
371
+ stopped_samples = []
372
+ for prev_input_ids_slice in prev_input_ids:
373
+ match = False
374
+ for stop_token_seq in self.stop_words_ids:
375
+ if self._tokens_match(prev_input_ids_slice, stop_token_seq):
376
+ # if tokens do not match continue
377
+ match = True
378
+ break
379
+ stopped_samples.append(match)
380
+
381
+ return stopped_samples
382
+
383
+
384
+ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
385
+ """This function has been mostly taken from huggingface conversational
386
+ ai code at
387
+ https://medium.com/huggingface/how-to-build-a-state-of-the-art-
388
+ conversational-ai-with-transfer-learning-2d818ac26313"""
389
+
390
+ if top_k > 0:
391
+ # Remove all tokens with a probability less than the
392
+ # last token of the top-k
393
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
394
+ logits[indices_to_remove] = filter_value
395
+
396
+ if top_p > 0.0:
397
+ # Cconvert to 1D
398
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
399
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
400
+
401
+ # Remove tokens with cumulative probability above the threshold
402
+ sorted_indices_to_remove = cumulative_probs > top_p
403
+ # Shift the indices to the right to keep also the first token
404
+ # above the threshold
405
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
406
+ sorted_indices_to_remove[..., 0] = 0
407
+ for i in range(sorted_indices.size(0)):
408
+ indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
409
+ logits[i][indices_to_remove] = filter_value
410
+
411
+ return logits
412
+
413
+
414
+ def switch(val1, val2, boolean):
415
+ boolean = boolean.type_as(val1)
416
+ return (1 - boolean) * val1 + boolean * val2
tokenization_qwen.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Tokenization classes for QWen."""
7
+
8
+ import base64
9
+ import logging
10
+ import os
11
+ import unicodedata
12
+ from typing import Collection, Dict, List, Set, Tuple, Union
13
+
14
+ import tiktoken
15
+ from transformers import PreTrainedTokenizer, AddedToken
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"}
21
+
22
+ PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
23
+ ENDOFTEXT = "<|endoftext|>"
24
+ IMSTART = "<|im_start|>"
25
+ IMEND = "<|im_end|>"
26
+ # as the default behavior is changed to allow special tokens in
27
+ # regular texts, the surface forms of special tokens need to be
28
+ # as different as possible to minimize the impact
29
+ EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
30
+ SPECIAL_TOKENS = (
31
+ ENDOFTEXT,
32
+ IMSTART,
33
+ IMEND,
34
+ ) + EXTRAS
35
+
36
+
37
+ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
38
+ with open(tiktoken_bpe_file, "rb") as f:
39
+ contents = f.read()
40
+ return {
41
+ base64.b64decode(token): int(rank)
42
+ for token, rank in (line.split() for line in contents.splitlines() if line)
43
+ }
44
+
45
+ class QWenTokenizer(PreTrainedTokenizer):
46
+ """QWen tokenizer."""
47
+
48
+ vocab_files_names = VOCAB_FILES_NAMES
49
+
50
+ def __init__(
51
+ self,
52
+ vocab_file,
53
+ errors="replace",
54
+ **kwargs,
55
+ ):
56
+ super().__init__(**kwargs)
57
+
58
+ self.errors = errors # how to handle errors in decoding
59
+
60
+ self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
61
+ self.special_tokens = {
62
+ token: index
63
+ for index, token in enumerate(
64
+ SPECIAL_TOKENS, start=len(self.mergeable_ranks)
65
+ )
66
+ }
67
+
68
+ enc = tiktoken.Encoding(
69
+ "Qwen",
70
+ pat_str=PAT_STR,
71
+ mergeable_ranks=self.mergeable_ranks,
72
+ special_tokens=self.special_tokens,
73
+ )
74
+ assert (
75
+ len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
76
+ ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
77
+
78
+ self.decoder = {
79
+ v: k for k, v in self.mergeable_ranks.items()
80
+ } # type: dict[int, bytes|str]
81
+ self.decoder.update({v: k for k, v in self.special_tokens.items()})
82
+
83
+ self.tokenizer = enc # type: tiktoken.Encoding
84
+
85
+ self.eod_id = self.tokenizer.eot_token
86
+ self.im_start_id = self.special_tokens[IMSTART]
87
+ self.im_end_id = self.special_tokens[IMEND]
88
+
89
+ def __len__(self) -> int:
90
+ return self.tokenizer.n_vocab
91
+
92
+ def get_vocab(self) -> Dict[bytes, int]:
93
+ return self.mergeable_ranks
94
+
95
+ def convert_tokens_to_ids(
96
+ self, tokens: Union[bytes, str, List[Union[bytes, str]]]
97
+ ) -> List[int]:
98
+ ids = []
99
+ if isinstance(tokens, (str, bytes)):
100
+ if tokens in self.special_tokens:
101
+ return self.special_tokens[tokens]
102
+ else:
103
+ return self.mergeable_ranks.get(tokens)
104
+ for token in tokens:
105
+ if token in self.special_tokens:
106
+ ids.append(self.special_tokens[token])
107
+ else:
108
+ ids.append(self.mergeable_ranks.get(token))
109
+ return ids
110
+
111
+ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
112
+ if not special_tokens and new_tokens:
113
+ raise ValueError('Adding regular tokens is not supported')
114
+ for token in new_tokens:
115
+ surface_form = token.content if isinstance(token, AddedToken) else token
116
+ if surface_form not in SPECIAL_TOKENS:
117
+ raise ValueError('Adding unknown special tokens is not supported')
118
+ return 0
119
+
120
+ def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
121
+ """
122
+ Save only the vocabulary of the tokenizer (vocabulary).
123
+
124
+ Returns:
125
+ `Tuple(str)`: Paths to the files saved.
126
+ """
127
+ file_path = os.path.join(save_directory, "qwen.tiktoken")
128
+ with open(file_path, "w", encoding="utf8") as w:
129
+ for k, v in self.mergeable_ranks.items():
130
+ line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
131
+ w.write(line)
132
+ return (file_path,)
133
+
134
+ def tokenize(
135
+ self,
136
+ text: str,
137
+ allowed_special: Union[Set, str] = "all",
138
+ disallowed_special: Union[Collection, str] = (),
139
+ **kwargs,
140
+ ) -> List[Union[bytes, str]]:
141
+ """
142
+ Converts a string in a sequence of tokens.
143
+
144
+ Args:
145
+ text (`str`):
146
+ The sequence to be encoded.
147
+ allowed_special (`Literal["all"]` or `set`):
148
+ The surface forms of the tokens to be encoded as special tokens in regular texts.
149
+ Default to "all".
150
+ disallowed_special (`Literal["all"]` or `Collection`):
151
+ The surface forms of the tokens that should not be in regular texts and trigger errors.
152
+ Default to an empty tuple.
153
+
154
+ kwargs (additional keyword arguments, *optional*):
155
+ Will be passed to the underlying model specific encode method.
156
+
157
+ Returns:
158
+ `List[bytes|str]`: The list of tokens.
159
+ """
160
+ tokens = []
161
+ text = unicodedata.normalize("NFC", text)
162
+
163
+ # this implementation takes a detour: text -> token id -> token surface forms
164
+ for t in self.tokenizer.encode(
165
+ text, allowed_special=allowed_special, disallowed_special=disallowed_special
166
+ ):
167
+ tokens.append(self.decoder[t])
168
+ return tokens
169
+
170
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
171
+ """
172
+ Converts a sequence of tokens in a single string.
173
+ """
174
+ text = ""
175
+ temp = b""
176
+ for t in tokens:
177
+ if isinstance(t, str):
178
+ if temp:
179
+ text += temp.decode("utf-8", errors=self.errors)
180
+ temp = b""
181
+ text += t
182
+ elif isinstance(t, bytes):
183
+ temp += t
184
+ else:
185
+ raise TypeError("token should only be of type types or str")
186
+ if temp:
187
+ text += temp.decode("utf-8", errors=self.errors)
188
+ return text
189
+
190
+ @property
191
+ def vocab_size(self):
192
+ return self.tokenizer.n_vocab
193
+
194
+ def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
195
+ """Converts an id to a token, special tokens included"""
196
+ if index in self.decoder:
197
+ return self.decoder[index]
198
+ raise ValueError("unknown ids")
199
+
200
+ def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
201
+ """Converts a token to an id using the vocab, special tokens included"""
202
+ if token in self.special_tokens:
203
+ return self.special_tokens[token]
204
+ if token in self.mergeable_ranks:
205
+ return self.mergeable_ranks[token]
206
+ raise ValueError("unknown token")
207
+
208
+ def _tokenize(self, text: str, **kwargs):
209
+ """
210
+ Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
211
+ vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
212
+
213
+ Do NOT take care of added tokens.
214
+ """
215
+ raise NotImplementedError
216
+
217
+ def _decode(
218
+ self,
219
+ token_ids: Union[int, List[int]],
220
+ skip_special_tokens: bool = False,
221
+ errors: str = None,
222
+ **kwargs,
223
+ ) -> str:
224
+ if isinstance(token_ids, int):
225
+ token_ids = [token_ids]
226
+ if skip_special_tokens:
227
+ token_ids = [i for i in token_ids if i < self.eod_id]
228
+ return self.tokenizer.decode(token_ids, errors=errors or self.errors)
tokenizer_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoTokenizer": [
4
+ "tokenization_qwen.QWenTokenizer",
5
+ null
6
+ ]
7
+ },
8
+ "clean_up_tokenization_spaces": true,
9
+ "model_max_length": 8192,
10
+ "padding_side": "left",
11
+ "tokenizer_class": "QWenTokenizer"
12
+ }