Commit
·
a66edf1
0
Parent(s):
initial commit
Browse files- .gitattributes +35 -0
- .gitignore +2 -0
- LICENSE +60 -0
- README.md +192 -0
- config.json +255 -0
- configuration_japanese_instructblip_alpha.py +57 -0
- configuration_japanese_stablelm_alpha.py +120 -0
- japanese-instructblip-parrot.png +0 -0
- modeling_japanese_instructblip_alpha.py +62 -0
- modeling_japanese_stablelm_alpha.py +682 -0
- preprocessor_config.json +24 -0
- pytorch_model-00001-of-00004.bin +3 -0
- pytorch_model-00002-of-00004.bin +3 -0
- pytorch_model-00003-of-00004.bin +3 -0
- pytorch_model-00004-of-00004.bin +3 -0
- pytorch_model.bin.index.fp16.json +0 -0
- pytorch_model.bin.index.json +0 -0
- pytorch_model.fp16-00001-of-00002.bin +3 -0
- pytorch_model.fp16-00002-of-00002.bin +3 -0
- requirements.txt +2 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
test.py
|
LICENSE
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
JAPANESE STABLELM RESEARCH LICENSE AGREEMENT
|
2 |
+
Dated: August 7, 2023
|
3 |
+
|
4 |
+
"Agreement" means the terms and conditions for use, reproduction, distribution and modification of the Software Products set forth herein.
|
5 |
+
|
6 |
+
“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.
|
7 |
+
|
8 |
+
"Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person’s or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
9 |
+
|
10 |
+
"Stability AI" or "we" means Stability AI Ltd.
|
11 |
+
|
12 |
+
"Software" means, collectively, Stability AI’s proprietary Japanese StableLM made available under this Agreement.
|
13 |
+
|
14 |
+
“Software Products” means Software and Documentation.
|
15 |
+
|
16 |
+
By using or distributing any portion or element of the Software Products, you agree to be bound by this Agreement.
|
17 |
+
|
18 |
+
1. License Rights and Redistribution.
|
19 |
+
a. Subject to your compliance with this Agreement and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Software Products to reproduce, distribute, and create derivative works of the Software Products for purposes other than commercial or production use.
|
20 |
+
b. You will not, and will not permit, assist or cause any third party to use, modify, copy, reproduce, create derivative works of, or distribute the Software Products (or any derivative works thereof, works incorporating the Software Products, or any data produced by the Software), in whole or in part, for any commercial or production purposes.
|
21 |
+
c. If you distribute or make the Software Products, or any derivative works thereof, available to a third party, you shall (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "Japanese StableLM is licensed under the Japanese StableLM Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.”
|
22 |
+
d. The licenses granted to you under this Agreement are conditioned upon your compliance with the Documentation and this Agreement, including the Acceptable Use Policy below and as may be updated from time to time in the future on stability.ai, which is hereby incorporated by reference into this Agreement.
|
23 |
+
2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS.
|
24 |
+
3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
25 |
+
4. Intellectual Property.
|
26 |
+
a. No trademark licenses are granted under this Agreement, and in connection with the Software Products, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products.
|
27 |
+
b. Subject to Stability AI’s ownership of the Software Products and derivatives made by or for Stability AI, with respect to any derivative works and modifications of the Software Products that are made by you, as between you and Stability AI, you are and will be the owner of such derivative works and modifications.
|
28 |
+
c. If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products in violation of this Agreement.
|
29 |
+
5. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Software Products. Sections 2-4 shall survive the termination of this Agreement.
|
30 |
+
|
31 |
+
—----------
|
32 |
+
|
33 |
+
Japanese StableLM Acceptable Use Policy
|
34 |
+
|
35 |
+
If you access, use, or distribute any Stability AI models, software, or other materials (“Stability Technology”) you agree to this Acceptable Use Policy (“Policy”).
|
36 |
+
|
37 |
+
We want everyone to use Stability Technology safely and responsibly. You agree you will not use, or allow others to use, Stability Technology to:
|
38 |
+
1. To violate the law or others’ rights (including intellectual property rights and the rights of data privacy and protection), nor will you promote, contribute to, encourage, facilitate, plan, incite, or further anyone else’s violation of the law or others’ rights;
|
39 |
+
2. To commit, promote, contribute to, facilitate, encourage, plan, incite, or further any of the following:
|
40 |
+
a. Violence or terrorism;
|
41 |
+
b. Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content;
|
42 |
+
c. Human trafficking, exploitation, and sexual violence;
|
43 |
+
d. Harassment, abuse, threatening, stalking, or bullying of individuals or groups of individuals;
|
44 |
+
e. Discrimination in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services on the basis of race, color, caste, religion, sex (including pregnancy, sexual orientation, or gender identity), national origin, age, disability, or genetic information (including family medical history) except as may be required by applicable law (such as the provision of social security benefits solely to people who meet certain age requirements under the law);
|
45 |
+
f. Creation of malicious code, malware, computer viruses or any activity that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system;
|
46 |
+
3. For purposes of or for the performance of:
|
47 |
+
a. Fully automated decision-making, including profiling, with respect to an individual or group of individuals which produces legal effects concerning such individual(s) or similarly significantly affects such individual(s);
|
48 |
+
b. Systematic or automated scraping, mining, extraction, or harvesting of personally identifiable data, or similar activity, from the output of any Stability Technology except with respect to data that you have provided as input to the Stability Technology and which you are legally entitled to process, for so long as you retain such entitlement;
|
49 |
+
c. Development, improvement, or manufacture of any weapons of mass destruction (such as nuclear, chemical, or biologic weapons), weapons of war (such as missiles or landmines), or any gain of function-related activities with respect to any pathogens;
|
50 |
+
d. Mission critical applications or systems where best industry practices require fail-safe controls or performance, including operation of nuclear facilities, aircraft navigation, electrical grids, communication systems, water treatment facilities, air traffic control, life support, weapons systems, or emergency locator or other emergency services;
|
51 |
+
4. To intentionally deceive or mislead others, including use of Japanese StableLM related to the following:
|
52 |
+
i. Generating, promoting, or furthering fraud or the creation or promotion of disinformation;
|
53 |
+
ii. Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content;
|
54 |
+
iii. Generating, promoting, or further distributing spam;
|
55 |
+
iv. Impersonating another individual without consent, authorization, or legal right
|
56 |
+
v. Representing or misleading people into believing that the use of Japanese StableLM or outputs are human-generated;
|
57 |
+
vi. Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement;
|
58 |
+
vii. Generating or facilitating large-scale political advertisements, propaganda, or influence campaigns;
|
59 |
+
5. Fail to appropriately disclose to end users any known dangers of your AI system or misrepresent or mislead with respect to its abilities.
|
60 |
+
Nothing in this AUP is intended to prevent or impede any good faith research, testing, or evaluation of Japanese StableLM, or publication related to any of the foregoing. If you discover any flaws in Japanese StableLM that may be harmful to people in any way, we encourage you to notify us and give us a chance to remedy such flaws before others can exploit them. If you have questions about this AUP, contact us at [email protected].
|
README.md
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- ja
|
4 |
+
tags:
|
5 |
+
- instructblip
|
6 |
+
- vision
|
7 |
+
- image-captioning
|
8 |
+
- japanese-stablelm
|
9 |
+
pipeline_tag: image-to-text
|
10 |
+
license:
|
11 |
+
- other
|
12 |
+
extra_gated_heading: Access Japanese StableLM Instruct Alpha
|
13 |
+
extra_gated_description: This repository is publicly accessible, but you have to accept the conditions to access its files and content.
|
14 |
+
extra_gated_button_content: Access repository
|
15 |
+
extra_gated_fields:
|
16 |
+
Name: text
|
17 |
+
Email: text
|
18 |
+
Organization: text
|
19 |
+
I agree to accept the conditions and share above info with Stability AI: checkbox
|
20 |
+
extra_gated_prompt: |
|
21 |
+
### JAPANESE STABLELM RESEARCH LICENSE AGREEMENT
|
22 |
+
Dated: August 7, 2023
|
23 |
+
|
24 |
+
"Agreement" means the terms and conditions for use, reproduction, distribution and modification of the Software Products set forth herein.
|
25 |
+
|
26 |
+
“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software.
|
27 |
+
|
28 |
+
"Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person’s or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
29 |
+
|
30 |
+
"Stability AI" or "we" means Stability AI Ltd.
|
31 |
+
|
32 |
+
"Software" means, collectively, Stability AI’s proprietary Japanese StableLM made available under this Agreement.
|
33 |
+
|
34 |
+
“Software Products” means Software and Documentation.
|
35 |
+
|
36 |
+
By using or distributing any portion or element of the Software Products, you agree to be bound by this Agreement.
|
37 |
+
- License Rights and Redistribution.
|
38 |
+
- Subject to your compliance with this Agreement and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Software Products to reproduce, distribute, and create derivative works of the Software Products for purposes other than commercial or production use.
|
39 |
+
- You will not, and will not permit, assist or cause any third party to use, modify, copy, reproduce, create derivative works of, or distribute the Software Products (or any derivative works thereof, works incorporating the Software Products, or any data produced by the Software), in whole or in part, for any commercial or production purposes.
|
40 |
+
- If you distribute or make the Software Products, or any derivative works thereof, available to a third party, you shall (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "Japanese StableLM is licensed under the Japanese StableLM Research License, Copyright (c) Stability AI Ltd. All Rights Reserved.”
|
41 |
+
- The licenses granted to you under this Agreement are conditioned upon your compliance with the Documentation and this Agreement, including the Acceptable Use Policy below and as may be updated from time to time in the future on stability.ai, which is hereby incorporated by reference into this Agreement.
|
42 |
+
- Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS.
|
43 |
+
- Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
44 |
+
- Intellectual Property.
|
45 |
+
- No trademark licenses are granted under this Agreement, and in connection with the Software Products, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products.
|
46 |
+
- Subject to Stability AI’s ownership of the Software Products and derivatives made by or for Stability AI, with respect to any derivative works and modifications of the Software Products that are made by you, as between you and Stability AI, you are and will be the owner of such derivative works and modifications.
|
47 |
+
- If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products in violation of this Agreement.
|
48 |
+
- Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Software Products. Sections 2-4 shall survive the termination of this Agreement.
|
49 |
+
—----------
|
50 |
+
### Japanese StableLM Acceptable Use Policy
|
51 |
+
If you access, use, or distribute any Stability AI models, software, or other materials (“Stability Technology”) you agree to this Acceptable Use Policy (“Policy”).
|
52 |
+
We want everyone to use Stability Technology safely and responsibly. You agree you will not use, or allow others to use, Stability Technology to:
|
53 |
+
- To violate the law or others’ rights (including intellectual property rights and the rights of data privacy and protection), nor will you promote, contribute to, encourage, facilitate, plan, incite, or further anyone else’s violation of the law or others’ rights;
|
54 |
+
- To commit, promote, contribute to, facilitate, encourage, plan, incite, or further any of the following:
|
55 |
+
- Violence or terrorism;
|
56 |
+
- Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content;
|
57 |
+
- Human trafficking, exploitation, and sexual violence;
|
58 |
+
- Harassment, abuse, threatening, stalking, or bullying of individuals or groups of individuals;
|
59 |
+
- Discrimination in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services on the basis of race, color, caste, religion, sex (including pregnancy, sexual orientation, or gender identity), national origin, age, disability, or genetic information (including family medical history) except as may be required by applicable law (such as the provision of social security benefits solely to people who meet certain age requirements under the law);
|
60 |
+
- Creation of malicious code, malware, computer viruses or any activity that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system;
|
61 |
+
- For purposes of or for the performance of:
|
62 |
+
- Fully automated decision-making, including profiling, with respect to an individual or group of individuals which produces legal effects concerning such individual(s) or similarly significantly affects such individual(s);
|
63 |
+
- Systematic or automated scraping, mining, extraction, or harvesting of personally identifiable data, or similar activity, from the output of any Stability Technology except with respect to data that you have provided as input to the Stability Technology and which you are legally entitled to process, for so long as you retain such entitlement;
|
64 |
+
- Development, improvement, or manufacture of any weapons of mass destruction (such as nuclear, chemical, or biologic weapons), weapons of war (such as missiles or landmines), or any gain of function-related activities with respect to any pathogens;
|
65 |
+
- Mission critical applications or systems where best industry practices require fail-safe controls or performance, including operation of nuclear facilities, aircraft navigation, electrical grids, communication systems, water treatment facilities, air traffic control, life support, weapons systems, or emergency locator or other emergency services;
|
66 |
+
- To intentionally deceive or mislead others, including use of Japanese StableLM related to the following:
|
67 |
+
- Generating, promoting, or furthering fraud or the creation or promotion of disinformation;
|
68 |
+
- Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content;
|
69 |
+
- Generating, promoting, or further distributing spam;
|
70 |
+
- Impersonating another individual without consent, authorization, or legal right
|
71 |
+
- Representing or misleading people into believing that the use of Japanese StableLM or outputs are human-generated;
|
72 |
+
- Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement;
|
73 |
+
- Generating or facilitating large-scale political advertisements, propaganda, or influence campaigns;
|
74 |
+
- Fail to appropriately disclose to end users any known dangers of your AI system or misrepresent or mislead with respect to its abilities.
|
75 |
+
Nothing in this AUP is intended to prevent or impede any good faith research, testing, or evaluation of Japanese StableLM, or publication related to any of the foregoing. If you discover any flaws in Japanese StableLM that may be harmful to people in any way, we encourage you to notify us and give us a chance to remedy such flaws before others can exploit them. If you have questions about this AUP, contact us at [email protected].
|
76 |
+
---
|
77 |
+
|
78 |
+
# Japanese InstructBLIP Alpha
|
79 |
+
|
80 |
+

|
81 |
+
|
82 |
+
## Model Details
|
83 |
+
Japanese InstructBLIP Alpha is a vision-language instruction-following model that enables to generate Japanese descriptions for input images and optionally input texts such as questions.
|
84 |
+
|
85 |
+
|
86 |
+
## Usage
|
87 |
+
|
88 |
+
First install additional dependencies in [requirements.txt](./requirements.txt):
|
89 |
+
|
90 |
+
```sh
|
91 |
+
pip install sentencepiece einops
|
92 |
+
```
|
93 |
+
|
94 |
+
|
95 |
+
```python
|
96 |
+
import torch
|
97 |
+
from transformers import LlamaTokenizer, AutoModelForVision2Seq, BlipImageProcessor
|
98 |
+
from PIL import Image
|
99 |
+
import requests
|
100 |
+
|
101 |
+
# helper function to format input prompts
|
102 |
+
def build_prompt(prompt="", sep="\n\n### "):
|
103 |
+
sys_msg = "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。"
|
104 |
+
p = sys_msg
|
105 |
+
roles = ["指示", "応答"]
|
106 |
+
user_query = "与えられた画像について、詳細に述べてください。"
|
107 |
+
msgs = [": \n" + user_query, ": "]
|
108 |
+
if prompt:
|
109 |
+
roles.insert(1, "入力")
|
110 |
+
msgs.insert(1, ": \n" + prompt)
|
111 |
+
for role, msg in zip(roles, msgs):
|
112 |
+
p += sep + role + msg
|
113 |
+
return p
|
114 |
+
|
115 |
+
# load model
|
116 |
+
model = AutoModelForVision2Seq.from_pretrained("stabilityai/japanese-instructblip-alpha", trust_remote_code=True)
|
117 |
+
processor = BlipImageProcessor.from_pretrained("stabilityai/japanese-instructblip-alpha")
|
118 |
+
tokenizer = LlamaTokenizer.from_pretrained("novelai/nerdstash-tokenizer-v1", additional_special_tokens=['▁▁'])
|
119 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
120 |
+
model.to(device)
|
121 |
+
|
122 |
+
# prepare inputs
|
123 |
+
url = "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80"
|
124 |
+
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
|
125 |
+
prompt = "" # input empty string for image captioning. You can also input questions as prompts
|
126 |
+
prompt = build_prompt(prompt)
|
127 |
+
inputs = processor(images=image, return_tensors="pt")
|
128 |
+
text_encoding = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
|
129 |
+
text_encoding["qformer_input_ids"] = text_encoding["input_ids"].clone()
|
130 |
+
text_encoding["qformer_attention_mask"] = text_encoding["attention_mask"].clone()
|
131 |
+
inputs.update(text_encoding)
|
132 |
+
|
133 |
+
# generate
|
134 |
+
outputs = model.generate(
|
135 |
+
**inputs.to(device, dtype=model.dtype),
|
136 |
+
num_beams=5,
|
137 |
+
max_new_tokens=32,
|
138 |
+
min_length=1,
|
139 |
+
)
|
140 |
+
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].strip()
|
141 |
+
print(generated_text)
|
142 |
+
# 桜と東京スカイツリー
|
143 |
+
```
|
144 |
+
|
145 |
+
|
146 |
+
## Model Details
|
147 |
+
* **Developed by**: [Stability AI](https://stability.ai/)
|
148 |
+
* **Model type**: [InstructBLIP](https://arxiv.org/abs/2305.06500)
|
149 |
+
* **Language(s)**: Japanese
|
150 |
+
* **License**: [JAPANESE STABLELM RESEARCH LICENSE AGREEMENT](./LICENSE).
|
151 |
+
|
152 |
+
### Training
|
153 |
+
Japanese InstructBLIP Alpha leverages the [InstructBLIP](https://arxiv.org/abs/2305.06500) architecture. It consists of 3 components: a frozen vision image encoder, a Q-Former, and a frozen LLM. The vision encoder and the Q-Former were initialized with [Salesforce/instructblip-vicuna-7b](https://huggingface.co/Salesforce/instructblip-vicuna-7b). For the frozen LLM, [Japanese-StableLM-Instruct-Alpha-7B](https://huggingface.co/stabilityai/japanese-stablelm-instruct-alpha-7b) model was used. During training, only Q-Former was trained.
|
154 |
+
|
155 |
+
### Training Dataset
|
156 |
+
The training dataset includes the following public datasets:
|
157 |
+
- [CC12M](https://github.com/google-research-datasets/conceptual-12m) with captions translated into Japanese
|
158 |
+
- [MS-COCO](https://cocodataset.org/#home) with [STAIR Captions](http://captions.stair.center/)
|
159 |
+
- [Japanese Visual Genome VQA dataset](https://github.com/yahoojapan/ja-vg-vqa)
|
160 |
+
|
161 |
+
## Use and Limitations
|
162 |
+
|
163 |
+
### Intended Use
|
164 |
+
|
165 |
+
This model is intended to be used by the open-source community in chat-like applications in adherence with the research license.
|
166 |
+
|
167 |
+
### Limitations and bias
|
168 |
+
|
169 |
+
Although the aforementioned datasets help to steer the base language models into "safer" distributions of text, not all biases and toxicity can be mitigated through fine-tuning. We ask that users be mindful of such potential issues that can arise in generated responses. Do not treat model outputs as substitutes for human judgment or as sources of truth. Please use responsibly.
|
170 |
+
|
171 |
+
|
172 |
+
## How to cite
|
173 |
+
```bibtex
|
174 |
+
@misc{JapaneseInstructBLIPAlpha,
|
175 |
+
url = {[https://huggingface.co/stabilityai/japanese-instructblip-alpha](https://huggingface.co/stabilityai/japanese-instructblip-alpha)},
|
176 |
+
title = {Japanese InstructBLIP Alpha},
|
177 |
+
author = {Shing, Makoto and Akiba, Takuya}
|
178 |
+
}
|
179 |
+
```
|
180 |
+
|
181 |
+
## Citations
|
182 |
+
|
183 |
+
```bibtex
|
184 |
+
@misc{dai2023instructblip,
|
185 |
+
title = {InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning},
|
186 |
+
author = {Wenliang Dai and Junnan Li and Dongxu Li and Anthony Meng Huat Tiong and Junqi Zhao and Weisheng Wang and Boyang Li and Pascale Fung and Steven Hoi},
|
187 |
+
year = {2023},
|
188 |
+
eprint = {2305.06500},
|
189 |
+
archivePrefix = {arXiv},
|
190 |
+
primaryClass = {cs.CV}
|
191 |
+
}
|
192 |
+
```
|
config.json
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "stabilityai/japanese-instructblip-alpha",
|
3 |
+
"architectures": [
|
4 |
+
"JapaneseInstructBlipAlphaForConditionalGeneration"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoModelForVision2Seq": "modeling_japanese_instructblip_alpha.JapaneseInstructBlipAlphaForConditionalGeneration",
|
8 |
+
"AutoConfig": "configuration_japanese_instructblip_alpha.JapaneseInstructBlipAlphaConfig"
|
9 |
+
},
|
10 |
+
"initializer_factor": 1.0,
|
11 |
+
"initializer_range": 0.02,
|
12 |
+
"model_type": "instructblip",
|
13 |
+
"num_query_tokens": 32,
|
14 |
+
"qformer_config": {
|
15 |
+
"_name_or_path": "",
|
16 |
+
"add_cross_attention": false,
|
17 |
+
"architectures": null,
|
18 |
+
"attention_probs_dropout_prob": 0.1,
|
19 |
+
"bad_words_ids": null,
|
20 |
+
"begin_suppress_tokens": null,
|
21 |
+
"bos_token_id": null,
|
22 |
+
"chunk_size_feed_forward": 0,
|
23 |
+
"cross_attention_frequency": 2,
|
24 |
+
"cross_attention_hidden_size": null,
|
25 |
+
"decoder_start_token_id": null,
|
26 |
+
"diversity_penalty": 0.0,
|
27 |
+
"do_sample": false,
|
28 |
+
"early_stopping": false,
|
29 |
+
"encoder_hidden_size": 1408,
|
30 |
+
"encoder_no_repeat_ngram_size": 0,
|
31 |
+
"eos_token_id": null,
|
32 |
+
"exponential_decay_length_penalty": null,
|
33 |
+
"finetuning_task": null,
|
34 |
+
"forced_bos_token_id": null,
|
35 |
+
"forced_eos_token_id": null,
|
36 |
+
"hidden_act": "gelu",
|
37 |
+
"hidden_dropout_prob": 0.1,
|
38 |
+
"hidden_size": 768,
|
39 |
+
"id2label": {
|
40 |
+
"0": "LABEL_0",
|
41 |
+
"1": "LABEL_1"
|
42 |
+
},
|
43 |
+
"initializer_range": 0.02,
|
44 |
+
"intermediate_size": 3072,
|
45 |
+
"is_decoder": false,
|
46 |
+
"is_encoder_decoder": false,
|
47 |
+
"label2id": {
|
48 |
+
"LABEL_0": 0,
|
49 |
+
"LABEL_1": 1
|
50 |
+
},
|
51 |
+
"layer_norm_eps": 1e-12,
|
52 |
+
"length_penalty": 1.0,
|
53 |
+
"max_length": 20,
|
54 |
+
"max_position_embeddings": 512,
|
55 |
+
"min_length": 0,
|
56 |
+
"model_type": "instructblip_qformer",
|
57 |
+
"no_repeat_ngram_size": 0,
|
58 |
+
"num_attention_heads": 12,
|
59 |
+
"num_beam_groups": 1,
|
60 |
+
"num_beams": 1,
|
61 |
+
"num_hidden_layers": 12,
|
62 |
+
"num_return_sequences": 1,
|
63 |
+
"output_attentions": false,
|
64 |
+
"output_hidden_states": false,
|
65 |
+
"output_scores": false,
|
66 |
+
"pad_token_id": 0,
|
67 |
+
"position_embedding_type": "absolute",
|
68 |
+
"prefix": null,
|
69 |
+
"problem_type": null,
|
70 |
+
"pruned_heads": {},
|
71 |
+
"remove_invalid_values": false,
|
72 |
+
"repetition_penalty": 1.0,
|
73 |
+
"return_dict": true,
|
74 |
+
"return_dict_in_generate": false,
|
75 |
+
"sep_token_id": null,
|
76 |
+
"suppress_tokens": null,
|
77 |
+
"task_specific_params": null,
|
78 |
+
"temperature": 1.0,
|
79 |
+
"tf_legacy_loss": false,
|
80 |
+
"tie_encoder_decoder": false,
|
81 |
+
"tie_word_embeddings": true,
|
82 |
+
"tokenizer_class": null,
|
83 |
+
"top_k": 50,
|
84 |
+
"top_p": 1.0,
|
85 |
+
"torch_dtype": null,
|
86 |
+
"torchscript": false,
|
87 |
+
"transformers_version": "4.31.0",
|
88 |
+
"typical_p": 1.0,
|
89 |
+
"use_bfloat16": false,
|
90 |
+
"vocab_size": 65535
|
91 |
+
},
|
92 |
+
"text_config": {
|
93 |
+
"_name_or_path": "stabilityai/japanese-stablelm-instruct-alpha-7b",
|
94 |
+
"add_cross_attention": false,
|
95 |
+
"architectures": [
|
96 |
+
"JapaneseStableLMAlphaForCausalLM"
|
97 |
+
],
|
98 |
+
"auto_map": {
|
99 |
+
"AutoConfig": "stabilityai/japanese-stablelm-instruct-alpha-7b--configuration_japanese_stablelm_alpha.JapaneseStableLMAlphaConfig",
|
100 |
+
"AutoModelForCausalLM": "stabilityai/japanese-stablelm-instruct-alpha-7b--modeling_japanese_stablelm_alpha.JapaneseStableLMAlphaForCausalLM"
|
101 |
+
},
|
102 |
+
"bad_words_ids": null,
|
103 |
+
"begin_suppress_tokens": null,
|
104 |
+
"bos_token_id": 3,
|
105 |
+
"chunk_size_feed_forward": 0,
|
106 |
+
"classifier_dropout": 0.1,
|
107 |
+
"cross_attention_hidden_size": null,
|
108 |
+
"decoder_start_token_id": null,
|
109 |
+
"diversity_penalty": 0.0,
|
110 |
+
"do_sample": false,
|
111 |
+
"early_stopping": false,
|
112 |
+
"encoder_no_repeat_ngram_size": 0,
|
113 |
+
"eos_token_id": 3,
|
114 |
+
"exponential_decay_length_penalty": null,
|
115 |
+
"finetuning_task": null,
|
116 |
+
"forced_bos_token_id": null,
|
117 |
+
"forced_eos_token_id": null,
|
118 |
+
"hidden_act": "silu",
|
119 |
+
"hidden_size": 4096,
|
120 |
+
"id2label": {
|
121 |
+
"0": "LABEL_0",
|
122 |
+
"1": "LABEL_1"
|
123 |
+
},
|
124 |
+
"initializer_range": 0.02,
|
125 |
+
"is_decoder": false,
|
126 |
+
"is_encoder_decoder": false,
|
127 |
+
"label2id": {
|
128 |
+
"LABEL_0": 0,
|
129 |
+
"LABEL_1": 1
|
130 |
+
},
|
131 |
+
"layer_norm_eps": 1e-05,
|
132 |
+
"length_penalty": 1.0,
|
133 |
+
"max_length": 20,
|
134 |
+
"max_position_embeddings": 2048,
|
135 |
+
"min_length": 0,
|
136 |
+
"no_repeat_ngram_size": 0,
|
137 |
+
"num_attention_heads": 32,
|
138 |
+
"num_beam_groups": 1,
|
139 |
+
"num_beams": 1,
|
140 |
+
"num_hidden_layers": 32,
|
141 |
+
"num_return_sequences": 1,
|
142 |
+
"output_attentions": false,
|
143 |
+
"output_hidden_states": false,
|
144 |
+
"output_scores": false,
|
145 |
+
"pad_token_id": null,
|
146 |
+
"prefix": null,
|
147 |
+
"problem_type": null,
|
148 |
+
"pruned_heads": {},
|
149 |
+
"remove_invalid_values": false,
|
150 |
+
"repetition_penalty": 1.0,
|
151 |
+
"return_dict": true,
|
152 |
+
"return_dict_in_generate": false,
|
153 |
+
"rotary_emb_base": 10000,
|
154 |
+
"rotary_pct": 0.25,
|
155 |
+
"rotary_scale_base": 512,
|
156 |
+
"sep_token_id": null,
|
157 |
+
"suppress_tokens": null,
|
158 |
+
"task_specific_params": null,
|
159 |
+
"temperature": 1.0,
|
160 |
+
"tf_legacy_loss": false,
|
161 |
+
"tie_encoder_decoder": false,
|
162 |
+
"tie_word_embeddings": false,
|
163 |
+
"tokenizer_class": null,
|
164 |
+
"top_k": 50,
|
165 |
+
"top_p": 1.0,
|
166 |
+
"torch_dtype": "float32",
|
167 |
+
"torchscript": false,
|
168 |
+
"transformers_version": "4.31.0",
|
169 |
+
"typical_p": 1.0,
|
170 |
+
"use_bfloat16": false,
|
171 |
+
"use_bias_in_mlp": false,
|
172 |
+
"use_cache": true,
|
173 |
+
"use_parallel_residual": true,
|
174 |
+
"vocab_size": 65535
|
175 |
+
},
|
176 |
+
"tie_word_embeddings": false,
|
177 |
+
"torch_dtype": "float32",
|
178 |
+
"transformers_version": null,
|
179 |
+
"use_decoder_only_language_model": true,
|
180 |
+
"vision_config": {
|
181 |
+
"_name_or_path": "",
|
182 |
+
"add_cross_attention": false,
|
183 |
+
"architectures": null,
|
184 |
+
"attention_dropout": 0.0,
|
185 |
+
"bad_words_ids": null,
|
186 |
+
"begin_suppress_tokens": null,
|
187 |
+
"bos_token_id": null,
|
188 |
+
"chunk_size_feed_forward": 0,
|
189 |
+
"cross_attention_hidden_size": null,
|
190 |
+
"decoder_start_token_id": null,
|
191 |
+
"diversity_penalty": 0.0,
|
192 |
+
"do_sample": false,
|
193 |
+
"early_stopping": false,
|
194 |
+
"encoder_no_repeat_ngram_size": 0,
|
195 |
+
"eos_token_id": null,
|
196 |
+
"exponential_decay_length_penalty": null,
|
197 |
+
"finetuning_task": null,
|
198 |
+
"forced_bos_token_id": null,
|
199 |
+
"forced_eos_token_id": null,
|
200 |
+
"hidden_act": "gelu",
|
201 |
+
"hidden_size": 1408,
|
202 |
+
"id2label": {
|
203 |
+
"0": "LABEL_0",
|
204 |
+
"1": "LABEL_1"
|
205 |
+
},
|
206 |
+
"image_size": 224,
|
207 |
+
"initializer_range": 1e-10,
|
208 |
+
"intermediate_size": 6144,
|
209 |
+
"is_decoder": false,
|
210 |
+
"is_encoder_decoder": false,
|
211 |
+
"label2id": {
|
212 |
+
"LABEL_0": 0,
|
213 |
+
"LABEL_1": 1
|
214 |
+
},
|
215 |
+
"layer_norm_eps": 1e-06,
|
216 |
+
"length_penalty": 1.0,
|
217 |
+
"max_length": 20,
|
218 |
+
"min_length": 0,
|
219 |
+
"model_type": "instructblip_vision_model",
|
220 |
+
"no_repeat_ngram_size": 0,
|
221 |
+
"num_attention_heads": 16,
|
222 |
+
"num_beam_groups": 1,
|
223 |
+
"num_beams": 1,
|
224 |
+
"num_hidden_layers": 39,
|
225 |
+
"num_return_sequences": 1,
|
226 |
+
"output_attentions": false,
|
227 |
+
"output_hidden_states": false,
|
228 |
+
"output_scores": false,
|
229 |
+
"pad_token_id": null,
|
230 |
+
"patch_size": 14,
|
231 |
+
"prefix": null,
|
232 |
+
"problem_type": null,
|
233 |
+
"pruned_heads": {},
|
234 |
+
"qkv_bias": true,
|
235 |
+
"remove_invalid_values": false,
|
236 |
+
"repetition_penalty": 1.0,
|
237 |
+
"return_dict": true,
|
238 |
+
"return_dict_in_generate": false,
|
239 |
+
"sep_token_id": null,
|
240 |
+
"suppress_tokens": null,
|
241 |
+
"task_specific_params": null,
|
242 |
+
"temperature": 1.0,
|
243 |
+
"tf_legacy_loss": false,
|
244 |
+
"tie_encoder_decoder": false,
|
245 |
+
"tie_word_embeddings": true,
|
246 |
+
"tokenizer_class": null,
|
247 |
+
"top_k": 50,
|
248 |
+
"top_p": 1.0,
|
249 |
+
"torch_dtype": null,
|
250 |
+
"torchscript": false,
|
251 |
+
"transformers_version": "4.31.0",
|
252 |
+
"typical_p": 1.0,
|
253 |
+
"use_bfloat16": false
|
254 |
+
}
|
255 |
+
}
|
configuration_japanese_instructblip_alpha.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Stability and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Japanese InstructBLIP Alpha model configuration"""
|
16 |
+
|
17 |
+
from transformers import (
|
18 |
+
PretrainedConfig,
|
19 |
+
InstructBlipConfig,
|
20 |
+
InstructBlipVisionConfig,
|
21 |
+
InstructBlipQFormerConfig,
|
22 |
+
AutoConfig,
|
23 |
+
)
|
24 |
+
from transformers.utils import logging
|
25 |
+
from .configuration_japanese_stablelm_alpha import JapaneseStableLMAlphaConfig
|
26 |
+
|
27 |
+
|
28 |
+
logger = logging.get_logger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
class JapaneseInstructBlipAlphaConfig(InstructBlipConfig):
|
32 |
+
def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
|
33 |
+
PretrainedConfig.__init__(self, **kwargs)
|
34 |
+
|
35 |
+
if vision_config is None:
|
36 |
+
vision_config = {}
|
37 |
+
logger.info("vision_config is None. initializing the InstructBlipVisionConfig with default values.")
|
38 |
+
|
39 |
+
if qformer_config is None:
|
40 |
+
qformer_config = {}
|
41 |
+
logger.info("qformer_config is None. Initializing the InstructBlipQFormerConfig with default values.")
|
42 |
+
|
43 |
+
if text_config is None:
|
44 |
+
text_config = {}
|
45 |
+
logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).")
|
46 |
+
self.vision_config = InstructBlipVisionConfig(**vision_config)
|
47 |
+
self.qformer_config = InstructBlipQFormerConfig(**qformer_config)
|
48 |
+
self.text_config = JapaneseStableLMAlphaConfig(**text_config)
|
49 |
+
|
50 |
+
self.tie_word_embeddings = self.text_config.tie_word_embeddings
|
51 |
+
self.is_encoder_decoder = self.text_config.is_encoder_decoder
|
52 |
+
|
53 |
+
self.num_query_tokens = num_query_tokens
|
54 |
+
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
|
55 |
+
self.use_decoder_only_language_model = True
|
56 |
+
self.initializer_factor = 1.0
|
57 |
+
self.initializer_range = 0.02
|
configuration_japanese_stablelm_alpha.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Stability and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" JapaneseStableLMAlpha model configuration"""
|
16 |
+
|
17 |
+
from transformers import PretrainedConfig
|
18 |
+
from transformers.utils import logging
|
19 |
+
|
20 |
+
|
21 |
+
logger = logging.get_logger(__name__)
|
22 |
+
|
23 |
+
STABLE_LM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
24 |
+
|
25 |
+
|
26 |
+
class JapaneseStableLMAlphaConfig(PretrainedConfig):
|
27 |
+
r"""
|
28 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
29 |
+
documentation from [`PretrainedConfig`] for more information.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
vocab_size (`int`, *optional*, defaults to 65536):
|
33 |
+
Vocabulary size of the JapaneseStableLMAlphaModel. Defines the number of different tokens that
|
34 |
+
can be represented by the `inputs_ids` passed when calling [`JapaneseStableLMAlphaModel`].
|
35 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
36 |
+
Dimension of the decoder layers and the pooler layer.
|
37 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
38 |
+
Number of hidden layers in the Transformer decoder.
|
39 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
40 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
41 |
+
intermediate_size (`int`, *optional*, defaults to 16384):
|
42 |
+
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer decoder.
|
43 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
44 |
+
The non-linear activation function (function or string).
|
45 |
+
rotary_pct (`float`, *optional*, defaults to 0.25):
|
46 |
+
Percentage of hidden dimensions to allocate to rotary embeddings.
|
47 |
+
rotary_emb_base (`int`, *optional*, defaults to 10000)
|
48 |
+
Base for computing rotary embeddings frequency.
|
49 |
+
rotary_scale_base (`int`, *optional*, defaults to 512)
|
50 |
+
Base `scale` for computing XPos rotary embeddings scale.
|
51 |
+
classifier_dropout (`float`, *optional*, defaults to 0.1):
|
52 |
+
Argument used when doing token classification, used in the model
|
53 |
+
[`StableLMForTokenClassification`]. The dropout ratio for the hidden layer.
|
54 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
55 |
+
The maximum sequence length that this model might ever be used with.
|
56 |
+
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
|
57 |
+
initializer_range (`float`, *optional*, defaults to 1e-5):
|
58 |
+
The standard deviation of the truncated_normal_initializer for initializing
|
59 |
+
all weight matrices.
|
60 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
61 |
+
The epsilon used by the layer normalization layers.
|
62 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
63 |
+
Whether or not the model should return the last key/values attentions
|
64 |
+
(not used by all models). Only relevant if `config.is_decoder=True`.
|
65 |
+
use_parallel_residual (`bool`, *optional*, defaults to `True`):
|
66 |
+
Whether to use a "parallel" formulation in each Transformer layer,
|
67 |
+
which can provide a slight training speedup at large scales.
|
68 |
+
Example:
|
69 |
+
|
70 |
+
```python
|
71 |
+
>>> from transformers import JapaneseStableLMAlphaConfig, JapaneseStableLMAlphaModel
|
72 |
+
|
73 |
+
>>> # Initializing a JapaneseStableLMAlpha style configuration
|
74 |
+
>>> configuration = JapaneseStableLMAlphaConfig()
|
75 |
+
|
76 |
+
>>> # Initializing a model (with random weights) from the style configuration
|
77 |
+
>>> model = JapaneseStableLMAlphaModel(configuration) # doctest: +SKIP
|
78 |
+
|
79 |
+
>>> # Accessing the model configuration
|
80 |
+
>>> configuration = model.config # doctest: +SKIP
|
81 |
+
```"""
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
vocab_size=65536,
|
85 |
+
hidden_size=4096,
|
86 |
+
num_hidden_layers=32,
|
87 |
+
num_attention_heads=32,
|
88 |
+
hidden_act="silu",
|
89 |
+
rotary_pct=0.25,
|
90 |
+
rotary_emb_base=10000,
|
91 |
+
rotary_scale_base=512,
|
92 |
+
classifier_dropout=0.1,
|
93 |
+
max_position_embeddings=2048,
|
94 |
+
initializer_range=0.02,
|
95 |
+
layer_norm_eps=1e-5,
|
96 |
+
use_cache=True,
|
97 |
+
bos_token_id=3,
|
98 |
+
eos_token_id=3,
|
99 |
+
tie_word_embeddings=False,
|
100 |
+
use_parallel_residual=True,
|
101 |
+
use_bias_in_mlp=True,
|
102 |
+
**kwargs,
|
103 |
+
):
|
104 |
+
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
105 |
+
self.vocab_size = vocab_size
|
106 |
+
self.max_position_embeddings = max_position_embeddings
|
107 |
+
self.hidden_size = hidden_size
|
108 |
+
self.num_hidden_layers = num_hidden_layers
|
109 |
+
self.num_attention_heads = num_attention_heads
|
110 |
+
self.hidden_act = hidden_act
|
111 |
+
self.rotary_pct = rotary_pct
|
112 |
+
self.rotary_emb_base = rotary_emb_base
|
113 |
+
self.rotary_scale_base = rotary_scale_base
|
114 |
+
self.classifier_dropout = classifier_dropout
|
115 |
+
self.initializer_range = initializer_range
|
116 |
+
self.layer_norm_eps = layer_norm_eps
|
117 |
+
self.use_cache = use_cache
|
118 |
+
self.tie_word_embeddings = tie_word_embeddings
|
119 |
+
self.use_parallel_residual = use_parallel_residual
|
120 |
+
self.use_bias_in_mlp = use_bias_in_mlp
|
japanese-instructblip-parrot.png
ADDED
![]() |
modeling_japanese_instructblip_alpha.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Stability and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch JapaneseStableLMAlpha model. """
|
16 |
+
import torch
|
17 |
+
from torch import nn
|
18 |
+
from transformers import (
|
19 |
+
InstructBlipPreTrainedModel,
|
20 |
+
InstructBlipVisionModel,
|
21 |
+
InstructBlipQFormerModel,
|
22 |
+
InstructBlipForConditionalGeneration,
|
23 |
+
AutoModelForCausalLM,
|
24 |
+
AutoModelForSeq2SeqLM,
|
25 |
+
)
|
26 |
+
from transformers.utils import logging
|
27 |
+
from .modeling_japanese_stablelm_alpha import JapaneseStableLMAlphaForCausalLM
|
28 |
+
from .configuration_japanese_instructblip_alpha import JapaneseInstructBlipAlphaConfig
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.get_logger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
class JapaneseInstructBlipAlphaForConditionalGeneration(InstructBlipForConditionalGeneration):
|
35 |
+
config_class = JapaneseInstructBlipAlphaConfig
|
36 |
+
|
37 |
+
def __init__(self, config: JapaneseInstructBlipAlphaConfig):
|
38 |
+
InstructBlipPreTrainedModel.__init__(self, config)
|
39 |
+
|
40 |
+
self.vision_model = InstructBlipVisionModel(config.vision_config)
|
41 |
+
|
42 |
+
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
43 |
+
self.qformer = InstructBlipQFormerModel(config.qformer_config)
|
44 |
+
|
45 |
+
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
|
46 |
+
|
47 |
+
if config.use_decoder_only_language_model:
|
48 |
+
language_model = JapaneseStableLMAlphaForCausalLM(config.text_config)
|
49 |
+
else:
|
50 |
+
raise NotImplementedError
|
51 |
+
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config, trust_remote_code=True,)
|
52 |
+
|
53 |
+
if language_model._no_split_modules is not None:
|
54 |
+
self._no_split_modules.extend(language_model._no_split_modules)
|
55 |
+
|
56 |
+
if language_model._keep_in_fp32_modules is not None:
|
57 |
+
self._keep_in_fp32_modules.extend(language_model._keep_in_fp32_modules)
|
58 |
+
|
59 |
+
self.language_model = language_model
|
60 |
+
|
61 |
+
# Initialize weights and apply final processing
|
62 |
+
self.post_init()
|
modeling_japanese_stablelm_alpha.py
ADDED
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Stability and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch JapaneseStableLMAlpha model. """
|
16 |
+
from typing import Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
from torch import nn
|
21 |
+
from torch.nn import CrossEntropyLoss
|
22 |
+
from transformers.modeling_outputs import (
|
23 |
+
BaseModelOutputWithPast,
|
24 |
+
CausalLMOutputWithPast,
|
25 |
+
)
|
26 |
+
from transformers.modeling_utils import PreTrainedModel
|
27 |
+
from transformers.utils import logging
|
28 |
+
from .configuration_japanese_stablelm_alpha import JapaneseStableLMAlphaConfig
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.get_logger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
class JapaneseStableLMAlphaPreTrainedModel(PreTrainedModel):
|
35 |
+
"""
|
36 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
37 |
+
models.
|
38 |
+
"""
|
39 |
+
|
40 |
+
config_class = JapaneseStableLMAlphaConfig
|
41 |
+
base_model_prefix = "transformer"
|
42 |
+
supports_gradient_checkpointing = True
|
43 |
+
_no_split_modules = ["DecoderLayer"]
|
44 |
+
_skip_keys_device_placement = "past_key_values"
|
45 |
+
|
46 |
+
def _init_weights(self, module):
|
47 |
+
"""Initialize the weights"""
|
48 |
+
if isinstance(module, nn.Linear):
|
49 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
50 |
+
if module.bias is not None:
|
51 |
+
module.bias.data.zero_()
|
52 |
+
elif isinstance(module, nn.Embedding):
|
53 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
54 |
+
if module.padding_idx is not None:
|
55 |
+
module.weight.data[module.padding_idx].zero_()
|
56 |
+
elif isinstance(module, nn.LayerNorm):
|
57 |
+
if module.bias is not None:
|
58 |
+
module.bias.data.zero_()
|
59 |
+
if module.weight is not None:
|
60 |
+
module.weight.data.fill_(1.0)
|
61 |
+
|
62 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
63 |
+
if isinstance(module, JapaneseStableLMAlphaModel):
|
64 |
+
module.gradient_checkpointing = value
|
65 |
+
|
66 |
+
|
67 |
+
class JapaneseStableLMAlphaModel(JapaneseStableLMAlphaPreTrainedModel):
|
68 |
+
def __init__(self, config):
|
69 |
+
super().__init__(config)
|
70 |
+
self.config = config
|
71 |
+
|
72 |
+
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
|
73 |
+
self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
74 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
75 |
+
|
76 |
+
self.gradient_checkpointing = False
|
77 |
+
|
78 |
+
# Initialize weights and apply final processing
|
79 |
+
self.post_init()
|
80 |
+
|
81 |
+
def get_input_embeddings(self):
|
82 |
+
return self.embed_in
|
83 |
+
|
84 |
+
def set_input_embeddings(self, value):
|
85 |
+
self.embed_in = value
|
86 |
+
|
87 |
+
def forward(
|
88 |
+
self,
|
89 |
+
input_ids: Optional[torch.LongTensor] = None,
|
90 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
91 |
+
position_ids: Optional[torch.LongTensor] = None,
|
92 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
93 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
94 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
95 |
+
use_cache: Optional[bool] = None,
|
96 |
+
output_attentions: Optional[bool] = None,
|
97 |
+
output_hidden_states: Optional[bool] = None,
|
98 |
+
return_dict: Optional[bool] = None,
|
99 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
100 |
+
r"""
|
101 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
102 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
103 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
104 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
105 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
106 |
+
use_cache (`bool`, *optional*):
|
107 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
108 |
+
`past_key_values`).
|
109 |
+
"""
|
110 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
111 |
+
output_hidden_states = (
|
112 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
113 |
+
)
|
114 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
115 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
116 |
+
|
117 |
+
if input_ids is not None and inputs_embeds is not None:
|
118 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
119 |
+
elif input_ids is not None:
|
120 |
+
input_shape = input_ids.size()
|
121 |
+
elif inputs_embeds is not None:
|
122 |
+
input_shape = inputs_embeds.size()[:-1]
|
123 |
+
else:
|
124 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
125 |
+
|
126 |
+
batch_size, seq_length = input_shape
|
127 |
+
|
128 |
+
if past_key_values is None:
|
129 |
+
past_length = 0
|
130 |
+
past_key_values = tuple([None] * self.config.num_hidden_layers)
|
131 |
+
else:
|
132 |
+
past_length = past_key_values[0][0].size(-2)
|
133 |
+
|
134 |
+
if position_ids is None:
|
135 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
136 |
+
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
|
137 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
138 |
+
else:
|
139 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
140 |
+
|
141 |
+
# Attention mask.
|
142 |
+
if attention_mask is not None:
|
143 |
+
assert batch_size > 0, "batch_size has to be defined and > 0"
|
144 |
+
attention_mask = attention_mask.view(batch_size, -1)
|
145 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
146 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
147 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
148 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
149 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
150 |
+
attention_mask = attention_mask[:, None, None, :]
|
151 |
+
|
152 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
153 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
154 |
+
# positions we want to attend and the dtype's smallest value for masked positions.
|
155 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
156 |
+
# effectively the same as removing these entirely.
|
157 |
+
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
158 |
+
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
159 |
+
|
160 |
+
# Prepare head mask if needed
|
161 |
+
# 1.0 in head_mask indicate we keep the head
|
162 |
+
# attention_probs has shape bsz x n_heads x N x N
|
163 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
164 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
165 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
166 |
+
|
167 |
+
if inputs_embeds is None:
|
168 |
+
inputs_embeds = self.embed_in(input_ids)
|
169 |
+
|
170 |
+
hidden_states = inputs_embeds
|
171 |
+
|
172 |
+
if self.gradient_checkpointing and self.training:
|
173 |
+
if use_cache:
|
174 |
+
logger.warning(
|
175 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
176 |
+
)
|
177 |
+
use_cache = False
|
178 |
+
|
179 |
+
presents = () if use_cache else None
|
180 |
+
all_attentions = () if output_attentions else None
|
181 |
+
all_hidden_states = () if output_hidden_states else None
|
182 |
+
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
|
183 |
+
if output_hidden_states:
|
184 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
185 |
+
|
186 |
+
if self.gradient_checkpointing and self.training:
|
187 |
+
|
188 |
+
def create_custom_forward(module):
|
189 |
+
def custom_forward(*inputs):
|
190 |
+
# None for layer_past
|
191 |
+
return module(*inputs, use_cache, None, output_attentions)
|
192 |
+
|
193 |
+
return custom_forward
|
194 |
+
|
195 |
+
outputs = torch.utils.checkpoint.checkpoint(
|
196 |
+
create_custom_forward(layer),
|
197 |
+
hidden_states,
|
198 |
+
attention_mask,
|
199 |
+
position_ids,
|
200 |
+
head_mask[i],
|
201 |
+
)
|
202 |
+
else:
|
203 |
+
outputs = layer(
|
204 |
+
hidden_states,
|
205 |
+
attention_mask=attention_mask,
|
206 |
+
position_ids=position_ids,
|
207 |
+
head_mask=head_mask[i],
|
208 |
+
layer_past=layer_past,
|
209 |
+
use_cache=use_cache,
|
210 |
+
output_attentions=output_attentions,
|
211 |
+
)
|
212 |
+
hidden_states = outputs[0]
|
213 |
+
if use_cache is True:
|
214 |
+
presents = presents + (outputs[1],)
|
215 |
+
if output_attentions:
|
216 |
+
all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
|
217 |
+
|
218 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
219 |
+
# Add last hidden state
|
220 |
+
if output_hidden_states:
|
221 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
222 |
+
|
223 |
+
if not return_dict:
|
224 |
+
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
|
225 |
+
|
226 |
+
return BaseModelOutputWithPast(
|
227 |
+
last_hidden_state=hidden_states,
|
228 |
+
past_key_values=presents,
|
229 |
+
hidden_states=all_hidden_states,
|
230 |
+
attentions=all_attentions,
|
231 |
+
)
|
232 |
+
|
233 |
+
|
234 |
+
class DecoderLayer(nn.Module):
|
235 |
+
def __init__(self, config):
|
236 |
+
super().__init__()
|
237 |
+
self.use_parallel_residual = config.use_parallel_residual
|
238 |
+
self.input_layernorm = nn.LayerNorm(
|
239 |
+
config.hidden_size,
|
240 |
+
eps=config.layer_norm_eps,
|
241 |
+
elementwise_affine=False,
|
242 |
+
)
|
243 |
+
self.post_attention_layernorm = nn.LayerNorm(
|
244 |
+
config.hidden_size,
|
245 |
+
eps=config.layer_norm_eps
|
246 |
+
)
|
247 |
+
self.attention = Attention(config)
|
248 |
+
self.mlp = MLP(config)
|
249 |
+
|
250 |
+
def forward(
|
251 |
+
self,
|
252 |
+
hidden_states: Optional[torch.FloatTensor],
|
253 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
254 |
+
position_ids: Optional[torch.LongTensor] = None,
|
255 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
256 |
+
use_cache: Optional[bool] = False,
|
257 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
258 |
+
output_attentions: Optional[bool] = False,
|
259 |
+
):
|
260 |
+
attention_layer_outputs = self.attention(
|
261 |
+
self.input_layernorm(hidden_states),
|
262 |
+
attention_mask=attention_mask,
|
263 |
+
position_ids=position_ids,
|
264 |
+
layer_past=layer_past,
|
265 |
+
head_mask=head_mask,
|
266 |
+
use_cache=use_cache,
|
267 |
+
output_attentions=output_attentions,
|
268 |
+
)
|
269 |
+
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
|
270 |
+
outputs = attention_layer_outputs[1:]
|
271 |
+
|
272 |
+
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
|
273 |
+
hidden_states = hidden_states + mlp_output + attn_output
|
274 |
+
|
275 |
+
if use_cache:
|
276 |
+
outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights)
|
277 |
+
else:
|
278 |
+
outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
|
279 |
+
|
280 |
+
return outputs
|
281 |
+
|
282 |
+
|
283 |
+
class MLP(nn.Module):
|
284 |
+
def __init__(self, config: JapaneseStableLMAlphaConfig):
|
285 |
+
super().__init__()
|
286 |
+
hidden_size = config.hidden_size
|
287 |
+
multiple_of = 256
|
288 |
+
ff_dim = int(8 * hidden_size / 3)
|
289 |
+
intermediate_size = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
|
290 |
+
|
291 |
+
self.packed_input_proj = torch.nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
|
292 |
+
self.out_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
293 |
+
self.act = nn.SiLU()
|
294 |
+
|
295 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
296 |
+
ff, ff_gate = self.packed_input_proj(x).chunk(2, dim=-1)
|
297 |
+
return self.out_proj(ff * self.act(ff_gate))
|
298 |
+
|
299 |
+
|
300 |
+
class RotaryEmbedding(torch.nn.Module):
|
301 |
+
"""Based on Tri Dao's XPos: https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/layers/rotary.py"""
|
302 |
+
def __init__(
|
303 |
+
self,
|
304 |
+
dim: int,
|
305 |
+
max_position_embeddings: int,
|
306 |
+
base: int = 10_000,
|
307 |
+
scale_base: int = 512,
|
308 |
+
device: str = None
|
309 |
+
):
|
310 |
+
super().__init__()
|
311 |
+
self.dim = dim
|
312 |
+
self.seq_len_cached = max_position_embeddings
|
313 |
+
|
314 |
+
# Set up `inv_freq` term
|
315 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
316 |
+
self.register_buffer("inv_freq", inv_freq)
|
317 |
+
|
318 |
+
# Set up `scale` term
|
319 |
+
self.scale_base = scale_base
|
320 |
+
scale = (
|
321 |
+
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
322 |
+
if scale_base is not None else None
|
323 |
+
)
|
324 |
+
self.register_buffer("scale", scale)
|
325 |
+
|
326 |
+
# Seet up `cos..` and `sin...` cache terms
|
327 |
+
t = torch.arange(self.seq_len_cached, device=device, dtype=torch.float32)
|
328 |
+
freqs = torch.outer(t, self.inv_freq)
|
329 |
+
# freqs = torch.cat((freqs, freqs), dim=-1)
|
330 |
+
seq_range = torch.arange(self.seq_len_cached, dtype=self.scale.dtype, device=self.scale.device)
|
331 |
+
power = (seq_range - self.seq_len_cached // 2) / self.scale_base
|
332 |
+
scale_cached = self.scale.to(device=power.device) ** power.unsqueeze(-1)
|
333 |
+
# scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
|
334 |
+
self.register_buffer("cos_cached", torch.cos(freqs) * scale_cached, persistent=False)
|
335 |
+
self.register_buffer("sin_cached", torch.sin(freqs) * scale_cached, persistent=False)
|
336 |
+
self.register_buffer("cos_k_cached", torch.cos(freqs) / scale_cached, persistent=False)
|
337 |
+
self.register_buffer("sin_k_cached", torch.sin(freqs) / scale_cached, persistent=False)
|
338 |
+
|
339 |
+
def forward(self, x, seq_len=None):
|
340 |
+
if seq_len > self.seq_len_cached:
|
341 |
+
self.seq_len_cached = seq_len
|
342 |
+
t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
|
343 |
+
freqs = torch.outer(t, self.inv_freq)
|
344 |
+
freqs = torch.cat((freqs, freqs), dim=-1)
|
345 |
+
seq_range = torch.arange(self.seq_len_cached, dtype=self.scale.dtype, device=self.scale.device)
|
346 |
+
power = (seq_range - self.seq_len_cached // 2) / self.scale_base
|
347 |
+
scale_cached = self.scale.to(device=power.device) ** power.unsqueeze(-1)
|
348 |
+
scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
|
349 |
+
self.register_buffer("cos_cached", torch.cos(freqs) * scale_cached, persistent=False)
|
350 |
+
self.register_buffer("sin_cached", torch.sin(freqs) * scale_cached, persistent=False)
|
351 |
+
self.register_buffer("cos_k_cached", torch.cos(freqs) / scale_cached, persistent=False)
|
352 |
+
self.register_buffer("sin_k_cached", torch.sin(freqs) / scale_cached, persistent=False)
|
353 |
+
return (
|
354 |
+
self.cos_cached[:seq_len, ...],
|
355 |
+
self.sin_cached[:seq_len, ...],
|
356 |
+
self.cos_k_cached[:seq_len, ...],
|
357 |
+
self.sin_k_cached[:seq_len, ...],
|
358 |
+
)
|
359 |
+
|
360 |
+
|
361 |
+
def rotate_half(x):
|
362 |
+
x1, x2 = x.chunk(2, dim=-1)
|
363 |
+
return torch.cat((-x2, x1), dim=-1)
|
364 |
+
|
365 |
+
|
366 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, cos_k=None, sin_k=None):
|
367 |
+
"""
|
368 |
+
q, k: [bs, num_heads, seq_len, rot_dim]
|
369 |
+
cos, sin: [seq_len, rot_dim / 2]
|
370 |
+
position_ids: [bs, seq_len]
|
371 |
+
"""
|
372 |
+
# print(f"q: {q.shape}, k: {k.shape}, cos: {cos.shape}, sin: {sin.shape}, position_ids: {position_ids.shape}")
|
373 |
+
import einops
|
374 |
+
cos = einops.repeat(cos, 's r -> s (2 r)')
|
375 |
+
sin = einops.repeat(sin, 's r -> s (2 r)')
|
376 |
+
cos_k = einops.repeat(cos_k, 's r -> s (2 r)')
|
377 |
+
sin_k = einops.repeat(sin_k, 's r -> s (2 r)')
|
378 |
+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
|
379 |
+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
|
380 |
+
cos_k = cos_k[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
|
381 |
+
sin_k = sin_k[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
|
382 |
+
|
383 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
384 |
+
k_embed = (k * cos_k) + (rotate_half(k) * sin_k)
|
385 |
+
return q_embed, k_embed
|
386 |
+
|
387 |
+
|
388 |
+
class Attention(nn.Module):
|
389 |
+
def __init__(self, config):
|
390 |
+
super().__init__()
|
391 |
+
self.num_attention_heads = config.num_attention_heads
|
392 |
+
self.hidden_size = config.hidden_size
|
393 |
+
if self.hidden_size % self.num_attention_heads != 0:
|
394 |
+
raise ValueError(
|
395 |
+
"The hidden size is not divisble by the number of attention heads! Make sure to update them"
|
396 |
+
)
|
397 |
+
self.head_size = self.hidden_size // self.num_attention_heads
|
398 |
+
|
399 |
+
max_positions = config.max_position_embeddings
|
400 |
+
self.register_buffer(
|
401 |
+
"bias",
|
402 |
+
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
403 |
+
1, 1, max_positions, max_positions
|
404 |
+
),
|
405 |
+
persistent=False,
|
406 |
+
)
|
407 |
+
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
|
408 |
+
|
409 |
+
self.rotary_ndims = int(self.head_size * config.rotary_pct)
|
410 |
+
self.rotary_emb = RotaryEmbedding(
|
411 |
+
self.rotary_ndims,
|
412 |
+
max_position_embeddings=config.max_position_embeddings,
|
413 |
+
base=config.rotary_emb_base,
|
414 |
+
scale_base=config.rotary_scale_base,
|
415 |
+
)
|
416 |
+
|
417 |
+
self.register_buffer(
|
418 |
+
"norm_factor",
|
419 |
+
torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),
|
420 |
+
persistent=False,
|
421 |
+
)
|
422 |
+
|
423 |
+
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
|
424 |
+
self.dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
425 |
+
|
426 |
+
def forward(
|
427 |
+
self,
|
428 |
+
hidden_states: torch.FloatTensor,
|
429 |
+
attention_mask: torch.FloatTensor,
|
430 |
+
position_ids: torch.LongTensor,
|
431 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
432 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
433 |
+
use_cache: Optional[bool] = False,
|
434 |
+
output_attentions: Optional[bool] = False,
|
435 |
+
):
|
436 |
+
has_layer_past = layer_past is not None
|
437 |
+
|
438 |
+
# Compute QKV
|
439 |
+
# Attention heads [batch, seq_len, hidden_size]
|
440 |
+
# --> [batch, seq_len, (np * 3 * head_size)]
|
441 |
+
qkv = self.query_key_value(hidden_states)
|
442 |
+
|
443 |
+
# [batch, seq_len, (num_heads * 3 * head_size)]
|
444 |
+
# --> [batch, seq_len, num_heads, 3 * head_size]
|
445 |
+
new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
|
446 |
+
qkv = qkv.view(*new_qkv_shape)
|
447 |
+
|
448 |
+
# [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
|
449 |
+
query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
|
450 |
+
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
|
451 |
+
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
|
452 |
+
|
453 |
+
# Compute rotary embeddings on rotary_ndims
|
454 |
+
query_rot = query[..., : self.rotary_ndims]
|
455 |
+
query_pass = query[..., self.rotary_ndims :]
|
456 |
+
key_rot = key[..., : self.rotary_ndims]
|
457 |
+
key_pass = key[..., self.rotary_ndims :]
|
458 |
+
|
459 |
+
# Compute token offset for rotary embeddings (when decoding)
|
460 |
+
kv_seq_len = key.shape[-2]
|
461 |
+
if has_layer_past:
|
462 |
+
kv_seq_len += layer_past[0].shape[-2]
|
463 |
+
|
464 |
+
# Add rotary embeddings to query and key
|
465 |
+
# TODO: Check if using xpos
|
466 |
+
cos, sin, cos_k, sin_k = self.rotary_emb(value, seq_len=kv_seq_len)
|
467 |
+
query, key = apply_rotary_pos_emb(
|
468 |
+
query_rot, key_rot, cos, sin, position_ids, cos_k=cos_k, sin_k=sin_k)
|
469 |
+
|
470 |
+
query = torch.cat((query, query_pass), dim=-1)
|
471 |
+
key = torch.cat((key, key_pass), dim=-1)
|
472 |
+
|
473 |
+
# Cache QKV values
|
474 |
+
if has_layer_past:
|
475 |
+
past_key = layer_past[0]
|
476 |
+
past_value = layer_past[1]
|
477 |
+
key = torch.cat((past_key, key), dim=-2)
|
478 |
+
value = torch.cat((past_value, value), dim=-2)
|
479 |
+
present = (key, value) if use_cache else None
|
480 |
+
|
481 |
+
# Compute attention
|
482 |
+
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
483 |
+
|
484 |
+
# Merge attn_head_size dim and num_attn_heads dim into hidden dim
|
485 |
+
# [bs, seq_len, num_attention_heads, attn_head_size]
|
486 |
+
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
|
487 |
+
attn_output = attn_output.view(attn_output.size(0), attn_output.size(1), self.num_attention_heads * self.head_size)
|
488 |
+
|
489 |
+
attn_output = self.dense(attn_output)
|
490 |
+
|
491 |
+
outputs = (attn_output, present)
|
492 |
+
if output_attentions:
|
493 |
+
outputs += (attn_weights,)
|
494 |
+
|
495 |
+
return outputs
|
496 |
+
|
497 |
+
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
498 |
+
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
|
499 |
+
# compute causal mask from causal mask buffer
|
500 |
+
|
501 |
+
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
502 |
+
key_length = key.size(-2)
|
503 |
+
|
504 |
+
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
505 |
+
|
506 |
+
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
|
507 |
+
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
|
508 |
+
attn_scores = torch.zeros(
|
509 |
+
batch_size * num_attention_heads,
|
510 |
+
query_length,
|
511 |
+
key_length,
|
512 |
+
dtype=query.dtype,
|
513 |
+
device=key.device,
|
514 |
+
)
|
515 |
+
attn_scores = torch.baddbmm(
|
516 |
+
attn_scores,
|
517 |
+
query,
|
518 |
+
key.transpose(1, 2),
|
519 |
+
beta=1.0,
|
520 |
+
alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor),
|
521 |
+
)
|
522 |
+
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
|
523 |
+
|
524 |
+
mask_value = torch.finfo(attn_scores.dtype).min
|
525 |
+
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
526 |
+
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
527 |
+
mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype, device=attn_scores.device)
|
528 |
+
attn_scores = torch.where(causal_mask, attn_scores, mask_value)
|
529 |
+
|
530 |
+
if attention_mask is not None:
|
531 |
+
# Apply the attention mask
|
532 |
+
attn_scores = attn_scores + attention_mask
|
533 |
+
|
534 |
+
# NOTE: Upcast to float32
|
535 |
+
attn_weights = nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).type_as(value)
|
536 |
+
|
537 |
+
# Mask heads if we want to
|
538 |
+
if head_mask is not None:
|
539 |
+
attn_weights = attn_weights * head_mask
|
540 |
+
|
541 |
+
attn_output = torch.matmul(attn_weights, value)
|
542 |
+
return attn_output, attn_weights
|
543 |
+
|
544 |
+
|
545 |
+
def attention_mask_func(attention_scores, ltor_mask):
|
546 |
+
attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min)
|
547 |
+
return attention_scores
|
548 |
+
|
549 |
+
|
550 |
+
class JapaneseStableLMAlphaForCausalLM(JapaneseStableLMAlphaPreTrainedModel):
|
551 |
+
_tied_weights_keys = ["embed_out.weight"]
|
552 |
+
|
553 |
+
def __init__(self, config):
|
554 |
+
super().__init__(config)
|
555 |
+
|
556 |
+
self.transformer = JapaneseStableLMAlphaModel(config)
|
557 |
+
self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
558 |
+
|
559 |
+
# Initialize weights and apply final processing
|
560 |
+
self.post_init()
|
561 |
+
|
562 |
+
def get_output_embeddings(self):
|
563 |
+
return self.embed_out
|
564 |
+
|
565 |
+
def set_output_embeddings(self, new_embeddings):
|
566 |
+
self.embed_out = new_embeddings
|
567 |
+
|
568 |
+
def forward(
|
569 |
+
self,
|
570 |
+
input_ids: Optional[torch.LongTensor] = None,
|
571 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
572 |
+
position_ids: Optional[torch.LongTensor] = None,
|
573 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
574 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
575 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
576 |
+
labels: Optional[torch.LongTensor] = None,
|
577 |
+
use_cache: Optional[bool] = None,
|
578 |
+
output_attentions: Optional[bool] = None,
|
579 |
+
output_hidden_states: Optional[bool] = None,
|
580 |
+
return_dict: Optional[bool] = None,
|
581 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
582 |
+
r"""
|
583 |
+
Example:
|
584 |
+
|
585 |
+
```python
|
586 |
+
>>> import torch
|
587 |
+
>>> from transformers import LlamaTokenizer, JapaneseStableLMAlphaForCausalLM, JapaneseStableLMAlphaConfig
|
588 |
+
|
589 |
+
>>> tokenizer = LlamaTokenizer.from_pretrained("novelai/nerdstash-tokenizer-v1")
|
590 |
+
>>> config = JapaneseStableLMAlphaConfig.from_pretrained("stabilityai/stablelm-ja-base-alpha-7b")
|
591 |
+
>>> config.is_decoder = True
|
592 |
+
>>> model = JapaneseStableLMAlphaForCausalLM.from_pretrained("stabilityai/stablelm-ja-base-alpha-7b", config=config, trust_remote_code=True)
|
593 |
+
|
594 |
+
>>> inputs = tokenizer("日本語の美しいところは、", return_tensors="pt")
|
595 |
+
>>> outputs = model(**inputs)
|
596 |
+
|
597 |
+
>>> prediction_logits = outputs.logits
|
598 |
+
```"""
|
599 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
600 |
+
|
601 |
+
outputs = self.transformer(
|
602 |
+
input_ids,
|
603 |
+
attention_mask=attention_mask,
|
604 |
+
position_ids=position_ids,
|
605 |
+
head_mask=head_mask,
|
606 |
+
inputs_embeds=inputs_embeds,
|
607 |
+
past_key_values=past_key_values,
|
608 |
+
use_cache=use_cache,
|
609 |
+
output_attentions=output_attentions,
|
610 |
+
output_hidden_states=output_hidden_states,
|
611 |
+
return_dict=return_dict,
|
612 |
+
)
|
613 |
+
|
614 |
+
hidden_states = outputs[0]
|
615 |
+
lm_logits = self.embed_out(hidden_states)
|
616 |
+
|
617 |
+
lm_loss = None
|
618 |
+
if labels is not None:
|
619 |
+
# move labels to correct device to enable model parallelism
|
620 |
+
labels = labels.to(lm_logits.device)
|
621 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
622 |
+
shift_logits = lm_logits[:, :-1, :].contiguous()
|
623 |
+
labels = labels[:, 1:].contiguous()
|
624 |
+
loss_fct = CrossEntropyLoss()
|
625 |
+
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
|
626 |
+
|
627 |
+
if not return_dict:
|
628 |
+
output = (lm_logits,) + outputs[1:]
|
629 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
630 |
+
|
631 |
+
return CausalLMOutputWithPast(
|
632 |
+
loss=lm_loss,
|
633 |
+
logits=lm_logits,
|
634 |
+
past_key_values=outputs.past_key_values,
|
635 |
+
hidden_states=outputs.hidden_states,
|
636 |
+
attentions=outputs.attentions,
|
637 |
+
)
|
638 |
+
|
639 |
+
def prepare_inputs_for_generation(
|
640 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
641 |
+
):
|
642 |
+
input_shape = input_ids.shape
|
643 |
+
|
644 |
+
# cut decoder_input_ids if past is used
|
645 |
+
if past_key_values and past_key_values[0] is not None:
|
646 |
+
input_ids = input_ids[:, -1:]
|
647 |
+
|
648 |
+
position_ids = kwargs.get("position_ids", None)
|
649 |
+
if attention_mask is not None and position_ids is None:
|
650 |
+
# create position_ids on the fly for batch generation
|
651 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
652 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
653 |
+
if past_key_values:
|
654 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
655 |
+
|
656 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
657 |
+
if attention_mask is None:
|
658 |
+
attention_mask = input_ids.new_ones(input_shape)
|
659 |
+
|
660 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
661 |
+
if inputs_embeds is not None and past_key_values is None:
|
662 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
663 |
+
else:
|
664 |
+
model_inputs = {"input_ids": input_ids}
|
665 |
+
|
666 |
+
model_inputs.update(
|
667 |
+
{
|
668 |
+
"attention_mask": attention_mask,
|
669 |
+
"past_key_values": past_key_values,
|
670 |
+
"position_ids": position_ids,
|
671 |
+
}
|
672 |
+
)
|
673 |
+
|
674 |
+
return model_inputs
|
675 |
+
|
676 |
+
def _reorder_cache(self, past_key_values, beam_idx):
|
677 |
+
reordered_past = ()
|
678 |
+
for layer_past in past_key_values:
|
679 |
+
reordered_past += (
|
680 |
+
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
681 |
+
)
|
682 |
+
return reordered_past
|
preprocessor_config.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_convert_rgb": true,
|
3 |
+
"do_normalize": true,
|
4 |
+
"do_rescale": true,
|
5 |
+
"do_resize": true,
|
6 |
+
"image_mean": [
|
7 |
+
0.48145466,
|
8 |
+
0.4578275,
|
9 |
+
0.40821073
|
10 |
+
],
|
11 |
+
"image_processor_type": "BlipImageProcessor",
|
12 |
+
"image_std": [
|
13 |
+
0.26862954,
|
14 |
+
0.26130258,
|
15 |
+
0.27577711
|
16 |
+
],
|
17 |
+
"processor_class": "InstructBlipProcessor",
|
18 |
+
"resample": 3,
|
19 |
+
"rescale_factor": 0.00392156862745098,
|
20 |
+
"size": {
|
21 |
+
"height": 224,
|
22 |
+
"width": 224
|
23 |
+
}
|
24 |
+
}
|
pytorch_model-00001-of-00004.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:be7874a30f77cac685b8b1c446e7ac97ea2a066c6e9dc4a75c43d7c45341eb9f
|
3 |
+
size 9928428361
|
pytorch_model-00002-of-00004.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dbca4cbebac59e57652738c1fbbccd18abbc77faff5971774a74050f40593ae0
|
3 |
+
size 9982874199
|
pytorch_model-00003-of-00004.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c9152e6b29b5067795d63786f91a185282101c6fdcad843c40b208cc968d137d
|
3 |
+
size 9714437907
|
pytorch_model-00004-of-00004.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a5a02c272a7eaf9b8fb1248a52522f89533a432703e4764400fbdfac4d2fde8a
|
3 |
+
size 3233898185
|
pytorch_model.bin.index.fp16.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pytorch_model.bin.index.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
pytorch_model.fp16-00001-of-00002.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:edb6de9b5fc91eec3b6a57d6dacd14073001f37fa860e68775beaf3afb7d79dc
|
3 |
+
size 9955835753
|
pytorch_model.fp16-00002-of-00002.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:528041e8ff85ea59f4ba0508efc852ef55c9fe15cedd2956c6f79ee8630f81f6
|
3 |
+
size 6474190985
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
sentencepiece
|
2 |
+
einops
|