Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- .gitignore +148 -0
- .gradio/certificate.pem +31 -0
- LICENSE +204 -0
- README.md +196 -12
- README_zh-CN.md +184 -0
- app.py +379 -0
- configs/dataset_config.py +64 -0
- configs/lora_config.py +9 -0
- docs/images/demo_image.jpg +0 -0
- environment.yml +10 -0
- mmgpt/__init__.py +2 -0
- mmgpt/datasets/__init__.py +4 -0
- mmgpt/datasets/alpaca_gpt4_dataset.py +26 -0
- mmgpt/datasets/aokvqa_dataset.py +51 -0
- mmgpt/datasets/baize_dataset.py +86 -0
- mmgpt/datasets/builder.py +126 -0
- mmgpt/datasets/cc_sbu_align_dataset.py +107 -0
- mmgpt/datasets/clevr_dataset.py +74 -0
- mmgpt/datasets/coco_caption_dataset.py +119 -0
- mmgpt/datasets/dial_dataset.py +83 -0
- mmgpt/datasets/dolly_dataset.py +150 -0
- mmgpt/datasets/gqa_dataset.py +83 -0
- mmgpt/datasets/llava_dataset.py +18 -0
- mmgpt/datasets/nlvr_dataset.py +212 -0
- mmgpt/datasets/ocr_vqa_dataset.py +23 -0
- mmgpt/datasets/samplers/__init__.py +1 -0
- mmgpt/datasets/samplers/infinite_sampler.py +30 -0
- mmgpt/datasets/snli_ve_datasets.py +82 -0
- mmgpt/datasets/text_ocr_dataset.py +64 -0
- mmgpt/datasets/vqa_dataset.py +227 -0
- mmgpt/models/__init__.py +0 -0
- mmgpt/models/blip2/__init__.py +0 -0
- mmgpt/models/builder.py +74 -0
- mmgpt/models/open_flamingo/__init__.py +3 -0
- mmgpt/models/open_flamingo/builder.py +142 -0
- mmgpt/models/open_flamingo/flamingo.py +208 -0
- mmgpt/models/open_flamingo/flamingo_lm.py +131 -0
- mmgpt/models/open_flamingo/helpers.py +263 -0
- mmgpt/models/open_flamingo/utils.py +31 -0
- mmgpt/train/__init__.py +1 -0
- mmgpt/train/distributed.py +131 -0
- mmgpt/train/instruction_finetune.py +460 -0
- mmgpt/train/train_utils.py +251 -0
- requirements.txt +20 -0
- setup.py +50 -0
.gitignore
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pt
|
2 |
+
|
3 |
+
wandb/
|
4 |
+
|
5 |
+
checkpoints/
|
6 |
+
tests/
|
7 |
+
|
8 |
+
# Byte-compiled / optimized / DLL files
|
9 |
+
__pycache__/
|
10 |
+
*.py[cod]
|
11 |
+
*$py.class
|
12 |
+
|
13 |
+
# C extensions
|
14 |
+
*.so
|
15 |
+
|
16 |
+
# Distribution / packaging
|
17 |
+
.Python
|
18 |
+
build/
|
19 |
+
develop-eggs/
|
20 |
+
dist/
|
21 |
+
downloads/
|
22 |
+
eggs/
|
23 |
+
.eggs/
|
24 |
+
lib/
|
25 |
+
lib64/
|
26 |
+
parts/
|
27 |
+
sdist/
|
28 |
+
var/
|
29 |
+
wheels/
|
30 |
+
pip-wheel-metadata/
|
31 |
+
share/python-wheels/
|
32 |
+
*.egg-info/
|
33 |
+
.installed.cfg
|
34 |
+
*.egg
|
35 |
+
MANIFEST
|
36 |
+
|
37 |
+
# PyInstaller
|
38 |
+
# Usually these files are written by a python script from a template
|
39 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
40 |
+
*.manifest
|
41 |
+
*.spec
|
42 |
+
|
43 |
+
# Installer logs
|
44 |
+
pip-log.txt
|
45 |
+
pip-delete-this-directory.txt
|
46 |
+
|
47 |
+
# Unit test / coverage reports
|
48 |
+
htmlcov/
|
49 |
+
.tox/
|
50 |
+
.nox/
|
51 |
+
.coverage
|
52 |
+
.coverage.*
|
53 |
+
.cache
|
54 |
+
nosetests.xml
|
55 |
+
coverage.xml
|
56 |
+
*.cover
|
57 |
+
*.py,cover
|
58 |
+
.hypothesis/
|
59 |
+
.pytest_cache/
|
60 |
+
|
61 |
+
# Translations
|
62 |
+
*.mo
|
63 |
+
*.pot
|
64 |
+
|
65 |
+
# Django stuff:
|
66 |
+
*.log
|
67 |
+
local_settings.py
|
68 |
+
db.sqlite3
|
69 |
+
db.sqlite3-journal
|
70 |
+
|
71 |
+
# Flask stuff:
|
72 |
+
instance/
|
73 |
+
.webassets-cache
|
74 |
+
|
75 |
+
# Scrapy stuff:
|
76 |
+
.scrapy
|
77 |
+
|
78 |
+
# Sphinx documentation
|
79 |
+
docs/_build/
|
80 |
+
|
81 |
+
# PyBuilder
|
82 |
+
target/
|
83 |
+
|
84 |
+
# Jupyter Notebook
|
85 |
+
.ipynb_checkpoints
|
86 |
+
|
87 |
+
# IPython
|
88 |
+
profile_default/
|
89 |
+
ipython_config.py
|
90 |
+
|
91 |
+
# pyenv
|
92 |
+
.python-version
|
93 |
+
|
94 |
+
# pipenv
|
95 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
96 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
97 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
98 |
+
# install all needed dependencies.
|
99 |
+
#Pipfile.lock
|
100 |
+
|
101 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
102 |
+
__pypackages__/
|
103 |
+
|
104 |
+
# Celery stuff
|
105 |
+
celerybeat-schedule
|
106 |
+
celerybeat.pid
|
107 |
+
|
108 |
+
# SageMath parsed files
|
109 |
+
*.sage.py
|
110 |
+
|
111 |
+
# Environments
|
112 |
+
.env
|
113 |
+
.venv
|
114 |
+
env/
|
115 |
+
venv/
|
116 |
+
ENV/
|
117 |
+
env.bak/
|
118 |
+
venv.bak/
|
119 |
+
|
120 |
+
# Pycharm project settings
|
121 |
+
.idea
|
122 |
+
|
123 |
+
# Spyder project settings
|
124 |
+
.spyderproject
|
125 |
+
.spyproject
|
126 |
+
|
127 |
+
# Rope project settings
|
128 |
+
.ropeproject
|
129 |
+
|
130 |
+
# mkdocs documentation
|
131 |
+
/site
|
132 |
+
|
133 |
+
# mypy
|
134 |
+
.mypy_cache/
|
135 |
+
.dmypy.json
|
136 |
+
dmypy.json
|
137 |
+
|
138 |
+
*.out
|
139 |
+
src/wandb
|
140 |
+
wandb
|
141 |
+
|
142 |
+
# Pyre type checker
|
143 |
+
.pyre/
|
144 |
+
|
145 |
+
# Training
|
146 |
+
batchscript*
|
147 |
+
work_dirs
|
148 |
+
data
|
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
LICENSE
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright 2018-2023 OpenMMLab. All rights reserved.
|
2 |
+
|
3 |
+
Apache License
|
4 |
+
Version 2.0, January 2004
|
5 |
+
http://www.apache.org/licenses/
|
6 |
+
|
7 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
8 |
+
|
9 |
+
1. Definitions.
|
10 |
+
|
11 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
12 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
13 |
+
|
14 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
15 |
+
the copyright owner that is granting the License.
|
16 |
+
|
17 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
18 |
+
other entities that control, are controlled by, or are under common
|
19 |
+
control with that entity. For the purposes of this definition,
|
20 |
+
"control" means (i) the power, direct or indirect, to cause the
|
21 |
+
direction or management of such entity, whether by contract or
|
22 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
23 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
24 |
+
|
25 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
26 |
+
exercising permissions granted by this License.
|
27 |
+
|
28 |
+
"Source" form shall mean the preferred form for making modifications,
|
29 |
+
including but not limited to software source code, documentation
|
30 |
+
source, and configuration files.
|
31 |
+
|
32 |
+
"Object" form shall mean any form resulting from mechanical
|
33 |
+
transformation or translation of a Source form, including but
|
34 |
+
not limited to compiled object code, generated documentation,
|
35 |
+
and conversions to other media types.
|
36 |
+
|
37 |
+
"Work" shall mean the work of authorship, whether in Source or
|
38 |
+
Object form, made available under the License, as indicated by a
|
39 |
+
copyright notice that is included in or attached to the work
|
40 |
+
(an example is provided in the Appendix below).
|
41 |
+
|
42 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
43 |
+
form, that is based on (or derived from) the Work and for which the
|
44 |
+
editorial revisions, annotations, elaborations, or other modifications
|
45 |
+
represent, as a whole, an original work of authorship. For the purposes
|
46 |
+
of this License, Derivative Works shall not include works that remain
|
47 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
48 |
+
the Work and Derivative Works thereof.
|
49 |
+
|
50 |
+
"Contribution" shall mean any work of authorship, including
|
51 |
+
the original version of the Work and any modifications or additions
|
52 |
+
to that Work or Derivative Works thereof, that is intentionally
|
53 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
54 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
55 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
56 |
+
means any form of electronic, verbal, or written communication sent
|
57 |
+
to the Licensor or its representatives, including but not limited to
|
58 |
+
communication on electronic mailing lists, source code control systems,
|
59 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
60 |
+
Licensor for the purpose of discussing and improving the Work, but
|
61 |
+
excluding communication that is conspicuously marked or otherwise
|
62 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
63 |
+
|
64 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
65 |
+
on behalf of whom a Contribution has been received by Licensor and
|
66 |
+
subsequently incorporated within the Work.
|
67 |
+
|
68 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
69 |
+
this License, each Contributor hereby grants to You a perpetual,
|
70 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
71 |
+
copyright license to reproduce, prepare Derivative Works of,
|
72 |
+
publicly display, publicly perform, sublicense, and distribute the
|
73 |
+
Work and such Derivative Works in Source or Object form.
|
74 |
+
|
75 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
76 |
+
this License, each Contributor hereby grants to You a perpetual,
|
77 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
78 |
+
(except as stated in this section) patent license to make, have made,
|
79 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
80 |
+
where such license applies only to those patent claims licensable
|
81 |
+
by such Contributor that are necessarily infringed by their
|
82 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
83 |
+
with the Work to which such Contribution(s) was submitted. If You
|
84 |
+
institute patent litigation against any entity (including a
|
85 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
86 |
+
or a Contribution incorporated within the Work constitutes direct
|
87 |
+
or contributory patent infringement, then any patent licenses
|
88 |
+
granted to You under this License for that Work shall terminate
|
89 |
+
as of the date such litigation is filed.
|
90 |
+
|
91 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
92 |
+
Work or Derivative Works thereof in any medium, with or without
|
93 |
+
modifications, and in Source or Object form, provided that You
|
94 |
+
meet the following conditions:
|
95 |
+
|
96 |
+
(a) You must give any other recipients of the Work or
|
97 |
+
Derivative Works a copy of this License; and
|
98 |
+
|
99 |
+
(b) You must cause any modified files to carry prominent notices
|
100 |
+
stating that You changed the files; and
|
101 |
+
|
102 |
+
(c) You must retain, in the Source form of any Derivative Works
|
103 |
+
that You distribute, all copyright, patent, trademark, and
|
104 |
+
attribution notices from the Source form of the Work,
|
105 |
+
excluding those notices that do not pertain to any part of
|
106 |
+
the Derivative Works; and
|
107 |
+
|
108 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
109 |
+
distribution, then any Derivative Works that You distribute must
|
110 |
+
include a readable copy of the attribution notices contained
|
111 |
+
within such NOTICE file, excluding those notices that do not
|
112 |
+
pertain to any part of the Derivative Works, in at least one
|
113 |
+
of the following places: within a NOTICE text file distributed
|
114 |
+
as part of the Derivative Works; within the Source form or
|
115 |
+
documentation, if provided along with the Derivative Works; or,
|
116 |
+
within a display generated by the Derivative Works, if and
|
117 |
+
wherever such third-party notices normally appear. The contents
|
118 |
+
of the NOTICE file are for informational purposes only and
|
119 |
+
do not modify the License. You may add Your own attribution
|
120 |
+
notices within Derivative Works that You distribute, alongside
|
121 |
+
or as an addendum to the NOTICE text from the Work, provided
|
122 |
+
that such additional attribution notices cannot be construed
|
123 |
+
as modifying the License.
|
124 |
+
|
125 |
+
You may add Your own copyright statement to Your modifications and
|
126 |
+
may provide additional or different license terms and conditions
|
127 |
+
for use, reproduction, or distribution of Your modifications, or
|
128 |
+
for any such Derivative Works as a whole, provided Your use,
|
129 |
+
reproduction, and distribution of the Work otherwise complies with
|
130 |
+
the conditions stated in this License.
|
131 |
+
|
132 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
133 |
+
any Contribution intentionally submitted for inclusion in the Work
|
134 |
+
by You to the Licensor shall be under the terms and conditions of
|
135 |
+
this License, without any additional terms or conditions.
|
136 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
137 |
+
the terms of any separate license agreement you may have executed
|
138 |
+
with Licensor regarding such Contributions.
|
139 |
+
|
140 |
+
6. Trademarks. This License does not grant permission to use the trade
|
141 |
+
names, trademarks, service marks, or product names of the Licensor,
|
142 |
+
except as required for reasonable and customary use in describing the
|
143 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
144 |
+
|
145 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
146 |
+
agreed to in writing, Licensor provides the Work (and each
|
147 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
148 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
149 |
+
implied, including, without limitation, any warranties or conditions
|
150 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
151 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
152 |
+
appropriateness of using or redistributing the Work and assume any
|
153 |
+
risks associated with Your exercise of permissions under this License.
|
154 |
+
|
155 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
156 |
+
whether in tort (including negligence), contract, or otherwise,
|
157 |
+
unless required by applicable law (such as deliberate and grossly
|
158 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
159 |
+
liable to You for damages, including any direct, indirect, special,
|
160 |
+
incidental, or consequential damages of any character arising as a
|
161 |
+
result of this License or out of the use or inability to use the
|
162 |
+
Work (including but not limited to damages for loss of goodwill,
|
163 |
+
work stoppage, computer failure or malfunction, or any and all
|
164 |
+
other commercial damages or losses), even if such Contributor
|
165 |
+
has been advised of the possibility of such damages.
|
166 |
+
|
167 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
168 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
169 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
170 |
+
or other liability obligations and/or rights consistent with this
|
171 |
+
License. However, in accepting such obligations, You may act only
|
172 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
173 |
+
of any other Contributor, and only if You agree to indemnify,
|
174 |
+
defend, and hold each Contributor harmless for any liability
|
175 |
+
incurred by, or claims asserted against, such Contributor by reason
|
176 |
+
of your accepting any such warranty or additional liability.
|
177 |
+
|
178 |
+
END OF TERMS AND CONDITIONS
|
179 |
+
|
180 |
+
APPENDIX: How to apply the Apache License to your work.
|
181 |
+
|
182 |
+
To apply the Apache License to your work, attach the following
|
183 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
184 |
+
replaced with your own identifying information. (Don't include
|
185 |
+
the brackets!) The text should be enclosed in the appropriate
|
186 |
+
comment syntax for the file format. We also recommend that a
|
187 |
+
file or class name and description of purpose be included on the
|
188 |
+
same "printed page" as the copyright notice for easier
|
189 |
+
identification within third-party archives.
|
190 |
+
|
191 |
+
Copyright 2018-2023 OpenMMLab.
|
192 |
+
|
193 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
194 |
+
you may not use this file except in compliance with the License.
|
195 |
+
You may obtain a copy of the License at
|
196 |
+
|
197 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
198 |
+
|
199 |
+
Unless required by applicable law or agreed to in writing, software
|
200 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
201 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
202 |
+
See the License for the specific language governing permissions and
|
203 |
+
limitations under the License.
|
204 |
+
|
README.md
CHANGED
@@ -1,12 +1,196 @@
|
|
1 |
-
---
|
2 |
-
title: Multimodal
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Multimodal-GPT
|
3 |
+
app_file: app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 5.21.0
|
6 |
+
---
|
7 |
+
# 🤖 Multi-modal GPT
|
8 |
+
|
9 |
+
Train a multi-modal chatbot with visual and language instructions!
|
10 |
+
|
11 |
+
Based on the open-source multi-modal model [OpenFlamingo](https://github.com/mlfoundations/open_flamingo), we create various **visual instruction** data with open datasets, including VQA, Image Captioning, Visual Reasoning, Text OCR, and Visual Dialogue. Additionally, we also train the language model component of OpenFlamingo using only **language-only instruction** data.
|
12 |
+
|
13 |
+
The **joint training** of visual and language instructions effectively improves the performance of the model! For more details please refer to our [technical report](https://arxiv.org/abs/2305.04790).
|
14 |
+
|
15 |
+
Welcome to join us!
|
16 |
+
|
17 |
+
</div>
|
18 |
+
|
19 |
+
<div align="center">
|
20 |
+
|
21 |
+
English | [简体中文](README_zh-CN.md)
|
22 |
+
|
23 |
+
</div>
|
24 |
+
|
25 |
+
<div align="center">
|
26 |
+
<a href="https://openmmlab.medium.com/" style="text-decoration:none;">
|
27 |
+
<img src="https://user-images.githubusercontent.com/25839884/219255827-67c1a27f-f8c5-46a9-811d-5e57448c61d1.png" width="3%" alt="" /></a>
|
28 |
+
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
29 |
+
<a href="https://discord.com/channels/1037617289144569886/1046608014234370059" style="text-decoration:none;">
|
30 |
+
<img src="https://user-images.githubusercontent.com/25839884/218347213-c080267f-cbb6-443e-8532-8e1ed9a58ea9.png" width="3%" alt="" /></a>
|
31 |
+
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
32 |
+
<a href="https://twitter.com/OpenMMLab" style="text-decoration:none;">
|
33 |
+
<img src="https://user-images.githubusercontent.com/25839884/218346637-d30c8a0f-3eba-4699-8131-512fb06d46db.png" width="3%" alt="" /></a>
|
34 |
+
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
35 |
+
<a href="https://www.youtube.com/openmmlab" style="text-decoration:none;">
|
36 |
+
<img src="https://user-images.githubusercontent.com/25839884/218346691-ceb2116a-465a-40af-8424-9f30d2348ca9.png" width="3%" alt="" /></a>
|
37 |
+
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
38 |
+
<a href="https://space.bilibili.com/1293512903" style="text-decoration:none;">
|
39 |
+
<img src="https://user-images.githubusercontent.com/25839884/219026751-d7d14cce-a7c9-4e82-9942-8375fca65b99.png" width="3%" alt="" /></a>
|
40 |
+
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
41 |
+
<a href="https://www.zhihu.com/people/openmmlab" style="text-decoration:none;">
|
42 |
+
<img src="https://user-images.githubusercontent.com/25839884/219026120-ba71e48b-6e94-4bd4-b4e9-b7d175b5e362.png" width="3%" alt="" /></a>
|
43 |
+
</div>
|
44 |
+
|
45 |
+
## Features
|
46 |
+
|
47 |
+
- Support various vision and language instruction data
|
48 |
+
- Parameter efficient fine-tuning with LoRA
|
49 |
+
- Tuning vision and language at the same time, complement each other
|
50 |
+
|
51 |
+
|
52 |
+
## Installation
|
53 |
+
|
54 |
+
To install the package in an existing environment, run
|
55 |
+
|
56 |
+
```bash
|
57 |
+
git clone https://github.com/open-mmlab/Multimodal-GPT.git
|
58 |
+
cd Multimodal-GPT
|
59 |
+
pip install -r requirements.txt
|
60 |
+
pip install -v -e .
|
61 |
+
```
|
62 |
+
|
63 |
+
or create a new conda environment
|
64 |
+
|
65 |
+
```bash
|
66 |
+
conda env create -f environment.yml
|
67 |
+
```
|
68 |
+
|
69 |
+
|
70 |
+
## Launch Demo Locally
|
71 |
+
|
72 |
+
1. Download the pre-trained weights.
|
73 |
+
|
74 |
+
Use [this script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py) for converting LLaMA weights to Hugging Face format.
|
75 |
+
|
76 |
+
Download the OpenFlamingo pre-trained model from [openflamingo/OpenFlamingo-9B](https://huggingface.co/openflamingo/OpenFlamingo-9B).
|
77 |
+
|
78 |
+
Download our LoRA Weight from [here](https://download.openmmlab.com/mmgpt/v0/mmgpt-lora-v0-release.pt).
|
79 |
+
|
80 |
+
Then place these models in `checkpoints` folders like this:
|
81 |
+
|
82 |
+
```
|
83 |
+
checkpoints
|
84 |
+
├── llama-7b_hf
|
85 |
+
│ ├── config.json
|
86 |
+
│ ├── pytorch_model-00001-of-00002.bin
|
87 |
+
│ ├── ......
|
88 |
+
│ └── tokenizer.model
|
89 |
+
├── OpenFlamingo-9B
|
90 |
+
│ └──checkpoint.pt
|
91 |
+
├──mmgpt-lora-v0-release.pt
|
92 |
+
|
93 |
+
2. launch the gradio demo
|
94 |
+
|
95 |
+
```bash
|
96 |
+
python app.py
|
97 |
+
```
|
98 |
+
|
99 |
+
## Examples
|
100 |
+
|
101 |
+
### Recipe:
|
102 |
+

|
103 |
+
|
104 |
+
### Travel plan:
|
105 |
+

|
106 |
+
|
107 |
+
### Movie:
|
108 |
+

|
109 |
+
|
110 |
+
### Famous person:
|
111 |
+

|
112 |
+
|
113 |
+
|
114 |
+
## Fine-tuning
|
115 |
+
|
116 |
+
### Prepare datasets
|
117 |
+
|
118 |
+
1. [A-OKVQA](https://allenai.org/project/a-okvqa/home)
|
119 |
+
|
120 |
+
Download annotation from [this link](https://prior-datasets.s3.us-east-2.amazonaws.com/aokvqa/aokvqa_v1p0.tar.gz) and unzip to `data/aokvqa/annotations`.
|
121 |
+
|
122 |
+
It also requires images from coco dataset which can be downloaded from [here](https://cocodataset.org/#home).
|
123 |
+
|
124 |
+
2. [COCO Caption](https://cs.stanford.edu/people/karpathy/deepimagesent/)
|
125 |
+
|
126 |
+
Download from [this link](https://cs.stanford.edu/people/karpathy/deepimagesent/coco.zip) and unzip to `data/coco`.
|
127 |
+
|
128 |
+
It also requires images from coco dataset which can be downloaded from [here](https://cocodataset.org/#home).
|
129 |
+
|
130 |
+
3. [OCR VQA](https://ocr-vqa.github.io/)
|
131 |
+
|
132 |
+
Download from [this link](https://drive.google.com/drive/folders/1_GYPY5UkUy7HIcR0zq3ZCFgeZN7BAfm_?usp=sharing) and place in `data/OCR_VQA/`.
|
133 |
+
|
134 |
+
4. [LlaVA](https://llava-vl.github.io/)
|
135 |
+
|
136 |
+
Download from [liuhaotian/LLaVA-Instruct-150K](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) and place in `data/llava/`.
|
137 |
+
|
138 |
+
It also requires images from coco dataset which can be downloaded from [here](https://cocodataset.org/#home).
|
139 |
+
|
140 |
+
5. [Mini-GPT4](https://minigpt-4.github.io/)
|
141 |
+
|
142 |
+
Download from [Vision-CAIR/cc_sbu_align](https://huggingface.co/datasets/Vision-CAIR/cc_sbu_align) and place in `data/cc_sbu_align/`.
|
143 |
+
|
144 |
+
6. [Dolly 15k](https://www.databricks.com/blog/2023/03/24/hello-dolly-democratizing-magic-chatgpt-open-models.html)
|
145 |
+
|
146 |
+
Download from [databricks/databricks-dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k) and place it in `data/dolly/databricks-dolly-15k.jsonl`.
|
147 |
+
|
148 |
+
7. [Alpaca GPT4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
149 |
+
|
150 |
+
Download it from [this link](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/raw/main/data/alpaca_gpt4_data.json) and place it in `data/alpaca_gpt4/alpaca_gpt4_data.json`.
|
151 |
+
|
152 |
+
You can also customize the data path in the [configs/dataset_config.py](configs/dataset_config.py).
|
153 |
+
|
154 |
+
8. [Baize](https://github.com/project-baize/baize-chatbot)
|
155 |
+
|
156 |
+
Download it from [this link](https://github.com/project-baize/baize-chatbot/blob/main/data/quora_chat_data.json) and place it in `data/baize/quora_chat_data.json`.
|
157 |
+
|
158 |
+
|
159 |
+
## Start training
|
160 |
+
|
161 |
+
```bash
|
162 |
+
torchrun --nproc_per_node=8 mmgpt/train/instruction_finetune.py \
|
163 |
+
--lm_path checkpoints/llama-7b_hf \
|
164 |
+
--tokenizer_path checkpoints/llama-7b_hf \
|
165 |
+
--pretrained_path checkpoints/OpenFlamingo-9B/checkpoint.pt \
|
166 |
+
--run_name train-my-gpt4 \
|
167 |
+
--learning_rate 1e-5 \
|
168 |
+
--lr_scheduler cosine \
|
169 |
+
--batch_size 1 \
|
170 |
+
--tuning_config configs/lora_config.py \
|
171 |
+
--dataset_config configs/dataset_config.py \
|
172 |
+
--report_to_wandb
|
173 |
+
```
|
174 |
+
|
175 |
+
|
176 |
+
## Acknowledgements
|
177 |
+
|
178 |
+
- [OpenFlamingo](https://github.com/mlfoundations/open_flamingo)
|
179 |
+
- [LAVIS](https://github.com/salesforce/LAVIS)
|
180 |
+
- [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca)
|
181 |
+
- [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4)
|
182 |
+
- [LLaVA](https://github.com/haotian-liu/LLaVA/tree/main)
|
183 |
+
- [Instruction Tuning with GPT-4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
184 |
+
|
185 |
+
If you find our project useful for your research and applications, please cite using this BibTeX:
|
186 |
+
|
187 |
+
```bibtex
|
188 |
+
@misc{gong2023multimodalgpt,
|
189 |
+
title={MultiModal-GPT: A Vision and Language Model for Dialogue with Humans},
|
190 |
+
author={Tao Gong and Chengqi Lyu and Shilong Zhang and Yudong Wang and Miao Zheng and Qian Zhao and Kuikun Liu and Wenwei Zhang and Ping Luo and Kai Chen},
|
191 |
+
year={2023},
|
192 |
+
eprint={2305.04790},
|
193 |
+
archivePrefix={arXiv},
|
194 |
+
primaryClass={cs.CV}
|
195 |
+
}
|
196 |
+
```
|
README_zh-CN.md
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🤖 Multi-modal GPT
|
2 |
+
|
3 |
+
使用视觉和语言指令训练一个多模态聊天机器人!
|
4 |
+
|
5 |
+
基于开源多模态模型 [OpenFlamingo](https://github.com/mlfoundations/open_flamingo),我们使用公开数据集创建了各种**视觉指令**数据,包括视觉问答、图像字幕、视觉推理、文本 OCR 和视觉对话。此外,我们还使用仅包含**语言指令**数据的语言模型组件进行了训练。
|
6 |
+
|
7 |
+
视觉和语言指令的**联合训练**有效提高了模型的性能!更多细节请参阅我们的[技术报告](https://arxiv.org/abs/2305.04790)。
|
8 |
+
|
9 |
+
欢迎加入我们!
|
10 |
+
|
11 |
+
</div>
|
12 |
+
|
13 |
+
<div align="center">
|
14 |
+
|
15 |
+
[English](README.md) | 简体中文
|
16 |
+
|
17 |
+
</div>
|
18 |
+
|
19 |
+
<div align="center">
|
20 |
+
<a href="https://openmmlab.medium.com/" style="text-decoration:none;">
|
21 |
+
<img src="https://user-images.githubusercontent.com/25839884/219255827-67c1a27f-f8c5-46a9-811d-5e57448c61d1.png" width="3%" alt="" /></a>
|
22 |
+
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
23 |
+
<a href="https://discord.com/channels/1037617289144569886/1046608014234370059" style="text-decoration:none;">
|
24 |
+
<img src="https://user-images.githubusercontent.com/25839884/218347213-c080267f-cbb6-443e-8532-8e1ed9a58ea9.png" width="3%" alt="" /></a>
|
25 |
+
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
26 |
+
<a href="https://twitter.com/OpenMMLab" style="text-decoration:none;">
|
27 |
+
<img src="https://user-images.githubusercontent.com/25839884/218346637-d30c8a0f-3eba-4699-8131-512fb06d46db.png" width="3%" alt="" /></a>
|
28 |
+
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
29 |
+
<a href="https://www.youtube.com/openmmlab" style="text-decoration:none;">
|
30 |
+
<img src="https://user-images.githubusercontent.com/25839884/218346691-ceb2116a-465a-40af-8424-9f30d2348ca9.png" width="3%" alt="" /></a>
|
31 |
+
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
32 |
+
<a href="https://space.bilibili.com/1293512903" style="text-decoration:none;">
|
33 |
+
<img src="https://user-images.githubusercontent.com/25839884/219026751-d7d14cce-a7c9-4e82-9942-8375fca65b99.png" width="3%" alt="" /></a>
|
34 |
+
<img src="https://user-images.githubusercontent.com/25839884/218346358-56cc8e2f-a2b8-487f-9088-32480cceabcf.png" width="3%" alt="" />
|
35 |
+
<a href="https://www.zhihu.com/people/openmmlab" style="text-decoration:none;">
|
36 |
+
<img src="https://user-images.githubusercontent.com/25839884/219026120-ba71e48b-6e94-4bd4-b4e9-b7d175b5e362.png" width="3%" alt="" /></a>
|
37 |
+
</div>
|
38 |
+
|
39 |
+
## 特性
|
40 |
+
|
41 |
+
- 支持各种视觉和语言指令数据
|
42 |
+
- 使用 LoRA 进行参数高效微调
|
43 |
+
- 同时调整视觉和语言,相互补充
|
44 |
+
|
45 |
+
## 安装
|
46 |
+
|
47 |
+
在一个已有环境中安装依赖包,运行以下指令
|
48 |
+
|
49 |
+
```bash
|
50 |
+
git clone https://github.com/open-mmlab/Multimodal-GPT.git
|
51 |
+
cd Multimodal-GPT
|
52 |
+
pip install -r requirements.txt
|
53 |
+
pip install -v -e .
|
54 |
+
```
|
55 |
+
|
56 |
+
或者创建一个新的 conda 环境
|
57 |
+
|
58 |
+
```bash
|
59 |
+
conda env create -f environment.yml
|
60 |
+
```
|
61 |
+
|
62 |
+
## Demo
|
63 |
+
|
64 |
+
1. 下载预训练权重
|
65 |
+
|
66 |
+
使用[这个脚本](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py)把 LLaMA 权重转换成 HuggingFace 格式。
|
67 |
+
|
68 |
+
从 [openflamingo/OpenFlamingo-9B](https://huggingface.co/openflamingo/OpenFlamingo-9B) 下载 OpenFlamingo 预训练模型。
|
69 |
+
|
70 |
+
从[这个链接](https://download.openmmlab.com/mmgpt/v0/mmgpt-lora-v0-release.pt) 下载我们的 LoRA 权重。
|
71 |
+
|
72 |
+
然后把所有模型权重放到 `checkpoints` 文件夹下,目录结构如下:
|
73 |
+
|
74 |
+
```
|
75 |
+
checkpoints
|
76 |
+
├── llama-7b_hf
|
77 |
+
│ ├── config.json
|
78 |
+
│ ├── pytorch_model-00001-of-00002.bin
|
79 |
+
│ ├── ......
|
80 |
+
│ └── tokenizer.model
|
81 |
+
├── OpenFlamingo-9B
|
82 |
+
│ └──checkpoint.pt
|
83 |
+
├──mmgpt-lora-v0-release.pt
|
84 |
+
|
85 |
+
2. 启动 gradio demo
|
86 |
+
|
87 |
+
```bash
|
88 |
+
python app.py
|
89 |
+
```
|
90 |
+
|
91 |
+
## 示例
|
92 |
+
|
93 |
+
### 菜单:
|
94 |
+

|
95 |
+
|
96 |
+
### 旅行计划:
|
97 |
+

|
98 |
+
|
99 |
+
### 电影:
|
100 |
+

|
101 |
+
|
102 |
+
### 名人:
|
103 |
+

|
104 |
+
|
105 |
+
|
106 |
+
## 微调 Fine-tuning
|
107 |
+
|
108 |
+
### 准备数据集
|
109 |
+
|
110 |
+
1. [A-OKVQA](https://allenai.org/project/a-okvqa/home)
|
111 |
+
|
112 |
+
从[这个链接](https://prior-datasets.s3.us-east-2.amazonaws.com/aokvqa/aokvqa_v1p0.tar.gz)下载标注,解压到 `data/aokvqa/annotations` 路径下。
|
113 |
+
|
114 |
+
同时还需要 coco 数据集的图像,可以从[这里](https://cocodataset.org/#home)下载。
|
115 |
+
|
116 |
+
2. [COCO Caption](https://cs.stanford.edu/people/karpathy/deepimagesent/)
|
117 |
+
|
118 |
+
从[这个链接](https://cs.stanford.edu/people/karpathy/deepimagesent/coco.zip),解压到 `data/coco` 路径下。
|
119 |
+
|
120 |
+
同时还需要 coco 数据集的图像,可以从[这里](https://cocodataset.org/#home)下载。
|
121 |
+
|
122 |
+
3. [OCR VQA](https://ocr-vqa.github.io/)
|
123 |
+
|
124 |
+
从 [这个链接](https://drive.google.com/drive/folders/1_GYPY5UkUy7HIcR0zq3ZCFgeZN7BAfm_?usp=sharing) 下载数据集,放到 `data/OCR_VQA/` 路径下。
|
125 |
+
|
126 |
+
4. [LlaVA](https://llava-vl.github.io/)
|
127 |
+
|
128 |
+
从 [liuhaotian/LLaVA-Instruct-150K](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) 下载数据集,放到 `data/llava/` 路径下。
|
129 |
+
|
130 |
+
同时还需要 coco 数据集的图像,可以从[这里](https://cocodataset.org/#home)下载。
|
131 |
+
|
132 |
+
5. [Mini-GPT4](https://minigpt-4.github.io/)
|
133 |
+
|
134 |
+
从 [Vision-CAIR/cc_sbu_align](https://huggingface.co/datasets/Vision-CAIR/cc_sbu_align) 下载数据集,放到 `data/cc_sbu_align/` 路径下。
|
135 |
+
|
136 |
+
6. [Dolly 15k](https://www.databricks.com/blog/2023/03/24/hello-dolly-democratizing-magic-chatgpt-open-models.html)
|
137 |
+
|
138 |
+
从 [databricks/databricks-dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k) 下载数据集,放到 `data/dolly/databricks-dolly-15k.jsonl` 路径下。
|
139 |
+
|
140 |
+
7. [Alpaca GPT4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
141 |
+
|
142 |
+
从[这个链接](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/raw/main/data/alpaca_gpt4_data.json) 下载数据集,放到 `data/alpaca_gpt4/alpaca_gpt4_data.json` 路径下。
|
143 |
+
|
144 |
+
你也可以在 [configs/dataset_config.py](configs/dataset_config.py) 文件中自定义数据集路径。
|
145 |
+
|
146 |
+
|
147 |
+
## 开启训练
|
148 |
+
|
149 |
+
```bash
|
150 |
+
torchrun --nproc_per_node=8 mmgpt/train/instruction_finetune.py \
|
151 |
+
--lm_path checkpoints/llama-7b_hf \
|
152 |
+
--tokenizer_path checkpoints/llama-7b_hf \
|
153 |
+
--pretrained_path checkpoints/OpenFlamingo-9B/checkpoint.pt \
|
154 |
+
--run_name train-my-gpt4 \
|
155 |
+
--learning_rate 1e-5 \
|
156 |
+
--lr_scheduler cosine \
|
157 |
+
--batch_size 1 \
|
158 |
+
--tuning_config configs/lora_config.py \
|
159 |
+
--dataset_config configs/dataset_config.py \
|
160 |
+
--report_to_wandb
|
161 |
+
```
|
162 |
+
|
163 |
+
|
164 |
+
## 致谢
|
165 |
+
|
166 |
+
- [OpenFlamingo](https://github.com/mlfoundations/open_flamingo)
|
167 |
+
- [LAVIS](https://github.com/salesforce/LAVIS)
|
168 |
+
- [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca)
|
169 |
+
- [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4)
|
170 |
+
- [LLaVA](https://github.com/haotian-liu/LLaVA/tree/main)
|
171 |
+
- [Instruction Tuning with GPT-4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
172 |
+
|
173 |
+
如果你觉得我们的项目对你的研究和应用有帮助,请用以下 BibTeX 进行引用
|
174 |
+
|
175 |
+
```bibtex
|
176 |
+
@misc{gong2023multimodalgpt,
|
177 |
+
title={MultiModal-GPT: A Vision and Language Model for Dialogue with Humans},
|
178 |
+
author={Tao Gong and Chengqi Lyu and Shilong Zhang and Yudong Wang and Miao Zheng and Qian Zhao and Kuikun Liu and Wenwei Zhang and Ping Luo and Kai Chen},
|
179 |
+
year={2023},
|
180 |
+
eprint={2305.04790},
|
181 |
+
archivePrefix={arXiv},
|
182 |
+
primaryClass={cs.CV}
|
183 |
+
}
|
184 |
+
```
|
app.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
from mmgpt.models.builder import create_model_and_transforms
|
8 |
+
|
9 |
+
TEMPLATE = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
|
10 |
+
response_split = "### Response:"
|
11 |
+
|
12 |
+
|
13 |
+
class Inferencer:
|
14 |
+
|
15 |
+
def __init__(self, finetune_path, llama_path, open_flamingo_path):
|
16 |
+
ckpt = torch.load(finetune_path, map_location="cpu", weights_only=False)
|
17 |
+
if "model_state_dict" in ckpt:
|
18 |
+
state_dict = ckpt["model_state_dict"]
|
19 |
+
# remove the "module." prefix
|
20 |
+
state_dict = {
|
21 |
+
k[7:]: v
|
22 |
+
for k, v in state_dict.items() if k.startswith("module.")
|
23 |
+
}
|
24 |
+
else:
|
25 |
+
state_dict = ckpt
|
26 |
+
tuning_config = ckpt.get("tuning_config")
|
27 |
+
if tuning_config is None:
|
28 |
+
print("tuning_config not found in checkpoint")
|
29 |
+
else:
|
30 |
+
print("tuning_config found in checkpoint: ", tuning_config)
|
31 |
+
model, image_processor, tokenizer = create_model_and_transforms(
|
32 |
+
model_name="open_flamingo",
|
33 |
+
clip_vision_encoder_path="ViT-L-14",
|
34 |
+
clip_vision_encoder_pretrained="openai",
|
35 |
+
lang_encoder_path=llama_path,
|
36 |
+
tokenizer_path=llama_path,
|
37 |
+
pretrained_model_path=open_flamingo_path,
|
38 |
+
tuning_config=tuning_config,
|
39 |
+
)
|
40 |
+
model.load_state_dict(state_dict, strict=False)
|
41 |
+
model.half()
|
42 |
+
|
43 |
+
|
44 |
+
device = torch.device("cpu")
|
45 |
+
model = model.to(device)
|
46 |
+
|
47 |
+
# model = model.to("cuda")
|
48 |
+
model.eval()
|
49 |
+
tokenizer.padding_side = "left"
|
50 |
+
tokenizer.add_eos_token = False
|
51 |
+
self.model = model
|
52 |
+
self.image_processor = image_processor
|
53 |
+
self.tokenizer = tokenizer
|
54 |
+
|
55 |
+
def __call__(self, prompt, imgpaths, max_new_token, num_beams, temperature,
|
56 |
+
top_k, top_p, do_sample):
|
57 |
+
device = torch.device("cpu")
|
58 |
+
if len(imgpaths) > 1:
|
59 |
+
raise gr.Error(
|
60 |
+
"Current only support one image, please clear gallery and upload one image"
|
61 |
+
)
|
62 |
+
lang_x = self.tokenizer([prompt], return_tensors="pt")
|
63 |
+
if len(imgpaths) == 0 or imgpaths is None:
|
64 |
+
for layer in self.model.lang_encoder._get_decoder_layers():
|
65 |
+
layer.condition_only_lang_x(True)
|
66 |
+
output_ids = self.model.lang_encoder.generate(
|
67 |
+
input_ids=lang_x["input_ids"].to(device),
|
68 |
+
attention_mask=lang_x["attention_mask"].to(device),
|
69 |
+
max_new_tokens=max_new_token,
|
70 |
+
num_beams=num_beams,
|
71 |
+
temperature=temperature,
|
72 |
+
top_k=top_k,
|
73 |
+
top_p=top_p,
|
74 |
+
do_sample=do_sample,
|
75 |
+
)[0]
|
76 |
+
for layer in self.model.lang_encoder._get_decoder_layers():
|
77 |
+
layer.condition_only_lang_x(False)
|
78 |
+
else:
|
79 |
+
images = (Image.open(fp) for fp in imgpaths)
|
80 |
+
vision_x = [self.image_processor(im).unsqueeze(0) for im in images]
|
81 |
+
vision_x = torch.cat(vision_x, dim=0)
|
82 |
+
vision_x = vision_x.unsqueeze(1).unsqueeze(0).half()
|
83 |
+
|
84 |
+
output_ids = self.model.generate(
|
85 |
+
vision_x=vision_x.to(device),
|
86 |
+
lang_x=lang_x["input_ids"].to(device),
|
87 |
+
attention_mask=lang_x["attention_mask"].to(device),
|
88 |
+
max_new_tokens=max_new_token,
|
89 |
+
num_beams=num_beams,
|
90 |
+
temperature=temperature,
|
91 |
+
top_k=top_k,
|
92 |
+
top_p=top_p,
|
93 |
+
do_sample=do_sample,
|
94 |
+
)[0]
|
95 |
+
generated_text = self.tokenizer.decode(
|
96 |
+
output_ids, skip_special_tokens=True)
|
97 |
+
# print(generated_text)
|
98 |
+
result = generated_text.split(response_split)[-1].strip()
|
99 |
+
return result
|
100 |
+
|
101 |
+
|
102 |
+
class PromptGenerator:
|
103 |
+
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
prompt_template=TEMPLATE,
|
107 |
+
ai_prefix="Response",
|
108 |
+
user_prefix="Instruction",
|
109 |
+
sep: str = "\n\n### ",
|
110 |
+
buffer_size=0,
|
111 |
+
):
|
112 |
+
self.all_history = list()
|
113 |
+
self.ai_prefix = ai_prefix
|
114 |
+
self.user_prefix = user_prefix
|
115 |
+
self.buffer_size = buffer_size
|
116 |
+
self.prompt_template = prompt_template
|
117 |
+
self.sep = sep
|
118 |
+
|
119 |
+
def add_message(self, role, message):
|
120 |
+
self.all_history.append([role, message])
|
121 |
+
|
122 |
+
def get_images(self):
|
123 |
+
img_list = list()
|
124 |
+
if self.buffer_size > 0:
|
125 |
+
all_history = self.all_history[-2 * (self.buffer_size + 1):]
|
126 |
+
elif self.buffer_size == 0:
|
127 |
+
all_history = self.all_history[-2:]
|
128 |
+
else:
|
129 |
+
all_history = self.all_history[:]
|
130 |
+
for his in all_history:
|
131 |
+
if type(his[-1]) == tuple:
|
132 |
+
img_list.append(his[-1][-1])
|
133 |
+
return img_list
|
134 |
+
|
135 |
+
def get_prompt(self):
|
136 |
+
format_dict = dict()
|
137 |
+
if "{user_prefix}" in self.prompt_template:
|
138 |
+
format_dict["user_prefix"] = self.user_prefix
|
139 |
+
if "{ai_prefix}" in self.prompt_template:
|
140 |
+
format_dict["ai_prefix"] = self.ai_prefix
|
141 |
+
prompt_template = self.prompt_template.format(**format_dict)
|
142 |
+
ret = prompt_template
|
143 |
+
if self.buffer_size > 0:
|
144 |
+
all_history = self.all_history[-2 * (self.buffer_size + 1):]
|
145 |
+
elif self.buffer_size == 0:
|
146 |
+
all_history = self.all_history[-2:]
|
147 |
+
else:
|
148 |
+
all_history = self.all_history[:]
|
149 |
+
context = []
|
150 |
+
have_image = False
|
151 |
+
for role, message in all_history[::-1]:
|
152 |
+
if message:
|
153 |
+
if type(message) is tuple and message[
|
154 |
+
1] is not None and not have_image:
|
155 |
+
message, _ = message
|
156 |
+
context.append(self.sep + "Image:\n<image>" + self.sep +
|
157 |
+
role + ":\n" + message)
|
158 |
+
else:
|
159 |
+
context.append(self.sep + role + ":\n" + message)
|
160 |
+
else:
|
161 |
+
context.append(self.sep + role + ":\n")
|
162 |
+
|
163 |
+
ret += "".join(context[::-1])
|
164 |
+
return ret
|
165 |
+
|
166 |
+
|
167 |
+
def to_gradio_chatbot(prompt_generator):
|
168 |
+
ret = []
|
169 |
+
for i, (role, msg) in enumerate(prompt_generator.all_history):
|
170 |
+
if i % 2 == 0:
|
171 |
+
if type(msg) is tuple:
|
172 |
+
import base64
|
173 |
+
from io import BytesIO
|
174 |
+
|
175 |
+
msg, image = msg
|
176 |
+
if type(image) is str:
|
177 |
+
from PIL import Image
|
178 |
+
|
179 |
+
image = Image.open(image)
|
180 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
181 |
+
aspect_ratio = max_hw / min_hw
|
182 |
+
max_len, min_len = 800, 400
|
183 |
+
shortest_edge = int(
|
184 |
+
min(max_len / aspect_ratio, min_len, min_hw))
|
185 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
186 |
+
H, W = image.size
|
187 |
+
if H > W:
|
188 |
+
H, W = longest_edge, shortest_edge
|
189 |
+
else:
|
190 |
+
H, W = shortest_edge, longest_edge
|
191 |
+
image = image.resize((H, W))
|
192 |
+
# image = image.resize((224, 224))
|
193 |
+
buffered = BytesIO()
|
194 |
+
image.save(buffered, format="JPEG")
|
195 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
196 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
197 |
+
msg = msg + img_str
|
198 |
+
ret.append([msg, None])
|
199 |
+
else:
|
200 |
+
ret[-1][-1] = msg
|
201 |
+
return ret
|
202 |
+
|
203 |
+
|
204 |
+
def bot(
|
205 |
+
text,
|
206 |
+
image,
|
207 |
+
state,
|
208 |
+
prompt,
|
209 |
+
ai_prefix,
|
210 |
+
user_prefix,
|
211 |
+
seperator,
|
212 |
+
history_buffer,
|
213 |
+
max_new_token,
|
214 |
+
num_beams,
|
215 |
+
temperature,
|
216 |
+
top_k,
|
217 |
+
top_p,
|
218 |
+
do_sample,
|
219 |
+
):
|
220 |
+
state.prompt_template = prompt
|
221 |
+
state.ai_prefix = ai_prefix
|
222 |
+
state.user_prefix = user_prefix
|
223 |
+
state.sep = seperator
|
224 |
+
state.buffer_size = history_buffer
|
225 |
+
if image:
|
226 |
+
state.add_message(user_prefix, (text, image))
|
227 |
+
else:
|
228 |
+
state.add_message(user_prefix, text)
|
229 |
+
state.add_message(ai_prefix, None)
|
230 |
+
inputs = state.get_prompt()
|
231 |
+
image_paths = state.get_images()[-1:]
|
232 |
+
|
233 |
+
inference_results = inferencer(inputs, image_paths, max_new_token,
|
234 |
+
num_beams, temperature, top_k, top_p,
|
235 |
+
do_sample)
|
236 |
+
state.all_history[-1][-1] = inference_results
|
237 |
+
memory_allocated = str(round(torch.cuda.memory_allocated() / 1024**3,
|
238 |
+
2)) + 'GB'
|
239 |
+
return state, to_gradio_chatbot(state), "", None, inputs, memory_allocated
|
240 |
+
|
241 |
+
|
242 |
+
def clear(state):
|
243 |
+
state.all_history = []
|
244 |
+
return state, to_gradio_chatbot(state), "", None, ""
|
245 |
+
|
246 |
+
|
247 |
+
title_markdown = ("""
|
248 |
+
# 🤖 Multi-modal GPT
|
249 |
+
[[Project]](https://github.com/open-mmlab/Multimodal-GPT.git)""")
|
250 |
+
|
251 |
+
|
252 |
+
def build_conversation_demo():
|
253 |
+
with gr.Blocks(title="Multi-modal GPT") as demo:
|
254 |
+
gr.Markdown(title_markdown)
|
255 |
+
|
256 |
+
state = gr.State(PromptGenerator())
|
257 |
+
with gr.Row():
|
258 |
+
with gr.Column(scale=3):
|
259 |
+
memory_allocated = gr.Textbox(
|
260 |
+
value=init_memory, label="Memory")
|
261 |
+
imagebox = gr.Image(type="filepath")
|
262 |
+
# TODO config parameters
|
263 |
+
with gr.Accordion(
|
264 |
+
"Parameters",
|
265 |
+
open=True,
|
266 |
+
):
|
267 |
+
max_new_token_bar = gr.Slider(
|
268 |
+
0, 1024, 512, label="max_new_token", step=1)
|
269 |
+
num_beams_bar = gr.Slider(
|
270 |
+
0.0, 10, 3, label="num_beams", step=1)
|
271 |
+
temperature_bar = gr.Slider(
|
272 |
+
0.0, 1.0, 1.0, label="temperature", step=0.01)
|
273 |
+
topk_bar = gr.Slider(0, 100, 20, label="top_k", step=1)
|
274 |
+
topp_bar = gr.Slider(0, 1.0, 1.0, label="top_p", step=0.01)
|
275 |
+
do_sample = gr.Checkbox(True, label="do_sample")
|
276 |
+
with gr.Accordion(
|
277 |
+
"Prompt",
|
278 |
+
open=False,
|
279 |
+
):
|
280 |
+
with gr.Row():
|
281 |
+
ai_prefix = gr.Text("Response", label="AI Prefix")
|
282 |
+
user_prefix = gr.Text(
|
283 |
+
"Instruction", label="User Prefix")
|
284 |
+
seperator = gr.Text("\n\n### ", label="Seperator")
|
285 |
+
history_buffer = gr.Slider(
|
286 |
+
-1, 10, -1, label="History buffer", step=1)
|
287 |
+
prompt = gr.Text(TEMPLATE, label="Prompt")
|
288 |
+
model_inputs = gr.Textbox(label="Actual inputs for Model")
|
289 |
+
|
290 |
+
with gr.Column(scale=6):
|
291 |
+
with gr.Row():
|
292 |
+
with gr.Column():
|
293 |
+
chatbot = gr.Chatbot(elem_id="chatbot", type="messages")
|
294 |
+
with gr.Row():
|
295 |
+
with gr.Column(scale=8):
|
296 |
+
textbox = gr.Textbox(
|
297 |
+
show_label=False,
|
298 |
+
placeholder="Enter text and press ENTER",
|
299 |
+
container=False)
|
300 |
+
submit_btn = gr.Button(value="Submit")
|
301 |
+
clear_btn = gr.Button(value="🗑️ Clear history")
|
302 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
303 |
+
gr.Examples(
|
304 |
+
examples=[
|
305 |
+
[
|
306 |
+
f"{cur_dir}/docs/images/demo_image.jpg",
|
307 |
+
"What is in this image?"
|
308 |
+
],
|
309 |
+
],
|
310 |
+
inputs=[imagebox, textbox],
|
311 |
+
)
|
312 |
+
textbox.submit(
|
313 |
+
bot,
|
314 |
+
[
|
315 |
+
textbox,
|
316 |
+
imagebox,
|
317 |
+
state,
|
318 |
+
prompt,
|
319 |
+
ai_prefix,
|
320 |
+
user_prefix,
|
321 |
+
seperator,
|
322 |
+
history_buffer,
|
323 |
+
max_new_token_bar,
|
324 |
+
num_beams_bar,
|
325 |
+
temperature_bar,
|
326 |
+
topk_bar,
|
327 |
+
topp_bar,
|
328 |
+
do_sample,
|
329 |
+
],
|
330 |
+
[
|
331 |
+
state, chatbot, textbox, imagebox, model_inputs,
|
332 |
+
memory_allocated
|
333 |
+
],
|
334 |
+
)
|
335 |
+
submit_btn.click(
|
336 |
+
bot,
|
337 |
+
[
|
338 |
+
textbox,
|
339 |
+
imagebox,
|
340 |
+
state,
|
341 |
+
prompt,
|
342 |
+
ai_prefix,
|
343 |
+
user_prefix,
|
344 |
+
seperator,
|
345 |
+
history_buffer,
|
346 |
+
max_new_token_bar,
|
347 |
+
num_beams_bar,
|
348 |
+
temperature_bar,
|
349 |
+
topk_bar,
|
350 |
+
topp_bar,
|
351 |
+
do_sample,
|
352 |
+
],
|
353 |
+
[
|
354 |
+
state, chatbot, textbox, imagebox, model_inputs,
|
355 |
+
memory_allocated
|
356 |
+
],
|
357 |
+
)
|
358 |
+
clear_btn.click(clear, [state],
|
359 |
+
[state, chatbot, textbox, imagebox, model_inputs])
|
360 |
+
return demo
|
361 |
+
|
362 |
+
|
363 |
+
if __name__ == "__main__":
|
364 |
+
llama_path = "checkpoints/llama-7b_hf"
|
365 |
+
open_flamingo_path = "checkpoints/OpenFlamingo-9B/checkpoint.pt"
|
366 |
+
finetune_path = "checkpoints/mmgpt-lora-v0-release.pt"
|
367 |
+
|
368 |
+
inferencer = Inferencer(
|
369 |
+
llama_path=llama_path,
|
370 |
+
open_flamingo_path=open_flamingo_path,
|
371 |
+
finetune_path=finetune_path)
|
372 |
+
init_memory = str(round(torch.cuda.memory_allocated() / 1024**3, 2)) + 'GB'
|
373 |
+
demo = build_conversation_demo()
|
374 |
+
demo.queue(max_size=3)
|
375 |
+
IP = "0.0.0.0"
|
376 |
+
PORT = 8997
|
377 |
+
demo.launch(server_name=IP, server_port=PORT, share=True)
|
378 |
+
|
379 |
+
|
configs/dataset_config.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
visual_datasets = [
|
2 |
+
dict(
|
3 |
+
type="llava",
|
4 |
+
vis_root="data/coco/train2017",
|
5 |
+
ann_paths=[
|
6 |
+
"data/llava/detail_23k.json",
|
7 |
+
"data/llava/complex_reasoning_77k.json",
|
8 |
+
],
|
9 |
+
),
|
10 |
+
dict(
|
11 |
+
type="llava_dial",
|
12 |
+
vis_root="data/coco/train2017",
|
13 |
+
ann_paths=[
|
14 |
+
"data/llava/conversation_58k.json",
|
15 |
+
],
|
16 |
+
),
|
17 |
+
dict(
|
18 |
+
type="aokvqa",
|
19 |
+
vis_root="data/coco/images",
|
20 |
+
ann_paths=[
|
21 |
+
"data/aokvqa/annotations/aokvqa_v1p0_train.json",
|
22 |
+
],
|
23 |
+
sample=5000,
|
24 |
+
),
|
25 |
+
dict(
|
26 |
+
type="minigpt4",
|
27 |
+
vis_root="data/cc_sbu_align/image",
|
28 |
+
ann_paths=[
|
29 |
+
"data/cc_sbu_align/filter_cap.json",
|
30 |
+
],
|
31 |
+
),
|
32 |
+
dict(
|
33 |
+
type="coco_caption",
|
34 |
+
vis_root="data/coco",
|
35 |
+
ann_paths=[
|
36 |
+
"data/coco/annotations/coco_karpathy_train_converted.json",
|
37 |
+
"data/coco/annotations/coco_karpathy_val.json",
|
38 |
+
],
|
39 |
+
sample=512,
|
40 |
+
),
|
41 |
+
dict(
|
42 |
+
type="ocr_vqa",
|
43 |
+
vis_root="data/OCR_VQA/image",
|
44 |
+
ann_paths=[
|
45 |
+
"data/OCR_VQA/downloaded_dataset.json",
|
46 |
+
],
|
47 |
+
sample=512,
|
48 |
+
),
|
49 |
+
]
|
50 |
+
|
51 |
+
language_datasets = [
|
52 |
+
dict(
|
53 |
+
type="dolly",
|
54 |
+
ann_path="data/dolly/databricks-dolly-15k.jsonl",
|
55 |
+
),
|
56 |
+
dict(
|
57 |
+
type="alpaca_gpt4",
|
58 |
+
ann_path="data/alpaca_gpt4/alpaca_gpt4_data.json",
|
59 |
+
),
|
60 |
+
dict(
|
61 |
+
type="baize",
|
62 |
+
ann_path="data/baize/quora_chat_data.json",
|
63 |
+
),
|
64 |
+
]
|
configs/lora_config.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tuning_config = dict(
|
2 |
+
lora=True,
|
3 |
+
lora_target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "to_q", "to_kv", "to_out", "ff.1", "ff.3"],
|
4 |
+
lora_r=16,
|
5 |
+
lora_alpha=16,
|
6 |
+
lora_dropout=0.0,
|
7 |
+
vis=True,
|
8 |
+
unfrozen=[],
|
9 |
+
)
|
docs/images/demo_image.jpg
ADDED
![]() |
environment.yml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: mmgpt
|
2 |
+
channels:
|
3 |
+
- defaults
|
4 |
+
dependencies:
|
5 |
+
- python=3.9
|
6 |
+
- conda-forge::openjdk
|
7 |
+
- pip
|
8 |
+
- pip:
|
9 |
+
- -r requirements.txt
|
10 |
+
- -e .
|
mmgpt/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .models.builder import create_model_and_transforms
|
2 |
+
from .models.open_flamingo import Flamingo
|
mmgpt/datasets/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .builder import build_dataset # noqa: F401
|
2 |
+
from .dial_dataset import DialDataset # noqa: F401
|
3 |
+
from .samplers import InfiniteSampler # noqa: F401
|
4 |
+
from .vqa_dataset import VQADataset # noqa: F401
|
mmgpt/datasets/alpaca_gpt4_dataset.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
from mmgpt.datasets.dolly_dataset import DollyDataset
|
4 |
+
|
5 |
+
|
6 |
+
class AlpacaGPT4Dataset(DollyDataset):
|
7 |
+
"""
|
8 |
+
```json
|
9 |
+
[
|
10 |
+
{
|
11 |
+
"instruction": "Identify the odd one out.",
|
12 |
+
"input": "Twitter, Instagram, Telegram",
|
13 |
+
"output": "The odd one out is Telegram. Twitter and Instagram are social media platforms mainly for sharing information, images and videos while Telegram is a cloud-based instant messaging and voice-over-IP service."
|
14 |
+
},
|
15 |
+
]
|
16 |
+
"""
|
17 |
+
|
18 |
+
def load_annotation(self, ann_path):
|
19 |
+
self.annotation = json.load(open(ann_path, "r"))
|
20 |
+
|
21 |
+
def process_text(self, ann):
|
22 |
+
instruction = ann["instruction"]
|
23 |
+
input = ann["input"]
|
24 |
+
output = ann["output"]
|
25 |
+
instruction = self.prompter(instruction=instruction, input=input)
|
26 |
+
return dict(instruction=instruction, answer=output)
|
mmgpt/datasets/aokvqa_dataset.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
from .vqa_dataset import VQADataset
|
4 |
+
|
5 |
+
REASON_QUESTIONS = [
|
6 |
+
"Why?",
|
7 |
+
"Why is this?",
|
8 |
+
"And why?",
|
9 |
+
"What is the reason?",
|
10 |
+
"And can you tell me why?",
|
11 |
+
"Can you tell me why?",
|
12 |
+
"Can you tell me the reason?",
|
13 |
+
]
|
14 |
+
|
15 |
+
|
16 |
+
class AOKVQADataset(VQADataset):
|
17 |
+
def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs):
|
18 |
+
super().__init__(tokenizer, vis_processor, vis_root, ann_paths, **kwargs)
|
19 |
+
|
20 |
+
def process_text(self, ann):
|
21 |
+
question = ann["question"]
|
22 |
+
question = question + " " + random.choice(REASON_QUESTIONS)
|
23 |
+
|
24 |
+
choices = ann["choices"]
|
25 |
+
true_answer = choices[ann["correct_choice_idx"]]
|
26 |
+
answer = "The answer is " + true_answer + ". Because " + " ".join(ann["rationales"])
|
27 |
+
|
28 |
+
is_option = random.random() < self.option_prob and len(choices) > 1
|
29 |
+
if is_option:
|
30 |
+
instruction = self.prompter(question, choices)
|
31 |
+
else:
|
32 |
+
instruction = self.prompter(question)
|
33 |
+
|
34 |
+
instruction = self.prompter(question)
|
35 |
+
return dict(instruction=instruction, answer=answer)
|
36 |
+
|
37 |
+
|
38 |
+
def build_aokvqa_dataset(
|
39 |
+
tokenizer,
|
40 |
+
vis_processor,
|
41 |
+
vis_root="data/coco/images",
|
42 |
+
ann_paths=["data/aokvqa/annotations/aokvqa_v1p0_train.json"],
|
43 |
+
sample_image=False,
|
44 |
+
):
|
45 |
+
return AOKVQADataset(
|
46 |
+
tokenizer=tokenizer,
|
47 |
+
vis_processor=vis_processor,
|
48 |
+
vis_root=vis_root,
|
49 |
+
ann_paths=ann_paths,
|
50 |
+
sample_image=sample_image,
|
51 |
+
)
|
mmgpt/datasets/baize_dataset.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
from mmgpt.datasets.dolly_dataset import DollyDataset
|
4 |
+
|
5 |
+
|
6 |
+
TEMPLATE = {
|
7 |
+
"description": "Template used by Alpaca-LoRA.",
|
8 |
+
"prompt_choice": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{question}\n\n### Input:\n{options}\n\n### Response:\n",
|
9 |
+
"prompt_qa": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{question}\n\n### Response:\n",
|
10 |
+
"prompt_dial": "\n\n### Instruction:\n{question}\n\n### Response:\n",
|
11 |
+
"response_split": "### Response:",
|
12 |
+
}
|
13 |
+
|
14 |
+
class LangDialPrompter:
|
15 |
+
def __call__(self, question, options=None):
|
16 |
+
if options:
|
17 |
+
options = ", ".join(options)
|
18 |
+
res = TEMPLATE["prompt_choice"].format(image="<image>", question=question, options=options)
|
19 |
+
else:
|
20 |
+
res = TEMPLATE["prompt_dial"].format(question=question)
|
21 |
+
return res
|
22 |
+
|
23 |
+
def get_response(self, output: str) -> str:
|
24 |
+
return output.split(TEMPLATE["response_split"])[-1].strip()
|
25 |
+
|
26 |
+
class BaiZeDataset(DollyDataset):
|
27 |
+
"""
|
28 |
+
```json
|
29 |
+
[
|
30 |
+
{
|
31 |
+
"instruction": "Identify the odd one out.",
|
32 |
+
"input": "Twitter, Instagram, Telegram",
|
33 |
+
"output": "The odd one out is Telegram. Twitter and Instagram are social media platforms mainly for sharing information, images and videos while Telegram is a cloud-based instant messaging and voice-over-IP service."
|
34 |
+
},
|
35 |
+
]
|
36 |
+
"""
|
37 |
+
def __init__(self, *args, **kwargs):
|
38 |
+
super(BaiZeDataset, self).__init__(*args, **kwargs)
|
39 |
+
self.prompter = LangDialPrompter()
|
40 |
+
|
41 |
+
def load_annotation(self, ann_path):
|
42 |
+
self.annotation = json.load(open(ann_path, "r"))
|
43 |
+
|
44 |
+
def process_text(self, anns):
|
45 |
+
# TODO remove this
|
46 |
+
begin_string = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
|
47 |
+
convs = anns['input'].split("[|Human|] ")
|
48 |
+
conv_list = []
|
49 |
+
for conv_id, one_conv in enumerate(convs[1:-1]):
|
50 |
+
question, answer = one_conv.split("[|AI|] ")
|
51 |
+
question = question.replace("\n", "")
|
52 |
+
answer = answer.replace("\n", "")
|
53 |
+
instruction = self.prompter(question)
|
54 |
+
if conv_id == 0:
|
55 |
+
single_conv = dict(instruction=begin_string + instruction, answer=answer)
|
56 |
+
else:
|
57 |
+
single_conv = dict(instruction=instruction, answer=answer)
|
58 |
+
conv_list.append(single_conv)
|
59 |
+
return conv_list
|
60 |
+
|
61 |
+
def __getitem__(self, index):
|
62 |
+
ann = self.annotation[index]
|
63 |
+
text_list = self.process_text(ann)
|
64 |
+
res_list = []
|
65 |
+
for text in text_list:
|
66 |
+
single_res = self.tokenize(text)
|
67 |
+
single_res["instruction"] = text["instruction"]
|
68 |
+
single_res["answer"] = text["answer"]
|
69 |
+
res_list.append(single_res)
|
70 |
+
|
71 |
+
input_ids = []
|
72 |
+
attention_mask = []
|
73 |
+
labels = []
|
74 |
+
instruction = []
|
75 |
+
answer = []
|
76 |
+
for res in res_list:
|
77 |
+
input_ids.extend(res["input_ids"])
|
78 |
+
attention_mask.extend(res["attention_mask"])
|
79 |
+
labels.extend(res["labels"])
|
80 |
+
instruction.append(res["instruction"])
|
81 |
+
answer.append(res["answer"])
|
82 |
+
|
83 |
+
res = dict(
|
84 |
+
input_ids=input_ids, attention_mask=attention_mask, labels=labels, instruction=instruction, answer=answer
|
85 |
+
)
|
86 |
+
return res
|
mmgpt/datasets/builder.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from .alpaca_gpt4_dataset import AlpacaGPT4Dataset # noqa: F401
|
5 |
+
from .aokvqa_dataset import AOKVQADataset # noqa: F401
|
6 |
+
from .cc_sbu_align_dataset import CcSbuAlignDataset # noqa: F401
|
7 |
+
from .clevr_dataset import CLEVRDataset # noqa: F401
|
8 |
+
from .coco_caption_dataset import COCOCaptionDataset # noqa: F401
|
9 |
+
from .dial_dataset import DialDataset # noqa: F401
|
10 |
+
from .dolly_dataset import DollyDataset # noqa: F401
|
11 |
+
from .gqa_dataset import GQADataset # noqa: F401
|
12 |
+
from .llava_dataset import LlavaDataset # noqa: F401
|
13 |
+
from .nlvr_dataset import NLVRv1Dataset, NLVRv2Dataset # noqa: F401
|
14 |
+
from .ocr_vqa_dataset import OCRVQADataset # noqa: F401
|
15 |
+
from .snli_ve_datasets import SNLIVEDataset # noqa: F401
|
16 |
+
from .text_ocr_dataset import TextOCRDataset # noqa: F401
|
17 |
+
from .vqa_dataset import ConcatDataset, VQADataset # noqa: F401
|
18 |
+
from .baize_dataset import BaiZeDataset # noqa: F401
|
19 |
+
|
20 |
+
|
21 |
+
def build_dataset(dataset_config, **kwargs):
|
22 |
+
if isinstance(dataset_config, list):
|
23 |
+
datasets = [build_dataset(cfg, **kwargs) for cfg in dataset_config]
|
24 |
+
return ConcatDataset(datasets)
|
25 |
+
dataset_type = dataset_config.pop("type")
|
26 |
+
sample = dataset_config.pop("sample", -1)
|
27 |
+
if dataset_type == "llava":
|
28 |
+
dataset = LlavaDataset(
|
29 |
+
**dataset_config,
|
30 |
+
**kwargs,
|
31 |
+
)
|
32 |
+
elif dataset_type == "vqa":
|
33 |
+
dataset = VQADataset(
|
34 |
+
**dataset_config,
|
35 |
+
**kwargs,
|
36 |
+
)
|
37 |
+
elif dataset_type == "minigpt4":
|
38 |
+
dataset = CcSbuAlignDataset(
|
39 |
+
**dataset_config,
|
40 |
+
**kwargs,
|
41 |
+
)
|
42 |
+
elif dataset_type == "llava_dial":
|
43 |
+
dataset = DialDataset(
|
44 |
+
**dataset_config,
|
45 |
+
**kwargs,
|
46 |
+
)
|
47 |
+
elif dataset_type == "coco_dial":
|
48 |
+
dataset = DialDataset(
|
49 |
+
**dataset_config,
|
50 |
+
**kwargs,
|
51 |
+
)
|
52 |
+
elif dataset_type == "aokvqa":
|
53 |
+
dataset = AOKVQADataset(
|
54 |
+
**dataset_config,
|
55 |
+
**kwargs,
|
56 |
+
)
|
57 |
+
elif dataset_type == "okvqa":
|
58 |
+
dataset = VQADataset(
|
59 |
+
**dataset_config,
|
60 |
+
**kwargs,
|
61 |
+
)
|
62 |
+
elif dataset_type == "text_ocr":
|
63 |
+
dataset = TextOCRDataset(
|
64 |
+
**dataset_config,
|
65 |
+
**kwargs,
|
66 |
+
)
|
67 |
+
elif dataset_type == "ocr_vqa":
|
68 |
+
dataset = OCRVQADataset(
|
69 |
+
**dataset_config,
|
70 |
+
**kwargs,
|
71 |
+
)
|
72 |
+
elif dataset_type == "coco_caption":
|
73 |
+
dataset = COCOCaptionDataset(
|
74 |
+
**dataset_config,
|
75 |
+
**kwargs,
|
76 |
+
)
|
77 |
+
elif dataset_type == "gqa":
|
78 |
+
dataset = GQADataset(
|
79 |
+
**dataset_config,
|
80 |
+
**kwargs,
|
81 |
+
)
|
82 |
+
elif dataset_type == "clevr":
|
83 |
+
dataset = CLEVRDataset(
|
84 |
+
**dataset_config,
|
85 |
+
**kwargs,
|
86 |
+
)
|
87 |
+
elif dataset_type == "nlvrv1":
|
88 |
+
dataset = NLVRv1Dataset(
|
89 |
+
**dataset_config,
|
90 |
+
**kwargs,
|
91 |
+
)
|
92 |
+
elif dataset_type == "nlvrv2":
|
93 |
+
dataset = NLVRv2Dataset(
|
94 |
+
**dataset_config,
|
95 |
+
**kwargs,
|
96 |
+
)
|
97 |
+
elif dataset_type == "snlive":
|
98 |
+
dataset = SNLIVEDataset(
|
99 |
+
**dataset_config,
|
100 |
+
**kwargs,
|
101 |
+
)
|
102 |
+
elif dataset_type == "dolly":
|
103 |
+
dataset = DollyDataset(
|
104 |
+
**dataset_config,
|
105 |
+
**kwargs,
|
106 |
+
)
|
107 |
+
elif dataset_type == "alpaca_gpt4":
|
108 |
+
dataset = AlpacaGPT4Dataset(
|
109 |
+
**dataset_config,
|
110 |
+
**kwargs,
|
111 |
+
)
|
112 |
+
elif dataset_type == "baize":
|
113 |
+
dataset = BaiZeDataset(
|
114 |
+
**dataset_config,
|
115 |
+
**kwargs,
|
116 |
+
)
|
117 |
+
else:
|
118 |
+
raise NotImplementedError
|
119 |
+
|
120 |
+
if sample > 0:
|
121 |
+
random_indices = np.random.choice(len(dataset), min(sample, len(dataset)), replace=False)
|
122 |
+
subsample_dataset = torch.utils.data.Subset(dataset, random_indices)
|
123 |
+
subsample_dataset.collater = dataset.collater
|
124 |
+
return subsample_dataset
|
125 |
+
else:
|
126 |
+
return dataset
|
mmgpt/datasets/cc_sbu_align_dataset.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
from .vqa_dataset import VQADataset, VQAPrompter
|
8 |
+
|
9 |
+
QUESTIONS = [
|
10 |
+
"please describe the image",
|
11 |
+
"can you describe the image",
|
12 |
+
"Could you provide a description of the image?",
|
13 |
+
"What do you see in this image?",
|
14 |
+
"Share your thoughts on the content of the image.",
|
15 |
+
"Please narrate what's happening in the picture.",
|
16 |
+
"Can you give a brief explanation of the image?",
|
17 |
+
"Describe the main elements and details present in the image.",
|
18 |
+
"In your own words, what is depicted in the image?",
|
19 |
+
"Can you outline the key aspects of the image?",
|
20 |
+
"What are the most striking features in this image?",
|
21 |
+
"Please provide a summary of the image's content.",
|
22 |
+
"Describe the overall theme or concept captured in the image.",
|
23 |
+
"How would you explain the image's composition and focus?",
|
24 |
+
"What is the focal point or main subject of the image?",
|
25 |
+
"How do the different components of the image interact with each other?",
|
26 |
+
"What would be a fitting caption for this image?",
|
27 |
+
"Can you create a concise description that captures the essence of the image?",
|
28 |
+
"How would you briefly summarize the content of this image in a phrase or sentence?",
|
29 |
+
"Please provide a catchy and relevant caption for this picture.",
|
30 |
+
"If you were to give this image a title, what would it be?",
|
31 |
+
"Describe the image in one creative sentence.",
|
32 |
+
"Please suggest a memorable phrase that encapsulates the image's content.",
|
33 |
+
"What engaging phrase would best represent this image?",
|
34 |
+
"Can you create an expressive caption that highlights the main theme of the image?",
|
35 |
+
"How would you sum up the image's story for a caption?",
|
36 |
+
"Provide an eye-catching caption that conveys the image's core message.",
|
37 |
+
"If you were to give this image a headline, what would it say?",
|
38 |
+
"Can you craft a captivating caption that communicates the essence of the image?",
|
39 |
+
"How would you describe the image's content in a powerful caption?",
|
40 |
+
"Please provide an inventive title to summarize the scene depicted in the image.",
|
41 |
+
"Compose a concise and striking phrase that reflects the image's key elements.",
|
42 |
+
"If you were to create a caption for this image, what would it be?",
|
43 |
+
"Offer a compelling caption that highlights the central focus of the image.",
|
44 |
+
"Can you produce a unique caption that encapsulates the image's overall mood?",
|
45 |
+
"Please generate an attention-grabbing caption that would best illustrate the events captured in this image",
|
46 |
+
"How would you express the image's main idea in an impactful sentence?",
|
47 |
+
"Please create a vivid and concise title that conveys the essence of the picture.",
|
48 |
+
"Compose an imaginative caption that reflects the image's most striking features.",
|
49 |
+
"What memorable statement would best represent the scene illustrated in this image?",
|
50 |
+
"Draft an evocative caption that brings the image to life for the reader.",
|
51 |
+
"Can you suggest an insightful caption that highlights the underlying message of the image?",
|
52 |
+
"What engaging phrase would effectively convey the action or subject matter depicted in this picture?",
|
53 |
+
"How would you encapsulate the image's core theme in a concise and expressive manner?",
|
54 |
+
"Please provide a creative and impactful title that captures the spirit of the image.",
|
55 |
+
"Craft a captivating caption that showcases the image's most prominent attributes.",
|
56 |
+
"What intriguing statement would best sum up the scene presented in this image?",
|
57 |
+
"Develop a descriptive caption that paints a vivid picture for the viewer.",
|
58 |
+
"Can you give a detailed account of the image's contents?",
|
59 |
+
"What are the key elements and features visible in this image?",
|
60 |
+
"How would you narrate the events or actions depicted in the picture?",
|
61 |
+
"Please share your observations about the various components present in the image.",
|
62 |
+
"What is the overall theme or concept captured in this image? Can you describe it?",
|
63 |
+
]
|
64 |
+
|
65 |
+
|
66 |
+
class CcSbuAlignDataset(VQADataset):
|
67 |
+
def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, add_eos=True, ignore_instruction=True):
|
68 |
+
self.tokenizer = tokenizer
|
69 |
+
self.vis_root = vis_root
|
70 |
+
|
71 |
+
self.annotation = []
|
72 |
+
for ann_path in ann_paths:
|
73 |
+
self.annotation.extend(json.load(open(ann_path, "r"))["annotations"])
|
74 |
+
|
75 |
+
self.vis_processor = vis_processor
|
76 |
+
self.prompter = VQAPrompter()
|
77 |
+
self.add_eos = add_eos
|
78 |
+
self.ignore_instruction = ignore_instruction
|
79 |
+
|
80 |
+
def process_text(self, ann):
|
81 |
+
# random select a question
|
82 |
+
question = random.choice(QUESTIONS)
|
83 |
+
answer = ann["caption"]
|
84 |
+
instruction = self.prompter(question)
|
85 |
+
return dict(instruction=instruction, answer=answer)
|
86 |
+
|
87 |
+
def process_image(self, ann):
|
88 |
+
image_path = os.path.join(self.vis_root, ann["image_id"] + ".jpg")
|
89 |
+
image = Image.open(image_path).convert("RGB")
|
90 |
+
|
91 |
+
image = self.vis_processor(image)
|
92 |
+
return image
|
93 |
+
|
94 |
+
|
95 |
+
def build_ccsbualign_dataset(
|
96 |
+
tokenizer,
|
97 |
+
vis_processor,
|
98 |
+
vis_root="data/cc_sbu_align/image/",
|
99 |
+
ann_paths=["data/cc_sbu_align/filter_cap.json"],
|
100 |
+
**kwargs,
|
101 |
+
):
|
102 |
+
return CcSbuAlignDataset(
|
103 |
+
tokenizer=tokenizer,
|
104 |
+
vis_processor=vis_processor,
|
105 |
+
vis_root=vis_root,
|
106 |
+
ann_paths=ann_paths,
|
107 |
+
)
|
mmgpt/datasets/clevr_dataset.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from collections import defaultdict
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from .vqa_dataset import VQADataset
|
9 |
+
|
10 |
+
|
11 |
+
class CLEVRDataset(VQADataset):
|
12 |
+
"""Visual Reasoning Dataset. It also contains Dialog.
|
13 |
+
|
14 |
+
Note: The image is a little bit simple. with several objects and simple background.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs):
|
18 |
+
super().__init__(tokenizer, vis_processor, vis_root, ann_paths=[], **kwargs)
|
19 |
+
|
20 |
+
self.annotation = self.load_annotations(ann_paths)
|
21 |
+
if self.sample_image:
|
22 |
+
print("randomly sample one annotation for each image")
|
23 |
+
self.annotation = self.parse_annotation(self.annotation)
|
24 |
+
self._add_instance_ids()
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
def load_annotations(ann_paths):
|
28 |
+
annotation = []
|
29 |
+
for ann_path in ann_paths:
|
30 |
+
ann = json.load(open(ann_path, "r"))
|
31 |
+
annotation.extend(ann["questions"])
|
32 |
+
return annotation
|
33 |
+
|
34 |
+
def parse_annotation(self, annotation):
|
35 |
+
image_list = defaultdict(list)
|
36 |
+
for ann in annotation:
|
37 |
+
image_list[ann["image_filename"]].append(ann)
|
38 |
+
annotation = []
|
39 |
+
for ann_list in image_list.values():
|
40 |
+
annotation.append(random.choice(ann_list))
|
41 |
+
return annotation
|
42 |
+
|
43 |
+
def process_text(self, ann):
|
44 |
+
question = ann["question"]
|
45 |
+
answer = ann["answer"]
|
46 |
+
instruction = self.prompter(question)
|
47 |
+
return dict(instruction=instruction, answer=answer)
|
48 |
+
|
49 |
+
def process_image(self, ann):
|
50 |
+
split = ann["split"]
|
51 |
+
image_path = os.path.join(self.vis_root, split, ann["image_filename"])
|
52 |
+
image = Image.open(image_path).convert("RGB")
|
53 |
+
|
54 |
+
image = self.vis_processor(image)
|
55 |
+
return image
|
56 |
+
|
57 |
+
|
58 |
+
def build_clevr_dataset(
|
59 |
+
tokenizer,
|
60 |
+
vis_processor,
|
61 |
+
vis_root="data/clevr/CLEVR_v1.0/images",
|
62 |
+
ann_paths=[
|
63 |
+
"data/clevr/CLEVR_v1.0/questions/CLEVR_train_questions.json",
|
64 |
+
"data/clevr/CLEVR_v1.0/questions/CLEVR_val_questions.json",
|
65 |
+
],
|
66 |
+
sample_image=False,
|
67 |
+
):
|
68 |
+
return CLEVRDataset(
|
69 |
+
tokenizer=tokenizer,
|
70 |
+
vis_processor=vis_processor,
|
71 |
+
vis_root=vis_root,
|
72 |
+
ann_paths=ann_paths,
|
73 |
+
sample_image=sample_image,
|
74 |
+
)
|
mmgpt/datasets/coco_caption_dataset.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import json
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
from PIL import Image
|
14 |
+
from transformers import LlamaTokenizer
|
15 |
+
|
16 |
+
from .vqa_dataset import VQADataset
|
17 |
+
|
18 |
+
QUESTIONS = [
|
19 |
+
"please describe the image",
|
20 |
+
"can you describe the image",
|
21 |
+
"Could you provide a description of the image?",
|
22 |
+
"What do you see in this image?",
|
23 |
+
"Share your thoughts on the content of the image.",
|
24 |
+
"Please narrate what's happening in the picture.",
|
25 |
+
"Can you give a brief explanation of the image?",
|
26 |
+
"Describe the main elements and details present in the image.",
|
27 |
+
"In your own words, what is depicted in the image?",
|
28 |
+
"Can you outline the key aspects of the image?",
|
29 |
+
"What are the most striking features in this image?",
|
30 |
+
"Please provide a summary of the image's content.",
|
31 |
+
"Describe the overall theme or concept captured in the image.",
|
32 |
+
"How would you explain the image's composition and focus?",
|
33 |
+
"What is the focal point or main subject of the image?",
|
34 |
+
"How do the different components of the image interact with each other?",
|
35 |
+
"What would be a fitting caption for this image?",
|
36 |
+
"Can you create a concise description that captures the essence of the image?",
|
37 |
+
"How would you briefly summarize the content of this image in a phrase or sentence?",
|
38 |
+
"Please provide a catchy and relevant caption for this picture.",
|
39 |
+
"If you were to give this image a title, what would it be?",
|
40 |
+
"Describe the image in one creative sentence.",
|
41 |
+
"Please suggest a memorable phrase that encapsulates the image's content.",
|
42 |
+
"What engaging phrase would best represent this image?",
|
43 |
+
"Can you create an expressive caption that highlights the main theme of the image?",
|
44 |
+
"How would you sum up the image's story for a caption?",
|
45 |
+
"Provide an eye-catching caption that conveys the image's core message.",
|
46 |
+
"If you were to give this image a headline, what would it say?",
|
47 |
+
"Can you craft a captivating caption that communicates the essence of the image?",
|
48 |
+
"How would you describe the image's content in a powerful caption?",
|
49 |
+
"Please provide an inventive title to summarize the scene depicted in the image.",
|
50 |
+
"Compose a concise and striking phrase that reflects the image's key elements.",
|
51 |
+
"If you were to create a caption for this image, what would it be?",
|
52 |
+
"Offer a compelling caption that highlights the central focus of the image.",
|
53 |
+
"Can you produce a unique caption that encapsulates the image's overall mood?",
|
54 |
+
"Please generate an attention-grabbing caption that would best illustrate the events captured in this image",
|
55 |
+
"How would you express the image's main idea in an impactful sentence?",
|
56 |
+
"Please create a vivid and concise title that conveys the essence of the picture.",
|
57 |
+
"Compose an imaginative caption that reflects the image's most striking features.",
|
58 |
+
"What memorable statement would best represent the scene illustrated in this image?",
|
59 |
+
"Draft an evocative caption that brings the image to life for the reader.",
|
60 |
+
"Can you suggest an insightful caption that highlights the underlying message of the image?",
|
61 |
+
"What engaging phrase would effectively convey the action or subject matter depicted in this picture?",
|
62 |
+
"How would you encapsulate the image's core theme in a concise and expressive manner?",
|
63 |
+
"Please provide a creative and impactful title that captures the spirit of the image.",
|
64 |
+
"Craft a captivating caption that showcases the image's most prominent attributes.",
|
65 |
+
"What intriguing statement would best sum up the scene presented in this image?",
|
66 |
+
"Develop a descriptive caption that paints a vivid picture for the viewer.",
|
67 |
+
"Can you give a detailed account of the image's contents?",
|
68 |
+
"What are the key elements and features visible in this image?",
|
69 |
+
"How would you narrate the events or actions depicted in the picture?",
|
70 |
+
"Please share your observations about the various components present in the image.",
|
71 |
+
"What is the overall theme or concept captured in this image? Can you describe it?",
|
72 |
+
]
|
73 |
+
|
74 |
+
|
75 |
+
class COCOCaptionDataset(VQADataset):
|
76 |
+
def __init__(
|
77 |
+
self, tokenizer, vis_processor=None, vis_root=None, ann_paths=[], add_eos=True, ignore_instruction=True
|
78 |
+
):
|
79 |
+
"""
|
80 |
+
vis_root (string): Root directory of images (e.g. coco/images/)
|
81 |
+
ann_root (string): directory to store the annotation file
|
82 |
+
"""
|
83 |
+
self.tokenizer: LlamaTokenizer = tokenizer
|
84 |
+
self.vis_root = vis_root
|
85 |
+
|
86 |
+
self.annotation = []
|
87 |
+
for ann_path in ann_paths:
|
88 |
+
self.annotation.extend(json.load(open(ann_path, "r")))
|
89 |
+
|
90 |
+
self.vis_processor = vis_processor
|
91 |
+
|
92 |
+
instructions = []
|
93 |
+
for question in QUESTIONS:
|
94 |
+
# instruction = f"Below is a question about an image. Write a response to answer the question.\n\n### Image:\n<image>\n\n### Question:\n{question}\n\n### Answer:\n".format(
|
95 |
+
# question
|
96 |
+
# )
|
97 |
+
instruction = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Image:\n{image}\n\n### Instruction:\n{question}\n\n### Response:\n".format(
|
98 |
+
image="<image>", question=question
|
99 |
+
)
|
100 |
+
instructions.append(instruction)
|
101 |
+
self.instructions = instructions
|
102 |
+
self.add_eos = add_eos
|
103 |
+
self.ignore_instruction = ignore_instruction
|
104 |
+
|
105 |
+
def process_image(self, ann):
|
106 |
+
image_path = os.path.join(self.vis_root, ann["image"])
|
107 |
+
image = Image.open(image_path).convert("RGB")
|
108 |
+
|
109 |
+
image = self.vis_processor(image)
|
110 |
+
return image
|
111 |
+
|
112 |
+
def process_text(self, ann):
|
113 |
+
all_captions = ann["caption"]
|
114 |
+
if not isinstance(all_captions, list):
|
115 |
+
all_captions = [all_captions]
|
116 |
+
caption = random.choice(all_captions)
|
117 |
+
instruction = random.choice(self.instructions)
|
118 |
+
|
119 |
+
return dict(instruction=instruction, answer=caption)
|
mmgpt/datasets/dial_dataset.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .vqa_dataset import VQADataset
|
2 |
+
|
3 |
+
TEMPLATE = {
|
4 |
+
"description": "Template used by Alpaca-LoRA.",
|
5 |
+
# "prompt_choice": "Below is a multiple choice question about an image, along with answer options. Please choose the correct answer from these options.\n\n### Image:\n{image}\n\n### Question:\n{question}\n\n### Options:\n{options}\n\n### Answer:\n",
|
6 |
+
# "prompt_qa": "Below is a question about an image. Write a response to answer the question.\n\n### Image:\n{image}\n\n### Question:\n{question}\n\n### Answer:\n",
|
7 |
+
"prompt_choice": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Image:\n{image}\n\n### Instruction:\n{question}\n\n### Input:\n{options}\n\n### Response:\n",
|
8 |
+
"prompt_qa": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Image:\n{image}\n\n### Instruction:\n{question}\n\n### Response:\n",
|
9 |
+
"prompt_dial": "\n\n### Instruction:\n{question}\n\n### Response:\n",
|
10 |
+
"response_split": "### Response:",
|
11 |
+
}
|
12 |
+
|
13 |
+
|
14 |
+
class DialPrompter:
|
15 |
+
def __call__(self, question, options=None):
|
16 |
+
if options:
|
17 |
+
options = ", ".join(options)
|
18 |
+
res = TEMPLATE["prompt_choice"].format(image="<image>", question=question, options=options)
|
19 |
+
else:
|
20 |
+
res = TEMPLATE["prompt_dial"].format(question=question)
|
21 |
+
return res
|
22 |
+
|
23 |
+
def get_response(self, output: str) -> str:
|
24 |
+
return output.split(TEMPLATE["response_split"])[-1].strip()
|
25 |
+
|
26 |
+
|
27 |
+
class DialDataset(VQADataset):
|
28 |
+
def __init__(self, *args, **kwargs):
|
29 |
+
super(DialDataset, self).__init__(*args, **kwargs)
|
30 |
+
self.prompter = DialPrompter()
|
31 |
+
|
32 |
+
def _add_instance_ids(self, key="id"):
|
33 |
+
for idx, ann in enumerate(self.annotation):
|
34 |
+
ann[key] = str(idx)
|
35 |
+
|
36 |
+
def process_text(self, anns):
|
37 |
+
# TODO remove this
|
38 |
+
begin_string = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Image:\n{image}".format(
|
39 |
+
image="<image>"
|
40 |
+
)
|
41 |
+
num_convs = len(anns["conversations"]) // 2
|
42 |
+
conv_list = []
|
43 |
+
for conv_id in range(num_convs):
|
44 |
+
question = anns["conversations"][conv_id]["value"]
|
45 |
+
# remove '<image>' tag and '\n'
|
46 |
+
question = question.replace("<image>", "").replace("\n", "")
|
47 |
+
answer = anns["conversations"][conv_id + 1]["value"]
|
48 |
+
instruction = self.prompter(question)
|
49 |
+
if conv_id == 0:
|
50 |
+
single_conv = dict(instruction=begin_string + instruction, answer=answer)
|
51 |
+
else:
|
52 |
+
single_conv = dict(instruction=instruction, answer=answer)
|
53 |
+
conv_list.append(single_conv)
|
54 |
+
return conv_list
|
55 |
+
|
56 |
+
def __getitem__(self, index):
|
57 |
+
ann = self.annotation[index]
|
58 |
+
image = self.process_image(ann)
|
59 |
+
text_list = self.process_text(ann)
|
60 |
+
res_list = []
|
61 |
+
for text in text_list:
|
62 |
+
single_res = self.tokenize(text)
|
63 |
+
single_res["instruction"] = text["instruction"]
|
64 |
+
single_res["answer"] = text["answer"]
|
65 |
+
res_list.append(single_res)
|
66 |
+
|
67 |
+
input_ids = []
|
68 |
+
attention_mask = []
|
69 |
+
labels = []
|
70 |
+
instruction = []
|
71 |
+
answer = []
|
72 |
+
for res in res_list:
|
73 |
+
input_ids.extend(res["input_ids"])
|
74 |
+
attention_mask.extend(res["attention_mask"])
|
75 |
+
labels.extend(res["labels"])
|
76 |
+
instruction.extend(res["instruction"])
|
77 |
+
answer.extend(res["answer"])
|
78 |
+
|
79 |
+
res = dict(
|
80 |
+
input_ids=input_ids, attention_mask=attention_mask, labels=labels, instruction=instruction, answer=answer
|
81 |
+
)
|
82 |
+
res.update(image=image)
|
83 |
+
return res
|
mmgpt/datasets/dolly_dataset.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import json
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from transformers import LlamaTokenizer
|
7 |
+
|
8 |
+
TEMPLATE = {
|
9 |
+
"description": "Template used by LLM.",
|
10 |
+
"prompt_no_input_format": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
|
11 |
+
"prompt_with_input_format": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
|
12 |
+
"response_split": "### Response:",
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
class LMPrompter:
|
17 |
+
def __call__(self, instruction, input=None):
|
18 |
+
if input is None or len(input) == 0:
|
19 |
+
return TEMPLATE["prompt_no_input_format"].format(instruction=instruction)
|
20 |
+
else:
|
21 |
+
return TEMPLATE["prompt_with_input_format"].format(instruction=instruction, input=input)
|
22 |
+
|
23 |
+
def get_response(self, output: str) -> str:
|
24 |
+
return output.split(TEMPLATE["response_split"])[-1].strip()
|
25 |
+
|
26 |
+
|
27 |
+
class DollyDataset(Dataset):
|
28 |
+
"""Each line of the annotation file is a json object with the following fields:
|
29 |
+
|
30 |
+
{
|
31 |
+
"instruction": "What is a dispersive prism?",
|
32 |
+
"context": "In optics, a dispersive prism is an optical prism that is used to disperse light, that is, to separate light into its spectral components (the colors of the rainbow). Different wavelengths (colors) of light will be deflected by the prism at different angles.[1] This is a result of the prism material's index of refraction varying with wavelength (dispersion). Generally, longer wavelengths (red) undergo a smaller deviation than shorter wavelengths (blue). The dispersion of white light into colors by a prism led Sir Isaac Newton to conclude that white light consisted of a mixture of different colors.",
|
33 |
+
"response": "A dispersive prism is an optical prism that disperses the light's different wavelengths at different angles. When white light is shined through a dispersive prism it will separate into the different colors of the rainbow.",
|
34 |
+
"category": "summarization"
|
35 |
+
}
|
36 |
+
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, tokenizer, ann_path: str, add_eos=True, ignore_instruction=True, **kwargs):
|
40 |
+
"""
|
41 |
+
ann_path (string): directory to store the annotation file
|
42 |
+
"""
|
43 |
+
assert tokenizer.add_eos_token is False, "tokenizer should not add eos token by default"
|
44 |
+
self.tokenizer: LlamaTokenizer = tokenizer
|
45 |
+
|
46 |
+
self.annotation = []
|
47 |
+
self.prompter = LMPrompter()
|
48 |
+
self.add_eos = add_eos
|
49 |
+
self.ignore_instruction = ignore_instruction
|
50 |
+
self.load_annotation(ann_path)
|
51 |
+
|
52 |
+
def load_annotation(self, ann_path):
|
53 |
+
self.annotation = []
|
54 |
+
for line in open(ann_path, "r").readlines():
|
55 |
+
self.annotation.append(json.loads(line))
|
56 |
+
|
57 |
+
def __len__(self):
|
58 |
+
return len(self.annotation)
|
59 |
+
|
60 |
+
def process_text(self, ann):
|
61 |
+
instruction = ann["instruction"]
|
62 |
+
context = ann["context"]
|
63 |
+
response = ann["response"]
|
64 |
+
instruction = self.prompter(instruction=instruction, input=context)
|
65 |
+
return dict(instruction=instruction, answer=response)
|
66 |
+
|
67 |
+
def tokenize(self, text):
|
68 |
+
res = self.tokenizer(
|
69 |
+
text["instruction"] + text["answer"],
|
70 |
+
return_tensors=None,
|
71 |
+
padding="do_not_pad",
|
72 |
+
truncation=True,
|
73 |
+
max_length=512,
|
74 |
+
)
|
75 |
+
|
76 |
+
# manually add eos token
|
77 |
+
if res["input_ids"][-1] != self.tokenizer.eos_token_id and len(res["input_ids"]) < 512 and self.add_eos:
|
78 |
+
res["input_ids"].append(self.tokenizer.eos_token_id)
|
79 |
+
res["attention_mask"].append(1)
|
80 |
+
labels = copy.deepcopy(res["input_ids"])
|
81 |
+
# ignore instruction_token
|
82 |
+
if self.ignore_instruction:
|
83 |
+
instruction_token = self.tokenizer(
|
84 |
+
text["instruction"], return_tensors=None, padding="do_not_pad", truncation=True, max_length=512
|
85 |
+
)
|
86 |
+
labels = [-100] * len(instruction_token["input_ids"]) + labels[len(instruction_token["input_ids"]) :]
|
87 |
+
|
88 |
+
res.update(labels=labels)
|
89 |
+
return res
|
90 |
+
|
91 |
+
def __getitem__(self, index):
|
92 |
+
ann = self.annotation[index]
|
93 |
+
text = self.process_text(ann)
|
94 |
+
res = self.tokenize(text)
|
95 |
+
res.update(text)
|
96 |
+
return res
|
97 |
+
|
98 |
+
def collater(self, samples):
|
99 |
+
question_list, answer_list, input_id_list, attention_mask_list, labels_list = [], [], [], [], []
|
100 |
+
|
101 |
+
for sample in samples:
|
102 |
+
question_list.append(sample["instruction"])
|
103 |
+
answer_list.append(sample["answer"])
|
104 |
+
input_id_list.append(sample["input_ids"])
|
105 |
+
attention_mask_list.append(sample["attention_mask"])
|
106 |
+
labels_list.append(sample["labels"])
|
107 |
+
|
108 |
+
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
|
109 |
+
# same length to return tensors.
|
110 |
+
max_label_length = max(len(l) for l in labels_list)
|
111 |
+
padding_side = self.tokenizer.padding_side
|
112 |
+
padded_labels = []
|
113 |
+
for l in labels_list:
|
114 |
+
remainder = [-100] * (max_label_length - len(l))
|
115 |
+
if isinstance(l, list):
|
116 |
+
l = l + remainder if padding_side == "right" else remainder + l
|
117 |
+
elif padding_side == "right":
|
118 |
+
l = np.concatenate([l, remainder]).astype(np.int64)
|
119 |
+
else:
|
120 |
+
l = np.concatenate([remainder, l]).astype(np.int64)
|
121 |
+
padded_labels.append(l)
|
122 |
+
|
123 |
+
padded_samples = self.tokenizer.pad(
|
124 |
+
{"input_ids": input_id_list, "attention_mask": attention_mask_list, "labels": padded_labels},
|
125 |
+
return_tensors="pt",
|
126 |
+
padding="longest",
|
127 |
+
)
|
128 |
+
|
129 |
+
labels = padded_samples["labels"]
|
130 |
+
labels[labels == self.tokenizer.pad_token_id] = -100
|
131 |
+
labels[:, 0] = -100
|
132 |
+
return {
|
133 |
+
"input_ids": padded_samples["input_ids"],
|
134 |
+
"attention_mask": padded_samples["attention_mask"],
|
135 |
+
"labels": labels,
|
136 |
+
"instruction": question_list,
|
137 |
+
"answer": answer_list,
|
138 |
+
}
|
139 |
+
|
140 |
+
|
141 |
+
def build_dolly_dataset(
|
142 |
+
tokenizer,
|
143 |
+
ann_path="data/dolly/databricks-dolly-15k.jsonl",
|
144 |
+
**kwargs,
|
145 |
+
):
|
146 |
+
return DollyDataset(
|
147 |
+
tokenizer=tokenizer,
|
148 |
+
ann_path=ann_path,
|
149 |
+
**kwargs,
|
150 |
+
)
|
mmgpt/datasets/gqa_dataset.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from collections import defaultdict
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from .vqa_dataset import VQADataset
|
9 |
+
|
10 |
+
|
11 |
+
class GQADataset(VQADataset):
|
12 |
+
"""Visual Reasoning Dataset."""
|
13 |
+
|
14 |
+
def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs):
|
15 |
+
super().__init__(tokenizer, vis_processor, vis_root, ann_paths=[], **kwargs)
|
16 |
+
|
17 |
+
self.annotation = self.load_annotations(ann_paths)
|
18 |
+
if self.sample_image:
|
19 |
+
print("randomly sample one annotation for each image")
|
20 |
+
self.annotation = self.parse_annotation(self.annotation)
|
21 |
+
self._add_instance_ids()
|
22 |
+
self.answer_prob = 1.0
|
23 |
+
|
24 |
+
@staticmethod
|
25 |
+
def load_annotations(ann_paths):
|
26 |
+
annotation = []
|
27 |
+
for ann_path in ann_paths:
|
28 |
+
ann = json.load(open(ann_path, "r"))
|
29 |
+
for k, v in ann.items():
|
30 |
+
v["question_id"] = k
|
31 |
+
annotation.append(v)
|
32 |
+
return annotation
|
33 |
+
|
34 |
+
def parse_annotation(self, annotation):
|
35 |
+
image_list = defaultdict(list)
|
36 |
+
for ann in annotation:
|
37 |
+
image_list[ann["imageId"]].append(ann)
|
38 |
+
annotation = []
|
39 |
+
for ann_list in image_list.values():
|
40 |
+
annotation.append(random.choice(ann_list))
|
41 |
+
return annotation
|
42 |
+
|
43 |
+
def process_text(self, ann):
|
44 |
+
question = ann["question"]
|
45 |
+
|
46 |
+
answer = ann["answer"]
|
47 |
+
full_answer = ann["fullAnswer"]
|
48 |
+
|
49 |
+
# TODO: check which one is better
|
50 |
+
# Random select answer or full_answer
|
51 |
+
if random.random() < self.answer_prob:
|
52 |
+
select_answer = full_answer
|
53 |
+
else:
|
54 |
+
select_answer = answer
|
55 |
+
|
56 |
+
instruction = self.prompter(question)
|
57 |
+
return dict(instruction=instruction, answer=select_answer)
|
58 |
+
|
59 |
+
def process_image(self, ann):
|
60 |
+
image_path = os.path.join(self.vis_root, ann["imageId"] + ".jpg")
|
61 |
+
image = Image.open(image_path).convert("RGB")
|
62 |
+
|
63 |
+
image = self.vis_processor(image)
|
64 |
+
return image
|
65 |
+
|
66 |
+
|
67 |
+
def build_gqa_dataset(
|
68 |
+
tokenizer,
|
69 |
+
vis_processor,
|
70 |
+
vis_root="data/gqa/images",
|
71 |
+
ann_paths=[
|
72 |
+
"data/gqa/questions/train_all_questions/train_all_questions_0.json",
|
73 |
+
"data/gqa/questions/val_all_questions.json",
|
74 |
+
],
|
75 |
+
sample_image=False,
|
76 |
+
):
|
77 |
+
return GQADataset(
|
78 |
+
tokenizer=tokenizer,
|
79 |
+
vis_processor=vis_processor,
|
80 |
+
vis_root=vis_root,
|
81 |
+
ann_paths=ann_paths,
|
82 |
+
sample_image=sample_image,
|
83 |
+
)
|
mmgpt/datasets/llava_dataset.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .vqa_dataset import VQADataset
|
2 |
+
|
3 |
+
|
4 |
+
class LlavaDataset(VQADataset):
|
5 |
+
def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs):
|
6 |
+
super().__init__(tokenizer, vis_processor, vis_root, ann_paths, **kwargs)
|
7 |
+
|
8 |
+
def _add_instance_ids(self, key="id"):
|
9 |
+
for idx, ann in enumerate(self.annotation):
|
10 |
+
ann[key] = str(idx)
|
11 |
+
|
12 |
+
def process_text(self, ann):
|
13 |
+
question = ann["conversations"][0]["value"]
|
14 |
+
# remove '<image>' tag and '\n'
|
15 |
+
question = question.replace("<image>", "").replace("\n", "")
|
16 |
+
answer = ann["conversations"][1]["value"]
|
17 |
+
instruction = self.prompter(question)
|
18 |
+
return dict(instruction=instruction, answer=answer)
|
mmgpt/datasets/nlvr_dataset.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
from .vqa_dataset import VQADataset
|
12 |
+
|
13 |
+
QUESTIONS = [
|
14 |
+
"Is this true?",
|
15 |
+
"Is this right?",
|
16 |
+
"Can you confirm this information?" "Do you agree with this statement?",
|
17 |
+
"Does this align with your understanding?",
|
18 |
+
"How do you interpret this information?",
|
19 |
+
"Does this align with your understanding?",
|
20 |
+
"Can you confirm this?",
|
21 |
+
"Is this statement correct?",
|
22 |
+
"Could you verify this information?",
|
23 |
+
"Do you agree with this?",
|
24 |
+
"Is this accurate?",
|
25 |
+
"Can you validate this claim?",
|
26 |
+
"Are these details valid?",
|
27 |
+
"Is this factually correct?",
|
28 |
+
"Is the following information correct?",
|
29 |
+
"Could you please verify this fact?",
|
30 |
+
"Do you agree with this assertion?",
|
31 |
+
"Are these details accurate?",
|
32 |
+
"Does this claim hold true?",
|
33 |
+
]
|
34 |
+
|
35 |
+
|
36 |
+
class NLVRv1Dataset(VQADataset):
|
37 |
+
"""Visual Reasoning Dataset."""
|
38 |
+
|
39 |
+
def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs):
|
40 |
+
super().__init__(tokenizer, vis_processor, vis_root, ann_paths=[], **kwargs)
|
41 |
+
|
42 |
+
self.annotation = self.load_annotations(ann_paths)
|
43 |
+
if self.sample_image:
|
44 |
+
print("randomly sample one annotation for each image")
|
45 |
+
self.annotation = self.parse_annotation(self.annotation)
|
46 |
+
self._add_instance_ids()
|
47 |
+
|
48 |
+
@staticmethod
|
49 |
+
def load_annotations(ann_paths):
|
50 |
+
annotation = []
|
51 |
+
for ann_path in ann_paths:
|
52 |
+
if "train.json" in ann_path:
|
53 |
+
split = "train"
|
54 |
+
elif "dev.json" in ann_path:
|
55 |
+
split = "dev"
|
56 |
+
elif "test.json" in ann_path:
|
57 |
+
split = "test"
|
58 |
+
else:
|
59 |
+
raise ValueError(f"Unknown split for {ann_path}")
|
60 |
+
|
61 |
+
with open(ann_path, "r") as f:
|
62 |
+
for line in f.readlines():
|
63 |
+
line = line.strip()
|
64 |
+
if len(line) != 0:
|
65 |
+
ann = json.loads(line)
|
66 |
+
ann["split"] = split
|
67 |
+
annotation.append(ann)
|
68 |
+
|
69 |
+
return annotation
|
70 |
+
|
71 |
+
def parse_annotation(self, annotation):
|
72 |
+
image_list = defaultdict(list)
|
73 |
+
for ann in annotation:
|
74 |
+
img_key = f"{ann['split']}-{ann['identifier']}"
|
75 |
+
image_list[img_key].append(ann)
|
76 |
+
annotation = []
|
77 |
+
for ann_list in image_list.values():
|
78 |
+
annotation.append(random.choice(ann_list))
|
79 |
+
return annotation
|
80 |
+
|
81 |
+
def process_text(self, ann):
|
82 |
+
question = ann["sentence"] + " " + random.choice(QUESTIONS)
|
83 |
+
true_answer = ann["label"]
|
84 |
+
|
85 |
+
if random.random() < self.option_prob:
|
86 |
+
instruction = self.prompter(question, ["true", "false"])
|
87 |
+
else:
|
88 |
+
instruction = self.prompter(question)
|
89 |
+
|
90 |
+
return dict(instruction=instruction, answer=true_answer)
|
91 |
+
|
92 |
+
def process_image(self, ann):
|
93 |
+
# each question have 6 images, we can random select one of them.
|
94 |
+
# TODO: check whether using all 6 images?
|
95 |
+
random_id = random.randint(0, 5)
|
96 |
+
image_name = f"{ann['split']}-{ann['identifier']}-{random_id}.png"
|
97 |
+
image_path = os.path.join(self.vis_root, ann["split"], "images", ann["directory"], image_name)
|
98 |
+
image = Image.open(image_path).convert("RGB")
|
99 |
+
|
100 |
+
image = self.vis_processor(image)
|
101 |
+
return image
|
102 |
+
|
103 |
+
|
104 |
+
class NLVRv2Dataset(VQADataset):
|
105 |
+
"""Visual Reasoning Dataset."""
|
106 |
+
|
107 |
+
def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs):
|
108 |
+
super().__init__(tokenizer, vis_processor, vis_root, ann_paths, **kwargs)
|
109 |
+
self.flip_prob = 0.5
|
110 |
+
|
111 |
+
def parse_annotation(self, annotation):
|
112 |
+
image_list = defaultdict(list)
|
113 |
+
for ann in annotation:
|
114 |
+
image_list[ann["images"][0]].append(ann)
|
115 |
+
# image_name_list = list(image_list.keys())
|
116 |
+
annotation = []
|
117 |
+
for ann_list in image_list.values():
|
118 |
+
annotation.append(random.choice(ann_list))
|
119 |
+
return annotation
|
120 |
+
|
121 |
+
def process_text(self, ann):
|
122 |
+
question = ann["sentence"] + " " + random.choice(QUESTIONS)
|
123 |
+
true_answer = ann["label"]
|
124 |
+
|
125 |
+
if random.random() < self.option_prob:
|
126 |
+
instruction = self.prompter(question, ["true", "false"])
|
127 |
+
else:
|
128 |
+
instruction = self.prompter(question)
|
129 |
+
|
130 |
+
return dict(instruction=instruction, answer=true_answer)
|
131 |
+
|
132 |
+
def process_image(self, ann):
|
133 |
+
image_0_path = os.path.join(self.vis_root, ann["images"][0])
|
134 |
+
image_1_path = os.path.join(self.vis_root, ann["images"][1])
|
135 |
+
|
136 |
+
image_0 = Image.open(image_0_path).convert("RGB")
|
137 |
+
image_1 = Image.open(image_1_path).convert("RGB")
|
138 |
+
image_0 = self.vis_processor(image_0)
|
139 |
+
image_1 = self.vis_processor(image_1)
|
140 |
+
return image_0, image_1
|
141 |
+
|
142 |
+
@staticmethod
|
143 |
+
def _flip(samples):
|
144 |
+
sentence = samples["sentence"]
|
145 |
+
image0, image1 = samples["image0"], samples["image1"]
|
146 |
+
|
147 |
+
if "left" not in sentence and "right" not in sentence:
|
148 |
+
if random.random() < 0.5:
|
149 |
+
image0, image1 = image1, image0
|
150 |
+
else:
|
151 |
+
if random.random() < 0.5:
|
152 |
+
sentence = sentence.replace("left", "[TEMP_TOKEN]")
|
153 |
+
sentence = sentence.replace("right", "left")
|
154 |
+
sentence = sentence.replace("[TEMP_TOKEN]", "right")
|
155 |
+
|
156 |
+
image0, image1 = image1, image0
|
157 |
+
|
158 |
+
samples["sentence"] = sentence
|
159 |
+
samples["image0"] = image0
|
160 |
+
samples["image1"] = image1
|
161 |
+
|
162 |
+
return samples
|
163 |
+
|
164 |
+
def __getitem__(self, index):
|
165 |
+
ann = copy.deepcopy(self.annotation[index])
|
166 |
+
image_0, image_1 = self.process_image(ann)
|
167 |
+
if random.random() < self.flip_prob:
|
168 |
+
samples = self._flip({"sentence": ann["sentence"], "image0": image_0, "image1": image_1})
|
169 |
+
image_0, image_1 = samples["image0"], samples["image1"]
|
170 |
+
ann["sentence"] = samples["sentence"]
|
171 |
+
# concat
|
172 |
+
# TODO: https://github.com/salesforce/LAVIS/blob/main/lavis/models/blip_models/blip_nlvr.py
|
173 |
+
# model logic need update if using nlvr2
|
174 |
+
image = torch.cat([image_0, image_1], dim=2)
|
175 |
+
image = F.interpolate(image[None, ...], size=(image_0.shape[1], image_0.shape[2]))[0]
|
176 |
+
text = self.process_text(ann)
|
177 |
+
res = self.tokenize(text)
|
178 |
+
res.update(image=image)
|
179 |
+
res.update(text)
|
180 |
+
return res
|
181 |
+
|
182 |
+
|
183 |
+
def build_nlvrv1_dataset(
|
184 |
+
tokenizer,
|
185 |
+
vis_processor,
|
186 |
+
vis_root="data/nlvr",
|
187 |
+
ann_paths=["data/nlvr//train/train.json"],
|
188 |
+
sample_image=False,
|
189 |
+
):
|
190 |
+
return NLVRv1Dataset(
|
191 |
+
tokenizer=tokenizer,
|
192 |
+
vis_processor=vis_processor,
|
193 |
+
vis_root=vis_root,
|
194 |
+
ann_paths=ann_paths,
|
195 |
+
sample_image=sample_image,
|
196 |
+
)
|
197 |
+
|
198 |
+
|
199 |
+
def build_nlvrv2_dataset(
|
200 |
+
tokenizer,
|
201 |
+
vis_processor,
|
202 |
+
vis_root="data/nlvr2",
|
203 |
+
ann_paths=["data/nlvr2/annotations/nlvr_train.json"],
|
204 |
+
sample_image=False,
|
205 |
+
):
|
206 |
+
return NLVRv2Dataset(
|
207 |
+
tokenizer=tokenizer,
|
208 |
+
vis_processor=vis_processor,
|
209 |
+
vis_root=vis_root,
|
210 |
+
ann_paths=ann_paths,
|
211 |
+
sample_image=sample_image,
|
212 |
+
)
|
mmgpt/datasets/ocr_vqa_dataset.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
from .vqa_dataset import VQADataset
|
7 |
+
|
8 |
+
|
9 |
+
class OCRVQADataset(VQADataset):
|
10 |
+
def process_image(self, ann):
|
11 |
+
image_path = os.path.join(self.vis_root, ann["filename"])
|
12 |
+
image = Image.open(image_path).convert("RGB")
|
13 |
+
|
14 |
+
image = self.vis_processor(image)
|
15 |
+
return image
|
16 |
+
|
17 |
+
def process_text(self, ann):
|
18 |
+
index = random.choice(list(range(len(ann["questions"]))))
|
19 |
+
question = ann["questions"][index]
|
20 |
+
answer = ann["answers"][index]
|
21 |
+
|
22 |
+
instruction = self.prompter(question)
|
23 |
+
return dict(instruction=instruction, answer=answer)
|
mmgpt/datasets/samplers/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .infinite_sampler import InfiniteSampler
|
mmgpt/datasets/samplers/infinite_sampler.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.utils.data.sampler import Sampler
|
5 |
+
|
6 |
+
from mmgpt.train.distributed import world_info_from_env
|
7 |
+
|
8 |
+
|
9 |
+
class InfiniteSampler(Sampler):
|
10 |
+
def __init__(self, dataset: int, shuffle: bool = True, seed: int = 0):
|
11 |
+
self._size = len(dataset)
|
12 |
+
self._shuffle = shuffle
|
13 |
+
self._seed = int(seed)
|
14 |
+
_, rank, world_size = world_info_from_env()
|
15 |
+
|
16 |
+
self._rank = rank
|
17 |
+
self._world_size = world_size
|
18 |
+
|
19 |
+
def __iter__(self):
|
20 |
+
start = self._rank
|
21 |
+
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
|
22 |
+
|
23 |
+
def _infinite_indices(self):
|
24 |
+
g = torch.Generator()
|
25 |
+
g.manual_seed(self._seed)
|
26 |
+
while True:
|
27 |
+
if self._shuffle:
|
28 |
+
yield from torch.randperm(self._size, generator=g).tolist()
|
29 |
+
else:
|
30 |
+
yield from torch.arange(self._size).tolist()
|
mmgpt/datasets/snli_ve_datasets.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from collections import defaultdict
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from .vqa_dataset import VQADataset
|
9 |
+
|
10 |
+
QUESTIONS = [
|
11 |
+
"What do you think of the above sentence?",
|
12 |
+
"Can you confirm this statement?",
|
13 |
+
"How do you interpret the given information?",
|
14 |
+
"What is your opinion on this matter?",
|
15 |
+
"Could you provide your perspective on this statement?",
|
16 |
+
"How would you respond to the provided claim?",
|
17 |
+
"What are your thoughts regarding the mentioned subject?",
|
18 |
+
"Can you elaborate on this idea in English?",
|
19 |
+
"Do you have any insights or feedback on this topic?",
|
20 |
+
"What's your take on the given statement?",
|
21 |
+
"What is your perspective on the given statement?",
|
22 |
+
"How would you interpret this remark?",
|
23 |
+
"Could you please provide your opinion on this?",
|
24 |
+
"Can you share your understanding of the above point?",
|
25 |
+
"Would you mind elaborating on this topic?",
|
26 |
+
"What are your views about the given statement?",
|
27 |
+
"How do you feel about the presented information?",
|
28 |
+
"Could you provide your perspective on this?",
|
29 |
+
"What is your opinion regarding this statement?",
|
30 |
+
"Can you share your thoughts about the mentioned claim?",
|
31 |
+
"How would you interpret the above comment?",
|
32 |
+
"Would you mind sharing your insights on this issue?",
|
33 |
+
]
|
34 |
+
|
35 |
+
|
36 |
+
class SNLIVEDataset(VQADataset):
|
37 |
+
"""Visual Reasoning Dataset."""
|
38 |
+
|
39 |
+
def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs):
|
40 |
+
super().__init__(tokenizer, vis_processor, vis_root, ann_paths=[], **kwargs)
|
41 |
+
|
42 |
+
self.annotation = self.load_annotations(ann_paths)
|
43 |
+
if self.sample_image:
|
44 |
+
print("randomly sample one annotation for each image")
|
45 |
+
self.annotation = self.parse_annotation(self.annotation)
|
46 |
+
self._add_instance_ids()
|
47 |
+
|
48 |
+
@staticmethod
|
49 |
+
def load_annotations(ann_paths):
|
50 |
+
annotation = []
|
51 |
+
for ann_path in ann_paths:
|
52 |
+
with open(ann_path, "r") as f:
|
53 |
+
for line in f.readlines():
|
54 |
+
line = line.strip()
|
55 |
+
if len(line) != 0:
|
56 |
+
ann = json.loads(line)
|
57 |
+
annotation.append(ann)
|
58 |
+
return annotation
|
59 |
+
|
60 |
+
def parse_annotation(self, annotation):
|
61 |
+
image_list = defaultdict(list)
|
62 |
+
for ann in annotation:
|
63 |
+
image_list[ann["Flickr30K_ID"]].append(ann)
|
64 |
+
annotation = []
|
65 |
+
for ann_list in image_list.values():
|
66 |
+
annotation.append(random.choice(ann_list))
|
67 |
+
return annotation
|
68 |
+
|
69 |
+
def process_text(self, ann):
|
70 |
+
question = ann["sentence2"] + " " + random.choice(QUESTIONS)
|
71 |
+
answer = ann["gold_label"]
|
72 |
+
if random.random() < self.option_prob:
|
73 |
+
instruction = self.prompter(question, ["entailment", "neutral", "contradiction"])
|
74 |
+
else:
|
75 |
+
instruction = self.prompter(question)
|
76 |
+
return dict(instruction=instruction, answer=answer)
|
77 |
+
|
78 |
+
def process_image(self, ann):
|
79 |
+
image_path = os.path.join(self.vis_root, ann["Flickr30K_ID"] + ".jpg")
|
80 |
+
image = Image.open(image_path).convert("RGB")
|
81 |
+
image = self.vis_processor(image)
|
82 |
+
return image
|
mmgpt/datasets/text_ocr_dataset.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
from transformers import LlamaTokenizer
|
8 |
+
|
9 |
+
from .vqa_dataset import VQADataset, VQAPrompter
|
10 |
+
|
11 |
+
|
12 |
+
class TextOCRDataset(VQADataset):
|
13 |
+
def __init__(
|
14 |
+
self, tokenizer, vis_processor=None, vis_root=None, ann_paths=[], add_eos=True, ignore_instruction=True
|
15 |
+
):
|
16 |
+
"""
|
17 |
+
vis_root (string): Root directory of images (e.g. coco/images/)
|
18 |
+
ann_root (string): directory to store the annotation file
|
19 |
+
"""
|
20 |
+
assert tokenizer.add_eos_token is False, "tokenizer should not add eos token by default"
|
21 |
+
self.tokenizer: LlamaTokenizer = tokenizer
|
22 |
+
self.vis_root = vis_root
|
23 |
+
|
24 |
+
self.annotation = []
|
25 |
+
for ann_path in ann_paths:
|
26 |
+
self.annotation.extend(json.load(open(ann_path, "r"))["data"])
|
27 |
+
|
28 |
+
self.vis_processor = vis_processor
|
29 |
+
|
30 |
+
self._add_instance_ids()
|
31 |
+
self.option_prob = 0.5
|
32 |
+
self.prompter = VQAPrompter()
|
33 |
+
self.add_eos = add_eos
|
34 |
+
self.ignore_instruction = ignore_instruction
|
35 |
+
|
36 |
+
def process_image(self, ann):
|
37 |
+
image_path = os.path.join(self.vis_root, ann["image_id"] + ".jpg")
|
38 |
+
image = Image.open(image_path).convert("RGB")
|
39 |
+
|
40 |
+
image = self.vis_processor(image)
|
41 |
+
return image
|
42 |
+
|
43 |
+
def process_text(self, ann):
|
44 |
+
question = ann["question"]
|
45 |
+
|
46 |
+
answer_weight = {}
|
47 |
+
for answer in ann["answers"]:
|
48 |
+
if answer in answer_weight.keys():
|
49 |
+
answer_weight[answer] += 1 / len(ann["answers"])
|
50 |
+
else:
|
51 |
+
answer_weight[answer] = 1 / len(ann["answers"])
|
52 |
+
|
53 |
+
answers = list(answer_weight.keys())
|
54 |
+
weights = list(answer_weight.values())
|
55 |
+
|
56 |
+
# create instruction
|
57 |
+
true_answer = answers[np.argmax(weights)]
|
58 |
+
is_option = random.random() < self.option_prob and len(answers) > 1
|
59 |
+
if is_option:
|
60 |
+
instruction = self.prompter(question, answers)
|
61 |
+
else:
|
62 |
+
instruction = self.prompter(question)
|
63 |
+
|
64 |
+
return dict(instruction=instruction, answer=true_answer)
|
mmgpt/datasets/vqa_dataset.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
All rights reserved.
|
4 |
+
SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
"""
|
7 |
+
|
8 |
+
import copy
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
import random
|
12 |
+
from collections import defaultdict
|
13 |
+
from typing import Iterable
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
from PIL import Image
|
18 |
+
from torch.utils.data import ConcatDataset, Dataset
|
19 |
+
from torch.utils.data.dataloader import default_collate
|
20 |
+
from transformers import LlamaTokenizer
|
21 |
+
|
22 |
+
TEMPLATE = {
|
23 |
+
"description": "Template used by Alpaca-LoRA.",
|
24 |
+
# "prompt_choice": "Below is a multiple choice question about an image, along with answer options. Please choose the correct answer from these options.\n\n### Image:\n{image}\n\n### Question:\n{question}\n\n### Input:\n{options}\n\n### Answer:\n",
|
25 |
+
# "prompt_qa": "Below is a question about an image. Write a response to answer the question.\n\n### Image:\n{image}\n\n### Question:\n{question}\n\n### Answer:\n",
|
26 |
+
"prompt_choice": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Image:\n{image}\n\n### Instruction:\n{question}\n\n### Input:\n{options}\n\n### Response:\n",
|
27 |
+
"prompt_qa": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Image:\n{image}\n\n### Instruction:\n{question}\n\n### Response:\n",
|
28 |
+
"response_split": "### Response:",
|
29 |
+
}
|
30 |
+
|
31 |
+
|
32 |
+
class VQAPrompter:
|
33 |
+
def __call__(self, question, options=None):
|
34 |
+
if options:
|
35 |
+
options = ", ".join(options)
|
36 |
+
res = TEMPLATE["prompt_choice"].format(image="<image>", question=question, options=options)
|
37 |
+
else:
|
38 |
+
res = TEMPLATE["prompt_qa"].format(image="<image>", question=question)
|
39 |
+
return res
|
40 |
+
|
41 |
+
def get_response(self, output: str) -> str:
|
42 |
+
return output.split(TEMPLATE["response_split"])[-1].strip()
|
43 |
+
|
44 |
+
|
45 |
+
class VQADataset(Dataset):
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
tokenizer,
|
49 |
+
vis_processor=None,
|
50 |
+
vis_root=None,
|
51 |
+
ann_paths=[],
|
52 |
+
add_eos=True,
|
53 |
+
ignore_instruction=True,
|
54 |
+
sample_image=False,
|
55 |
+
):
|
56 |
+
"""
|
57 |
+
vis_root (string): Root directory of images (e.g. coco/images/)
|
58 |
+
ann_root (string): directory to store the annotation file
|
59 |
+
"""
|
60 |
+
assert tokenizer.add_eos_token is False, "tokenizer should not add eos token by default"
|
61 |
+
self.tokenizer: LlamaTokenizer = tokenizer
|
62 |
+
self.vis_root = vis_root
|
63 |
+
|
64 |
+
self.annotation = []
|
65 |
+
for ann_path in ann_paths:
|
66 |
+
self.annotation.extend(json.load(open(ann_path, "r")))
|
67 |
+
|
68 |
+
self.sample_image = sample_image
|
69 |
+
if self.sample_image:
|
70 |
+
print("randomly sample one annotation for each image")
|
71 |
+
self.annotation = self.parse_annotation(self.annotation)
|
72 |
+
|
73 |
+
self.vis_processor = vis_processor
|
74 |
+
|
75 |
+
self._add_instance_ids()
|
76 |
+
self.option_prob = 0.5
|
77 |
+
self.prompter = VQAPrompter()
|
78 |
+
self.add_eos = add_eos
|
79 |
+
self.ignore_instruction = ignore_instruction
|
80 |
+
|
81 |
+
def parse_annotation(self, annotation):
|
82 |
+
image_list = defaultdict(list)
|
83 |
+
for ann in annotation:
|
84 |
+
image_list[ann["image"]].append(ann)
|
85 |
+
# image_name_list = list(image_list.keys())
|
86 |
+
annotation = []
|
87 |
+
for ann_list in image_list.values():
|
88 |
+
annotation.append(random.choice(ann_list))
|
89 |
+
return annotation
|
90 |
+
|
91 |
+
def __len__(self):
|
92 |
+
return len(self.annotation)
|
93 |
+
|
94 |
+
def _add_instance_ids(self, key="instance_id"):
|
95 |
+
for idx, ann in enumerate(self.annotation):
|
96 |
+
ann[key] = str(idx)
|
97 |
+
|
98 |
+
def process_image(self, ann):
|
99 |
+
image_path = os.path.join(self.vis_root, ann["image"])
|
100 |
+
image = Image.open(image_path).convert("RGB")
|
101 |
+
|
102 |
+
image = self.vis_processor(image)
|
103 |
+
return image
|
104 |
+
|
105 |
+
def process_text(self, ann):
|
106 |
+
question = ann["question"]
|
107 |
+
|
108 |
+
answer_weight = {}
|
109 |
+
for answer in ann["answer"]:
|
110 |
+
if answer in answer_weight.keys():
|
111 |
+
answer_weight[answer] += 1 / len(ann["answer"])
|
112 |
+
else:
|
113 |
+
answer_weight[answer] = 1 / len(ann["answer"])
|
114 |
+
|
115 |
+
answers = list(answer_weight.keys())
|
116 |
+
weights = list(answer_weight.values())
|
117 |
+
|
118 |
+
# create instruction
|
119 |
+
true_answer = answers[np.argmax(weights)]
|
120 |
+
is_option = random.random() < self.option_prob and len(answers) > 1
|
121 |
+
if is_option:
|
122 |
+
instruction = self.prompter(question, answers)
|
123 |
+
else:
|
124 |
+
instruction = self.prompter(question)
|
125 |
+
|
126 |
+
return dict(instruction=instruction, answer=true_answer)
|
127 |
+
|
128 |
+
def tokenize(self, text):
|
129 |
+
res = self.tokenizer(
|
130 |
+
text["instruction"] + text["answer"],
|
131 |
+
return_tensors=None,
|
132 |
+
padding="do_not_pad",
|
133 |
+
truncation=True,
|
134 |
+
max_length=512,
|
135 |
+
)
|
136 |
+
|
137 |
+
# manually add eos token
|
138 |
+
if res["input_ids"][-1] != self.tokenizer.eos_token_id and len(res["input_ids"]) < 512 and self.add_eos:
|
139 |
+
res["input_ids"].append(self.tokenizer.eos_token_id)
|
140 |
+
res["attention_mask"].append(1)
|
141 |
+
labels = copy.deepcopy(res["input_ids"])
|
142 |
+
# ignore instruction_token
|
143 |
+
if self.ignore_instruction:
|
144 |
+
instruction_token = self.tokenizer(
|
145 |
+
text["instruction"], return_tensors=None, padding="do_not_pad", truncation=True, max_length=512
|
146 |
+
)
|
147 |
+
labels = [-100] * len(instruction_token["input_ids"]) + labels[len(instruction_token["input_ids"]) :]
|
148 |
+
|
149 |
+
res.update(labels=labels)
|
150 |
+
return res
|
151 |
+
|
152 |
+
def __getitem__(self, index):
|
153 |
+
ann = self.annotation[index]
|
154 |
+
image = self.process_image(ann)
|
155 |
+
text = self.process_text(ann)
|
156 |
+
res = self.tokenize(text)
|
157 |
+
res.update(image=image)
|
158 |
+
res.update(text)
|
159 |
+
return res
|
160 |
+
|
161 |
+
def collater(self, samples):
|
162 |
+
image_list, question_list, answer_list, input_id_list, attention_mask_list, labels_list = [], [], [], [], [], []
|
163 |
+
|
164 |
+
for sample in samples:
|
165 |
+
image_list.append(sample["image"])
|
166 |
+
question_list.append(sample["instruction"])
|
167 |
+
answer_list.append(sample["answer"])
|
168 |
+
input_id_list.append(sample["input_ids"])
|
169 |
+
attention_mask_list.append(sample["attention_mask"])
|
170 |
+
labels_list.append(sample["labels"])
|
171 |
+
|
172 |
+
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
|
173 |
+
# same length to return tensors.
|
174 |
+
max_label_length = max(len(l) for l in labels_list)
|
175 |
+
padding_side = self.tokenizer.padding_side
|
176 |
+
padded_labels = []
|
177 |
+
for l in labels_list:
|
178 |
+
remainder = [-100] * (max_label_length - len(l))
|
179 |
+
if isinstance(l, list):
|
180 |
+
l = l + remainder if padding_side == "right" else remainder + l
|
181 |
+
elif padding_side == "right":
|
182 |
+
l = np.concatenate([l, remainder]).astype(np.int64)
|
183 |
+
else:
|
184 |
+
l = np.concatenate([remainder, l]).astype(np.int64)
|
185 |
+
padded_labels.append(l)
|
186 |
+
|
187 |
+
padded_samples = self.tokenizer.pad(
|
188 |
+
{"input_ids": input_id_list, "attention_mask": attention_mask_list, "labels": padded_labels},
|
189 |
+
return_tensors="pt",
|
190 |
+
padding="longest",
|
191 |
+
)
|
192 |
+
|
193 |
+
labels = padded_samples["labels"]
|
194 |
+
media_token_id = self.tokenizer("<image>", add_special_tokens=False)["input_ids"][-1]
|
195 |
+
labels[labels == self.tokenizer.pad_token_id] = -100
|
196 |
+
labels[:, 0] = -100
|
197 |
+
labels[labels == media_token_id] = -100
|
198 |
+
return {
|
199 |
+
"image": torch.stack(image_list, dim=0),
|
200 |
+
"input_ids": padded_samples["input_ids"],
|
201 |
+
"attention_mask": padded_samples["attention_mask"],
|
202 |
+
"labels": labels,
|
203 |
+
"instruction": question_list,
|
204 |
+
"answer": answer_list,
|
205 |
+
}
|
206 |
+
|
207 |
+
|
208 |
+
class ConcatDataset(ConcatDataset):
|
209 |
+
def __init__(self, datasets: Iterable[Dataset]) -> None:
|
210 |
+
super().__init__(datasets)
|
211 |
+
|
212 |
+
def collater(self, samples):
|
213 |
+
# TODO For now only supports datasets with same underlying collater implementations
|
214 |
+
|
215 |
+
all_keys = set()
|
216 |
+
for s in samples:
|
217 |
+
all_keys.update(s)
|
218 |
+
|
219 |
+
shared_keys = all_keys
|
220 |
+
for s in samples:
|
221 |
+
shared_keys = shared_keys & set(s.keys())
|
222 |
+
|
223 |
+
samples_shared_keys = []
|
224 |
+
for s in samples:
|
225 |
+
samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
|
226 |
+
|
227 |
+
return self.datasets[0].collater(samples_shared_keys)
|
mmgpt/models/__init__.py
ADDED
File without changes
|
mmgpt/models/blip2/__init__.py
ADDED
File without changes
|
mmgpt/models/builder.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .open_flamingo import create_model_and_transforms as create_open_flamingo_model_and_transforms
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import LlamaTokenizer, LlamaForCausalLM
|
4 |
+
|
5 |
+
def create_model_and_transforms(
|
6 |
+
model_name: str,
|
7 |
+
clip_vision_encoder_path: str,
|
8 |
+
clip_vision_encoder_pretrained: str,
|
9 |
+
lang_encoder_path: str,
|
10 |
+
tokenizer_path: str,
|
11 |
+
tuning_config,
|
12 |
+
pretrained_model_path,
|
13 |
+
**kwargs,
|
14 |
+
):
|
15 |
+
if model_name == "open_flamingo":
|
16 |
+
return create_open_flamingo_model_and_transforms(
|
17 |
+
clip_vision_encoder_path=clip_vision_encoder_path,
|
18 |
+
clip_vision_encoder_pretrained=clip_vision_encoder_pretrained,
|
19 |
+
lang_encoder_path=lang_encoder_path,
|
20 |
+
tokenizer_path=tokenizer_path,
|
21 |
+
tuning_config=tuning_config,
|
22 |
+
pretrained_model_path=pretrained_model_path,
|
23 |
+
**kwargs,
|
24 |
+
)
|
25 |
+
# TODO: support BLIP2
|
26 |
+
else:
|
27 |
+
raise ValueError(f"Unknown model name: {model_name}")
|
28 |
+
|
29 |
+
# only for debugging
|
30 |
+
def create_toy_model_and_transforms(
|
31 |
+
model_name: str,
|
32 |
+
clip_vision_encoder_path: str,
|
33 |
+
clip_vision_encoder_pretrained: str,
|
34 |
+
lang_encoder_path: str,
|
35 |
+
tokenizer_path: str,
|
36 |
+
tuning_config,
|
37 |
+
pretrained_model_path,
|
38 |
+
**kwargs,
|
39 |
+
):
|
40 |
+
print("init toy vision encoder")
|
41 |
+
import torchvision
|
42 |
+
|
43 |
+
image_processor = torchvision.transforms.Compose(
|
44 |
+
[
|
45 |
+
torchvision.transforms.Resize((224, 224)),
|
46 |
+
torchvision.transforms.ToTensor(),
|
47 |
+
]
|
48 |
+
)
|
49 |
+
print("init tokenizer")
|
50 |
+
text_tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path)
|
51 |
+
# add Flamingo special tokens to the tokenizer
|
52 |
+
text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
|
53 |
+
if text_tokenizer.pad_token is None:
|
54 |
+
# Issue: GPT models don't have a pad token, which we use to
|
55 |
+
# modify labels for the loss.
|
56 |
+
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
57 |
+
|
58 |
+
class ToyModel(nn.Module):
|
59 |
+
def __init__(self, *args, **kwargs):
|
60 |
+
super().__init__()
|
61 |
+
self.input_embeddings = nn.Embedding(38000, 512)
|
62 |
+
self.layer = nn.Linear(512, 512)
|
63 |
+
self.config = {"hidden_size": 512}
|
64 |
+
|
65 |
+
def forward(self, lang_x, **kwargs):
|
66 |
+
x = self.input_embeddings(lang_x)
|
67 |
+
x = self.layer(x)
|
68 |
+
loss = x.sum()
|
69 |
+
|
70 |
+
return (loss,)
|
71 |
+
|
72 |
+
model = ToyModel()
|
73 |
+
|
74 |
+
return model, image_processor, text_tokenizer
|
mmgpt/models/open_flamingo/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .builder import create_model_and_transforms
|
2 |
+
from .flamingo import Flamingo
|
3 |
+
from .flamingo_lm import FlamingoLMMixin
|
mmgpt/models/open_flamingo/builder.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Modified from https://github.com/mlfoundations/open_flamingo"""
|
2 |
+
import open_clip
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from bigmodelvis import Visualization
|
6 |
+
from peft import LoraConfig, get_peft_model
|
7 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer
|
8 |
+
|
9 |
+
from .flamingo import Flamingo
|
10 |
+
from .flamingo_lm import FlamingoLMMixin
|
11 |
+
from .utils import extend_instance
|
12 |
+
|
13 |
+
|
14 |
+
def create_model_and_transforms(
|
15 |
+
clip_vision_encoder_path: str,
|
16 |
+
clip_vision_encoder_pretrained: str,
|
17 |
+
lang_encoder_path: str,
|
18 |
+
tokenizer_path: str,
|
19 |
+
decoder_layers_attr_name: str = None,
|
20 |
+
pretrained_model_path: str = None,
|
21 |
+
tuning_config=None,
|
22 |
+
**flamingo_kwargs,
|
23 |
+
):
|
24 |
+
"""
|
25 |
+
Initialize a Flamingo model from a pretrained vision encoder and language encoder.
|
26 |
+
Appends special tokens to the tokenizer and freezes backbones.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32")
|
30 |
+
clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k")
|
31 |
+
lang_encoder_path (str): path to pretrained language encoder
|
32 |
+
tokenizer_path (str): path to pretrained tokenizer
|
33 |
+
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
|
34 |
+
Returns:
|
35 |
+
Flamingo: Flamingo model from pretrained vision and language encoders
|
36 |
+
Image processor: Pipeline to preprocess input images
|
37 |
+
Tokenizer: A tokenizer for the language model
|
38 |
+
"""
|
39 |
+
print("init clip vision encoder")
|
40 |
+
vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
|
41 |
+
clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained
|
42 |
+
)
|
43 |
+
# set the vision encoder to output the visual features
|
44 |
+
vision_encoder.visual.output_tokens = True
|
45 |
+
print("init tokenizer")
|
46 |
+
text_tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path)
|
47 |
+
# add Flamingo special tokens to the tokenizer
|
48 |
+
text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
|
49 |
+
if text_tokenizer.pad_token is None:
|
50 |
+
# Issue: GPT models don't have a pad token, which we use to
|
51 |
+
# modify labels for the loss.
|
52 |
+
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
53 |
+
text_tokenizer.bos_token_id = 1
|
54 |
+
text_tokenizer.eos_token_id = 2
|
55 |
+
|
56 |
+
print("init llama")
|
57 |
+
lang_encoder = LlamaForCausalLM.from_pretrained(lang_encoder_path)
|
58 |
+
extend_instance(lang_encoder, FlamingoLMMixin)
|
59 |
+
|
60 |
+
if decoder_layers_attr_name is None:
|
61 |
+
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
|
62 |
+
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
|
63 |
+
lang_encoder.resize_token_embeddings(len(text_tokenizer))
|
64 |
+
|
65 |
+
model = Flamingo(
|
66 |
+
vision_encoder,
|
67 |
+
lang_encoder,
|
68 |
+
text_tokenizer.encode("<|endofchunk|>")[-1],
|
69 |
+
text_tokenizer.encode("<image>")[-1],
|
70 |
+
vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"],
|
71 |
+
cross_attn_every_n_layers=4,
|
72 |
+
**flamingo_kwargs,
|
73 |
+
)
|
74 |
+
|
75 |
+
if pretrained_model_path is not None:
|
76 |
+
print(f"loading pretrained model from {pretrained_model_path}")
|
77 |
+
model.load_state_dict(torch.load(pretrained_model_path), strict=False)
|
78 |
+
|
79 |
+
# Freeze all parameters
|
80 |
+
model.requires_grad_(False)
|
81 |
+
assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
|
82 |
+
|
83 |
+
if tuning_config is not None:
|
84 |
+
model = prepare_model_for_tuning(model, tuning_config)
|
85 |
+
else:
|
86 |
+
raise ValueError("tuning_config must be provided")
|
87 |
+
|
88 |
+
print(
|
89 |
+
f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
|
90 |
+
)
|
91 |
+
|
92 |
+
return model, image_processor, text_tokenizer
|
93 |
+
|
94 |
+
|
95 |
+
def _infer_decoder_layers_attr_name(model):
|
96 |
+
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
|
97 |
+
if k.lower() in model.__class__.__name__.lower():
|
98 |
+
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
|
99 |
+
|
100 |
+
raise ValueError(
|
101 |
+
f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
|
102 |
+
)
|
103 |
+
|
104 |
+
|
105 |
+
__KNOWN_DECODER_LAYERS_ATTR_NAMES = {
|
106 |
+
"opt": "model.decoder.layers",
|
107 |
+
"gptneo": "transformer.h",
|
108 |
+
"gptj": "transformer.h",
|
109 |
+
"gpt-j": "transformer.h",
|
110 |
+
"pythia": "gpt_neox.layers",
|
111 |
+
"llama": "model.layers",
|
112 |
+
}
|
113 |
+
|
114 |
+
|
115 |
+
def prepare_model_for_tuning(model: nn.Module, config):
|
116 |
+
if config.lora:
|
117 |
+
lora_config = LoraConfig(
|
118 |
+
r=config.lora_r,
|
119 |
+
lora_alpha=config.lora_alpha,
|
120 |
+
target_modules=config.lora_target_modules,
|
121 |
+
lora_dropout=config.lora_dropout,
|
122 |
+
bias="none", # won't use bias currently
|
123 |
+
modules_to_save=[], # TODO: might be helpful if save partial model
|
124 |
+
task_type="CAUSAL_LM",
|
125 |
+
)
|
126 |
+
model.lang_encoder = get_peft_model(model.lang_encoder, peft_config=lora_config)
|
127 |
+
|
128 |
+
# manually unfreeze modules, we use a `substring` fashion mathcing
|
129 |
+
for name, param in model.named_parameters():
|
130 |
+
if any(substr in name for substr in config.unfrozen):
|
131 |
+
param.requires_grad = True
|
132 |
+
|
133 |
+
if config.vis and is_rank0():
|
134 |
+
Visualization(model).structure_graph()
|
135 |
+
return model
|
136 |
+
|
137 |
+
|
138 |
+
# temporary workaround, should use a common utils in the future
|
139 |
+
def is_rank0():
|
140 |
+
if not torch.distributed.is_initialized():
|
141 |
+
return True
|
142 |
+
return torch.distributed.get_rank() == 0
|
mmgpt/models/open_flamingo/flamingo.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Modified from https://github.com/mlfoundations/open_flamingo"""
|
2 |
+
import torch
|
3 |
+
from einops import rearrange
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from .helpers import PerceiverResampler
|
7 |
+
|
8 |
+
|
9 |
+
class Flamingo(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
vision_encoder: nn.Module,
|
13 |
+
lang_encoder: nn.Module,
|
14 |
+
eoc_token_id: int,
|
15 |
+
media_token_id: int,
|
16 |
+
vis_dim: int,
|
17 |
+
cross_attn_every_n_layers: int = 1,
|
18 |
+
use_media_placement_augmentation: bool = False,
|
19 |
+
):
|
20 |
+
"""
|
21 |
+
Args:
|
22 |
+
vision_encoder (nn.Module): HF CLIPModel
|
23 |
+
lang_encoder (nn.Module): HF causal language model
|
24 |
+
eoc_token_id (int): Token id for <|endofchunk|>
|
25 |
+
media_token_id (int): Token id for <image>
|
26 |
+
vis_dim (int): Dimension of the visual features.
|
27 |
+
Visual features are projected to match this shape along the last dimension.
|
28 |
+
cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
|
29 |
+
use_media_placement_augmentation (bool, optional): Whether to randomly assign images to the preceding or following text in training. Defaults to False.
|
30 |
+
"""
|
31 |
+
super().__init__()
|
32 |
+
self.eoc_token_id = eoc_token_id
|
33 |
+
self.media_token_id = media_token_id
|
34 |
+
self.use_media_placement_augmentation = use_media_placement_augmentation
|
35 |
+
self.vis_dim = vis_dim
|
36 |
+
self.vision_encoder = vision_encoder
|
37 |
+
self.perceiver = PerceiverResampler(dim=self.vis_dim)
|
38 |
+
self.lang_encoder = lang_encoder
|
39 |
+
self.lang_encoder.init_flamingo(
|
40 |
+
media_token_id=media_token_id,
|
41 |
+
vis_hidden_size=self.vis_dim,
|
42 |
+
cross_attn_every_n_layers=cross_attn_every_n_layers,
|
43 |
+
use_media_placement_augmentation=self.use_media_placement_augmentation,
|
44 |
+
)
|
45 |
+
|
46 |
+
def forward(
|
47 |
+
self,
|
48 |
+
vision_x: torch.Tensor,
|
49 |
+
lang_x: torch.Tensor,
|
50 |
+
attention_mask: torch.Tensor = None,
|
51 |
+
labels: torch.Tensor = None,
|
52 |
+
use_cached_vision_x: bool = False,
|
53 |
+
clear_conditioned_layers: bool = True,
|
54 |
+
past_key_values=None,
|
55 |
+
use_cache: bool = False,
|
56 |
+
):
|
57 |
+
"""
|
58 |
+
Forward pass of Flamingo.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
vision_x (torch.Tensor): Vision input
|
62 |
+
shape (B, T_img, F, C, H, W) with F=1
|
63 |
+
lang_x (torch.Tensor): Language input ids
|
64 |
+
shape (B, T_txt)
|
65 |
+
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
|
66 |
+
labels (torch.Tensor, optional): Labels. Defaults to None.
|
67 |
+
clear_conditioned_layers: if True, clear the conditioned layers
|
68 |
+
once the foward pass is completed. Set this to false if the
|
69 |
+
same set of images will be reused in another subsequent
|
70 |
+
forward pass.
|
71 |
+
past_key_values: pre-computed values to pass to language model.
|
72 |
+
See past_key_values documentation in Hugging Face
|
73 |
+
CausalLM models.
|
74 |
+
use_cache: whether to use cached key values. See use_cache
|
75 |
+
documentation in Hugging Face CausalLM models.
|
76 |
+
"""
|
77 |
+
if vision_x is None and use_cached_vision_x is False:
|
78 |
+
for layer in self.lang_encoder._get_decoder_layers():
|
79 |
+
layer.condition_only_lang_x(True)
|
80 |
+
output = self.lang_encoder(
|
81 |
+
input_ids=lang_x,
|
82 |
+
attention_mask=attention_mask,
|
83 |
+
labels=labels,
|
84 |
+
past_key_values=past_key_values,
|
85 |
+
use_cache=use_cache,
|
86 |
+
)
|
87 |
+
for layer in self.lang_encoder._get_decoder_layers():
|
88 |
+
layer.condition_only_lang_x(False)
|
89 |
+
return output
|
90 |
+
assert (
|
91 |
+
vision_x is not None
|
92 |
+
) or use_cached_vision_x, "Must provide either vision_x or use_cached_vision_x to True."
|
93 |
+
|
94 |
+
if use_cached_vision_x:
|
95 |
+
# Case: use cached; vision_x should be cached and other
|
96 |
+
# vision-related inputs should not be provided.
|
97 |
+
assert vision_x is None, "Expect vision_x to be None when use_cached_vision_x is True."
|
98 |
+
assert self.lang_encoder.is_conditioned()
|
99 |
+
|
100 |
+
else:
|
101 |
+
# Case: do not use caching (i.e. this is a standard forward pass);
|
102 |
+
self._encode_vision_x(vision_x=vision_x)
|
103 |
+
|
104 |
+
output = self.lang_encoder(
|
105 |
+
input_ids=lang_x,
|
106 |
+
attention_mask=attention_mask,
|
107 |
+
labels=labels,
|
108 |
+
past_key_values=past_key_values,
|
109 |
+
use_cache=use_cache,
|
110 |
+
)
|
111 |
+
|
112 |
+
if clear_conditioned_layers:
|
113 |
+
self.lang_encoder.clear_conditioned_layers()
|
114 |
+
|
115 |
+
return output
|
116 |
+
|
117 |
+
def generate(
|
118 |
+
self,
|
119 |
+
vision_x: torch.Tensor,
|
120 |
+
lang_x: torch.Tensor,
|
121 |
+
attention_mask: torch.Tensor = None,
|
122 |
+
num_beams=1,
|
123 |
+
max_new_tokens=None,
|
124 |
+
temperature=1.0,
|
125 |
+
top_k=0,
|
126 |
+
top_p=1.0,
|
127 |
+
no_repeat_ngram_size=0,
|
128 |
+
prefix_allowed_tokens_fn=None,
|
129 |
+
length_penalty=1.0,
|
130 |
+
num_return_sequences=1,
|
131 |
+
do_sample=False,
|
132 |
+
early_stopping=False,
|
133 |
+
):
|
134 |
+
"""
|
135 |
+
Generate text conditioned on vision and language inputs.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
vision_x (torch.Tensor): Vision input
|
139 |
+
shape (B, T_img, F, C, H, W)
|
140 |
+
images in the same chunk are collated along T_img, and frames are collated along F
|
141 |
+
currently only F=1 is supported (single-frame videos)
|
142 |
+
lang_x (torch.Tensor): Language input
|
143 |
+
shape (B, T_txt)
|
144 |
+
max_length (int, optional): Maximum length of the output. Defaults to None.
|
145 |
+
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
|
146 |
+
num_beams (int, optional): Number of beams. Defaults to 1.
|
147 |
+
max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
|
148 |
+
temperature (float, optional): Temperature. Defaults to 1.0.
|
149 |
+
top_k (int, optional): Top k. Defaults to 0.
|
150 |
+
top_p (float, optional): Top p. Defaults to 1.0.
|
151 |
+
no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
|
152 |
+
length_penalty (float, optional): Length penalty. Defaults to 1.0.
|
153 |
+
num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
|
154 |
+
do_sample (bool, optional): Do sample. Defaults to False.
|
155 |
+
early_stopping (bool, optional): Early stopping. Defaults to False.
|
156 |
+
Returns:
|
157 |
+
torch.Tensor: lang_x with generated tokens appended to it
|
158 |
+
"""
|
159 |
+
if num_beams > 1:
|
160 |
+
vision_x = vision_x.repeat_interleave(num_beams, dim=0)
|
161 |
+
|
162 |
+
self._encode_vision_x(vision_x=vision_x)
|
163 |
+
|
164 |
+
output = self.lang_encoder.generate(
|
165 |
+
lang_x,
|
166 |
+
attention_mask=attention_mask,
|
167 |
+
# eos_token_id=self.eoc_token_id,
|
168 |
+
num_beams=num_beams,
|
169 |
+
max_new_tokens=max_new_tokens,
|
170 |
+
temperature=temperature,
|
171 |
+
top_k=top_k,
|
172 |
+
top_p=top_p,
|
173 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
174 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
175 |
+
length_penalty=length_penalty,
|
176 |
+
num_return_sequences=num_return_sequences,
|
177 |
+
do_sample=do_sample,
|
178 |
+
early_stopping=early_stopping,
|
179 |
+
)
|
180 |
+
|
181 |
+
self.lang_encoder.clear_conditioned_layers()
|
182 |
+
return output
|
183 |
+
|
184 |
+
def _encode_vision_x(self, vision_x: torch.Tensor):
|
185 |
+
"""
|
186 |
+
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
|
187 |
+
Args:
|
188 |
+
vision_x (torch.Tensor): Vision input
|
189 |
+
shape (B, T_img, F, C, H, W)
|
190 |
+
Images in the same chunk are collated along T_img, and frames are collated along F
|
191 |
+
Currently only F=1 is supported (single-frame videos)
|
192 |
+
|
193 |
+
rearrange code based on https://github.com/dhansmair/flamingo-mini
|
194 |
+
"""
|
195 |
+
|
196 |
+
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
|
197 |
+
b, T, F = vision_x.shape[:3]
|
198 |
+
assert F == 1, "Only single frame supported"
|
199 |
+
|
200 |
+
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
|
201 |
+
with torch.no_grad():
|
202 |
+
vision_x = self.vision_encoder.visual(vision_x)[1]
|
203 |
+
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
|
204 |
+
|
205 |
+
vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d)
|
206 |
+
|
207 |
+
for layer in self.lang_encoder._get_decoder_layers():
|
208 |
+
layer.condition_vis_x(vision_x)
|
mmgpt/models/open_flamingo/flamingo_lm.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Modified from https://github.com/mlfoundations/open_flamingo"""
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .helpers import GatedCrossAttentionBlock
|
7 |
+
from .utils import getattr_recursive, setattr_recursive
|
8 |
+
|
9 |
+
|
10 |
+
class FlamingoLayer(nn.Module):
|
11 |
+
def __init__(self, gated_cross_attn_layer, decoder_layer):
|
12 |
+
super().__init__()
|
13 |
+
self.gated_cross_attn_layer = gated_cross_attn_layer
|
14 |
+
self.decoder_layer = decoder_layer
|
15 |
+
self.vis_x = None
|
16 |
+
self.media_locations = None
|
17 |
+
self.only_lang_x = False
|
18 |
+
|
19 |
+
def is_conditioned(self) -> bool:
|
20 |
+
"""Check whether the layer is conditioned."""
|
21 |
+
return self.vis_x is not None
|
22 |
+
|
23 |
+
# Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
|
24 |
+
def condition_vis_x(self, vis_x):
|
25 |
+
self.vis_x = vis_x
|
26 |
+
|
27 |
+
def condition_only_lang_x(self, only_lang_x=False):
|
28 |
+
self.only_lang_x = only_lang_x
|
29 |
+
|
30 |
+
def condition_media_locations(self, media_locations):
|
31 |
+
self.media_locations = media_locations
|
32 |
+
|
33 |
+
def condition_attend_previous(self, attend_previous):
|
34 |
+
self.attend_previous = attend_previous
|
35 |
+
|
36 |
+
def forward(
|
37 |
+
self,
|
38 |
+
lang_x,
|
39 |
+
attention_mask=None,
|
40 |
+
**decoder_layer_kwargs,
|
41 |
+
):
|
42 |
+
if self.gated_cross_attn_layer is None or self.only_lang_x:
|
43 |
+
return self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs)
|
44 |
+
|
45 |
+
if self.vis_x is None:
|
46 |
+
raise ValueError("vis_x must be conditioned before forward pass")
|
47 |
+
|
48 |
+
if self.media_locations is None:
|
49 |
+
raise ValueError("media_locations must be conditioned before forward pass")
|
50 |
+
|
51 |
+
lang_x = self.gated_cross_attn_layer(
|
52 |
+
lang_x,
|
53 |
+
self.vis_x,
|
54 |
+
media_locations=self.media_locations,
|
55 |
+
attend_previous=self.attend_previous,
|
56 |
+
)
|
57 |
+
lang_x = self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs)
|
58 |
+
return lang_x
|
59 |
+
|
60 |
+
|
61 |
+
class FlamingoLMMixin(nn.Module):
|
62 |
+
"""
|
63 |
+
Mixin to add cross-attention layers to a language model.
|
64 |
+
"""
|
65 |
+
|
66 |
+
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
|
67 |
+
self.decoder_layers_attr_name = decoder_layers_attr_name
|
68 |
+
|
69 |
+
def _get_decoder_layers(self):
|
70 |
+
return getattr_recursive(self, self.decoder_layers_attr_name)
|
71 |
+
|
72 |
+
def _set_decoder_layers(self, value):
|
73 |
+
setattr_recursive(self, self.decoder_layers_attr_name, value)
|
74 |
+
|
75 |
+
def init_flamingo(
|
76 |
+
self,
|
77 |
+
media_token_id,
|
78 |
+
vis_hidden_size,
|
79 |
+
cross_attn_every_n_layers,
|
80 |
+
use_media_placement_augmentation,
|
81 |
+
):
|
82 |
+
"""
|
83 |
+
Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
|
84 |
+
"""
|
85 |
+
|
86 |
+
self.gated_cross_attn_layers = nn.ModuleList(
|
87 |
+
[
|
88 |
+
GatedCrossAttentionBlock(dim=self.config.hidden_size, dim_visual=vis_hidden_size)
|
89 |
+
if (layer_idx + 1) % cross_attn_every_n_layers == 0
|
90 |
+
else None
|
91 |
+
for layer_idx, _ in enumerate(self._get_decoder_layers())
|
92 |
+
]
|
93 |
+
)
|
94 |
+
self._set_decoder_layers(
|
95 |
+
nn.ModuleList(
|
96 |
+
[
|
97 |
+
FlamingoLayer(gated_cross_attn_layer, decoder_layer)
|
98 |
+
for gated_cross_attn_layer, decoder_layer in zip(
|
99 |
+
self.gated_cross_attn_layers, self._get_decoder_layers()
|
100 |
+
)
|
101 |
+
]
|
102 |
+
)
|
103 |
+
)
|
104 |
+
self.media_token_id = media_token_id
|
105 |
+
self.use_media_placement_augmentation = use_media_placement_augmentation
|
106 |
+
self.initialized_flamingo = True
|
107 |
+
|
108 |
+
def forward(self, *input, **kwargs):
|
109 |
+
"""Condition the Flamingo layers on the media locations before forward()"""
|
110 |
+
if not self.initialized_flamingo:
|
111 |
+
raise ValueError("Flamingo layers are not initialized. Please call `init_flamingo` first.")
|
112 |
+
|
113 |
+
input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0]
|
114 |
+
media_locations = input_ids == self.media_token_id
|
115 |
+
attend_previous = (random.random() < 0.5) if self.use_media_placement_augmentation else False
|
116 |
+
|
117 |
+
for layer in self.get_decoder().layers:
|
118 |
+
layer.condition_media_locations(media_locations)
|
119 |
+
layer.condition_attend_previous(attend_previous)
|
120 |
+
|
121 |
+
return super().forward(*input, **kwargs) # Call the other parent's forward method
|
122 |
+
|
123 |
+
def is_conditioned(self) -> bool:
|
124 |
+
"""Check whether all decoder layers are already conditioned."""
|
125 |
+
return all(l.is_conditioned() for l in self._get_decoder_layers())
|
126 |
+
|
127 |
+
def clear_conditioned_layers(self):
|
128 |
+
for layer in self._get_decoder_layers():
|
129 |
+
layer.condition_vis_x(None)
|
130 |
+
layer.condition_media_locations(None)
|
131 |
+
layer.condition_attend_previous(None)
|
mmgpt/models/open_flamingo/helpers.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Taken from https://github.com/lucidrains/flamingo-pytorch
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
from einops_exts import rearrange_many
|
8 |
+
from torch import einsum, nn
|
9 |
+
|
10 |
+
|
11 |
+
def exists(val):
|
12 |
+
return val is not None
|
13 |
+
|
14 |
+
|
15 |
+
def FeedForward(dim, mult=4):
|
16 |
+
inner_dim = int(dim * mult)
|
17 |
+
return nn.Sequential(
|
18 |
+
nn.LayerNorm(dim),
|
19 |
+
nn.Linear(dim, inner_dim, bias=False),
|
20 |
+
nn.GELU(),
|
21 |
+
nn.Linear(inner_dim, dim, bias=False),
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
class PerceiverAttention(nn.Module):
|
26 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
27 |
+
super().__init__()
|
28 |
+
self.scale = dim_head**-0.5
|
29 |
+
self.heads = heads
|
30 |
+
inner_dim = dim_head * heads
|
31 |
+
|
32 |
+
self.norm_media = nn.LayerNorm(dim)
|
33 |
+
self.norm_latents = nn.LayerNorm(dim)
|
34 |
+
|
35 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
36 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
37 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
38 |
+
|
39 |
+
def forward(self, x, latents):
|
40 |
+
"""
|
41 |
+
Args:
|
42 |
+
x (torch.Tensor): image features
|
43 |
+
shape (b, T, n1, D)
|
44 |
+
latent (torch.Tensor): latent features
|
45 |
+
shape (b, T, n2, D)
|
46 |
+
"""
|
47 |
+
x = self.norm_media(x)
|
48 |
+
latents = self.norm_latents(latents)
|
49 |
+
|
50 |
+
h = self.heads
|
51 |
+
|
52 |
+
q = self.to_q(latents)
|
53 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
54 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
55 |
+
q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
|
56 |
+
q = q * self.scale
|
57 |
+
|
58 |
+
# attention
|
59 |
+
sim = einsum("... i d, ... j d -> ... i j", q, k)
|
60 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
61 |
+
attn = sim.softmax(dim=-1)
|
62 |
+
|
63 |
+
out = einsum("... i j, ... j d -> ... i d", attn, v)
|
64 |
+
out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
|
65 |
+
return self.to_out(out)
|
66 |
+
|
67 |
+
|
68 |
+
class PerceiverResampler(nn.Module):
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
*,
|
72 |
+
dim,
|
73 |
+
depth=6,
|
74 |
+
dim_head=64,
|
75 |
+
heads=8,
|
76 |
+
num_latents=64,
|
77 |
+
max_num_media=None,
|
78 |
+
max_num_frames=None,
|
79 |
+
ff_mult=4,
|
80 |
+
):
|
81 |
+
super().__init__()
|
82 |
+
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
83 |
+
self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None
|
84 |
+
self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None
|
85 |
+
|
86 |
+
self.layers = nn.ModuleList([])
|
87 |
+
for _ in range(depth):
|
88 |
+
self.layers.append(
|
89 |
+
nn.ModuleList(
|
90 |
+
[
|
91 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
92 |
+
FeedForward(dim=dim, mult=ff_mult),
|
93 |
+
]
|
94 |
+
)
|
95 |
+
)
|
96 |
+
|
97 |
+
self.norm = nn.LayerNorm(dim)
|
98 |
+
|
99 |
+
def forward(self, x):
|
100 |
+
"""
|
101 |
+
Args:
|
102 |
+
x (torch.Tensor): image features
|
103 |
+
shape (b, T, F, v, D)
|
104 |
+
Returns:
|
105 |
+
shape (b, T, n, D) where n is self.num_latents
|
106 |
+
"""
|
107 |
+
b, T, F, v = x.shape[:4]
|
108 |
+
|
109 |
+
# frame and media time embeddings
|
110 |
+
if exists(self.frame_embs):
|
111 |
+
frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
|
112 |
+
x = x + frame_embs
|
113 |
+
x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions
|
114 |
+
if exists(self.media_time_embs):
|
115 |
+
x = x + self.media_time_embs[:T]
|
116 |
+
|
117 |
+
# blocks
|
118 |
+
latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
|
119 |
+
for attn, ff in self.layers:
|
120 |
+
latents = attn(x, latents) + latents
|
121 |
+
latents = ff(latents) + latents
|
122 |
+
return self.norm(latents)
|
123 |
+
|
124 |
+
|
125 |
+
# gated cross attention
|
126 |
+
|
127 |
+
|
128 |
+
class MaskedCrossAttention(nn.Module):
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
*,
|
132 |
+
dim,
|
133 |
+
dim_visual,
|
134 |
+
dim_head=64,
|
135 |
+
heads=8,
|
136 |
+
only_attend_immediate_media=True,
|
137 |
+
):
|
138 |
+
super().__init__()
|
139 |
+
self.scale = dim_head**-0.5
|
140 |
+
self.heads = heads
|
141 |
+
inner_dim = dim_head * heads
|
142 |
+
|
143 |
+
self.norm = nn.LayerNorm(dim)
|
144 |
+
|
145 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
146 |
+
self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
|
147 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
148 |
+
|
149 |
+
# whether for text to only attend to immediate preceding image, or all previous images
|
150 |
+
self.only_attend_immediate_media = only_attend_immediate_media
|
151 |
+
|
152 |
+
def forward(self, x, media, media_locations=None, attend_previous=True):
|
153 |
+
"""
|
154 |
+
Args:
|
155 |
+
x (torch.Tensor): text features
|
156 |
+
shape (B, T_txt, D_txt)
|
157 |
+
media (torch.Tensor): image features
|
158 |
+
shape (B, T_img, n, D_img) where n is the dim of the latents
|
159 |
+
media_locations: boolean mask identifying the media tokens in x
|
160 |
+
shape (B, T_txt)
|
161 |
+
attend_previous: bool
|
162 |
+
If false, ignores immediately preceding image and starts attending when following image
|
163 |
+
"""
|
164 |
+
_, T_img, n = media.shape[:3]
|
165 |
+
h = self.heads
|
166 |
+
|
167 |
+
x = self.norm(x)
|
168 |
+
|
169 |
+
q = self.to_q(x)
|
170 |
+
media = rearrange(media, "b t n d -> b (t n) d")
|
171 |
+
|
172 |
+
k, v = self.to_kv(media).chunk(2, dim=-1)
|
173 |
+
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
|
174 |
+
|
175 |
+
q = q * self.scale
|
176 |
+
|
177 |
+
sim = einsum("... i d, ... j d -> ... i j", q, k)
|
178 |
+
|
179 |
+
if exists(media_locations):
|
180 |
+
# at each boolean of True, increment the time counter (relative to media time)
|
181 |
+
text_time = media_locations.cumsum(dim=-1)
|
182 |
+
media_time = torch.arange(T_img, device=x.device) + 1
|
183 |
+
|
184 |
+
if not attend_previous:
|
185 |
+
text_time[~media_locations] += 1
|
186 |
+
# make sure max is still the number of images in the sequence
|
187 |
+
text_time[
|
188 |
+
text_time
|
189 |
+
> repeat(
|
190 |
+
torch.count_nonzero(media_locations, dim=1),
|
191 |
+
"b -> b i",
|
192 |
+
i=text_time.shape[1],
|
193 |
+
)
|
194 |
+
] = 0
|
195 |
+
|
196 |
+
# text time must equal media time if only attending to most immediate image
|
197 |
+
# otherwise, as long as text time is greater than media time (if attending to all previous images / media)
|
198 |
+
mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
|
199 |
+
|
200 |
+
text_to_media_mask = mask_op(
|
201 |
+
rearrange(text_time, "b i -> b 1 i 1"),
|
202 |
+
repeat(media_time, "j -> 1 1 1 (j n)", n=n),
|
203 |
+
)
|
204 |
+
sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
|
205 |
+
|
206 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
207 |
+
attn = sim.softmax(dim=-1)
|
208 |
+
|
209 |
+
if exists(media_locations) and self.only_attend_immediate_media:
|
210 |
+
# any text without a preceding media needs to have attention zeroed out
|
211 |
+
text_without_media_mask = text_time == 0
|
212 |
+
text_without_media_mask = rearrange(text_without_media_mask, "b i -> b 1 i 1")
|
213 |
+
attn = attn.masked_fill(text_without_media_mask, 0.0)
|
214 |
+
|
215 |
+
out = einsum("... i j, ... j d -> ... i d", attn, v)
|
216 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
217 |
+
return self.to_out(out)
|
218 |
+
|
219 |
+
|
220 |
+
class GatedCrossAttentionBlock(nn.Module):
|
221 |
+
def __init__(
|
222 |
+
self,
|
223 |
+
*,
|
224 |
+
dim,
|
225 |
+
dim_visual,
|
226 |
+
dim_head=64,
|
227 |
+
heads=8,
|
228 |
+
ff_mult=4,
|
229 |
+
only_attend_immediate_media=True,
|
230 |
+
):
|
231 |
+
super().__init__()
|
232 |
+
self.attn = MaskedCrossAttention(
|
233 |
+
dim=dim,
|
234 |
+
dim_visual=dim_visual,
|
235 |
+
dim_head=dim_head,
|
236 |
+
heads=heads,
|
237 |
+
only_attend_immediate_media=only_attend_immediate_media,
|
238 |
+
)
|
239 |
+
self.attn_gate = nn.Parameter(torch.tensor([0.0]))
|
240 |
+
|
241 |
+
self.ff = FeedForward(dim, mult=ff_mult)
|
242 |
+
self.ff_gate = nn.Parameter(torch.tensor([0.0]))
|
243 |
+
|
244 |
+
def forward(
|
245 |
+
self,
|
246 |
+
x,
|
247 |
+
media,
|
248 |
+
media_locations=None,
|
249 |
+
attend_previous=True,
|
250 |
+
):
|
251 |
+
x = (
|
252 |
+
self.attn(
|
253 |
+
x,
|
254 |
+
media,
|
255 |
+
media_locations=media_locations,
|
256 |
+
attend_previous=attend_previous,
|
257 |
+
)
|
258 |
+
* self.attn_gate.tanh()
|
259 |
+
+ x
|
260 |
+
)
|
261 |
+
x = self.ff(x) * self.ff_gate.tanh() + x
|
262 |
+
|
263 |
+
return x
|
mmgpt/models/open_flamingo/utils.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def extend_instance(obj, mixin):
|
2 |
+
"""Apply mixins to a class instance after creation"""
|
3 |
+
base_cls = obj.__class__
|
4 |
+
base_cls_name = obj.__class__.__name__
|
5 |
+
obj.__class__ = type(
|
6 |
+
base_cls_name, (mixin, base_cls), {}
|
7 |
+
) # mixin needs to go first for our forward() logic to work
|
8 |
+
|
9 |
+
|
10 |
+
def getattr_recursive(obj, att):
|
11 |
+
"""
|
12 |
+
Return nested attribute of obj
|
13 |
+
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
|
14 |
+
"""
|
15 |
+
if att == "":
|
16 |
+
return obj
|
17 |
+
i = att.find(".")
|
18 |
+
if i < 0:
|
19 |
+
return getattr(obj, att)
|
20 |
+
else:
|
21 |
+
return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
|
22 |
+
|
23 |
+
|
24 |
+
def setattr_recursive(obj, att, val):
|
25 |
+
"""
|
26 |
+
Set nested attribute of obj
|
27 |
+
Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
|
28 |
+
"""
|
29 |
+
if "." in att:
|
30 |
+
obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
|
31 |
+
setattr(obj, att.split(".")[-1], val)
|
mmgpt/train/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
mmgpt/train/distributed.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Modified from https://github.com/mlfoundations/open_flamingo"""
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
try:
|
7 |
+
import horovod.torch as hvd
|
8 |
+
except ImportError:
|
9 |
+
hvd = None
|
10 |
+
|
11 |
+
|
12 |
+
def is_global_master(args):
|
13 |
+
return args.rank == 0
|
14 |
+
|
15 |
+
|
16 |
+
def is_local_master(args):
|
17 |
+
return args.local_rank == 0
|
18 |
+
|
19 |
+
|
20 |
+
def is_master(args, local=False):
|
21 |
+
return is_local_master(args) if local else is_global_master(args)
|
22 |
+
|
23 |
+
|
24 |
+
def is_using_horovod():
|
25 |
+
# NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
|
26 |
+
# Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
|
27 |
+
ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
|
28 |
+
pmi_vars = ["PMI_RANK", "PMI_SIZE"]
|
29 |
+
if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]):
|
30 |
+
return True
|
31 |
+
else:
|
32 |
+
return False
|
33 |
+
|
34 |
+
|
35 |
+
def is_using_distributed():
|
36 |
+
if "WORLD_SIZE" in os.environ:
|
37 |
+
return int(os.environ["WORLD_SIZE"]) > 1
|
38 |
+
if "SLURM_NTASKS" in os.environ:
|
39 |
+
return int(os.environ["SLURM_NTASKS"]) > 1
|
40 |
+
return False
|
41 |
+
|
42 |
+
|
43 |
+
def world_info_from_env():
|
44 |
+
local_rank = 0
|
45 |
+
for v in (
|
46 |
+
"LOCAL_RANK",
|
47 |
+
"MPI_LOCALRANKID",
|
48 |
+
"SLURM_LOCALID",
|
49 |
+
"OMPI_COMM_WORLD_LOCAL_RANK",
|
50 |
+
):
|
51 |
+
if v in os.environ:
|
52 |
+
local_rank = int(os.environ[v])
|
53 |
+
break
|
54 |
+
global_rank = 0
|
55 |
+
for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"):
|
56 |
+
if v in os.environ:
|
57 |
+
global_rank = int(os.environ[v])
|
58 |
+
break
|
59 |
+
world_size = 1
|
60 |
+
for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"):
|
61 |
+
if v in os.environ:
|
62 |
+
world_size = int(os.environ[v])
|
63 |
+
break
|
64 |
+
|
65 |
+
return local_rank, global_rank, world_size
|
66 |
+
|
67 |
+
|
68 |
+
def init_distributed_device(args):
|
69 |
+
# Distributed training = training on more than one GPU.
|
70 |
+
# Works in both single and multi-node scenarios.
|
71 |
+
args.distributed = False
|
72 |
+
args.world_size = 1
|
73 |
+
args.rank = 0 # global rank
|
74 |
+
args.local_rank = 0
|
75 |
+
if args.horovod:
|
76 |
+
assert hvd is not None, "Horovod is not installed"
|
77 |
+
hvd.init()
|
78 |
+
args.local_rank = int(hvd.local_rank())
|
79 |
+
args.rank = hvd.rank()
|
80 |
+
args.world_size = hvd.size()
|
81 |
+
args.distributed = True
|
82 |
+
os.environ["LOCAL_RANK"] = str(args.local_rank)
|
83 |
+
os.environ["RANK"] = str(args.rank)
|
84 |
+
os.environ["WORLD_SIZE"] = str(args.world_size)
|
85 |
+
elif is_using_distributed():
|
86 |
+
if "SLURM_PROCID" in os.environ:
|
87 |
+
# DDP via SLURM
|
88 |
+
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
89 |
+
# SLURM var -> torch.distributed vars in case needed
|
90 |
+
os.environ["LOCAL_RANK"] = str(args.local_rank)
|
91 |
+
os.environ["RANK"] = str(args.rank)
|
92 |
+
os.environ["WORLD_SIZE"] = str(args.world_size)
|
93 |
+
torch.distributed.init_process_group(
|
94 |
+
backend=args.dist_backend,
|
95 |
+
init_method=args.dist_url,
|
96 |
+
world_size=args.world_size,
|
97 |
+
rank=args.rank,
|
98 |
+
)
|
99 |
+
else:
|
100 |
+
# DDP via torchrun, torch.distributed.launch
|
101 |
+
args.local_rank, _, _ = world_info_from_env()
|
102 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url)
|
103 |
+
args.world_size = torch.distributed.get_world_size()
|
104 |
+
args.rank = torch.distributed.get_rank()
|
105 |
+
args.distributed = True
|
106 |
+
else:
|
107 |
+
# needed to run on single gpu
|
108 |
+
torch.distributed.init_process_group(
|
109 |
+
backend=args.dist_backend,
|
110 |
+
init_method=args.dist_url,
|
111 |
+
world_size=1,
|
112 |
+
rank=0,
|
113 |
+
)
|
114 |
+
|
115 |
+
if torch.cuda.is_available():
|
116 |
+
if args.distributed and not args.no_set_device_rank:
|
117 |
+
device = "cuda:%d" % args.local_rank
|
118 |
+
else:
|
119 |
+
device = "cuda:0"
|
120 |
+
torch.cuda.set_device(device)
|
121 |
+
else:
|
122 |
+
device = "cpu"
|
123 |
+
args.device = device
|
124 |
+
device = torch.device(device)
|
125 |
+
return device
|
126 |
+
|
127 |
+
|
128 |
+
def is_rank0():
|
129 |
+
if not torch.distributed.is_initialized():
|
130 |
+
return True
|
131 |
+
return torch.distributed.get_rank() == 0
|
mmgpt/train/instruction_finetune.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Modified from https://github.com/mlfoundations/open_flamingo"""
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import copy
|
5 |
+
import glob
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
import time
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import wandb
|
13 |
+
from mmengine import Config
|
14 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
15 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
16 |
+
from tqdm import tqdm
|
17 |
+
from transformers import (
|
18 |
+
get_constant_schedule_with_warmup,
|
19 |
+
get_cosine_schedule_with_warmup,
|
20 |
+
get_linear_schedule_with_warmup,
|
21 |
+
)
|
22 |
+
|
23 |
+
from mmgpt import create_model_and_transforms
|
24 |
+
from mmgpt.models.builder import create_toy_model_and_transforms
|
25 |
+
from mmgpt.datasets import InfiniteSampler, build_dataset
|
26 |
+
from mmgpt.train.distributed import init_distributed_device, world_info_from_env
|
27 |
+
from mmgpt.train.train_utils import AverageMeter, get_autocast, get_cast_dtype, get_checkpoint
|
28 |
+
|
29 |
+
|
30 |
+
def random_seed(seed=42, rank=0):
|
31 |
+
torch.manual_seed(seed + rank)
|
32 |
+
np.random.seed(seed + rank)
|
33 |
+
random.seed(seed + rank)
|
34 |
+
|
35 |
+
|
36 |
+
def main():
|
37 |
+
parser = argparse.ArgumentParser()
|
38 |
+
parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
|
39 |
+
parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
|
40 |
+
parser.add_argument("--lm_path", default="checkpoints/llama-7b_hf", type=str)
|
41 |
+
parser.add_argument(
|
42 |
+
"--tokenizer_path",
|
43 |
+
default="checkpoints/llama-7b_hf",
|
44 |
+
type=str,
|
45 |
+
help="path to tokenizer",
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
"--pretrained_path",
|
49 |
+
default="checkpoints/OpenFlamingo-9B/checkpoint.pt",
|
50 |
+
type=str,
|
51 |
+
help="path to pretrained model",
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--run_name",
|
55 |
+
type=str,
|
56 |
+
default="train-my-gpt4",
|
57 |
+
help="used to name saving directory and wandb run",
|
58 |
+
)
|
59 |
+
parser.add_argument("--use_media_placement_augmentation", action="store_true")
|
60 |
+
parser.add_argument("--offline", action="store_true")
|
61 |
+
parser.add_argument("--num_epochs", type=int, default=1)
|
62 |
+
parser.add_argument("--logging_steps", type=int, default=100, help="log loss every n steps")
|
63 |
+
# Sum of gradient optimization batch size
|
64 |
+
parser.add_argument(
|
65 |
+
"--resume_from_checkpoint",
|
66 |
+
type=str,
|
67 |
+
help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states",
|
68 |
+
default=None,
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--delete_previous_checkpoint",
|
72 |
+
action="store_true",
|
73 |
+
help="delete previous checkpoint when saving new checkpoint",
|
74 |
+
)
|
75 |
+
parser.add_argument("--seed", type=int, default=42)
|
76 |
+
parser.add_argument("--learning_rate", default=1e-5, type=float)
|
77 |
+
parser.add_argument(
|
78 |
+
"--lr_scheduler",
|
79 |
+
default="constant",
|
80 |
+
type=str,
|
81 |
+
help="constant, linear, or cosine",
|
82 |
+
)
|
83 |
+
parser.add_argument("--warmup_steps", default=100, type=int)
|
84 |
+
parser.add_argument("--weight_decay", default=0.1, type=float)
|
85 |
+
parser.add_argument(
|
86 |
+
"--precision",
|
87 |
+
choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"],
|
88 |
+
default="amp",
|
89 |
+
help="Floating point precision.",
|
90 |
+
)
|
91 |
+
# data args
|
92 |
+
parser.add_argument("--workers", type=int, default=0)
|
93 |
+
parser.add_argument("--batch_size", type=int, default=1)
|
94 |
+
parser.add_argument("--dataset_config", type=str, default=None, help="path to dataset config file")
|
95 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=16)
|
96 |
+
# Finetune config
|
97 |
+
parser.add_argument("--tuning_config", type=str, default=None, help="path to tuning config file")
|
98 |
+
# distributed training args
|
99 |
+
parser.add_argument(
|
100 |
+
"--dist-url",
|
101 |
+
default="env://",
|
102 |
+
type=str,
|
103 |
+
help="url used to set up distributed training",
|
104 |
+
)
|
105 |
+
parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend")
|
106 |
+
parser.add_argument(
|
107 |
+
"--horovod",
|
108 |
+
default=False,
|
109 |
+
action="store_true",
|
110 |
+
help="Use horovod for distributed training.",
|
111 |
+
)
|
112 |
+
parser.add_argument(
|
113 |
+
"--no-set-device-rank",
|
114 |
+
default=False,
|
115 |
+
action="store_true",
|
116 |
+
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
|
117 |
+
)
|
118 |
+
# wandb args
|
119 |
+
parser.add_argument("--report_to_wandb", default=False, action="store_true")
|
120 |
+
parser.add_argument(
|
121 |
+
"--wandb_project",
|
122 |
+
type=str,
|
123 |
+
)
|
124 |
+
parser.add_argument(
|
125 |
+
"--wandb_entity",
|
126 |
+
type=str,
|
127 |
+
)
|
128 |
+
parser.add_argument(
|
129 |
+
"--save_checkpoints_to_wandb",
|
130 |
+
default=False,
|
131 |
+
action="store_true",
|
132 |
+
help="save checkpoints to wandb",
|
133 |
+
)
|
134 |
+
|
135 |
+
args = parser.parse_args()
|
136 |
+
|
137 |
+
if args.save_checkpoints_to_wandb and not args.report_to_wandb:
|
138 |
+
raise ValueError("save_checkpoints_to_wandb requires report_to_wandb")
|
139 |
+
|
140 |
+
if args.offline:
|
141 |
+
os.environ["WANDB_MODE"] = "offline"
|
142 |
+
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
143 |
+
|
144 |
+
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
145 |
+
|
146 |
+
if args.rank == 0:
|
147 |
+
if not os.path.exists(args.run_name):
|
148 |
+
os.makedirs(args.run_name)
|
149 |
+
|
150 |
+
device_id = init_distributed_device(args)
|
151 |
+
|
152 |
+
random_seed(args.seed)
|
153 |
+
|
154 |
+
if args.tuning_config is not None:
|
155 |
+
tuning_config = Config.fromfile(args.tuning_config)
|
156 |
+
else:
|
157 |
+
raise ValueError("tuning_config must be specified")
|
158 |
+
|
159 |
+
model, image_processor, tokenizer = create_model_and_transforms(
|
160 |
+
model_name="open_flamingo",
|
161 |
+
clip_vision_encoder_path=args.vision_encoder_path,
|
162 |
+
clip_vision_encoder_pretrained=args.vision_encoder_pretrained,
|
163 |
+
lang_encoder_path=args.lm_path,
|
164 |
+
tokenizer_path=args.tokenizer_path if args.tokenizer_path else args.lm_path,
|
165 |
+
use_media_placement_augmentation=args.use_media_placement_augmentation,
|
166 |
+
pretrained_model_path=args.pretrained_path,
|
167 |
+
tuning_config=tuning_config.tuning_config,
|
168 |
+
)
|
169 |
+
|
170 |
+
if args.dataset_config is not None:
|
171 |
+
dataset_config = Config.fromfile(args.dataset_config)
|
172 |
+
else:
|
173 |
+
raise ValueError("dataset_config must be specified")
|
174 |
+
|
175 |
+
dataset = build_dataset(
|
176 |
+
dataset_config=dataset_config.visual_datasets,
|
177 |
+
vis_processor=image_processor,
|
178 |
+
tokenizer=tokenizer,
|
179 |
+
)
|
180 |
+
train_dataloader = DataLoader(
|
181 |
+
dataset,
|
182 |
+
batch_size=args.batch_size,
|
183 |
+
num_workers=args.workers,
|
184 |
+
sampler=DistributedSampler(dataset, shuffle=True, drop_last=True),
|
185 |
+
collate_fn=dataset.collater,
|
186 |
+
)
|
187 |
+
|
188 |
+
# build language dataset and dataloader for multi-modality training
|
189 |
+
if dataset_config.get('language_datasets') is not None and len(dataset_config.language_datasets) > 0:
|
190 |
+
lang_dataset = build_dataset(
|
191 |
+
dataset_config=dataset_config.language_datasets,
|
192 |
+
tokenizer=tokenizer,
|
193 |
+
)
|
194 |
+
lang_dataloader = DataLoader(
|
195 |
+
lang_dataset,
|
196 |
+
batch_size=args.batch_size,
|
197 |
+
num_workers=args.workers,
|
198 |
+
sampler=InfiniteSampler(lang_dataset, shuffle=True),
|
199 |
+
collate_fn=lang_dataset.collater,
|
200 |
+
)
|
201 |
+
lang_dataloader = iter(lang_dataloader)
|
202 |
+
else:
|
203 |
+
lang_dataloader = None
|
204 |
+
|
205 |
+
random_seed(args.seed, args.rank)
|
206 |
+
|
207 |
+
print(f"Start running training on rank {args.rank}.")
|
208 |
+
|
209 |
+
if args.rank == 0 and args.report_to_wandb:
|
210 |
+
wandb.init(
|
211 |
+
project=args.wandb_project,
|
212 |
+
entity=args.wandb_entity,
|
213 |
+
name=args.run_name,
|
214 |
+
config=vars(args),
|
215 |
+
)
|
216 |
+
|
217 |
+
device_id = args.rank % torch.cuda.device_count()
|
218 |
+
model = model.to(device_id)
|
219 |
+
|
220 |
+
ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters=True)
|
221 |
+
|
222 |
+
def get_grouped_params(model):
|
223 |
+
params_with_wd, params_without_wd = [], []
|
224 |
+
|
225 |
+
def apply_decay(x):
|
226 |
+
return (
|
227 |
+
"gated_cross_attn_layer" in x
|
228 |
+
and "ff_gate" not in x
|
229 |
+
and "attn_gate" not in x
|
230 |
+
and "norm" not in x
|
231 |
+
and "bias" not in x
|
232 |
+
)
|
233 |
+
|
234 |
+
for n, p in model.named_parameters():
|
235 |
+
# if p.requires_grad:
|
236 |
+
if apply_decay(n):
|
237 |
+
params_with_wd.append(p)
|
238 |
+
else:
|
239 |
+
params_without_wd.append(p)
|
240 |
+
|
241 |
+
return [
|
242 |
+
{"params": params_with_wd, "weight_decay": args.weight_decay},
|
243 |
+
{"params": params_without_wd, "weight_decay": 0.0},
|
244 |
+
]
|
245 |
+
|
246 |
+
optimizer = torch.optim.AdamW(get_grouped_params(ddp_model), lr=args.learning_rate)
|
247 |
+
|
248 |
+
total_training_steps = len(train_dataloader) * args.num_epochs
|
249 |
+
|
250 |
+
if args.rank == 0:
|
251 |
+
print(f"Total training steps: {total_training_steps}")
|
252 |
+
|
253 |
+
if args.lr_scheduler == "linear":
|
254 |
+
lr_scheduler = get_linear_schedule_with_warmup(
|
255 |
+
optimizer,
|
256 |
+
num_warmup_steps=args.warmup_steps,
|
257 |
+
num_training_steps=total_training_steps // args.gradient_accumulation_steps,
|
258 |
+
)
|
259 |
+
elif args.lr_scheduler == "cosine":
|
260 |
+
lr_scheduler = get_cosine_schedule_with_warmup(
|
261 |
+
optimizer,
|
262 |
+
num_warmup_steps=args.warmup_steps,
|
263 |
+
num_training_steps=total_training_steps // args.gradient_accumulation_steps,
|
264 |
+
)
|
265 |
+
else:
|
266 |
+
lr_scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps)
|
267 |
+
|
268 |
+
# check if a checkpoint exists for this run
|
269 |
+
if os.path.exists(f"{args.run_name}") and args.resume_from_checkpoint is None:
|
270 |
+
checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt")
|
271 |
+
if len(checkpoint_list) == 0:
|
272 |
+
print(f"Found no checkpoints for run {args.run_name}.")
|
273 |
+
else:
|
274 |
+
args.resume_from_checkpoint = sorted(checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0]))[-1]
|
275 |
+
print(f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}.")
|
276 |
+
|
277 |
+
resume_from_epoch = 0
|
278 |
+
if args.resume_from_checkpoint is not None:
|
279 |
+
if args.rank == 0:
|
280 |
+
print(f"Loading checkpoint from {args.resume_from_checkpoint}")
|
281 |
+
checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu")
|
282 |
+
ddp_model.load_state_dict(checkpoint["model_state_dict"], False)
|
283 |
+
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
284 |
+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
|
285 |
+
resume_from_epoch = checkpoint["epoch"] + 1
|
286 |
+
|
287 |
+
ddp_model.train()
|
288 |
+
|
289 |
+
for epoch in range(resume_from_epoch, args.num_epochs):
|
290 |
+
train_dataloader.sampler.set_epoch(epoch)
|
291 |
+
|
292 |
+
train_one_epoch(
|
293 |
+
args=args,
|
294 |
+
model=ddp_model,
|
295 |
+
epoch=epoch,
|
296 |
+
tokenizer=tokenizer,
|
297 |
+
optimizer=optimizer,
|
298 |
+
lr_scheduler=lr_scheduler,
|
299 |
+
train_dataloader=train_dataloader,
|
300 |
+
language_dataloader=lang_dataloader,
|
301 |
+
device_id=device_id,
|
302 |
+
wandb=wandb,
|
303 |
+
)
|
304 |
+
|
305 |
+
if args.rank == 0:
|
306 |
+
if not os.path.exists(args.run_name):
|
307 |
+
os.makedirs(args.run_name)
|
308 |
+
|
309 |
+
checkpoint_dict = {
|
310 |
+
"epoch": epoch,
|
311 |
+
"model_state_dict": get_checkpoint(ddp_model),
|
312 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
313 |
+
"lr_scheduler_state_dict": lr_scheduler.state_dict(),
|
314 |
+
"tuning_config": tuning_config,
|
315 |
+
}
|
316 |
+
|
317 |
+
print(f"Saving checkpoint to {args.run_name}/checkpoint_{epoch}.pt")
|
318 |
+
torch.save(checkpoint_dict, f"{args.run_name}/checkpoint_{epoch}.pt")
|
319 |
+
if args.report_to_wandb and args.save_checkpoints_to_wandb:
|
320 |
+
wandb.save(f"{args.run_name}/checkpoint_{epoch}.pt")
|
321 |
+
|
322 |
+
if args.delete_previous_checkpoint:
|
323 |
+
if epoch > 0:
|
324 |
+
os.remove(f"{args.run_name}/checkpoint_{epoch-1}.pt")
|
325 |
+
if args.rank == 0:
|
326 |
+
torch.save(
|
327 |
+
{"model_state_dict": get_checkpoint(ddp_model.module), "tuning_config": tuning_config},
|
328 |
+
f"{args.run_name}/final_weights.pt",
|
329 |
+
)
|
330 |
+
if args.report_to_wandb and args.save_checkpoints_to_wandb:
|
331 |
+
wandb.save(f"{args.run_name}/final_weights.pt")
|
332 |
+
|
333 |
+
|
334 |
+
def train_one_epoch(
|
335 |
+
args,
|
336 |
+
model,
|
337 |
+
epoch,
|
338 |
+
train_dataloader,
|
339 |
+
language_dataloader,
|
340 |
+
tokenizer,
|
341 |
+
optimizer,
|
342 |
+
lr_scheduler,
|
343 |
+
device_id,
|
344 |
+
wandb,
|
345 |
+
):
|
346 |
+
num_batches_per_epoch = len(train_dataloader)
|
347 |
+
|
348 |
+
total_training_steps = num_batches_per_epoch * args.num_epochs
|
349 |
+
|
350 |
+
autocast = get_autocast(args.precision)
|
351 |
+
cast_dtype = get_cast_dtype(args.precision)
|
352 |
+
|
353 |
+
model.train()
|
354 |
+
|
355 |
+
# setup logging
|
356 |
+
step_time_m = AverageMeter() # time for one optimizer step (> 1 batch if using gradient accum)
|
357 |
+
data_time_m = (
|
358 |
+
AverageMeter()
|
359 |
+
) # avg time to load one batch of both C4 AND laion (= 1 batch regardless of gradient accum)
|
360 |
+
end = time.time()
|
361 |
+
|
362 |
+
# loop through dataloader
|
363 |
+
for num_steps, batch in tqdm(
|
364 |
+
enumerate(train_dataloader),
|
365 |
+
disable=args.rank != 0,
|
366 |
+
total=total_training_steps,
|
367 |
+
initial=(epoch * num_batches_per_epoch),
|
368 |
+
):
|
369 |
+
data_time_m.update(time.time() - end)
|
370 |
+
|
371 |
+
global_step = num_steps + epoch * num_batches_per_epoch
|
372 |
+
|
373 |
+
#### VISION FORWARD PASS ####
|
374 |
+
images = batch["image"].to(device_id, dtype=cast_dtype, non_blocking=True).unsqueeze(1).unsqueeze(1)
|
375 |
+
input_ids = batch["input_ids"].to(device_id, dtype=cast_dtype, non_blocking=True)
|
376 |
+
attention_mask = batch["attention_mask"].to(device_id, dtype=cast_dtype, non_blocking=True)
|
377 |
+
labels = batch["labels"].to(device_id, dtype=cast_dtype, non_blocking=True)
|
378 |
+
|
379 |
+
with autocast():
|
380 |
+
loss_batch = model(
|
381 |
+
vision_x=images,
|
382 |
+
lang_x=input_ids,
|
383 |
+
attention_mask=attention_mask,
|
384 |
+
labels=labels,
|
385 |
+
)[0]
|
386 |
+
loss = loss_batch / args.gradient_accumulation_steps
|
387 |
+
loss_vision = loss # for logging
|
388 |
+
|
389 |
+
#### BACKWARD PASS ####
|
390 |
+
loss.backward()
|
391 |
+
|
392 |
+
#### LANGUAGE FORWARD PASS ####
|
393 |
+
if language_dataloader is not None:
|
394 |
+
batch_lang = next(language_dataloader)
|
395 |
+
lang_input_ids = batch_lang["input_ids"].to(device_id, dtype=cast_dtype, non_blocking=True)
|
396 |
+
lang_attention_mask = batch_lang["attention_mask"].to(device_id, dtype=cast_dtype, non_blocking=True)
|
397 |
+
lang_labels = batch_lang["labels"].to(device_id, dtype=cast_dtype, non_blocking=True)
|
398 |
+
|
399 |
+
with autocast():
|
400 |
+
lang_loss_batch = model(
|
401 |
+
vision_x=None,
|
402 |
+
lang_x=lang_input_ids,
|
403 |
+
attention_mask=lang_attention_mask,
|
404 |
+
labels=lang_labels,
|
405 |
+
)[0]
|
406 |
+
lang_loss = lang_loss_batch / args.gradient_accumulation_steps
|
407 |
+
#### BACKWARD PASS ####
|
408 |
+
lang_loss.backward()
|
409 |
+
|
410 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
411 |
+
|
412 |
+
# step optimizer and log
|
413 |
+
if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or (num_steps == num_batches_per_epoch - 1):
|
414 |
+
optimizer.step()
|
415 |
+
lr_scheduler.step()
|
416 |
+
optimizer.zero_grad()
|
417 |
+
|
418 |
+
# step time and reset end outside of rank 0
|
419 |
+
step_time_m.update(time.time() - end)
|
420 |
+
end = time.time()
|
421 |
+
|
422 |
+
if args.rank == 0 and args.report_to_wandb:
|
423 |
+
# compute within rank 0
|
424 |
+
samples_per_second = (
|
425 |
+
args.gradient_accumulation_steps * args.batch_size * args.world_size / step_time_m.val
|
426 |
+
)
|
427 |
+
samples_per_second_per_gpu = args.gradient_accumulation_steps * args.batch_size / step_time_m.val
|
428 |
+
|
429 |
+
wandb.log(
|
430 |
+
{
|
431 |
+
"data_time": data_time_m.avg,
|
432 |
+
"step_time": step_time_m.avg,
|
433 |
+
"samples_per_second": samples_per_second,
|
434 |
+
"samples_per_second_per_gpu": samples_per_second_per_gpu,
|
435 |
+
"lr": optimizer.param_groups[0]["lr"],
|
436 |
+
},
|
437 |
+
commit=False,
|
438 |
+
)
|
439 |
+
step_time_m.reset()
|
440 |
+
data_time_m.reset()
|
441 |
+
|
442 |
+
loss_log = {
|
443 |
+
"loss": loss.item(),
|
444 |
+
"loss_vision": loss_vision.item(),
|
445 |
+
"global_step": global_step,
|
446 |
+
}
|
447 |
+
if language_dataloader is not None:
|
448 |
+
loss_log["loss_lang"] = lang_loss.item()
|
449 |
+
|
450 |
+
wandb.log(loss_log, commit=True)
|
451 |
+
|
452 |
+
# Log loss to console
|
453 |
+
if ((num_steps + 1) % args.logging_steps == 0) and args.rank == 0:
|
454 |
+
print(
|
455 |
+
f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Loss: {loss.item():.3f}"
|
456 |
+
)
|
457 |
+
|
458 |
+
|
459 |
+
if __name__ == "__main__":
|
460 |
+
main()
|
mmgpt/train/train_utils.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Modified from https://github.com/mlfoundations/open_flamingo"""
|
2 |
+
import time
|
3 |
+
from contextlib import suppress
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
|
9 |
+
def get_cast_dtype(precision: str):
|
10 |
+
cast_dtype = None
|
11 |
+
if precision == "bf16":
|
12 |
+
cast_dtype = torch.bfloat16
|
13 |
+
elif precision == "fp16":
|
14 |
+
cast_dtype = torch.float16
|
15 |
+
return cast_dtype
|
16 |
+
|
17 |
+
|
18 |
+
def get_autocast(precision):
|
19 |
+
if precision == "amp":
|
20 |
+
return torch.cuda.amp.autocast
|
21 |
+
elif precision == "amp_bfloat16" or precision == "amp_bf16":
|
22 |
+
# amp_bfloat16 is more stable than amp float16 for clip training
|
23 |
+
return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
|
24 |
+
else:
|
25 |
+
return suppress
|
26 |
+
|
27 |
+
|
28 |
+
def train_one_epoch(
|
29 |
+
args,
|
30 |
+
model,
|
31 |
+
epoch,
|
32 |
+
laion_loader,
|
33 |
+
mmc4_loader,
|
34 |
+
tokenizer,
|
35 |
+
optimizer,
|
36 |
+
lr_scheduler,
|
37 |
+
device_id,
|
38 |
+
wandb,
|
39 |
+
):
|
40 |
+
num_batches_per_epoch_laion = laion_loader.num_batches
|
41 |
+
num_batches_per_epoch_mmc4 = mmc4_loader.num_batches
|
42 |
+
|
43 |
+
assert (
|
44 |
+
num_batches_per_epoch_laion == num_batches_per_epoch_mmc4
|
45 |
+
), "Number of batches in laion and mmc4 datasets must be the same"
|
46 |
+
num_batches_per_epoch = num_batches_per_epoch_mmc4
|
47 |
+
total_training_steps = num_batches_per_epoch * args.num_epochs
|
48 |
+
|
49 |
+
autocast = get_autocast(args.precision)
|
50 |
+
cast_dtype = get_cast_dtype(args.precision)
|
51 |
+
|
52 |
+
media_token_id = tokenizer("<image>", add_special_tokens=False)["input_ids"][-1]
|
53 |
+
endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1]
|
54 |
+
|
55 |
+
model.train()
|
56 |
+
|
57 |
+
# setup logging
|
58 |
+
step_time_m = AverageMeter() # time for one optimizer step (> 1 batch if using gradient accum)
|
59 |
+
data_time_m = (
|
60 |
+
AverageMeter()
|
61 |
+
) # avg time to load one batch of both C4 AND laion (= 1 batch regardless of gradient accum)
|
62 |
+
end = time.time()
|
63 |
+
|
64 |
+
# loop through dataloader
|
65 |
+
for num_steps, (batch_laion, batch_mmc4) in tqdm(
|
66 |
+
enumerate(zip(laion_loader, mmc4_loader)),
|
67 |
+
disable=args.rank != 0,
|
68 |
+
total=total_training_steps,
|
69 |
+
initial=(epoch * num_batches_per_epoch),
|
70 |
+
):
|
71 |
+
data_time_m.update(time.time() - end)
|
72 |
+
|
73 |
+
global_step = num_steps + epoch * num_batches_per_epoch
|
74 |
+
|
75 |
+
#### LAION FORWARD PASS ####
|
76 |
+
images = batch_laion[0].to(device_id, dtype=cast_dtype, non_blocking=True).unsqueeze(1).unsqueeze(1)
|
77 |
+
|
78 |
+
input_ids = batch_laion[1][0].to(device_id, dtype=cast_dtype, non_blocking=True)
|
79 |
+
attention_mask = batch_laion[1][1].to(device_id, dtype=cast_dtype, non_blocking=True)
|
80 |
+
|
81 |
+
labels = input_ids.clone()
|
82 |
+
labels[labels == tokenizer.pad_token_id] = -100
|
83 |
+
labels[:, 0] = -100
|
84 |
+
labels[labels == media_token_id] = -100
|
85 |
+
labels.to(device_id)
|
86 |
+
|
87 |
+
with autocast():
|
88 |
+
loss_laion = model(
|
89 |
+
vision_x=images,
|
90 |
+
lang_x=input_ids,
|
91 |
+
attention_mask=attention_mask,
|
92 |
+
labels=labels,
|
93 |
+
)[0]
|
94 |
+
divided_loss_laion = loss_laion / args.gradient_accumulation_steps
|
95 |
+
|
96 |
+
#### C4 FORWARD PASS ####
|
97 |
+
images = batch_mmc4[0].to(device_id, dtype=cast_dtype, non_blocking=True).unsqueeze(2)
|
98 |
+
input_ids = torch.stack([x[0] for x in batch_mmc4[1]]).squeeze(1)
|
99 |
+
attention_mask = torch.stack([x[1] for x in batch_mmc4[1]]).squeeze(1)
|
100 |
+
|
101 |
+
# NOTE: irena: expected shape of clip_text_input_ids / attention_mask is (N, I, max_seq_len)
|
102 |
+
labels = input_ids.clone()
|
103 |
+
labels[labels == tokenizer.pad_token_id] = -100
|
104 |
+
labels[:, 0] = -100
|
105 |
+
|
106 |
+
for i in range(labels.shape[0]):
|
107 |
+
# remove loss for any token before the first <image> token
|
108 |
+
label_idx = 0
|
109 |
+
while label_idx < labels.shape[1] and labels[i][label_idx] != media_token_id:
|
110 |
+
labels[i][label_idx] = -100
|
111 |
+
label_idx += 1
|
112 |
+
|
113 |
+
# get index of all endofchunk tokens in the sequence
|
114 |
+
endofchunk_idxs = torch.where(labels[i] == endofchunk_token_id)[0]
|
115 |
+
for endofchunk_idx in endofchunk_idxs:
|
116 |
+
token_idx = endofchunk_idx + 1
|
117 |
+
while token_idx < labels.shape[1] and labels[i][token_idx] != media_token_id:
|
118 |
+
labels[i][token_idx] = -100
|
119 |
+
token_idx += 1
|
120 |
+
|
121 |
+
labels[labels == media_token_id] = -100
|
122 |
+
labels.to(device_id)
|
123 |
+
|
124 |
+
with autocast():
|
125 |
+
loss_mmc4 = model(
|
126 |
+
vision_x=images,
|
127 |
+
lang_x=input_ids,
|
128 |
+
attention_mask=attention_mask,
|
129 |
+
labels=labels,
|
130 |
+
)[0]
|
131 |
+
|
132 |
+
# if loss is nan, skip this batch
|
133 |
+
if torch.isnan(loss_mmc4):
|
134 |
+
print("loss is nan, skipping this batch")
|
135 |
+
print("input_ids: ", tokenizer.batch_decode(input_ids))
|
136 |
+
print("labels: ", labels)
|
137 |
+
print("images: ", images)
|
138 |
+
optimizer.zero_grad()
|
139 |
+
continue
|
140 |
+
|
141 |
+
divided_loss_mmc4 = loss_mmc4 / args.gradient_accumulation_steps
|
142 |
+
|
143 |
+
#### BACKWARD PASS ####
|
144 |
+
loss = divided_loss_laion * args.loss_multiplier_laion + divided_loss_mmc4 * args.loss_multiplier_mmc4
|
145 |
+
loss.backward()
|
146 |
+
|
147 |
+
#### MASK GRADIENTS FOR EMBEDDINGS ####
|
148 |
+
# Note (anas): Do not apply weight decay to embeddings as it will break this function.
|
149 |
+
def mask_embedding(m):
|
150 |
+
if isinstance(m, torch.nn.Embedding) and m.weight.requires_grad:
|
151 |
+
zero_mask = torch.zeros_like(m.weight.grad)
|
152 |
+
zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id])
|
153 |
+
zero_mask[endofchunk_token_id] = torch.ones_like(zero_mask[endofchunk_token_id])
|
154 |
+
m.weight.grad = m.weight.grad * zero_mask
|
155 |
+
|
156 |
+
model.apply(mask_embedding)
|
157 |
+
|
158 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
159 |
+
|
160 |
+
# step optimizer and log
|
161 |
+
if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or (num_steps == num_batches_per_epoch - 1):
|
162 |
+
optimizer.step()
|
163 |
+
lr_scheduler.step()
|
164 |
+
optimizer.zero_grad()
|
165 |
+
|
166 |
+
# step time and reset end outside of rank 0
|
167 |
+
step_time_m.update(time.time() - end)
|
168 |
+
end = time.time()
|
169 |
+
|
170 |
+
if args.rank == 0 and args.report_to_wandb:
|
171 |
+
# compute within rank 0
|
172 |
+
laion_samples_per_second = (
|
173 |
+
args.gradient_accumulation_steps * args.batch_size_laion * args.world_size / step_time_m.val
|
174 |
+
)
|
175 |
+
laion_samples_per_second_per_gpu = (
|
176 |
+
args.gradient_accumulation_steps * args.batch_size_laion / step_time_m.val
|
177 |
+
)
|
178 |
+
|
179 |
+
c4_samples_per_second = (
|
180 |
+
args.gradient_accumulation_steps * args.batch_size_mmc4 * args.world_size / step_time_m.val
|
181 |
+
)
|
182 |
+
c4_samples_per_second_per_gpu = (
|
183 |
+
args.gradient_accumulation_steps * args.batch_size_mmc4 / step_time_m.val
|
184 |
+
)
|
185 |
+
|
186 |
+
wandb.log(
|
187 |
+
{
|
188 |
+
"data_time": data_time_m.avg,
|
189 |
+
"step_time": step_time_m.avg,
|
190 |
+
"laion_samples_per_second": laion_samples_per_second,
|
191 |
+
"laion_samples_per_second_per_gpu": laion_samples_per_second_per_gpu,
|
192 |
+
"c4_samples_per_second": c4_samples_per_second,
|
193 |
+
"c4_samples_per_second_per_gpu": c4_samples_per_second_per_gpu,
|
194 |
+
"lr": optimizer.param_groups[0]["lr"],
|
195 |
+
},
|
196 |
+
commit=False,
|
197 |
+
)
|
198 |
+
step_time_m.reset()
|
199 |
+
data_time_m.reset()
|
200 |
+
|
201 |
+
wandb.log(
|
202 |
+
{
|
203 |
+
"loss_laion": divided_loss_laion.item(),
|
204 |
+
"global_step": global_step,
|
205 |
+
},
|
206 |
+
commit=False,
|
207 |
+
)
|
208 |
+
wandb.log(
|
209 |
+
{"loss_mmc4": divided_loss_mmc4.item(), "global_step": global_step},
|
210 |
+
commit=True,
|
211 |
+
)
|
212 |
+
|
213 |
+
# Log loss to console
|
214 |
+
if ((num_steps + 1) % args.logging_steps == 0) and args.rank == 0:
|
215 |
+
print(
|
216 |
+
f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Loss LAION: {loss_laion.item():.3f} // Loss MMC4: {loss_mmc4.item():.3f}"
|
217 |
+
)
|
218 |
+
|
219 |
+
|
220 |
+
def get_checkpoint(model: torch.nn.Module):
|
221 |
+
state_dict = model.state_dict()
|
222 |
+
parameters = {k: v for k, v in model.named_parameters()}
|
223 |
+
# remove duplicate parameters
|
224 |
+
duplicate_keys = set(state_dict.keys()) - set(parameters.keys())
|
225 |
+
for k in duplicate_keys:
|
226 |
+
del state_dict[k]
|
227 |
+
# remove non-grad parameters
|
228 |
+
for name, p in parameters.items():
|
229 |
+
if not p.requires_grad:
|
230 |
+
del state_dict[name]
|
231 |
+
|
232 |
+
return state_dict
|
233 |
+
|
234 |
+
|
235 |
+
class AverageMeter(object):
|
236 |
+
"""Computes and stores the average and current value"""
|
237 |
+
|
238 |
+
def __init__(self):
|
239 |
+
self.reset()
|
240 |
+
|
241 |
+
def reset(self):
|
242 |
+
self.val = 0
|
243 |
+
self.avg = 0
|
244 |
+
self.sum = 0
|
245 |
+
self.count = 0
|
246 |
+
|
247 |
+
def update(self, val, n=1):
|
248 |
+
self.val = val
|
249 |
+
self.sum += val * n
|
250 |
+
self.count += n
|
251 |
+
self.avg = self.sum / self.count
|
requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
einops
|
2 |
+
einops-exts
|
3 |
+
transformers
|
4 |
+
peft
|
5 |
+
bigmodelvis
|
6 |
+
torch
|
7 |
+
torchvision
|
8 |
+
pillow
|
9 |
+
more-itertools
|
10 |
+
datasets
|
11 |
+
braceexpand
|
12 |
+
webdataset
|
13 |
+
wandb
|
14 |
+
nltk
|
15 |
+
scipy
|
16 |
+
inflection
|
17 |
+
sentencepiece
|
18 |
+
open_clip_torch
|
19 |
+
mmengine
|
20 |
+
gradio
|
setup.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
from setuptools import find_packages, setup
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
with Path(Path(__file__).parent, "README.md").open(encoding="utf-8") as file:
|
7 |
+
long_description = file.read()
|
8 |
+
|
9 |
+
# TODO: This is a hack to get around the fact that we can't read the requirements.txt file, we should fix this.
|
10 |
+
# def _read_reqs(relpath):
|
11 |
+
# fullpath = os.path.join(Path(__file__).parent, relpath)
|
12 |
+
# with open(fullpath) as f:
|
13 |
+
# return [
|
14 |
+
# s.strip()
|
15 |
+
# for s in f.readlines()
|
16 |
+
# if (s.strip() and not s.startswith("#"))
|
17 |
+
# ]
|
18 |
+
|
19 |
+
REQUIREMENTS = [
|
20 |
+
"einops",
|
21 |
+
"einops-exts",
|
22 |
+
"transformers",
|
23 |
+
"torch",
|
24 |
+
"torchvision",
|
25 |
+
"pillow",
|
26 |
+
"more-itertools",
|
27 |
+
"datasets",
|
28 |
+
"braceexpand",
|
29 |
+
"webdataset",
|
30 |
+
"wandb",
|
31 |
+
"nltk",
|
32 |
+
"scipy",
|
33 |
+
"inflection",
|
34 |
+
"sentencepiece",
|
35 |
+
"open_clip_torch",
|
36 |
+
]
|
37 |
+
|
38 |
+
setup(
|
39 |
+
name="mmgpt",
|
40 |
+
packages=find_packages(),
|
41 |
+
include_package_data=True,
|
42 |
+
version="0.0.1",
|
43 |
+
license="Apache 2.0",
|
44 |
+
description="An open-source framework for multi-modality instruction fine-tuning",
|
45 |
+
long_description=long_description,
|
46 |
+
long_description_content_type="text/markdown",
|
47 |
+
data_files=[(".", ["README.md"])],
|
48 |
+
keywords=["machine learning"],
|
49 |
+
install_requires=REQUIREMENTS,
|
50 |
+
)
|