Update modeling_motif.py

#1
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
LICENSE CHANGED
@@ -8,18 +8,17 @@ Motif-2.6B Release Date: June 9, 2025
8
  "Motif Technologies" or "we" means Motif Technologies Corp.
9
  By clicking "I Accept" below or by using or distributing any portion or element of the Motif Materials, you agree to be bound by this Agreement.
10
  1. License Rights and Redistribution.
11
-
12
  a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Motif Technologies' intellectual property or other rights owned by Motif Technologies embodied in the Motif Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Motif Materials.
13
  b. Redistribution and Use.
14
- i. If you distribute or make available the Motif Materials (or any derivative works thereof), or a product or service (including another AI model) that contains any of them, you shall (A) provide a copy of this Agreement with any such Motif Materials; and (B) prominently display "Built with Motif" on a related website, user interface, blogpost, about page, or product documentation. If you use the Motif Materials or any outputs or results of the Motif Materials to create, train, fine tune, or otherwise improve an AI model, which is distributed or made available, you shall also include "Motif-2.6B" at the beginning of any such AI model name.
15
- ii. If you receive Motif Materials, or any derivative works thereof, from a Licensee as part of an integrated end user product, then Section 2 of this Agreement will not apply to you.
16
- iii. You must retain in all copies of the Motif Materials that you distribute the following attribution notice within a "Notice" text file distributed as a part of such copies: "Motif-2.6B is licensed under the Motif-2.6B Community License, Copyright © Motif Technologies Corp. All Rights Reserved."
17
- iv. Your use of the Motif Materials must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Motif Materials (available at https://motiftech.io), which is hereby incorporated by reference into this Agreement.
18
 
19
  2. Additional Commercial Terms. If, on the Motif-2.6B version release date, the monthly active users of the products or services made available by or for Licensee, or Licensee's affiliates, is greater than 700 million monthly active users in the preceding calendar month, you must request a license from Motif Technologies, which Motif Technologies may grant to you in its sole discretion, and you are not authorized to exercise any of the rights under this Agreement unless or until Motif Technologies otherwise expressly grants you such license.
20
  3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE MOTIF MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, AND MOTIF TECHNOLOGIES DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE MOTIF MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE MOTIF MATERIALS AND ANY OUTPUT AND RESULTS.
21
  4. Limitation of Liability. IN NO EVENT WILL MOTIF TECHNOLOGIES OR ITS SHAREHOLDER OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF MOTIF TECHNOLOGIES OR ITS SHAREHOLDER OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
22
- 5. Intellectual Property.
23
  a. No trademark licenses are granted under this Agreement, and in connection with the Motif Materials, neither Motif Technologies nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Motif Materials or as set forth in this Section 5(a). Motif Technologies hereby grants you a license to use "Motif" (the "Mark") solely as required to comply with the last sentence of Section 1.b.i. All goodwill arising out of your use of the Mark will inure to the benefit of Motif Technologies.
24
  b. Subject to Motif Technologies' ownership of Motif Materials and derivatives made by or for Motif Technologies, with respect to any derivative works and modifications of the Motif Materials that are made by you, as between you and Motif Technologies, you are and will be the owner of such derivative works and modifications.
25
  c. If you institute litigation or other proceedings against Motif Technologies, Motif Technologies' shareholder or affiliate or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Motif Materials or Motif-2.6B outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Motif Technologies from and against any claim by any third party arising out of or related to your use or distribution of the Motif Materials.
 
8
  "Motif Technologies" or "we" means Motif Technologies Corp.
9
  By clicking "I Accept" below or by using or distributing any portion or element of the Motif Materials, you agree to be bound by this Agreement.
10
  1. License Rights and Redistribution.
 
11
  a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Motif Technologies' intellectual property or other rights owned by Motif Technologies embodied in the Motif Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Motif Materials.
12
  b. Redistribution and Use.
13
+ i. If you distribute or make available the Motif Materials (or any derivative works thereof), or a product or service (including another AI model) that contains any of them, you shall (A) provide a copy of this Agreement with any such Motif Materials; and (B) prominently display "Built with Motif" on a related website, user interface, blogpost, about page, or product documentation. If you use the Motif Materials or any outputs or results of the Motif Materials to create, train, fine tune, or otherwise improve an AI model, which is distributed or made available, you shall also include "Motif-2.6B" at the beginning of any such AI model name.
14
+ ii. If you receive Motif Materials, or any derivative works thereof, from a Licensee as part of an integrated end user product, then Section 2 of this Agreement will not apply to you.
15
+ iii. You must retain in all copies of the Motif Materials that you distribute the following attribution notice within a "Notice" text file distributed as a part of such copies: "Motif-2.6B is licensed under the Motif-2.6B Community License, Copyright © Motif Technologies Corp. All Rights Reserved."
16
+ iv. Your use of the Motif Materials must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Motif Materials (available at https://motiftech.io), which is hereby incorporated by reference into this Agreement.
17
 
18
  2. Additional Commercial Terms. If, on the Motif-2.6B version release date, the monthly active users of the products or services made available by or for Licensee, or Licensee's affiliates, is greater than 700 million monthly active users in the preceding calendar month, you must request a license from Motif Technologies, which Motif Technologies may grant to you in its sole discretion, and you are not authorized to exercise any of the rights under this Agreement unless or until Motif Technologies otherwise expressly grants you such license.
19
  3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE MOTIF MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, AND MOTIF TECHNOLOGIES DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE MOTIF MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE MOTIF MATERIALS AND ANY OUTPUT AND RESULTS.
20
  4. Limitation of Liability. IN NO EVENT WILL MOTIF TECHNOLOGIES OR ITS SHAREHOLDER OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF MOTIF TECHNOLOGIES OR ITS SHAREHOLDER OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
21
+ 5. Intellectual Property.
22
  a. No trademark licenses are granted under this Agreement, and in connection with the Motif Materials, neither Motif Technologies nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Motif Materials or as set forth in this Section 5(a). Motif Technologies hereby grants you a license to use "Motif" (the "Mark") solely as required to comply with the last sentence of Section 1.b.i. All goodwill arising out of your use of the Mark will inure to the benefit of Motif Technologies.
23
  b. Subject to Motif Technologies' ownership of Motif Materials and derivatives made by or for Motif Technologies, with respect to any derivative works and modifications of the Motif Materials that are made by you, as between you and Motif Technologies, you are and will be the owner of such derivative works and modifications.
24
  c. If you institute litigation or other proceedings against Motif Technologies, Motif Technologies' shareholder or affiliate or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Motif Materials or Motif-2.6B outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Motif Technologies from and against any claim by any third party arising out of or related to your use or distribution of the Motif Materials.
README.md CHANGED
@@ -1,5 +1,234 @@
1
- ---
2
- license: other
3
- license_name: motif-license
4
- license_link: LICENSE
5
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: motif-license
4
+ license_link: LICENSE
5
+ language:
6
+ - en
7
+ - ko
8
+ ---
9
+
10
+ *Last update: 8th June 2025*
11
+
12
+ # Introduction
13
+
14
+ We announce **Motif 2.6B**, a 2.6 billion parameter language model trained from scratch on AMD Instinct™ MI250X GPUs. Motif 2.6B marks our very first step toward building helpful, reliable AI aligned with human values. With this initial release, our goal is for Motif 2.6B to match the performance of well-known open-source models such as Gemma, Llama, and Phi — particularly those in the sLLM regime.
15
+
16
+ # Training information
17
+
18
+ - GPUs: 384 MI250X
19
+ - Training time: 42 days
20
+ - Training data: 2.4T tokens
21
+
22
+ *Notice: A detailed technical report will be released at a later time.*
23
+
24
+ # Evaluation
25
+
26
+ When models are released, their accompanying technical reports or papers often present benchmark results based on evaluation settings chosen by the developers. While this is a common and understandable practice, it can lead to challenges when comparing models across different organizations. The same model may yield different scores depending on evaluation conditions, and details of these conditions are not always fully disclosed. This lack of standardization can make it difficult for the open-source community to interpret and trust reported results. We therefore reference performance scores based on the official numbers reported by each model’s developers in their respective publications.
27
+
28
+ To illustrate how much evaluation scores can vary across reports, we provide concrete examples of benchmark score differences for major models in the **Evaluation Appendix**.
29
+
30
+ ### Comparison to Mistral 7B by Mistral AI
31
+
32
+ The benchmarks and corresponding scores listed in the table below are taken directly from the [Mistral 7B technical report](https://arxiv.org/pdf/2310.06825).
33
+
34
+ |Benchmark|Metric|Mistral 7B|Motif 2.6B|Improvement|
35
+ |---|---|---|---|---|
36
+ |MMLU|5-shot|60.1|57.93|-3.61%|
37
+ |HellaSwag|0-shot|81.3|61.35|-24.54%|
38
+ |WinoG|0-shot|75.3|59.91|-20.44%|
39
+ |PIQA|0-shot|83|75.95|-8.49%|
40
+ |Arc-e|0-shot|80|87.21|+9.01%|
41
+ |Arc-c|0-shot|55.5|74.2|+33.69%|
42
+ |NQ|5-shot|28.8|11.14|-61.32%|
43
+ |TriviaQA|5-shot|69.9|54.97|-21.36%|
44
+ |HumanEval|0-shot|30.5|68.3|+123.93%|
45
+ |MBPP|3-shot|47.5|60.3|+26.95%|
46
+ |MATH|4-shot, maj@4|13.1|40.2*|+206.87%|
47
+ |GSM8K|8-shot, maj@8|52.2|77.71|+48.87%|
48
+ ||||**Average**|**+33.77%**|
49
+
50
+ \* : We report the 4-shot, maj@1 score instead of the 4-shot, maj@4.
51
+
52
+ ### Comparison to the Gemma series by Google
53
+
54
+ #### Gemma 1 & 2
55
+ The benchmarks and corresponding scores listed in the table below are taken directly from the [Gemma 2 technical report](https://arxiv.org/abs/2408.00118).
56
+
57
+ *Note: Although referred to as "2B", Gemma 2 2B actually has <U>2.6 billion</U> parameters.*
58
+
59
+ |Benchmark|Metric|Gemma 1 2B|Gemma 1 7B|Gemma 2 2B|Gemma 2 9B|Motif 2.6B|Improvement(over 1 1B)|Improvement(over 1 7B)|Improvement(over 2 2B)|Improvement(over 2 9B)|
60
+ |---|---|---|---|---|---|---|---|---|---|---|
61
+ |MMLU|5-shot|42.3|64.4|52.2|71.3|57.93|+36.95%|-10.05%|+10.98%|-18.75%|
62
+ |ARC-C|25-shot|48.5|61.1|55.7|68.4|75.08|+54.80%|+22.88%|+34.79%|+9.77%|
63
+ |GSM8K|5-shot|15.1|51.8|24.3|68.6|67.85|+349.34%|+30.98%|+179.22%|-1.09%|
64
+ |AGIEval|3-5-shot|24.2|44.9|31.5|52.8|-|-|-|-|-|
65
+ |DROP|3-shot, F1|48.5|56.3|51.2|69.4|29.33|-39.53%|-47.90%|-42.71%|-57.74%|
66
+ |BBH|3-shot, CoT|35.2|59|41.9|68.2|48.56|37.95%|-17.69%|+15.89%|-28.80%|
67
+ |Winogrande|5-shot|66.8|79|71.3|80.6|67.09|+0.43%|-15.08%|-5.90%|-16.76%|
68
+ |HellaSwag|10-shot|71.7|82.3|72.9|81.9|69.89|-2.52%|-15.08%|-4.13%|-14.66%|
69
+ |MATH|4-shot|11.8|24.3|16|36.6|40.2|+240.88%|+65.43%|+151.25%|+9.84%|
70
+ |ARC-e|0-shot|73.2|81.5|80.6|88|87.21|+19.14%|+7.01%|+8.20%|-0.90%|
71
+ |PIQA|0-shot|77.3|81.2|78.4|81.7|75.95|-1.75%|-6.47%|-3.13%|-7.04%|
72
+ |SIQA|0-shot|49.7|51.8|51.9|53.4|61.97|+24.69%|+19.63%|+19.40%|+16.05%|
73
+ |Boolq|0-shot|69.4|83.2|72.7|84.2|67.76|-2.36%|-18.56%|-6.80%|-19.52%|
74
+ |TriviaQA|5-shot|53.2|63.4|60.4|76.6|54.97|+3.33%|-13.30%|-8.99%|-28.24%|
75
+ |NQ|5-shot|12.5|23|17.1|29.2|10.91|-12.72%|-52.57%|-36.20%|-62.64%|
76
+ |HumanEval|pass@1|22|32.3|20.1|40.2|68.3|+210.45%|+111.46%|+239.80%|+69.90%|
77
+ |MBPP|3-shot|29.2|44.4|30.2|52.4|60.3|+106.51%|+35.81%|+99.67%|+15.08%|
78
+ |||||||**Average**|**+84.76%**|**+1.69%**|**+42.42%**|**-14.78%**|
79
+
80
+ #### Gemma 3
81
+ The benchmarks and corresponding scores listed in the table below are taken directly from the [Gemma 3 technical report](https://arxiv.org/abs/2503.19786).
82
+
83
+ |Benchmark|Metric|Gemma 3 1B|Gemma 3 4B|Motif 2.6B|Improvement(over 1B)|Improvement(over 4B)|
84
+ |---|---|---|---|---|---|---|
85
+ |HellaS|10-shot|62.3|77.2|69.89|+12.18%|-9.47%|
86
+ |BoolQ|0-shot|63.2|72.3|67.76|+7.22%|-6.28%|
87
+ |PIQA|0-shot|73.8|79.6|75.59|+2.43%|-5.04%|
88
+ |SIQA|0-shot|48.9|51.9|61.97|+26.73%|+19.40%|
89
+ |TQA|5-shot|39.8|65.8|54.97|+38.12%|-16.46%|
90
+ |NQ|5-shot|9.48|20|10.91|+15.08%|-45.45%|
91
+ |ARC-C|25-shot|38.4|56.2|75.08|+95.52%|+33.59%|
92
+ |ARC-E|0-shot|73|82.4|87.21|+19.47%|+5.84%|
93
+ |WinoG|5-shot|58.2|64.7|67.09|+15.27%|+3.69%|
94
+ |BBH|few-shot, CoT|28.4|50.9|48.56|+70.99%|-4.60%|
95
+ |Drop|1-shot, F1|42.4|60.1|29.33|-30.83%|-51.20%|
96
+ |MMLU|5-shot|-|59.6|57.93|-|-2.80%|
97
+ |MMLUpro|5-shot, CoT|-|29.2|-|-|-|
98
+ |AGIE|3-5-shot|-|42.1|-|-|-|
99
+ |MATH|4-shot, CoT|-|24.2|40.2|-|+66.12%|
100
+ |GSM8K|8-shot, CoT|-|38.4|77.71|-|+102.37%|
101
+ |GPQA Diamond|5-shot, CoT|-|15|31.81|-|+112.07%|
102
+ |MBPP|3-shot|-|46|60.3|-|+31.09%|
103
+ |HumanE|0-shot|-|36|68.3|-|+89.72%|
104
+ |IFEval|-|80.2|90.2|74.02|-7.71%|-17.94%|
105
+ |||||**Average**|**+22.04%**|**+16.93%**|
106
+
107
+ ### Comparison to the Llama series by Meta
108
+
109
+ #### Llama 3
110
+ The benchmarks and corresponding scores listed in the table below are taken directly from the [Llama 3 technical report](https://arxiv.org/abs/2407.21783).
111
+
112
+ |Benchmark|Metric|Llama 3 8B|Motif 2.6B|Improvement|
113
+ |---|---|---|---|---|
114
+ |MMLU|5-shot|69.4|57.93|-16.53%|
115
+ |MMLU|0-shot, CoT|73|57.95|-20.62%|
116
+ |MMLU-Pro|5-shot, CoT|48.3|-|-|
117
+ |IFEval|-|80.4|74.02|-7.94%|
118
+ |HumanEval|0-shot|72.6|68.3|-5.92%|
119
+ |MBPP|0-shot|72.8|57.93|-20.43%|
120
+ |GSM8K|8-shot, CoT|84.5|77.71|-8.04%|
121
+ |MATH|0-shot, CoT|51.9|49.68|-4.28%|
122
+ |ARC Challenge|0-shot|83.4|74.2|-11.03%|
123
+ |GPQA|0-shot, CoT|32.8|18.53|-43.51%|
124
+ ||||**Average**|**-15.36%**|
125
+
126
+ #### Llama 3.2
127
+ The benchmarks and corresponding scores listed in the table below are taken directly from the [Llama 3.2 official blog](https://ai.meta.com/blog/llama-3-2-connect-2024-vision-edge-mobile-devices/).
128
+
129
+ |Benchmark|Metric|Llama 3.2 1B|Llama 3.2 1B|Motif 2.6B|Improvement(over 1B)|Improvement(over 3B)|
130
+ |---|---|---|---|---|---|---|
131
+ |MMLU|0-shot|49.3|63.4|57.6|+16.75%|-9.21%|
132
+ |Open-rewrite eval*|0-shot, rougeL|41.6|40.1|-|-|-|
133
+ |TLDR9+|test, 1-shot, rougeL|16.8|19|-|-|-|
134
+ |IFEval|-|59.5|77.4|74.02|+24.40%|-4.37%|
135
+ |GSM9K|8-shot, CoT|44.4|77.7|74.9|+68.69%|-3.60%|
136
+ |MATH|0-shot, CoT|30.6|48|49.68|+62.35%|+3.50%|
137
+ |ARC Challenge|0-shot|59.4|78.6|74.2|+24.92%|-5.6%|
138
+ |GPQA|0-shot|27.2|32.8|25.45|-6.43%|-22.41%|
139
+ |Hellaswag|0-shot|41.2|69.8|61.35|+48.91%|-12.11%|
140
+ |||||**Average**|**+39.42%**|**-3.86%**|
141
+
142
+ ### Comparison to the Phi series by Microsoft
143
+ The benchmarks and corresponding scores listed in the table below are taken directly from the [Phi-3 technical report](https://arxiv.org/abs/2404.14219).
144
+
145
+ |Benchmark|Metric|Phi-3 3.8B|Phi-3 7B|Phi-2 2.7B|Motif 2.6B|Improvement(over 3.8B)|Improvement(over 7B)|Improvement(over 2.7B)|
146
+ |---|---|---|---|---|---|---|---|---|
147
+ |MMLU|5-shot|68.8|75.7|56.3|57.93|-15.80%|-23.47%|+2.90%|
148
+ |HellaSwag|5-shot|76.7|77|53.6|68.97|-10.08%|-10.43%|+28.68%|
149
+ |ANLI|7-shot|52.8|58.1|42.5|47.99|-9.11%|-17.40%|+12.92%|
150
+ |GSM-8K|8-shot, CoT|82.5|89.6|61.1|76.5|-7.27%|-14.62%|+25.20%|
151
+ |MATH|0-shot, CoT|41.3|34.6|-|49.68|+20.29%|+43.58%|-|
152
+ |MedQA|2-shot|53.8|65.4|40.9|42.1|-21.75%|-35.63%|+2.93%|
153
+ |AGIEval|0-shot|37.5|45.1|29.8|-|-|-|-|
154
+ |TriviaQA|5-shot|64|58.1|45.2|54.97|-14.11%|-5.39%|+21.62%|
155
+ |Arc-C|10-shot|84.9|90.7|75.9|75.17|-11.46%|-17.12%|-0.96%|
156
+ |Arc-E|10-shot|94.6|97|88.5|88.64|-6.30%|-8.62%|+0.16%|
157
+ |PIQA|5-shot|84.2|86.9|60.2|78.29|-7.02%|-9.91%|+30.05%|
158
+ |SociQA|5-shot|76.6|79.2|68.3|66.73|-12.89%|-15.74%|-2.3%|
159
+ |BigBench-Hard|3-shot, CoT|71.7|79.1|59.4|48.56|-32.27%|-38.61%|-18.25%|
160
+ |WinoGrande|5-shot|70.8|81.5|54.7|67.09|-5.24%|-17.68%|+22.65%|
161
+ |OpenBookQA|10-shot|83.2|88|73.6|87.8|+5.53%|-0.23%|+19.29%|
162
+ |BoolQ|2-shot|77.2|84.8|-|70.7|-8.42%|-16.63%|-|
163
+ |CommonSenseQA|10-shot|80.2|80|69.3|71.25|-11.16%|-10.94%|2.81%|
164
+ |TruthfulQA|10-shot|65|70.2|-|52.07|-19.89%|-25.83%|-|
165
+ |HumanEval|0-shot|58.5|61|59|68.29|+16.74%|+11.95%|+15.75%|
166
+ |MBPP|3-shot|70|71.7|60.6|60.3|-13.86%|-15.90%|-0.50%|
167
+ |GPQA|2-shot, CoT|32.8|34.3|-|23.44|-28.54%|-31.66%|-|
168
+ |MT Bench|2R. Avg.|8.38|8.7|-|6.77|-19.21%|-22.18%|-|
169
+ ||||||**Average**|**-10.09%**|**-13.45%**|**+10.18%**|
170
+
171
+ ## Evaluation Appendix
172
+
173
+ In the comparisons presented above, Motif 2.6B showed average performance improvements of -15.36% and -14.78% over Llama 3 8B and Gemma 2 9B, respectively, based on the benchmark scores reported in their original technical reports. However, when compared to the benchmarks and scores reported in the Qwen 2.5 technical report, Motif 2.6B shows an average improvement of +18.55% over Llama 3 8B and +1.12% over Gemma 2 9B. See the table below for details.
174
+
175
+ ### Comparison to Llama 3 8B and Gemma 2 9B based on scores from the *Qwen2.5 technical report*
176
+ The benchmarks and corresponding scores listed in the table below are taken directly from the [Qwen2.5 technical report](https://arxiv.org/abs/2412.15115).
177
+
178
+ |Benchmark|Metric|Llama 3 8B|Gemma 2 9B|Motif 2.6B|Improvement(over Llama 3 8B)|Improvement(over Gemma 2 9B)|
179
+ |---|---|---|---|---|---|---|
180
+ |MMLU|5-shot|66.6|71.3|57.93|-13.02%|-18.75%|
181
+ |MMLU-pro|5-shot|35.4|44.7|28.4|-19.77%|-36.47%|
182
+ |MMLU-redux|5-shot|61.6|67.9|59.54|-3.34%|-12.31%|
183
+ |BBH|3-shot|57.7|68.2|39.28|-31.92%|-42.40%|
184
+ |ARC-C|25-shot|59.3|68.2|75.08|+26.61%|+10.09%|
185
+ |TruthfulQA|0-shot|44|45.3|41.55|-5.56%|-8.27%|
186
+ |Winogrande|5-shot|77.4|79.5|67.09|-13.32%|-15.61%|
187
+ |HellaSwag|10-shot|82.1|81.9|69.88|-14.88%|-14.68%|
188
+ |GPQA|5-shot|25.8|32.8|29.24|+13.33%|-10.85%|
189
+ |TheoremQA|5-shot|22.1|28.9|-|-|-|
190
+ |MATH|4-shot|20.5|37.7|40.2|+96.10%|+6.63%|
191
+ |MMLU-stem|5-shot|55.3|65.1|52.9|-4.34%|-18.74%|
192
+ |GSM8K|4-shot|55.3|70.7|68.84|+24.48%|-2.63%|
193
+ |HumanEval|0-shot|33.5|37.8|68.3|+103.88%|+80.69%|
194
+ |HumanEval+|0-shot|29.3|30.5|62.2|+112.29%|+103.93%|
195
+ |MBPP|0-shot|53.9|62.2|60.3|+11.87%|-3.05%|
196
+ |MBPP+|0-shot|44.4|50.6|50.8|+14.41%|+0.40%|
197
+ |MultiPL-E|0-shot|22.6|34.9|-|-|-|
198
+ |||||**Average**|**+18.55%**|**+1.12%**|
199
+
200
+
201
+ ## How to use
202
+ ```python
203
+ from transformers import AutoModelForCausalLM, AutoTokenizer
204
+
205
+ model = AutoModelForCausalLM.from_pretrained(
206
+ "Motif-Technologies/Motif-2.6B",
207
+ trust_remote_code = True,
208
+ _attn_implementation = "eager", # also supports flash_attention_2
209
+ ).cuda()
210
+
211
+ tokenizer = AutoTokenizer.from_pretrained(
212
+ "Motif-Technologies/Motif-2.6B",
213
+ trust_remote_code = True,
214
+ )
215
+
216
+ query = "What is the capital city of South Korea?"
217
+ input_ids = tokenizer.apply_chat_template(
218
+ [
219
+ {'role': 'system', 'content': 'you are an helpful assistant'},
220
+ {'role': 'user', 'content': query},
221
+ ],
222
+ add_generation_prompt = True,
223
+ return_tensors='pt',
224
+ ).cuda()
225
+
226
+ output = model.generate(input_ids, max_new_tokens=128, pad_token_id=tokenizer.eos_token_id)
227
+ output = tokenizer.decode(res[0, input_ids.shape[-1]:], skip_special_tokens = True)
228
+ print(output)
229
+
230
+ """
231
+ The capital city of South Korea is Seoul. Located in the southern part of the country, Seoul is not only the largest city in South Korea but also one of the largest metropolitan areas in the world.
232
+ It is a vibrant and dynamic city known for its rich history, cultural heritage, and modern amenities. Seoul is a major economic, cultural, and political center in East Asia, and it plays a crucial role in the region's politics, economy, and culture.
233
+ The city is divided into different administrative districts, each with its own unique characteristics and attractions.
234
+ """
added_tokens.json ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 219406,
3
+ "<think>": 219404,
4
+ "<|assistant|>": 219402,
5
+ "<|beginoftext|>": 219396,
6
+ "<|dummy_id_100|>": 219505,
7
+ "<|dummy_id_101|>": 219506,
8
+ "<|dummy_id_102|>": 219507,
9
+ "<|dummy_id_103|>": 219508,
10
+ "<|dummy_id_104|>": 219509,
11
+ "<|dummy_id_105|>": 219510,
12
+ "<|dummy_id_106|>": 219511,
13
+ "<|dummy_id_107|>": 219512,
14
+ "<|dummy_id_108|>": 219513,
15
+ "<|dummy_id_109|>": 219514,
16
+ "<|dummy_id_10|>": 219414,
17
+ "<|dummy_id_110|>": 219515,
18
+ "<|dummy_id_111|>": 219516,
19
+ "<|dummy_id_112|>": 219517,
20
+ "<|dummy_id_113|>": 219518,
21
+ "<|dummy_id_114|>": 219519,
22
+ "<|dummy_id_11|>": 219415,
23
+ "<|dummy_id_12|>": 219417,
24
+ "<|dummy_id_13|>": 219418,
25
+ "<|dummy_id_14|>": 219419,
26
+ "<|dummy_id_15|>": 219420,
27
+ "<|dummy_id_16|>": 219421,
28
+ "<|dummy_id_17|>": 219422,
29
+ "<|dummy_id_18|>": 219423,
30
+ "<|dummy_id_19|>": 219424,
31
+ "<|dummy_id_20|>": 219425,
32
+ "<|dummy_id_21|>": 219426,
33
+ "<|dummy_id_22|>": 219427,
34
+ "<|dummy_id_23|>": 219428,
35
+ "<|dummy_id_24|>": 219429,
36
+ "<|dummy_id_25|>": 219430,
37
+ "<|dummy_id_26|>": 219431,
38
+ "<|dummy_id_27|>": 219432,
39
+ "<|dummy_id_28|>": 219433,
40
+ "<|dummy_id_29|>": 219434,
41
+ "<|dummy_id_30|>": 219435,
42
+ "<|dummy_id_31|>": 219436,
43
+ "<|dummy_id_32|>": 219437,
44
+ "<|dummy_id_33|>": 219438,
45
+ "<|dummy_id_34|>": 219439,
46
+ "<|dummy_id_35|>": 219440,
47
+ "<|dummy_id_36|>": 219441,
48
+ "<|dummy_id_37|>": 219442,
49
+ "<|dummy_id_38|>": 219443,
50
+ "<|dummy_id_39|>": 219444,
51
+ "<|dummy_id_3|>": 219407,
52
+ "<|dummy_id_40|>": 219445,
53
+ "<|dummy_id_41|>": 219446,
54
+ "<|dummy_id_42|>": 219447,
55
+ "<|dummy_id_43|>": 219448,
56
+ "<|dummy_id_44|>": 219449,
57
+ "<|dummy_id_45|>": 219450,
58
+ "<|dummy_id_46|>": 219451,
59
+ "<|dummy_id_47|>": 219452,
60
+ "<|dummy_id_48|>": 219453,
61
+ "<|dummy_id_49|>": 219454,
62
+ "<|dummy_id_4|>": 219408,
63
+ "<|dummy_id_50|>": 219455,
64
+ "<|dummy_id_51|>": 219456,
65
+ "<|dummy_id_52|>": 219457,
66
+ "<|dummy_id_53|>": 219458,
67
+ "<|dummy_id_54|>": 219459,
68
+ "<|dummy_id_55|>": 219460,
69
+ "<|dummy_id_56|>": 219461,
70
+ "<|dummy_id_57|>": 219462,
71
+ "<|dummy_id_58|>": 219463,
72
+ "<|dummy_id_59|>": 219464,
73
+ "<|dummy_id_5|>": 219409,
74
+ "<|dummy_id_60|>": 219465,
75
+ "<|dummy_id_61|>": 219466,
76
+ "<|dummy_id_62|>": 219467,
77
+ "<|dummy_id_63|>": 219468,
78
+ "<|dummy_id_64|>": 219469,
79
+ "<|dummy_id_65|>": 219470,
80
+ "<|dummy_id_66|>": 219471,
81
+ "<|dummy_id_67|>": 219472,
82
+ "<|dummy_id_68|>": 219473,
83
+ "<|dummy_id_69|>": 219474,
84
+ "<|dummy_id_6|>": 219410,
85
+ "<|dummy_id_70|>": 219475,
86
+ "<|dummy_id_71|>": 219476,
87
+ "<|dummy_id_72|>": 219477,
88
+ "<|dummy_id_73|>": 219478,
89
+ "<|dummy_id_74|>": 219479,
90
+ "<|dummy_id_75|>": 219480,
91
+ "<|dummy_id_76|>": 219481,
92
+ "<|dummy_id_77|>": 219482,
93
+ "<|dummy_id_78|>": 219483,
94
+ "<|dummy_id_79|>": 219484,
95
+ "<|dummy_id_7|>": 219411,
96
+ "<|dummy_id_80|>": 219485,
97
+ "<|dummy_id_81|>": 219486,
98
+ "<|dummy_id_82|>": 219487,
99
+ "<|dummy_id_83|>": 219488,
100
+ "<|dummy_id_84|>": 219489,
101
+ "<|dummy_id_85|>": 219490,
102
+ "<|dummy_id_86|>": 219491,
103
+ "<|dummy_id_87|>": 219492,
104
+ "<|dummy_id_88|>": 219493,
105
+ "<|dummy_id_89|>": 219494,
106
+ "<|dummy_id_8|>": 219412,
107
+ "<|dummy_id_90|>": 219495,
108
+ "<|dummy_id_91|>": 219496,
109
+ "<|dummy_id_92|>": 219497,
110
+ "<|dummy_id_93|>": 219498,
111
+ "<|dummy_id_94|>": 219499,
112
+ "<|dummy_id_95|>": 219500,
113
+ "<|dummy_id_96|>": 219501,
114
+ "<|dummy_id_97|>": 219502,
115
+ "<|dummy_id_98|>": 219503,
116
+ "<|dummy_id_99|>": 219504,
117
+ "<|dummy_id_9|>": 219413,
118
+ "<|endofprompt|>": 219416,
119
+ "<|endoftext|>": 219395,
120
+ "<|endofturn|>": 219405,
121
+ "<|fim_middle|>": 219398,
122
+ "<|fim_prefix|>": 219397,
123
+ "<|fim_suffix|>": 219399,
124
+ "<|startofturn|>": 219403,
125
+ "<|system|>": 219400,
126
+ "<|user|>": 219401
127
+ }
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "absolute_position_embedding": false,
3
+ "architectures": [
4
+ "MotifForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_motif.MotifConfig",
9
+ "AutoModelForCausalLM": "modeling_motif.MotifForCausalLM"
10
+ },
11
+ "bos_token_id": 219396,
12
+ "eos_token_id": 219395,
13
+ "hidden_act": "poly_norm",
14
+ "hidden_size": 2048,
15
+ "initializer_range": 2e-05,
16
+ "intermediate_size": 8192,
17
+ "loss_reduction": "mean",
18
+ "max_position_embeddings": 16384,
19
+ "max_window_layers": 28,
20
+ "model_type": "Motif",
21
+ "num_attention_heads": 16,
22
+ "num_hidden_layers": 32,
23
+ "num_key_value_heads": 16,
24
+ "rms_norm_eps": 1e-06,
25
+ "rope_scaling": null,
26
+ "rope_theta": 500000.0,
27
+ "sliding_window": null,
28
+ "tie_word_embeddings": true,
29
+ "torch_dtype": "bfloat16",
30
+ "transformers_version": "4.46.3",
31
+ "use_bias": false,
32
+ "use_cache": true,
33
+ "use_sliding_window": false,
34
+ "vocab_size": 219520
35
+ }
configuration_motif.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.modeling_rope_utils import rope_config_validation
6
+ from transformers.utils import logging
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+
11
+ class MotifConfig(PretrainedConfig):
12
+ r"""
13
+ This is the configuration class to store the configuration of a [`MotifModel`]. It is used to instantiate a
14
+ Motif model according to the specified arguments, defining the model architecture. Instantiating a configuration
15
+ with the defaults will yield a similar configuration to that of
16
+ Motif-102B [moreh/Motif-102B](https://huggingface.co/moreh/Motif-102B).
17
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
18
+ documentation from [`PretrainedConfig`] for more information.
19
+ Args:
20
+ vocab_size (`int`, *optional*, defaults to 151936):
21
+ Vocabulary size of the Motif model. Defines the number of different tokens that can be represented by the
22
+ `inputs_ids` passed when calling [`MotifModel`]
23
+ hidden_size (`int`, *optional*, defaults to 4096):
24
+ Dimension of the hidden representations.
25
+ intermediate_size (`int`, *optional*, defaults to 22016):
26
+ Dimension of the MLP representations.
27
+ num_hidden_layers (`int`, *optional*, defaults to 32):
28
+ Number of hidden layers in the Transformer encoder.
29
+ num_attention_heads (`int`, *optional*, defaults to 32):
30
+ Number of attention heads for each attention layer in the Transformer encoder.
31
+ num_key_value_heads (`int`, *optional*, defaults to 32):
32
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
33
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
34
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
35
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
36
+ by meanpooling all the original heads within that group. For more details checkout [this
37
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
38
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
39
+ The non-linear activation function (function or string) in the decoder.
40
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
41
+ The maximum sequence length that this model might ever be used with.
42
+ initializer_range (`float`, *optional*, defaults to 0.02):
43
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
44
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
45
+ The epsilon used by the rms normalization layers.
46
+ use_cache (`bool`, *optional*, defaults to `True`):
47
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
48
+ relevant if `config.is_decoder=True`.
49
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
50
+ Whether the model's input and output word embeddings should be tied.
51
+ rope_theta (`float`, *optional*, defaults to 10000.0):
52
+ The base period of the RoPE embeddings.
53
+ rope_scaling (`Dict`, *optional*):
54
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
55
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
56
+ accordingly.
57
+ Expected contents:
58
+ `rope_type` (`str`):
59
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
60
+ 'llama3'], with 'default' being the original RoPE implementation.
61
+ `factor` (`float`, *optional*):
62
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
63
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
64
+ original maximum pre-trained length.
65
+ `original_max_position_embeddings` (`int`, *optional*):
66
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
67
+ pretraining.
68
+ `attention_factor` (`float`, *optional*):
69
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
70
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
71
+ `factor` field to infer the suggested value.
72
+ `beta_fast` (`float`, *optional*):
73
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
74
+ ramp function. If unspecified, it defaults to 32.
75
+ `beta_slow` (`float`, *optional*):
76
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
77
+ ramp function. If unspecified, it defaults to 1.
78
+ `short_factor` (`List[float]`, *optional*):
79
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
80
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
81
+ size divided by the number of attention heads divided by 2
82
+ `long_factor` (`List[float]`, *optional*):
83
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
84
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
85
+ size divided by the number of attention heads divided by 2
86
+ `low_freq_factor` (`float`, *optional*):
87
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
88
+ `high_freq_factor` (`float`, *optional*):
89
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
90
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
91
+ Whether to use sliding window attention.
92
+ sliding_window (`int`, *optional*, defaults to 4096):
93
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
94
+ max_window_layers (`int`, *optional*, defaults to 28):
95
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
96
+ attention_dropout (`float`, *optional*, defaults to 0.0):
97
+ The dropout ratio for the attention probabilities.
98
+ ```python
99
+ >>> from transformers import MotifModel, MotifConfig
100
+ >>> # Initializing a Motif style configuration
101
+ >>> configuration = MotifConfig()
102
+ >>> # Initializing a model from the Motif-102B style configuration
103
+ >>> model = MotifModel(configuration)
104
+ >>> # Accessing the model configuration
105
+ >>> configuration = model.config
106
+ ```"""
107
+
108
+ model_type = "Motif"
109
+ keys_to_ignore_at_inference = ["past_key_values"]
110
+
111
+ def __init__(
112
+ self,
113
+ vocab_size=151936,
114
+ hidden_size=4096,
115
+ intermediate_size=22016,
116
+ num_hidden_layers=32,
117
+ num_attention_heads=32,
118
+ num_key_value_heads=32,
119
+ hidden_act="silu",
120
+ max_position_embeddings=32768,
121
+ initializer_range=0.02,
122
+ rms_norm_eps=1e-6,
123
+ use_cache=True,
124
+ tie_word_embeddings=False,
125
+ rope_theta=10000.0,
126
+ rope_scaling=None,
127
+ use_sliding_window=False,
128
+ sliding_window=4096,
129
+ max_window_layers=28,
130
+ attention_dropout=0.0,
131
+ **kwargs,
132
+ ):
133
+
134
+ self.vocab_size = vocab_size
135
+ self.max_position_embeddings = max_position_embeddings
136
+ self.hidden_size = hidden_size
137
+ self.intermediate_size = intermediate_size
138
+ self.num_hidden_layers = num_hidden_layers
139
+ self.num_attention_heads = num_attention_heads
140
+ self.use_sliding_window = use_sliding_window
141
+ self.sliding_window = sliding_window if use_sliding_window else None
142
+ self.max_window_layers = max_window_layers
143
+
144
+ # for backward compatibility
145
+ if num_key_value_heads is None:
146
+ num_key_value_heads = num_attention_heads
147
+
148
+ self.num_key_value_heads = num_key_value_heads
149
+ self.hidden_act = hidden_act
150
+ self.initializer_range = initializer_range
151
+ self.rms_norm_eps = rms_norm_eps
152
+ self.use_cache = use_cache
153
+ self.rope_theta = rope_theta
154
+ self.rope_scaling = rope_scaling
155
+ self.attention_dropout = attention_dropout
156
+
157
+ # Validate the correctness of rotary position embeddings parameters
158
+ # BC: if there is a 'type' field, move it to 'rope_type'.
159
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
160
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
161
+ rope_config_validation(self)
162
+
163
+ super().__init__(
164
+ tie_word_embeddings=tie_word_embeddings,
165
+ **kwargs,
166
+ )
167
+ logger.info(f' kwargs : {kwargs}')
generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 219396,
4
+ "eos_token_id": [
5
+ 219395,
6
+ 219405
7
+ ],
8
+ "transformers_version": "4.51.3",
9
+ "use_cache": true
10
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3585a3263814f762598b6fc4464430d61b069742fa543e118d78bbefe01da08
3
+ size 4952662512
model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9380f28b7892e62f48fbd07e0d533ddc0fee44a7ba0b6111408a8a49e577996
3
+ size 4966459400
model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db220fa89bb483e4f8c029335034beb866b1f323adcac350f3d77fe546cad46c
3
+ size 469808712
model.safetensors.index.json ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 10388873728
4
+ },
5
+ "weight_map": {
6
+ "model.embed_tokens.weight": "model-00001-of-00003.safetensors",
7
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
8
+ "model.layers.0.mlp.act_fn.bias": "model-00001-of-00003.safetensors",
9
+ "model.layers.0.mlp.act_fn.weight": "model-00001-of-00003.safetensors",
10
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
11
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
12
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
13
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
14
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
15
+ "model.layers.0.self_attn.lambda_k1": "model-00001-of-00003.safetensors",
16
+ "model.layers.0.self_attn.lambda_k2": "model-00001-of-00003.safetensors",
17
+ "model.layers.0.self_attn.lambda_q1": "model-00001-of-00003.safetensors",
18
+ "model.layers.0.self_attn.lambda_q2": "model-00001-of-00003.safetensors",
19
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
20
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
21
+ "model.layers.0.self_attn.subln.weight": "model-00001-of-00003.safetensors",
22
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
23
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
24
+ "model.layers.1.mlp.act_fn.bias": "model-00001-of-00003.safetensors",
25
+ "model.layers.1.mlp.act_fn.weight": "model-00001-of-00003.safetensors",
26
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
27
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
28
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
29
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
30
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
31
+ "model.layers.1.self_attn.lambda_k1": "model-00001-of-00003.safetensors",
32
+ "model.layers.1.self_attn.lambda_k2": "model-00001-of-00003.safetensors",
33
+ "model.layers.1.self_attn.lambda_q1": "model-00001-of-00003.safetensors",
34
+ "model.layers.1.self_attn.lambda_q2": "model-00001-of-00003.safetensors",
35
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
36
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
37
+ "model.layers.1.self_attn.subln.weight": "model-00001-of-00003.safetensors",
38
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
39
+ "model.layers.10.input_layernorm.weight": "model-00001-of-00003.safetensors",
40
+ "model.layers.10.mlp.act_fn.bias": "model-00001-of-00003.safetensors",
41
+ "model.layers.10.mlp.act_fn.weight": "model-00001-of-00003.safetensors",
42
+ "model.layers.10.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
43
+ "model.layers.10.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
44
+ "model.layers.10.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
45
+ "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
46
+ "model.layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
47
+ "model.layers.10.self_attn.lambda_k1": "model-00001-of-00003.safetensors",
48
+ "model.layers.10.self_attn.lambda_k2": "model-00001-of-00003.safetensors",
49
+ "model.layers.10.self_attn.lambda_q1": "model-00001-of-00003.safetensors",
50
+ "model.layers.10.self_attn.lambda_q2": "model-00001-of-00003.safetensors",
51
+ "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
52
+ "model.layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
53
+ "model.layers.10.self_attn.subln.weight": "model-00001-of-00003.safetensors",
54
+ "model.layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
55
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00003.safetensors",
56
+ "model.layers.11.mlp.act_fn.bias": "model-00002-of-00003.safetensors",
57
+ "model.layers.11.mlp.act_fn.weight": "model-00002-of-00003.safetensors",
58
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
59
+ "model.layers.11.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
60
+ "model.layers.11.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
61
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
62
+ "model.layers.11.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
63
+ "model.layers.11.self_attn.lambda_k1": "model-00001-of-00003.safetensors",
64
+ "model.layers.11.self_attn.lambda_k2": "model-00001-of-00003.safetensors",
65
+ "model.layers.11.self_attn.lambda_q1": "model-00001-of-00003.safetensors",
66
+ "model.layers.11.self_attn.lambda_q2": "model-00001-of-00003.safetensors",
67
+ "model.layers.11.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
68
+ "model.layers.11.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
69
+ "model.layers.11.self_attn.subln.weight": "model-00001-of-00003.safetensors",
70
+ "model.layers.11.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
71
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
72
+ "model.layers.12.mlp.act_fn.bias": "model-00002-of-00003.safetensors",
73
+ "model.layers.12.mlp.act_fn.weight": "model-00002-of-00003.safetensors",
74
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
75
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
76
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
77
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
78
+ "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
79
+ "model.layers.12.self_attn.lambda_k1": "model-00002-of-00003.safetensors",
80
+ "model.layers.12.self_attn.lambda_k2": "model-00002-of-00003.safetensors",
81
+ "model.layers.12.self_attn.lambda_q1": "model-00002-of-00003.safetensors",
82
+ "model.layers.12.self_attn.lambda_q2": "model-00002-of-00003.safetensors",
83
+ "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
84
+ "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
85
+ "model.layers.12.self_attn.subln.weight": "model-00002-of-00003.safetensors",
86
+ "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
87
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
88
+ "model.layers.13.mlp.act_fn.bias": "model-00002-of-00003.safetensors",
89
+ "model.layers.13.mlp.act_fn.weight": "model-00002-of-00003.safetensors",
90
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
91
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
92
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
93
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
94
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
95
+ "model.layers.13.self_attn.lambda_k1": "model-00002-of-00003.safetensors",
96
+ "model.layers.13.self_attn.lambda_k2": "model-00002-of-00003.safetensors",
97
+ "model.layers.13.self_attn.lambda_q1": "model-00002-of-00003.safetensors",
98
+ "model.layers.13.self_attn.lambda_q2": "model-00002-of-00003.safetensors",
99
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
100
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
101
+ "model.layers.13.self_attn.subln.weight": "model-00002-of-00003.safetensors",
102
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
103
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
104
+ "model.layers.14.mlp.act_fn.bias": "model-00002-of-00003.safetensors",
105
+ "model.layers.14.mlp.act_fn.weight": "model-00002-of-00003.safetensors",
106
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
107
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
108
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
109
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
110
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
111
+ "model.layers.14.self_attn.lambda_k1": "model-00002-of-00003.safetensors",
112
+ "model.layers.14.self_attn.lambda_k2": "model-00002-of-00003.safetensors",
113
+ "model.layers.14.self_attn.lambda_q1": "model-00002-of-00003.safetensors",
114
+ "model.layers.14.self_attn.lambda_q2": "model-00002-of-00003.safetensors",
115
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
116
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
117
+ "model.layers.14.self_attn.subln.weight": "model-00002-of-00003.safetensors",
118
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
119
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
120
+ "model.layers.15.mlp.act_fn.bias": "model-00002-of-00003.safetensors",
121
+ "model.layers.15.mlp.act_fn.weight": "model-00002-of-00003.safetensors",
122
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
123
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
124
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
125
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
126
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
127
+ "model.layers.15.self_attn.lambda_k1": "model-00002-of-00003.safetensors",
128
+ "model.layers.15.self_attn.lambda_k2": "model-00002-of-00003.safetensors",
129
+ "model.layers.15.self_attn.lambda_q1": "model-00002-of-00003.safetensors",
130
+ "model.layers.15.self_attn.lambda_q2": "model-00002-of-00003.safetensors",
131
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
132
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
133
+ "model.layers.15.self_attn.subln.weight": "model-00002-of-00003.safetensors",
134
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
135
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
136
+ "model.layers.16.mlp.act_fn.bias": "model-00002-of-00003.safetensors",
137
+ "model.layers.16.mlp.act_fn.weight": "model-00002-of-00003.safetensors",
138
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
139
+ "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
140
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
141
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
142
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
143
+ "model.layers.16.self_attn.lambda_k1": "model-00002-of-00003.safetensors",
144
+ "model.layers.16.self_attn.lambda_k2": "model-00002-of-00003.safetensors",
145
+ "model.layers.16.self_attn.lambda_q1": "model-00002-of-00003.safetensors",
146
+ "model.layers.16.self_attn.lambda_q2": "model-00002-of-00003.safetensors",
147
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
148
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
149
+ "model.layers.16.self_attn.subln.weight": "model-00002-of-00003.safetensors",
150
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
151
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
152
+ "model.layers.17.mlp.act_fn.bias": "model-00002-of-00003.safetensors",
153
+ "model.layers.17.mlp.act_fn.weight": "model-00002-of-00003.safetensors",
154
+ "model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
155
+ "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
156
+ "model.layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
157
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
158
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
159
+ "model.layers.17.self_attn.lambda_k1": "model-00002-of-00003.safetensors",
160
+ "model.layers.17.self_attn.lambda_k2": "model-00002-of-00003.safetensors",
161
+ "model.layers.17.self_attn.lambda_q1": "model-00002-of-00003.safetensors",
162
+ "model.layers.17.self_attn.lambda_q2": "model-00002-of-00003.safetensors",
163
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
164
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
165
+ "model.layers.17.self_attn.subln.weight": "model-00002-of-00003.safetensors",
166
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
167
+ "model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
168
+ "model.layers.18.mlp.act_fn.bias": "model-00002-of-00003.safetensors",
169
+ "model.layers.18.mlp.act_fn.weight": "model-00002-of-00003.safetensors",
170
+ "model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
171
+ "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
172
+ "model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
173
+ "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
174
+ "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
175
+ "model.layers.18.self_attn.lambda_k1": "model-00002-of-00003.safetensors",
176
+ "model.layers.18.self_attn.lambda_k2": "model-00002-of-00003.safetensors",
177
+ "model.layers.18.self_attn.lambda_q1": "model-00002-of-00003.safetensors",
178
+ "model.layers.18.self_attn.lambda_q2": "model-00002-of-00003.safetensors",
179
+ "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
180
+ "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
181
+ "model.layers.18.self_attn.subln.weight": "model-00002-of-00003.safetensors",
182
+ "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
183
+ "model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
184
+ "model.layers.19.mlp.act_fn.bias": "model-00002-of-00003.safetensors",
185
+ "model.layers.19.mlp.act_fn.weight": "model-00002-of-00003.safetensors",
186
+ "model.layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
187
+ "model.layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
188
+ "model.layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
189
+ "model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
190
+ "model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
191
+ "model.layers.19.self_attn.lambda_k1": "model-00002-of-00003.safetensors",
192
+ "model.layers.19.self_attn.lambda_k2": "model-00002-of-00003.safetensors",
193
+ "model.layers.19.self_attn.lambda_q1": "model-00002-of-00003.safetensors",
194
+ "model.layers.19.self_attn.lambda_q2": "model-00002-of-00003.safetensors",
195
+ "model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
196
+ "model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
197
+ "model.layers.19.self_attn.subln.weight": "model-00002-of-00003.safetensors",
198
+ "model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
199
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
200
+ "model.layers.2.mlp.act_fn.bias": "model-00001-of-00003.safetensors",
201
+ "model.layers.2.mlp.act_fn.weight": "model-00001-of-00003.safetensors",
202
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
203
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
204
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
205
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
206
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
207
+ "model.layers.2.self_attn.lambda_k1": "model-00001-of-00003.safetensors",
208
+ "model.layers.2.self_attn.lambda_k2": "model-00001-of-00003.safetensors",
209
+ "model.layers.2.self_attn.lambda_q1": "model-00001-of-00003.safetensors",
210
+ "model.layers.2.self_attn.lambda_q2": "model-00001-of-00003.safetensors",
211
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
212
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
213
+ "model.layers.2.self_attn.subln.weight": "model-00001-of-00003.safetensors",
214
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
215
+ "model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
216
+ "model.layers.20.mlp.act_fn.bias": "model-00002-of-00003.safetensors",
217
+ "model.layers.20.mlp.act_fn.weight": "model-00002-of-00003.safetensors",
218
+ "model.layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
219
+ "model.layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
220
+ "model.layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
221
+ "model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
222
+ "model.layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
223
+ "model.layers.20.self_attn.lambda_k1": "model-00002-of-00003.safetensors",
224
+ "model.layers.20.self_attn.lambda_k2": "model-00002-of-00003.safetensors",
225
+ "model.layers.20.self_attn.lambda_q1": "model-00002-of-00003.safetensors",
226
+ "model.layers.20.self_attn.lambda_q2": "model-00002-of-00003.safetensors",
227
+ "model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
228
+ "model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
229
+ "model.layers.20.self_attn.subln.weight": "model-00002-of-00003.safetensors",
230
+ "model.layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
231
+ "model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
232
+ "model.layers.21.mlp.act_fn.bias": "model-00002-of-00003.safetensors",
233
+ "model.layers.21.mlp.act_fn.weight": "model-00002-of-00003.safetensors",
234
+ "model.layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
235
+ "model.layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
236
+ "model.layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
237
+ "model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
238
+ "model.layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
239
+ "model.layers.21.self_attn.lambda_k1": "model-00002-of-00003.safetensors",
240
+ "model.layers.21.self_attn.lambda_k2": "model-00002-of-00003.safetensors",
241
+ "model.layers.21.self_attn.lambda_q1": "model-00002-of-00003.safetensors",
242
+ "model.layers.21.self_attn.lambda_q2": "model-00002-of-00003.safetensors",
243
+ "model.layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
244
+ "model.layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
245
+ "model.layers.21.self_attn.subln.weight": "model-00002-of-00003.safetensors",
246
+ "model.layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
247
+ "model.layers.22.input_layernorm.weight": "model-00002-of-00003.safetensors",
248
+ "model.layers.22.mlp.act_fn.bias": "model-00002-of-00003.safetensors",
249
+ "model.layers.22.mlp.act_fn.weight": "model-00002-of-00003.safetensors",
250
+ "model.layers.22.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
251
+ "model.layers.22.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
252
+ "model.layers.22.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
253
+ "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
254
+ "model.layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
255
+ "model.layers.22.self_attn.lambda_k1": "model-00002-of-00003.safetensors",
256
+ "model.layers.22.self_attn.lambda_k2": "model-00002-of-00003.safetensors",
257
+ "model.layers.22.self_attn.lambda_q1": "model-00002-of-00003.safetensors",
258
+ "model.layers.22.self_attn.lambda_q2": "model-00002-of-00003.safetensors",
259
+ "model.layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
260
+ "model.layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
261
+ "model.layers.22.self_attn.subln.weight": "model-00002-of-00003.safetensors",
262
+ "model.layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
263
+ "model.layers.23.input_layernorm.weight": "model-00002-of-00003.safetensors",
264
+ "model.layers.23.mlp.act_fn.bias": "model-00002-of-00003.safetensors",
265
+ "model.layers.23.mlp.act_fn.weight": "model-00002-of-00003.safetensors",
266
+ "model.layers.23.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
267
+ "model.layers.23.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
268
+ "model.layers.23.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
269
+ "model.layers.23.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
270
+ "model.layers.23.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
271
+ "model.layers.23.self_attn.lambda_k1": "model-00002-of-00003.safetensors",
272
+ "model.layers.23.self_attn.lambda_k2": "model-00002-of-00003.safetensors",
273
+ "model.layers.23.self_attn.lambda_q1": "model-00002-of-00003.safetensors",
274
+ "model.layers.23.self_attn.lambda_q2": "model-00002-of-00003.safetensors",
275
+ "model.layers.23.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
276
+ "model.layers.23.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
277
+ "model.layers.23.self_attn.subln.weight": "model-00002-of-00003.safetensors",
278
+ "model.layers.23.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
279
+ "model.layers.24.input_layernorm.weight": "model-00002-of-00003.safetensors",
280
+ "model.layers.24.mlp.act_fn.bias": "model-00002-of-00003.safetensors",
281
+ "model.layers.24.mlp.act_fn.weight": "model-00002-of-00003.safetensors",
282
+ "model.layers.24.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
283
+ "model.layers.24.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
284
+ "model.layers.24.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
285
+ "model.layers.24.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
286
+ "model.layers.24.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
287
+ "model.layers.24.self_attn.lambda_k1": "model-00002-of-00003.safetensors",
288
+ "model.layers.24.self_attn.lambda_k2": "model-00002-of-00003.safetensors",
289
+ "model.layers.24.self_attn.lambda_q1": "model-00002-of-00003.safetensors",
290
+ "model.layers.24.self_attn.lambda_q2": "model-00002-of-00003.safetensors",
291
+ "model.layers.24.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
292
+ "model.layers.24.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
293
+ "model.layers.24.self_attn.subln.weight": "model-00002-of-00003.safetensors",
294
+ "model.layers.24.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
295
+ "model.layers.25.input_layernorm.weight": "model-00002-of-00003.safetensors",
296
+ "model.layers.25.mlp.act_fn.bias": "model-00002-of-00003.safetensors",
297
+ "model.layers.25.mlp.act_fn.weight": "model-00002-of-00003.safetensors",
298
+ "model.layers.25.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
299
+ "model.layers.25.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
300
+ "model.layers.25.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
301
+ "model.layers.25.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
302
+ "model.layers.25.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
303
+ "model.layers.25.self_attn.lambda_k1": "model-00002-of-00003.safetensors",
304
+ "model.layers.25.self_attn.lambda_k2": "model-00002-of-00003.safetensors",
305
+ "model.layers.25.self_attn.lambda_q1": "model-00002-of-00003.safetensors",
306
+ "model.layers.25.self_attn.lambda_q2": "model-00002-of-00003.safetensors",
307
+ "model.layers.25.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
308
+ "model.layers.25.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
309
+ "model.layers.25.self_attn.subln.weight": "model-00002-of-00003.safetensors",
310
+ "model.layers.25.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
311
+ "model.layers.26.input_layernorm.weight": "model-00002-of-00003.safetensors",
312
+ "model.layers.26.mlp.act_fn.bias": "model-00002-of-00003.safetensors",
313
+ "model.layers.26.mlp.act_fn.weight": "model-00002-of-00003.safetensors",
314
+ "model.layers.26.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
315
+ "model.layers.26.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
316
+ "model.layers.26.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
317
+ "model.layers.26.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
318
+ "model.layers.26.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
319
+ "model.layers.26.self_attn.lambda_k1": "model-00002-of-00003.safetensors",
320
+ "model.layers.26.self_attn.lambda_k2": "model-00002-of-00003.safetensors",
321
+ "model.layers.26.self_attn.lambda_q1": "model-00002-of-00003.safetensors",
322
+ "model.layers.26.self_attn.lambda_q2": "model-00002-of-00003.safetensors",
323
+ "model.layers.26.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
324
+ "model.layers.26.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
325
+ "model.layers.26.self_attn.subln.weight": "model-00002-of-00003.safetensors",
326
+ "model.layers.26.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
327
+ "model.layers.27.input_layernorm.weight": "model-00002-of-00003.safetensors",
328
+ "model.layers.27.mlp.act_fn.bias": "model-00002-of-00003.safetensors",
329
+ "model.layers.27.mlp.act_fn.weight": "model-00002-of-00003.safetensors",
330
+ "model.layers.27.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
331
+ "model.layers.27.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
332
+ "model.layers.27.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
333
+ "model.layers.27.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
334
+ "model.layers.27.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
335
+ "model.layers.27.self_attn.lambda_k1": "model-00002-of-00003.safetensors",
336
+ "model.layers.27.self_attn.lambda_k2": "model-00002-of-00003.safetensors",
337
+ "model.layers.27.self_attn.lambda_q1": "model-00002-of-00003.safetensors",
338
+ "model.layers.27.self_attn.lambda_q2": "model-00002-of-00003.safetensors",
339
+ "model.layers.27.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
340
+ "model.layers.27.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
341
+ "model.layers.27.self_attn.subln.weight": "model-00002-of-00003.safetensors",
342
+ "model.layers.27.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
343
+ "model.layers.28.input_layernorm.weight": "model-00002-of-00003.safetensors",
344
+ "model.layers.28.mlp.act_fn.bias": "model-00002-of-00003.safetensors",
345
+ "model.layers.28.mlp.act_fn.weight": "model-00002-of-00003.safetensors",
346
+ "model.layers.28.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
347
+ "model.layers.28.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
348
+ "model.layers.28.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
349
+ "model.layers.28.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
350
+ "model.layers.28.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
351
+ "model.layers.28.self_attn.lambda_k1": "model-00002-of-00003.safetensors",
352
+ "model.layers.28.self_attn.lambda_k2": "model-00002-of-00003.safetensors",
353
+ "model.layers.28.self_attn.lambda_q1": "model-00002-of-00003.safetensors",
354
+ "model.layers.28.self_attn.lambda_q2": "model-00002-of-00003.safetensors",
355
+ "model.layers.28.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
356
+ "model.layers.28.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
357
+ "model.layers.28.self_attn.subln.weight": "model-00002-of-00003.safetensors",
358
+ "model.layers.28.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
359
+ "model.layers.29.input_layernorm.weight": "model-00002-of-00003.safetensors",
360
+ "model.layers.29.mlp.act_fn.bias": "model-00002-of-00003.safetensors",
361
+ "model.layers.29.mlp.act_fn.weight": "model-00002-of-00003.safetensors",
362
+ "model.layers.29.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
363
+ "model.layers.29.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
364
+ "model.layers.29.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
365
+ "model.layers.29.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
366
+ "model.layers.29.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
367
+ "model.layers.29.self_attn.lambda_k1": "model-00002-of-00003.safetensors",
368
+ "model.layers.29.self_attn.lambda_k2": "model-00002-of-00003.safetensors",
369
+ "model.layers.29.self_attn.lambda_q1": "model-00002-of-00003.safetensors",
370
+ "model.layers.29.self_attn.lambda_q2": "model-00002-of-00003.safetensors",
371
+ "model.layers.29.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
372
+ "model.layers.29.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
373
+ "model.layers.29.self_attn.subln.weight": "model-00002-of-00003.safetensors",
374
+ "model.layers.29.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
375
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
376
+ "model.layers.3.mlp.act_fn.bias": "model-00001-of-00003.safetensors",
377
+ "model.layers.3.mlp.act_fn.weight": "model-00001-of-00003.safetensors",
378
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
379
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
380
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
381
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
382
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
383
+ "model.layers.3.self_attn.lambda_k1": "model-00001-of-00003.safetensors",
384
+ "model.layers.3.self_attn.lambda_k2": "model-00001-of-00003.safetensors",
385
+ "model.layers.3.self_attn.lambda_q1": "model-00001-of-00003.safetensors",
386
+ "model.layers.3.self_attn.lambda_q2": "model-00001-of-00003.safetensors",
387
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
388
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
389
+ "model.layers.3.self_attn.subln.weight": "model-00001-of-00003.safetensors",
390
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
391
+ "model.layers.30.input_layernorm.weight": "model-00003-of-00003.safetensors",
392
+ "model.layers.30.mlp.act_fn.bias": "model-00003-of-00003.safetensors",
393
+ "model.layers.30.mlp.act_fn.weight": "model-00003-of-00003.safetensors",
394
+ "model.layers.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
395
+ "model.layers.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
396
+ "model.layers.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
397
+ "model.layers.30.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
398
+ "model.layers.30.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
399
+ "model.layers.30.self_attn.lambda_k1": "model-00002-of-00003.safetensors",
400
+ "model.layers.30.self_attn.lambda_k2": "model-00002-of-00003.safetensors",
401
+ "model.layers.30.self_attn.lambda_q1": "model-00002-of-00003.safetensors",
402
+ "model.layers.30.self_attn.lambda_q2": "model-00002-of-00003.safetensors",
403
+ "model.layers.30.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
404
+ "model.layers.30.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
405
+ "model.layers.30.self_attn.subln.weight": "model-00002-of-00003.safetensors",
406
+ "model.layers.30.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
407
+ "model.layers.31.input_layernorm.weight": "model-00003-of-00003.safetensors",
408
+ "model.layers.31.mlp.act_fn.bias": "model-00003-of-00003.safetensors",
409
+ "model.layers.31.mlp.act_fn.weight": "model-00003-of-00003.safetensors",
410
+ "model.layers.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
411
+ "model.layers.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
412
+ "model.layers.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
413
+ "model.layers.31.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
414
+ "model.layers.31.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
415
+ "model.layers.31.self_attn.lambda_k1": "model-00003-of-00003.safetensors",
416
+ "model.layers.31.self_attn.lambda_k2": "model-00003-of-00003.safetensors",
417
+ "model.layers.31.self_attn.lambda_q1": "model-00003-of-00003.safetensors",
418
+ "model.layers.31.self_attn.lambda_q2": "model-00003-of-00003.safetensors",
419
+ "model.layers.31.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
420
+ "model.layers.31.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
421
+ "model.layers.31.self_attn.subln.weight": "model-00003-of-00003.safetensors",
422
+ "model.layers.31.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
423
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
424
+ "model.layers.4.mlp.act_fn.bias": "model-00001-of-00003.safetensors",
425
+ "model.layers.4.mlp.act_fn.weight": "model-00001-of-00003.safetensors",
426
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
427
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
428
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
429
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
430
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
431
+ "model.layers.4.self_attn.lambda_k1": "model-00001-of-00003.safetensors",
432
+ "model.layers.4.self_attn.lambda_k2": "model-00001-of-00003.safetensors",
433
+ "model.layers.4.self_attn.lambda_q1": "model-00001-of-00003.safetensors",
434
+ "model.layers.4.self_attn.lambda_q2": "model-00001-of-00003.safetensors",
435
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
436
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
437
+ "model.layers.4.self_attn.subln.weight": "model-00001-of-00003.safetensors",
438
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
439
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
440
+ "model.layers.5.mlp.act_fn.bias": "model-00001-of-00003.safetensors",
441
+ "model.layers.5.mlp.act_fn.weight": "model-00001-of-00003.safetensors",
442
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
443
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
444
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
445
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
446
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
447
+ "model.layers.5.self_attn.lambda_k1": "model-00001-of-00003.safetensors",
448
+ "model.layers.5.self_attn.lambda_k2": "model-00001-of-00003.safetensors",
449
+ "model.layers.5.self_attn.lambda_q1": "model-00001-of-00003.safetensors",
450
+ "model.layers.5.self_attn.lambda_q2": "model-00001-of-00003.safetensors",
451
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
452
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
453
+ "model.layers.5.self_attn.subln.weight": "model-00001-of-00003.safetensors",
454
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
455
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
456
+ "model.layers.6.mlp.act_fn.bias": "model-00001-of-00003.safetensors",
457
+ "model.layers.6.mlp.act_fn.weight": "model-00001-of-00003.safetensors",
458
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
459
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
460
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
461
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
462
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
463
+ "model.layers.6.self_attn.lambda_k1": "model-00001-of-00003.safetensors",
464
+ "model.layers.6.self_attn.lambda_k2": "model-00001-of-00003.safetensors",
465
+ "model.layers.6.self_attn.lambda_q1": "model-00001-of-00003.safetensors",
466
+ "model.layers.6.self_attn.lambda_q2": "model-00001-of-00003.safetensors",
467
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
468
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
469
+ "model.layers.6.self_attn.subln.weight": "model-00001-of-00003.safetensors",
470
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
471
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
472
+ "model.layers.7.mlp.act_fn.bias": "model-00001-of-00003.safetensors",
473
+ "model.layers.7.mlp.act_fn.weight": "model-00001-of-00003.safetensors",
474
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
475
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
476
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
477
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
478
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
479
+ "model.layers.7.self_attn.lambda_k1": "model-00001-of-00003.safetensors",
480
+ "model.layers.7.self_attn.lambda_k2": "model-00001-of-00003.safetensors",
481
+ "model.layers.7.self_attn.lambda_q1": "model-00001-of-00003.safetensors",
482
+ "model.layers.7.self_attn.lambda_q2": "model-00001-of-00003.safetensors",
483
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
484
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
485
+ "model.layers.7.self_attn.subln.weight": "model-00001-of-00003.safetensors",
486
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
487
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
488
+ "model.layers.8.mlp.act_fn.bias": "model-00001-of-00003.safetensors",
489
+ "model.layers.8.mlp.act_fn.weight": "model-00001-of-00003.safetensors",
490
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
491
+ "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
492
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
493
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
494
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
495
+ "model.layers.8.self_attn.lambda_k1": "model-00001-of-00003.safetensors",
496
+ "model.layers.8.self_attn.lambda_k2": "model-00001-of-00003.safetensors",
497
+ "model.layers.8.self_attn.lambda_q1": "model-00001-of-00003.safetensors",
498
+ "model.layers.8.self_attn.lambda_q2": "model-00001-of-00003.safetensors",
499
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
500
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
501
+ "model.layers.8.self_attn.subln.weight": "model-00001-of-00003.safetensors",
502
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
503
+ "model.layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
504
+ "model.layers.9.mlp.act_fn.bias": "model-00001-of-00003.safetensors",
505
+ "model.layers.9.mlp.act_fn.weight": "model-00001-of-00003.safetensors",
506
+ "model.layers.9.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
507
+ "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
508
+ "model.layers.9.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
509
+ "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
510
+ "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
511
+ "model.layers.9.self_attn.lambda_k1": "model-00001-of-00003.safetensors",
512
+ "model.layers.9.self_attn.lambda_k2": "model-00001-of-00003.safetensors",
513
+ "model.layers.9.self_attn.lambda_q1": "model-00001-of-00003.safetensors",
514
+ "model.layers.9.self_attn.lambda_q2": "model-00001-of-00003.safetensors",
515
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
516
+ "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
517
+ "model.layers.9.self_attn.subln.weight": "model-00001-of-00003.safetensors",
518
+ "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
519
+ "model.norm.weight": "model-00003-of-00003.safetensors"
520
+ }
521
+ }
modeling_motif.py ADDED
@@ -0,0 +1,1378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.utils.checkpoint
8
+ from torch import nn
9
+ from torch.nn import CrossEntropyLoss
10
+ from transformers.activations import ACT2CLS as _ACT2CLS
11
+ from transformers.activations import ClassInstantier
12
+ from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
15
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
16
+ from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
17
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
18
+ from transformers.modeling_utils import PreTrainedModel
19
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
20
+ from transformers.utils import (add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available,
21
+ is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings)
22
+
23
+ from .configuration_motif import MotifConfig
24
+
25
+
26
+ class PolyNorm(torch.nn.Module):
27
+ """
28
+ A trainable activation function introduced in https://arxiv.org/html/2411.03884v1.
29
+ The code is copied from https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
30
+ """
31
+
32
+ def __init__(self, eps=1e-6):
33
+ super(PolyNorm, self).__init__()
34
+ self.weight = torch.nn.Parameter(torch.ones(3) / 3)
35
+ self.bias = torch.nn.Parameter(torch.zeros(1))
36
+ self.eps = eps
37
+
38
+ def _norm(self, x):
39
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
40
+
41
+ def forward(self, x):
42
+ return self.weight[0] * self._norm(x ** 3) + self.weight[1] * self._norm(
43
+ x ** 2) + self.weight[2] * self._norm(x) + self.bias
44
+
45
+
46
+ CUSTOM_ACT2CLS = {"poly_norm": PolyNorm}
47
+ ACT2CLS = {**_ACT2CLS, **CUSTOM_ACT2CLS}
48
+ ACT2FN = ClassInstantier(ACT2CLS)
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ if is_flash_attn_2_available():
53
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
54
+
55
+ _CONFIG_FOR_DOC = "MotifConfig"
56
+
57
+
58
+ class MotifRMSNorm(nn.Module):
59
+
60
+ def __init__(self, hidden_size, eps=1e-6):
61
+ """
62
+ MotifRMSNorm is equivalent to T5LayerNorm
63
+ """
64
+ super().__init__()
65
+ self.weight = nn.Parameter(torch.ones(hidden_size))
66
+ self.variance_epsilon = eps
67
+
68
+ def forward(self, hidden_states):
69
+ input_dtype = hidden_states.dtype
70
+ hidden_states = hidden_states.to(torch.float32)
71
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
72
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
73
+ return self.weight * hidden_states.to(input_dtype)
74
+
75
+ def extra_repr(self):
76
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
77
+
78
+
79
+ ALL_LAYERNORM_LAYERS.append(MotifRMSNorm)
80
+
81
+
82
+ class MotifRotaryEmbeddingWithCache(nn.Module):
83
+ """
84
+ Rotary positional embedding module with caching for efficiency.
85
+
86
+ Args:
87
+ dim (int): Dimensionality of the embedding.
88
+ max_position_embeddings (int): Maximum sequence length for caching. Default is 2048.
89
+ base (int): Base for computing inverse frequency. Default is 10000.
90
+ device (torch.device, optional): Device for tensor storage.
91
+
92
+ Methods:
93
+ forward(x, seq_len=None):
94
+ Computes cosine and sine embeddings for input sequence length.
95
+ Automatically updates cache if `seq_len` exceeds cached length.
96
+
97
+ Attributes:
98
+ inv_freq (torch.Tensor): Inverse frequency tensor for position encoding.
99
+ cos_cached (torch.Tensor): Cached cosine embeddings.
100
+ sin_cached (torch.Tensor): Cached sine embeddings.
101
+ """
102
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
103
+ super().__init__()
104
+
105
+ self.dim = dim
106
+ self.max_position_embeddings = max_position_embeddings
107
+ self.base = base
108
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
109
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
110
+
111
+ self._set_cos_sin_cache(seq_len=max_position_embeddings,
112
+ device=self.inv_freq.device,
113
+ dtype=torch.get_default_dtype())
114
+
115
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
116
+ self.max_seq_len_cached = seq_len
117
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
118
+
119
+ freqs = torch.outer(t, self.inv_freq)
120
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
121
+ emb = torch.cat((freqs, freqs), dim=-1)
122
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
123
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
124
+
125
+ def forward(self, x, seq_len=None):
126
+ # x: [bs, num_attention_heads, seq_len, head_size]
127
+ if seq_len > self.max_seq_len_cached:
128
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
129
+
130
+ return (
131
+ self.cos_cached[ :seq_len].to(dtype=x.dtype),
132
+ self.sin_cached[ :seq_len].to(dtype=x.dtype),
133
+ )
134
+
135
+
136
+ class MotifRotaryEmbedding(nn.Module):
137
+
138
+ def __init__(
139
+ self,
140
+ dim=None,
141
+ max_position_embeddings=2048,
142
+ base=10000,
143
+ device=None,
144
+ scaling_factor=1.0,
145
+ rope_type="default",
146
+ config: Optional[MotifConfig] = None,
147
+ ):
148
+ super().__init__()
149
+ # TODO (joao): remove the `if` below, only used for BC
150
+ self.rope_kwargs = {}
151
+ if config is None:
152
+ logger.warning_once(
153
+ "`MotifRotaryEmbedding` can now be fully parameterized by passing the model config through the "
154
+ "`config` argument. All other arguments will be removed in v4.46")
155
+ self.rope_kwargs = {
156
+ "rope_type": rope_type,
157
+ "factor": scaling_factor,
158
+ "dim": dim,
159
+ "base": base,
160
+ "max_position_embeddings": max_position_embeddings,
161
+ }
162
+ self.rope_type = rope_type
163
+ self.max_seq_len_cached = max_position_embeddings
164
+ self.original_max_seq_len = max_position_embeddings
165
+ else:
166
+ if config.rope_scaling is not None:
167
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
168
+ else:
169
+ self.rope_type = "default"
170
+ self.max_seq_len_cached = config.max_position_embeddings
171
+ self.original_max_seq_len = config.max_position_embeddings
172
+
173
+ self.config = config
174
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
175
+
176
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
177
+
178
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
179
+ self.original_inv_freq = self.inv_freq
180
+
181
+ def _dynamic_frequency_update(self, position_ids, device):
182
+ """
183
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
184
+ 1 - growing beyond the cached sequence length (allow scaling)
185
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
186
+ """
187
+ seq_len = torch.max(position_ids) + 1
188
+ if seq_len > self.max_seq_len_cached: # growth
189
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config,
190
+ device,
191
+ seq_len=seq_len,
192
+ **self.rope_kwargs)
193
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
194
+ self.max_seq_len_cached = seq_len
195
+
196
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
197
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
198
+ self.max_seq_len_cached = self.original_max_seq_len
199
+
200
+ @torch.no_grad()
201
+ def forward(self, x, position_ids):
202
+ if "dynamic" in self.rope_type:
203
+ self._dynamic_frequency_update(position_ids, device=x.device)
204
+
205
+ # Core RoPE block
206
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
207
+ position_ids_expanded = position_ids[:, None, :].float()
208
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
209
+ device_type = x.device.type
210
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
211
+ with torch.autocast(device_type=device_type, enabled=False):
212
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
213
+ emb = torch.cat((freqs, freqs), dim=-1)
214
+ cos = emb.cos()
215
+ sin = emb.sin()
216
+
217
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
218
+ cos = cos * self.attention_scaling
219
+ sin = sin * self.attention_scaling
220
+
221
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
222
+
223
+
224
+ def rotate_half(x):
225
+ """
226
+ Rotates half of the dimensions of the input tensor using torch.roll and in-place negation.
227
+
228
+ Args:
229
+ x (torch.Tensor): The input tensor.
230
+
231
+ Returns:
232
+ torch.Tensor: A tensor where the latter half of the dimensions are negated
233
+ and moved before the first half.
234
+ """
235
+ half_size = x.shape[-1] // 2
236
+ rotated_tensor = torch.roll(x, shifts=-half_size, dims=-1)
237
+ rotated_tensor[..., :half_size] *= -1
238
+
239
+ return rotated_tensor
240
+
241
+
242
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
243
+ """
244
+ Applies rotary position embeddings to the input tensors.
245
+
246
+ Args:
247
+ q (torch.Tensor): Query tensor of shape (B, NH, S, D_KV).
248
+ k (torch.Tensor): Key tensor of shape (B, NH, S, D_KV).
249
+ cos (torch.Tensor): Cosine values for rotary embedding.
250
+ sin (torch.Tensor): Sine values for rotary embedding.
251
+ unsqueeze_dim (int, optional): Dimension along which `cos` and `sin` are unsqueezed.
252
+ Defaults to 1.
253
+
254
+ Returns:
255
+ Tuple[torch.Tensor, torch.Tensor]: Returns transformed query and key tensors after applying rotary embeddings.
256
+ """
257
+ '''
258
+ # (B, NH, S, D_KV) -> (B, S, NH, D_KV)
259
+ cos = cos.unsqueeze(unsqueeze_dim)
260
+ sin = sin.unsqueeze(unsqueeze_dim)
261
+ q_embed = (q * cos) + (rotate_half(q) * sin)
262
+ k_embed = (k * cos) + (rotate_half(k) * sin)
263
+ '''
264
+ device = q.device
265
+ return map(
266
+ lambda x: (x * cos[position_ids].unsqueeze(unsqueeze_dim).to(device)) +
267
+ (rotate_half(x) * sin[position_ids].unsqueeze(unsqueeze_dim).to(device)), (q, k))
268
+
269
+
270
+ class MotifMLP(nn.Module):
271
+
272
+ def __init__(self, config):
273
+ super().__init__()
274
+ self.hidden_size = config.hidden_size
275
+ self.intermediate_size = config.intermediate_size
276
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias)
277
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias)
278
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
279
+ self.act_fn = ACT2FN[config.hidden_act]
280
+
281
+ def forward(self, hidden_state):
282
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
283
+
284
+
285
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
286
+
287
+
288
+ """
289
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
290
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
291
+
292
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
293
+ if n_rep == 1:
294
+ return hidden_states
295
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
296
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
297
+ """
298
+
299
+ return torch.repeat_interleave(hidden_states, dim=1, repeats=n_rep)
300
+
301
+
302
+ class MotifAttention(nn.Module):
303
+ """
304
+ Differential Attention (DiffAttention) module.
305
+
306
+ Implements the Differential Attention from
307
+ "DIFFERENTIAL TRANSFORMER" (https://arxiv.org/pdf/2410.05258).
308
+
309
+ Overview
310
+ Standard transformers often over-allocate attention to irrelevant context.
311
+ DiffAttention addresses this by computing attention as the difference between
312
+ two separate softmax attention maps, effectively canceling noise and promoting
313
+ sparse, structured attention patterns.
314
+
315
+ Reference Implementation
316
+ https://github.com/microsoft/unilm/tree/master/Diff-Transformer
317
+
318
+ Args
319
+ The differential attention mechanism computes attention as the difference of two softmax attention scores, weighted by a learnable scalar λ.
320
+ λ is re-parameterized as λ = exp(λ_q1 · λ_k1) − exp(λ_q2 · λ_k2) + λ_init.
321
+ - lambda_q1, lambda_q2 (nn.Parameter): Learnable vectors used to compute the first and second components of λ for query transformations.
322
+ - lambda_k1, lambda_k2 (nn.Parameter): Learnable vectors used to compute the first and second components of λ for key transformations.
323
+ - lambda_init (float): A constant used for initializing λ, typically set as λ_init = 0.8 − 0.6 × exp(−0.3 × (layer_index − 1)).
324
+
325
+ """
326
+
327
+ def __init__(self, config: MotifConfig, layer_idx: Optional[int] = None):
328
+ super().__init__()
329
+ self.config = config
330
+ self.layer_idx = layer_idx
331
+ if layer_idx is None:
332
+ logger.warning_once(
333
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
334
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
335
+ "when creating this class.")
336
+
337
+
338
+ self.hidden_size = config.hidden_size
339
+ self.num_heads = config.num_attention_heads
340
+ self.head_dim = self.hidden_size // self.num_heads
341
+ self.num_key_value_heads = config.num_key_value_heads
342
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
343
+ self.max_position_embeddings = config.max_position_embeddings
344
+ self.rope_theta = config.rope_theta
345
+ self.is_causal = True
346
+ self.attention_dropout = config.attention_dropout
347
+
348
+ if (self.head_dim * self.num_heads) != self.hidden_size:
349
+ raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
350
+ f" and `num_heads`: {self.num_heads}).")
351
+
352
+ self.num_heads //= 2
353
+ self.num_key_value_heads //= 2
354
+ self.n_rep = self.num_heads // self.num_key_value_heads
355
+
356
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
357
+ self.k_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
358
+ self.v_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
359
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
360
+
361
+ for name in ["lambda_q1", "lambda_k1", "lambda_q2", "lambda_k2"]:
362
+ setattr(self, name, nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32)))
363
+ getattr(self, name).data.normal_(mean=0.0, std=0.1)
364
+
365
+ self.subln = MotifRMSNorm(2 * self.head_dim, eps=1e-5)
366
+ self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * (layer_idx - 1))
367
+
368
+ self.rotary_emb = MotifRotaryEmbeddingWithCache(self.head_dim,
369
+ max_position_embeddings=self.max_position_embeddings,
370
+ base=self.rope_theta)
371
+
372
+ def forward(
373
+ self,
374
+ hidden_states: torch.Tensor,
375
+ attention_mask: Optional[torch.Tensor] = None,
376
+ position_ids: Optional[torch.LongTensor] = None,
377
+ past_key_value: Optional[Cache] = None,
378
+ output_attentions: bool = False,
379
+ use_cache: bool = False,
380
+ cache_position: Optional[torch.LongTensor] = None,
381
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
382
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
383
+ bsz, q_len, _ = hidden_states.size()
384
+
385
+ query_states = self.q_proj(hidden_states)
386
+ key_states = self.k_proj(hidden_states)
387
+ value_states = self.v_proj(hidden_states)
388
+
389
+ query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
390
+ key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2)
391
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, 2 * self.head_dim).transpose(1, 2)
392
+
393
+ kv_seq_len = key_states.shape[-2]
394
+ if position_embeddings is None:
395
+ logger.warning_once(
396
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
397
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
398
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
399
+ "removed and `position_embeddings` will be mandatory.")
400
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
401
+ else:
402
+ cos, sin = (self.rotary_emb(value_states, q_len + past_key_value.get_usable_length(q_len, self.layer_idx))
403
+ if use_cache else position_embeddings)
404
+
405
+ query_states, key_states = apply_rotary_pos_emb(query_states,
406
+ key_states,
407
+ cos,
408
+ sin,
409
+ position_ids=position_ids)
410
+
411
+ if past_key_value is not None:
412
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
413
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
414
+
415
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
416
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
417
+
418
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
419
+
420
+ kv_seq_len = key_states.shape[-2]
421
+ offset = kv_seq_len - q_len
422
+
423
+ attention_mask = torch.triu(
424
+ torch.full((q_len, kv_seq_len), float("-inf"), dtype=attn_weights.dtype, device=attn_weights.device),
425
+ 1 + offset)
426
+
427
+ attn_weights = attn_weights + attention_mask
428
+
429
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
430
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
431
+
432
+ lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(attn_weights)
433
+ lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(attn_weights)
434
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init
435
+ attn_weights = attn_weights.view(bsz, self.num_heads, 2, q_len, -1)
436
+ attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1]
437
+
438
+ attn_output = torch.matmul(attn_weights, value_states)
439
+
440
+ attn_output = self.subln(attn_output)
441
+ attn_output = attn_output * (1 - self.lambda_init)
442
+
443
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim * 2):
444
+ raise ValueError(f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
445
+ f" {attn_output.size()}")
446
+
447
+ attn_output = attn_output.transpose(1, 2).contiguous()
448
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
449
+
450
+ attn_output = self.o_proj(attn_output)
451
+
452
+ if not output_attentions:
453
+ attn_weights = None
454
+
455
+ return attn_output, attn_weights, past_key_value
456
+
457
+
458
+ class MotifFlashAttention2(MotifAttention):
459
+ """
460
+ Motif flash attention module, following Motif attention module. This module inherits from `MotifAttention`
461
+ as the weights of the module stays untouched. The only required change would be on the forward pass
462
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
463
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
464
+ config.max_window_layers layers.
465
+ """
466
+
467
+ def __init__(self, *args, **kwargs):
468
+ super().__init__(*args, **kwargs)
469
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
470
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
471
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
472
+
473
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
474
+
475
+ logger.info(f'flash attention is used {not self._flash_attn_uses_top_left_mask}')
476
+
477
+ def _reshape_heads(self, tensor, batch_size, seq_len):
478
+ """2-way head split tensor reshape"""
479
+ return tensor.reshape(batch_size, seq_len, self.num_heads, 2, self.head_dim)
480
+
481
+ def _restore_shape(self, tensor, batch_size, seq_len):
482
+ """restore tensor"""
483
+ return tensor.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
484
+
485
+ def _compute_attention(self, query_states, key_states, value_states, attention_mask, q_len, position_ids,
486
+ dropout_rate, sliding_window):
487
+ """Flash Attention 2 implements"""
488
+ _input_type = query_states.dtype
489
+ scale_factor = 1.0 / math.sqrt(self.head_dim)
490
+ if not self._flash_attn_uses_top_left_mask:
491
+ causal = self.is_causal
492
+ else:
493
+ causal = self.is_causal and q_len != 1
494
+
495
+ attn_out = _flash_attention_forward(query_states.bfloat16(),
496
+ key_states.bfloat16(),
497
+ value_states.bfloat16(),
498
+ attention_mask,
499
+ q_len,
500
+ position_ids=position_ids,
501
+ dropout=dropout_rate,
502
+ sliding_window=sliding_window,
503
+ is_causal=True,
504
+ softmax_scale=scale_factor,
505
+ use_top_left_mask=self._flash_attn_uses_top_left_mask)
506
+ return attn_out.to(_input_type)
507
+
508
+ def forward(
509
+ self,
510
+ hidden_states: torch.Tensor,
511
+ attention_mask: Optional[torch.Tensor] = None,
512
+ position_ids: Optional[torch.LongTensor] = None,
513
+ past_key_value: Optional[Cache] = None,
514
+ output_attentions: bool = False,
515
+ use_cache: bool = False,
516
+ cache_position: Optional[torch.LongTensor] = None,
517
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
518
+ ):
519
+ bsz, q_len, _ = hidden_states.size()
520
+
521
+ query_states = self.q_proj(hidden_states)
522
+ key_states = self.k_proj(hidden_states)
523
+ value_states = self.v_proj(hidden_states)
524
+
525
+ query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
526
+ key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2)
527
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, 2 * self.head_dim).transpose(1, 2)
528
+ kv_seq_len = key_states.shape[-2]
529
+ if position_embeddings is None:
530
+ logger.warning_once(
531
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
532
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
533
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
534
+ "removed and `position_embeddings` will be mandatory.")
535
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
536
+ else:
537
+ cos, sin = (self.rotary_emb(value_states, q_len + past_key_value.get_usable_length(q_len, self.layer_idx))
538
+ if use_cache else position_embeddings)
539
+
540
+ query_states, key_states = apply_rotary_pos_emb(query_states,
541
+ key_states,
542
+ cos,
543
+ sin,
544
+ position_ids=position_ids)
545
+
546
+ if past_key_value is not None:
547
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
548
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
549
+
550
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
551
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
552
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
553
+
554
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
555
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
556
+ # cast them back in float16 just to be sure everything works as expected.
557
+ input_dtype = query_states.dtype
558
+ if input_dtype == torch.float32:
559
+ if torch.is_autocast_enabled():
560
+ target_dtype = torch.get_autocast_gpu_dtype()
561
+ # Handle the case where the model is quantized
562
+ elif hasattr(self.config, "_pre_quantization_dtype"):
563
+ target_dtype = self.config._pre_quantization_dtype
564
+ else:
565
+ target_dtype = self.q_proj.weight.dtype
566
+
567
+ logger.warning_once(
568
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
569
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
570
+ f" {target_dtype}.")
571
+
572
+ query_states = query_states.to(target_dtype)
573
+ key_states = key_states.to(target_dtype)
574
+ value_states = value_states.to(target_dtype)
575
+
576
+ q_len = query_states.shape[-2]
577
+ kv_seq_len = key_states.shape[-2]
578
+
579
+ # Reashape to the expected shape for Flash Attention
580
+ query_states = query_states.transpose(1, 2)
581
+ key_states = key_states.transpose(1, 2)
582
+ value_states = value_states.transpose(1, 2)
583
+
584
+ if (self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None
585
+ and self.layer_idx >= self.config.max_window_layers):
586
+ sliding_window = self.config.sliding_window
587
+ else:
588
+ sliding_window = None
589
+
590
+ q = self._reshape_heads(query_states, bsz, q_len)
591
+ k = self._reshape_heads(key_states, bsz, kv_seq_len)
592
+ v = self._reshape_heads(value_states, bsz, kv_seq_len)
593
+
594
+ q1, q2 = q[..., 0, :], q[..., 1, :]
595
+ k1, k2 = k[..., 0, :], k[..., 1, :]
596
+ v1, v2 = v[..., 0, :], v[..., 1, :]
597
+
598
+ q1, q2, k1, k2, v1, v2 = map(lambda x: self._restore_shape(x, bsz, q_len if x is q1 or x is q2 else kv_seq_len),
599
+ (q1, q2, k1, k2, v1, v2))
600
+
601
+ q1, q2 = q1.contiguous(), q2.contiguous()
602
+ k1, k2 = k1.contiguous(), k2.contiguous()
603
+ v1, v2 = v1.contiguous(), v2.contiguous()
604
+
605
+ attn11, attn12 = self._compute_attention(q1, k1, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window), \
606
+ self._compute_attention(q1, k1, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window)
607
+ attn21, attn22 = self._compute_attention(q2, k2, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window), \
608
+ self._compute_attention(q2, k2, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window)
609
+
610
+ attn1, attn2 = torch.cat([attn11, attn12], dim=-1), torch.cat([attn21, attn22], dim=-1)
611
+
612
+ lambda_q1 = self.lambda_q1.unsqueeze(0).expand([bsz, self.lambda_q1.shape[0]]) # bsz, num_head
613
+ lambda_q2 = self.lambda_q2.unsqueeze(0).expand([bsz, self.lambda_q2.shape[0]]) # bsz, num_head
614
+
615
+ lambda_1 = torch.exp(torch.sum(lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(attn1) # bsz
616
+ lambda_2 = torch.exp(torch.sum(lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(attn2) # bsz
617
+
618
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init
619
+
620
+ attn_output = attn1 - lambda_full.view([bsz, 1, 1, 1]) * attn2
621
+
622
+ attn_output = self.subln(attn_output)
623
+ attn_output = attn_output * (1 - self.lambda_init)
624
+
625
+ if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim * 2):
626
+ raise ValueError(f"`attn_output` should be of size {(bsz, q_len, self.num_heads, 2*self.head_dim)}, but is"
627
+ f" {attn_output.size()}")
628
+
629
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
630
+ attn_output = self.o_proj(attn_output)
631
+
632
+ return attn_output, None, past_key_value
633
+
634
+
635
+ class MotifSdpaAttention(MotifAttention):
636
+ """
637
+ Motif attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
638
+ `MotifAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
639
+ SDPA API.
640
+ """
641
+
642
+ def forward(
643
+ self,
644
+ hidden_states: torch.Tensor,
645
+ attention_mask: Optional[torch.Tensor] = None,
646
+ position_ids: Optional[torch.LongTensor] = None,
647
+ past_key_value: Optional[Cache] = None,
648
+ output_attentions: bool = False,
649
+ use_cache: bool = False,
650
+ cache_position: Optional[torch.LongTensor] = None,
651
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
652
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
653
+ if output_attentions:
654
+ logger.warning_once(
655
+ "MotifModel is using MotifSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
656
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
657
+ )
658
+ return super().forward(
659
+ hidden_states=hidden_states,
660
+ attention_mask=attention_mask,
661
+ position_ids=position_ids,
662
+ past_key_value=past_key_value,
663
+ output_attentions=output_attentions,
664
+ use_cache=use_cache,
665
+ )
666
+
667
+ bsz, q_len, _ = hidden_states.size()
668
+
669
+ query_states = self.q_proj(hidden_states)
670
+ key_states = self.k_proj(hidden_states)
671
+ value_states = self.v_proj(hidden_states)
672
+
673
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
674
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
675
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
676
+ kv_seq_len = key_states.shape[-2]
677
+ if position_embeddings is None:
678
+ logger.warning_once(
679
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
680
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
681
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
682
+ "removed and `position_embeddings` will be mandatory.")
683
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
684
+ else:
685
+ cos, sin = position_embeddings
686
+ query_states, key_states = apply_rotary_pos_emb(query_states,
687
+ key_states,
688
+ cos,
689
+ sin)
690
+
691
+ if past_key_value is not None:
692
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
693
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
694
+
695
+ query_states = query_states.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
696
+ key_states = key_states.transpose(1, 2).reshape(bsz, q_len, self.hidden_size // self.num_key_value_groups)
697
+ value_states = value_states.transpose(1, 2).reshape(bsz, q_len, self.hidden_size // self.num_key_value_groups)
698
+
699
+ batch, query_length, key_length = query_states.size(0), query_states.size(-2), key_states.size(-2)
700
+ masked_bias = attention_mask.expand(batch, self.num_heads, query_length, key_length)
701
+
702
+ # Compute Scale Factor
703
+ scale_factor = 1.0
704
+ scale_factor /= float(self.head_dim) ** 0.5
705
+
706
+ attn_output = ScaledDotProductAttention(query_states,
707
+ key_states,
708
+ value_states,
709
+ masked_bias,
710
+ dropout_rate=0.0,
711
+ training=self.training,
712
+ attn_weight_scale_factor=scale_factor,
713
+ num_kv_groups=self.num_key_value_groups,
714
+ recompute_mode=False)
715
+ attn_output = attn_output.to(hidden_states.dtype)
716
+
717
+ attn_output = self.o_proj(attn_output)
718
+
719
+ return attn_output, None, past_key_value
720
+
721
+
722
+ MOTIF_ATTENTION_CLASSES = {
723
+ "eager": MotifAttention,
724
+ "flash_attention_2": MotifFlashAttention2,
725
+ "sdpa": MotifAttention,
726
+ }
727
+
728
+
729
+ class MotifDecoderLayer(nn.Module):
730
+
731
+ def __init__(self, config: MotifConfig, layer_idx: int):
732
+ super().__init__()
733
+ self.hidden_size = config.hidden_size
734
+ if config.sliding_window and config._attn_implementation != "flash_attention_2":
735
+ logger.warning_once(
736
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
737
+ "unexpected results may be encountered.")
738
+ self.self_attn = MOTIF_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
739
+ self.mlp = MotifMLP(config)
740
+
741
+ self.input_layernorm = MotifRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
742
+ self.post_attention_layernorm = MotifRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
743
+
744
+
745
+ def forward(
746
+ self,
747
+ hidden_states: torch.Tensor,
748
+ attention_mask: Optional[torch.Tensor] = None,
749
+ position_ids: Optional[torch.LongTensor] = None,
750
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
751
+ output_attentions: Optional[bool] = False,
752
+ use_cache: Optional[bool] = False,
753
+ cache_position: Optional[torch.LongTensor] = None,
754
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
755
+ **kwargs,
756
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
757
+ """
758
+ Args:
759
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
760
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
761
+ `(batch, sequence_length)` where padding elements are indicated by 0.
762
+ output_attentions (`bool`, *optional*):
763
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
764
+ returned tensors for more detail.
765
+ use_cache (`bool`, *optional*):
766
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
767
+ (see `past_key_values`).
768
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
769
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
770
+ Indices depicting the position of the input sequence tokens in the sequence.
771
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
772
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
773
+ with `head_dim` being the embedding dimension of each attention head.
774
+ kwargs (`dict`, *optional*):
775
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
776
+ into the model
777
+ """
778
+
779
+ residual = hidden_states
780
+
781
+ hidden_states = self.input_layernorm(hidden_states)
782
+
783
+ # Self Attention
784
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
785
+ hidden_states=hidden_states,
786
+ attention_mask=attention_mask,
787
+ position_ids=position_ids,
788
+ past_key_value=past_key_value,
789
+ output_attentions=output_attentions,
790
+ use_cache=use_cache,
791
+ cache_position=cache_position,
792
+ position_embeddings=position_embeddings,
793
+ )
794
+ hidden_states = residual + hidden_states
795
+
796
+ # Fully Connected
797
+ residual = hidden_states
798
+ hidden_states = self.post_attention_layernorm(hidden_states)
799
+ hidden_states = self.mlp(hidden_states)
800
+ hidden_states = residual + hidden_states
801
+
802
+ outputs = (hidden_states, )
803
+
804
+ if output_attentions:
805
+ outputs += (self_attn_weights, )
806
+
807
+ if use_cache:
808
+ outputs += (present_key_value, )
809
+
810
+ return outputs
811
+
812
+
813
+ MOTIF_START_DOCSTRING = r"""
814
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
815
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
816
+ etc.)
817
+
818
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
819
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
820
+ and behavior.
821
+
822
+ Parameters:
823
+ config ([`MotifConfig`]):
824
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
825
+ load the weights associated with the model, only the configuration. Check out the
826
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
827
+ """
828
+
829
+
830
+ @add_start_docstrings(
831
+ "The bare Motif Model outputting raw hidden-states without any specific head on top.",
832
+ MOTIF_START_DOCSTRING,
833
+ )
834
+ class MotifPreTrainedModel(PreTrainedModel):
835
+ config_class = MotifConfig
836
+ base_model_prefix = "model"
837
+ supports_gradient_checkpointing = True
838
+ _no_split_modules = ["MotifDecoderLayer"]
839
+ _skip_keys_device_placement = "past_key_values"
840
+ _supports_flash_attn_2 = True
841
+ _supports_sdpa = True
842
+ _supports_cache_class = True
843
+ _supports_quantized_cache = True
844
+ _supports_static_cache = True
845
+
846
+ def _init_weights(self, module):
847
+ module_std = self.config.initializer_range
848
+ if isinstance(module, nn.Linear):
849
+ module.weight.data.normal_(mean=0.0, std=module_std)
850
+ module.weight.data = torch.where(abs(module.weight.data) > module_std*3, 0, module.weight.data)
851
+ if module.bias is not None:
852
+ module.bias.data.zero_()
853
+
854
+ elif isinstance(module, nn.Embedding):
855
+ module.weight.data.normal_(mean=0.0, std=module_std)
856
+ module.weight.data = torch.where(abs(module.weight.data) > module_std*3, 0, module.weight.data)
857
+ if module.padding_idx is not None:
858
+ module.weight.data[module.padding_idx].zero_()
859
+
860
+
861
+ @dataclass
862
+ class MotifModelOutputWithPast(ModelOutput):
863
+ """
864
+ This augments `BaseModelOutputWithPast` in `transformers.modeling_outputs` with new optional keys: `causal_mask`, `position_embeddings`.
865
+ The optional keys are currently used in the following ways:
866
+ - pass information to the token-wise last attention layers in multi-token training
867
+ """
868
+ last_hidden_state: torch.FloatTensor = None
869
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
870
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
871
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
872
+ causal_mask: Optional[torch.Tensor] = None
873
+ position_embeddings: Optional[torch.FloatTensor] = None
874
+
875
+
876
+ MOTIF_INPUTS_DOCSTRING = r"""
877
+ Args:
878
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
879
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
880
+ it.
881
+
882
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
883
+ [`PreTrainedTokenizer.__call__`] for details.
884
+
885
+ [What are input IDs?](../glossary#input-ids)
886
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
887
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
888
+
889
+ - 1 for tokens that are **not masked**,
890
+ - 0 for tokens that are **masked**.
891
+
892
+ [What are attention masks?](../glossary#attention-mask)
893
+
894
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
895
+ [`PreTrainedTokenizer.__call__`] for details.
896
+
897
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
898
+ `past_key_values`).
899
+
900
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
901
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
902
+ information on the default strategy.
903
+
904
+ - 1 indicates the head is **not masked**,
905
+ - 0 indicates the head is **masked**.
906
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
907
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
908
+ config.n_positions - 1]`.
909
+
910
+ [What are position IDs?](../glossary#position-ids)
911
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
912
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
913
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
914
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
915
+
916
+ Two formats are allowed:
917
+ - a [`~cache_utils.Cache`] instance, see our
918
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
919
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
920
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
921
+ cache format.
922
+
923
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
924
+ legacy cache format will be returned.
925
+
926
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
927
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
928
+ of shape `(batch_size, sequence_length)`.
929
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
930
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
931
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
932
+ model's internal embedding lookup matrix.
933
+ use_cache (`bool`, *optional*):
934
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
935
+ `past_key_values`).
936
+ output_attentions (`bool`, *optional*):
937
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
938
+ tensors for more detail.
939
+ output_hidden_states (`bool`, *optional*):
940
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
941
+ more detail.
942
+ return_dict (`bool`, *optional*):
943
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
944
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
945
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
946
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
947
+ the complete sequence length.
948
+ """
949
+
950
+
951
+ @add_start_docstrings(
952
+ "The bare Motif Model outputting raw hidden-states without any specific head on top.",
953
+ MOTIF_START_DOCSTRING,
954
+ )
955
+ class MotifModel(MotifPreTrainedModel):
956
+ """
957
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MotifDecoderLayer`]
958
+
959
+ Args:
960
+ config: MotifConfig
961
+ """
962
+
963
+ def __init__(self, config: MotifConfig):
964
+ super().__init__(config)
965
+ self.padding_idx = config.pad_token_id
966
+ self.vocab_size = config.vocab_size
967
+
968
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
969
+ num_hidden_layers = config.num_hidden_layers
970
+ self.layers = nn.ModuleList([MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)])
971
+ self.norm = MotifRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
972
+ self.hidden_size = config.hidden_size
973
+ self.num_heads = config.num_attention_heads
974
+ self.head_dim = self.hidden_size // self.num_heads
975
+ self.max_position_embeddings = config.max_position_embeddings
976
+ self.rope_theta = config.rope_theta
977
+ self.rotary_emb = MotifRotaryEmbeddingWithCache(self.head_dim,
978
+ max_position_embeddings=self.max_position_embeddings,
979
+ base=self.rope_theta)
980
+
981
+ self.gradient_checkpointing = False
982
+ self.post_init()
983
+
984
+ def get_input_embeddings(self):
985
+ return self.embed_tokens
986
+
987
+ def set_input_embeddings(self, value):
988
+ self.embed_tokens = value
989
+
990
+ @add_start_docstrings_to_model_forward(MOTIF_INPUTS_DOCSTRING)
991
+ def forward(
992
+ self,
993
+ input_ids: torch.LongTensor = None,
994
+ attention_mask: Optional[torch.Tensor] = None,
995
+ position_ids: Optional[torch.LongTensor] = None,
996
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
997
+ inputs_embeds: Optional[torch.FloatTensor] = None,
998
+ use_cache: Optional[bool] = None,
999
+ output_attentions: Optional[bool] = None,
1000
+ output_hidden_states: Optional[bool] = None,
1001
+ return_dict: Optional[bool] = None,
1002
+ cache_position: Optional[torch.LongTensor] = None,
1003
+ outputs_include_causal_mask: bool = False,
1004
+ outputs_include_position_embeddings: bool = False,
1005
+ ) -> Union[Tuple, MotifModelOutputWithPast]:
1006
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1007
+ output_hidden_states = (output_hidden_states
1008
+ if output_hidden_states is not None else self.config.output_hidden_states)
1009
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1010
+
1011
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1012
+
1013
+ if (input_ids is None) ^ (inputs_embeds is not None):
1014
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1015
+
1016
+ if self.gradient_checkpointing and self.training:
1017
+ if use_cache:
1018
+ logger.warning_once(
1019
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
1020
+ use_cache = False
1021
+
1022
+ return_legacy_cache = False
1023
+ if use_cache and not isinstance(past_key_values, Cache):
1024
+ return_legacy_cache = True
1025
+ if past_key_values is None:
1026
+ past_key_values = DynamicCache()
1027
+ else:
1028
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1029
+ logger.warning_once(
1030
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
1031
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
1032
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)")
1033
+
1034
+ if inputs_embeds is None:
1035
+ inputs_embeds = self.embed_tokens(input_ids)
1036
+
1037
+ if cache_position is None:
1038
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1039
+ cache_position = torch.arange(past_seen_tokens,
1040
+ past_seen_tokens + inputs_embeds.shape[1],
1041
+ device=inputs_embeds.device)
1042
+ if position_ids is None:
1043
+ position_ids = cache_position.unsqueeze(0)
1044
+
1045
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values,
1046
+ output_attentions)
1047
+
1048
+ hidden_states = inputs_embeds
1049
+ bsz, q_len, _ = hidden_states.size()
1050
+ position_embeddings = self.rotary_emb(hidden_states, seq_len=q_len)
1051
+
1052
+ all_hidden_states = () if output_hidden_states else None
1053
+ all_self_attns = () if output_attentions else None
1054
+ next_decoder_cache = None
1055
+
1056
+ for idx, decoder_layer in enumerate(self.layers):
1057
+ if output_hidden_states:
1058
+ all_hidden_states += (hidden_states, )
1059
+
1060
+ if self.gradient_checkpointing and self.training:
1061
+ layer_outputs = self._gradient_checkpointing_func(
1062
+ decoder_layer.__call__,
1063
+ hidden_states,
1064
+ causal_mask,
1065
+ position_ids,
1066
+ past_key_values,
1067
+ output_attentions,
1068
+ use_cache,
1069
+ cache_position,
1070
+ position_embeddings,
1071
+ )
1072
+ else:
1073
+ layer_outputs = decoder_layer(
1074
+ hidden_states,
1075
+ attention_mask=causal_mask,
1076
+ position_ids=position_ids,
1077
+ past_key_value=past_key_values,
1078
+ output_attentions=output_attentions,
1079
+ use_cache=use_cache,
1080
+ cache_position=cache_position,
1081
+ position_embeddings=position_embeddings,
1082
+ )
1083
+
1084
+ hidden_states = layer_outputs[0]
1085
+
1086
+ if use_cache:
1087
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1088
+
1089
+ if output_attentions:
1090
+ all_self_attns += (layer_outputs[1], )
1091
+
1092
+ hidden_states = self.norm(hidden_states)
1093
+
1094
+ if output_hidden_states:
1095
+ all_hidden_states += (hidden_states, )
1096
+
1097
+ next_cache = next_decoder_cache if use_cache else None
1098
+ if return_legacy_cache:
1099
+ next_cache = next_cache.to_legacy_cache()
1100
+
1101
+ causal_mask_output = causal_mask if outputs_include_causal_mask else None
1102
+ position_embeddings_output = position_embeddings if outputs_include_position_embeddings else None
1103
+ if not return_dict:
1104
+ return tuple(v for v in [
1105
+ hidden_states, next_cache, all_hidden_states, all_self_attns, causal_mask_output,
1106
+ position_embeddings_output
1107
+ ] if v is not None)
1108
+ return MotifModelOutputWithPast(last_hidden_state=hidden_states,
1109
+ past_key_values=next_cache,
1110
+ hidden_states=all_hidden_states,
1111
+ attentions=all_self_attns,
1112
+ causal_mask=causal_mask_output,
1113
+ position_embeddings=position_embeddings_output)
1114
+
1115
+ def _update_causal_mask(
1116
+ self,
1117
+ attention_mask: torch.Tensor,
1118
+ input_tensor: torch.Tensor,
1119
+ cache_position: torch.Tensor,
1120
+ past_key_values: Cache,
1121
+ output_attentions: bool,
1122
+ ):
1123
+ if self.config._attn_implementation == "flash_attention_2":
1124
+ if attention_mask is not None and 0.0 in attention_mask:
1125
+ return attention_mask
1126
+ return None
1127
+
1128
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
1129
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
1130
+ # to infer the attention mask.
1131
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1132
+ using_static_cache = isinstance(past_key_values, StaticCache)
1133
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
1134
+
1135
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
1136
+ if (self.config._attn_implementation == "sdpa" and not (using_static_cache or using_sliding_window_cache)
1137
+ and not output_attentions):
1138
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
1139
+ attention_mask,
1140
+ inputs_embeds=input_tensor,
1141
+ past_key_values_length=past_seen_tokens,
1142
+ sliding_window=self.config.sliding_window,
1143
+ is_training=self.training,
1144
+ ):
1145
+ return None
1146
+
1147
+ dtype, device = input_tensor.dtype, input_tensor.device
1148
+ min_dtype = torch.finfo(dtype).min
1149
+ sequence_length = input_tensor.shape[1]
1150
+
1151
+ # SlidingWindowCache or StaticCache
1152
+ if using_sliding_window_cache or using_static_cache:
1153
+ target_length = past_key_values.get_max_cache_shape()
1154
+ # DynamicCache or no cache
1155
+ else:
1156
+ target_length = (attention_mask.shape[-1]
1157
+ if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1)
1158
+
1159
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1160
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
1161
+ attention_mask,
1162
+ sequence_length=sequence_length,
1163
+ target_length=target_length,
1164
+ dtype=dtype,
1165
+ device=device,
1166
+ cache_position=cache_position,
1167
+ batch_size=input_tensor.shape[0],
1168
+ config=self.config,
1169
+ past_key_values=past_key_values,
1170
+ )
1171
+
1172
+ if (self.config._attn_implementation == "sdpa" and attention_mask is not None
1173
+ and attention_mask.device.type == "cuda" and not output_attentions):
1174
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1175
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1176
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1177
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1178
+
1179
+ return causal_mask
1180
+
1181
+ @staticmethod
1182
+ def _prepare_4d_causal_attention_mask_with_cache_position(
1183
+ attention_mask: torch.Tensor,
1184
+ sequence_length: int,
1185
+ target_length: int,
1186
+ dtype: torch.dtype,
1187
+ device: torch.device,
1188
+ cache_position: torch.Tensor,
1189
+ batch_size: int,
1190
+ config: MotifConfig,
1191
+ past_key_values: Cache,
1192
+ ):
1193
+ """
1194
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
1195
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
1196
+
1197
+ Args:
1198
+ attention_mask (`torch.Tensor`):
1199
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
1200
+ sequence_length (`int`):
1201
+ The sequence length being processed.
1202
+ target_length (`int`):
1203
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
1204
+ dtype (`torch.dtype`):
1205
+ The dtype to use for the 4D attention mask.
1206
+ device (`torch.device`):
1207
+ The device to plcae the 4D attention mask on.
1208
+ cache_position (`torch.Tensor`):
1209
+ Indices depicting the position of the input sequence tokens in the sequence.
1210
+ batch_size (`torch.Tensor`):
1211
+ Batch size.
1212
+ config (`MotifConfig`):
1213
+ The model's configuration class
1214
+ past_key_values (`Cache`):
1215
+ The cache class that is being used currently to generate
1216
+ """
1217
+ if attention_mask is not None and attention_mask.dim() == 4:
1218
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1219
+ causal_mask = attention_mask
1220
+ else:
1221
+ min_dtype = torch.finfo(dtype).min
1222
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1223
+ diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1224
+ if config.sliding_window is not None:
1225
+ # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
1226
+ # the check is needed to verify is current checkpoint was trained with sliding window or not
1227
+ if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
1228
+ sliding_attend_mask = torch.arange(
1229
+ target_length, device=device) <= (cache_position.reshape(-1, 1) - config.sliding_window)
1230
+ diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
1231
+ causal_mask *= diagonal_attend_mask
1232
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
1233
+ if attention_mask is not None:
1234
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1235
+ if attention_mask.shape[-1] > target_length:
1236
+ attention_mask = attention_mask[:, :target_length]
1237
+ mask_length = attention_mask.shape[-1]
1238
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
1239
+ padding_mask = padding_mask == 0
1240
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
1241
+ padding_mask, min_dtype)
1242
+ return causal_mask
1243
+
1244
+
1245
+ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
1246
+ _tied_weights_keys = ["lm_head.weight"]
1247
+
1248
+ def __init__(self, config: MotifConfig):
1249
+ super().__init__(config)
1250
+ self.model = MotifModel(config)
1251
+ self.vocab_size = config.vocab_size
1252
+
1253
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1254
+
1255
+ # Initialize weights and apply final processing
1256
+ self.post_init()
1257
+
1258
+ if getattr(config, "tie_word_embeddings", True):
1259
+ self.tie_weights()
1260
+
1261
+ def get_input_embeddings(self):
1262
+ return self.model.embed_tokens
1263
+
1264
+ def set_input_embeddings(self, value):
1265
+ self.model.embed_tokens = value
1266
+
1267
+ def get_output_embeddings(self):
1268
+ return self.lm_head
1269
+
1270
+ def set_output_embeddings(self, new_embeddings):
1271
+ self.lm_head = new_embeddings
1272
+
1273
+ def set_decoder(self, decoder):
1274
+ self.model = decoder
1275
+
1276
+ def get_decoder(self):
1277
+ return self.model
1278
+
1279
+
1280
+ @add_start_docstrings_to_model_forward(MOTIF_INPUTS_DOCSTRING)
1281
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1282
+ def forward(
1283
+ self,
1284
+ input_ids: torch.LongTensor = None,
1285
+ attention_mask: Optional[torch.Tensor] = None,
1286
+ position_ids: Optional[torch.LongTensor] = None,
1287
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1288
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1289
+ labels: Optional[torch.LongTensor] = None,
1290
+ use_cache: Optional[bool] = None,
1291
+ output_attentions: Optional[bool] = None,
1292
+ output_hidden_states: Optional[bool] = None,
1293
+ return_dict: Optional[bool] = None,
1294
+ cache_position: Optional[torch.LongTensor] = None,
1295
+ num_logits_to_keep: int = 0,
1296
+ **loss_kwargs,
1297
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1298
+ r"""
1299
+ Args:
1300
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1301
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1302
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1303
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1304
+
1305
+ num_logits_to_keep (`int`, *optional*):
1306
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
1307
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
1308
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
1309
+
1310
+ Returns:
1311
+
1312
+ Example:
1313
+
1314
+ ```python
1315
+ >>> from transformers import AutoTokenizer, MotifForCausalLM
1316
+
1317
+ >>> model = MotifForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1318
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1319
+
1320
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1321
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1322
+
1323
+ >>> # Generate
1324
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1325
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1326
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1327
+ ```"""
1328
+
1329
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1330
+ output_hidden_states = (output_hidden_states
1331
+ if output_hidden_states is not None else self.config.output_hidden_states)
1332
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1333
+
1334
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1335
+ outputs: MotifModelOutputWithPast = self.model(
1336
+ input_ids=input_ids,
1337
+ attention_mask=attention_mask,
1338
+ position_ids=position_ids,
1339
+ past_key_values=past_key_values,
1340
+ inputs_embeds=inputs_embeds,
1341
+ use_cache=use_cache,
1342
+ output_attentions=output_attentions,
1343
+ output_hidden_states=output_hidden_states,
1344
+ return_dict=return_dict,
1345
+ cache_position=cache_position,
1346
+ )
1347
+
1348
+ hidden_states = outputs[0]
1349
+
1350
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1351
+ hidden_states = hidden_states
1352
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
1353
+ logits = logits.float()
1354
+
1355
+ loss = None
1356
+ if labels is not None:
1357
+ logits = logits
1358
+ # Shift so that tokens < n predict n
1359
+ shift_logits = logits[..., :-1, :].contiguous()
1360
+ shift_labels = labels[..., 1:].contiguous()
1361
+ # Flatten the tokens
1362
+ loss_fct = CrossEntropyLoss()
1363
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1364
+ shift_labels = shift_labels.view(-1)
1365
+ shift_labels = shift_labels.to(shift_logits.device)
1366
+ loss = loss_fct(shift_logits, shift_labels)
1367
+
1368
+ if not return_dict:
1369
+ output = (logits, ) + outputs[1:]
1370
+ return (loss, ) + output if loss is not None else output
1371
+
1372
+ return CausalLMOutputWithPast(
1373
+ loss=loss,
1374
+ logits=logits,
1375
+ past_key_values=outputs.past_key_values,
1376
+ hidden_states=outputs.hidden_states,
1377
+ attentions=outputs.attentions,
1378
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|beginoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|endoftext|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:275139d476909028da05d4ee035aba88f0ca0dbfd0d395f72b7fc80fd7782e19
3
+ size 17264873
tokenizer_config.json ADDED
@@ -0,0 +1,1027 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "219395": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "219396": {
13
+ "content": "<|beginoftext|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "219397": {
21
+ "content": "<|fim_prefix|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "219398": {
29
+ "content": "<|fim_middle|>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ },
36
+ "219399": {
37
+ "content": "<|fim_suffix|>",
38
+ "lstrip": false,
39
+ "normalized": false,
40
+ "rstrip": false,
41
+ "single_word": false,
42
+ "special": true
43
+ },
44
+ "219400": {
45
+ "content": "<|system|>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false,
50
+ "special": true
51
+ },
52
+ "219401": {
53
+ "content": "<|user|>",
54
+ "lstrip": false,
55
+ "normalized": false,
56
+ "rstrip": false,
57
+ "single_word": false,
58
+ "special": true
59
+ },
60
+ "219402": {
61
+ "content": "<|assistant|>",
62
+ "lstrip": false,
63
+ "normalized": false,
64
+ "rstrip": false,
65
+ "single_word": false,
66
+ "special": true
67
+ },
68
+ "219403": {
69
+ "content": "<|startofturn|>",
70
+ "lstrip": false,
71
+ "normalized": false,
72
+ "rstrip": false,
73
+ "single_word": false,
74
+ "special": true
75
+ },
76
+ "219404": {
77
+ "content": "<think>",
78
+ "lstrip": false,
79
+ "normalized": false,
80
+ "rstrip": false,
81
+ "single_word": false,
82
+ "special": true
83
+ },
84
+ "219405": {
85
+ "content": "<|endofturn|>",
86
+ "lstrip": false,
87
+ "normalized": false,
88
+ "rstrip": false,
89
+ "single_word": false,
90
+ "special": true
91
+ },
92
+ "219406": {
93
+ "content": "</think>",
94
+ "lstrip": false,
95
+ "normalized": false,
96
+ "rstrip": false,
97
+ "single_word": false,
98
+ "special": true
99
+ },
100
+ "219407": {
101
+ "content": "<|dummy_id_3|>",
102
+ "lstrip": false,
103
+ "normalized": false,
104
+ "rstrip": false,
105
+ "single_word": false,
106
+ "special": true
107
+ },
108
+ "219408": {
109
+ "content": "<|dummy_id_4|>",
110
+ "lstrip": false,
111
+ "normalized": false,
112
+ "rstrip": false,
113
+ "single_word": false,
114
+ "special": true
115
+ },
116
+ "219409": {
117
+ "content": "<|dummy_id_5|>",
118
+ "lstrip": false,
119
+ "normalized": false,
120
+ "rstrip": false,
121
+ "single_word": false,
122
+ "special": true
123
+ },
124
+ "219410": {
125
+ "content": "<|dummy_id_6|>",
126
+ "lstrip": false,
127
+ "normalized": false,
128
+ "rstrip": false,
129
+ "single_word": false,
130
+ "special": true
131
+ },
132
+ "219411": {
133
+ "content": "<|dummy_id_7|>",
134
+ "lstrip": false,
135
+ "normalized": false,
136
+ "rstrip": false,
137
+ "single_word": false,
138
+ "special": true
139
+ },
140
+ "219412": {
141
+ "content": "<|dummy_id_8|>",
142
+ "lstrip": false,
143
+ "normalized": false,
144
+ "rstrip": false,
145
+ "single_word": false,
146
+ "special": true
147
+ },
148
+ "219413": {
149
+ "content": "<|dummy_id_9|>",
150
+ "lstrip": false,
151
+ "normalized": false,
152
+ "rstrip": false,
153
+ "single_word": false,
154
+ "special": true
155
+ },
156
+ "219414": {
157
+ "content": "<|dummy_id_10|>",
158
+ "lstrip": false,
159
+ "normalized": false,
160
+ "rstrip": false,
161
+ "single_word": false,
162
+ "special": true
163
+ },
164
+ "219415": {
165
+ "content": "<|dummy_id_11|>",
166
+ "lstrip": false,
167
+ "normalized": false,
168
+ "rstrip": false,
169
+ "single_word": false,
170
+ "special": true
171
+ },
172
+ "219416": {
173
+ "content": "<|endofprompt|>",
174
+ "lstrip": false,
175
+ "normalized": false,
176
+ "rstrip": false,
177
+ "single_word": false,
178
+ "special": true
179
+ },
180
+ "219417": {
181
+ "content": "<|dummy_id_12|>",
182
+ "lstrip": false,
183
+ "normalized": false,
184
+ "rstrip": false,
185
+ "single_word": false,
186
+ "special": true
187
+ },
188
+ "219418": {
189
+ "content": "<|dummy_id_13|>",
190
+ "lstrip": false,
191
+ "normalized": false,
192
+ "rstrip": false,
193
+ "single_word": false,
194
+ "special": true
195
+ },
196
+ "219419": {
197
+ "content": "<|dummy_id_14|>",
198
+ "lstrip": false,
199
+ "normalized": false,
200
+ "rstrip": false,
201
+ "single_word": false,
202
+ "special": true
203
+ },
204
+ "219420": {
205
+ "content": "<|dummy_id_15|>",
206
+ "lstrip": false,
207
+ "normalized": false,
208
+ "rstrip": false,
209
+ "single_word": false,
210
+ "special": true
211
+ },
212
+ "219421": {
213
+ "content": "<|dummy_id_16|>",
214
+ "lstrip": false,
215
+ "normalized": false,
216
+ "rstrip": false,
217
+ "single_word": false,
218
+ "special": true
219
+ },
220
+ "219422": {
221
+ "content": "<|dummy_id_17|>",
222
+ "lstrip": false,
223
+ "normalized": false,
224
+ "rstrip": false,
225
+ "single_word": false,
226
+ "special": true
227
+ },
228
+ "219423": {
229
+ "content": "<|dummy_id_18|>",
230
+ "lstrip": false,
231
+ "normalized": false,
232
+ "rstrip": false,
233
+ "single_word": false,
234
+ "special": true
235
+ },
236
+ "219424": {
237
+ "content": "<|dummy_id_19|>",
238
+ "lstrip": false,
239
+ "normalized": false,
240
+ "rstrip": false,
241
+ "single_word": false,
242
+ "special": true
243
+ },
244
+ "219425": {
245
+ "content": "<|dummy_id_20|>",
246
+ "lstrip": false,
247
+ "normalized": false,
248
+ "rstrip": false,
249
+ "single_word": false,
250
+ "special": true
251
+ },
252
+ "219426": {
253
+ "content": "<|dummy_id_21|>",
254
+ "lstrip": false,
255
+ "normalized": false,
256
+ "rstrip": false,
257
+ "single_word": false,
258
+ "special": true
259
+ },
260
+ "219427": {
261
+ "content": "<|dummy_id_22|>",
262
+ "lstrip": false,
263
+ "normalized": false,
264
+ "rstrip": false,
265
+ "single_word": false,
266
+ "special": true
267
+ },
268
+ "219428": {
269
+ "content": "<|dummy_id_23|>",
270
+ "lstrip": false,
271
+ "normalized": false,
272
+ "rstrip": false,
273
+ "single_word": false,
274
+ "special": true
275
+ },
276
+ "219429": {
277
+ "content": "<|dummy_id_24|>",
278
+ "lstrip": false,
279
+ "normalized": false,
280
+ "rstrip": false,
281
+ "single_word": false,
282
+ "special": true
283
+ },
284
+ "219430": {
285
+ "content": "<|dummy_id_25|>",
286
+ "lstrip": false,
287
+ "normalized": false,
288
+ "rstrip": false,
289
+ "single_word": false,
290
+ "special": true
291
+ },
292
+ "219431": {
293
+ "content": "<|dummy_id_26|>",
294
+ "lstrip": false,
295
+ "normalized": false,
296
+ "rstrip": false,
297
+ "single_word": false,
298
+ "special": true
299
+ },
300
+ "219432": {
301
+ "content": "<|dummy_id_27|>",
302
+ "lstrip": false,
303
+ "normalized": false,
304
+ "rstrip": false,
305
+ "single_word": false,
306
+ "special": true
307
+ },
308
+ "219433": {
309
+ "content": "<|dummy_id_28|>",
310
+ "lstrip": false,
311
+ "normalized": false,
312
+ "rstrip": false,
313
+ "single_word": false,
314
+ "special": true
315
+ },
316
+ "219434": {
317
+ "content": "<|dummy_id_29|>",
318
+ "lstrip": false,
319
+ "normalized": false,
320
+ "rstrip": false,
321
+ "single_word": false,
322
+ "special": true
323
+ },
324
+ "219435": {
325
+ "content": "<|dummy_id_30|>",
326
+ "lstrip": false,
327
+ "normalized": false,
328
+ "rstrip": false,
329
+ "single_word": false,
330
+ "special": true
331
+ },
332
+ "219436": {
333
+ "content": "<|dummy_id_31|>",
334
+ "lstrip": false,
335
+ "normalized": false,
336
+ "rstrip": false,
337
+ "single_word": false,
338
+ "special": true
339
+ },
340
+ "219437": {
341
+ "content": "<|dummy_id_32|>",
342
+ "lstrip": false,
343
+ "normalized": false,
344
+ "rstrip": false,
345
+ "single_word": false,
346
+ "special": true
347
+ },
348
+ "219438": {
349
+ "content": "<|dummy_id_33|>",
350
+ "lstrip": false,
351
+ "normalized": false,
352
+ "rstrip": false,
353
+ "single_word": false,
354
+ "special": true
355
+ },
356
+ "219439": {
357
+ "content": "<|dummy_id_34|>",
358
+ "lstrip": false,
359
+ "normalized": false,
360
+ "rstrip": false,
361
+ "single_word": false,
362
+ "special": true
363
+ },
364
+ "219440": {
365
+ "content": "<|dummy_id_35|>",
366
+ "lstrip": false,
367
+ "normalized": false,
368
+ "rstrip": false,
369
+ "single_word": false,
370
+ "special": true
371
+ },
372
+ "219441": {
373
+ "content": "<|dummy_id_36|>",
374
+ "lstrip": false,
375
+ "normalized": false,
376
+ "rstrip": false,
377
+ "single_word": false,
378
+ "special": true
379
+ },
380
+ "219442": {
381
+ "content": "<|dummy_id_37|>",
382
+ "lstrip": false,
383
+ "normalized": false,
384
+ "rstrip": false,
385
+ "single_word": false,
386
+ "special": true
387
+ },
388
+ "219443": {
389
+ "content": "<|dummy_id_38|>",
390
+ "lstrip": false,
391
+ "normalized": false,
392
+ "rstrip": false,
393
+ "single_word": false,
394
+ "special": true
395
+ },
396
+ "219444": {
397
+ "content": "<|dummy_id_39|>",
398
+ "lstrip": false,
399
+ "normalized": false,
400
+ "rstrip": false,
401
+ "single_word": false,
402
+ "special": true
403
+ },
404
+ "219445": {
405
+ "content": "<|dummy_id_40|>",
406
+ "lstrip": false,
407
+ "normalized": false,
408
+ "rstrip": false,
409
+ "single_word": false,
410
+ "special": true
411
+ },
412
+ "219446": {
413
+ "content": "<|dummy_id_41|>",
414
+ "lstrip": false,
415
+ "normalized": false,
416
+ "rstrip": false,
417
+ "single_word": false,
418
+ "special": true
419
+ },
420
+ "219447": {
421
+ "content": "<|dummy_id_42|>",
422
+ "lstrip": false,
423
+ "normalized": false,
424
+ "rstrip": false,
425
+ "single_word": false,
426
+ "special": true
427
+ },
428
+ "219448": {
429
+ "content": "<|dummy_id_43|>",
430
+ "lstrip": false,
431
+ "normalized": false,
432
+ "rstrip": false,
433
+ "single_word": false,
434
+ "special": true
435
+ },
436
+ "219449": {
437
+ "content": "<|dummy_id_44|>",
438
+ "lstrip": false,
439
+ "normalized": false,
440
+ "rstrip": false,
441
+ "single_word": false,
442
+ "special": true
443
+ },
444
+ "219450": {
445
+ "content": "<|dummy_id_45|>",
446
+ "lstrip": false,
447
+ "normalized": false,
448
+ "rstrip": false,
449
+ "single_word": false,
450
+ "special": true
451
+ },
452
+ "219451": {
453
+ "content": "<|dummy_id_46|>",
454
+ "lstrip": false,
455
+ "normalized": false,
456
+ "rstrip": false,
457
+ "single_word": false,
458
+ "special": true
459
+ },
460
+ "219452": {
461
+ "content": "<|dummy_id_47|>",
462
+ "lstrip": false,
463
+ "normalized": false,
464
+ "rstrip": false,
465
+ "single_word": false,
466
+ "special": true
467
+ },
468
+ "219453": {
469
+ "content": "<|dummy_id_48|>",
470
+ "lstrip": false,
471
+ "normalized": false,
472
+ "rstrip": false,
473
+ "single_word": false,
474
+ "special": true
475
+ },
476
+ "219454": {
477
+ "content": "<|dummy_id_49|>",
478
+ "lstrip": false,
479
+ "normalized": false,
480
+ "rstrip": false,
481
+ "single_word": false,
482
+ "special": true
483
+ },
484
+ "219455": {
485
+ "content": "<|dummy_id_50|>",
486
+ "lstrip": false,
487
+ "normalized": false,
488
+ "rstrip": false,
489
+ "single_word": false,
490
+ "special": true
491
+ },
492
+ "219456": {
493
+ "content": "<|dummy_id_51|>",
494
+ "lstrip": false,
495
+ "normalized": false,
496
+ "rstrip": false,
497
+ "single_word": false,
498
+ "special": true
499
+ },
500
+ "219457": {
501
+ "content": "<|dummy_id_52|>",
502
+ "lstrip": false,
503
+ "normalized": false,
504
+ "rstrip": false,
505
+ "single_word": false,
506
+ "special": true
507
+ },
508
+ "219458": {
509
+ "content": "<|dummy_id_53|>",
510
+ "lstrip": false,
511
+ "normalized": false,
512
+ "rstrip": false,
513
+ "single_word": false,
514
+ "special": true
515
+ },
516
+ "219459": {
517
+ "content": "<|dummy_id_54|>",
518
+ "lstrip": false,
519
+ "normalized": false,
520
+ "rstrip": false,
521
+ "single_word": false,
522
+ "special": true
523
+ },
524
+ "219460": {
525
+ "content": "<|dummy_id_55|>",
526
+ "lstrip": false,
527
+ "normalized": false,
528
+ "rstrip": false,
529
+ "single_word": false,
530
+ "special": true
531
+ },
532
+ "219461": {
533
+ "content": "<|dummy_id_56|>",
534
+ "lstrip": false,
535
+ "normalized": false,
536
+ "rstrip": false,
537
+ "single_word": false,
538
+ "special": true
539
+ },
540
+ "219462": {
541
+ "content": "<|dummy_id_57|>",
542
+ "lstrip": false,
543
+ "normalized": false,
544
+ "rstrip": false,
545
+ "single_word": false,
546
+ "special": true
547
+ },
548
+ "219463": {
549
+ "content": "<|dummy_id_58|>",
550
+ "lstrip": false,
551
+ "normalized": false,
552
+ "rstrip": false,
553
+ "single_word": false,
554
+ "special": true
555
+ },
556
+ "219464": {
557
+ "content": "<|dummy_id_59|>",
558
+ "lstrip": false,
559
+ "normalized": false,
560
+ "rstrip": false,
561
+ "single_word": false,
562
+ "special": true
563
+ },
564
+ "219465": {
565
+ "content": "<|dummy_id_60|>",
566
+ "lstrip": false,
567
+ "normalized": false,
568
+ "rstrip": false,
569
+ "single_word": false,
570
+ "special": true
571
+ },
572
+ "219466": {
573
+ "content": "<|dummy_id_61|>",
574
+ "lstrip": false,
575
+ "normalized": false,
576
+ "rstrip": false,
577
+ "single_word": false,
578
+ "special": true
579
+ },
580
+ "219467": {
581
+ "content": "<|dummy_id_62|>",
582
+ "lstrip": false,
583
+ "normalized": false,
584
+ "rstrip": false,
585
+ "single_word": false,
586
+ "special": true
587
+ },
588
+ "219468": {
589
+ "content": "<|dummy_id_63|>",
590
+ "lstrip": false,
591
+ "normalized": false,
592
+ "rstrip": false,
593
+ "single_word": false,
594
+ "special": true
595
+ },
596
+ "219469": {
597
+ "content": "<|dummy_id_64|>",
598
+ "lstrip": false,
599
+ "normalized": false,
600
+ "rstrip": false,
601
+ "single_word": false,
602
+ "special": true
603
+ },
604
+ "219470": {
605
+ "content": "<|dummy_id_65|>",
606
+ "lstrip": false,
607
+ "normalized": false,
608
+ "rstrip": false,
609
+ "single_word": false,
610
+ "special": true
611
+ },
612
+ "219471": {
613
+ "content": "<|dummy_id_66|>",
614
+ "lstrip": false,
615
+ "normalized": false,
616
+ "rstrip": false,
617
+ "single_word": false,
618
+ "special": true
619
+ },
620
+ "219472": {
621
+ "content": "<|dummy_id_67|>",
622
+ "lstrip": false,
623
+ "normalized": false,
624
+ "rstrip": false,
625
+ "single_word": false,
626
+ "special": true
627
+ },
628
+ "219473": {
629
+ "content": "<|dummy_id_68|>",
630
+ "lstrip": false,
631
+ "normalized": false,
632
+ "rstrip": false,
633
+ "single_word": false,
634
+ "special": true
635
+ },
636
+ "219474": {
637
+ "content": "<|dummy_id_69|>",
638
+ "lstrip": false,
639
+ "normalized": false,
640
+ "rstrip": false,
641
+ "single_word": false,
642
+ "special": true
643
+ },
644
+ "219475": {
645
+ "content": "<|dummy_id_70|>",
646
+ "lstrip": false,
647
+ "normalized": false,
648
+ "rstrip": false,
649
+ "single_word": false,
650
+ "special": true
651
+ },
652
+ "219476": {
653
+ "content": "<|dummy_id_71|>",
654
+ "lstrip": false,
655
+ "normalized": false,
656
+ "rstrip": false,
657
+ "single_word": false,
658
+ "special": true
659
+ },
660
+ "219477": {
661
+ "content": "<|dummy_id_72|>",
662
+ "lstrip": false,
663
+ "normalized": false,
664
+ "rstrip": false,
665
+ "single_word": false,
666
+ "special": true
667
+ },
668
+ "219478": {
669
+ "content": "<|dummy_id_73|>",
670
+ "lstrip": false,
671
+ "normalized": false,
672
+ "rstrip": false,
673
+ "single_word": false,
674
+ "special": true
675
+ },
676
+ "219479": {
677
+ "content": "<|dummy_id_74|>",
678
+ "lstrip": false,
679
+ "normalized": false,
680
+ "rstrip": false,
681
+ "single_word": false,
682
+ "special": true
683
+ },
684
+ "219480": {
685
+ "content": "<|dummy_id_75|>",
686
+ "lstrip": false,
687
+ "normalized": false,
688
+ "rstrip": false,
689
+ "single_word": false,
690
+ "special": true
691
+ },
692
+ "219481": {
693
+ "content": "<|dummy_id_76|>",
694
+ "lstrip": false,
695
+ "normalized": false,
696
+ "rstrip": false,
697
+ "single_word": false,
698
+ "special": true
699
+ },
700
+ "219482": {
701
+ "content": "<|dummy_id_77|>",
702
+ "lstrip": false,
703
+ "normalized": false,
704
+ "rstrip": false,
705
+ "single_word": false,
706
+ "special": true
707
+ },
708
+ "219483": {
709
+ "content": "<|dummy_id_78|>",
710
+ "lstrip": false,
711
+ "normalized": false,
712
+ "rstrip": false,
713
+ "single_word": false,
714
+ "special": true
715
+ },
716
+ "219484": {
717
+ "content": "<|dummy_id_79|>",
718
+ "lstrip": false,
719
+ "normalized": false,
720
+ "rstrip": false,
721
+ "single_word": false,
722
+ "special": true
723
+ },
724
+ "219485": {
725
+ "content": "<|dummy_id_80|>",
726
+ "lstrip": false,
727
+ "normalized": false,
728
+ "rstrip": false,
729
+ "single_word": false,
730
+ "special": true
731
+ },
732
+ "219486": {
733
+ "content": "<|dummy_id_81|>",
734
+ "lstrip": false,
735
+ "normalized": false,
736
+ "rstrip": false,
737
+ "single_word": false,
738
+ "special": true
739
+ },
740
+ "219487": {
741
+ "content": "<|dummy_id_82|>",
742
+ "lstrip": false,
743
+ "normalized": false,
744
+ "rstrip": false,
745
+ "single_word": false,
746
+ "special": true
747
+ },
748
+ "219488": {
749
+ "content": "<|dummy_id_83|>",
750
+ "lstrip": false,
751
+ "normalized": false,
752
+ "rstrip": false,
753
+ "single_word": false,
754
+ "special": true
755
+ },
756
+ "219489": {
757
+ "content": "<|dummy_id_84|>",
758
+ "lstrip": false,
759
+ "normalized": false,
760
+ "rstrip": false,
761
+ "single_word": false,
762
+ "special": true
763
+ },
764
+ "219490": {
765
+ "content": "<|dummy_id_85|>",
766
+ "lstrip": false,
767
+ "normalized": false,
768
+ "rstrip": false,
769
+ "single_word": false,
770
+ "special": true
771
+ },
772
+ "219491": {
773
+ "content": "<|dummy_id_86|>",
774
+ "lstrip": false,
775
+ "normalized": false,
776
+ "rstrip": false,
777
+ "single_word": false,
778
+ "special": true
779
+ },
780
+ "219492": {
781
+ "content": "<|dummy_id_87|>",
782
+ "lstrip": false,
783
+ "normalized": false,
784
+ "rstrip": false,
785
+ "single_word": false,
786
+ "special": true
787
+ },
788
+ "219493": {
789
+ "content": "<|dummy_id_88|>",
790
+ "lstrip": false,
791
+ "normalized": false,
792
+ "rstrip": false,
793
+ "single_word": false,
794
+ "special": true
795
+ },
796
+ "219494": {
797
+ "content": "<|dummy_id_89|>",
798
+ "lstrip": false,
799
+ "normalized": false,
800
+ "rstrip": false,
801
+ "single_word": false,
802
+ "special": true
803
+ },
804
+ "219495": {
805
+ "content": "<|dummy_id_90|>",
806
+ "lstrip": false,
807
+ "normalized": false,
808
+ "rstrip": false,
809
+ "single_word": false,
810
+ "special": true
811
+ },
812
+ "219496": {
813
+ "content": "<|dummy_id_91|>",
814
+ "lstrip": false,
815
+ "normalized": false,
816
+ "rstrip": false,
817
+ "single_word": false,
818
+ "special": true
819
+ },
820
+ "219497": {
821
+ "content": "<|dummy_id_92|>",
822
+ "lstrip": false,
823
+ "normalized": false,
824
+ "rstrip": false,
825
+ "single_word": false,
826
+ "special": true
827
+ },
828
+ "219498": {
829
+ "content": "<|dummy_id_93|>",
830
+ "lstrip": false,
831
+ "normalized": false,
832
+ "rstrip": false,
833
+ "single_word": false,
834
+ "special": true
835
+ },
836
+ "219499": {
837
+ "content": "<|dummy_id_94|>",
838
+ "lstrip": false,
839
+ "normalized": false,
840
+ "rstrip": false,
841
+ "single_word": false,
842
+ "special": true
843
+ },
844
+ "219500": {
845
+ "content": "<|dummy_id_95|>",
846
+ "lstrip": false,
847
+ "normalized": false,
848
+ "rstrip": false,
849
+ "single_word": false,
850
+ "special": true
851
+ },
852
+ "219501": {
853
+ "content": "<|dummy_id_96|>",
854
+ "lstrip": false,
855
+ "normalized": false,
856
+ "rstrip": false,
857
+ "single_word": false,
858
+ "special": true
859
+ },
860
+ "219502": {
861
+ "content": "<|dummy_id_97|>",
862
+ "lstrip": false,
863
+ "normalized": false,
864
+ "rstrip": false,
865
+ "single_word": false,
866
+ "special": true
867
+ },
868
+ "219503": {
869
+ "content": "<|dummy_id_98|>",
870
+ "lstrip": false,
871
+ "normalized": false,
872
+ "rstrip": false,
873
+ "single_word": false,
874
+ "special": true
875
+ },
876
+ "219504": {
877
+ "content": "<|dummy_id_99|>",
878
+ "lstrip": false,
879
+ "normalized": false,
880
+ "rstrip": false,
881
+ "single_word": false,
882
+ "special": true
883
+ },
884
+ "219505": {
885
+ "content": "<|dummy_id_100|>",
886
+ "lstrip": false,
887
+ "normalized": false,
888
+ "rstrip": false,
889
+ "single_word": false,
890
+ "special": true
891
+ },
892
+ "219506": {
893
+ "content": "<|dummy_id_101|>",
894
+ "lstrip": false,
895
+ "normalized": false,
896
+ "rstrip": false,
897
+ "single_word": false,
898
+ "special": true
899
+ },
900
+ "219507": {
901
+ "content": "<|dummy_id_102|>",
902
+ "lstrip": false,
903
+ "normalized": false,
904
+ "rstrip": false,
905
+ "single_word": false,
906
+ "special": true
907
+ },
908
+ "219508": {
909
+ "content": "<|dummy_id_103|>",
910
+ "lstrip": false,
911
+ "normalized": false,
912
+ "rstrip": false,
913
+ "single_word": false,
914
+ "special": true
915
+ },
916
+ "219509": {
917
+ "content": "<|dummy_id_104|>",
918
+ "lstrip": false,
919
+ "normalized": false,
920
+ "rstrip": false,
921
+ "single_word": false,
922
+ "special": true
923
+ },
924
+ "219510": {
925
+ "content": "<|dummy_id_105|>",
926
+ "lstrip": false,
927
+ "normalized": false,
928
+ "rstrip": false,
929
+ "single_word": false,
930
+ "special": true
931
+ },
932
+ "219511": {
933
+ "content": "<|dummy_id_106|>",
934
+ "lstrip": false,
935
+ "normalized": false,
936
+ "rstrip": false,
937
+ "single_word": false,
938
+ "special": true
939
+ },
940
+ "219512": {
941
+ "content": "<|dummy_id_107|>",
942
+ "lstrip": false,
943
+ "normalized": false,
944
+ "rstrip": false,
945
+ "single_word": false,
946
+ "special": true
947
+ },
948
+ "219513": {
949
+ "content": "<|dummy_id_108|>",
950
+ "lstrip": false,
951
+ "normalized": false,
952
+ "rstrip": false,
953
+ "single_word": false,
954
+ "special": true
955
+ },
956
+ "219514": {
957
+ "content": "<|dummy_id_109|>",
958
+ "lstrip": false,
959
+ "normalized": false,
960
+ "rstrip": false,
961
+ "single_word": false,
962
+ "special": true
963
+ },
964
+ "219515": {
965
+ "content": "<|dummy_id_110|>",
966
+ "lstrip": false,
967
+ "normalized": false,
968
+ "rstrip": false,
969
+ "single_word": false,
970
+ "special": true
971
+ },
972
+ "219516": {
973
+ "content": "<|dummy_id_111|>",
974
+ "lstrip": false,
975
+ "normalized": false,
976
+ "rstrip": false,
977
+ "single_word": false,
978
+ "special": true
979
+ },
980
+ "219517": {
981
+ "content": "<|dummy_id_112|>",
982
+ "lstrip": false,
983
+ "normalized": false,
984
+ "rstrip": false,
985
+ "single_word": false,
986
+ "special": true
987
+ },
988
+ "219518": {
989
+ "content": "<|dummy_id_113|>",
990
+ "lstrip": false,
991
+ "normalized": false,
992
+ "rstrip": false,
993
+ "single_word": false,
994
+ "special": true
995
+ },
996
+ "219519": {
997
+ "content": "<|dummy_id_114|>",
998
+ "lstrip": false,
999
+ "normalized": false,
1000
+ "rstrip": false,
1001
+ "single_word": false,
1002
+ "special": true
1003
+ }
1004
+ },
1005
+ "block_size": 2048,
1006
+ "bos_token": "<|beginoftext|>",
1007
+ "chat_template": "{{ bos_token }}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'assistant' and '</think>' in content %}{% set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}{% set content = content.split('</think>')[-1].lstrip('\n') %}{{ '<|startofturn|><|assistant|>\n\n<think>\n' + reasoning_content + '\n</think>\n\n' + content + '<|endofturn|>' }}{% else %}{{ '<|startofturn|><|' + message['role'] + '|>\n\n' + content + '<|endofturn|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|startofturn|><|assistant|>\n\n' }}{% endif %}",
1008
+ "clean_up_tokenization_spaces": false,
1009
+ "corruption_rate": 0.15,
1010
+ "eos_token": "<|endoftext|>",
1011
+ "extra_ids": 0,
1012
+ "extra_special_tokens": {},
1013
+ "fixed_vocab": true,
1014
+ "merges_file_path": "./data/merges.txt",
1015
+ "model_max_length": 1000000000000000019884624838656,
1016
+ "pad_token": "<|endoftext|>",
1017
+ "padding_side": "left",
1018
+ "seq_length": 2048,
1019
+ "tokenizer_class": "GPT2Tokenizer",
1020
+ "tokenizer_name": "/nfs-ssd/motif_1/tokenizers/ver5",
1021
+ "tokens": -1,
1022
+ "unk_token": "<|endoftext|>",
1023
+ "update_tokenizer": false,
1024
+ "use_moreh_tokenizer": false,
1025
+ "vocab_file_path": "./data/vocab.json",
1026
+ "vocab_size": 219395
1027
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff