devroop commited on
Commit
03561be
·
verified ·
1 Parent(s): f2bf9b8

Upload folder using huggingface_hub

Browse files
Files changed (45) hide show
  1. .gitignore +148 -0
  2. .gradio/certificate.pem +31 -0
  3. LICENSE +204 -0
  4. README.md +196 -12
  5. README_zh-CN.md +184 -0
  6. app.py +379 -0
  7. configs/dataset_config.py +64 -0
  8. configs/lora_config.py +9 -0
  9. docs/images/demo_image.jpg +0 -0
  10. environment.yml +10 -0
  11. mmgpt/__init__.py +2 -0
  12. mmgpt/datasets/__init__.py +4 -0
  13. mmgpt/datasets/alpaca_gpt4_dataset.py +26 -0
  14. mmgpt/datasets/aokvqa_dataset.py +51 -0
  15. mmgpt/datasets/baize_dataset.py +86 -0
  16. mmgpt/datasets/builder.py +126 -0
  17. mmgpt/datasets/cc_sbu_align_dataset.py +107 -0
  18. mmgpt/datasets/clevr_dataset.py +74 -0
  19. mmgpt/datasets/coco_caption_dataset.py +119 -0
  20. mmgpt/datasets/dial_dataset.py +83 -0
  21. mmgpt/datasets/dolly_dataset.py +150 -0
  22. mmgpt/datasets/gqa_dataset.py +83 -0
  23. mmgpt/datasets/llava_dataset.py +18 -0
  24. mmgpt/datasets/nlvr_dataset.py +212 -0
  25. mmgpt/datasets/ocr_vqa_dataset.py +23 -0
  26. mmgpt/datasets/samplers/__init__.py +1 -0
  27. mmgpt/datasets/samplers/infinite_sampler.py +30 -0
  28. mmgpt/datasets/snli_ve_datasets.py +82 -0
  29. mmgpt/datasets/text_ocr_dataset.py +64 -0
  30. mmgpt/datasets/vqa_dataset.py +227 -0
  31. mmgpt/models/__init__.py +0 -0
  32. mmgpt/models/blip2/__init__.py +0 -0
  33. mmgpt/models/builder.py +74 -0
  34. mmgpt/models/open_flamingo/__init__.py +3 -0
  35. mmgpt/models/open_flamingo/builder.py +142 -0
  36. mmgpt/models/open_flamingo/flamingo.py +208 -0
  37. mmgpt/models/open_flamingo/flamingo_lm.py +131 -0
  38. mmgpt/models/open_flamingo/helpers.py +263 -0
  39. mmgpt/models/open_flamingo/utils.py +31 -0
  40. mmgpt/train/__init__.py +1 -0
  41. mmgpt/train/distributed.py +131 -0
  42. mmgpt/train/instruction_finetune.py +460 -0
  43. mmgpt/train/train_utils.py +251 -0
  44. requirements.txt +20 -0
  45. setup.py +50 -0
.gitignore ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pt
2
+
3
+ wandb/
4
+
5
+ checkpoints/
6
+ tests/
7
+
8
+ # Byte-compiled / optimized / DLL files
9
+ __pycache__/
10
+ *.py[cod]
11
+ *$py.class
12
+
13
+ # C extensions
14
+ *.so
15
+
16
+ # Distribution / packaging
17
+ .Python
18
+ build/
19
+ develop-eggs/
20
+ dist/
21
+ downloads/
22
+ eggs/
23
+ .eggs/
24
+ lib/
25
+ lib64/
26
+ parts/
27
+ sdist/
28
+ var/
29
+ wheels/
30
+ pip-wheel-metadata/
31
+ share/python-wheels/
32
+ *.egg-info/
33
+ .installed.cfg
34
+ *.egg
35
+ MANIFEST
36
+
37
+ # PyInstaller
38
+ # Usually these files are written by a python script from a template
39
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
40
+ *.manifest
41
+ *.spec
42
+
43
+ # Installer logs
44
+ pip-log.txt
45
+ pip-delete-this-directory.txt
46
+
47
+ # Unit test / coverage reports
48
+ htmlcov/
49
+ .tox/
50
+ .nox/
51
+ .coverage
52
+ .coverage.*
53
+ .cache
54
+ nosetests.xml
55
+ coverage.xml
56
+ *.cover
57
+ *.py,cover
58
+ .hypothesis/
59
+ .pytest_cache/
60
+
61
+ # Translations
62
+ *.mo
63
+ *.pot
64
+
65
+ # Django stuff:
66
+ *.log
67
+ local_settings.py
68
+ db.sqlite3
69
+ db.sqlite3-journal
70
+
71
+ # Flask stuff:
72
+ instance/
73
+ .webassets-cache
74
+
75
+ # Scrapy stuff:
76
+ .scrapy
77
+
78
+ # Sphinx documentation
79
+ docs/_build/
80
+
81
+ # PyBuilder
82
+ target/
83
+
84
+ # Jupyter Notebook
85
+ .ipynb_checkpoints
86
+
87
+ # IPython
88
+ profile_default/
89
+ ipython_config.py
90
+
91
+ # pyenv
92
+ .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
102
+ __pypackages__/
103
+
104
+ # Celery stuff
105
+ celerybeat-schedule
106
+ celerybeat.pid
107
+
108
+ # SageMath parsed files
109
+ *.sage.py
110
+
111
+ # Environments
112
+ .env
113
+ .venv
114
+ env/
115
+ venv/
116
+ ENV/
117
+ env.bak/
118
+ venv.bak/
119
+
120
+ # Pycharm project settings
121
+ .idea
122
+
123
+ # Spyder project settings
124
+ .spyderproject
125
+ .spyproject
126
+
127
+ # Rope project settings
128
+ .ropeproject
129
+
130
+ # mkdocs documentation
131
+ /site
132
+
133
+ # mypy
134
+ .mypy_cache/
135
+ .dmypy.json
136
+ dmypy.json
137
+
138
+ *.out
139
+ src/wandb
140
+ wandb
141
+
142
+ # Pyre type checker
143
+ .pyre/
144
+
145
+ # Training
146
+ batchscript*
147
+ work_dirs
148
+ data
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
LICENSE ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2018-2023 OpenMMLab. All rights reserved.
2
+
3
+ Apache License
4
+ Version 2.0, January 2004
5
+ http://www.apache.org/licenses/
6
+
7
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8
+
9
+ 1. Definitions.
10
+
11
+ "License" shall mean the terms and conditions for use, reproduction,
12
+ and distribution as defined by Sections 1 through 9 of this document.
13
+
14
+ "Licensor" shall mean the copyright owner or entity authorized by
15
+ the copyright owner that is granting the License.
16
+
17
+ "Legal Entity" shall mean the union of the acting entity and all
18
+ other entities that control, are controlled by, or are under common
19
+ control with that entity. For the purposes of this definition,
20
+ "control" means (i) the power, direct or indirect, to cause the
21
+ direction or management of such entity, whether by contract or
22
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
23
+ outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ "You" (or "Your") shall mean an individual or Legal Entity
26
+ exercising permissions granted by this License.
27
+
28
+ "Source" form shall mean the preferred form for making modifications,
29
+ including but not limited to software source code, documentation
30
+ source, and configuration files.
31
+
32
+ "Object" form shall mean any form resulting from mechanical
33
+ transformation or translation of a Source form, including but
34
+ not limited to compiled object code, generated documentation,
35
+ and conversions to other media types.
36
+
37
+ "Work" shall mean the work of authorship, whether in Source or
38
+ Object form, made available under the License, as indicated by a
39
+ copyright notice that is included in or attached to the work
40
+ (an example is provided in the Appendix below).
41
+
42
+ "Derivative Works" shall mean any work, whether in Source or Object
43
+ form, that is based on (or derived from) the Work and for which the
44
+ editorial revisions, annotations, elaborations, or other modifications
45
+ represent, as a whole, an original work of authorship. For the purposes
46
+ of this License, Derivative Works shall not include works that remain
47
+ separable from, or merely link (or bind by name) to the interfaces of,
48
+ the Work and Derivative Works thereof.
49
+
50
+ "Contribution" shall mean any work of authorship, including
51
+ the original version of the Work and any modifications or additions
52
+ to that Work or Derivative Works thereof, that is intentionally
53
+ submitted to Licensor for inclusion in the Work by the copyright owner
54
+ or by an individual or Legal Entity authorized to submit on behalf of
55
+ the copyright owner. For the purposes of this definition, "submitted"
56
+ means any form of electronic, verbal, or written communication sent
57
+ to the Licensor or its representatives, including but not limited to
58
+ communication on electronic mailing lists, source code control systems,
59
+ and issue tracking systems that are managed by, or on behalf of, the
60
+ Licensor for the purpose of discussing and improving the Work, but
61
+ excluding communication that is conspicuously marked or otherwise
62
+ designated in writing by the copyright owner as "Not a Contribution."
63
+
64
+ "Contributor" shall mean Licensor and any individual or Legal Entity
65
+ on behalf of whom a Contribution has been received by Licensor and
66
+ subsequently incorporated within the Work.
67
+
68
+ 2. Grant of Copyright License. Subject to the terms and conditions of
69
+ this License, each Contributor hereby grants to You a perpetual,
70
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71
+ copyright license to reproduce, prepare Derivative Works of,
72
+ publicly display, publicly perform, sublicense, and distribute the
73
+ Work and such Derivative Works in Source or Object form.
74
+
75
+ 3. Grant of Patent License. Subject to the terms and conditions of
76
+ this License, each Contributor hereby grants to You a perpetual,
77
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78
+ (except as stated in this section) patent license to make, have made,
79
+ use, offer to sell, sell, import, and otherwise transfer the Work,
80
+ where such license applies only to those patent claims licensable
81
+ by such Contributor that are necessarily infringed by their
82
+ Contribution(s) alone or by combination of their Contribution(s)
83
+ with the Work to which such Contribution(s) was submitted. If You
84
+ institute patent litigation against any entity (including a
85
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
86
+ or a Contribution incorporated within the Work constitutes direct
87
+ or contributory patent infringement, then any patent licenses
88
+ granted to You under this License for that Work shall terminate
89
+ as of the date such litigation is filed.
90
+
91
+ 4. Redistribution. You may reproduce and distribute copies of the
92
+ Work or Derivative Works thereof in any medium, with or without
93
+ modifications, and in Source or Object form, provided that You
94
+ meet the following conditions:
95
+
96
+ (a) You must give any other recipients of the Work or
97
+ Derivative Works a copy of this License; and
98
+
99
+ (b) You must cause any modified files to carry prominent notices
100
+ stating that You changed the files; and
101
+
102
+ (c) You must retain, in the Source form of any Derivative Works
103
+ that You distribute, all copyright, patent, trademark, and
104
+ attribution notices from the Source form of the Work,
105
+ excluding those notices that do not pertain to any part of
106
+ the Derivative Works; and
107
+
108
+ (d) If the Work includes a "NOTICE" text file as part of its
109
+ distribution, then any Derivative Works that You distribute must
110
+ include a readable copy of the attribution notices contained
111
+ within such NOTICE file, excluding those notices that do not
112
+ pertain to any part of the Derivative Works, in at least one
113
+ of the following places: within a NOTICE text file distributed
114
+ as part of the Derivative Works; within the Source form or
115
+ documentation, if provided along with the Derivative Works; or,
116
+ within a display generated by the Derivative Works, if and
117
+ wherever such third-party notices normally appear. The contents
118
+ of the NOTICE file are for informational purposes only and
119
+ do not modify the License. You may add Your own attribution
120
+ notices within Derivative Works that You distribute, alongside
121
+ or as an addendum to the NOTICE text from the Work, provided
122
+ that such additional attribution notices cannot be construed
123
+ as modifying the License.
124
+
125
+ You may add Your own copyright statement to Your modifications and
126
+ may provide additional or different license terms and conditions
127
+ for use, reproduction, or distribution of Your modifications, or
128
+ for any such Derivative Works as a whole, provided Your use,
129
+ reproduction, and distribution of the Work otherwise complies with
130
+ the conditions stated in this License.
131
+
132
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
133
+ any Contribution intentionally submitted for inclusion in the Work
134
+ by You to the Licensor shall be under the terms and conditions of
135
+ this License, without any additional terms or conditions.
136
+ Notwithstanding the above, nothing herein shall supersede or modify
137
+ the terms of any separate license agreement you may have executed
138
+ with Licensor regarding such Contributions.
139
+
140
+ 6. Trademarks. This License does not grant permission to use the trade
141
+ names, trademarks, service marks, or product names of the Licensor,
142
+ except as required for reasonable and customary use in describing the
143
+ origin of the Work and reproducing the content of the NOTICE file.
144
+
145
+ 7. Disclaimer of Warranty. Unless required by applicable law or
146
+ agreed to in writing, Licensor provides the Work (and each
147
+ Contributor provides its Contributions) on an "AS IS" BASIS,
148
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149
+ implied, including, without limitation, any warranties or conditions
150
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151
+ PARTICULAR PURPOSE. You are solely responsible for determining the
152
+ appropriateness of using or redistributing the Work and assume any
153
+ risks associated with Your exercise of permissions under this License.
154
+
155
+ 8. Limitation of Liability. In no event and under no legal theory,
156
+ whether in tort (including negligence), contract, or otherwise,
157
+ unless required by applicable law (such as deliberate and grossly
158
+ negligent acts) or agreed to in writing, shall any Contributor be
159
+ liable to You for damages, including any direct, indirect, special,
160
+ incidental, or consequential damages of any character arising as a
161
+ result of this License or out of the use or inability to use the
162
+ Work (including but not limited to damages for loss of goodwill,
163
+ work stoppage, computer failure or malfunction, or any and all
164
+ other commercial damages or losses), even if such Contributor
165
+ has been advised of the possibility of such damages.
166
+
167
+ 9. Accepting Warranty or Additional Liability. While redistributing
168
+ the Work or Derivative Works thereof, You may choose to offer,
169
+ and charge a fee for, acceptance of support, warranty, indemnity,
170
+ or other liability obligations and/or rights consistent with this
171
+ License. However, in accepting such obligations, You may act only
172
+ on Your own behalf and on Your sole responsibility, not on behalf
173
+ of any other Contributor, and only if You agree to indemnify,
174
+ defend, and hold each Contributor harmless for any liability
175
+ incurred by, or claims asserted against, such Contributor by reason
176
+ of your accepting any such warranty or additional liability.
177
+
178
+ END OF TERMS AND CONDITIONS
179
+
180
+ APPENDIX: How to apply the Apache License to your work.
181
+
182
+ To apply the Apache License to your work, attach the following
183
+ boilerplate notice, with the fields enclosed by brackets "[]"
184
+ replaced with your own identifying information. (Don't include
185
+ the brackets!) The text should be enclosed in the appropriate
186
+ comment syntax for the file format. We also recommend that a
187
+ file or class name and description of purpose be included on the
188
+ same "printed page" as the copyright notice for easier
189
+ identification within third-party archives.
190
+
191
+ Copyright 2018-2023 OpenMMLab.
192
+
193
+ Licensed under the Apache License, Version 2.0 (the "License");
194
+ you may not use this file except in compliance with the License.
195
+ You may obtain a copy of the License at
196
+
197
+ http://www.apache.org/licenses/LICENSE-2.0
198
+
199
+ Unless required by applicable law or agreed to in writing, software
200
+ distributed under the License is distributed on an "AS IS" BASIS,
201
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202
+ See the License for the specific language governing permissions and
203
+ limitations under the License.
204
+
README.md CHANGED
@@ -1,12 +1,196 @@
1
- ---
2
- title: Multimodal GPT
3
- emoji: 🐠
4
- colorFrom: indigo
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.21.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Multimodal-GPT
3
+ app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 5.21.0
6
+ ---
7
+ # 🤖 Multi-modal GPT
8
+
9
+ Train a multi-modal chatbot with visual and language instructions!
10
+
11
+ Based on the open-source multi-modal model [OpenFlamingo](https://github.com/mlfoundations/open_flamingo), we create various **visual instruction** data with open datasets, including VQA, Image Captioning, Visual Reasoning, Text OCR, and Visual Dialogue. Additionally, we also train the language model component of OpenFlamingo using only **language-only instruction** data.
12
+
13
+ The **joint training** of visual and language instructions effectively improves the performance of the model! For more details please refer to our [technical report](https://arxiv.org/abs/2305.04790).
14
+
15
+ Welcome to join us!
16
+
17
+ </div>
18
+
19
+ <div align="center">
20
+
21
+ English | [简体中文](README_zh-CN.md)
22
+
23
+ </div>
24
+
25
+ <div align="center">
26
+ <a href="https://openmmlab.medium.com/" style="text-decoration:none;">
27
+ <img src="https://user-images.githubusercontent.com/25839884/219255827-67c1a27f-f8c5-46a9-811d-5e57448c61d1.png" width="3%" alt="" /></a>
28
+ <img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
29
+ <a href="https://discord.com/channels/1037617289144569886/1046608014234370059" style="text-decoration:none;">
30
+ <img src="https://user-images.githubusercontent.com/25839884/218347213-c080267f-cbb6-443e-8532-8e1ed9a58ea9.png" width="3%" alt="" /></a>
31
+ <img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
32
+ <a href="https://twitter.com/OpenMMLab" style="text-decoration:none;">
33
+ <img src="https://user-images.githubusercontent.com/25839884/218346637-d30c8a0f-3eba-4699-8131-512fb06d46db.png" width="3%" alt="" /></a>
34
+ <img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
35
+ <a href="https://www.youtube.com/openmmlab" style="text-decoration:none;">
36
+ <img src="https://user-images.githubusercontent.com/25839884/218346691-ceb2116a-465a-40af-8424-9f30d2348ca9.png" width="3%" alt="" /></a>
37
+ <img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
38
+ <a href="https://space.bilibili.com/1293512903" style="text-decoration:none;">
39
+ <img src="https://user-images.githubusercontent.com/25839884/219026751-d7d14cce-a7c9-4e82-9942-8375fca65b99.png" width="3%" alt="" /></a>
40
+ <img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
41
+ <a href="https://www.zhihu.com/people/openmmlab" style="text-decoration:none;">
42
+ <img src="https://user-images.githubusercontent.com/25839884/219026120-ba71e48b-6e94-4bd4-b4e9-b7d175b5e362.png" width="3%" alt="" /></a>
43
+ </div>
44
+
45
+ ## Features
46
+
47
+ - Support various vision and language instruction data
48
+ - Parameter efficient fine-tuning with LoRA
49
+ - Tuning vision and language at the same time, complement each other
50
+
51
+
52
+ ## Installation
53
+
54
+ To install the package in an existing environment, run
55
+
56
+ ```bash
57
+ git clone https://github.com/open-mmlab/Multimodal-GPT.git
58
+ cd Multimodal-GPT
59
+ pip install -r requirements.txt
60
+ pip install -v -e .
61
+ ```
62
+
63
+ or create a new conda environment
64
+
65
+ ```bash
66
+ conda env create -f environment.yml
67
+ ```
68
+
69
+
70
+ ## Launch Demo Locally
71
+
72
+ 1. Download the pre-trained weights.
73
+
74
+ Use [this script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py) for converting LLaMA weights to Hugging Face format.
75
+
76
+ Download the OpenFlamingo pre-trained model from [openflamingo/OpenFlamingo-9B](https://huggingface.co/openflamingo/OpenFlamingo-9B).
77
+
78
+ Download our LoRA Weight from [here](https://download.openmmlab.com/mmgpt/v0/mmgpt-lora-v0-release.pt).
79
+
80
+ Then place these models in `checkpoints` folders like this:
81
+
82
+ ```
83
+ checkpoints
84
+ ├── llama-7b_hf
85
+ │ ├── config.json
86
+ │ ├── pytorch_model-00001-of-00002.bin
87
+ │ ├── ......
88
+ │ └── tokenizer.model
89
+ ├── OpenFlamingo-9B
90
+ │ └──checkpoint.pt
91
+ ├──mmgpt-lora-v0-release.pt
92
+
93
+ 2. launch the gradio demo
94
+
95
+ ```bash
96
+ python app.py
97
+ ```
98
+
99
+ ## Examples
100
+
101
+ ### Recipe:
102
+ ![image4](https://user-images.githubusercontent.com/12907710/234554562-8f3be88f-d563-47ba-97d9-ade8d47c46b0.png)
103
+
104
+ ### Travel plan:
105
+ ![image3](https://user-images.githubusercontent.com/12907710/234523464-80c4e3f0-f99f-4498-96ef-dc43ef89c64b.png)
106
+
107
+ ### Movie:
108
+ ![image2](https://user-images.githubusercontent.com/12907710/234523468-e11905a6-491f-4b87-934f-90da7d14d1c3.png)
109
+
110
+ ### Famous person:
111
+ ![image](https://user-images.githubusercontent.com/12907710/234523475-fd91f979-a344-4228-813f-6b55a1bc250f.png)
112
+
113
+
114
+ ## Fine-tuning
115
+
116
+ ### Prepare datasets
117
+
118
+ 1. [A-OKVQA](https://allenai.org/project/a-okvqa/home)
119
+
120
+ Download annotation from [this link](https://prior-datasets.s3.us-east-2.amazonaws.com/aokvqa/aokvqa_v1p0.tar.gz) and unzip to `data/aokvqa/annotations`.
121
+
122
+ It also requires images from coco dataset which can be downloaded from [here](https://cocodataset.org/#home).
123
+
124
+ 2. [COCO Caption](https://cs.stanford.edu/people/karpathy/deepimagesent/)
125
+
126
+ Download from [this link](https://cs.stanford.edu/people/karpathy/deepimagesent/coco.zip) and unzip to `data/coco`.
127
+
128
+ It also requires images from coco dataset which can be downloaded from [here](https://cocodataset.org/#home).
129
+
130
+ 3. [OCR VQA](https://ocr-vqa.github.io/)
131
+
132
+ Download from [this link](https://drive.google.com/drive/folders/1_GYPY5UkUy7HIcR0zq3ZCFgeZN7BAfm_?usp=sharing) and place in `data/OCR_VQA/`.
133
+
134
+ 4. [LlaVA](https://llava-vl.github.io/)
135
+
136
+ Download from [liuhaotian/LLaVA-Instruct-150K](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) and place in `data/llava/`.
137
+
138
+ It also requires images from coco dataset which can be downloaded from [here](https://cocodataset.org/#home).
139
+
140
+ 5. [Mini-GPT4](https://minigpt-4.github.io/)
141
+
142
+ Download from [Vision-CAIR/cc_sbu_align](https://huggingface.co/datasets/Vision-CAIR/cc_sbu_align) and place in `data/cc_sbu_align/`.
143
+
144
+ 6. [Dolly 15k](https://www.databricks.com/blog/2023/03/24/hello-dolly-democratizing-magic-chatgpt-open-models.html)
145
+
146
+ Download from [databricks/databricks-dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k) and place it in `data/dolly/databricks-dolly-15k.jsonl`.
147
+
148
+ 7. [Alpaca GPT4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
149
+
150
+ Download it from [this link](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/raw/main/data/alpaca_gpt4_data.json) and place it in `data/alpaca_gpt4/alpaca_gpt4_data.json`.
151
+
152
+ You can also customize the data path in the [configs/dataset_config.py](configs/dataset_config.py).
153
+
154
+ 8. [Baize](https://github.com/project-baize/baize-chatbot)
155
+
156
+ Download it from [this link](https://github.com/project-baize/baize-chatbot/blob/main/data/quora_chat_data.json) and place it in `data/baize/quora_chat_data.json`.
157
+
158
+
159
+ ## Start training
160
+
161
+ ```bash
162
+ torchrun --nproc_per_node=8 mmgpt/train/instruction_finetune.py \
163
+ --lm_path checkpoints/llama-7b_hf \
164
+ --tokenizer_path checkpoints/llama-7b_hf \
165
+ --pretrained_path checkpoints/OpenFlamingo-9B/checkpoint.pt \
166
+ --run_name train-my-gpt4 \
167
+ --learning_rate 1e-5 \
168
+ --lr_scheduler cosine \
169
+ --batch_size 1 \
170
+ --tuning_config configs/lora_config.py \
171
+ --dataset_config configs/dataset_config.py \
172
+ --report_to_wandb
173
+ ```
174
+
175
+
176
+ ## Acknowledgements
177
+
178
+ - [OpenFlamingo](https://github.com/mlfoundations/open_flamingo)
179
+ - [LAVIS](https://github.com/salesforce/LAVIS)
180
+ - [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca)
181
+ - [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4)
182
+ - [LLaVA](https://github.com/haotian-liu/LLaVA/tree/main)
183
+ - [Instruction Tuning with GPT-4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
184
+
185
+ If you find our project useful for your research and applications, please cite using this BibTeX:
186
+
187
+ ```bibtex
188
+ @misc{gong2023multimodalgpt,
189
+ title={MultiModal-GPT: A Vision and Language Model for Dialogue with Humans},
190
+ author={Tao Gong and Chengqi Lyu and Shilong Zhang and Yudong Wang and Miao Zheng and Qian Zhao and Kuikun Liu and Wenwei Zhang and Ping Luo and Kai Chen},
191
+ year={2023},
192
+ eprint={2305.04790},
193
+ archivePrefix={arXiv},
194
+ primaryClass={cs.CV}
195
+ }
196
+ ```
README_zh-CN.md ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🤖 Multi-modal GPT
2
+
3
+ 使用视觉和语言指令训练一个多模态聊天机器人!
4
+
5
+ 基于开源多模态模型 [OpenFlamingo](https://github.com/mlfoundations/open_flamingo),我们使用公开数据集创建了各种**视觉指令**数据,包括视觉问答、图像字幕、视觉推理、文本 OCR 和视觉对话。此外,我们还使用仅包含**语言指令**数据的语言模型组件进行了训练。
6
+
7
+ 视觉和语言指令的**联合训练**有效提高了模型的性能!更多细节请参阅我们的[技术报告](https://arxiv.org/abs/2305.04790)。
8
+
9
+ 欢迎加入我们!
10
+
11
+ </div>
12
+
13
+ <div align="center">
14
+
15
+ [English](README.md) | 简体中文
16
+
17
+ </div>
18
+
19
+ <div align="center">
20
+ <a href="https://openmmlab.medium.com/" style="text-decoration:none;">
21
+ <img src="https://user-images.githubusercontent.com/25839884/219255827-67c1a27f-f8c5-46a9-811d-5e57448c61d1.png" width="3%" alt="" /></a>
22
+ <img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
23
+ <a href="https://discord.com/channels/1037617289144569886/1046608014234370059" style="text-decoration:none;">
24
+ <img src="https://user-images.githubusercontent.com/25839884/218347213-c080267f-cbb6-443e-8532-8e1ed9a58ea9.png" width="3%" alt="" /></a>
25
+ <img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
26
+ <a href="https://twitter.com/OpenMMLab" style="text-decoration:none;">
27
+ <img src="https://user-images.githubusercontent.com/25839884/218346637-d30c8a0f-3eba-4699-8131-512fb06d46db.png" width="3%" alt="" /></a>
28
+ <img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
29
+ <a href="https://www.youtube.com/openmmlab" style="text-decoration:none;">
30
+ <img src="https://user-images.githubusercontent.com/25839884/218346691-ceb2116a-465a-40af-8424-9f30d2348ca9.png" width="3%" alt="" /></a>
31
+ <img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
32
+ <a href="https://space.bilibili.com/1293512903" style="text-decoration:none;">
33
+ <img src="https://user-images.githubusercontent.com/25839884/219026751-d7d14cce-a7c9-4e82-9942-8375fca65b99.png" width="3%" alt="" /></a>
34
+ <img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
35
+ <a href="https://www.zhihu.com/people/openmmlab" style="text-decoration:none;">
36
+ <img src="https://user-images.githubusercontent.com/25839884/219026120-ba71e48b-6e94-4bd4-b4e9-b7d175b5e362.png" width="3%" alt="" /></a>
37
+ </div>
38
+
39
+ ## 特性
40
+
41
+ - 支持各种视觉和语言指令数据
42
+ - 使用 LoRA 进行参数高效微调
43
+ - 同时调整视觉和语言,相互补充
44
+
45
+ ## 安装
46
+
47
+ 在一个已有环境中安装依赖包,运行以下指令
48
+
49
+ ```bash
50
+ git clone https://github.com/open-mmlab/Multimodal-GPT.git
51
+ cd Multimodal-GPT
52
+ pip install -r requirements.txt
53
+ pip install -v -e .
54
+ ```
55
+
56
+ 或者创建一个新的 conda 环境
57
+
58
+ ```bash
59
+ conda env create -f environment.yml
60
+ ```
61
+
62
+ ## Demo
63
+
64
+ 1. 下载预训练权重
65
+
66
+ 使用[这个脚本](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py)把 LLaMA 权重转换成 HuggingFace 格式。
67
+
68
+ 从 [openflamingo/OpenFlamingo-9B](https://huggingface.co/openflamingo/OpenFlamingo-9B) 下载 OpenFlamingo 预训练模型。
69
+
70
+ 从[这个链接](https://download.openmmlab.com/mmgpt/v0/mmgpt-lora-v0-release.pt) 下载我们的 LoRA 权重。
71
+
72
+ 然后把所有模型权重放到 `checkpoints` 文件夹下,目录结构如下:
73
+
74
+ ```
75
+ checkpoints
76
+ ├── llama-7b_hf
77
+ │ ├── config.json
78
+ │ ├── pytorch_model-00001-of-00002.bin
79
+ │ ├── ......
80
+ │ └── tokenizer.model
81
+ ├── OpenFlamingo-9B
82
+ │ └──checkpoint.pt
83
+ ├──mmgpt-lora-v0-release.pt
84
+
85
+ 2. 启动 gradio demo
86
+
87
+ ```bash
88
+ python app.py
89
+ ```
90
+
91
+ ## 示例
92
+
93
+ ### 菜单:
94
+ ![image4](https://user-images.githubusercontent.com/12907710/234554562-8f3be88f-d563-47ba-97d9-ade8d47c46b0.png)
95
+
96
+ ### 旅行计划:
97
+ ![image3](https://user-images.githubusercontent.com/12907710/234523464-80c4e3f0-f99f-4498-96ef-dc43ef89c64b.png)
98
+
99
+ ### 电影:
100
+ ![image2](https://user-images.githubusercontent.com/12907710/234523468-e11905a6-491f-4b87-934f-90da7d14d1c3.png)
101
+
102
+ ### 名人:
103
+ ![image](https://user-images.githubusercontent.com/12907710/234523475-fd91f979-a344-4228-813f-6b55a1bc250f.png)
104
+
105
+
106
+ ## 微调 Fine-tuning
107
+
108
+ ### 准备数据集
109
+
110
+ 1. [A-OKVQA](https://allenai.org/project/a-okvqa/home)
111
+
112
+ 从[这个链接](https://prior-datasets.s3.us-east-2.amazonaws.com/aokvqa/aokvqa_v1p0.tar.gz)下载标注,解压到 `data/aokvqa/annotations` 路径下。
113
+
114
+ 同时还需要 coco 数据集的图像,可以从[这里](https://cocodataset.org/#home)下载。
115
+
116
+ 2. [COCO Caption](https://cs.stanford.edu/people/karpathy/deepimagesent/)
117
+
118
+ 从[这个链接](https://cs.stanford.edu/people/karpathy/deepimagesent/coco.zip),解压到 `data/coco` 路径下。
119
+
120
+ 同时还需要 coco 数据集的图像,可以从[这里](https://cocodataset.org/#home)下载。
121
+
122
+ 3. [OCR VQA](https://ocr-vqa.github.io/)
123
+
124
+ 从 [这个链接](https://drive.google.com/drive/folders/1_GYPY5UkUy7HIcR0zq3ZCFgeZN7BAfm_?usp=sharing) 下载数据集,放到 `data/OCR_VQA/` 路径下。
125
+
126
+ 4. [LlaVA](https://llava-vl.github.io/)
127
+
128
+ 从 [liuhaotian/LLaVA-Instruct-150K](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) 下载数据集,放到 `data/llava/` 路径下。
129
+
130
+ 同时还需要 coco 数据集的图像,可以从[这里](https://cocodataset.org/#home)下载。
131
+
132
+ 5. [Mini-GPT4](https://minigpt-4.github.io/)
133
+
134
+ 从 [Vision-CAIR/cc_sbu_align](https://huggingface.co/datasets/Vision-CAIR/cc_sbu_align) 下载数据集,放到 `data/cc_sbu_align/` 路径下。
135
+
136
+ 6. [Dolly 15k](https://www.databricks.com/blog/2023/03/24/hello-dolly-democratizing-magic-chatgpt-open-models.html)
137
+
138
+ 从 [databricks/databricks-dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k) 下载数据集,放到 `data/dolly/databricks-dolly-15k.jsonl` 路径下。
139
+
140
+ 7. [Alpaca GPT4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
141
+
142
+ 从[这个链接](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/raw/main/data/alpaca_gpt4_data.json) 下载数据集,放到 `data/alpaca_gpt4/alpaca_gpt4_data.json` 路径下。
143
+
144
+ 你也可以在 [configs/dataset_config.py](configs/dataset_config.py) 文件中自定义数据集路径。
145
+
146
+
147
+ ## 开启训练
148
+
149
+ ```bash
150
+ torchrun --nproc_per_node=8 mmgpt/train/instruction_finetune.py \
151
+ --lm_path checkpoints/llama-7b_hf \
152
+ --tokenizer_path checkpoints/llama-7b_hf \
153
+ --pretrained_path checkpoints/OpenFlamingo-9B/checkpoint.pt \
154
+ --run_name train-my-gpt4 \
155
+ --learning_rate 1e-5 \
156
+ --lr_scheduler cosine \
157
+ --batch_size 1 \
158
+ --tuning_config configs/lora_config.py \
159
+ --dataset_config configs/dataset_config.py \
160
+ --report_to_wandb
161
+ ```
162
+
163
+
164
+ ## 致谢
165
+
166
+ - [OpenFlamingo](https://github.com/mlfoundations/open_flamingo)
167
+ - [LAVIS](https://github.com/salesforce/LAVIS)
168
+ - [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca)
169
+ - [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4)
170
+ - [LLaVA](https://github.com/haotian-liu/LLaVA/tree/main)
171
+ - [Instruction Tuning with GPT-4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
172
+
173
+ 如果你觉得我们的项目对你的研究和应用有帮助,请用以下 BibTeX 进行引用
174
+
175
+ ```bibtex
176
+ @misc{gong2023multimodalgpt,
177
+ title={MultiModal-GPT: A Vision and Language Model for Dialogue with Humans},
178
+ author={Tao Gong and Chengqi Lyu and Shilong Zhang and Yudong Wang and Miao Zheng and Qian Zhao and Kuikun Liu and Wenwei Zhang and Ping Luo and Kai Chen},
179
+ year={2023},
180
+ eprint={2305.04790},
181
+ archivePrefix={arXiv},
182
+ primaryClass={cs.CV}
183
+ }
184
+ ```
app.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from PIL import Image
6
+
7
+ from mmgpt.models.builder import create_model_and_transforms
8
+
9
+ TEMPLATE = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
10
+ response_split = "### Response:"
11
+
12
+
13
+ class Inferencer:
14
+
15
+ def __init__(self, finetune_path, llama_path, open_flamingo_path):
16
+ ckpt = torch.load(finetune_path, map_location="cpu", weights_only=False)
17
+ if "model_state_dict" in ckpt:
18
+ state_dict = ckpt["model_state_dict"]
19
+ # remove the "module." prefix
20
+ state_dict = {
21
+ k[7:]: v
22
+ for k, v in state_dict.items() if k.startswith("module.")
23
+ }
24
+ else:
25
+ state_dict = ckpt
26
+ tuning_config = ckpt.get("tuning_config")
27
+ if tuning_config is None:
28
+ print("tuning_config not found in checkpoint")
29
+ else:
30
+ print("tuning_config found in checkpoint: ", tuning_config)
31
+ model, image_processor, tokenizer = create_model_and_transforms(
32
+ model_name="open_flamingo",
33
+ clip_vision_encoder_path="ViT-L-14",
34
+ clip_vision_encoder_pretrained="openai",
35
+ lang_encoder_path=llama_path,
36
+ tokenizer_path=llama_path,
37
+ pretrained_model_path=open_flamingo_path,
38
+ tuning_config=tuning_config,
39
+ )
40
+ model.load_state_dict(state_dict, strict=False)
41
+ model.half()
42
+
43
+
44
+ device = torch.device("cpu")
45
+ model = model.to(device)
46
+
47
+ # model = model.to("cuda")
48
+ model.eval()
49
+ tokenizer.padding_side = "left"
50
+ tokenizer.add_eos_token = False
51
+ self.model = model
52
+ self.image_processor = image_processor
53
+ self.tokenizer = tokenizer
54
+
55
+ def __call__(self, prompt, imgpaths, max_new_token, num_beams, temperature,
56
+ top_k, top_p, do_sample):
57
+ device = torch.device("cpu")
58
+ if len(imgpaths) > 1:
59
+ raise gr.Error(
60
+ "Current only support one image, please clear gallery and upload one image"
61
+ )
62
+ lang_x = self.tokenizer([prompt], return_tensors="pt")
63
+ if len(imgpaths) == 0 or imgpaths is None:
64
+ for layer in self.model.lang_encoder._get_decoder_layers():
65
+ layer.condition_only_lang_x(True)
66
+ output_ids = self.model.lang_encoder.generate(
67
+ input_ids=lang_x["input_ids"].to(device),
68
+ attention_mask=lang_x["attention_mask"].to(device),
69
+ max_new_tokens=max_new_token,
70
+ num_beams=num_beams,
71
+ temperature=temperature,
72
+ top_k=top_k,
73
+ top_p=top_p,
74
+ do_sample=do_sample,
75
+ )[0]
76
+ for layer in self.model.lang_encoder._get_decoder_layers():
77
+ layer.condition_only_lang_x(False)
78
+ else:
79
+ images = (Image.open(fp) for fp in imgpaths)
80
+ vision_x = [self.image_processor(im).unsqueeze(0) for im in images]
81
+ vision_x = torch.cat(vision_x, dim=0)
82
+ vision_x = vision_x.unsqueeze(1).unsqueeze(0).half()
83
+
84
+ output_ids = self.model.generate(
85
+ vision_x=vision_x.to(device),
86
+ lang_x=lang_x["input_ids"].to(device),
87
+ attention_mask=lang_x["attention_mask"].to(device),
88
+ max_new_tokens=max_new_token,
89
+ num_beams=num_beams,
90
+ temperature=temperature,
91
+ top_k=top_k,
92
+ top_p=top_p,
93
+ do_sample=do_sample,
94
+ )[0]
95
+ generated_text = self.tokenizer.decode(
96
+ output_ids, skip_special_tokens=True)
97
+ # print(generated_text)
98
+ result = generated_text.split(response_split)[-1].strip()
99
+ return result
100
+
101
+
102
+ class PromptGenerator:
103
+
104
+ def __init__(
105
+ self,
106
+ prompt_template=TEMPLATE,
107
+ ai_prefix="Response",
108
+ user_prefix="Instruction",
109
+ sep: str = "\n\n### ",
110
+ buffer_size=0,
111
+ ):
112
+ self.all_history = list()
113
+ self.ai_prefix = ai_prefix
114
+ self.user_prefix = user_prefix
115
+ self.buffer_size = buffer_size
116
+ self.prompt_template = prompt_template
117
+ self.sep = sep
118
+
119
+ def add_message(self, role, message):
120
+ self.all_history.append([role, message])
121
+
122
+ def get_images(self):
123
+ img_list = list()
124
+ if self.buffer_size > 0:
125
+ all_history = self.all_history[-2 * (self.buffer_size + 1):]
126
+ elif self.buffer_size == 0:
127
+ all_history = self.all_history[-2:]
128
+ else:
129
+ all_history = self.all_history[:]
130
+ for his in all_history:
131
+ if type(his[-1]) == tuple:
132
+ img_list.append(his[-1][-1])
133
+ return img_list
134
+
135
+ def get_prompt(self):
136
+ format_dict = dict()
137
+ if "{user_prefix}" in self.prompt_template:
138
+ format_dict["user_prefix"] = self.user_prefix
139
+ if "{ai_prefix}" in self.prompt_template:
140
+ format_dict["ai_prefix"] = self.ai_prefix
141
+ prompt_template = self.prompt_template.format(**format_dict)
142
+ ret = prompt_template
143
+ if self.buffer_size > 0:
144
+ all_history = self.all_history[-2 * (self.buffer_size + 1):]
145
+ elif self.buffer_size == 0:
146
+ all_history = self.all_history[-2:]
147
+ else:
148
+ all_history = self.all_history[:]
149
+ context = []
150
+ have_image = False
151
+ for role, message in all_history[::-1]:
152
+ if message:
153
+ if type(message) is tuple and message[
154
+ 1] is not None and not have_image:
155
+ message, _ = message
156
+ context.append(self.sep + "Image:\n<image>" + self.sep +
157
+ role + ":\n" + message)
158
+ else:
159
+ context.append(self.sep + role + ":\n" + message)
160
+ else:
161
+ context.append(self.sep + role + ":\n")
162
+
163
+ ret += "".join(context[::-1])
164
+ return ret
165
+
166
+
167
+ def to_gradio_chatbot(prompt_generator):
168
+ ret = []
169
+ for i, (role, msg) in enumerate(prompt_generator.all_history):
170
+ if i % 2 == 0:
171
+ if type(msg) is tuple:
172
+ import base64
173
+ from io import BytesIO
174
+
175
+ msg, image = msg
176
+ if type(image) is str:
177
+ from PIL import Image
178
+
179
+ image = Image.open(image)
180
+ max_hw, min_hw = max(image.size), min(image.size)
181
+ aspect_ratio = max_hw / min_hw
182
+ max_len, min_len = 800, 400
183
+ shortest_edge = int(
184
+ min(max_len / aspect_ratio, min_len, min_hw))
185
+ longest_edge = int(shortest_edge * aspect_ratio)
186
+ H, W = image.size
187
+ if H > W:
188
+ H, W = longest_edge, shortest_edge
189
+ else:
190
+ H, W = shortest_edge, longest_edge
191
+ image = image.resize((H, W))
192
+ # image = image.resize((224, 224))
193
+ buffered = BytesIO()
194
+ image.save(buffered, format="JPEG")
195
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
196
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
197
+ msg = msg + img_str
198
+ ret.append([msg, None])
199
+ else:
200
+ ret[-1][-1] = msg
201
+ return ret
202
+
203
+
204
+ def bot(
205
+ text,
206
+ image,
207
+ state,
208
+ prompt,
209
+ ai_prefix,
210
+ user_prefix,
211
+ seperator,
212
+ history_buffer,
213
+ max_new_token,
214
+ num_beams,
215
+ temperature,
216
+ top_k,
217
+ top_p,
218
+ do_sample,
219
+ ):
220
+ state.prompt_template = prompt
221
+ state.ai_prefix = ai_prefix
222
+ state.user_prefix = user_prefix
223
+ state.sep = seperator
224
+ state.buffer_size = history_buffer
225
+ if image:
226
+ state.add_message(user_prefix, (text, image))
227
+ else:
228
+ state.add_message(user_prefix, text)
229
+ state.add_message(ai_prefix, None)
230
+ inputs = state.get_prompt()
231
+ image_paths = state.get_images()[-1:]
232
+
233
+ inference_results = inferencer(inputs, image_paths, max_new_token,
234
+ num_beams, temperature, top_k, top_p,
235
+ do_sample)
236
+ state.all_history[-1][-1] = inference_results
237
+ memory_allocated = str(round(torch.cuda.memory_allocated() / 1024**3,
238
+ 2)) + 'GB'
239
+ return state, to_gradio_chatbot(state), "", None, inputs, memory_allocated
240
+
241
+
242
+ def clear(state):
243
+ state.all_history = []
244
+ return state, to_gradio_chatbot(state), "", None, ""
245
+
246
+
247
+ title_markdown = ("""
248
+ # 🤖 Multi-modal GPT
249
+ [[Project]](https://github.com/open-mmlab/Multimodal-GPT.git)""")
250
+
251
+
252
+ def build_conversation_demo():
253
+ with gr.Blocks(title="Multi-modal GPT") as demo:
254
+ gr.Markdown(title_markdown)
255
+
256
+ state = gr.State(PromptGenerator())
257
+ with gr.Row():
258
+ with gr.Column(scale=3):
259
+ memory_allocated = gr.Textbox(
260
+ value=init_memory, label="Memory")
261
+ imagebox = gr.Image(type="filepath")
262
+ # TODO config parameters
263
+ with gr.Accordion(
264
+ "Parameters",
265
+ open=True,
266
+ ):
267
+ max_new_token_bar = gr.Slider(
268
+ 0, 1024, 512, label="max_new_token", step=1)
269
+ num_beams_bar = gr.Slider(
270
+ 0.0, 10, 3, label="num_beams", step=1)
271
+ temperature_bar = gr.Slider(
272
+ 0.0, 1.0, 1.0, label="temperature", step=0.01)
273
+ topk_bar = gr.Slider(0, 100, 20, label="top_k", step=1)
274
+ topp_bar = gr.Slider(0, 1.0, 1.0, label="top_p", step=0.01)
275
+ do_sample = gr.Checkbox(True, label="do_sample")
276
+ with gr.Accordion(
277
+ "Prompt",
278
+ open=False,
279
+ ):
280
+ with gr.Row():
281
+ ai_prefix = gr.Text("Response", label="AI Prefix")
282
+ user_prefix = gr.Text(
283
+ "Instruction", label="User Prefix")
284
+ seperator = gr.Text("\n\n### ", label="Seperator")
285
+ history_buffer = gr.Slider(
286
+ -1, 10, -1, label="History buffer", step=1)
287
+ prompt = gr.Text(TEMPLATE, label="Prompt")
288
+ model_inputs = gr.Textbox(label="Actual inputs for Model")
289
+
290
+ with gr.Column(scale=6):
291
+ with gr.Row():
292
+ with gr.Column():
293
+ chatbot = gr.Chatbot(elem_id="chatbot", type="messages")
294
+ with gr.Row():
295
+ with gr.Column(scale=8):
296
+ textbox = gr.Textbox(
297
+ show_label=False,
298
+ placeholder="Enter text and press ENTER",
299
+ container=False)
300
+ submit_btn = gr.Button(value="Submit")
301
+ clear_btn = gr.Button(value="🗑️ Clear history")
302
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
303
+ gr.Examples(
304
+ examples=[
305
+ [
306
+ f"{cur_dir}/docs/images/demo_image.jpg",
307
+ "What is in this image?"
308
+ ],
309
+ ],
310
+ inputs=[imagebox, textbox],
311
+ )
312
+ textbox.submit(
313
+ bot,
314
+ [
315
+ textbox,
316
+ imagebox,
317
+ state,
318
+ prompt,
319
+ ai_prefix,
320
+ user_prefix,
321
+ seperator,
322
+ history_buffer,
323
+ max_new_token_bar,
324
+ num_beams_bar,
325
+ temperature_bar,
326
+ topk_bar,
327
+ topp_bar,
328
+ do_sample,
329
+ ],
330
+ [
331
+ state, chatbot, textbox, imagebox, model_inputs,
332
+ memory_allocated
333
+ ],
334
+ )
335
+ submit_btn.click(
336
+ bot,
337
+ [
338
+ textbox,
339
+ imagebox,
340
+ state,
341
+ prompt,
342
+ ai_prefix,
343
+ user_prefix,
344
+ seperator,
345
+ history_buffer,
346
+ max_new_token_bar,
347
+ num_beams_bar,
348
+ temperature_bar,
349
+ topk_bar,
350
+ topp_bar,
351
+ do_sample,
352
+ ],
353
+ [
354
+ state, chatbot, textbox, imagebox, model_inputs,
355
+ memory_allocated
356
+ ],
357
+ )
358
+ clear_btn.click(clear, [state],
359
+ [state, chatbot, textbox, imagebox, model_inputs])
360
+ return demo
361
+
362
+
363
+ if __name__ == "__main__":
364
+ llama_path = "checkpoints/llama-7b_hf"
365
+ open_flamingo_path = "checkpoints/OpenFlamingo-9B/checkpoint.pt"
366
+ finetune_path = "checkpoints/mmgpt-lora-v0-release.pt"
367
+
368
+ inferencer = Inferencer(
369
+ llama_path=llama_path,
370
+ open_flamingo_path=open_flamingo_path,
371
+ finetune_path=finetune_path)
372
+ init_memory = str(round(torch.cuda.memory_allocated() / 1024**3, 2)) + 'GB'
373
+ demo = build_conversation_demo()
374
+ demo.queue(max_size=3)
375
+ IP = "0.0.0.0"
376
+ PORT = 8997
377
+ demo.launch(server_name=IP, server_port=PORT, share=True)
378
+
379
+
configs/dataset_config.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ visual_datasets = [
2
+ dict(
3
+ type="llava",
4
+ vis_root="data/coco/train2017",
5
+ ann_paths=[
6
+ "data/llava/detail_23k.json",
7
+ "data/llava/complex_reasoning_77k.json",
8
+ ],
9
+ ),
10
+ dict(
11
+ type="llava_dial",
12
+ vis_root="data/coco/train2017",
13
+ ann_paths=[
14
+ "data/llava/conversation_58k.json",
15
+ ],
16
+ ),
17
+ dict(
18
+ type="aokvqa",
19
+ vis_root="data/coco/images",
20
+ ann_paths=[
21
+ "data/aokvqa/annotations/aokvqa_v1p0_train.json",
22
+ ],
23
+ sample=5000,
24
+ ),
25
+ dict(
26
+ type="minigpt4",
27
+ vis_root="data/cc_sbu_align/image",
28
+ ann_paths=[
29
+ "data/cc_sbu_align/filter_cap.json",
30
+ ],
31
+ ),
32
+ dict(
33
+ type="coco_caption",
34
+ vis_root="data/coco",
35
+ ann_paths=[
36
+ "data/coco/annotations/coco_karpathy_train_converted.json",
37
+ "data/coco/annotations/coco_karpathy_val.json",
38
+ ],
39
+ sample=512,
40
+ ),
41
+ dict(
42
+ type="ocr_vqa",
43
+ vis_root="data/OCR_VQA/image",
44
+ ann_paths=[
45
+ "data/OCR_VQA/downloaded_dataset.json",
46
+ ],
47
+ sample=512,
48
+ ),
49
+ ]
50
+
51
+ language_datasets = [
52
+ dict(
53
+ type="dolly",
54
+ ann_path="data/dolly/databricks-dolly-15k.jsonl",
55
+ ),
56
+ dict(
57
+ type="alpaca_gpt4",
58
+ ann_path="data/alpaca_gpt4/alpaca_gpt4_data.json",
59
+ ),
60
+ dict(
61
+ type="baize",
62
+ ann_path="data/baize/quora_chat_data.json",
63
+ ),
64
+ ]
configs/lora_config.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ tuning_config = dict(
2
+ lora=True,
3
+ lora_target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "to_q", "to_kv", "to_out", "ff.1", "ff.3"],
4
+ lora_r=16,
5
+ lora_alpha=16,
6
+ lora_dropout=0.0,
7
+ vis=True,
8
+ unfrozen=[],
9
+ )
docs/images/demo_image.jpg ADDED
environment.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ name: mmgpt
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - python=3.9
6
+ - conda-forge::openjdk
7
+ - pip
8
+ - pip:
9
+ - -r requirements.txt
10
+ - -e .
mmgpt/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .models.builder import create_model_and_transforms
2
+ from .models.open_flamingo import Flamingo
mmgpt/datasets/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .builder import build_dataset # noqa: F401
2
+ from .dial_dataset import DialDataset # noqa: F401
3
+ from .samplers import InfiniteSampler # noqa: F401
4
+ from .vqa_dataset import VQADataset # noqa: F401
mmgpt/datasets/alpaca_gpt4_dataset.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from mmgpt.datasets.dolly_dataset import DollyDataset
4
+
5
+
6
+ class AlpacaGPT4Dataset(DollyDataset):
7
+ """
8
+ ```json
9
+ [
10
+ {
11
+ "instruction": "Identify the odd one out.",
12
+ "input": "Twitter, Instagram, Telegram",
13
+ "output": "The odd one out is Telegram. Twitter and Instagram are social media platforms mainly for sharing information, images and videos while Telegram is a cloud-based instant messaging and voice-over-IP service."
14
+ },
15
+ ]
16
+ """
17
+
18
+ def load_annotation(self, ann_path):
19
+ self.annotation = json.load(open(ann_path, "r"))
20
+
21
+ def process_text(self, ann):
22
+ instruction = ann["instruction"]
23
+ input = ann["input"]
24
+ output = ann["output"]
25
+ instruction = self.prompter(instruction=instruction, input=input)
26
+ return dict(instruction=instruction, answer=output)
mmgpt/datasets/aokvqa_dataset.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ from .vqa_dataset import VQADataset
4
+
5
+ REASON_QUESTIONS = [
6
+ "Why?",
7
+ "Why is this?",
8
+ "And why?",
9
+ "What is the reason?",
10
+ "And can you tell me why?",
11
+ "Can you tell me why?",
12
+ "Can you tell me the reason?",
13
+ ]
14
+
15
+
16
+ class AOKVQADataset(VQADataset):
17
+ def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs):
18
+ super().__init__(tokenizer, vis_processor, vis_root, ann_paths, **kwargs)
19
+
20
+ def process_text(self, ann):
21
+ question = ann["question"]
22
+ question = question + " " + random.choice(REASON_QUESTIONS)
23
+
24
+ choices = ann["choices"]
25
+ true_answer = choices[ann["correct_choice_idx"]]
26
+ answer = "The answer is " + true_answer + ". Because " + " ".join(ann["rationales"])
27
+
28
+ is_option = random.random() < self.option_prob and len(choices) > 1
29
+ if is_option:
30
+ instruction = self.prompter(question, choices)
31
+ else:
32
+ instruction = self.prompter(question)
33
+
34
+ instruction = self.prompter(question)
35
+ return dict(instruction=instruction, answer=answer)
36
+
37
+
38
+ def build_aokvqa_dataset(
39
+ tokenizer,
40
+ vis_processor,
41
+ vis_root="data/coco/images",
42
+ ann_paths=["data/aokvqa/annotations/aokvqa_v1p0_train.json"],
43
+ sample_image=False,
44
+ ):
45
+ return AOKVQADataset(
46
+ tokenizer=tokenizer,
47
+ vis_processor=vis_processor,
48
+ vis_root=vis_root,
49
+ ann_paths=ann_paths,
50
+ sample_image=sample_image,
51
+ )
mmgpt/datasets/baize_dataset.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from mmgpt.datasets.dolly_dataset import DollyDataset
4
+
5
+
6
+ TEMPLATE = {
7
+ "description": "Template used by Alpaca-LoRA.",
8
+ "prompt_choice": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{question}\n\n### Input:\n{options}\n\n### Response:\n",
9
+ "prompt_qa": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{question}\n\n### Response:\n",
10
+ "prompt_dial": "\n\n### Instruction:\n{question}\n\n### Response:\n",
11
+ "response_split": "### Response:",
12
+ }
13
+
14
+ class LangDialPrompter:
15
+ def __call__(self, question, options=None):
16
+ if options:
17
+ options = ", ".join(options)
18
+ res = TEMPLATE["prompt_choice"].format(image="<image>", question=question, options=options)
19
+ else:
20
+ res = TEMPLATE["prompt_dial"].format(question=question)
21
+ return res
22
+
23
+ def get_response(self, output: str) -> str:
24
+ return output.split(TEMPLATE["response_split"])[-1].strip()
25
+
26
+ class BaiZeDataset(DollyDataset):
27
+ """
28
+ ```json
29
+ [
30
+ {
31
+ "instruction": "Identify the odd one out.",
32
+ "input": "Twitter, Instagram, Telegram",
33
+ "output": "The odd one out is Telegram. Twitter and Instagram are social media platforms mainly for sharing information, images and videos while Telegram is a cloud-based instant messaging and voice-over-IP service."
34
+ },
35
+ ]
36
+ """
37
+ def __init__(self, *args, **kwargs):
38
+ super(BaiZeDataset, self).__init__(*args, **kwargs)
39
+ self.prompter = LangDialPrompter()
40
+
41
+ def load_annotation(self, ann_path):
42
+ self.annotation = json.load(open(ann_path, "r"))
43
+
44
+ def process_text(self, anns):
45
+ # TODO remove this
46
+ begin_string = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
47
+ convs = anns['input'].split("[|Human|] ")
48
+ conv_list = []
49
+ for conv_id, one_conv in enumerate(convs[1:-1]):
50
+ question, answer = one_conv.split("[|AI|] ")
51
+ question = question.replace("\n", "")
52
+ answer = answer.replace("\n", "")
53
+ instruction = self.prompter(question)
54
+ if conv_id == 0:
55
+ single_conv = dict(instruction=begin_string + instruction, answer=answer)
56
+ else:
57
+ single_conv = dict(instruction=instruction, answer=answer)
58
+ conv_list.append(single_conv)
59
+ return conv_list
60
+
61
+ def __getitem__(self, index):
62
+ ann = self.annotation[index]
63
+ text_list = self.process_text(ann)
64
+ res_list = []
65
+ for text in text_list:
66
+ single_res = self.tokenize(text)
67
+ single_res["instruction"] = text["instruction"]
68
+ single_res["answer"] = text["answer"]
69
+ res_list.append(single_res)
70
+
71
+ input_ids = []
72
+ attention_mask = []
73
+ labels = []
74
+ instruction = []
75
+ answer = []
76
+ for res in res_list:
77
+ input_ids.extend(res["input_ids"])
78
+ attention_mask.extend(res["attention_mask"])
79
+ labels.extend(res["labels"])
80
+ instruction.append(res["instruction"])
81
+ answer.append(res["answer"])
82
+
83
+ res = dict(
84
+ input_ids=input_ids, attention_mask=attention_mask, labels=labels, instruction=instruction, answer=answer
85
+ )
86
+ return res
mmgpt/datasets/builder.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from .alpaca_gpt4_dataset import AlpacaGPT4Dataset # noqa: F401
5
+ from .aokvqa_dataset import AOKVQADataset # noqa: F401
6
+ from .cc_sbu_align_dataset import CcSbuAlignDataset # noqa: F401
7
+ from .clevr_dataset import CLEVRDataset # noqa: F401
8
+ from .coco_caption_dataset import COCOCaptionDataset # noqa: F401
9
+ from .dial_dataset import DialDataset # noqa: F401
10
+ from .dolly_dataset import DollyDataset # noqa: F401
11
+ from .gqa_dataset import GQADataset # noqa: F401
12
+ from .llava_dataset import LlavaDataset # noqa: F401
13
+ from .nlvr_dataset import NLVRv1Dataset, NLVRv2Dataset # noqa: F401
14
+ from .ocr_vqa_dataset import OCRVQADataset # noqa: F401
15
+ from .snli_ve_datasets import SNLIVEDataset # noqa: F401
16
+ from .text_ocr_dataset import TextOCRDataset # noqa: F401
17
+ from .vqa_dataset import ConcatDataset, VQADataset # noqa: F401
18
+ from .baize_dataset import BaiZeDataset # noqa: F401
19
+
20
+
21
+ def build_dataset(dataset_config, **kwargs):
22
+ if isinstance(dataset_config, list):
23
+ datasets = [build_dataset(cfg, **kwargs) for cfg in dataset_config]
24
+ return ConcatDataset(datasets)
25
+ dataset_type = dataset_config.pop("type")
26
+ sample = dataset_config.pop("sample", -1)
27
+ if dataset_type == "llava":
28
+ dataset = LlavaDataset(
29
+ **dataset_config,
30
+ **kwargs,
31
+ )
32
+ elif dataset_type == "vqa":
33
+ dataset = VQADataset(
34
+ **dataset_config,
35
+ **kwargs,
36
+ )
37
+ elif dataset_type == "minigpt4":
38
+ dataset = CcSbuAlignDataset(
39
+ **dataset_config,
40
+ **kwargs,
41
+ )
42
+ elif dataset_type == "llava_dial":
43
+ dataset = DialDataset(
44
+ **dataset_config,
45
+ **kwargs,
46
+ )
47
+ elif dataset_type == "coco_dial":
48
+ dataset = DialDataset(
49
+ **dataset_config,
50
+ **kwargs,
51
+ )
52
+ elif dataset_type == "aokvqa":
53
+ dataset = AOKVQADataset(
54
+ **dataset_config,
55
+ **kwargs,
56
+ )
57
+ elif dataset_type == "okvqa":
58
+ dataset = VQADataset(
59
+ **dataset_config,
60
+ **kwargs,
61
+ )
62
+ elif dataset_type == "text_ocr":
63
+ dataset = TextOCRDataset(
64
+ **dataset_config,
65
+ **kwargs,
66
+ )
67
+ elif dataset_type == "ocr_vqa":
68
+ dataset = OCRVQADataset(
69
+ **dataset_config,
70
+ **kwargs,
71
+ )
72
+ elif dataset_type == "coco_caption":
73
+ dataset = COCOCaptionDataset(
74
+ **dataset_config,
75
+ **kwargs,
76
+ )
77
+ elif dataset_type == "gqa":
78
+ dataset = GQADataset(
79
+ **dataset_config,
80
+ **kwargs,
81
+ )
82
+ elif dataset_type == "clevr":
83
+ dataset = CLEVRDataset(
84
+ **dataset_config,
85
+ **kwargs,
86
+ )
87
+ elif dataset_type == "nlvrv1":
88
+ dataset = NLVRv1Dataset(
89
+ **dataset_config,
90
+ **kwargs,
91
+ )
92
+ elif dataset_type == "nlvrv2":
93
+ dataset = NLVRv2Dataset(
94
+ **dataset_config,
95
+ **kwargs,
96
+ )
97
+ elif dataset_type == "snlive":
98
+ dataset = SNLIVEDataset(
99
+ **dataset_config,
100
+ **kwargs,
101
+ )
102
+ elif dataset_type == "dolly":
103
+ dataset = DollyDataset(
104
+ **dataset_config,
105
+ **kwargs,
106
+ )
107
+ elif dataset_type == "alpaca_gpt4":
108
+ dataset = AlpacaGPT4Dataset(
109
+ **dataset_config,
110
+ **kwargs,
111
+ )
112
+ elif dataset_type == "baize":
113
+ dataset = BaiZeDataset(
114
+ **dataset_config,
115
+ **kwargs,
116
+ )
117
+ else:
118
+ raise NotImplementedError
119
+
120
+ if sample > 0:
121
+ random_indices = np.random.choice(len(dataset), min(sample, len(dataset)), replace=False)
122
+ subsample_dataset = torch.utils.data.Subset(dataset, random_indices)
123
+ subsample_dataset.collater = dataset.collater
124
+ return subsample_dataset
125
+ else:
126
+ return dataset
mmgpt/datasets/cc_sbu_align_dataset.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+
5
+ from PIL import Image
6
+
7
+ from .vqa_dataset import VQADataset, VQAPrompter
8
+
9
+ QUESTIONS = [
10
+ "please describe the image",
11
+ "can you describe the image",
12
+ "Could you provide a description of the image?",
13
+ "What do you see in this image?",
14
+ "Share your thoughts on the content of the image.",
15
+ "Please narrate what's happening in the picture.",
16
+ "Can you give a brief explanation of the image?",
17
+ "Describe the main elements and details present in the image.",
18
+ "In your own words, what is depicted in the image?",
19
+ "Can you outline the key aspects of the image?",
20
+ "What are the most striking features in this image?",
21
+ "Please provide a summary of the image's content.",
22
+ "Describe the overall theme or concept captured in the image.",
23
+ "How would you explain the image's composition and focus?",
24
+ "What is the focal point or main subject of the image?",
25
+ "How do the different components of the image interact with each other?",
26
+ "What would be a fitting caption for this image?",
27
+ "Can you create a concise description that captures the essence of the image?",
28
+ "How would you briefly summarize the content of this image in a phrase or sentence?",
29
+ "Please provide a catchy and relevant caption for this picture.",
30
+ "If you were to give this image a title, what would it be?",
31
+ "Describe the image in one creative sentence.",
32
+ "Please suggest a memorable phrase that encapsulates the image's content.",
33
+ "What engaging phrase would best represent this image?",
34
+ "Can you create an expressive caption that highlights the main theme of the image?",
35
+ "How would you sum up the image's story for a caption?",
36
+ "Provide an eye-catching caption that conveys the image's core message.",
37
+ "If you were to give this image a headline, what would it say?",
38
+ "Can you craft a captivating caption that communicates the essence of the image?",
39
+ "How would you describe the image's content in a powerful caption?",
40
+ "Please provide an inventive title to summarize the scene depicted in the image.",
41
+ "Compose a concise and striking phrase that reflects the image's key elements.",
42
+ "If you were to create a caption for this image, what would it be?",
43
+ "Offer a compelling caption that highlights the central focus of the image.",
44
+ "Can you produce a unique caption that encapsulates the image's overall mood?",
45
+ "Please generate an attention-grabbing caption that would best illustrate the events captured in this image",
46
+ "How would you express the image's main idea in an impactful sentence?",
47
+ "Please create a vivid and concise title that conveys the essence of the picture.",
48
+ "Compose an imaginative caption that reflects the image's most striking features.",
49
+ "What memorable statement would best represent the scene illustrated in this image?",
50
+ "Draft an evocative caption that brings the image to life for the reader.",
51
+ "Can you suggest an insightful caption that highlights the underlying message of the image?",
52
+ "What engaging phrase would effectively convey the action or subject matter depicted in this picture?",
53
+ "How would you encapsulate the image's core theme in a concise and expressive manner?",
54
+ "Please provide a creative and impactful title that captures the spirit of the image.",
55
+ "Craft a captivating caption that showcases the image's most prominent attributes.",
56
+ "What intriguing statement would best sum up the scene presented in this image?",
57
+ "Develop a descriptive caption that paints a vivid picture for the viewer.",
58
+ "Can you give a detailed account of the image's contents?",
59
+ "What are the key elements and features visible in this image?",
60
+ "How would you narrate the events or actions depicted in the picture?",
61
+ "Please share your observations about the various components present in the image.",
62
+ "What is the overall theme or concept captured in this image? Can you describe it?",
63
+ ]
64
+
65
+
66
+ class CcSbuAlignDataset(VQADataset):
67
+ def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, add_eos=True, ignore_instruction=True):
68
+ self.tokenizer = tokenizer
69
+ self.vis_root = vis_root
70
+
71
+ self.annotation = []
72
+ for ann_path in ann_paths:
73
+ self.annotation.extend(json.load(open(ann_path, "r"))["annotations"])
74
+
75
+ self.vis_processor = vis_processor
76
+ self.prompter = VQAPrompter()
77
+ self.add_eos = add_eos
78
+ self.ignore_instruction = ignore_instruction
79
+
80
+ def process_text(self, ann):
81
+ # random select a question
82
+ question = random.choice(QUESTIONS)
83
+ answer = ann["caption"]
84
+ instruction = self.prompter(question)
85
+ return dict(instruction=instruction, answer=answer)
86
+
87
+ def process_image(self, ann):
88
+ image_path = os.path.join(self.vis_root, ann["image_id"] + ".jpg")
89
+ image = Image.open(image_path).convert("RGB")
90
+
91
+ image = self.vis_processor(image)
92
+ return image
93
+
94
+
95
+ def build_ccsbualign_dataset(
96
+ tokenizer,
97
+ vis_processor,
98
+ vis_root="data/cc_sbu_align/image/",
99
+ ann_paths=["data/cc_sbu_align/filter_cap.json"],
100
+ **kwargs,
101
+ ):
102
+ return CcSbuAlignDataset(
103
+ tokenizer=tokenizer,
104
+ vis_processor=vis_processor,
105
+ vis_root=vis_root,
106
+ ann_paths=ann_paths,
107
+ )
mmgpt/datasets/clevr_dataset.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ from collections import defaultdict
5
+
6
+ from PIL import Image
7
+
8
+ from .vqa_dataset import VQADataset
9
+
10
+
11
+ class CLEVRDataset(VQADataset):
12
+ """Visual Reasoning Dataset. It also contains Dialog.
13
+
14
+ Note: The image is a little bit simple. with several objects and simple background.
15
+ """
16
+
17
+ def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs):
18
+ super().__init__(tokenizer, vis_processor, vis_root, ann_paths=[], **kwargs)
19
+
20
+ self.annotation = self.load_annotations(ann_paths)
21
+ if self.sample_image:
22
+ print("randomly sample one annotation for each image")
23
+ self.annotation = self.parse_annotation(self.annotation)
24
+ self._add_instance_ids()
25
+
26
+ @staticmethod
27
+ def load_annotations(ann_paths):
28
+ annotation = []
29
+ for ann_path in ann_paths:
30
+ ann = json.load(open(ann_path, "r"))
31
+ annotation.extend(ann["questions"])
32
+ return annotation
33
+
34
+ def parse_annotation(self, annotation):
35
+ image_list = defaultdict(list)
36
+ for ann in annotation:
37
+ image_list[ann["image_filename"]].append(ann)
38
+ annotation = []
39
+ for ann_list in image_list.values():
40
+ annotation.append(random.choice(ann_list))
41
+ return annotation
42
+
43
+ def process_text(self, ann):
44
+ question = ann["question"]
45
+ answer = ann["answer"]
46
+ instruction = self.prompter(question)
47
+ return dict(instruction=instruction, answer=answer)
48
+
49
+ def process_image(self, ann):
50
+ split = ann["split"]
51
+ image_path = os.path.join(self.vis_root, split, ann["image_filename"])
52
+ image = Image.open(image_path).convert("RGB")
53
+
54
+ image = self.vis_processor(image)
55
+ return image
56
+
57
+
58
+ def build_clevr_dataset(
59
+ tokenizer,
60
+ vis_processor,
61
+ vis_root="data/clevr/CLEVR_v1.0/images",
62
+ ann_paths=[
63
+ "data/clevr/CLEVR_v1.0/questions/CLEVR_train_questions.json",
64
+ "data/clevr/CLEVR_v1.0/questions/CLEVR_val_questions.json",
65
+ ],
66
+ sample_image=False,
67
+ ):
68
+ return CLEVRDataset(
69
+ tokenizer=tokenizer,
70
+ vis_processor=vis_processor,
71
+ vis_root=vis_root,
72
+ ann_paths=ann_paths,
73
+ sample_image=sample_image,
74
+ )
mmgpt/datasets/coco_caption_dataset.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import json
9
+ import os
10
+ import random
11
+
12
+ import numpy as np
13
+ from PIL import Image
14
+ from transformers import LlamaTokenizer
15
+
16
+ from .vqa_dataset import VQADataset
17
+
18
+ QUESTIONS = [
19
+ "please describe the image",
20
+ "can you describe the image",
21
+ "Could you provide a description of the image?",
22
+ "What do you see in this image?",
23
+ "Share your thoughts on the content of the image.",
24
+ "Please narrate what's happening in the picture.",
25
+ "Can you give a brief explanation of the image?",
26
+ "Describe the main elements and details present in the image.",
27
+ "In your own words, what is depicted in the image?",
28
+ "Can you outline the key aspects of the image?",
29
+ "What are the most striking features in this image?",
30
+ "Please provide a summary of the image's content.",
31
+ "Describe the overall theme or concept captured in the image.",
32
+ "How would you explain the image's composition and focus?",
33
+ "What is the focal point or main subject of the image?",
34
+ "How do the different components of the image interact with each other?",
35
+ "What would be a fitting caption for this image?",
36
+ "Can you create a concise description that captures the essence of the image?",
37
+ "How would you briefly summarize the content of this image in a phrase or sentence?",
38
+ "Please provide a catchy and relevant caption for this picture.",
39
+ "If you were to give this image a title, what would it be?",
40
+ "Describe the image in one creative sentence.",
41
+ "Please suggest a memorable phrase that encapsulates the image's content.",
42
+ "What engaging phrase would best represent this image?",
43
+ "Can you create an expressive caption that highlights the main theme of the image?",
44
+ "How would you sum up the image's story for a caption?",
45
+ "Provide an eye-catching caption that conveys the image's core message.",
46
+ "If you were to give this image a headline, what would it say?",
47
+ "Can you craft a captivating caption that communicates the essence of the image?",
48
+ "How would you describe the image's content in a powerful caption?",
49
+ "Please provide an inventive title to summarize the scene depicted in the image.",
50
+ "Compose a concise and striking phrase that reflects the image's key elements.",
51
+ "If you were to create a caption for this image, what would it be?",
52
+ "Offer a compelling caption that highlights the central focus of the image.",
53
+ "Can you produce a unique caption that encapsulates the image's overall mood?",
54
+ "Please generate an attention-grabbing caption that would best illustrate the events captured in this image",
55
+ "How would you express the image's main idea in an impactful sentence?",
56
+ "Please create a vivid and concise title that conveys the essence of the picture.",
57
+ "Compose an imaginative caption that reflects the image's most striking features.",
58
+ "What memorable statement would best represent the scene illustrated in this image?",
59
+ "Draft an evocative caption that brings the image to life for the reader.",
60
+ "Can you suggest an insightful caption that highlights the underlying message of the image?",
61
+ "What engaging phrase would effectively convey the action or subject matter depicted in this picture?",
62
+ "How would you encapsulate the image's core theme in a concise and expressive manner?",
63
+ "Please provide a creative and impactful title that captures the spirit of the image.",
64
+ "Craft a captivating caption that showcases the image's most prominent attributes.",
65
+ "What intriguing statement would best sum up the scene presented in this image?",
66
+ "Develop a descriptive caption that paints a vivid picture for the viewer.",
67
+ "Can you give a detailed account of the image's contents?",
68
+ "What are the key elements and features visible in this image?",
69
+ "How would you narrate the events or actions depicted in the picture?",
70
+ "Please share your observations about the various components present in the image.",
71
+ "What is the overall theme or concept captured in this image? Can you describe it?",
72
+ ]
73
+
74
+
75
+ class COCOCaptionDataset(VQADataset):
76
+ def __init__(
77
+ self, tokenizer, vis_processor=None, vis_root=None, ann_paths=[], add_eos=True, ignore_instruction=True
78
+ ):
79
+ """
80
+ vis_root (string): Root directory of images (e.g. coco/images/)
81
+ ann_root (string): directory to store the annotation file
82
+ """
83
+ self.tokenizer: LlamaTokenizer = tokenizer
84
+ self.vis_root = vis_root
85
+
86
+ self.annotation = []
87
+ for ann_path in ann_paths:
88
+ self.annotation.extend(json.load(open(ann_path, "r")))
89
+
90
+ self.vis_processor = vis_processor
91
+
92
+ instructions = []
93
+ for question in QUESTIONS:
94
+ # instruction = f"Below is a question about an image. Write a response to answer the question.\n\n### Image:\n<image>\n\n### Question:\n{question}\n\n### Answer:\n".format(
95
+ # question
96
+ # )
97
+ instruction = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Image:\n{image}\n\n### Instruction:\n{question}\n\n### Response:\n".format(
98
+ image="<image>", question=question
99
+ )
100
+ instructions.append(instruction)
101
+ self.instructions = instructions
102
+ self.add_eos = add_eos
103
+ self.ignore_instruction = ignore_instruction
104
+
105
+ def process_image(self, ann):
106
+ image_path = os.path.join(self.vis_root, ann["image"])
107
+ image = Image.open(image_path).convert("RGB")
108
+
109
+ image = self.vis_processor(image)
110
+ return image
111
+
112
+ def process_text(self, ann):
113
+ all_captions = ann["caption"]
114
+ if not isinstance(all_captions, list):
115
+ all_captions = [all_captions]
116
+ caption = random.choice(all_captions)
117
+ instruction = random.choice(self.instructions)
118
+
119
+ return dict(instruction=instruction, answer=caption)
mmgpt/datasets/dial_dataset.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .vqa_dataset import VQADataset
2
+
3
+ TEMPLATE = {
4
+ "description": "Template used by Alpaca-LoRA.",
5
+ # "prompt_choice": "Below is a multiple choice question about an image, along with answer options. Please choose the correct answer from these options.\n\n### Image:\n{image}\n\n### Question:\n{question}\n\n### Options:\n{options}\n\n### Answer:\n",
6
+ # "prompt_qa": "Below is a question about an image. Write a response to answer the question.\n\n### Image:\n{image}\n\n### Question:\n{question}\n\n### Answer:\n",
7
+ "prompt_choice": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Image:\n{image}\n\n### Instruction:\n{question}\n\n### Input:\n{options}\n\n### Response:\n",
8
+ "prompt_qa": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Image:\n{image}\n\n### Instruction:\n{question}\n\n### Response:\n",
9
+ "prompt_dial": "\n\n### Instruction:\n{question}\n\n### Response:\n",
10
+ "response_split": "### Response:",
11
+ }
12
+
13
+
14
+ class DialPrompter:
15
+ def __call__(self, question, options=None):
16
+ if options:
17
+ options = ", ".join(options)
18
+ res = TEMPLATE["prompt_choice"].format(image="<image>", question=question, options=options)
19
+ else:
20
+ res = TEMPLATE["prompt_dial"].format(question=question)
21
+ return res
22
+
23
+ def get_response(self, output: str) -> str:
24
+ return output.split(TEMPLATE["response_split"])[-1].strip()
25
+
26
+
27
+ class DialDataset(VQADataset):
28
+ def __init__(self, *args, **kwargs):
29
+ super(DialDataset, self).__init__(*args, **kwargs)
30
+ self.prompter = DialPrompter()
31
+
32
+ def _add_instance_ids(self, key="id"):
33
+ for idx, ann in enumerate(self.annotation):
34
+ ann[key] = str(idx)
35
+
36
+ def process_text(self, anns):
37
+ # TODO remove this
38
+ begin_string = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Image:\n{image}".format(
39
+ image="<image>"
40
+ )
41
+ num_convs = len(anns["conversations"]) // 2
42
+ conv_list = []
43
+ for conv_id in range(num_convs):
44
+ question = anns["conversations"][conv_id]["value"]
45
+ # remove '<image>' tag and '\n'
46
+ question = question.replace("<image>", "").replace("\n", "")
47
+ answer = anns["conversations"][conv_id + 1]["value"]
48
+ instruction = self.prompter(question)
49
+ if conv_id == 0:
50
+ single_conv = dict(instruction=begin_string + instruction, answer=answer)
51
+ else:
52
+ single_conv = dict(instruction=instruction, answer=answer)
53
+ conv_list.append(single_conv)
54
+ return conv_list
55
+
56
+ def __getitem__(self, index):
57
+ ann = self.annotation[index]
58
+ image = self.process_image(ann)
59
+ text_list = self.process_text(ann)
60
+ res_list = []
61
+ for text in text_list:
62
+ single_res = self.tokenize(text)
63
+ single_res["instruction"] = text["instruction"]
64
+ single_res["answer"] = text["answer"]
65
+ res_list.append(single_res)
66
+
67
+ input_ids = []
68
+ attention_mask = []
69
+ labels = []
70
+ instruction = []
71
+ answer = []
72
+ for res in res_list:
73
+ input_ids.extend(res["input_ids"])
74
+ attention_mask.extend(res["attention_mask"])
75
+ labels.extend(res["labels"])
76
+ instruction.extend(res["instruction"])
77
+ answer.extend(res["answer"])
78
+
79
+ res = dict(
80
+ input_ids=input_ids, attention_mask=attention_mask, labels=labels, instruction=instruction, answer=answer
81
+ )
82
+ res.update(image=image)
83
+ return res
mmgpt/datasets/dolly_dataset.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+
4
+ import numpy as np
5
+ from torch.utils.data import Dataset
6
+ from transformers import LlamaTokenizer
7
+
8
+ TEMPLATE = {
9
+ "description": "Template used by LLM.",
10
+ "prompt_no_input_format": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
11
+ "prompt_with_input_format": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
12
+ "response_split": "### Response:",
13
+ }
14
+
15
+
16
+ class LMPrompter:
17
+ def __call__(self, instruction, input=None):
18
+ if input is None or len(input) == 0:
19
+ return TEMPLATE["prompt_no_input_format"].format(instruction=instruction)
20
+ else:
21
+ return TEMPLATE["prompt_with_input_format"].format(instruction=instruction, input=input)
22
+
23
+ def get_response(self, output: str) -> str:
24
+ return output.split(TEMPLATE["response_split"])[-1].strip()
25
+
26
+
27
+ class DollyDataset(Dataset):
28
+ """Each line of the annotation file is a json object with the following fields:
29
+
30
+ {
31
+ "instruction": "What is a dispersive prism?",
32
+ "context": "In optics, a dispersive prism is an optical prism that is used to disperse light, that is, to separate light into its spectral components (the colors of the rainbow). Different wavelengths (colors) of light will be deflected by the prism at different angles.[1] This is a result of the prism material's index of refraction varying with wavelength (dispersion). Generally, longer wavelengths (red) undergo a smaller deviation than shorter wavelengths (blue). The dispersion of white light into colors by a prism led Sir Isaac Newton to conclude that white light consisted of a mixture of different colors.",
33
+ "response": "A dispersive prism is an optical prism that disperses the light's different wavelengths at different angles. When white light is shined through a dispersive prism it will separate into the different colors of the rainbow.",
34
+ "category": "summarization"
35
+ }
36
+
37
+ """
38
+
39
+ def __init__(self, tokenizer, ann_path: str, add_eos=True, ignore_instruction=True, **kwargs):
40
+ """
41
+ ann_path (string): directory to store the annotation file
42
+ """
43
+ assert tokenizer.add_eos_token is False, "tokenizer should not add eos token by default"
44
+ self.tokenizer: LlamaTokenizer = tokenizer
45
+
46
+ self.annotation = []
47
+ self.prompter = LMPrompter()
48
+ self.add_eos = add_eos
49
+ self.ignore_instruction = ignore_instruction
50
+ self.load_annotation(ann_path)
51
+
52
+ def load_annotation(self, ann_path):
53
+ self.annotation = []
54
+ for line in open(ann_path, "r").readlines():
55
+ self.annotation.append(json.loads(line))
56
+
57
+ def __len__(self):
58
+ return len(self.annotation)
59
+
60
+ def process_text(self, ann):
61
+ instruction = ann["instruction"]
62
+ context = ann["context"]
63
+ response = ann["response"]
64
+ instruction = self.prompter(instruction=instruction, input=context)
65
+ return dict(instruction=instruction, answer=response)
66
+
67
+ def tokenize(self, text):
68
+ res = self.tokenizer(
69
+ text["instruction"] + text["answer"],
70
+ return_tensors=None,
71
+ padding="do_not_pad",
72
+ truncation=True,
73
+ max_length=512,
74
+ )
75
+
76
+ # manually add eos token
77
+ if res["input_ids"][-1] != self.tokenizer.eos_token_id and len(res["input_ids"]) < 512 and self.add_eos:
78
+ res["input_ids"].append(self.tokenizer.eos_token_id)
79
+ res["attention_mask"].append(1)
80
+ labels = copy.deepcopy(res["input_ids"])
81
+ # ignore instruction_token
82
+ if self.ignore_instruction:
83
+ instruction_token = self.tokenizer(
84
+ text["instruction"], return_tensors=None, padding="do_not_pad", truncation=True, max_length=512
85
+ )
86
+ labels = [-100] * len(instruction_token["input_ids"]) + labels[len(instruction_token["input_ids"]) :]
87
+
88
+ res.update(labels=labels)
89
+ return res
90
+
91
+ def __getitem__(self, index):
92
+ ann = self.annotation[index]
93
+ text = self.process_text(ann)
94
+ res = self.tokenize(text)
95
+ res.update(text)
96
+ return res
97
+
98
+ def collater(self, samples):
99
+ question_list, answer_list, input_id_list, attention_mask_list, labels_list = [], [], [], [], []
100
+
101
+ for sample in samples:
102
+ question_list.append(sample["instruction"])
103
+ answer_list.append(sample["answer"])
104
+ input_id_list.append(sample["input_ids"])
105
+ attention_mask_list.append(sample["attention_mask"])
106
+ labels_list.append(sample["labels"])
107
+
108
+ # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
109
+ # same length to return tensors.
110
+ max_label_length = max(len(l) for l in labels_list)
111
+ padding_side = self.tokenizer.padding_side
112
+ padded_labels = []
113
+ for l in labels_list:
114
+ remainder = [-100] * (max_label_length - len(l))
115
+ if isinstance(l, list):
116
+ l = l + remainder if padding_side == "right" else remainder + l
117
+ elif padding_side == "right":
118
+ l = np.concatenate([l, remainder]).astype(np.int64)
119
+ else:
120
+ l = np.concatenate([remainder, l]).astype(np.int64)
121
+ padded_labels.append(l)
122
+
123
+ padded_samples = self.tokenizer.pad(
124
+ {"input_ids": input_id_list, "attention_mask": attention_mask_list, "labels": padded_labels},
125
+ return_tensors="pt",
126
+ padding="longest",
127
+ )
128
+
129
+ labels = padded_samples["labels"]
130
+ labels[labels == self.tokenizer.pad_token_id] = -100
131
+ labels[:, 0] = -100
132
+ return {
133
+ "input_ids": padded_samples["input_ids"],
134
+ "attention_mask": padded_samples["attention_mask"],
135
+ "labels": labels,
136
+ "instruction": question_list,
137
+ "answer": answer_list,
138
+ }
139
+
140
+
141
+ def build_dolly_dataset(
142
+ tokenizer,
143
+ ann_path="data/dolly/databricks-dolly-15k.jsonl",
144
+ **kwargs,
145
+ ):
146
+ return DollyDataset(
147
+ tokenizer=tokenizer,
148
+ ann_path=ann_path,
149
+ **kwargs,
150
+ )
mmgpt/datasets/gqa_dataset.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ from collections import defaultdict
5
+
6
+ from PIL import Image
7
+
8
+ from .vqa_dataset import VQADataset
9
+
10
+
11
+ class GQADataset(VQADataset):
12
+ """Visual Reasoning Dataset."""
13
+
14
+ def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs):
15
+ super().__init__(tokenizer, vis_processor, vis_root, ann_paths=[], **kwargs)
16
+
17
+ self.annotation = self.load_annotations(ann_paths)
18
+ if self.sample_image:
19
+ print("randomly sample one annotation for each image")
20
+ self.annotation = self.parse_annotation(self.annotation)
21
+ self._add_instance_ids()
22
+ self.answer_prob = 1.0
23
+
24
+ @staticmethod
25
+ def load_annotations(ann_paths):
26
+ annotation = []
27
+ for ann_path in ann_paths:
28
+ ann = json.load(open(ann_path, "r"))
29
+ for k, v in ann.items():
30
+ v["question_id"] = k
31
+ annotation.append(v)
32
+ return annotation
33
+
34
+ def parse_annotation(self, annotation):
35
+ image_list = defaultdict(list)
36
+ for ann in annotation:
37
+ image_list[ann["imageId"]].append(ann)
38
+ annotation = []
39
+ for ann_list in image_list.values():
40
+ annotation.append(random.choice(ann_list))
41
+ return annotation
42
+
43
+ def process_text(self, ann):
44
+ question = ann["question"]
45
+
46
+ answer = ann["answer"]
47
+ full_answer = ann["fullAnswer"]
48
+
49
+ # TODO: check which one is better
50
+ # Random select answer or full_answer
51
+ if random.random() < self.answer_prob:
52
+ select_answer = full_answer
53
+ else:
54
+ select_answer = answer
55
+
56
+ instruction = self.prompter(question)
57
+ return dict(instruction=instruction, answer=select_answer)
58
+
59
+ def process_image(self, ann):
60
+ image_path = os.path.join(self.vis_root, ann["imageId"] + ".jpg")
61
+ image = Image.open(image_path).convert("RGB")
62
+
63
+ image = self.vis_processor(image)
64
+ return image
65
+
66
+
67
+ def build_gqa_dataset(
68
+ tokenizer,
69
+ vis_processor,
70
+ vis_root="data/gqa/images",
71
+ ann_paths=[
72
+ "data/gqa/questions/train_all_questions/train_all_questions_0.json",
73
+ "data/gqa/questions/val_all_questions.json",
74
+ ],
75
+ sample_image=False,
76
+ ):
77
+ return GQADataset(
78
+ tokenizer=tokenizer,
79
+ vis_processor=vis_processor,
80
+ vis_root=vis_root,
81
+ ann_paths=ann_paths,
82
+ sample_image=sample_image,
83
+ )
mmgpt/datasets/llava_dataset.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .vqa_dataset import VQADataset
2
+
3
+
4
+ class LlavaDataset(VQADataset):
5
+ def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs):
6
+ super().__init__(tokenizer, vis_processor, vis_root, ann_paths, **kwargs)
7
+
8
+ def _add_instance_ids(self, key="id"):
9
+ for idx, ann in enumerate(self.annotation):
10
+ ann[key] = str(idx)
11
+
12
+ def process_text(self, ann):
13
+ question = ann["conversations"][0]["value"]
14
+ # remove '<image>' tag and '\n'
15
+ question = question.replace("<image>", "").replace("\n", "")
16
+ answer = ann["conversations"][1]["value"]
17
+ instruction = self.prompter(question)
18
+ return dict(instruction=instruction, answer=answer)
mmgpt/datasets/nlvr_dataset.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import os
4
+ import random
5
+ from collections import defaultdict
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from PIL import Image
10
+
11
+ from .vqa_dataset import VQADataset
12
+
13
+ QUESTIONS = [
14
+ "Is this true?",
15
+ "Is this right?",
16
+ "Can you confirm this information?" "Do you agree with this statement?",
17
+ "Does this align with your understanding?",
18
+ "How do you interpret this information?",
19
+ "Does this align with your understanding?",
20
+ "Can you confirm this?",
21
+ "Is this statement correct?",
22
+ "Could you verify this information?",
23
+ "Do you agree with this?",
24
+ "Is this accurate?",
25
+ "Can you validate this claim?",
26
+ "Are these details valid?",
27
+ "Is this factually correct?",
28
+ "Is the following information correct?",
29
+ "Could you please verify this fact?",
30
+ "Do you agree with this assertion?",
31
+ "Are these details accurate?",
32
+ "Does this claim hold true?",
33
+ ]
34
+
35
+
36
+ class NLVRv1Dataset(VQADataset):
37
+ """Visual Reasoning Dataset."""
38
+
39
+ def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs):
40
+ super().__init__(tokenizer, vis_processor, vis_root, ann_paths=[], **kwargs)
41
+
42
+ self.annotation = self.load_annotations(ann_paths)
43
+ if self.sample_image:
44
+ print("randomly sample one annotation for each image")
45
+ self.annotation = self.parse_annotation(self.annotation)
46
+ self._add_instance_ids()
47
+
48
+ @staticmethod
49
+ def load_annotations(ann_paths):
50
+ annotation = []
51
+ for ann_path in ann_paths:
52
+ if "train.json" in ann_path:
53
+ split = "train"
54
+ elif "dev.json" in ann_path:
55
+ split = "dev"
56
+ elif "test.json" in ann_path:
57
+ split = "test"
58
+ else:
59
+ raise ValueError(f"Unknown split for {ann_path}")
60
+
61
+ with open(ann_path, "r") as f:
62
+ for line in f.readlines():
63
+ line = line.strip()
64
+ if len(line) != 0:
65
+ ann = json.loads(line)
66
+ ann["split"] = split
67
+ annotation.append(ann)
68
+
69
+ return annotation
70
+
71
+ def parse_annotation(self, annotation):
72
+ image_list = defaultdict(list)
73
+ for ann in annotation:
74
+ img_key = f"{ann['split']}-{ann['identifier']}"
75
+ image_list[img_key].append(ann)
76
+ annotation = []
77
+ for ann_list in image_list.values():
78
+ annotation.append(random.choice(ann_list))
79
+ return annotation
80
+
81
+ def process_text(self, ann):
82
+ question = ann["sentence"] + " " + random.choice(QUESTIONS)
83
+ true_answer = ann["label"]
84
+
85
+ if random.random() < self.option_prob:
86
+ instruction = self.prompter(question, ["true", "false"])
87
+ else:
88
+ instruction = self.prompter(question)
89
+
90
+ return dict(instruction=instruction, answer=true_answer)
91
+
92
+ def process_image(self, ann):
93
+ # each question have 6 images, we can random select one of them.
94
+ # TODO: check whether using all 6 images?
95
+ random_id = random.randint(0, 5)
96
+ image_name = f"{ann['split']}-{ann['identifier']}-{random_id}.png"
97
+ image_path = os.path.join(self.vis_root, ann["split"], "images", ann["directory"], image_name)
98
+ image = Image.open(image_path).convert("RGB")
99
+
100
+ image = self.vis_processor(image)
101
+ return image
102
+
103
+
104
+ class NLVRv2Dataset(VQADataset):
105
+ """Visual Reasoning Dataset."""
106
+
107
+ def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs):
108
+ super().__init__(tokenizer, vis_processor, vis_root, ann_paths, **kwargs)
109
+ self.flip_prob = 0.5
110
+
111
+ def parse_annotation(self, annotation):
112
+ image_list = defaultdict(list)
113
+ for ann in annotation:
114
+ image_list[ann["images"][0]].append(ann)
115
+ # image_name_list = list(image_list.keys())
116
+ annotation = []
117
+ for ann_list in image_list.values():
118
+ annotation.append(random.choice(ann_list))
119
+ return annotation
120
+
121
+ def process_text(self, ann):
122
+ question = ann["sentence"] + " " + random.choice(QUESTIONS)
123
+ true_answer = ann["label"]
124
+
125
+ if random.random() < self.option_prob:
126
+ instruction = self.prompter(question, ["true", "false"])
127
+ else:
128
+ instruction = self.prompter(question)
129
+
130
+ return dict(instruction=instruction, answer=true_answer)
131
+
132
+ def process_image(self, ann):
133
+ image_0_path = os.path.join(self.vis_root, ann["images"][0])
134
+ image_1_path = os.path.join(self.vis_root, ann["images"][1])
135
+
136
+ image_0 = Image.open(image_0_path).convert("RGB")
137
+ image_1 = Image.open(image_1_path).convert("RGB")
138
+ image_0 = self.vis_processor(image_0)
139
+ image_1 = self.vis_processor(image_1)
140
+ return image_0, image_1
141
+
142
+ @staticmethod
143
+ def _flip(samples):
144
+ sentence = samples["sentence"]
145
+ image0, image1 = samples["image0"], samples["image1"]
146
+
147
+ if "left" not in sentence and "right" not in sentence:
148
+ if random.random() < 0.5:
149
+ image0, image1 = image1, image0
150
+ else:
151
+ if random.random() < 0.5:
152
+ sentence = sentence.replace("left", "[TEMP_TOKEN]")
153
+ sentence = sentence.replace("right", "left")
154
+ sentence = sentence.replace("[TEMP_TOKEN]", "right")
155
+
156
+ image0, image1 = image1, image0
157
+
158
+ samples["sentence"] = sentence
159
+ samples["image0"] = image0
160
+ samples["image1"] = image1
161
+
162
+ return samples
163
+
164
+ def __getitem__(self, index):
165
+ ann = copy.deepcopy(self.annotation[index])
166
+ image_0, image_1 = self.process_image(ann)
167
+ if random.random() < self.flip_prob:
168
+ samples = self._flip({"sentence": ann["sentence"], "image0": image_0, "image1": image_1})
169
+ image_0, image_1 = samples["image0"], samples["image1"]
170
+ ann["sentence"] = samples["sentence"]
171
+ # concat
172
+ # TODO: https://github.com/salesforce/LAVIS/blob/main/lavis/models/blip_models/blip_nlvr.py
173
+ # model logic need update if using nlvr2
174
+ image = torch.cat([image_0, image_1], dim=2)
175
+ image = F.interpolate(image[None, ...], size=(image_0.shape[1], image_0.shape[2]))[0]
176
+ text = self.process_text(ann)
177
+ res = self.tokenize(text)
178
+ res.update(image=image)
179
+ res.update(text)
180
+ return res
181
+
182
+
183
+ def build_nlvrv1_dataset(
184
+ tokenizer,
185
+ vis_processor,
186
+ vis_root="data/nlvr",
187
+ ann_paths=["data/nlvr//train/train.json"],
188
+ sample_image=False,
189
+ ):
190
+ return NLVRv1Dataset(
191
+ tokenizer=tokenizer,
192
+ vis_processor=vis_processor,
193
+ vis_root=vis_root,
194
+ ann_paths=ann_paths,
195
+ sample_image=sample_image,
196
+ )
197
+
198
+
199
+ def build_nlvrv2_dataset(
200
+ tokenizer,
201
+ vis_processor,
202
+ vis_root="data/nlvr2",
203
+ ann_paths=["data/nlvr2/annotations/nlvr_train.json"],
204
+ sample_image=False,
205
+ ):
206
+ return NLVRv2Dataset(
207
+ tokenizer=tokenizer,
208
+ vis_processor=vis_processor,
209
+ vis_root=vis_root,
210
+ ann_paths=ann_paths,
211
+ sample_image=sample_image,
212
+ )
mmgpt/datasets/ocr_vqa_dataset.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ from PIL import Image
5
+
6
+ from .vqa_dataset import VQADataset
7
+
8
+
9
+ class OCRVQADataset(VQADataset):
10
+ def process_image(self, ann):
11
+ image_path = os.path.join(self.vis_root, ann["filename"])
12
+ image = Image.open(image_path).convert("RGB")
13
+
14
+ image = self.vis_processor(image)
15
+ return image
16
+
17
+ def process_text(self, ann):
18
+ index = random.choice(list(range(len(ann["questions"]))))
19
+ question = ann["questions"][index]
20
+ answer = ann["answers"][index]
21
+
22
+ instruction = self.prompter(question)
23
+ return dict(instruction=instruction, answer=answer)
mmgpt/datasets/samplers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .infinite_sampler import InfiniteSampler
mmgpt/datasets/samplers/infinite_sampler.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+
3
+ import torch
4
+ from torch.utils.data.sampler import Sampler
5
+
6
+ from mmgpt.train.distributed import world_info_from_env
7
+
8
+
9
+ class InfiniteSampler(Sampler):
10
+ def __init__(self, dataset: int, shuffle: bool = True, seed: int = 0):
11
+ self._size = len(dataset)
12
+ self._shuffle = shuffle
13
+ self._seed = int(seed)
14
+ _, rank, world_size = world_info_from_env()
15
+
16
+ self._rank = rank
17
+ self._world_size = world_size
18
+
19
+ def __iter__(self):
20
+ start = self._rank
21
+ yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
22
+
23
+ def _infinite_indices(self):
24
+ g = torch.Generator()
25
+ g.manual_seed(self._seed)
26
+ while True:
27
+ if self._shuffle:
28
+ yield from torch.randperm(self._size, generator=g).tolist()
29
+ else:
30
+ yield from torch.arange(self._size).tolist()
mmgpt/datasets/snli_ve_datasets.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ from collections import defaultdict
5
+
6
+ from PIL import Image
7
+
8
+ from .vqa_dataset import VQADataset
9
+
10
+ QUESTIONS = [
11
+ "What do you think of the above sentence?",
12
+ "Can you confirm this statement?",
13
+ "How do you interpret the given information?",
14
+ "What is your opinion on this matter?",
15
+ "Could you provide your perspective on this statement?",
16
+ "How would you respond to the provided claim?",
17
+ "What are your thoughts regarding the mentioned subject?",
18
+ "Can you elaborate on this idea in English?",
19
+ "Do you have any insights or feedback on this topic?",
20
+ "What's your take on the given statement?",
21
+ "What is your perspective on the given statement?",
22
+ "How would you interpret this remark?",
23
+ "Could you please provide your opinion on this?",
24
+ "Can you share your understanding of the above point?",
25
+ "Would you mind elaborating on this topic?",
26
+ "What are your views about the given statement?",
27
+ "How do you feel about the presented information?",
28
+ "Could you provide your perspective on this?",
29
+ "What is your opinion regarding this statement?",
30
+ "Can you share your thoughts about the mentioned claim?",
31
+ "How would you interpret the above comment?",
32
+ "Would you mind sharing your insights on this issue?",
33
+ ]
34
+
35
+
36
+ class SNLIVEDataset(VQADataset):
37
+ """Visual Reasoning Dataset."""
38
+
39
+ def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs):
40
+ super().__init__(tokenizer, vis_processor, vis_root, ann_paths=[], **kwargs)
41
+
42
+ self.annotation = self.load_annotations(ann_paths)
43
+ if self.sample_image:
44
+ print("randomly sample one annotation for each image")
45
+ self.annotation = self.parse_annotation(self.annotation)
46
+ self._add_instance_ids()
47
+
48
+ @staticmethod
49
+ def load_annotations(ann_paths):
50
+ annotation = []
51
+ for ann_path in ann_paths:
52
+ with open(ann_path, "r") as f:
53
+ for line in f.readlines():
54
+ line = line.strip()
55
+ if len(line) != 0:
56
+ ann = json.loads(line)
57
+ annotation.append(ann)
58
+ return annotation
59
+
60
+ def parse_annotation(self, annotation):
61
+ image_list = defaultdict(list)
62
+ for ann in annotation:
63
+ image_list[ann["Flickr30K_ID"]].append(ann)
64
+ annotation = []
65
+ for ann_list in image_list.values():
66
+ annotation.append(random.choice(ann_list))
67
+ return annotation
68
+
69
+ def process_text(self, ann):
70
+ question = ann["sentence2"] + " " + random.choice(QUESTIONS)
71
+ answer = ann["gold_label"]
72
+ if random.random() < self.option_prob:
73
+ instruction = self.prompter(question, ["entailment", "neutral", "contradiction"])
74
+ else:
75
+ instruction = self.prompter(question)
76
+ return dict(instruction=instruction, answer=answer)
77
+
78
+ def process_image(self, ann):
79
+ image_path = os.path.join(self.vis_root, ann["Flickr30K_ID"] + ".jpg")
80
+ image = Image.open(image_path).convert("RGB")
81
+ image = self.vis_processor(image)
82
+ return image
mmgpt/datasets/text_ocr_dataset.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ from PIL import Image
7
+ from transformers import LlamaTokenizer
8
+
9
+ from .vqa_dataset import VQADataset, VQAPrompter
10
+
11
+
12
+ class TextOCRDataset(VQADataset):
13
+ def __init__(
14
+ self, tokenizer, vis_processor=None, vis_root=None, ann_paths=[], add_eos=True, ignore_instruction=True
15
+ ):
16
+ """
17
+ vis_root (string): Root directory of images (e.g. coco/images/)
18
+ ann_root (string): directory to store the annotation file
19
+ """
20
+ assert tokenizer.add_eos_token is False, "tokenizer should not add eos token by default"
21
+ self.tokenizer: LlamaTokenizer = tokenizer
22
+ self.vis_root = vis_root
23
+
24
+ self.annotation = []
25
+ for ann_path in ann_paths:
26
+ self.annotation.extend(json.load(open(ann_path, "r"))["data"])
27
+
28
+ self.vis_processor = vis_processor
29
+
30
+ self._add_instance_ids()
31
+ self.option_prob = 0.5
32
+ self.prompter = VQAPrompter()
33
+ self.add_eos = add_eos
34
+ self.ignore_instruction = ignore_instruction
35
+
36
+ def process_image(self, ann):
37
+ image_path = os.path.join(self.vis_root, ann["image_id"] + ".jpg")
38
+ image = Image.open(image_path).convert("RGB")
39
+
40
+ image = self.vis_processor(image)
41
+ return image
42
+
43
+ def process_text(self, ann):
44
+ question = ann["question"]
45
+
46
+ answer_weight = {}
47
+ for answer in ann["answers"]:
48
+ if answer in answer_weight.keys():
49
+ answer_weight[answer] += 1 / len(ann["answers"])
50
+ else:
51
+ answer_weight[answer] = 1 / len(ann["answers"])
52
+
53
+ answers = list(answer_weight.keys())
54
+ weights = list(answer_weight.values())
55
+
56
+ # create instruction
57
+ true_answer = answers[np.argmax(weights)]
58
+ is_option = random.random() < self.option_prob and len(answers) > 1
59
+ if is_option:
60
+ instruction = self.prompter(question, answers)
61
+ else:
62
+ instruction = self.prompter(question)
63
+
64
+ return dict(instruction=instruction, answer=true_answer)
mmgpt/datasets/vqa_dataset.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import copy
9
+ import json
10
+ import os
11
+ import random
12
+ from collections import defaultdict
13
+ from typing import Iterable
14
+
15
+ import numpy as np
16
+ import torch
17
+ from PIL import Image
18
+ from torch.utils.data import ConcatDataset, Dataset
19
+ from torch.utils.data.dataloader import default_collate
20
+ from transformers import LlamaTokenizer
21
+
22
+ TEMPLATE = {
23
+ "description": "Template used by Alpaca-LoRA.",
24
+ # "prompt_choice": "Below is a multiple choice question about an image, along with answer options. Please choose the correct answer from these options.\n\n### Image:\n{image}\n\n### Question:\n{question}\n\n### Input:\n{options}\n\n### Answer:\n",
25
+ # "prompt_qa": "Below is a question about an image. Write a response to answer the question.\n\n### Image:\n{image}\n\n### Question:\n{question}\n\n### Answer:\n",
26
+ "prompt_choice": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Image:\n{image}\n\n### Instruction:\n{question}\n\n### Input:\n{options}\n\n### Response:\n",
27
+ "prompt_qa": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Image:\n{image}\n\n### Instruction:\n{question}\n\n### Response:\n",
28
+ "response_split": "### Response:",
29
+ }
30
+
31
+
32
+ class VQAPrompter:
33
+ def __call__(self, question, options=None):
34
+ if options:
35
+ options = ", ".join(options)
36
+ res = TEMPLATE["prompt_choice"].format(image="<image>", question=question, options=options)
37
+ else:
38
+ res = TEMPLATE["prompt_qa"].format(image="<image>", question=question)
39
+ return res
40
+
41
+ def get_response(self, output: str) -> str:
42
+ return output.split(TEMPLATE["response_split"])[-1].strip()
43
+
44
+
45
+ class VQADataset(Dataset):
46
+ def __init__(
47
+ self,
48
+ tokenizer,
49
+ vis_processor=None,
50
+ vis_root=None,
51
+ ann_paths=[],
52
+ add_eos=True,
53
+ ignore_instruction=True,
54
+ sample_image=False,
55
+ ):
56
+ """
57
+ vis_root (string): Root directory of images (e.g. coco/images/)
58
+ ann_root (string): directory to store the annotation file
59
+ """
60
+ assert tokenizer.add_eos_token is False, "tokenizer should not add eos token by default"
61
+ self.tokenizer: LlamaTokenizer = tokenizer
62
+ self.vis_root = vis_root
63
+
64
+ self.annotation = []
65
+ for ann_path in ann_paths:
66
+ self.annotation.extend(json.load(open(ann_path, "r")))
67
+
68
+ self.sample_image = sample_image
69
+ if self.sample_image:
70
+ print("randomly sample one annotation for each image")
71
+ self.annotation = self.parse_annotation(self.annotation)
72
+
73
+ self.vis_processor = vis_processor
74
+
75
+ self._add_instance_ids()
76
+ self.option_prob = 0.5
77
+ self.prompter = VQAPrompter()
78
+ self.add_eos = add_eos
79
+ self.ignore_instruction = ignore_instruction
80
+
81
+ def parse_annotation(self, annotation):
82
+ image_list = defaultdict(list)
83
+ for ann in annotation:
84
+ image_list[ann["image"]].append(ann)
85
+ # image_name_list = list(image_list.keys())
86
+ annotation = []
87
+ for ann_list in image_list.values():
88
+ annotation.append(random.choice(ann_list))
89
+ return annotation
90
+
91
+ def __len__(self):
92
+ return len(self.annotation)
93
+
94
+ def _add_instance_ids(self, key="instance_id"):
95
+ for idx, ann in enumerate(self.annotation):
96
+ ann[key] = str(idx)
97
+
98
+ def process_image(self, ann):
99
+ image_path = os.path.join(self.vis_root, ann["image"])
100
+ image = Image.open(image_path).convert("RGB")
101
+
102
+ image = self.vis_processor(image)
103
+ return image
104
+
105
+ def process_text(self, ann):
106
+ question = ann["question"]
107
+
108
+ answer_weight = {}
109
+ for answer in ann["answer"]:
110
+ if answer in answer_weight.keys():
111
+ answer_weight[answer] += 1 / len(ann["answer"])
112
+ else:
113
+ answer_weight[answer] = 1 / len(ann["answer"])
114
+
115
+ answers = list(answer_weight.keys())
116
+ weights = list(answer_weight.values())
117
+
118
+ # create instruction
119
+ true_answer = answers[np.argmax(weights)]
120
+ is_option = random.random() < self.option_prob and len(answers) > 1
121
+ if is_option:
122
+ instruction = self.prompter(question, answers)
123
+ else:
124
+ instruction = self.prompter(question)
125
+
126
+ return dict(instruction=instruction, answer=true_answer)
127
+
128
+ def tokenize(self, text):
129
+ res = self.tokenizer(
130
+ text["instruction"] + text["answer"],
131
+ return_tensors=None,
132
+ padding="do_not_pad",
133
+ truncation=True,
134
+ max_length=512,
135
+ )
136
+
137
+ # manually add eos token
138
+ if res["input_ids"][-1] != self.tokenizer.eos_token_id and len(res["input_ids"]) < 512 and self.add_eos:
139
+ res["input_ids"].append(self.tokenizer.eos_token_id)
140
+ res["attention_mask"].append(1)
141
+ labels = copy.deepcopy(res["input_ids"])
142
+ # ignore instruction_token
143
+ if self.ignore_instruction:
144
+ instruction_token = self.tokenizer(
145
+ text["instruction"], return_tensors=None, padding="do_not_pad", truncation=True, max_length=512
146
+ )
147
+ labels = [-100] * len(instruction_token["input_ids"]) + labels[len(instruction_token["input_ids"]) :]
148
+
149
+ res.update(labels=labels)
150
+ return res
151
+
152
+ def __getitem__(self, index):
153
+ ann = self.annotation[index]
154
+ image = self.process_image(ann)
155
+ text = self.process_text(ann)
156
+ res = self.tokenize(text)
157
+ res.update(image=image)
158
+ res.update(text)
159
+ return res
160
+
161
+ def collater(self, samples):
162
+ image_list, question_list, answer_list, input_id_list, attention_mask_list, labels_list = [], [], [], [], [], []
163
+
164
+ for sample in samples:
165
+ image_list.append(sample["image"])
166
+ question_list.append(sample["instruction"])
167
+ answer_list.append(sample["answer"])
168
+ input_id_list.append(sample["input_ids"])
169
+ attention_mask_list.append(sample["attention_mask"])
170
+ labels_list.append(sample["labels"])
171
+
172
+ # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
173
+ # same length to return tensors.
174
+ max_label_length = max(len(l) for l in labels_list)
175
+ padding_side = self.tokenizer.padding_side
176
+ padded_labels = []
177
+ for l in labels_list:
178
+ remainder = [-100] * (max_label_length - len(l))
179
+ if isinstance(l, list):
180
+ l = l + remainder if padding_side == "right" else remainder + l
181
+ elif padding_side == "right":
182
+ l = np.concatenate([l, remainder]).astype(np.int64)
183
+ else:
184
+ l = np.concatenate([remainder, l]).astype(np.int64)
185
+ padded_labels.append(l)
186
+
187
+ padded_samples = self.tokenizer.pad(
188
+ {"input_ids": input_id_list, "attention_mask": attention_mask_list, "labels": padded_labels},
189
+ return_tensors="pt",
190
+ padding="longest",
191
+ )
192
+
193
+ labels = padded_samples["labels"]
194
+ media_token_id = self.tokenizer("<image>", add_special_tokens=False)["input_ids"][-1]
195
+ labels[labels == self.tokenizer.pad_token_id] = -100
196
+ labels[:, 0] = -100
197
+ labels[labels == media_token_id] = -100
198
+ return {
199
+ "image": torch.stack(image_list, dim=0),
200
+ "input_ids": padded_samples["input_ids"],
201
+ "attention_mask": padded_samples["attention_mask"],
202
+ "labels": labels,
203
+ "instruction": question_list,
204
+ "answer": answer_list,
205
+ }
206
+
207
+
208
+ class ConcatDataset(ConcatDataset):
209
+ def __init__(self, datasets: Iterable[Dataset]) -> None:
210
+ super().__init__(datasets)
211
+
212
+ def collater(self, samples):
213
+ # TODO For now only supports datasets with same underlying collater implementations
214
+
215
+ all_keys = set()
216
+ for s in samples:
217
+ all_keys.update(s)
218
+
219
+ shared_keys = all_keys
220
+ for s in samples:
221
+ shared_keys = shared_keys & set(s.keys())
222
+
223
+ samples_shared_keys = []
224
+ for s in samples:
225
+ samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
226
+
227
+ return self.datasets[0].collater(samples_shared_keys)
mmgpt/models/__init__.py ADDED
File without changes
mmgpt/models/blip2/__init__.py ADDED
File without changes
mmgpt/models/builder.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .open_flamingo import create_model_and_transforms as create_open_flamingo_model_and_transforms
2
+ import torch.nn as nn
3
+ from transformers import LlamaTokenizer, LlamaForCausalLM
4
+
5
+ def create_model_and_transforms(
6
+ model_name: str,
7
+ clip_vision_encoder_path: str,
8
+ clip_vision_encoder_pretrained: str,
9
+ lang_encoder_path: str,
10
+ tokenizer_path: str,
11
+ tuning_config,
12
+ pretrained_model_path,
13
+ **kwargs,
14
+ ):
15
+ if model_name == "open_flamingo":
16
+ return create_open_flamingo_model_and_transforms(
17
+ clip_vision_encoder_path=clip_vision_encoder_path,
18
+ clip_vision_encoder_pretrained=clip_vision_encoder_pretrained,
19
+ lang_encoder_path=lang_encoder_path,
20
+ tokenizer_path=tokenizer_path,
21
+ tuning_config=tuning_config,
22
+ pretrained_model_path=pretrained_model_path,
23
+ **kwargs,
24
+ )
25
+ # TODO: support BLIP2
26
+ else:
27
+ raise ValueError(f"Unknown model name: {model_name}")
28
+
29
+ # only for debugging
30
+ def create_toy_model_and_transforms(
31
+ model_name: str,
32
+ clip_vision_encoder_path: str,
33
+ clip_vision_encoder_pretrained: str,
34
+ lang_encoder_path: str,
35
+ tokenizer_path: str,
36
+ tuning_config,
37
+ pretrained_model_path,
38
+ **kwargs,
39
+ ):
40
+ print("init toy vision encoder")
41
+ import torchvision
42
+
43
+ image_processor = torchvision.transforms.Compose(
44
+ [
45
+ torchvision.transforms.Resize((224, 224)),
46
+ torchvision.transforms.ToTensor(),
47
+ ]
48
+ )
49
+ print("init tokenizer")
50
+ text_tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path)
51
+ # add Flamingo special tokens to the tokenizer
52
+ text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
53
+ if text_tokenizer.pad_token is None:
54
+ # Issue: GPT models don't have a pad token, which we use to
55
+ # modify labels for the loss.
56
+ text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
57
+
58
+ class ToyModel(nn.Module):
59
+ def __init__(self, *args, **kwargs):
60
+ super().__init__()
61
+ self.input_embeddings = nn.Embedding(38000, 512)
62
+ self.layer = nn.Linear(512, 512)
63
+ self.config = {"hidden_size": 512}
64
+
65
+ def forward(self, lang_x, **kwargs):
66
+ x = self.input_embeddings(lang_x)
67
+ x = self.layer(x)
68
+ loss = x.sum()
69
+
70
+ return (loss,)
71
+
72
+ model = ToyModel()
73
+
74
+ return model, image_processor, text_tokenizer
mmgpt/models/open_flamingo/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .builder import create_model_and_transforms
2
+ from .flamingo import Flamingo
3
+ from .flamingo_lm import FlamingoLMMixin
mmgpt/models/open_flamingo/builder.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/mlfoundations/open_flamingo"""
2
+ import open_clip
3
+ import torch
4
+ import torch.nn as nn
5
+ from bigmodelvis import Visualization
6
+ from peft import LoraConfig, get_peft_model
7
+ from transformers import LlamaForCausalLM, LlamaTokenizer
8
+
9
+ from .flamingo import Flamingo
10
+ from .flamingo_lm import FlamingoLMMixin
11
+ from .utils import extend_instance
12
+
13
+
14
+ def create_model_and_transforms(
15
+ clip_vision_encoder_path: str,
16
+ clip_vision_encoder_pretrained: str,
17
+ lang_encoder_path: str,
18
+ tokenizer_path: str,
19
+ decoder_layers_attr_name: str = None,
20
+ pretrained_model_path: str = None,
21
+ tuning_config=None,
22
+ **flamingo_kwargs,
23
+ ):
24
+ """
25
+ Initialize a Flamingo model from a pretrained vision encoder and language encoder.
26
+ Appends special tokens to the tokenizer and freezes backbones.
27
+
28
+ Args:
29
+ clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32")
30
+ clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k")
31
+ lang_encoder_path (str): path to pretrained language encoder
32
+ tokenizer_path (str): path to pretrained tokenizer
33
+ decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
34
+ Returns:
35
+ Flamingo: Flamingo model from pretrained vision and language encoders
36
+ Image processor: Pipeline to preprocess input images
37
+ Tokenizer: A tokenizer for the language model
38
+ """
39
+ print("init clip vision encoder")
40
+ vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
41
+ clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained
42
+ )
43
+ # set the vision encoder to output the visual features
44
+ vision_encoder.visual.output_tokens = True
45
+ print("init tokenizer")
46
+ text_tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path)
47
+ # add Flamingo special tokens to the tokenizer
48
+ text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
49
+ if text_tokenizer.pad_token is None:
50
+ # Issue: GPT models don't have a pad token, which we use to
51
+ # modify labels for the loss.
52
+ text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
53
+ text_tokenizer.bos_token_id = 1
54
+ text_tokenizer.eos_token_id = 2
55
+
56
+ print("init llama")
57
+ lang_encoder = LlamaForCausalLM.from_pretrained(lang_encoder_path)
58
+ extend_instance(lang_encoder, FlamingoLMMixin)
59
+
60
+ if decoder_layers_attr_name is None:
61
+ decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
62
+ lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
63
+ lang_encoder.resize_token_embeddings(len(text_tokenizer))
64
+
65
+ model = Flamingo(
66
+ vision_encoder,
67
+ lang_encoder,
68
+ text_tokenizer.encode("<|endofchunk|>")[-1],
69
+ text_tokenizer.encode("<image>")[-1],
70
+ vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"],
71
+ cross_attn_every_n_layers=4,
72
+ **flamingo_kwargs,
73
+ )
74
+
75
+ if pretrained_model_path is not None:
76
+ print(f"loading pretrained model from {pretrained_model_path}")
77
+ model.load_state_dict(torch.load(pretrained_model_path), strict=False)
78
+
79
+ # Freeze all parameters
80
+ model.requires_grad_(False)
81
+ assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
82
+
83
+ if tuning_config is not None:
84
+ model = prepare_model_for_tuning(model, tuning_config)
85
+ else:
86
+ raise ValueError("tuning_config must be provided")
87
+
88
+ print(
89
+ f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
90
+ )
91
+
92
+ return model, image_processor, text_tokenizer
93
+
94
+
95
+ def _infer_decoder_layers_attr_name(model):
96
+ for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
97
+ if k.lower() in model.__class__.__name__.lower():
98
+ return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
99
+
100
+ raise ValueError(
101
+ f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
102
+ )
103
+
104
+
105
+ __KNOWN_DECODER_LAYERS_ATTR_NAMES = {
106
+ "opt": "model.decoder.layers",
107
+ "gptneo": "transformer.h",
108
+ "gptj": "transformer.h",
109
+ "gpt-j": "transformer.h",
110
+ "pythia": "gpt_neox.layers",
111
+ "llama": "model.layers",
112
+ }
113
+
114
+
115
+ def prepare_model_for_tuning(model: nn.Module, config):
116
+ if config.lora:
117
+ lora_config = LoraConfig(
118
+ r=config.lora_r,
119
+ lora_alpha=config.lora_alpha,
120
+ target_modules=config.lora_target_modules,
121
+ lora_dropout=config.lora_dropout,
122
+ bias="none", # won't use bias currently
123
+ modules_to_save=[], # TODO: might be helpful if save partial model
124
+ task_type="CAUSAL_LM",
125
+ )
126
+ model.lang_encoder = get_peft_model(model.lang_encoder, peft_config=lora_config)
127
+
128
+ # manually unfreeze modules, we use a `substring` fashion mathcing
129
+ for name, param in model.named_parameters():
130
+ if any(substr in name for substr in config.unfrozen):
131
+ param.requires_grad = True
132
+
133
+ if config.vis and is_rank0():
134
+ Visualization(model).structure_graph()
135
+ return model
136
+
137
+
138
+ # temporary workaround, should use a common utils in the future
139
+ def is_rank0():
140
+ if not torch.distributed.is_initialized():
141
+ return True
142
+ return torch.distributed.get_rank() == 0
mmgpt/models/open_flamingo/flamingo.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/mlfoundations/open_flamingo"""
2
+ import torch
3
+ from einops import rearrange
4
+ from torch import nn
5
+
6
+ from .helpers import PerceiverResampler
7
+
8
+
9
+ class Flamingo(nn.Module):
10
+ def __init__(
11
+ self,
12
+ vision_encoder: nn.Module,
13
+ lang_encoder: nn.Module,
14
+ eoc_token_id: int,
15
+ media_token_id: int,
16
+ vis_dim: int,
17
+ cross_attn_every_n_layers: int = 1,
18
+ use_media_placement_augmentation: bool = False,
19
+ ):
20
+ """
21
+ Args:
22
+ vision_encoder (nn.Module): HF CLIPModel
23
+ lang_encoder (nn.Module): HF causal language model
24
+ eoc_token_id (int): Token id for <|endofchunk|>
25
+ media_token_id (int): Token id for <image>
26
+ vis_dim (int): Dimension of the visual features.
27
+ Visual features are projected to match this shape along the last dimension.
28
+ cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
29
+ use_media_placement_augmentation (bool, optional): Whether to randomly assign images to the preceding or following text in training. Defaults to False.
30
+ """
31
+ super().__init__()
32
+ self.eoc_token_id = eoc_token_id
33
+ self.media_token_id = media_token_id
34
+ self.use_media_placement_augmentation = use_media_placement_augmentation
35
+ self.vis_dim = vis_dim
36
+ self.vision_encoder = vision_encoder
37
+ self.perceiver = PerceiverResampler(dim=self.vis_dim)
38
+ self.lang_encoder = lang_encoder
39
+ self.lang_encoder.init_flamingo(
40
+ media_token_id=media_token_id,
41
+ vis_hidden_size=self.vis_dim,
42
+ cross_attn_every_n_layers=cross_attn_every_n_layers,
43
+ use_media_placement_augmentation=self.use_media_placement_augmentation,
44
+ )
45
+
46
+ def forward(
47
+ self,
48
+ vision_x: torch.Tensor,
49
+ lang_x: torch.Tensor,
50
+ attention_mask: torch.Tensor = None,
51
+ labels: torch.Tensor = None,
52
+ use_cached_vision_x: bool = False,
53
+ clear_conditioned_layers: bool = True,
54
+ past_key_values=None,
55
+ use_cache: bool = False,
56
+ ):
57
+ """
58
+ Forward pass of Flamingo.
59
+
60
+ Args:
61
+ vision_x (torch.Tensor): Vision input
62
+ shape (B, T_img, F, C, H, W) with F=1
63
+ lang_x (torch.Tensor): Language input ids
64
+ shape (B, T_txt)
65
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
66
+ labels (torch.Tensor, optional): Labels. Defaults to None.
67
+ clear_conditioned_layers: if True, clear the conditioned layers
68
+ once the foward pass is completed. Set this to false if the
69
+ same set of images will be reused in another subsequent
70
+ forward pass.
71
+ past_key_values: pre-computed values to pass to language model.
72
+ See past_key_values documentation in Hugging Face
73
+ CausalLM models.
74
+ use_cache: whether to use cached key values. See use_cache
75
+ documentation in Hugging Face CausalLM models.
76
+ """
77
+ if vision_x is None and use_cached_vision_x is False:
78
+ for layer in self.lang_encoder._get_decoder_layers():
79
+ layer.condition_only_lang_x(True)
80
+ output = self.lang_encoder(
81
+ input_ids=lang_x,
82
+ attention_mask=attention_mask,
83
+ labels=labels,
84
+ past_key_values=past_key_values,
85
+ use_cache=use_cache,
86
+ )
87
+ for layer in self.lang_encoder._get_decoder_layers():
88
+ layer.condition_only_lang_x(False)
89
+ return output
90
+ assert (
91
+ vision_x is not None
92
+ ) or use_cached_vision_x, "Must provide either vision_x or use_cached_vision_x to True."
93
+
94
+ if use_cached_vision_x:
95
+ # Case: use cached; vision_x should be cached and other
96
+ # vision-related inputs should not be provided.
97
+ assert vision_x is None, "Expect vision_x to be None when use_cached_vision_x is True."
98
+ assert self.lang_encoder.is_conditioned()
99
+
100
+ else:
101
+ # Case: do not use caching (i.e. this is a standard forward pass);
102
+ self._encode_vision_x(vision_x=vision_x)
103
+
104
+ output = self.lang_encoder(
105
+ input_ids=lang_x,
106
+ attention_mask=attention_mask,
107
+ labels=labels,
108
+ past_key_values=past_key_values,
109
+ use_cache=use_cache,
110
+ )
111
+
112
+ if clear_conditioned_layers:
113
+ self.lang_encoder.clear_conditioned_layers()
114
+
115
+ return output
116
+
117
+ def generate(
118
+ self,
119
+ vision_x: torch.Tensor,
120
+ lang_x: torch.Tensor,
121
+ attention_mask: torch.Tensor = None,
122
+ num_beams=1,
123
+ max_new_tokens=None,
124
+ temperature=1.0,
125
+ top_k=0,
126
+ top_p=1.0,
127
+ no_repeat_ngram_size=0,
128
+ prefix_allowed_tokens_fn=None,
129
+ length_penalty=1.0,
130
+ num_return_sequences=1,
131
+ do_sample=False,
132
+ early_stopping=False,
133
+ ):
134
+ """
135
+ Generate text conditioned on vision and language inputs.
136
+
137
+ Args:
138
+ vision_x (torch.Tensor): Vision input
139
+ shape (B, T_img, F, C, H, W)
140
+ images in the same chunk are collated along T_img, and frames are collated along F
141
+ currently only F=1 is supported (single-frame videos)
142
+ lang_x (torch.Tensor): Language input
143
+ shape (B, T_txt)
144
+ max_length (int, optional): Maximum length of the output. Defaults to None.
145
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
146
+ num_beams (int, optional): Number of beams. Defaults to 1.
147
+ max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
148
+ temperature (float, optional): Temperature. Defaults to 1.0.
149
+ top_k (int, optional): Top k. Defaults to 0.
150
+ top_p (float, optional): Top p. Defaults to 1.0.
151
+ no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
152
+ length_penalty (float, optional): Length penalty. Defaults to 1.0.
153
+ num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
154
+ do_sample (bool, optional): Do sample. Defaults to False.
155
+ early_stopping (bool, optional): Early stopping. Defaults to False.
156
+ Returns:
157
+ torch.Tensor: lang_x with generated tokens appended to it
158
+ """
159
+ if num_beams > 1:
160
+ vision_x = vision_x.repeat_interleave(num_beams, dim=0)
161
+
162
+ self._encode_vision_x(vision_x=vision_x)
163
+
164
+ output = self.lang_encoder.generate(
165
+ lang_x,
166
+ attention_mask=attention_mask,
167
+ # eos_token_id=self.eoc_token_id,
168
+ num_beams=num_beams,
169
+ max_new_tokens=max_new_tokens,
170
+ temperature=temperature,
171
+ top_k=top_k,
172
+ top_p=top_p,
173
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
174
+ no_repeat_ngram_size=no_repeat_ngram_size,
175
+ length_penalty=length_penalty,
176
+ num_return_sequences=num_return_sequences,
177
+ do_sample=do_sample,
178
+ early_stopping=early_stopping,
179
+ )
180
+
181
+ self.lang_encoder.clear_conditioned_layers()
182
+ return output
183
+
184
+ def _encode_vision_x(self, vision_x: torch.Tensor):
185
+ """
186
+ Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
187
+ Args:
188
+ vision_x (torch.Tensor): Vision input
189
+ shape (B, T_img, F, C, H, W)
190
+ Images in the same chunk are collated along T_img, and frames are collated along F
191
+ Currently only F=1 is supported (single-frame videos)
192
+
193
+ rearrange code based on https://github.com/dhansmair/flamingo-mini
194
+ """
195
+
196
+ assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
197
+ b, T, F = vision_x.shape[:3]
198
+ assert F == 1, "Only single frame supported"
199
+
200
+ vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
201
+ with torch.no_grad():
202
+ vision_x = self.vision_encoder.visual(vision_x)[1]
203
+ vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
204
+
205
+ vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d)
206
+
207
+ for layer in self.lang_encoder._get_decoder_layers():
208
+ layer.condition_vis_x(vision_x)
mmgpt/models/open_flamingo/flamingo_lm.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/mlfoundations/open_flamingo"""
2
+ import random
3
+
4
+ import torch.nn as nn
5
+
6
+ from .helpers import GatedCrossAttentionBlock
7
+ from .utils import getattr_recursive, setattr_recursive
8
+
9
+
10
+ class FlamingoLayer(nn.Module):
11
+ def __init__(self, gated_cross_attn_layer, decoder_layer):
12
+ super().__init__()
13
+ self.gated_cross_attn_layer = gated_cross_attn_layer
14
+ self.decoder_layer = decoder_layer
15
+ self.vis_x = None
16
+ self.media_locations = None
17
+ self.only_lang_x = False
18
+
19
+ def is_conditioned(self) -> bool:
20
+ """Check whether the layer is conditioned."""
21
+ return self.vis_x is not None
22
+
23
+ # Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
24
+ def condition_vis_x(self, vis_x):
25
+ self.vis_x = vis_x
26
+
27
+ def condition_only_lang_x(self, only_lang_x=False):
28
+ self.only_lang_x = only_lang_x
29
+
30
+ def condition_media_locations(self, media_locations):
31
+ self.media_locations = media_locations
32
+
33
+ def condition_attend_previous(self, attend_previous):
34
+ self.attend_previous = attend_previous
35
+
36
+ def forward(
37
+ self,
38
+ lang_x,
39
+ attention_mask=None,
40
+ **decoder_layer_kwargs,
41
+ ):
42
+ if self.gated_cross_attn_layer is None or self.only_lang_x:
43
+ return self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs)
44
+
45
+ if self.vis_x is None:
46
+ raise ValueError("vis_x must be conditioned before forward pass")
47
+
48
+ if self.media_locations is None:
49
+ raise ValueError("media_locations must be conditioned before forward pass")
50
+
51
+ lang_x = self.gated_cross_attn_layer(
52
+ lang_x,
53
+ self.vis_x,
54
+ media_locations=self.media_locations,
55
+ attend_previous=self.attend_previous,
56
+ )
57
+ lang_x = self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs)
58
+ return lang_x
59
+
60
+
61
+ class FlamingoLMMixin(nn.Module):
62
+ """
63
+ Mixin to add cross-attention layers to a language model.
64
+ """
65
+
66
+ def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
67
+ self.decoder_layers_attr_name = decoder_layers_attr_name
68
+
69
+ def _get_decoder_layers(self):
70
+ return getattr_recursive(self, self.decoder_layers_attr_name)
71
+
72
+ def _set_decoder_layers(self, value):
73
+ setattr_recursive(self, self.decoder_layers_attr_name, value)
74
+
75
+ def init_flamingo(
76
+ self,
77
+ media_token_id,
78
+ vis_hidden_size,
79
+ cross_attn_every_n_layers,
80
+ use_media_placement_augmentation,
81
+ ):
82
+ """
83
+ Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
84
+ """
85
+
86
+ self.gated_cross_attn_layers = nn.ModuleList(
87
+ [
88
+ GatedCrossAttentionBlock(dim=self.config.hidden_size, dim_visual=vis_hidden_size)
89
+ if (layer_idx + 1) % cross_attn_every_n_layers == 0
90
+ else None
91
+ for layer_idx, _ in enumerate(self._get_decoder_layers())
92
+ ]
93
+ )
94
+ self._set_decoder_layers(
95
+ nn.ModuleList(
96
+ [
97
+ FlamingoLayer(gated_cross_attn_layer, decoder_layer)
98
+ for gated_cross_attn_layer, decoder_layer in zip(
99
+ self.gated_cross_attn_layers, self._get_decoder_layers()
100
+ )
101
+ ]
102
+ )
103
+ )
104
+ self.media_token_id = media_token_id
105
+ self.use_media_placement_augmentation = use_media_placement_augmentation
106
+ self.initialized_flamingo = True
107
+
108
+ def forward(self, *input, **kwargs):
109
+ """Condition the Flamingo layers on the media locations before forward()"""
110
+ if not self.initialized_flamingo:
111
+ raise ValueError("Flamingo layers are not initialized. Please call `init_flamingo` first.")
112
+
113
+ input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0]
114
+ media_locations = input_ids == self.media_token_id
115
+ attend_previous = (random.random() < 0.5) if self.use_media_placement_augmentation else False
116
+
117
+ for layer in self.get_decoder().layers:
118
+ layer.condition_media_locations(media_locations)
119
+ layer.condition_attend_previous(attend_previous)
120
+
121
+ return super().forward(*input, **kwargs) # Call the other parent's forward method
122
+
123
+ def is_conditioned(self) -> bool:
124
+ """Check whether all decoder layers are already conditioned."""
125
+ return all(l.is_conditioned() for l in self._get_decoder_layers())
126
+
127
+ def clear_conditioned_layers(self):
128
+ for layer in self._get_decoder_layers():
129
+ layer.condition_vis_x(None)
130
+ layer.condition_media_locations(None)
131
+ layer.condition_attend_previous(None)
mmgpt/models/open_flamingo/helpers.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from https://github.com/lucidrains/flamingo-pytorch
3
+ """
4
+
5
+ import torch
6
+ from einops import rearrange, repeat
7
+ from einops_exts import rearrange_many
8
+ from torch import einsum, nn
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def FeedForward(dim, mult=4):
16
+ inner_dim = int(dim * mult)
17
+ return nn.Sequential(
18
+ nn.LayerNorm(dim),
19
+ nn.Linear(dim, inner_dim, bias=False),
20
+ nn.GELU(),
21
+ nn.Linear(inner_dim, dim, bias=False),
22
+ )
23
+
24
+
25
+ class PerceiverAttention(nn.Module):
26
+ def __init__(self, *, dim, dim_head=64, heads=8):
27
+ super().__init__()
28
+ self.scale = dim_head**-0.5
29
+ self.heads = heads
30
+ inner_dim = dim_head * heads
31
+
32
+ self.norm_media = nn.LayerNorm(dim)
33
+ self.norm_latents = nn.LayerNorm(dim)
34
+
35
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
36
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
37
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
38
+
39
+ def forward(self, x, latents):
40
+ """
41
+ Args:
42
+ x (torch.Tensor): image features
43
+ shape (b, T, n1, D)
44
+ latent (torch.Tensor): latent features
45
+ shape (b, T, n2, D)
46
+ """
47
+ x = self.norm_media(x)
48
+ latents = self.norm_latents(latents)
49
+
50
+ h = self.heads
51
+
52
+ q = self.to_q(latents)
53
+ kv_input = torch.cat((x, latents), dim=-2)
54
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
55
+ q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
56
+ q = q * self.scale
57
+
58
+ # attention
59
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
60
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
61
+ attn = sim.softmax(dim=-1)
62
+
63
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
64
+ out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
65
+ return self.to_out(out)
66
+
67
+
68
+ class PerceiverResampler(nn.Module):
69
+ def __init__(
70
+ self,
71
+ *,
72
+ dim,
73
+ depth=6,
74
+ dim_head=64,
75
+ heads=8,
76
+ num_latents=64,
77
+ max_num_media=None,
78
+ max_num_frames=None,
79
+ ff_mult=4,
80
+ ):
81
+ super().__init__()
82
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
83
+ self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None
84
+ self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None
85
+
86
+ self.layers = nn.ModuleList([])
87
+ for _ in range(depth):
88
+ self.layers.append(
89
+ nn.ModuleList(
90
+ [
91
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
92
+ FeedForward(dim=dim, mult=ff_mult),
93
+ ]
94
+ )
95
+ )
96
+
97
+ self.norm = nn.LayerNorm(dim)
98
+
99
+ def forward(self, x):
100
+ """
101
+ Args:
102
+ x (torch.Tensor): image features
103
+ shape (b, T, F, v, D)
104
+ Returns:
105
+ shape (b, T, n, D) where n is self.num_latents
106
+ """
107
+ b, T, F, v = x.shape[:4]
108
+
109
+ # frame and media time embeddings
110
+ if exists(self.frame_embs):
111
+ frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
112
+ x = x + frame_embs
113
+ x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions
114
+ if exists(self.media_time_embs):
115
+ x = x + self.media_time_embs[:T]
116
+
117
+ # blocks
118
+ latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
119
+ for attn, ff in self.layers:
120
+ latents = attn(x, latents) + latents
121
+ latents = ff(latents) + latents
122
+ return self.norm(latents)
123
+
124
+
125
+ # gated cross attention
126
+
127
+
128
+ class MaskedCrossAttention(nn.Module):
129
+ def __init__(
130
+ self,
131
+ *,
132
+ dim,
133
+ dim_visual,
134
+ dim_head=64,
135
+ heads=8,
136
+ only_attend_immediate_media=True,
137
+ ):
138
+ super().__init__()
139
+ self.scale = dim_head**-0.5
140
+ self.heads = heads
141
+ inner_dim = dim_head * heads
142
+
143
+ self.norm = nn.LayerNorm(dim)
144
+
145
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
146
+ self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
147
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
148
+
149
+ # whether for text to only attend to immediate preceding image, or all previous images
150
+ self.only_attend_immediate_media = only_attend_immediate_media
151
+
152
+ def forward(self, x, media, media_locations=None, attend_previous=True):
153
+ """
154
+ Args:
155
+ x (torch.Tensor): text features
156
+ shape (B, T_txt, D_txt)
157
+ media (torch.Tensor): image features
158
+ shape (B, T_img, n, D_img) where n is the dim of the latents
159
+ media_locations: boolean mask identifying the media tokens in x
160
+ shape (B, T_txt)
161
+ attend_previous: bool
162
+ If false, ignores immediately preceding image and starts attending when following image
163
+ """
164
+ _, T_img, n = media.shape[:3]
165
+ h = self.heads
166
+
167
+ x = self.norm(x)
168
+
169
+ q = self.to_q(x)
170
+ media = rearrange(media, "b t n d -> b (t n) d")
171
+
172
+ k, v = self.to_kv(media).chunk(2, dim=-1)
173
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
174
+
175
+ q = q * self.scale
176
+
177
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
178
+
179
+ if exists(media_locations):
180
+ # at each boolean of True, increment the time counter (relative to media time)
181
+ text_time = media_locations.cumsum(dim=-1)
182
+ media_time = torch.arange(T_img, device=x.device) + 1
183
+
184
+ if not attend_previous:
185
+ text_time[~media_locations] += 1
186
+ # make sure max is still the number of images in the sequence
187
+ text_time[
188
+ text_time
189
+ > repeat(
190
+ torch.count_nonzero(media_locations, dim=1),
191
+ "b -> b i",
192
+ i=text_time.shape[1],
193
+ )
194
+ ] = 0
195
+
196
+ # text time must equal media time if only attending to most immediate image
197
+ # otherwise, as long as text time is greater than media time (if attending to all previous images / media)
198
+ mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
199
+
200
+ text_to_media_mask = mask_op(
201
+ rearrange(text_time, "b i -> b 1 i 1"),
202
+ repeat(media_time, "j -> 1 1 1 (j n)", n=n),
203
+ )
204
+ sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
205
+
206
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
207
+ attn = sim.softmax(dim=-1)
208
+
209
+ if exists(media_locations) and self.only_attend_immediate_media:
210
+ # any text without a preceding media needs to have attention zeroed out
211
+ text_without_media_mask = text_time == 0
212
+ text_without_media_mask = rearrange(text_without_media_mask, "b i -> b 1 i 1")
213
+ attn = attn.masked_fill(text_without_media_mask, 0.0)
214
+
215
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
216
+ out = rearrange(out, "b h n d -> b n (h d)")
217
+ return self.to_out(out)
218
+
219
+
220
+ class GatedCrossAttentionBlock(nn.Module):
221
+ def __init__(
222
+ self,
223
+ *,
224
+ dim,
225
+ dim_visual,
226
+ dim_head=64,
227
+ heads=8,
228
+ ff_mult=4,
229
+ only_attend_immediate_media=True,
230
+ ):
231
+ super().__init__()
232
+ self.attn = MaskedCrossAttention(
233
+ dim=dim,
234
+ dim_visual=dim_visual,
235
+ dim_head=dim_head,
236
+ heads=heads,
237
+ only_attend_immediate_media=only_attend_immediate_media,
238
+ )
239
+ self.attn_gate = nn.Parameter(torch.tensor([0.0]))
240
+
241
+ self.ff = FeedForward(dim, mult=ff_mult)
242
+ self.ff_gate = nn.Parameter(torch.tensor([0.0]))
243
+
244
+ def forward(
245
+ self,
246
+ x,
247
+ media,
248
+ media_locations=None,
249
+ attend_previous=True,
250
+ ):
251
+ x = (
252
+ self.attn(
253
+ x,
254
+ media,
255
+ media_locations=media_locations,
256
+ attend_previous=attend_previous,
257
+ )
258
+ * self.attn_gate.tanh()
259
+ + x
260
+ )
261
+ x = self.ff(x) * self.ff_gate.tanh() + x
262
+
263
+ return x
mmgpt/models/open_flamingo/utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def extend_instance(obj, mixin):
2
+ """Apply mixins to a class instance after creation"""
3
+ base_cls = obj.__class__
4
+ base_cls_name = obj.__class__.__name__
5
+ obj.__class__ = type(
6
+ base_cls_name, (mixin, base_cls), {}
7
+ ) # mixin needs to go first for our forward() logic to work
8
+
9
+
10
+ def getattr_recursive(obj, att):
11
+ """
12
+ Return nested attribute of obj
13
+ Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
14
+ """
15
+ if att == "":
16
+ return obj
17
+ i = att.find(".")
18
+ if i < 0:
19
+ return getattr(obj, att)
20
+ else:
21
+ return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
22
+
23
+
24
+ def setattr_recursive(obj, att, val):
25
+ """
26
+ Set nested attribute of obj
27
+ Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
28
+ """
29
+ if "." in att:
30
+ obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
31
+ setattr(obj, att.split(".")[-1], val)
mmgpt/train/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
mmgpt/train/distributed.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/mlfoundations/open_flamingo"""
2
+ import os
3
+
4
+ import torch
5
+
6
+ try:
7
+ import horovod.torch as hvd
8
+ except ImportError:
9
+ hvd = None
10
+
11
+
12
+ def is_global_master(args):
13
+ return args.rank == 0
14
+
15
+
16
+ def is_local_master(args):
17
+ return args.local_rank == 0
18
+
19
+
20
+ def is_master(args, local=False):
21
+ return is_local_master(args) if local else is_global_master(args)
22
+
23
+
24
+ def is_using_horovod():
25
+ # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
26
+ # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
27
+ ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
28
+ pmi_vars = ["PMI_RANK", "PMI_SIZE"]
29
+ if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]):
30
+ return True
31
+ else:
32
+ return False
33
+
34
+
35
+ def is_using_distributed():
36
+ if "WORLD_SIZE" in os.environ:
37
+ return int(os.environ["WORLD_SIZE"]) > 1
38
+ if "SLURM_NTASKS" in os.environ:
39
+ return int(os.environ["SLURM_NTASKS"]) > 1
40
+ return False
41
+
42
+
43
+ def world_info_from_env():
44
+ local_rank = 0
45
+ for v in (
46
+ "LOCAL_RANK",
47
+ "MPI_LOCALRANKID",
48
+ "SLURM_LOCALID",
49
+ "OMPI_COMM_WORLD_LOCAL_RANK",
50
+ ):
51
+ if v in os.environ:
52
+ local_rank = int(os.environ[v])
53
+ break
54
+ global_rank = 0
55
+ for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"):
56
+ if v in os.environ:
57
+ global_rank = int(os.environ[v])
58
+ break
59
+ world_size = 1
60
+ for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"):
61
+ if v in os.environ:
62
+ world_size = int(os.environ[v])
63
+ break
64
+
65
+ return local_rank, global_rank, world_size
66
+
67
+
68
+ def init_distributed_device(args):
69
+ # Distributed training = training on more than one GPU.
70
+ # Works in both single and multi-node scenarios.
71
+ args.distributed = False
72
+ args.world_size = 1
73
+ args.rank = 0 # global rank
74
+ args.local_rank = 0
75
+ if args.horovod:
76
+ assert hvd is not None, "Horovod is not installed"
77
+ hvd.init()
78
+ args.local_rank = int(hvd.local_rank())
79
+ args.rank = hvd.rank()
80
+ args.world_size = hvd.size()
81
+ args.distributed = True
82
+ os.environ["LOCAL_RANK"] = str(args.local_rank)
83
+ os.environ["RANK"] = str(args.rank)
84
+ os.environ["WORLD_SIZE"] = str(args.world_size)
85
+ elif is_using_distributed():
86
+ if "SLURM_PROCID" in os.environ:
87
+ # DDP via SLURM
88
+ args.local_rank, args.rank, args.world_size = world_info_from_env()
89
+ # SLURM var -> torch.distributed vars in case needed
90
+ os.environ["LOCAL_RANK"] = str(args.local_rank)
91
+ os.environ["RANK"] = str(args.rank)
92
+ os.environ["WORLD_SIZE"] = str(args.world_size)
93
+ torch.distributed.init_process_group(
94
+ backend=args.dist_backend,
95
+ init_method=args.dist_url,
96
+ world_size=args.world_size,
97
+ rank=args.rank,
98
+ )
99
+ else:
100
+ # DDP via torchrun, torch.distributed.launch
101
+ args.local_rank, _, _ = world_info_from_env()
102
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url)
103
+ args.world_size = torch.distributed.get_world_size()
104
+ args.rank = torch.distributed.get_rank()
105
+ args.distributed = True
106
+ else:
107
+ # needed to run on single gpu
108
+ torch.distributed.init_process_group(
109
+ backend=args.dist_backend,
110
+ init_method=args.dist_url,
111
+ world_size=1,
112
+ rank=0,
113
+ )
114
+
115
+ if torch.cuda.is_available():
116
+ if args.distributed and not args.no_set_device_rank:
117
+ device = "cuda:%d" % args.local_rank
118
+ else:
119
+ device = "cuda:0"
120
+ torch.cuda.set_device(device)
121
+ else:
122
+ device = "cpu"
123
+ args.device = device
124
+ device = torch.device(device)
125
+ return device
126
+
127
+
128
+ def is_rank0():
129
+ if not torch.distributed.is_initialized():
130
+ return True
131
+ return torch.distributed.get_rank() == 0
mmgpt/train/instruction_finetune.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/mlfoundations/open_flamingo"""
2
+
3
+ import argparse
4
+ import copy
5
+ import glob
6
+ import os
7
+ import random
8
+ import time
9
+
10
+ import numpy as np
11
+ import torch
12
+ import wandb
13
+ from mmengine import Config
14
+ from torch.nn.parallel import DistributedDataParallel as DDP
15
+ from torch.utils.data import DataLoader, DistributedSampler
16
+ from tqdm import tqdm
17
+ from transformers import (
18
+ get_constant_schedule_with_warmup,
19
+ get_cosine_schedule_with_warmup,
20
+ get_linear_schedule_with_warmup,
21
+ )
22
+
23
+ from mmgpt import create_model_and_transforms
24
+ from mmgpt.models.builder import create_toy_model_and_transforms
25
+ from mmgpt.datasets import InfiniteSampler, build_dataset
26
+ from mmgpt.train.distributed import init_distributed_device, world_info_from_env
27
+ from mmgpt.train.train_utils import AverageMeter, get_autocast, get_cast_dtype, get_checkpoint
28
+
29
+
30
+ def random_seed(seed=42, rank=0):
31
+ torch.manual_seed(seed + rank)
32
+ np.random.seed(seed + rank)
33
+ random.seed(seed + rank)
34
+
35
+
36
+ def main():
37
+ parser = argparse.ArgumentParser()
38
+ parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
39
+ parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
40
+ parser.add_argument("--lm_path", default="checkpoints/llama-7b_hf", type=str)
41
+ parser.add_argument(
42
+ "--tokenizer_path",
43
+ default="checkpoints/llama-7b_hf",
44
+ type=str,
45
+ help="path to tokenizer",
46
+ )
47
+ parser.add_argument(
48
+ "--pretrained_path",
49
+ default="checkpoints/OpenFlamingo-9B/checkpoint.pt",
50
+ type=str,
51
+ help="path to pretrained model",
52
+ )
53
+ parser.add_argument(
54
+ "--run_name",
55
+ type=str,
56
+ default="train-my-gpt4",
57
+ help="used to name saving directory and wandb run",
58
+ )
59
+ parser.add_argument("--use_media_placement_augmentation", action="store_true")
60
+ parser.add_argument("--offline", action="store_true")
61
+ parser.add_argument("--num_epochs", type=int, default=1)
62
+ parser.add_argument("--logging_steps", type=int, default=100, help="log loss every n steps")
63
+ # Sum of gradient optimization batch size
64
+ parser.add_argument(
65
+ "--resume_from_checkpoint",
66
+ type=str,
67
+ help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states",
68
+ default=None,
69
+ )
70
+ parser.add_argument(
71
+ "--delete_previous_checkpoint",
72
+ action="store_true",
73
+ help="delete previous checkpoint when saving new checkpoint",
74
+ )
75
+ parser.add_argument("--seed", type=int, default=42)
76
+ parser.add_argument("--learning_rate", default=1e-5, type=float)
77
+ parser.add_argument(
78
+ "--lr_scheduler",
79
+ default="constant",
80
+ type=str,
81
+ help="constant, linear, or cosine",
82
+ )
83
+ parser.add_argument("--warmup_steps", default=100, type=int)
84
+ parser.add_argument("--weight_decay", default=0.1, type=float)
85
+ parser.add_argument(
86
+ "--precision",
87
+ choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"],
88
+ default="amp",
89
+ help="Floating point precision.",
90
+ )
91
+ # data args
92
+ parser.add_argument("--workers", type=int, default=0)
93
+ parser.add_argument("--batch_size", type=int, default=1)
94
+ parser.add_argument("--dataset_config", type=str, default=None, help="path to dataset config file")
95
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=16)
96
+ # Finetune config
97
+ parser.add_argument("--tuning_config", type=str, default=None, help="path to tuning config file")
98
+ # distributed training args
99
+ parser.add_argument(
100
+ "--dist-url",
101
+ default="env://",
102
+ type=str,
103
+ help="url used to set up distributed training",
104
+ )
105
+ parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend")
106
+ parser.add_argument(
107
+ "--horovod",
108
+ default=False,
109
+ action="store_true",
110
+ help="Use horovod for distributed training.",
111
+ )
112
+ parser.add_argument(
113
+ "--no-set-device-rank",
114
+ default=False,
115
+ action="store_true",
116
+ help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
117
+ )
118
+ # wandb args
119
+ parser.add_argument("--report_to_wandb", default=False, action="store_true")
120
+ parser.add_argument(
121
+ "--wandb_project",
122
+ type=str,
123
+ )
124
+ parser.add_argument(
125
+ "--wandb_entity",
126
+ type=str,
127
+ )
128
+ parser.add_argument(
129
+ "--save_checkpoints_to_wandb",
130
+ default=False,
131
+ action="store_true",
132
+ help="save checkpoints to wandb",
133
+ )
134
+
135
+ args = parser.parse_args()
136
+
137
+ if args.save_checkpoints_to_wandb and not args.report_to_wandb:
138
+ raise ValueError("save_checkpoints_to_wandb requires report_to_wandb")
139
+
140
+ if args.offline:
141
+ os.environ["WANDB_MODE"] = "offline"
142
+ os.environ["TRANSFORMERS_OFFLINE"] = "1"
143
+
144
+ args.local_rank, args.rank, args.world_size = world_info_from_env()
145
+
146
+ if args.rank == 0:
147
+ if not os.path.exists(args.run_name):
148
+ os.makedirs(args.run_name)
149
+
150
+ device_id = init_distributed_device(args)
151
+
152
+ random_seed(args.seed)
153
+
154
+ if args.tuning_config is not None:
155
+ tuning_config = Config.fromfile(args.tuning_config)
156
+ else:
157
+ raise ValueError("tuning_config must be specified")
158
+
159
+ model, image_processor, tokenizer = create_model_and_transforms(
160
+ model_name="open_flamingo",
161
+ clip_vision_encoder_path=args.vision_encoder_path,
162
+ clip_vision_encoder_pretrained=args.vision_encoder_pretrained,
163
+ lang_encoder_path=args.lm_path,
164
+ tokenizer_path=args.tokenizer_path if args.tokenizer_path else args.lm_path,
165
+ use_media_placement_augmentation=args.use_media_placement_augmentation,
166
+ pretrained_model_path=args.pretrained_path,
167
+ tuning_config=tuning_config.tuning_config,
168
+ )
169
+
170
+ if args.dataset_config is not None:
171
+ dataset_config = Config.fromfile(args.dataset_config)
172
+ else:
173
+ raise ValueError("dataset_config must be specified")
174
+
175
+ dataset = build_dataset(
176
+ dataset_config=dataset_config.visual_datasets,
177
+ vis_processor=image_processor,
178
+ tokenizer=tokenizer,
179
+ )
180
+ train_dataloader = DataLoader(
181
+ dataset,
182
+ batch_size=args.batch_size,
183
+ num_workers=args.workers,
184
+ sampler=DistributedSampler(dataset, shuffle=True, drop_last=True),
185
+ collate_fn=dataset.collater,
186
+ )
187
+
188
+ # build language dataset and dataloader for multi-modality training
189
+ if dataset_config.get('language_datasets') is not None and len(dataset_config.language_datasets) > 0:
190
+ lang_dataset = build_dataset(
191
+ dataset_config=dataset_config.language_datasets,
192
+ tokenizer=tokenizer,
193
+ )
194
+ lang_dataloader = DataLoader(
195
+ lang_dataset,
196
+ batch_size=args.batch_size,
197
+ num_workers=args.workers,
198
+ sampler=InfiniteSampler(lang_dataset, shuffle=True),
199
+ collate_fn=lang_dataset.collater,
200
+ )
201
+ lang_dataloader = iter(lang_dataloader)
202
+ else:
203
+ lang_dataloader = None
204
+
205
+ random_seed(args.seed, args.rank)
206
+
207
+ print(f"Start running training on rank {args.rank}.")
208
+
209
+ if args.rank == 0 and args.report_to_wandb:
210
+ wandb.init(
211
+ project=args.wandb_project,
212
+ entity=args.wandb_entity,
213
+ name=args.run_name,
214
+ config=vars(args),
215
+ )
216
+
217
+ device_id = args.rank % torch.cuda.device_count()
218
+ model = model.to(device_id)
219
+
220
+ ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters=True)
221
+
222
+ def get_grouped_params(model):
223
+ params_with_wd, params_without_wd = [], []
224
+
225
+ def apply_decay(x):
226
+ return (
227
+ "gated_cross_attn_layer" in x
228
+ and "ff_gate" not in x
229
+ and "attn_gate" not in x
230
+ and "norm" not in x
231
+ and "bias" not in x
232
+ )
233
+
234
+ for n, p in model.named_parameters():
235
+ # if p.requires_grad:
236
+ if apply_decay(n):
237
+ params_with_wd.append(p)
238
+ else:
239
+ params_without_wd.append(p)
240
+
241
+ return [
242
+ {"params": params_with_wd, "weight_decay": args.weight_decay},
243
+ {"params": params_without_wd, "weight_decay": 0.0},
244
+ ]
245
+
246
+ optimizer = torch.optim.AdamW(get_grouped_params(ddp_model), lr=args.learning_rate)
247
+
248
+ total_training_steps = len(train_dataloader) * args.num_epochs
249
+
250
+ if args.rank == 0:
251
+ print(f"Total training steps: {total_training_steps}")
252
+
253
+ if args.lr_scheduler == "linear":
254
+ lr_scheduler = get_linear_schedule_with_warmup(
255
+ optimizer,
256
+ num_warmup_steps=args.warmup_steps,
257
+ num_training_steps=total_training_steps // args.gradient_accumulation_steps,
258
+ )
259
+ elif args.lr_scheduler == "cosine":
260
+ lr_scheduler = get_cosine_schedule_with_warmup(
261
+ optimizer,
262
+ num_warmup_steps=args.warmup_steps,
263
+ num_training_steps=total_training_steps // args.gradient_accumulation_steps,
264
+ )
265
+ else:
266
+ lr_scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps)
267
+
268
+ # check if a checkpoint exists for this run
269
+ if os.path.exists(f"{args.run_name}") and args.resume_from_checkpoint is None:
270
+ checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt")
271
+ if len(checkpoint_list) == 0:
272
+ print(f"Found no checkpoints for run {args.run_name}.")
273
+ else:
274
+ args.resume_from_checkpoint = sorted(checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0]))[-1]
275
+ print(f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}.")
276
+
277
+ resume_from_epoch = 0
278
+ if args.resume_from_checkpoint is not None:
279
+ if args.rank == 0:
280
+ print(f"Loading checkpoint from {args.resume_from_checkpoint}")
281
+ checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu")
282
+ ddp_model.load_state_dict(checkpoint["model_state_dict"], False)
283
+ optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
284
+ lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
285
+ resume_from_epoch = checkpoint["epoch"] + 1
286
+
287
+ ddp_model.train()
288
+
289
+ for epoch in range(resume_from_epoch, args.num_epochs):
290
+ train_dataloader.sampler.set_epoch(epoch)
291
+
292
+ train_one_epoch(
293
+ args=args,
294
+ model=ddp_model,
295
+ epoch=epoch,
296
+ tokenizer=tokenizer,
297
+ optimizer=optimizer,
298
+ lr_scheduler=lr_scheduler,
299
+ train_dataloader=train_dataloader,
300
+ language_dataloader=lang_dataloader,
301
+ device_id=device_id,
302
+ wandb=wandb,
303
+ )
304
+
305
+ if args.rank == 0:
306
+ if not os.path.exists(args.run_name):
307
+ os.makedirs(args.run_name)
308
+
309
+ checkpoint_dict = {
310
+ "epoch": epoch,
311
+ "model_state_dict": get_checkpoint(ddp_model),
312
+ "optimizer_state_dict": optimizer.state_dict(),
313
+ "lr_scheduler_state_dict": lr_scheduler.state_dict(),
314
+ "tuning_config": tuning_config,
315
+ }
316
+
317
+ print(f"Saving checkpoint to {args.run_name}/checkpoint_{epoch}.pt")
318
+ torch.save(checkpoint_dict, f"{args.run_name}/checkpoint_{epoch}.pt")
319
+ if args.report_to_wandb and args.save_checkpoints_to_wandb:
320
+ wandb.save(f"{args.run_name}/checkpoint_{epoch}.pt")
321
+
322
+ if args.delete_previous_checkpoint:
323
+ if epoch > 0:
324
+ os.remove(f"{args.run_name}/checkpoint_{epoch-1}.pt")
325
+ if args.rank == 0:
326
+ torch.save(
327
+ {"model_state_dict": get_checkpoint(ddp_model.module), "tuning_config": tuning_config},
328
+ f"{args.run_name}/final_weights.pt",
329
+ )
330
+ if args.report_to_wandb and args.save_checkpoints_to_wandb:
331
+ wandb.save(f"{args.run_name}/final_weights.pt")
332
+
333
+
334
+ def train_one_epoch(
335
+ args,
336
+ model,
337
+ epoch,
338
+ train_dataloader,
339
+ language_dataloader,
340
+ tokenizer,
341
+ optimizer,
342
+ lr_scheduler,
343
+ device_id,
344
+ wandb,
345
+ ):
346
+ num_batches_per_epoch = len(train_dataloader)
347
+
348
+ total_training_steps = num_batches_per_epoch * args.num_epochs
349
+
350
+ autocast = get_autocast(args.precision)
351
+ cast_dtype = get_cast_dtype(args.precision)
352
+
353
+ model.train()
354
+
355
+ # setup logging
356
+ step_time_m = AverageMeter() # time for one optimizer step (> 1 batch if using gradient accum)
357
+ data_time_m = (
358
+ AverageMeter()
359
+ ) # avg time to load one batch of both C4 AND laion (= 1 batch regardless of gradient accum)
360
+ end = time.time()
361
+
362
+ # loop through dataloader
363
+ for num_steps, batch in tqdm(
364
+ enumerate(train_dataloader),
365
+ disable=args.rank != 0,
366
+ total=total_training_steps,
367
+ initial=(epoch * num_batches_per_epoch),
368
+ ):
369
+ data_time_m.update(time.time() - end)
370
+
371
+ global_step = num_steps + epoch * num_batches_per_epoch
372
+
373
+ #### VISION FORWARD PASS ####
374
+ images = batch["image"].to(device_id, dtype=cast_dtype, non_blocking=True).unsqueeze(1).unsqueeze(1)
375
+ input_ids = batch["input_ids"].to(device_id, dtype=cast_dtype, non_blocking=True)
376
+ attention_mask = batch["attention_mask"].to(device_id, dtype=cast_dtype, non_blocking=True)
377
+ labels = batch["labels"].to(device_id, dtype=cast_dtype, non_blocking=True)
378
+
379
+ with autocast():
380
+ loss_batch = model(
381
+ vision_x=images,
382
+ lang_x=input_ids,
383
+ attention_mask=attention_mask,
384
+ labels=labels,
385
+ )[0]
386
+ loss = loss_batch / args.gradient_accumulation_steps
387
+ loss_vision = loss # for logging
388
+
389
+ #### BACKWARD PASS ####
390
+ loss.backward()
391
+
392
+ #### LANGUAGE FORWARD PASS ####
393
+ if language_dataloader is not None:
394
+ batch_lang = next(language_dataloader)
395
+ lang_input_ids = batch_lang["input_ids"].to(device_id, dtype=cast_dtype, non_blocking=True)
396
+ lang_attention_mask = batch_lang["attention_mask"].to(device_id, dtype=cast_dtype, non_blocking=True)
397
+ lang_labels = batch_lang["labels"].to(device_id, dtype=cast_dtype, non_blocking=True)
398
+
399
+ with autocast():
400
+ lang_loss_batch = model(
401
+ vision_x=None,
402
+ lang_x=lang_input_ids,
403
+ attention_mask=lang_attention_mask,
404
+ labels=lang_labels,
405
+ )[0]
406
+ lang_loss = lang_loss_batch / args.gradient_accumulation_steps
407
+ #### BACKWARD PASS ####
408
+ lang_loss.backward()
409
+
410
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
411
+
412
+ # step optimizer and log
413
+ if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or (num_steps == num_batches_per_epoch - 1):
414
+ optimizer.step()
415
+ lr_scheduler.step()
416
+ optimizer.zero_grad()
417
+
418
+ # step time and reset end outside of rank 0
419
+ step_time_m.update(time.time() - end)
420
+ end = time.time()
421
+
422
+ if args.rank == 0 and args.report_to_wandb:
423
+ # compute within rank 0
424
+ samples_per_second = (
425
+ args.gradient_accumulation_steps * args.batch_size * args.world_size / step_time_m.val
426
+ )
427
+ samples_per_second_per_gpu = args.gradient_accumulation_steps * args.batch_size / step_time_m.val
428
+
429
+ wandb.log(
430
+ {
431
+ "data_time": data_time_m.avg,
432
+ "step_time": step_time_m.avg,
433
+ "samples_per_second": samples_per_second,
434
+ "samples_per_second_per_gpu": samples_per_second_per_gpu,
435
+ "lr": optimizer.param_groups[0]["lr"],
436
+ },
437
+ commit=False,
438
+ )
439
+ step_time_m.reset()
440
+ data_time_m.reset()
441
+
442
+ loss_log = {
443
+ "loss": loss.item(),
444
+ "loss_vision": loss_vision.item(),
445
+ "global_step": global_step,
446
+ }
447
+ if language_dataloader is not None:
448
+ loss_log["loss_lang"] = lang_loss.item()
449
+
450
+ wandb.log(loss_log, commit=True)
451
+
452
+ # Log loss to console
453
+ if ((num_steps + 1) % args.logging_steps == 0) and args.rank == 0:
454
+ print(
455
+ f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Loss: {loss.item():.3f}"
456
+ )
457
+
458
+
459
+ if __name__ == "__main__":
460
+ main()
mmgpt/train/train_utils.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/mlfoundations/open_flamingo"""
2
+ import time
3
+ from contextlib import suppress
4
+
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+
9
+ def get_cast_dtype(precision: str):
10
+ cast_dtype = None
11
+ if precision == "bf16":
12
+ cast_dtype = torch.bfloat16
13
+ elif precision == "fp16":
14
+ cast_dtype = torch.float16
15
+ return cast_dtype
16
+
17
+
18
+ def get_autocast(precision):
19
+ if precision == "amp":
20
+ return torch.cuda.amp.autocast
21
+ elif precision == "amp_bfloat16" or precision == "amp_bf16":
22
+ # amp_bfloat16 is more stable than amp float16 for clip training
23
+ return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
24
+ else:
25
+ return suppress
26
+
27
+
28
+ def train_one_epoch(
29
+ args,
30
+ model,
31
+ epoch,
32
+ laion_loader,
33
+ mmc4_loader,
34
+ tokenizer,
35
+ optimizer,
36
+ lr_scheduler,
37
+ device_id,
38
+ wandb,
39
+ ):
40
+ num_batches_per_epoch_laion = laion_loader.num_batches
41
+ num_batches_per_epoch_mmc4 = mmc4_loader.num_batches
42
+
43
+ assert (
44
+ num_batches_per_epoch_laion == num_batches_per_epoch_mmc4
45
+ ), "Number of batches in laion and mmc4 datasets must be the same"
46
+ num_batches_per_epoch = num_batches_per_epoch_mmc4
47
+ total_training_steps = num_batches_per_epoch * args.num_epochs
48
+
49
+ autocast = get_autocast(args.precision)
50
+ cast_dtype = get_cast_dtype(args.precision)
51
+
52
+ media_token_id = tokenizer("<image>", add_special_tokens=False)["input_ids"][-1]
53
+ endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1]
54
+
55
+ model.train()
56
+
57
+ # setup logging
58
+ step_time_m = AverageMeter() # time for one optimizer step (> 1 batch if using gradient accum)
59
+ data_time_m = (
60
+ AverageMeter()
61
+ ) # avg time to load one batch of both C4 AND laion (= 1 batch regardless of gradient accum)
62
+ end = time.time()
63
+
64
+ # loop through dataloader
65
+ for num_steps, (batch_laion, batch_mmc4) in tqdm(
66
+ enumerate(zip(laion_loader, mmc4_loader)),
67
+ disable=args.rank != 0,
68
+ total=total_training_steps,
69
+ initial=(epoch * num_batches_per_epoch),
70
+ ):
71
+ data_time_m.update(time.time() - end)
72
+
73
+ global_step = num_steps + epoch * num_batches_per_epoch
74
+
75
+ #### LAION FORWARD PASS ####
76
+ images = batch_laion[0].to(device_id, dtype=cast_dtype, non_blocking=True).unsqueeze(1).unsqueeze(1)
77
+
78
+ input_ids = batch_laion[1][0].to(device_id, dtype=cast_dtype, non_blocking=True)
79
+ attention_mask = batch_laion[1][1].to(device_id, dtype=cast_dtype, non_blocking=True)
80
+
81
+ labels = input_ids.clone()
82
+ labels[labels == tokenizer.pad_token_id] = -100
83
+ labels[:, 0] = -100
84
+ labels[labels == media_token_id] = -100
85
+ labels.to(device_id)
86
+
87
+ with autocast():
88
+ loss_laion = model(
89
+ vision_x=images,
90
+ lang_x=input_ids,
91
+ attention_mask=attention_mask,
92
+ labels=labels,
93
+ )[0]
94
+ divided_loss_laion = loss_laion / args.gradient_accumulation_steps
95
+
96
+ #### C4 FORWARD PASS ####
97
+ images = batch_mmc4[0].to(device_id, dtype=cast_dtype, non_blocking=True).unsqueeze(2)
98
+ input_ids = torch.stack([x[0] for x in batch_mmc4[1]]).squeeze(1)
99
+ attention_mask = torch.stack([x[1] for x in batch_mmc4[1]]).squeeze(1)
100
+
101
+ # NOTE: irena: expected shape of clip_text_input_ids / attention_mask is (N, I, max_seq_len)
102
+ labels = input_ids.clone()
103
+ labels[labels == tokenizer.pad_token_id] = -100
104
+ labels[:, 0] = -100
105
+
106
+ for i in range(labels.shape[0]):
107
+ # remove loss for any token before the first <image> token
108
+ label_idx = 0
109
+ while label_idx < labels.shape[1] and labels[i][label_idx] != media_token_id:
110
+ labels[i][label_idx] = -100
111
+ label_idx += 1
112
+
113
+ # get index of all endofchunk tokens in the sequence
114
+ endofchunk_idxs = torch.where(labels[i] == endofchunk_token_id)[0]
115
+ for endofchunk_idx in endofchunk_idxs:
116
+ token_idx = endofchunk_idx + 1
117
+ while token_idx < labels.shape[1] and labels[i][token_idx] != media_token_id:
118
+ labels[i][token_idx] = -100
119
+ token_idx += 1
120
+
121
+ labels[labels == media_token_id] = -100
122
+ labels.to(device_id)
123
+
124
+ with autocast():
125
+ loss_mmc4 = model(
126
+ vision_x=images,
127
+ lang_x=input_ids,
128
+ attention_mask=attention_mask,
129
+ labels=labels,
130
+ )[0]
131
+
132
+ # if loss is nan, skip this batch
133
+ if torch.isnan(loss_mmc4):
134
+ print("loss is nan, skipping this batch")
135
+ print("input_ids: ", tokenizer.batch_decode(input_ids))
136
+ print("labels: ", labels)
137
+ print("images: ", images)
138
+ optimizer.zero_grad()
139
+ continue
140
+
141
+ divided_loss_mmc4 = loss_mmc4 / args.gradient_accumulation_steps
142
+
143
+ #### BACKWARD PASS ####
144
+ loss = divided_loss_laion * args.loss_multiplier_laion + divided_loss_mmc4 * args.loss_multiplier_mmc4
145
+ loss.backward()
146
+
147
+ #### MASK GRADIENTS FOR EMBEDDINGS ####
148
+ # Note (anas): Do not apply weight decay to embeddings as it will break this function.
149
+ def mask_embedding(m):
150
+ if isinstance(m, torch.nn.Embedding) and m.weight.requires_grad:
151
+ zero_mask = torch.zeros_like(m.weight.grad)
152
+ zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id])
153
+ zero_mask[endofchunk_token_id] = torch.ones_like(zero_mask[endofchunk_token_id])
154
+ m.weight.grad = m.weight.grad * zero_mask
155
+
156
+ model.apply(mask_embedding)
157
+
158
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
159
+
160
+ # step optimizer and log
161
+ if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or (num_steps == num_batches_per_epoch - 1):
162
+ optimizer.step()
163
+ lr_scheduler.step()
164
+ optimizer.zero_grad()
165
+
166
+ # step time and reset end outside of rank 0
167
+ step_time_m.update(time.time() - end)
168
+ end = time.time()
169
+
170
+ if args.rank == 0 and args.report_to_wandb:
171
+ # compute within rank 0
172
+ laion_samples_per_second = (
173
+ args.gradient_accumulation_steps * args.batch_size_laion * args.world_size / step_time_m.val
174
+ )
175
+ laion_samples_per_second_per_gpu = (
176
+ args.gradient_accumulation_steps * args.batch_size_laion / step_time_m.val
177
+ )
178
+
179
+ c4_samples_per_second = (
180
+ args.gradient_accumulation_steps * args.batch_size_mmc4 * args.world_size / step_time_m.val
181
+ )
182
+ c4_samples_per_second_per_gpu = (
183
+ args.gradient_accumulation_steps * args.batch_size_mmc4 / step_time_m.val
184
+ )
185
+
186
+ wandb.log(
187
+ {
188
+ "data_time": data_time_m.avg,
189
+ "step_time": step_time_m.avg,
190
+ "laion_samples_per_second": laion_samples_per_second,
191
+ "laion_samples_per_second_per_gpu": laion_samples_per_second_per_gpu,
192
+ "c4_samples_per_second": c4_samples_per_second,
193
+ "c4_samples_per_second_per_gpu": c4_samples_per_second_per_gpu,
194
+ "lr": optimizer.param_groups[0]["lr"],
195
+ },
196
+ commit=False,
197
+ )
198
+ step_time_m.reset()
199
+ data_time_m.reset()
200
+
201
+ wandb.log(
202
+ {
203
+ "loss_laion": divided_loss_laion.item(),
204
+ "global_step": global_step,
205
+ },
206
+ commit=False,
207
+ )
208
+ wandb.log(
209
+ {"loss_mmc4": divided_loss_mmc4.item(), "global_step": global_step},
210
+ commit=True,
211
+ )
212
+
213
+ # Log loss to console
214
+ if ((num_steps + 1) % args.logging_steps == 0) and args.rank == 0:
215
+ print(
216
+ f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Loss LAION: {loss_laion.item():.3f} // Loss MMC4: {loss_mmc4.item():.3f}"
217
+ )
218
+
219
+
220
+ def get_checkpoint(model: torch.nn.Module):
221
+ state_dict = model.state_dict()
222
+ parameters = {k: v for k, v in model.named_parameters()}
223
+ # remove duplicate parameters
224
+ duplicate_keys = set(state_dict.keys()) - set(parameters.keys())
225
+ for k in duplicate_keys:
226
+ del state_dict[k]
227
+ # remove non-grad parameters
228
+ for name, p in parameters.items():
229
+ if not p.requires_grad:
230
+ del state_dict[name]
231
+
232
+ return state_dict
233
+
234
+
235
+ class AverageMeter(object):
236
+ """Computes and stores the average and current value"""
237
+
238
+ def __init__(self):
239
+ self.reset()
240
+
241
+ def reset(self):
242
+ self.val = 0
243
+ self.avg = 0
244
+ self.sum = 0
245
+ self.count = 0
246
+
247
+ def update(self, val, n=1):
248
+ self.val = val
249
+ self.sum += val * n
250
+ self.count += n
251
+ self.avg = self.sum / self.count
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops
2
+ einops-exts
3
+ transformers
4
+ peft
5
+ bigmodelvis
6
+ torch
7
+ torchvision
8
+ pillow
9
+ more-itertools
10
+ datasets
11
+ braceexpand
12
+ webdataset
13
+ wandb
14
+ nltk
15
+ scipy
16
+ inflection
17
+ sentencepiece
18
+ open_clip_torch
19
+ mmengine
20
+ gradio
setup.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from setuptools import find_packages, setup
4
+
5
+ if __name__ == "__main__":
6
+ with Path(Path(__file__).parent, "README.md").open(encoding="utf-8") as file:
7
+ long_description = file.read()
8
+
9
+ # TODO: This is a hack to get around the fact that we can't read the requirements.txt file, we should fix this.
10
+ # def _read_reqs(relpath):
11
+ # fullpath = os.path.join(Path(__file__).parent, relpath)
12
+ # with open(fullpath) as f:
13
+ # return [
14
+ # s.strip()
15
+ # for s in f.readlines()
16
+ # if (s.strip() and not s.startswith("#"))
17
+ # ]
18
+
19
+ REQUIREMENTS = [
20
+ "einops",
21
+ "einops-exts",
22
+ "transformers",
23
+ "torch",
24
+ "torchvision",
25
+ "pillow",
26
+ "more-itertools",
27
+ "datasets",
28
+ "braceexpand",
29
+ "webdataset",
30
+ "wandb",
31
+ "nltk",
32
+ "scipy",
33
+ "inflection",
34
+ "sentencepiece",
35
+ "open_clip_torch",
36
+ ]
37
+
38
+ setup(
39
+ name="mmgpt",
40
+ packages=find_packages(),
41
+ include_package_data=True,
42
+ version="0.0.1",
43
+ license="Apache 2.0",
44
+ description="An open-source framework for multi-modality instruction fine-tuning",
45
+ long_description=long_description,
46
+ long_description_content_type="text/markdown",
47
+ data_files=[(".", ["README.md"])],
48
+ keywords=["machine learning"],
49
+ install_requires=REQUIREMENTS,
50
+ )