Spaces:
Running
on
Zero
Running
on
Zero
infinity1096
commited on
Commit
·
c8b42eb
1
Parent(s):
3991736
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .gitignore +152 -0
- LICENSE.txt +58 -0
- UniCeption/.gitignore +167 -0
- UniCeption/.pre-commit-config.yaml +18 -0
- UniCeption/.pylintrc +399 -0
- UniCeption/LICENSE +28 -0
- UniCeption/README.md +155 -0
- UniCeption/examples/models/cosmos/autoencoding.py +48 -0
- UniCeption/examples/models/cosmos/example.png +3 -0
- UniCeption/examples/models/cosmos/example_decoded.png +3 -0
- UniCeption/examples/models/dust3r/convert_dust3r_weights_to_uniception.py +331 -0
- UniCeption/examples/models/dust3r/dust3r.py +261 -0
- UniCeption/examples/models/dust3r/profile_dust3r.py +47 -0
- UniCeption/pyproject.toml +21 -0
- UniCeption/scripts/check_dependencies.py +49 -0
- UniCeption/scripts/download_checkpoints.py +48 -0
- UniCeption/scripts/install_croco_rope.py +62 -0
- UniCeption/scripts/prepare_offline_install.py +399 -0
- UniCeption/scripts/validate_installation.py +213 -0
- UniCeption/setup.py +188 -0
- UniCeption/tests/models/encoders/conftest.py +26 -0
- UniCeption/tests/models/encoders/test_encoders.py +204 -0
- UniCeption/tests/models/encoders/viz_image_encoders.py +294 -0
- UniCeption/tests/models/info_sharing/viz_mulit_view_cross_attn_transformers.py +337 -0
- UniCeption/uniception/__init__.py +0 -0
- UniCeption/uniception/models/encoders/README.md +129 -0
- UniCeption/uniception/models/encoders/__init__.py +235 -0
- UniCeption/uniception/models/encoders/base.py +157 -0
- UniCeption/uniception/models/encoders/cosmos.py +137 -0
- UniCeption/uniception/models/encoders/croco.py +457 -0
- UniCeption/uniception/models/encoders/dense_rep_encoder.py +344 -0
- UniCeption/uniception/models/encoders/dinov2.py +333 -0
- UniCeption/uniception/models/encoders/global_rep_encoder.py +115 -0
- UniCeption/uniception/models/encoders/image_normalizations.py +35 -0
- UniCeption/uniception/models/encoders/list.py +10 -0
- UniCeption/uniception/models/encoders/naradio.py +502 -0
- UniCeption/uniception/models/encoders/patch_embedder.py +235 -0
- UniCeption/uniception/models/encoders/radio.py +367 -0
- UniCeption/uniception/models/encoders/utils.py +86 -0
- UniCeption/uniception/models/factory/__init__.py +3 -0
- UniCeption/uniception/models/factory/dust3r.py +332 -0
- UniCeption/uniception/models/info_sharing/README.md +18 -0
- UniCeption/uniception/models/info_sharing/__init__.py +35 -0
- UniCeption/uniception/models/info_sharing/alternating_attention_transformer.py +944 -0
- UniCeption/uniception/models/info_sharing/base.py +116 -0
- UniCeption/uniception/models/info_sharing/cross_attention_transformer.py +582 -0
- UniCeption/uniception/models/info_sharing/diff_cross_attention_transformer.py +588 -0
- UniCeption/uniception/models/info_sharing/global_attention_transformer.py +1107 -0
- UniCeption/uniception/models/libs/__init__.py +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
uniflowmatch.egg-info/**
|
2 |
+
ufm_model_refine/**
|
3 |
+
ufm_model/**
|
4 |
+
/home/inf/UniFlowMatch/convert_old_ckpt.py
|
5 |
+
checkpoints/**
|
6 |
+
|
7 |
+
# Byte-compiled / optimized / DLL files
|
8 |
+
__pycache__/
|
9 |
+
**/__pycache__/
|
10 |
+
*.py[cod]
|
11 |
+
*$py.class
|
12 |
+
|
13 |
+
# C extensions
|
14 |
+
*.so
|
15 |
+
|
16 |
+
# Distribution / packaging
|
17 |
+
.Python
|
18 |
+
build/
|
19 |
+
develop-eggs/
|
20 |
+
dist/
|
21 |
+
downloads/
|
22 |
+
eggs/
|
23 |
+
.eggs/
|
24 |
+
lib/
|
25 |
+
lib64/
|
26 |
+
parts/
|
27 |
+
sdist/
|
28 |
+
var/
|
29 |
+
wheels/
|
30 |
+
pip-wheel-metadata/
|
31 |
+
share/python-wheels/
|
32 |
+
*.egg-info/
|
33 |
+
.installed.cfg
|
34 |
+
*.egg
|
35 |
+
MANIFEST
|
36 |
+
|
37 |
+
# PyInstaller
|
38 |
+
# Usually these files are written by a python script from a template
|
39 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
40 |
+
*.manifest
|
41 |
+
*.spec
|
42 |
+
|
43 |
+
# Installer logs
|
44 |
+
pip-log.txt
|
45 |
+
pip-delete-this-directory.txt
|
46 |
+
|
47 |
+
# Unit test / coverage reports
|
48 |
+
htmlcov/
|
49 |
+
.tox/
|
50 |
+
.nox/
|
51 |
+
.coverage
|
52 |
+
.coverage.*
|
53 |
+
.cache
|
54 |
+
nosetests.xml
|
55 |
+
coverage.xml
|
56 |
+
*.cover
|
57 |
+
*.py,cover
|
58 |
+
.hypothesis/
|
59 |
+
.pytest_cache/
|
60 |
+
cover/
|
61 |
+
|
62 |
+
# Translations
|
63 |
+
*.mo
|
64 |
+
*.pot
|
65 |
+
|
66 |
+
# Django stuff:
|
67 |
+
*.log
|
68 |
+
local_settings.py
|
69 |
+
db.sqlite3
|
70 |
+
db.sqlite3-journal
|
71 |
+
|
72 |
+
# Flask stuff:
|
73 |
+
instance/
|
74 |
+
.webassets-cache
|
75 |
+
|
76 |
+
# Scrapy stuff:
|
77 |
+
.scrapy
|
78 |
+
|
79 |
+
# Sphinx documentation
|
80 |
+
docs/_build/
|
81 |
+
|
82 |
+
# PyBuilder
|
83 |
+
target/
|
84 |
+
|
85 |
+
# Jupyter Notebook
|
86 |
+
.ipynb_checkpoints
|
87 |
+
|
88 |
+
# IPython
|
89 |
+
profile_default/
|
90 |
+
ipython_config.py
|
91 |
+
|
92 |
+
# pyenv
|
93 |
+
.python-version
|
94 |
+
|
95 |
+
# pipenv
|
96 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
97 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
98 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
99 |
+
# install all needed dependencies.
|
100 |
+
#Pipfile.lock
|
101 |
+
|
102 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
103 |
+
__pypackages__/
|
104 |
+
|
105 |
+
# Celery stuff
|
106 |
+
celerybeat-schedule
|
107 |
+
celerybeat.pid
|
108 |
+
|
109 |
+
# SageMath parsed files
|
110 |
+
*.sage.py
|
111 |
+
|
112 |
+
# Environments
|
113 |
+
.env
|
114 |
+
.venv
|
115 |
+
env/
|
116 |
+
venv/
|
117 |
+
ENV/
|
118 |
+
env.bak/
|
119 |
+
venv.bak/
|
120 |
+
|
121 |
+
# Spyder project settings
|
122 |
+
.spyderproject
|
123 |
+
.spyproject
|
124 |
+
|
125 |
+
# Rope project settings
|
126 |
+
.ropeproject
|
127 |
+
|
128 |
+
# mkdocs documentation
|
129 |
+
/site
|
130 |
+
|
131 |
+
# mypy
|
132 |
+
.mypy_cache/
|
133 |
+
.dmypy.json
|
134 |
+
dmypy.json
|
135 |
+
|
136 |
+
# Pyre type checker
|
137 |
+
.pyre/
|
138 |
+
|
139 |
+
# pytype static type analyzer
|
140 |
+
.pytype/
|
141 |
+
|
142 |
+
# Profiling data
|
143 |
+
.prof
|
144 |
+
|
145 |
+
# Folder specific to your needs
|
146 |
+
**/tmp/
|
147 |
+
**/outputs/skyseg.onnx
|
148 |
+
skyseg.onnx
|
149 |
+
|
150 |
+
# pixi environments
|
151 |
+
.pixi
|
152 |
+
*.egg-info
|
LICENSE.txt
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Attribution-NonCommercial 2.5 Generic
|
2 |
+
CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE LEGAL SERVICES. DISTRIBUTION OF THIS LICENSE DOES NOT CREATE AN ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES REGARDING THE INFORMATION PROVIDED, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM ITS USE.
|
3 |
+
License
|
4 |
+
|
5 |
+
THE WORK (AS DEFINED BELOW) IS PROVIDED UNDER THE TERMS OF THIS CREATIVE COMMONS PUBLIC LICENSE ("CCPL" OR "LICENSE"). THE WORK IS PROTECTED BY COPYRIGHT AND/OR OTHER APPLICABLE LAW. ANY USE OF THE WORK OTHER THAN AS AUTHORIZED UNDER THIS LICENSE OR COPYRIGHT LAW IS PROHIBITED.
|
6 |
+
|
7 |
+
BY EXERCISING ANY RIGHTS TO THE WORK PROVIDED HERE, YOU ACCEPT AND AGREE TO BE BOUND BY THE TERMS OF THIS LICENSE. THE LICENSOR GRANTS YOU THE RIGHTS CONTAINED HERE IN CONSIDERATION OF YOUR ACCEPTANCE OF SUCH TERMS AND CONDITIONS.
|
8 |
+
|
9 |
+
1. Definitions
|
10 |
+
|
11 |
+
"Collective Work" means a work, such as a periodical issue, anthology or encyclopedia, in which the Work in its entirety in unmodified form, along with a number of other contributions, constituting separate and independent works in themselves, are assembled into a collective whole. A work that constitutes a Collective Work will not be considered a Derivative Work (as defined below) for the purposes of this License.
|
12 |
+
"Derivative Work" means a work based upon the Work or upon the Work and other pre-existing works, such as a translation, musical arrangement, dramatization, fictionalization, motion picture version, sound recording, art reproduction, abridgment, condensation, or any other form in which the Work may be recast, transformed, or adapted, except that a work that constitutes a Collective Work will not be considered a Derivative Work for the purpose of this License. For the avoidance of doubt, where the Work is a musical composition or sound recording, the synchronization of the Work in timed-relation with a moving image ("synching") will be considered a Derivative Work for the purpose of this License.
|
13 |
+
"Licensor" means the individual or entity that offers the Work under the terms of this License.
|
14 |
+
"Original Author" means the individual or entity who created the Work.
|
15 |
+
"Work" means the copyrightable work of authorship offered under the terms of this License.
|
16 |
+
"You" means an individual or entity exercising rights under this License who has not previously violated the terms of this License with respect to the Work, or who has received express permission from the Licensor to exercise rights under this License despite a previous violation.
|
17 |
+
2. Fair Use Rights. Nothing in this license is intended to reduce, limit, or restrict any rights arising from fair use, first sale or other limitations on the exclusive rights of the copyright owner under copyright law or other applicable laws.
|
18 |
+
|
19 |
+
3. License Grant. Subject to the terms and conditions of this License, Licensor hereby grants You a worldwide, royalty-free, non-exclusive, perpetual (for the duration of the applicable copyright) license to exercise the rights in the Work as stated below:
|
20 |
+
|
21 |
+
to reproduce the Work, to incorporate the Work into one or more Collective Works, and to reproduce the Work as incorporated in the Collective Works;
|
22 |
+
to create and reproduce Derivative Works;
|
23 |
+
to distribute copies or phonorecords of, display publicly, perform publicly, and perform publicly by means of a digital audio transmission the Work including as incorporated in Collective Works;
|
24 |
+
to distribute copies or phonorecords of, display publicly, perform publicly, and perform publicly by means of a digital audio transmission Derivative Works;
|
25 |
+
The above rights may be exercised in all media and formats whether now known or hereafter devised. The above rights include the right to make such modifications as are technically necessary to exercise the rights in other media and formats. All rights not expressly granted by Licensor are hereby reserved, including but not limited to the rights set forth in Sections 4(d) and 4(e).
|
26 |
+
|
27 |
+
4. Restrictions. The license granted in Section 3 above is expressly made subject to and limited by the following restrictions:
|
28 |
+
|
29 |
+
You may distribute, publicly display, publicly perform, or publicly digitally perform the Work only under the terms of this License, and You must include a copy of, or the Uniform Resource Identifier for, this License with every copy or phonorecord of the Work You distribute, publicly display, publicly perform, or publicly digitally perform. You may not offer or impose any terms on the Work that alter or restrict the terms of this License or the recipients' exercise of the rights granted hereunder. You may not sublicense the Work. You must keep intact all notices that refer to this License and to the disclaimer of warranties. You may not distribute, publicly display, publicly perform, or publicly digitally perform the Work with any technological measures that control access or use of the Work in a manner inconsistent with the terms of this License Agreement. The above applies to the Work as incorporated in a Collective Work, but this does not require the Collective Work apart from the Work itself to be made subject to the terms of this License. If You create a Collective Work, upon notice from any Licensor You must, to the extent practicable, remove from the Collective Work any credit as required by clause 4(c), as requested. If You create a Derivative Work, upon notice from any Licensor You must, to the extent practicable, remove from the Derivative Work any credit as required by clause 4(c), as requested.
|
30 |
+
You may not exercise any of the rights granted to You in Section 3 above in any manner that is primarily intended for or directed toward commercial advantage or private monetary compensation. The exchange of the Work for other copyrighted works by means of digital file-sharing or otherwise shall not be considered to be intended for or directed toward commercial advantage or private monetary compensation, provided there is no payment of any monetary compensation in connection with the exchange of copyrighted works.
|
31 |
+
If you distribute, publicly display, publicly perform, or publicly digitally perform the Work or any Derivative Works or Collective Works, You must keep intact all copyright notices for the Work and provide, reasonable to the medium or means You are utilizing: (i) the name of Original Author (or pseudonym, if applicable) if supplied, and/or (ii) if the Original Author and/or Licensor designate another party or parties (e.g. a sponsor institute, publishing entity, journal) for attribution in Licensor's copyright notice, terms of service or by other reasonable means, the name of such party or parties; the title of the Work if supplied; to the extent reasonably practicable, the Uniform Resource Identifier, if any, that Licensor specifies to be associated with the Work, unless such URI does not refer to the copyright notice or licensing information for the Work; and in the case of a Derivative Work, a credit identifying the use of the Work in the Derivative Work (e.g., "French translation of the Work by Original Author," or "Screenplay based on original Work by Original Author"). Such credit may be implemented in any reasonable manner; provided, however, that in the case of a Derivative Work or Collective Work, at a minimum such credit will appear where any other comparable authorship credit appears and in a manner at least as prominent as such other comparable authorship credit.
|
32 |
+
For the avoidance of doubt, where the Work is a musical composition:
|
33 |
+
|
34 |
+
Performance Royalties Under Blanket Licenses . Licensor reserves the exclusive right to collect, whether individually or via a performance rights society (e.g. ASCAP, BMI, SESAC), royalties for the public performance or public digital performance (e.g. webcast) of the Work if that performance is primarily intended for or directed toward commercial advantage or private monetary compensation.
|
35 |
+
Mechanical Rights and Statutory Royalties . Licensor reserves the exclusive right to collect, whether individually or via a music rights agency or designated agent (e.g. Harry Fox Agency), royalties for any phonorecord You create from the Work ("cover version") and distribute, subject to the compulsory license created by 17 USC Section 115 of the US Copyright Act (or the equivalent in other jurisdictions), if Your distribution of such cover version is primarily intended for or directed toward commercial advantage or private monetary compensation.
|
36 |
+
Webcasting Rights and Statutory Royalties. For the avoidance of doubt, where the Work is a sound recording, Licensor reserves the exclusive right to collect, whether individually or via a performance-rights society (e.g. SoundExchange), royalties for the public digital performance (e.g. webcast) of the Work, subject to the compulsory license created by 17 USC Section 114 of the US Copyright Act (or the equivalent in other jurisdictions), if Your public digital performance is primarily intended for or directed toward commercial advantage or private monetary compensation.
|
37 |
+
5. Representations, Warranties and Disclaimer
|
38 |
+
|
39 |
+
UNLESS OTHERWISE MUTUALLY AGREED TO BY THE PARTIES IN WRITING, LICENSOR OFFERS THE WORK AS-IS AND MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE WORK, EXPRESS, IMPLIED, STATUTORY OR OTHERWISE, INCLUDING, WITHOUT LIMITATION, WARRANTIES OF TITLE, MERCHANTIBILITY, FITNESS FOR A PARTICULAR PURPOSE, NONINFRINGEMENT, OR THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OF ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OF IMPLIED WARRANTIES, SO SUCH EXCLUSION MAY NOT APPLY TO YOU.
|
40 |
+
|
41 |
+
6. Limitation on Liability. EXCEPT TO THE EXTENT REQUIRED BY APPLICABLE LAW, IN NO EVENT WILL LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY FOR ANY SPECIAL, INCIDENTAL, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES ARISING OUT OF THIS LICENSE OR THE USE OF THE WORK, EVEN IF LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
42 |
+
|
43 |
+
7. Termination
|
44 |
+
|
45 |
+
This License and the rights granted hereunder will terminate automatically upon any breach by You of the terms of this License. Individuals or entities who have received Derivative Works or Collective Works from You under this License, however, will not have their licenses terminated provided such individuals or entities remain in full compliance with those licenses. Sections 1, 2, 5, 6, 7, and 8 will survive any termination of this License.
|
46 |
+
Subject to the above terms and conditions, the license granted here is perpetual (for the duration of the applicable copyright in the Work). Notwithstanding the above, Licensor reserves the right to release the Work under different license terms or to stop distributing the Work at any time; provided, however that any such election will not serve to withdraw this License (or any other license that has been, or is required to be, granted under the terms of this License), and this License will continue in full force and effect unless terminated as stated above.
|
47 |
+
8. Miscellaneous
|
48 |
+
|
49 |
+
Each time You distribute or publicly digitally perform the Work or a Collective Work, the Licensor offers to the recipient a license to the Work on the same terms and conditions as the license granted to You under this License.
|
50 |
+
Each time You distribute or publicly digitally perform a Derivative Work, Licensor offers to the recipient a license to the original Work on the same terms and conditions as the license granted to You under this License.
|
51 |
+
If any provision of this License is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this License, and without further action by the parties to this agreement, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable.
|
52 |
+
No term or provision of this License shall be deemed waived and no breach consented to unless such waiver or consent shall be in writing and signed by the party to be charged with such waiver or consent.
|
53 |
+
This License constitutes the entire agreement between the parties with respect to the Work licensed here. There are no understandings, agreements or representations with respect to the Work not specified here. Licensor shall not be bound by any additional provisions that may appear in any communication from You. This License may not be modified without the mutual written agreement of the Licensor and You.
|
54 |
+
Creative Commons is not a party to this License, and makes no warranty whatsoever in connection with the Work. Creative Commons will not be liable to You or any party on any legal theory for any damages whatsoever, including without limitation any general, special, incidental or consequential damages arising in connection to this license. Notwithstanding the foregoing two (2) sentences, if Creative Commons has expressly identified itself as the Licensor hereunder, it shall have all rights and obligations of Licensor.
|
55 |
+
|
56 |
+
Except for the limited purpose of indicating to the public that the Work is licensed under the CCPL, neither party will use the trademark "Creative Commons" or any related trademark or logo of Creative Commons without the prior written consent of Creative Commons. Any permitted use will be in compliance with Creative Commons' then-current trademark usage guidelines, as may be published on its website or otherwise made available upon request from time to time.
|
57 |
+
|
58 |
+
Creative Commons may be contacted at https://creativecommons.org/ .
|
UniCeption/.gitignore
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Local Folders
|
2 |
+
checkpoints/
|
3 |
+
local/
|
4 |
+
reference_data/
|
5 |
+
|
6 |
+
# Byte-compiled / optimized / DLL files
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
*$py.class
|
10 |
+
|
11 |
+
# C extensions
|
12 |
+
*.so
|
13 |
+
|
14 |
+
# Distribution / packaging
|
15 |
+
.Python
|
16 |
+
build/
|
17 |
+
develop-eggs/
|
18 |
+
dist/
|
19 |
+
downloads/
|
20 |
+
eggs/
|
21 |
+
.eggs/
|
22 |
+
lib/
|
23 |
+
lib64/
|
24 |
+
parts/
|
25 |
+
sdist/
|
26 |
+
var/
|
27 |
+
wheels/
|
28 |
+
share/python-wheels/
|
29 |
+
*.egg-info/
|
30 |
+
.installed.cfg
|
31 |
+
*.egg
|
32 |
+
MANIFEST
|
33 |
+
|
34 |
+
# PyInstaller
|
35 |
+
# Usually these files are written by a python script from a template
|
36 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
37 |
+
*.manifest
|
38 |
+
*.spec
|
39 |
+
|
40 |
+
# Installer logs
|
41 |
+
pip-log.txt
|
42 |
+
pip-delete-this-directory.txt
|
43 |
+
|
44 |
+
# Unit test / coverage reports
|
45 |
+
htmlcov/
|
46 |
+
.tox/
|
47 |
+
.nox/
|
48 |
+
.coverage
|
49 |
+
.coverage.*
|
50 |
+
.cache
|
51 |
+
nosetests.xml
|
52 |
+
coverage.xml
|
53 |
+
*.cover
|
54 |
+
*.py,cover
|
55 |
+
.hypothesis/
|
56 |
+
.pytest_cache/
|
57 |
+
cover/
|
58 |
+
|
59 |
+
# Translations
|
60 |
+
*.mo
|
61 |
+
*.pot
|
62 |
+
|
63 |
+
# Django stuff:
|
64 |
+
*.log
|
65 |
+
local_settings.py
|
66 |
+
db.sqlite3
|
67 |
+
db.sqlite3-journal
|
68 |
+
|
69 |
+
# Flask stuff:
|
70 |
+
instance/
|
71 |
+
.webassets-cache
|
72 |
+
|
73 |
+
# Scrapy stuff:
|
74 |
+
.scrapy
|
75 |
+
|
76 |
+
# Sphinx documentation
|
77 |
+
docs/_build/
|
78 |
+
|
79 |
+
# PyBuilder
|
80 |
+
.pybuilder/
|
81 |
+
target/
|
82 |
+
|
83 |
+
# Jupyter Notebook
|
84 |
+
.ipynb_checkpoints
|
85 |
+
|
86 |
+
# IPython
|
87 |
+
profile_default/
|
88 |
+
ipython_config.py
|
89 |
+
|
90 |
+
# pyenv
|
91 |
+
# For a library or package, you might want to ignore these files since the code is
|
92 |
+
# intended to run in multiple environments; otherwise, check them in:
|
93 |
+
# .python-version
|
94 |
+
|
95 |
+
# pipenv
|
96 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
97 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
98 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
99 |
+
# install all needed dependencies.
|
100 |
+
#Pipfile.lock
|
101 |
+
|
102 |
+
# poetry
|
103 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
104 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
105 |
+
# commonly ignored for libraries.
|
106 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
107 |
+
#poetry.lock
|
108 |
+
|
109 |
+
# pdm
|
110 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
111 |
+
#pdm.lock
|
112 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
113 |
+
# in version control.
|
114 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
115 |
+
.pdm.toml
|
116 |
+
.pdm-python
|
117 |
+
.pdm-build/
|
118 |
+
|
119 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
120 |
+
__pypackages__/
|
121 |
+
|
122 |
+
# Celery stuff
|
123 |
+
celerybeat-schedule
|
124 |
+
celerybeat.pid
|
125 |
+
|
126 |
+
# SageMath parsed files
|
127 |
+
*.sage.py
|
128 |
+
|
129 |
+
# Environments
|
130 |
+
.env
|
131 |
+
.venv
|
132 |
+
env/
|
133 |
+
venv/
|
134 |
+
ENV/
|
135 |
+
env.bak/
|
136 |
+
venv.bak/
|
137 |
+
|
138 |
+
# Spyder project settings
|
139 |
+
.spyderproject
|
140 |
+
.spyproject
|
141 |
+
|
142 |
+
# Rope project settings
|
143 |
+
.ropeproject
|
144 |
+
|
145 |
+
# mkdocs documentation
|
146 |
+
/site
|
147 |
+
|
148 |
+
# mypy
|
149 |
+
.mypy_cache/
|
150 |
+
.dmypy.json
|
151 |
+
dmypy.json
|
152 |
+
|
153 |
+
# Pyre type checker
|
154 |
+
.pyre/
|
155 |
+
|
156 |
+
# pytype static type analyzer
|
157 |
+
.pytype/
|
158 |
+
|
159 |
+
# Cython debug symbols
|
160 |
+
cython_debug/
|
161 |
+
|
162 |
+
# PyCharm
|
163 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
164 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
165 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
166 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
167 |
+
#.idea/
|
UniCeption/.pre-commit-config.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# See https://pre-commit.com for more information
|
2 |
+
# See https://pre-commit.com/hooks.html for more hooks
|
3 |
+
default_language_version:
|
4 |
+
python: python3
|
5 |
+
repos:
|
6 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
7 |
+
rev: v3.2.0
|
8 |
+
hooks:
|
9 |
+
- id: trailing-whitespace
|
10 |
+
- id: end-of-file-fixer
|
11 |
+
- repo: https://github.com/pre-commit/mirrors-isort
|
12 |
+
rev: 'v5.10.1'
|
13 |
+
hooks:
|
14 |
+
- id: isort
|
15 |
+
- repo: https://github.com/psf/black
|
16 |
+
rev: '23.3.0'
|
17 |
+
hooks:
|
18 |
+
- id: black
|
UniCeption/.pylintrc
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This Pylint rcfile contains a best-effort configuration to uphold the
|
2 |
+
# best-practices and style described in the Google Python style guide:
|
3 |
+
# https://google.github.io/styleguide/pyguide.html
|
4 |
+
#
|
5 |
+
# Its canonical open-source location is:
|
6 |
+
# https://google.github.io/styleguide/pylintrc
|
7 |
+
|
8 |
+
[MAIN]
|
9 |
+
|
10 |
+
# Files or directories to be skipped. They should be base names, not paths.
|
11 |
+
ignore=third_party
|
12 |
+
|
13 |
+
# Files or directories matching the regex patterns are skipped. The regex
|
14 |
+
# matches against base names, not paths.
|
15 |
+
ignore-patterns=
|
16 |
+
|
17 |
+
# Pickle collected data for later comparisons.
|
18 |
+
persistent=no
|
19 |
+
|
20 |
+
# List of plugins (as comma separated values of python modules names) to load,
|
21 |
+
# usually to register additional checkers.
|
22 |
+
load-plugins=
|
23 |
+
|
24 |
+
# Use multiple processes to speed up Pylint.
|
25 |
+
jobs=4
|
26 |
+
|
27 |
+
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
28 |
+
# active Python interpreter and may run arbitrary code.
|
29 |
+
unsafe-load-any-extension=no
|
30 |
+
|
31 |
+
|
32 |
+
[MESSAGES CONTROL]
|
33 |
+
|
34 |
+
# Only show warnings with the listed confidence levels. Leave empty to show
|
35 |
+
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
|
36 |
+
confidence=
|
37 |
+
|
38 |
+
# Enable the message, report, category or checker with the given id(s). You can
|
39 |
+
# either give multiple identifier separated by comma (,) or put this option
|
40 |
+
# multiple time (only on the command line, not in the configuration file where
|
41 |
+
# it should appear only once). See also the "--disable" option for examples.
|
42 |
+
#enable=
|
43 |
+
|
44 |
+
# Disable the message, report, category or checker with the given id(s). You
|
45 |
+
# can either give multiple identifiers separated by comma (,) or put this
|
46 |
+
# option multiple times (only on the command line, not in the configuration
|
47 |
+
# file where it should appear only once).You can also use "--disable=all" to
|
48 |
+
# disable everything first and then reenable specific checks. For example, if
|
49 |
+
# you want to run only the similarities checker, you can use "--disable=all
|
50 |
+
# --enable=similarities". If you want to run only the classes checker, but have
|
51 |
+
# no Warning level messages displayed, use"--disable=all --enable=classes
|
52 |
+
# --disable=W"
|
53 |
+
disable=R,
|
54 |
+
abstract-method,
|
55 |
+
apply-builtin,
|
56 |
+
arguments-differ,
|
57 |
+
attribute-defined-outside-init,
|
58 |
+
backtick,
|
59 |
+
bad-option-value,
|
60 |
+
basestring-builtin,
|
61 |
+
buffer-builtin,
|
62 |
+
c-extension-no-member,
|
63 |
+
consider-using-enumerate,
|
64 |
+
cmp-builtin,
|
65 |
+
cmp-method,
|
66 |
+
coerce-builtin,
|
67 |
+
coerce-method,
|
68 |
+
delslice-method,
|
69 |
+
div-method,
|
70 |
+
eq-without-hash,
|
71 |
+
execfile-builtin,
|
72 |
+
file-builtin,
|
73 |
+
filter-builtin-not-iterating,
|
74 |
+
fixme,
|
75 |
+
getslice-method,
|
76 |
+
global-statement,
|
77 |
+
hex-method,
|
78 |
+
idiv-method,
|
79 |
+
implicit-str-concat,
|
80 |
+
import-error,
|
81 |
+
import-self,
|
82 |
+
import-star-module-level,
|
83 |
+
input-builtin,
|
84 |
+
intern-builtin,
|
85 |
+
invalid-str-codec,
|
86 |
+
locally-disabled,
|
87 |
+
long-builtin,
|
88 |
+
long-suffix,
|
89 |
+
map-builtin-not-iterating,
|
90 |
+
misplaced-comparison-constant,
|
91 |
+
missing-function-docstring,
|
92 |
+
metaclass-assignment,
|
93 |
+
next-method-called,
|
94 |
+
next-method-defined,
|
95 |
+
no-absolute-import,
|
96 |
+
no-init, # added
|
97 |
+
no-member,
|
98 |
+
no-name-in-module,
|
99 |
+
no-self-use,
|
100 |
+
nonzero-method,
|
101 |
+
oct-method,
|
102 |
+
old-division,
|
103 |
+
old-ne-operator,
|
104 |
+
old-octal-literal,
|
105 |
+
old-raise-syntax,
|
106 |
+
parameter-unpacking,
|
107 |
+
print-statement,
|
108 |
+
raising-string,
|
109 |
+
range-builtin-not-iterating,
|
110 |
+
raw_input-builtin,
|
111 |
+
rdiv-method,
|
112 |
+
reduce-builtin,
|
113 |
+
relative-import,
|
114 |
+
reload-builtin,
|
115 |
+
round-builtin,
|
116 |
+
setslice-method,
|
117 |
+
signature-differs,
|
118 |
+
standarderror-builtin,
|
119 |
+
suppressed-message,
|
120 |
+
sys-max-int,
|
121 |
+
trailing-newlines,
|
122 |
+
unichr-builtin,
|
123 |
+
unicode-builtin,
|
124 |
+
unnecessary-pass,
|
125 |
+
unpacking-in-except,
|
126 |
+
useless-else-on-loop,
|
127 |
+
useless-suppression,
|
128 |
+
using-cmp-argument,
|
129 |
+
wrong-import-order,
|
130 |
+
xrange-builtin,
|
131 |
+
zip-builtin-not-iterating,
|
132 |
+
|
133 |
+
|
134 |
+
[REPORTS]
|
135 |
+
|
136 |
+
# Set the output format. Available formats are text, parseable, colorized, msvs
|
137 |
+
# (visual studio) and html. You can also give a reporter class, eg
|
138 |
+
# mypackage.mymodule.MyReporterClass.
|
139 |
+
output-format=text
|
140 |
+
|
141 |
+
# Tells whether to display a full report or only the messages
|
142 |
+
reports=no
|
143 |
+
|
144 |
+
# Python expression which should return a note less than 10 (10 is the highest
|
145 |
+
# note). You have access to the variables errors warning, statement which
|
146 |
+
# respectively contain the number of errors / warnings messages and the total
|
147 |
+
# number of statements analyzed. This is used by the global evaluation report
|
148 |
+
# (RP0004).
|
149 |
+
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
150 |
+
|
151 |
+
# Template used to display messages. This is a python new-style format string
|
152 |
+
# used to format the message information. See doc for all details
|
153 |
+
#msg-template=
|
154 |
+
|
155 |
+
|
156 |
+
[BASIC]
|
157 |
+
|
158 |
+
# Good variable names which should always be accepted, separated by a comma
|
159 |
+
good-names=main,_
|
160 |
+
|
161 |
+
# Bad variable names which should always be refused, separated by a comma
|
162 |
+
bad-names=
|
163 |
+
|
164 |
+
# Colon-delimited sets of names that determine each other's naming style when
|
165 |
+
# the name regexes allow several styles.
|
166 |
+
name-group=
|
167 |
+
|
168 |
+
# Include a hint for the correct naming format with invalid-name
|
169 |
+
include-naming-hint=no
|
170 |
+
|
171 |
+
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
172 |
+
# to this list to register other decorators that produce valid properties.
|
173 |
+
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
|
174 |
+
|
175 |
+
# Regular expression matching correct function names
|
176 |
+
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
|
177 |
+
|
178 |
+
# Regular expression matching correct variable names
|
179 |
+
variable-rgx=^[a-z][a-z0-9_]*$
|
180 |
+
|
181 |
+
# Regular expression matching correct constant names
|
182 |
+
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
183 |
+
|
184 |
+
# Regular expression matching correct attribute names
|
185 |
+
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
|
186 |
+
|
187 |
+
# Regular expression matching correct argument names
|
188 |
+
argument-rgx=^[a-z][a-z0-9_]*$
|
189 |
+
|
190 |
+
# Regular expression matching correct class attribute names
|
191 |
+
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
192 |
+
|
193 |
+
# Regular expression matching correct inline iteration names
|
194 |
+
inlinevar-rgx=^[a-z][a-z0-9_]*$
|
195 |
+
|
196 |
+
# Regular expression matching correct class names
|
197 |
+
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
|
198 |
+
|
199 |
+
# Regular expression matching correct module names
|
200 |
+
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
|
201 |
+
|
202 |
+
# Regular expression matching correct method names
|
203 |
+
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
|
204 |
+
|
205 |
+
# Regular expression which should only match function or class names that do
|
206 |
+
# not require a docstring.
|
207 |
+
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
|
208 |
+
|
209 |
+
# Minimum line length for functions/classes that require docstrings, shorter
|
210 |
+
# ones are exempt.
|
211 |
+
docstring-min-length=12
|
212 |
+
|
213 |
+
|
214 |
+
[TYPECHECK]
|
215 |
+
|
216 |
+
# List of decorators that produce context managers, such as
|
217 |
+
# contextlib.contextmanager. Add to this list to register other decorators that
|
218 |
+
# produce valid context managers.
|
219 |
+
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
|
220 |
+
|
221 |
+
# List of module names for which member attributes should not be checked
|
222 |
+
# (useful for modules/projects where namespaces are manipulated during runtime
|
223 |
+
# and thus existing member attributes cannot be deduced by static analysis. It
|
224 |
+
# supports qualified module names, as well as Unix pattern matching.
|
225 |
+
ignored-modules=
|
226 |
+
|
227 |
+
# List of class names for which member attributes should not be checked (useful
|
228 |
+
# for classes with dynamically set attributes). This supports the use of
|
229 |
+
# qualified names.
|
230 |
+
ignored-classes=optparse.Values,thread._local,_thread._local
|
231 |
+
|
232 |
+
# List of members which are set dynamically and missed by pylint inference
|
233 |
+
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
234 |
+
# expressions are accepted.
|
235 |
+
generated-members=
|
236 |
+
|
237 |
+
|
238 |
+
[FORMAT]
|
239 |
+
|
240 |
+
# Maximum number of characters on a single line.
|
241 |
+
max-line-length=80
|
242 |
+
|
243 |
+
# TODO(https://github.com/pylint-dev/pylint/issues/3352): Direct pylint to exempt
|
244 |
+
# lines made too long by directives to pytype.
|
245 |
+
|
246 |
+
# Regexp for a line that is allowed to be longer than the limit.
|
247 |
+
ignore-long-lines=(?x)(
|
248 |
+
^\s*(\#\ )?<?https?://\S+>?$|
|
249 |
+
^\s*(from\s+\S+\s+)?import\s+.+$)
|
250 |
+
|
251 |
+
# Allow the body of an if to be on the same line as the test if there is no
|
252 |
+
# else.
|
253 |
+
single-line-if-stmt=yes
|
254 |
+
|
255 |
+
# Maximum number of lines in a module
|
256 |
+
max-module-lines=99999
|
257 |
+
|
258 |
+
# String used as indentation unit. The internal Google style guide mandates 2
|
259 |
+
# spaces. Google's externaly-published style guide says 4, consistent with
|
260 |
+
# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
|
261 |
+
# projects (like TensorFlow).
|
262 |
+
indent-string=' '
|
263 |
+
|
264 |
+
# Number of spaces of indent required inside a hanging or continued line.
|
265 |
+
indent-after-paren=4
|
266 |
+
|
267 |
+
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
268 |
+
expected-line-ending-format=
|
269 |
+
|
270 |
+
|
271 |
+
[MISCELLANEOUS]
|
272 |
+
|
273 |
+
# List of note tags to take in consideration, separated by a comma.
|
274 |
+
notes=TODO
|
275 |
+
|
276 |
+
|
277 |
+
[STRING]
|
278 |
+
|
279 |
+
# This flag controls whether inconsistent-quotes generates a warning when the
|
280 |
+
# character used as a quote delimiter is used inconsistently within a module.
|
281 |
+
check-quote-consistency=yes
|
282 |
+
|
283 |
+
|
284 |
+
[VARIABLES]
|
285 |
+
|
286 |
+
# Tells whether we should check for unused import in __init__ files.
|
287 |
+
init-import=no
|
288 |
+
|
289 |
+
# A regular expression matching the name of dummy variables (i.e. expectedly
|
290 |
+
# not used).
|
291 |
+
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
|
292 |
+
|
293 |
+
# List of additional names supposed to be defined in builtins. Remember that
|
294 |
+
# you should avoid to define new builtins when possible.
|
295 |
+
additional-builtins=
|
296 |
+
|
297 |
+
# List of strings which can identify a callback function by name. A callback
|
298 |
+
# name must start or end with one of those strings.
|
299 |
+
callbacks=cb_,_cb
|
300 |
+
|
301 |
+
# List of qualified module names which can have objects that can redefine
|
302 |
+
# builtins.
|
303 |
+
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
|
304 |
+
|
305 |
+
|
306 |
+
[LOGGING]
|
307 |
+
|
308 |
+
# Logging modules to check that the string format arguments are in logging
|
309 |
+
# function parameter format
|
310 |
+
logging-modules=logging,absl.logging,tensorflow.io.logging
|
311 |
+
|
312 |
+
|
313 |
+
[SIMILARITIES]
|
314 |
+
|
315 |
+
# Minimum lines number of a similarity.
|
316 |
+
min-similarity-lines=4
|
317 |
+
|
318 |
+
# Ignore comments when computing similarities.
|
319 |
+
ignore-comments=yes
|
320 |
+
|
321 |
+
# Ignore docstrings when computing similarities.
|
322 |
+
ignore-docstrings=yes
|
323 |
+
|
324 |
+
# Ignore imports when computing similarities.
|
325 |
+
ignore-imports=no
|
326 |
+
|
327 |
+
|
328 |
+
[SPELLING]
|
329 |
+
|
330 |
+
# Spelling dictionary name. Available dictionaries: none. To make it working
|
331 |
+
# install python-enchant package.
|
332 |
+
spelling-dict=
|
333 |
+
|
334 |
+
# List of comma separated words that should not be checked.
|
335 |
+
spelling-ignore-words=
|
336 |
+
|
337 |
+
# A path to a file that contains private dictionary; one word per line.
|
338 |
+
spelling-private-dict-file=
|
339 |
+
|
340 |
+
# Tells whether to store unknown words to indicated private dictionary in
|
341 |
+
# --spelling-private-dict-file option instead of raising a message.
|
342 |
+
spelling-store-unknown-words=no
|
343 |
+
|
344 |
+
|
345 |
+
[IMPORTS]
|
346 |
+
|
347 |
+
# Deprecated modules which should not be used, separated by a comma
|
348 |
+
deprecated-modules=regsub,
|
349 |
+
TERMIOS,
|
350 |
+
Bastion,
|
351 |
+
rexec,
|
352 |
+
sets
|
353 |
+
|
354 |
+
# Create a graph of every (i.e. internal and external) dependencies in the
|
355 |
+
# given file (report RP0402 must not be disabled)
|
356 |
+
import-graph=
|
357 |
+
|
358 |
+
# Create a graph of external dependencies in the given file (report RP0402 must
|
359 |
+
# not be disabled)
|
360 |
+
ext-import-graph=
|
361 |
+
|
362 |
+
# Create a graph of internal dependencies in the given file (report RP0402 must
|
363 |
+
# not be disabled)
|
364 |
+
int-import-graph=
|
365 |
+
|
366 |
+
# Force import order to recognize a module as part of the standard
|
367 |
+
# compatibility libraries.
|
368 |
+
known-standard-library=
|
369 |
+
|
370 |
+
# Force import order to recognize a module as part of a third party library.
|
371 |
+
known-third-party=enchant, absl
|
372 |
+
|
373 |
+
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
374 |
+
# 3 compatible code, which means that the block might have code that exists
|
375 |
+
# only in one or another interpreter, leading to false positives when analysed.
|
376 |
+
analyse-fallback-blocks=no
|
377 |
+
|
378 |
+
|
379 |
+
[CLASSES]
|
380 |
+
|
381 |
+
# List of method names used to declare (i.e. assign) instance attributes.
|
382 |
+
defining-attr-methods=__init__,
|
383 |
+
__new__,
|
384 |
+
setUp
|
385 |
+
|
386 |
+
# List of member names, which should be excluded from the protected access
|
387 |
+
# warning.
|
388 |
+
exclude-protected=_asdict,
|
389 |
+
_fields,
|
390 |
+
_replace,
|
391 |
+
_source,
|
392 |
+
_make
|
393 |
+
|
394 |
+
# List of valid names for the first argument in a class method.
|
395 |
+
valid-classmethod-first-arg=cls,
|
396 |
+
class_
|
397 |
+
|
398 |
+
# List of valid names for the first argument in a metaclass class method.
|
399 |
+
valid-metaclass-classmethod-first-arg=mcs
|
UniCeption/LICENSE
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright (c) 2024, AirLab Stacks
|
4 |
+
|
5 |
+
Redistribution and use in source and binary forms, with or without
|
6 |
+
modification, are permitted provided that the following conditions are met:
|
7 |
+
|
8 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
9 |
+
list of conditions and the following disclaimer.
|
10 |
+
|
11 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
this list of conditions and the following disclaimer in the documentation
|
13 |
+
and/or other materials provided with the distribution.
|
14 |
+
|
15 |
+
3. Neither the name of the copyright holder nor the names of its
|
16 |
+
contributors may be used to endorse or promote products derived from
|
17 |
+
this software without specific prior written permission.
|
18 |
+
|
19 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
UniCeption/README.md
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# UniCeption
|
2 |
+
|
3 |
+
UniCeption houses modular building blocks for developing and training generalizable perception models for all things related to 3D, 4D, spatial AI and scene understanding.
|
4 |
+
It is designed to be flexible and extensible, allowing researchers to easily experiment with different architectures and configurations.
|
5 |
+
|
6 |
+
Please refer to the [Developer Guidelines](#developer-guidelines) for contributing to the project.
|
7 |
+
|
8 |
+
## Installation
|
9 |
+
|
10 |
+
Clone the repository to your local machine by running the following command:
|
11 |
+
|
12 |
+
```bash
|
13 |
+
git clone [email protected]:castacks/UniCeption.git
|
14 |
+
cd UniCeption
|
15 |
+
```
|
16 |
+
|
17 |
+
### Standard Installation
|
18 |
+
|
19 |
+
Install the `uniception` package in development mode by running the following commands:
|
20 |
+
|
21 |
+
```bash
|
22 |
+
# Please use Conda or Python Virtual Environment based on your preference
|
23 |
+
# For Conda Environment
|
24 |
+
conda create --name uniception python=3.12
|
25 |
+
conda activate uniception
|
26 |
+
# For Python Virtual Environment
|
27 |
+
virtualenv uniception
|
28 |
+
source uniception/bin/activate
|
29 |
+
|
30 |
+
# Install UniCeption with base dependencies (includes PyTorch)
|
31 |
+
pip install -e .
|
32 |
+
|
33 |
+
# Optional: Install with XFormers support
|
34 |
+
pip install -e ".[xformers]"
|
35 |
+
|
36 |
+
# Optional: Install with development tools
|
37 |
+
pip install -e ".[dev]"
|
38 |
+
|
39 |
+
# Optional: Install all optional dependencies
|
40 |
+
pip install -e ".[all]"
|
41 |
+
|
42 |
+
# Setup pre-commit hooks for development
|
43 |
+
pre-commit install
|
44 |
+
```
|
45 |
+
|
46 |
+
### Optional: CroCo RoPE Extension Installation
|
47 |
+
|
48 |
+
To use CroCo models with the custom RoPE kernel:
|
49 |
+
|
50 |
+
```bash
|
51 |
+
# Recommended: Use the console script
|
52 |
+
uniception-install-croco
|
53 |
+
|
54 |
+
# Alternative: Set environment variable during installation
|
55 |
+
INSTALL_CROCO_ROPE=true pip install -e .
|
56 |
+
|
57 |
+
# Manual compilation (if needed)
|
58 |
+
cd uniception/models/libs/croco/curope
|
59 |
+
python setup.py build_ext --inplace
|
60 |
+
cd ../../../../../
|
61 |
+
```
|
62 |
+
|
63 |
+
### Installation Validation and Dependency Checking
|
64 |
+
|
65 |
+
After installation, use these console scripts to validate your setup:
|
66 |
+
|
67 |
+
```bash
|
68 |
+
# Validate installation and check dependencies
|
69 |
+
uniception-validate
|
70 |
+
|
71 |
+
# Check which optional dependencies are available
|
72 |
+
uniception-check-deps
|
73 |
+
```
|
74 |
+
|
75 |
+
### Advanced Installation Options
|
76 |
+
|
77 |
+
#### Docker Installation (No Internet Access)
|
78 |
+
|
79 |
+
If you're working in a Docker container that already has Python dependencies installed but no internet access, you can install UniCeption in development mode without triggering network requests:
|
80 |
+
|
81 |
+
```bash
|
82 |
+
# Install only the package structure without dependencies
|
83 |
+
pip install -e . --no-deps
|
84 |
+
```
|
85 |
+
|
86 |
+
**Note:** This command assumes your Docker image already contains all required dependencies (PyTorch, etc.). Use `uniception-validate` after installation to verify all dependencies are available.
|
87 |
+
|
88 |
+
#### Offline Installation
|
89 |
+
|
90 |
+
For environments without internet access:
|
91 |
+
|
92 |
+
```bash
|
93 |
+
# 1. On a machine with internet access, prepare offline wheels
|
94 |
+
uniception-prepare-offline --output-dir offline_wheels --extras all
|
95 |
+
|
96 |
+
# 2. Copy the offline_wheels directory to your offline environment
|
97 |
+
# 3. Run the offline installation
|
98 |
+
cd offline_wheels
|
99 |
+
INSTALL_CROCO_ROPE=true INSTALL_XFORMERS=true ./install_offline.sh
|
100 |
+
```
|
101 |
+
|
102 |
+
#### Downloading Checkpoints
|
103 |
+
|
104 |
+
Download UniCeption format custom checkpoints:
|
105 |
+
|
106 |
+
```bash
|
107 |
+
# Download all available checkpoints
|
108 |
+
uniception-download-checkpoints
|
109 |
+
|
110 |
+
# Download specific folders only (e.g., encoders and prediction heads)
|
111 |
+
uniception-download-checkpoints --folders encoders prediction_heads
|
112 |
+
|
113 |
+
# Specify custom destination
|
114 |
+
uniception-download-checkpoints --destination /path/to/checkpoints
|
115 |
+
```
|
116 |
+
|
117 |
+
**Available options:**
|
118 |
+
- `--folders`: Specify which folders to download. Choices: `encoders`, `info_sharing`, `prediction_heads`, `examples` (default: all folders)
|
119 |
+
- `--destination`: Custom destination folder for downloaded checkpoints (default: current directory)
|
120 |
+
|
121 |
+
---
|
122 |
+
|
123 |
+
## Currently Supported Components
|
124 |
+
|
125 |
+
### Encoders
|
126 |
+
|
127 |
+
Please refer to the `uniception/models/encoders` directory for the supported encoders and documentation for adding new encoders. The supported encoders can be listed by running:
|
128 |
+
|
129 |
+
```bash
|
130 |
+
python3 -m uniception.models.encoders.list
|
131 |
+
```
|
132 |
+
|
133 |
+
---
|
134 |
+
|
135 |
+
## Information Sharing Blocks
|
136 |
+
|
137 |
+
Please refer to the `uniception/models/info_sharing` directory for the supported information sharing blocks.
|
138 |
+
|
139 |
+
---
|
140 |
+
|
141 |
+
## Prediction Heads
|
142 |
+
|
143 |
+
Please refer to the `uniception/models/prediction_heads` directory for the supported prediction heads.
|
144 |
+
|
145 |
+
---
|
146 |
+
|
147 |
+
## Developer Guidelines
|
148 |
+
|
149 |
+
Please follow these guidelines when contributing to UniCeption:
|
150 |
+
- **Code Style**: Follow the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html) for code style.
|
151 |
+
- **Documentation**: Add docstrings to all classes and methods.
|
152 |
+
- **Unit Tests**: Add necessary unit tests to the `tests` folder.
|
153 |
+
- **Linting**: Run `black` & `isort` on your code before committing. For example, you can run `black . && isort .`.
|
154 |
+
|
155 |
+
Please create a pull request for any changes you make, and ensure that all tests pass before merging.
|
UniCeption/examples/models/cosmos/autoencoding.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
from matplotlib import pyplot as plt
|
6 |
+
|
7 |
+
from uniception.models.encoders.base import ViTEncoderInput
|
8 |
+
from uniception.models.encoders.cosmos import CosmosEncoder
|
9 |
+
from uniception.models.prediction_heads.cosmos import CosmosSingleChannel
|
10 |
+
|
11 |
+
base_path = os.path.dirname(os.path.abspath(__file__))
|
12 |
+
|
13 |
+
encoder = CosmosEncoder(
|
14 |
+
name="cosmos",
|
15 |
+
patch_size=8,
|
16 |
+
pretrained_checkpoint_path=os.path.join(
|
17 |
+
base_path, "../../../checkpoints/encoders/cosmos/Cosmos-Tokenizer-CI8x8/encoder.pth"
|
18 |
+
),
|
19 |
+
)
|
20 |
+
|
21 |
+
decoder = CosmosSingleChannel(
|
22 |
+
patch_size=8,
|
23 |
+
pretrained_checkpoint_path=os.path.join(base_path, "../../../checkpoints/prediction_heads/cosmos/decoder_8.pth"),
|
24 |
+
)
|
25 |
+
|
26 |
+
example_image = cv2.imread(os.path.join(base_path, "./example.png"))
|
27 |
+
example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB)
|
28 |
+
example_tensor = torch.tensor(example_image).permute(2, 0, 1).unsqueeze(0).float() / 255.0
|
29 |
+
example_tensor = example_tensor * 2.0 - 1.0 # Normalize to [-1, 1] according to the COSMOS Encoder
|
30 |
+
|
31 |
+
encoded_latent = encoder(ViTEncoderInput("cosmos", example_tensor)).features
|
32 |
+
|
33 |
+
decoded_image = decoder(encoded_latent)
|
34 |
+
decoded_image = (decoded_image + 1.0) / 2.0 # Denormalize to [0, 1] for visualization
|
35 |
+
|
36 |
+
# plot the original and decoded images
|
37 |
+
plt.figure(figsize=(10, 5))
|
38 |
+
plt.subplot(1, 2, 1)
|
39 |
+
plt.imshow(example_image)
|
40 |
+
plt.title("Original Image")
|
41 |
+
plt.axis("off")
|
42 |
+
|
43 |
+
plt.subplot(1, 2, 2)
|
44 |
+
plt.imshow(decoded_image.squeeze().detach().permute(1, 2, 0).cpu().numpy())
|
45 |
+
plt.title("Decoded Image")
|
46 |
+
plt.axis("off")
|
47 |
+
|
48 |
+
plt.savefig(os.path.join(base_path, "example_decoded.png"))
|
UniCeption/examples/models/cosmos/example.png
ADDED
![]() |
Git LFS Details
|
UniCeption/examples/models/cosmos/example_decoded.png
ADDED
![]() |
Git LFS Details
|
UniCeption/examples/models/dust3r/convert_dust3r_weights_to_uniception.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file extracts the cross-attention transformer & prediction head weights from dust3r checkpoints into uniception format.
|
3 |
+
|
4 |
+
Special Notice: dust3r have changed their released weights before/after CVPR, and
|
5 |
+
uniception uses the checkpoint BEFORE CVPR (they perform better). So please make sure you are not converting
|
6 |
+
the newly downloaded weights. Consult Yuchen and Nikhil on where to find the old weights.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
import os
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
|
15 |
+
from uniception.models.info_sharing.cross_attention_transformer import MultiViewCrossAttentionTransformerIFR
|
16 |
+
from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor
|
17 |
+
from uniception.models.prediction_heads.linear import LinearFeature
|
18 |
+
|
19 |
+
|
20 |
+
def extract_cross_attention_weights(checkpoint_path, output_folder, output_filename):
|
21 |
+
"Extract the UniCeption format cross attention weights from the original CroCoV2/DUSt3R/MASt3R checkpoints."
|
22 |
+
# Load checkpoint
|
23 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
24 |
+
|
25 |
+
# Filter the relevant keys for the cross attention model and duplicate if necessary
|
26 |
+
filtered_checkpoint = checkpoint["model"]
|
27 |
+
filtered_checkpoint = {k: v for k, v in filtered_checkpoint.items() if "dec" in k}
|
28 |
+
duplicate_checkpoint = {}
|
29 |
+
if not any(k.startswith("dec_blocks2") for k in filtered_checkpoint):
|
30 |
+
print("Duplicating dec_blocks to dec_blocks2")
|
31 |
+
for key, value in filtered_checkpoint.items():
|
32 |
+
if key.startswith("dec_blocks"):
|
33 |
+
duplicate_checkpoint[key.replace("dec_blocks", "dec_blocks2")] = value
|
34 |
+
filtered_checkpoint = {**filtered_checkpoint, **duplicate_checkpoint}
|
35 |
+
new_checkpoint = {}
|
36 |
+
for k, v in filtered_checkpoint.items():
|
37 |
+
if "decoder_embed" in k:
|
38 |
+
new_key = k.replace("decoder_embed", "proj_embed")
|
39 |
+
new_checkpoint[new_key] = v
|
40 |
+
elif "dec_blocks." in k:
|
41 |
+
new_key = k.replace("dec_blocks.", "multi_view_branches.0.")
|
42 |
+
new_checkpoint[new_key] = v
|
43 |
+
elif "dec_blocks2." in k:
|
44 |
+
new_key = k.replace("dec_blocks2.", "multi_view_branches.1.")
|
45 |
+
new_checkpoint[new_key] = v
|
46 |
+
elif "dec_norm" in k:
|
47 |
+
new_key = k.replace("dec_norm", "norm")
|
48 |
+
new_checkpoint[new_key] = v
|
49 |
+
|
50 |
+
# Init model
|
51 |
+
model = MultiViewCrossAttentionTransformerIFR(
|
52 |
+
name="MV-CAT-IFR",
|
53 |
+
input_embed_dim=1024,
|
54 |
+
num_views=2,
|
55 |
+
indices=[5, 8],
|
56 |
+
norm_intermediate=False,
|
57 |
+
)
|
58 |
+
|
59 |
+
# Load new checkpoint
|
60 |
+
print(model.load_state_dict(new_checkpoint))
|
61 |
+
|
62 |
+
# Save the checkpoint
|
63 |
+
save_checkpoint = {}
|
64 |
+
save_checkpoint["model"] = model.state_dict()
|
65 |
+
os.makedirs(os.path.join(output_folder, "cross_attn_transformer"), exist_ok=True)
|
66 |
+
save_path = os.path.join(output_folder, "cross_attn_transformer", output_filename)
|
67 |
+
torch.save(save_checkpoint, save_path)
|
68 |
+
|
69 |
+
|
70 |
+
def extract_dust3r_dpt_checkpoints(checkpoint_path, output_folder, output_filename):
|
71 |
+
"Extract the UniCeption format DPT head weights from the original DUSt3R checkpoint."
|
72 |
+
source_ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
73 |
+
|
74 |
+
for head in ["head1", "head2"]:
|
75 |
+
# Extract head weights from the checkpoint
|
76 |
+
dpt_head_weights = {k: v for k, v in source_ckpt["model"].items() if k.startswith(f"downstream_{head}")}
|
77 |
+
dpt_head_weights = {k.replace(f"downstream_{head}.dpt.", ""): v for k, v in dpt_head_weights.items()}
|
78 |
+
dpt_feature_weights = {k: v for k, v in dpt_head_weights.items() if not (k.startswith("head"))}
|
79 |
+
|
80 |
+
# Construct the DPTFeature module and load the weights
|
81 |
+
dpt = DPTFeature(
|
82 |
+
patch_size=16,
|
83 |
+
hooks=[0, 1, 2, 3],
|
84 |
+
input_feature_dims=[1024, 768, 768, 768],
|
85 |
+
layer_dims=[96, 192, 384, 768],
|
86 |
+
feature_dim=256,
|
87 |
+
use_bn=False,
|
88 |
+
output_width_ratio=1,
|
89 |
+
)
|
90 |
+
|
91 |
+
dpt.load_state_dict(dpt_feature_weights, strict=True)
|
92 |
+
|
93 |
+
# Construct the dpt processor module and load the weights
|
94 |
+
dpt_processor_weights = {k.replace("head.", ""): v for k, v in dpt_head_weights.items() if k.startswith("head")}
|
95 |
+
|
96 |
+
# Replace the keys according to:
|
97 |
+
key_replace_dict = {
|
98 |
+
"0.weight": "conv1.weight",
|
99 |
+
"0.bias": "conv1.bias",
|
100 |
+
"2.weight": "conv2.0.weight",
|
101 |
+
"2.bias": "conv2.0.bias",
|
102 |
+
"4.weight": "conv2.2.weight",
|
103 |
+
"4.bias": "conv2.2.bias",
|
104 |
+
}
|
105 |
+
|
106 |
+
dpt_processor_weights = {key_replace_dict.get(k, k): v for k, v in dpt_processor_weights.items()}
|
107 |
+
|
108 |
+
dpt_reg_processor = DPTRegressionProcessor(input_feature_dim=256, output_dim=4, hidden_dims=[128, 128])
|
109 |
+
|
110 |
+
dpt_reg_processor.load_state_dict(dpt_processor_weights, strict=True)
|
111 |
+
|
112 |
+
# Save the state_dicts of the DPTFeature and DPTRegressionProcessor
|
113 |
+
dpt_feature_path = os.path.join(output_folder, "dpt_feature_head", output_filename + f"_feature_{head}.pth")
|
114 |
+
dpt_reg_processor_path = os.path.join(
|
115 |
+
output_folder, "dpt_reg_processor", output_filename + f"_reg_processor{head[-1]}.pth"
|
116 |
+
)
|
117 |
+
|
118 |
+
os.makedirs(os.path.dirname(dpt_feature_path), exist_ok=True)
|
119 |
+
os.makedirs(os.path.dirname(dpt_reg_processor_path), exist_ok=True)
|
120 |
+
|
121 |
+
torch.save({"model": dpt.state_dict()}, dpt_feature_path)
|
122 |
+
torch.save({"model": dpt_reg_processor.state_dict()}, dpt_reg_processor_path)
|
123 |
+
|
124 |
+
|
125 |
+
def extract_dust3r_linear_checkpoints(checkpoint_path, output_folder, output_filename):
|
126 |
+
"Extract the UniCeption format linear head weights from the original DUSt3R checkpoint."
|
127 |
+
test_linear_to_conv()
|
128 |
+
|
129 |
+
source_ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
130 |
+
|
131 |
+
for head in ["head1", "head2"]:
|
132 |
+
linear_head_params = {k: v for k, v in source_ckpt["model"].items() if k.startswith(f"downstream_{head}")}
|
133 |
+
linear_head_params = {k.replace(f"downstream_{head}.proj.", ""): v for k, v in linear_head_params.items()}
|
134 |
+
|
135 |
+
assert set(linear_head_params.keys()) == {"weight", "bias"}
|
136 |
+
|
137 |
+
input_feature_dim = 768
|
138 |
+
output_dim = 4
|
139 |
+
patch_size = 16
|
140 |
+
|
141 |
+
linear = nn.Linear(input_feature_dim, output_dim * patch_size * patch_size, bias=True)
|
142 |
+
linear.load_state_dict(linear_head_params, strict=True)
|
143 |
+
|
144 |
+
conv_layer = linear_to_conv2d(linear)
|
145 |
+
|
146 |
+
linear_feature = LinearFeature(input_feature_dim, 4, patch_size)
|
147 |
+
linear_feature.linear.load_state_dict(conv_layer.state_dict(), strict=True)
|
148 |
+
|
149 |
+
linear_feature_path = os.path.join(
|
150 |
+
output_folder, "linear_feature_head", output_filename + f"_feature_{head}.pth"
|
151 |
+
)
|
152 |
+
os.makedirs(os.path.dirname(linear_feature_path), exist_ok=True)
|
153 |
+
torch.save({"model": linear_feature.state_dict()}, linear_feature_path)
|
154 |
+
|
155 |
+
|
156 |
+
def extract_mast3r_dpt_checkpoints(checkpoint_path, output_folder, output_filename):
|
157 |
+
"Extract the UniCeption format DPT head weights from the original MASt3R checkpoint."
|
158 |
+
source_ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
159 |
+
|
160 |
+
for head in ["head1", "head2"]:
|
161 |
+
dpt_head = {k: v for k, v in source_ckpt["model"].items() if k.startswith(f"downstream_{head}")}
|
162 |
+
dpt_head = {k.replace(f"downstream_{head}.", ""): v for k, v in dpt_head.items()}
|
163 |
+
dpt_head = {k.replace("dpt.", ""): v for k, v in dpt_head.items()}
|
164 |
+
|
165 |
+
dpt_feature_weights = {
|
166 |
+
k: v for k, v in dpt_head.items() if not (k.startswith("head") or k.startswith("head_local_features"))
|
167 |
+
}
|
168 |
+
|
169 |
+
dpt = DPTFeature(
|
170 |
+
patch_size=16,
|
171 |
+
hooks=[0, 1, 2, 3],
|
172 |
+
input_feature_dims=[1024, 768, 768, 768],
|
173 |
+
layer_dims=[96, 192, 384, 768],
|
174 |
+
feature_dim=256,
|
175 |
+
use_bn=False,
|
176 |
+
output_width_ratio=1,
|
177 |
+
)
|
178 |
+
|
179 |
+
dpt.load_state_dict(dpt_feature_weights, strict=True)
|
180 |
+
|
181 |
+
dpt_processor_weights = {
|
182 |
+
k.replace("head.", ""): v
|
183 |
+
for k, v in dpt_head.items()
|
184 |
+
if (k.startswith("head") and not k.startswith("head_local_features"))
|
185 |
+
}
|
186 |
+
|
187 |
+
# Replace the keys according to:
|
188 |
+
key_replace_dict = {
|
189 |
+
"0.weight": "conv1.weight",
|
190 |
+
"0.bias": "conv1.bias",
|
191 |
+
"2.weight": "conv2.0.weight",
|
192 |
+
"2.bias": "conv2.0.bias",
|
193 |
+
"4.weight": "conv2.2.weight",
|
194 |
+
"4.bias": "conv2.2.bias",
|
195 |
+
}
|
196 |
+
|
197 |
+
dpt_processor_weights = {key_replace_dict.get(k, k): v for k, v in dpt_processor_weights.items()}
|
198 |
+
|
199 |
+
dpt_reg_processor = DPTRegressionProcessor(input_feature_dim=256, output_dim=4, hidden_dims=[128, 128])
|
200 |
+
|
201 |
+
dpt_reg_processor.load_state_dict(dpt_processor_weights, strict=True)
|
202 |
+
|
203 |
+
# Save the state_dicts of the DPTFeature and DPTRegressionProcessor
|
204 |
+
dpt_feature_path = os.path.join(output_folder, "dpt_feature_head", output_filename + f"_feature_{head}.pth")
|
205 |
+
dpt_reg_processor_path = os.path.join(
|
206 |
+
output_folder, "dpt_reg_processor", output_filename + f"_reg_processor{head[-1]}.pth"
|
207 |
+
)
|
208 |
+
|
209 |
+
os.makedirs(os.path.dirname(dpt_feature_path), exist_ok=True)
|
210 |
+
os.makedirs(os.path.dirname(dpt_reg_processor_path), exist_ok=True)
|
211 |
+
|
212 |
+
torch.save({"model": dpt.state_dict()}, dpt_feature_path)
|
213 |
+
torch.save({"model": dpt_reg_processor.state_dict()}, dpt_reg_processor_path)
|
214 |
+
|
215 |
+
|
216 |
+
def linear_to_conv2d(linear_layer):
|
217 |
+
"""
|
218 |
+
Converts a nn.Linear layer to an equivalent nn.Conv2d layer with a 1x1 kernel.
|
219 |
+
|
220 |
+
Parameters:
|
221 |
+
- linear_layer (nn.Linear): The Linear layer to convert.
|
222 |
+
|
223 |
+
Returns:
|
224 |
+
- conv_layer (nn.Conv2d): The equivalent Conv2d layer.
|
225 |
+
"""
|
226 |
+
# Extract in_features and out_features from the Linear layer
|
227 |
+
in_features = linear_layer.in_features
|
228 |
+
out_features = linear_layer.out_features
|
229 |
+
bias = linear_layer.bias is not None
|
230 |
+
|
231 |
+
# Create a Conv2d layer with a 1x1 kernel
|
232 |
+
conv_layer = nn.Conv2d(
|
233 |
+
in_channels=in_features, out_channels=out_features, kernel_size=1, stride=1, padding=0, bias=bias
|
234 |
+
)
|
235 |
+
|
236 |
+
# Reshape Linear weights to match Conv2d weights
|
237 |
+
conv_weight = linear_layer.weight.data.view(out_features, in_features, 1, 1).clone()
|
238 |
+
conv_layer.weight.data = conv_weight
|
239 |
+
|
240 |
+
# Copy bias if it exists
|
241 |
+
if bias:
|
242 |
+
conv_layer.bias.data = linear_layer.bias.data.clone()
|
243 |
+
|
244 |
+
return conv_layer
|
245 |
+
|
246 |
+
|
247 |
+
def test_linear_to_conv():
|
248 |
+
"Test the linear_to_conv2d function."
|
249 |
+
batch_size = 4
|
250 |
+
height = 16
|
251 |
+
width = 24
|
252 |
+
in_channels = 3
|
253 |
+
out_channels = 5
|
254 |
+
|
255 |
+
# Sample input tensor in BHWC format
|
256 |
+
x_linear = torch.randn(batch_size, height, width, in_channels)
|
257 |
+
|
258 |
+
# Define Linear layer
|
259 |
+
linear_layer = nn.Linear(in_channels, out_channels)
|
260 |
+
output_linear = linear_layer(x_linear)
|
261 |
+
|
262 |
+
# Transpose input tensor to BCHW format for Conv2d
|
263 |
+
x_conv = x_linear.permute(0, 3, 1, 2)
|
264 |
+
|
265 |
+
# Define Conv2d layer
|
266 |
+
conv_layer = linear_to_conv2d(linear_layer)
|
267 |
+
|
268 |
+
# Get Conv2d output and transpose back to BHWC format
|
269 |
+
output_conv = conv_layer(x_conv).permute(0, 2, 3, 1)
|
270 |
+
|
271 |
+
# Verify that outputs are the same
|
272 |
+
assert torch.allclose(output_linear, output_conv, atol=1e-6)
|
273 |
+
|
274 |
+
|
275 |
+
if __name__ == "__main__":
|
276 |
+
parser = argparse.ArgumentParser(description="Extract dust3r checkpoints to uniception format")
|
277 |
+
|
278 |
+
parser.add_argument(
|
279 |
+
"-dcf", "--dust3r_checkpoints_folder", type=str, required=True, help="Path to the dust3r checkpoints folder"
|
280 |
+
)
|
281 |
+
parser.add_argument("-of", "--output_folder", type=str, required=True, help="Path to the output folder")
|
282 |
+
|
283 |
+
args = parser.parse_args()
|
284 |
+
|
285 |
+
output_folder = args.output_folder
|
286 |
+
info_sharing_output_folder = os.path.join(output_folder, "info_sharing")
|
287 |
+
pred_head_output_folder = os.path.join(output_folder, "prediction_heads")
|
288 |
+
os.makedirs(output_folder, exist_ok=True)
|
289 |
+
os.makedirs(info_sharing_output_folder, exist_ok=True)
|
290 |
+
os.makedirs(pred_head_output_folder, exist_ok=True)
|
291 |
+
|
292 |
+
# Extract croco checkpoint
|
293 |
+
print("Extracting CroCo checkpoint...")
|
294 |
+
croco_ckpt_filepath = os.path.join(args.dust3r_checkpoints_folder, "CroCo_V2_ViTLarge_BaseDecoder.pth")
|
295 |
+
extract_cross_attention_weights(
|
296 |
+
croco_ckpt_filepath, info_sharing_output_folder, "Two_View_Cross_Attention_Transformer_CroCo.pth"
|
297 |
+
)
|
298 |
+
|
299 |
+
# Extract dust3r 224 linear checkpoint
|
300 |
+
print("Extracting DUSt3R 224 linear checkpoint...")
|
301 |
+
dust3r_ckpt_filepath = os.path.join(args.dust3r_checkpoints_folder, "DUSt3R_ViTLarge_BaseDecoder_224_linear.pth")
|
302 |
+
extract_cross_attention_weights(
|
303 |
+
dust3r_ckpt_filepath, info_sharing_output_folder, "Two_View_Cross_Attention_Transformer_DUSt3R_224_linear.pth"
|
304 |
+
)
|
305 |
+
extract_dust3r_linear_checkpoints(dust3r_ckpt_filepath, pred_head_output_folder, "DUSt3R_224_linear")
|
306 |
+
|
307 |
+
# Extract dust3r 512 linear checkpoint
|
308 |
+
print("Extracting DUSt3R 512 linear checkpoint...")
|
309 |
+
dust3r_ckpt_filepath = os.path.join(args.dust3r_checkpoints_folder, "DUSt3R_ViTLarge_BaseDecoder_512_linear.pth")
|
310 |
+
extract_cross_attention_weights(
|
311 |
+
dust3r_ckpt_filepath, info_sharing_output_folder, "Two_View_Cross_Attention_Transformer_DUSt3R_512_linear.pth"
|
312 |
+
)
|
313 |
+
extract_dust3r_linear_checkpoints(dust3r_ckpt_filepath, pred_head_output_folder, "DUSt3R_512_linear")
|
314 |
+
|
315 |
+
# Extract dust3r 512 dpt checkpoint
|
316 |
+
print("Extracting DUSt3R 512 dpt checkpoint...")
|
317 |
+
dust3r_ckpt_filepath = os.path.join(args.dust3r_checkpoints_folder, "DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth")
|
318 |
+
extract_cross_attention_weights(
|
319 |
+
dust3r_ckpt_filepath, info_sharing_output_folder, "Two_View_Cross_Attention_Transformer_DUSt3R_512_dpt.pth"
|
320 |
+
)
|
321 |
+
extract_dust3r_dpt_checkpoints(dust3r_ckpt_filepath, pred_head_output_folder, "DUSt3R_512_dpt")
|
322 |
+
|
323 |
+
# Extract mast3r 512 dpt checkpoint
|
324 |
+
print("Extracting MASt3R 512 dpt checkpoint...")
|
325 |
+
mast3r_ckpt_path = os.path.join(
|
326 |
+
args.dust3r_checkpoints_folder, "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth"
|
327 |
+
)
|
328 |
+
extract_cross_attention_weights(
|
329 |
+
mast3r_ckpt_path, info_sharing_output_folder, "Two_View_Cross_Attention_Transformer_MASt3R_512_dpt.pth"
|
330 |
+
)
|
331 |
+
extract_mast3r_dpt_checkpoints(mast3r_ckpt_path, pred_head_output_folder, "MASt3R_512_dpt")
|
UniCeption/examples/models/dust3r/dust3r.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Initalizing Pre-trained DUSt3R using UniCeption
|
3 |
+
"""
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
from io import BytesIO
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import requests
|
11 |
+
import rerun as rr
|
12 |
+
import torch
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
from uniception.models.factory import DUSt3R
|
16 |
+
from uniception.utils.viz import script_add_rerun_args
|
17 |
+
|
18 |
+
|
19 |
+
def get_model_configurations_and_checkpoints():
|
20 |
+
"""
|
21 |
+
Get different DUSt3R model configurations and paths to refactored checkpoints.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
Tuple[List[str], dict]: A tuple containing the model configurations and paths to refactored checkpoints.
|
25 |
+
"""
|
26 |
+
# Initialize model configurations
|
27 |
+
model_configurations = ["dust3r_224_linear", "dust3r_512_linear", "dust3r_512_dpt", "dust3r_512_dpt_mast3r"]
|
28 |
+
|
29 |
+
# Get paths to pretrained checkpoints
|
30 |
+
current_file_path = os.path.abspath(__file__)
|
31 |
+
relative_checkpoint_path = os.path.join(os.path.dirname(current_file_path), "../../../checkpoints")
|
32 |
+
|
33 |
+
# Initialize model configurations
|
34 |
+
model_to_checkpoint_path = {
|
35 |
+
"dust3r_512_dpt": {
|
36 |
+
"encoder": f"{relative_checkpoint_path}/encoders/CroCo_Encoder_512_DUSt3R_dpt.pth",
|
37 |
+
"info_sharing": f"{relative_checkpoint_path}/info_sharing/cross_attn_transformer/Two_View_Cross_Attention_Transformer_DUSt3R_512_dpt.pth",
|
38 |
+
"feature_head": [
|
39 |
+
f"{relative_checkpoint_path}/prediction_heads/dpt_feature_head/DUSt3R_512_dpt_feature_head1.pth",
|
40 |
+
f"{relative_checkpoint_path}/prediction_heads/dpt_feature_head/DUSt3R_512_dpt_feature_head2.pth",
|
41 |
+
],
|
42 |
+
"regressor": [
|
43 |
+
f"{relative_checkpoint_path}/prediction_heads/dpt_reg_processor/DUSt3R_512_dpt_reg_processor1.pth",
|
44 |
+
f"{relative_checkpoint_path}/prediction_heads/dpt_reg_processor/DUSt3R_512_dpt_reg_processor2.pth",
|
45 |
+
],
|
46 |
+
"ckpt_path": f"{relative_checkpoint_path}/examples/original_dust3r/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth",
|
47 |
+
},
|
48 |
+
"dust3r_512_dpt_mast3r": {
|
49 |
+
"encoder": f"{relative_checkpoint_path}/encoders/CroCo_Encoder_512_MASt3R.pth",
|
50 |
+
"info_sharing": f"{relative_checkpoint_path}/info_sharing/cross_attn_transformer/Two_View_Cross_Attention_Transformer_MASt3R_512_dpt.pth",
|
51 |
+
"feature_head": [
|
52 |
+
f"{relative_checkpoint_path}/prediction_heads/dpt_feature_head/MASt3R_512_dpt_feature_head1.pth",
|
53 |
+
f"{relative_checkpoint_path}/prediction_heads/dpt_feature_head/MASt3R_512_dpt_feature_head2.pth",
|
54 |
+
],
|
55 |
+
"regressor": [
|
56 |
+
f"{relative_checkpoint_path}/prediction_heads/dpt_reg_processor/MASt3R_512_dpt_reg_processor1.pth",
|
57 |
+
f"{relative_checkpoint_path}/prediction_heads/dpt_reg_processor/MASt3R_512_dpt_reg_processor2.pth",
|
58 |
+
],
|
59 |
+
"ckpt_path": f"{relative_checkpoint_path}/examples/original_dust3r/DUSt3R_ViTLarge_BaseDecoder_512_dpt_mast3r.pth",
|
60 |
+
},
|
61 |
+
"dust3r_512_linear": {
|
62 |
+
"encoder": f"{relative_checkpoint_path}/encoders/CroCo_Encoder_512_DUSt3R_linear.pth",
|
63 |
+
"info_sharing": f"{relative_checkpoint_path}/info_sharing/cross_attn_transformer/Two_View_Cross_Attention_Transformer_DUSt3R_512_linear.pth",
|
64 |
+
"feature_head": [
|
65 |
+
f"{relative_checkpoint_path}/prediction_heads/linear_feature_head/DUSt3R_512_linear_feature_head1.pth",
|
66 |
+
f"{relative_checkpoint_path}/prediction_heads/linear_feature_head/DUSt3R_512_linear_feature_head2.pth",
|
67 |
+
],
|
68 |
+
"regressor": None,
|
69 |
+
"ckpt_path": f"{relative_checkpoint_path}/examples/original_dust3r/DUSt3R_ViTLarge_BaseDecoder_512_linear.pth",
|
70 |
+
},
|
71 |
+
"dust3r_224_linear": {
|
72 |
+
"encoder": f"{relative_checkpoint_path}/encoders/CroCo_Encoder_224_DUSt3R_linear.pth",
|
73 |
+
"info_sharing": f"{relative_checkpoint_path}/info_sharing/cross_attn_transformer/Two_View_Cross_Attention_Transformer_DUSt3R_224_linear.pth",
|
74 |
+
"feature_head": [
|
75 |
+
f"{relative_checkpoint_path}/prediction_heads/linear_feature_head/DUSt3R_224_linear_feature_head1.pth",
|
76 |
+
f"{relative_checkpoint_path}/prediction_heads/linear_feature_head/DUSt3R_224_linear_feature_head2.pth",
|
77 |
+
],
|
78 |
+
"regressor": None,
|
79 |
+
"ckpt_path": f"{relative_checkpoint_path}/examples/original_dust3r/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth",
|
80 |
+
},
|
81 |
+
}
|
82 |
+
return model_configurations, model_to_checkpoint_path
|
83 |
+
|
84 |
+
|
85 |
+
def get_parser():
|
86 |
+
"Argument parser for the script."
|
87 |
+
parser = argparse.ArgumentParser()
|
88 |
+
parser.add_argument("--viz", action="store_true")
|
89 |
+
|
90 |
+
return parser
|
91 |
+
|
92 |
+
|
93 |
+
if __name__ == "__main__":
|
94 |
+
# Parse arguments
|
95 |
+
parser = get_parser()
|
96 |
+
script_add_rerun_args(parser) # Options: --addr
|
97 |
+
args = parser.parse_args()
|
98 |
+
|
99 |
+
# Set up Rerun for visualization
|
100 |
+
if args.viz:
|
101 |
+
rr.script_setup(args, f"UniCeption_DUSt3R_Inference")
|
102 |
+
rr.set_time("stable_time", sequence=0)
|
103 |
+
|
104 |
+
# the reference data are collected under this setting.
|
105 |
+
# may use (False, "high") to test the relative error at TF32 precision
|
106 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
107 |
+
torch.set_float32_matmul_precision("highest")
|
108 |
+
|
109 |
+
# Get paths to pretrained checkpoints
|
110 |
+
current_file_path = os.path.abspath(__file__)
|
111 |
+
relative_checkpoint_path = os.path.join(os.path.dirname(current_file_path), "../../../checkpoints")
|
112 |
+
model_configurations, model_to_checkpoint_path = get_model_configurations_and_checkpoints()
|
113 |
+
|
114 |
+
MODEL_TO_VERIFICATION_PATH = {
|
115 |
+
"dust3r_512_dpt": {
|
116 |
+
"head_output": os.path.join(
|
117 |
+
os.path.dirname(current_file_path),
|
118 |
+
"../../../reference_data/dust3r_pre_cvpr",
|
119 |
+
"DUSt3R_512_dpt",
|
120 |
+
"03_head_output.npz",
|
121 |
+
)
|
122 |
+
},
|
123 |
+
"dust3r_512_dpt_mast3r": {
|
124 |
+
"head_output": os.path.join(
|
125 |
+
os.path.dirname(current_file_path),
|
126 |
+
"../../../reference_data/dust3r_pre_cvpr",
|
127 |
+
"MASt3R_512_dpt",
|
128 |
+
"03_head_output.npz",
|
129 |
+
)
|
130 |
+
},
|
131 |
+
"dust3r_512_linear": {
|
132 |
+
"head_output": os.path.join(
|
133 |
+
os.path.dirname(current_file_path),
|
134 |
+
"../../../reference_data/dust3r_pre_cvpr",
|
135 |
+
"DUSt3R_512_linear",
|
136 |
+
"03_head_output.npz",
|
137 |
+
)
|
138 |
+
},
|
139 |
+
"dust3r_224_linear": {
|
140 |
+
"head_output": os.path.join(
|
141 |
+
os.path.dirname(current_file_path),
|
142 |
+
"../../../reference_data/dust3r_pre_cvpr",
|
143 |
+
"DUSt3R_224_linear",
|
144 |
+
"03_head_output.npz",
|
145 |
+
)
|
146 |
+
},
|
147 |
+
}
|
148 |
+
|
149 |
+
# Test different DUSt3R models using UniCeption modules
|
150 |
+
for model_name in model_configurations:
|
151 |
+
dust3r_model = DUSt3R(
|
152 |
+
name=model_name,
|
153 |
+
img_size=(512, 512) if "512" in model_name else (224, 224),
|
154 |
+
patch_embed_cls="PatchEmbedDust3R",
|
155 |
+
pred_head_type="linear" if "linear" in model_name else "dpt",
|
156 |
+
pretrained_checkpoint_path=model_to_checkpoint_path[model_name]["ckpt_path"],
|
157 |
+
# pretrained_encoder_checkpoint_path=model_to_checkpoint_path[model_name]["encoder"],
|
158 |
+
# pretrained_info_sharing_checkpoint_path=model_to_checkpoint_path[model_name]["info_sharing"],
|
159 |
+
# pretrained_pred_head_checkpoint_paths=model_to_checkpoint_path[model_name]["feature_head"],
|
160 |
+
# pretrained_pred_head_regressor_checkpoint_paths=model_to_checkpoint_path[model_name]["regressor"],
|
161 |
+
# override_encoder_checkpoint_attributes=True,
|
162 |
+
)
|
163 |
+
print("DUSt3R model initialized successfully!")
|
164 |
+
|
165 |
+
# Initalize device
|
166 |
+
if torch.cuda.is_available():
|
167 |
+
device = "cuda:0"
|
168 |
+
else:
|
169 |
+
device = "cpu"
|
170 |
+
dust3r_model.to(device)
|
171 |
+
|
172 |
+
# Initalize two example images
|
173 |
+
img0_url = (
|
174 |
+
"https://raw.githubusercontent.com/naver/croco/d3d0ab2858d44bcad54e5bfc24f565983fbe18d9/assets/Chateau1.png"
|
175 |
+
)
|
176 |
+
img1_url = (
|
177 |
+
"https://raw.githubusercontent.com/naver/croco/d3d0ab2858d44bcad54e5bfc24f565983fbe18d9/assets/Chateau2.png"
|
178 |
+
)
|
179 |
+
response = requests.get(img0_url)
|
180 |
+
img0 = Image.open(BytesIO(response.content))
|
181 |
+
response = requests.get(img1_url)
|
182 |
+
img1 = Image.open(BytesIO(response.content))
|
183 |
+
img0_tensor = torch.from_numpy(np.array(img0))[..., :3].permute(2, 0, 1).unsqueeze(0).float() / 255
|
184 |
+
img1_tensor = torch.from_numpy(np.array(img1))[..., :3].permute(2, 0, 1).unsqueeze(0).float() / 255
|
185 |
+
|
186 |
+
# Normalize images according to DUSt3R's normalization
|
187 |
+
img0_tensor = (img0_tensor - 0.5) / 0.5
|
188 |
+
img1_tensor = (img1_tensor - 0.5) / 0.5
|
189 |
+
img_tensor = torch.cat((img0_tensor, img1_tensor), dim=0).to(device)
|
190 |
+
|
191 |
+
# Run a forward pass
|
192 |
+
view1 = {"img": img_tensor, "instance": [0, 1], "data_norm_type": "dust3r"}
|
193 |
+
view2 = {"img": view1["img"][[1, 0]].clone().to(device), "instance": [1, 0], "data_norm_type": "dust3r"}
|
194 |
+
|
195 |
+
res1, res2 = dust3r_model(view1, view2)
|
196 |
+
print("Forward pass completed successfully!")
|
197 |
+
|
198 |
+
# Automatically test the results against the reference result from vanilla dust3r code if they exist
|
199 |
+
reference_output_path = MODEL_TO_VERIFICATION_PATH[model_name]["head_output"]
|
200 |
+
if os.path.exists(reference_output_path):
|
201 |
+
reference_output_data = np.load(reference_output_path)
|
202 |
+
|
203 |
+
# Check against the reference output
|
204 |
+
check_dict = {
|
205 |
+
"head1_pts3d": (
|
206 |
+
res1["pts3d"].detach().cpu().numpy(),
|
207 |
+
reference_output_data["head1_pts3d"],
|
208 |
+
),
|
209 |
+
"head2_pts3d": (
|
210 |
+
res2["pts3d_in_other_view"].detach().cpu().numpy(),
|
211 |
+
reference_output_data["head2_pts3d"],
|
212 |
+
),
|
213 |
+
"head1_conf": (
|
214 |
+
res1["conf"].detach().squeeze(-1).cpu().numpy(),
|
215 |
+
reference_output_data["head1_conf"],
|
216 |
+
),
|
217 |
+
"head2_conf": (
|
218 |
+
res2["conf"].detach().squeeze(-1).cpu().numpy(),
|
219 |
+
reference_output_data["head2_conf"],
|
220 |
+
),
|
221 |
+
}
|
222 |
+
|
223 |
+
compute_abs_and_rel_error = lambda x, y: (np.abs(x - y).max(), np.linalg.norm(x - y) / np.linalg.norm(x))
|
224 |
+
|
225 |
+
print(f"===== Checking for {model_name} model =====")
|
226 |
+
for key, (output, reference) in check_dict.items():
|
227 |
+
abs_error, rel_error = compute_abs_and_rel_error(output, reference)
|
228 |
+
print(f"{key} abs_error: {abs_error}, rel_error: {rel_error}")
|
229 |
+
|
230 |
+
assert abs_error < 1e-2 and rel_error < 1e-3, f"Error in {key} output"
|
231 |
+
|
232 |
+
points1 = res1["pts3d"][0].detach().cpu().numpy()
|
233 |
+
points2 = res2["pts3d_in_other_view"][0].detach().cpu().numpy()
|
234 |
+
conf_mask1 = res1["conf"][0].squeeze(-1).detach().cpu().numpy() > 3.0
|
235 |
+
conf_mask2 = res2["conf"][0].squeeze(-1).detach().cpu().numpy() > 3.0
|
236 |
+
|
237 |
+
if args.viz:
|
238 |
+
rr.log(f"{model_name}", rr.ViewCoordinates.RDF, static=True)
|
239 |
+
filtered_pts3d1 = points1[conf_mask1]
|
240 |
+
filtered_pts3d1_colors = np.array(img0)[..., :3][conf_mask1] / 255
|
241 |
+
filtered_pts3d2 = points2[conf_mask2]
|
242 |
+
filtered_pts3d2_colors = np.array(img1)[..., :3][conf_mask2] / 255
|
243 |
+
rr.log(
|
244 |
+
f"{model_name}/view1",
|
245 |
+
rr.Points3D(
|
246 |
+
positions=filtered_pts3d1.reshape(-1, 3),
|
247 |
+
colors=filtered_pts3d1_colors.reshape(-1, 3),
|
248 |
+
),
|
249 |
+
)
|
250 |
+
rr.log(
|
251 |
+
f"{model_name}/view2",
|
252 |
+
rr.Points3D(
|
253 |
+
positions=filtered_pts3d2.reshape(-1, 3),
|
254 |
+
colors=filtered_pts3d2_colors.reshape(-1, 3),
|
255 |
+
),
|
256 |
+
)
|
257 |
+
print(
|
258 |
+
"Visualizations logged to Rerun: rerun+http://127.0.0.1:<rr-port>/proxy."
|
259 |
+
"For example, to spawn viewer: rerun --connect rerun+http://127.0.0.1:<rr-port>/proxy"
|
260 |
+
"Replace <rr-port> with the actual port."
|
261 |
+
)
|
UniCeption/examples/models/dust3r/profile_dust3r.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from dust3r import get_model_configurations_and_checkpoints
|
3 |
+
|
4 |
+
from uniception.models.factory import DUSt3R
|
5 |
+
from uniception.utils.profile import benchmark_torch_function
|
6 |
+
|
7 |
+
if __name__ == "__main__":
|
8 |
+
# Get model configurations and checkpoints
|
9 |
+
model_configurations, model_to_checkpoint_path = get_model_configurations_and_checkpoints()
|
10 |
+
|
11 |
+
# Test different DUSt3R models using UniCeption modules
|
12 |
+
for model_name in model_configurations:
|
13 |
+
dust3r_model = DUSt3R(
|
14 |
+
name=model_name,
|
15 |
+
img_size=(512, 512) if "512" in model_name else (224, 224),
|
16 |
+
patch_embed_cls="PatchEmbedDust3R",
|
17 |
+
pred_head_type="linear" if "linear" in model_name else "dpt",
|
18 |
+
pretrained_checkpoint_path=model_to_checkpoint_path[model_name]["ckpt_path"],
|
19 |
+
)
|
20 |
+
print(f"DUSt3R model ({model_name}) initialized successfully!")
|
21 |
+
|
22 |
+
# Initialize device
|
23 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
24 |
+
dust3r_model.to(device)
|
25 |
+
print(f"Running on {device}")
|
26 |
+
|
27 |
+
# Generate random input tensors
|
28 |
+
img_size = (512, 512) if "512" in model_name else (224, 224)
|
29 |
+
batch_sizes = [1, 2, 4, 8]
|
30 |
+
|
31 |
+
for batch_size in batch_sizes:
|
32 |
+
# Prepare input views
|
33 |
+
view1_instances = range(batch_size)
|
34 |
+
view1_img_tensor = torch.randn(batch_size, 3, *img_size).to(device)
|
35 |
+
view1 = {"img": view1_img_tensor, "instance": view1_instances, "data_norm_type": "dust3r"}
|
36 |
+
view2_instances = range(batch_size)
|
37 |
+
view2_instances = [id + batch_size for id in view2_instances]
|
38 |
+
view2_img_tensor = torch.randn(batch_size, 3, *img_size).to(device)
|
39 |
+
view2 = {"img": view2_img_tensor, "instance": view2_instances, "data_norm_type": "dust3r"}
|
40 |
+
|
41 |
+
# Benchmark the forward pass of the model
|
42 |
+
with torch.no_grad():
|
43 |
+
with torch.autocast("cuda", enabled=True):
|
44 |
+
execution_time = benchmark_torch_function(dust3r_model, view1, view2)
|
45 |
+
print(
|
46 |
+
f"\033[92mForward pass for {model_name}, batch size : {batch_size} completed in {execution_time:.3f} milliseconds\033[0m"
|
47 |
+
)
|
UniCeption/pyproject.toml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.black]
|
2 |
+
line-length = 120
|
3 |
+
include = '\.pyi?$'
|
4 |
+
exclude = '''
|
5 |
+
/(
|
6 |
+
\.git
|
7 |
+
| \.hg
|
8 |
+
| \.mypy_cache
|
9 |
+
| \.tox
|
10 |
+
| \.venv
|
11 |
+
| _build
|
12 |
+
| buck-out
|
13 |
+
| build
|
14 |
+
| cuda
|
15 |
+
| dist
|
16 |
+
)/
|
17 |
+
'''
|
18 |
+
|
19 |
+
[tool.isort]
|
20 |
+
profile = "black"
|
21 |
+
line_length = 120
|
UniCeption/scripts/check_dependencies.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Console script to check UniCeption dependencies.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import sys
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
# Add the parent directory to the path to import uniception
|
10 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
11 |
+
|
12 |
+
|
13 |
+
def check_dependencies():
|
14 |
+
"""Check if optional dependencies are available."""
|
15 |
+
try:
|
16 |
+
import torch
|
17 |
+
|
18 |
+
print(f"PyTorch version: {torch.__version__}")
|
19 |
+
if torch.cuda.is_available():
|
20 |
+
print(f"CUDA available: {torch.version.cuda}")
|
21 |
+
else:
|
22 |
+
print("CUDA not available")
|
23 |
+
except ImportError:
|
24 |
+
print("PyTorch not installed")
|
25 |
+
|
26 |
+
try:
|
27 |
+
import xformers
|
28 |
+
|
29 |
+
print(f"XFormers version: {xformers.__version__}")
|
30 |
+
except ImportError:
|
31 |
+
print("XFormers not installed")
|
32 |
+
|
33 |
+
try:
|
34 |
+
from uniception.models.libs.croco.curope import cuRoPE2D
|
35 |
+
|
36 |
+
print("CroCo RoPE extension available")
|
37 |
+
except ImportError:
|
38 |
+
print("CroCo RoPE extension not available")
|
39 |
+
|
40 |
+
|
41 |
+
def main():
|
42 |
+
"""Main entry point for the check dependencies script."""
|
43 |
+
print("Checking UniCeption Dependencies...")
|
44 |
+
print("=" * 40)
|
45 |
+
check_dependencies()
|
46 |
+
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
main()
|
UniCeption/scripts/download_checkpoints.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"Download the UniCeption format checkpoints from the AirLab Data Server"
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
|
6 |
+
from minio import Minio
|
7 |
+
from minio.error import S3Error
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
|
11 |
+
def main():
|
12 |
+
parser = argparse.ArgumentParser(description="Download UniCeption format checkpoints from AirLab Data Server")
|
13 |
+
parser.add_argument(
|
14 |
+
"--folders",
|
15 |
+
nargs="+",
|
16 |
+
default=["encoders", "info_sharing", "prediction_heads", "examples"],
|
17 |
+
help="List of folders to download (default: all folders). Choices: encoders, info_sharing, prediction_heads, examples",
|
18 |
+
)
|
19 |
+
parser.add_argument("--destination", type=str, default="./", help="Destination folder for downloaded checkpoints")
|
20 |
+
args = parser.parse_args()
|
21 |
+
|
22 |
+
access_key = "bT79gQYtfhpxFIitlpns"
|
23 |
+
secret_key = "g7mSvUJ5k2a9mKv9IbhwXmUQjQX52MLwulhW9ONO"
|
24 |
+
client = Minio("airlab-share-02.andrew.cmu.edu:9000", access_key=access_key, secret_key=secret_key, secure=True)
|
25 |
+
|
26 |
+
bucket_name = "uniception"
|
27 |
+
|
28 |
+
def download_folder(folder_name, bucket_name, client, destination_folder):
|
29 |
+
folder_name = f"checkpoints/{folder_name}/"
|
30 |
+
objects = client.list_objects(bucket_name, prefix=folder_name, recursive=True)
|
31 |
+
for obj in tqdm(objects, desc=f"Downloading {folder_name}"):
|
32 |
+
destination_file = os.path.join(destination_folder, obj.object_name)
|
33 |
+
if not os.path.exists(destination_file):
|
34 |
+
os.makedirs(os.path.dirname(destination_file), exist_ok=True)
|
35 |
+
try:
|
36 |
+
client.fget_object(bucket_name, obj.object_name, destination_file)
|
37 |
+
print(f"Downloaded {obj.object_name} to {destination_file}")
|
38 |
+
except S3Error as e:
|
39 |
+
print(f"Error downloading {obj.object_name}: {e}")
|
40 |
+
else:
|
41 |
+
print(f"File {destination_file} already exists. Skipping...")
|
42 |
+
|
43 |
+
for folder in args.folders:
|
44 |
+
download_folder(folder, bucket_name, client, args.destination)
|
45 |
+
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
main()
|
UniCeption/scripts/install_croco_rope.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Console script to install CroCo RoPE extension.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import subprocess
|
8 |
+
import sys
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
|
12 |
+
def install_croco_rope():
|
13 |
+
"""Install CroCo RoPE extension."""
|
14 |
+
try:
|
15 |
+
# Find the project root (where setup.py is located)
|
16 |
+
script_dir = Path(__file__).parent
|
17 |
+
project_root = script_dir.parent
|
18 |
+
curope_path = project_root / "uniception" / "models" / "libs" / "croco" / "curope"
|
19 |
+
|
20 |
+
if curope_path.exists():
|
21 |
+
print("Installing CroCo RoPE extension...")
|
22 |
+
original_cwd = os.getcwd()
|
23 |
+
try:
|
24 |
+
os.chdir(curope_path)
|
25 |
+
subprocess.check_call([sys.executable, "setup.py", "build_ext", "--inplace"])
|
26 |
+
print("CroCo RoPE extension installed successfully!")
|
27 |
+
return True
|
28 |
+
except subprocess.CalledProcessError as e:
|
29 |
+
print(f"Warning: Failed to install CroCo RoPE extension: {e}")
|
30 |
+
print("You can install it later by running:")
|
31 |
+
print(f"cd {curope_path} && python setup.py build_ext --inplace")
|
32 |
+
return False
|
33 |
+
finally:
|
34 |
+
os.chdir(original_cwd)
|
35 |
+
else:
|
36 |
+
print("Warning: CroCo RoPE source code not found.")
|
37 |
+
print(f"Expected location: {curope_path}")
|
38 |
+
return False
|
39 |
+
except Exception as e:
|
40 |
+
print(f"Warning: Error during CroCo RoPE installation: {e}")
|
41 |
+
return False
|
42 |
+
|
43 |
+
|
44 |
+
def main():
|
45 |
+
"""Main entry point for the CroCo RoPE installation script."""
|
46 |
+
print("UniCeption CroCo RoPE Extension Installer")
|
47 |
+
print("=" * 45)
|
48 |
+
|
49 |
+
success = install_croco_rope()
|
50 |
+
|
51 |
+
if success:
|
52 |
+
print("\n✓ CroCo RoPE extension installation completed successfully!")
|
53 |
+
sys.exit(0)
|
54 |
+
else:
|
55 |
+
print("\n⚠ CroCo RoPE extension installation failed or skipped.")
|
56 |
+
print("This is typically due to missing CUDA development tools.")
|
57 |
+
print("The extension is optional and UniCeption will work without it.")
|
58 |
+
sys.exit(1)
|
59 |
+
|
60 |
+
|
61 |
+
if __name__ == "__main__":
|
62 |
+
main()
|
UniCeption/scripts/prepare_offline_install.py
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Script to prepare dependencies for offline installation.
|
4 |
+
|
5 |
+
This script downloads all necessary wheel files for offline installation
|
6 |
+
of UniCeption in environments without internet access.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
import os
|
11 |
+
import subprocess
|
12 |
+
import sys
|
13 |
+
from pathlib import Path
|
14 |
+
|
15 |
+
|
16 |
+
def download_wheels(output_dir: Path, extras: list = None):
|
17 |
+
"""Download wheel files for offline installation."""
|
18 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
19 |
+
|
20 |
+
# Create temporary requirements files
|
21 |
+
temp_dir = output_dir / "temp"
|
22 |
+
temp_dir.mkdir(exist_ok=True)
|
23 |
+
|
24 |
+
try:
|
25 |
+
# Create requirements files
|
26 |
+
create_requirements_files(temp_dir, extras)
|
27 |
+
|
28 |
+
# Download base dependencies
|
29 |
+
base_cmd = [
|
30 |
+
sys.executable,
|
31 |
+
"-m",
|
32 |
+
"pip",
|
33 |
+
"download",
|
34 |
+
"--dest",
|
35 |
+
str(output_dir),
|
36 |
+
"-r",
|
37 |
+
str(temp_dir / "requirements-base.txt"),
|
38 |
+
]
|
39 |
+
|
40 |
+
print(f"Downloading base dependencies to {output_dir}...")
|
41 |
+
subprocess.check_call(base_cmd)
|
42 |
+
|
43 |
+
# Download optional dependencies if requested
|
44 |
+
if extras:
|
45 |
+
for extra in extras:
|
46 |
+
if extra == "all":
|
47 |
+
# Download all extras
|
48 |
+
for req_file in ["requirements-xformers.txt", "requirements-dev.txt"]:
|
49 |
+
if (temp_dir / req_file).exists():
|
50 |
+
cmd = [
|
51 |
+
sys.executable,
|
52 |
+
"-m",
|
53 |
+
"pip",
|
54 |
+
"download",
|
55 |
+
"--dest",
|
56 |
+
str(output_dir),
|
57 |
+
"-r",
|
58 |
+
str(temp_dir / req_file),
|
59 |
+
]
|
60 |
+
print(
|
61 |
+
f"Downloading {req_file.replace('requirements-', '').replace('.txt', '')} dependencies..."
|
62 |
+
)
|
63 |
+
try:
|
64 |
+
subprocess.check_call(cmd)
|
65 |
+
except subprocess.CalledProcessError as e:
|
66 |
+
print(f"Warning: Failed to download {extra} dependencies: {e}")
|
67 |
+
else:
|
68 |
+
req_file = temp_dir / f"requirements-{extra}.txt"
|
69 |
+
if req_file.exists():
|
70 |
+
cmd = [sys.executable, "-m", "pip", "download", "--dest", str(output_dir), "-r", str(req_file)]
|
71 |
+
print(f"Downloading {extra} dependencies...")
|
72 |
+
try:
|
73 |
+
subprocess.check_call(cmd)
|
74 |
+
except subprocess.CalledProcessError as e:
|
75 |
+
print(f"Warning: Failed to download {extra} dependencies: {e}")
|
76 |
+
|
77 |
+
# Create final offline installation files
|
78 |
+
create_offline_installation_files(output_dir)
|
79 |
+
|
80 |
+
print("Download completed successfully!")
|
81 |
+
|
82 |
+
except subprocess.CalledProcessError as e:
|
83 |
+
print(f"Error downloading wheels: {e}")
|
84 |
+
sys.exit(1)
|
85 |
+
finally:
|
86 |
+
# Clean up temporary files
|
87 |
+
import shutil
|
88 |
+
|
89 |
+
if temp_dir.exists():
|
90 |
+
shutil.rmtree(temp_dir)
|
91 |
+
|
92 |
+
|
93 |
+
def create_requirements_files(temp_dir: Path, extras: list = None):
|
94 |
+
"""Create temporary requirements files for downloading."""
|
95 |
+
|
96 |
+
# Base requirements (including PyTorch)
|
97 |
+
base_reqs = [
|
98 |
+
"numpy",
|
99 |
+
"torch",
|
100 |
+
"torchvision",
|
101 |
+
"torchaudio",
|
102 |
+
"timm",
|
103 |
+
"black",
|
104 |
+
"jaxtyping",
|
105 |
+
"matplotlib",
|
106 |
+
"Pillow",
|
107 |
+
"scikit-learn",
|
108 |
+
"einops",
|
109 |
+
"rerun-sdk",
|
110 |
+
"pre-commit",
|
111 |
+
"minio",
|
112 |
+
"pytest",
|
113 |
+
"isort",
|
114 |
+
]
|
115 |
+
|
116 |
+
# Write base requirements
|
117 |
+
with open(temp_dir / "requirements-base.txt", "w") as f:
|
118 |
+
for req in base_reqs:
|
119 |
+
f.write(f"{req}\n")
|
120 |
+
|
121 |
+
# XFormers requirements
|
122 |
+
with open(temp_dir / "requirements-xformers.txt", "w") as f:
|
123 |
+
f.write("xformers\n")
|
124 |
+
|
125 |
+
# Dev requirements
|
126 |
+
dev_reqs = [
|
127 |
+
"black",
|
128 |
+
"isort",
|
129 |
+
"pre-commit",
|
130 |
+
"pytest",
|
131 |
+
]
|
132 |
+
|
133 |
+
with open(temp_dir / "requirements-dev.txt", "w") as f:
|
134 |
+
for req in dev_reqs:
|
135 |
+
f.write(f"{req}\n")
|
136 |
+
|
137 |
+
|
138 |
+
def create_offline_installation_files(output_dir: Path):
|
139 |
+
"""Create requirements files and installation script for offline use."""
|
140 |
+
|
141 |
+
# Base requirements (including PyTorch)
|
142 |
+
base_reqs = [
|
143 |
+
"numpy",
|
144 |
+
"torch",
|
145 |
+
"torchvision",
|
146 |
+
"torchaudio",
|
147 |
+
"timm",
|
148 |
+
"black",
|
149 |
+
"jaxtyping",
|
150 |
+
"matplotlib",
|
151 |
+
"Pillow",
|
152 |
+
"scikit-learn",
|
153 |
+
"einops",
|
154 |
+
"rerun-sdk",
|
155 |
+
"pre-commit",
|
156 |
+
"minio",
|
157 |
+
"pytest",
|
158 |
+
"isort",
|
159 |
+
]
|
160 |
+
|
161 |
+
# Write base requirements
|
162 |
+
with open(output_dir / "requirements-base.txt", "w") as f:
|
163 |
+
for req in base_reqs:
|
164 |
+
f.write(f"{req}\n")
|
165 |
+
|
166 |
+
# XFormers requirements
|
167 |
+
with open(output_dir / "requirements-xformers.txt", "w") as f:
|
168 |
+
f.write("xformers\n")
|
169 |
+
|
170 |
+
# Dev requirements
|
171 |
+
dev_reqs = [
|
172 |
+
"black",
|
173 |
+
"isort",
|
174 |
+
"pre-commit",
|
175 |
+
"pytest",
|
176 |
+
]
|
177 |
+
|
178 |
+
with open(output_dir / "requirements-dev.txt", "w") as f:
|
179 |
+
for req in dev_reqs:
|
180 |
+
f.write(f"{req}\n")
|
181 |
+
|
182 |
+
# Create installation script
|
183 |
+
install_script = output_dir / "install_offline.sh"
|
184 |
+
with open(install_script, "w") as f:
|
185 |
+
f.write(
|
186 |
+
"""#!/bin/bash
|
187 |
+
# Offline installation script for UniCeption
|
188 |
+
|
189 |
+
set -e
|
190 |
+
|
191 |
+
echo "Installing UniCeption dependencies offline..."
|
192 |
+
|
193 |
+
# Check if we're in the right directory
|
194 |
+
if [ ! -f "requirements-base.txt" ]; then
|
195 |
+
echo "Error: requirements-base.txt not found. Please run this script from the offline_wheels directory."
|
196 |
+
exit 1
|
197 |
+
fi
|
198 |
+
|
199 |
+
# Install base dependencies (includes PyTorch)
|
200 |
+
echo "Installing base dependencies (including PyTorch)..."
|
201 |
+
pip install --no-index --find-links . -r requirements-base.txt
|
202 |
+
|
203 |
+
# Install XFormers if requested
|
204 |
+
if [ "$INSTALL_XFORMERS" = "true" ]; then
|
205 |
+
echo "Installing XFormers..."
|
206 |
+
pip install --no-index --find-links . -r requirements-xformers.txt
|
207 |
+
fi
|
208 |
+
|
209 |
+
# Install dev dependencies if requested
|
210 |
+
if [ "$INSTALL_DEV" = "true" ]; then
|
211 |
+
echo "Installing development dependencies..."
|
212 |
+
pip install --no-index --find-links . -r requirements-dev.txt
|
213 |
+
fi
|
214 |
+
|
215 |
+
# Navigate back to UniCeption directory and install the package
|
216 |
+
echo "Installing UniCeption package..."
|
217 |
+
cd ..
|
218 |
+
pip install --no-deps -e .
|
219 |
+
|
220 |
+
# Install CroCo RoPE extension if requested
|
221 |
+
if [ "$INSTALL_CROCO_ROPE" = "true" ]; then
|
222 |
+
echo "Installing CroCo RoPE extension..."
|
223 |
+
cd uniception/models/libs/croco/curope
|
224 |
+
python setup.py build_ext --inplace
|
225 |
+
cd -
|
226 |
+
fi
|
227 |
+
|
228 |
+
echo "Offline installation completed successfully!"
|
229 |
+
echo ""
|
230 |
+
echo "To verify installation, run:"
|
231 |
+
echo "python setup.py check_deps"
|
232 |
+
"""
|
233 |
+
)
|
234 |
+
|
235 |
+
# Make script executable
|
236 |
+
install_script.chmod(0o755)
|
237 |
+
|
238 |
+
# Create Windows batch script as well
|
239 |
+
install_bat = output_dir / "install_offline.bat"
|
240 |
+
with open(install_bat, "w") as f:
|
241 |
+
f.write(
|
242 |
+
"""@echo off
|
243 |
+
REM Offline installation script for UniCeption (Windows)
|
244 |
+
|
245 |
+
echo Installing UniCeption dependencies offline...
|
246 |
+
|
247 |
+
REM Check if we're in the right directory
|
248 |
+
if not exist "requirements-base.txt" (
|
249 |
+
echo Error: requirements-base.txt not found. Please run this script from the offline_wheels directory.
|
250 |
+
exit /b 1
|
251 |
+
)
|
252 |
+
|
253 |
+
REM Install base dependencies (includes PyTorch)
|
254 |
+
echo Installing base dependencies (including PyTorch)...
|
255 |
+
pip install --no-index --find-links . -r requirements-base.txt
|
256 |
+
|
257 |
+
REM Install XFormers if requested
|
258 |
+
if "%INSTALL_XFORMERS%"=="true" (
|
259 |
+
echo Installing XFormers...
|
260 |
+
pip install --no-index --find-links . -r requirements-xformers.txt
|
261 |
+
)
|
262 |
+
|
263 |
+
REM Install dev dependencies if requested
|
264 |
+
if "%INSTALL_DEV%"=="true" (
|
265 |
+
echo Installing development dependencies...
|
266 |
+
pip install --no-index --find-links . -r requirements-dev.txt
|
267 |
+
)
|
268 |
+
|
269 |
+
REM Navigate back to UniCeption directory and install the package
|
270 |
+
echo Installing UniCeption package...
|
271 |
+
cd ..
|
272 |
+
pip install --no-deps -e .
|
273 |
+
|
274 |
+
REM Install CroCo RoPE extension if requested
|
275 |
+
if "%INSTALL_CROCO_ROPE%"=="true" (
|
276 |
+
echo Installing CroCo RoPE extension...
|
277 |
+
cd uniception\\models\\libs\\croco\\curope
|
278 |
+
python setup.py build_ext --inplace
|
279 |
+
cd ..\\..\\..\\..\\..
|
280 |
+
)
|
281 |
+
|
282 |
+
echo Offline installation completed successfully!
|
283 |
+
echo.
|
284 |
+
echo To verify installation, run:
|
285 |
+
echo python setup.py check_deps
|
286 |
+
"""
|
287 |
+
)
|
288 |
+
|
289 |
+
# Create a README for offline installation
|
290 |
+
readme_file = output_dir / "README_OFFLINE.md"
|
291 |
+
with open(readme_file, "w") as f:
|
292 |
+
f.write(
|
293 |
+
"""# UniCeption Offline Installation
|
294 |
+
|
295 |
+
This directory contains all the necessary files for installing UniCeption without internet access.
|
296 |
+
|
297 |
+
## Files Included
|
298 |
+
|
299 |
+
- `requirements-base.txt` - Core dependencies (including PyTorch)
|
300 |
+
- `requirements-xformers.txt` - XFormers dependency
|
301 |
+
- `requirements-dev.txt` - Development dependencies
|
302 |
+
- `install_offline.sh` - Installation script for Unix/Linux/macOS
|
303 |
+
- `install_offline.bat` - Installation script for Windows
|
304 |
+
- `*.whl` files - Downloaded wheel packages
|
305 |
+
|
306 |
+
## Installation Instructions
|
307 |
+
|
308 |
+
### Unix/Linux/macOS
|
309 |
+
|
310 |
+
```bash
|
311 |
+
# Set environment variables for optional components
|
312 |
+
export INSTALL_XFORMERS=true # Install XFormers
|
313 |
+
export INSTALL_DEV=true # Install development tools
|
314 |
+
export INSTALL_CROCO_ROPE=true # Compile CroCo RoPE extension
|
315 |
+
|
316 |
+
# Run the installation script
|
317 |
+
./install_offline.sh
|
318 |
+
```
|
319 |
+
|
320 |
+
### Windows
|
321 |
+
|
322 |
+
```cmd
|
323 |
+
REM Set environment variables for optional components
|
324 |
+
set INSTALL_XFORMERS=true
|
325 |
+
set INSTALL_DEV=true
|
326 |
+
set INSTALL_CROCO_ROPE=true
|
327 |
+
|
328 |
+
REM Run the installation script
|
329 |
+
install_offline.bat
|
330 |
+
```
|
331 |
+
|
332 |
+
## Manual Installation
|
333 |
+
|
334 |
+
If the scripts don't work, you can install manually:
|
335 |
+
|
336 |
+
```bash
|
337 |
+
# Install base dependencies (includes PyTorch)
|
338 |
+
pip install --no-index --find-links . -r requirements-base.txt
|
339 |
+
|
340 |
+
# Install optional dependencies as needed
|
341 |
+
pip install --no-index --find-links . -r requirements-xformers.txt
|
342 |
+
pip install --no-index --find-links . -r requirements-dev.txt
|
343 |
+
|
344 |
+
# Install UniCeption package (from parent directory)
|
345 |
+
cd ..
|
346 |
+
pip install --no-deps -e .
|
347 |
+
|
348 |
+
# Compile CroCo RoPE extension (optional)
|
349 |
+
cd uniception/models/libs/croco/curope
|
350 |
+
python setup.py build_ext --inplace
|
351 |
+
```
|
352 |
+
|
353 |
+
## Verification
|
354 |
+
|
355 |
+
After installation, verify everything is working:
|
356 |
+
|
357 |
+
```bash
|
358 |
+
cd .. # Go back to UniCeption root directory
|
359 |
+
python setup.py check_deps
|
360 |
+
```
|
361 |
+
|
362 |
+
## Notes
|
363 |
+
|
364 |
+
- PyTorch, TorchVision, and TorchAudio are now included in the base requirements
|
365 |
+
- XFormers is optional and only needed for certain performance optimizations
|
366 |
+
- CroCo RoPE extension compilation requires a CUDA-enabled environment
|
367 |
+
"""
|
368 |
+
)
|
369 |
+
|
370 |
+
print(f"Created offline installation files in {output_dir}")
|
371 |
+
print("Files created:")
|
372 |
+
print(" - requirements-base.txt (includes PyTorch)")
|
373 |
+
print(" - requirements-xformers.txt")
|
374 |
+
print(" - requirements-dev.txt")
|
375 |
+
print(" - install_offline.sh (Unix/Linux/macOS)")
|
376 |
+
print(" - install_offline.bat (Windows)")
|
377 |
+
print(" - README_OFFLINE.md")
|
378 |
+
|
379 |
+
|
380 |
+
def create_offline_requirements(output_dir: Path):
|
381 |
+
"""Create requirements files for offline installation."""
|
382 |
+
# This function is now replaced by create_offline_installation_files
|
383 |
+
pass
|
384 |
+
|
385 |
+
|
386 |
+
def main():
|
387 |
+
parser = argparse.ArgumentParser(description="Prepare UniCeption for offline installation")
|
388 |
+
parser.add_argument(
|
389 |
+
"--output-dir", type=Path, default="offline_wheels", help="Directory to store downloaded wheels"
|
390 |
+
)
|
391 |
+
parser.add_argument("--extras", nargs="+", choices=["xformers", "dev", "all"], help="Extra dependencies to include")
|
392 |
+
|
393 |
+
args = parser.parse_args()
|
394 |
+
|
395 |
+
download_wheels(args.output_dir, args.extras)
|
396 |
+
|
397 |
+
|
398 |
+
if __name__ == "__main__":
|
399 |
+
main()
|
UniCeption/scripts/validate_installation.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Validation script for UniCeption installation.
|
4 |
+
|
5 |
+
This script validates that all components of UniCeption are correctly installed
|
6 |
+
and provides helpful diagnostics.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import importlib
|
10 |
+
import sys
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
|
14 |
+
def check_package_installation():
|
15 |
+
"""Check if UniCeption package is properly installed."""
|
16 |
+
try:
|
17 |
+
import uniception
|
18 |
+
|
19 |
+
print("✓ UniCeption package is installed")
|
20 |
+
|
21 |
+
# Check if we can import core modules
|
22 |
+
try:
|
23 |
+
from uniception.models.encoders import UniCeptionViTEncoderBase
|
24 |
+
|
25 |
+
print("✓ Core encoder modules are available")
|
26 |
+
except ImportError as e:
|
27 |
+
print(f"✗ Failed to import core encoder modules: {e}")
|
28 |
+
|
29 |
+
return True
|
30 |
+
except ImportError as e:
|
31 |
+
print(f"✗ UniCeption package not found: {e}")
|
32 |
+
return False
|
33 |
+
|
34 |
+
|
35 |
+
def check_dependencies():
|
36 |
+
"""Check optional dependencies."""
|
37 |
+
dependencies = {
|
38 |
+
"torch": "PyTorch",
|
39 |
+
"torchvision": "TorchVision",
|
40 |
+
"torchaudio": "TorchAudio",
|
41 |
+
"xformers": "XFormers",
|
42 |
+
"timm": "Timm (PyTorch Image Models)",
|
43 |
+
"einops": "Einops",
|
44 |
+
"matplotlib": "Matplotlib",
|
45 |
+
"numpy": "NumPy",
|
46 |
+
"PIL": "Pillow",
|
47 |
+
}
|
48 |
+
|
49 |
+
available = []
|
50 |
+
missing = []
|
51 |
+
|
52 |
+
for module, name in dependencies.items():
|
53 |
+
try:
|
54 |
+
mod = importlib.import_module(module)
|
55 |
+
version = getattr(mod, "__version__", "unknown")
|
56 |
+
available.append(f"✓ {name}: {version}")
|
57 |
+
except ImportError:
|
58 |
+
missing.append(f"✗ {name}: not installed")
|
59 |
+
|
60 |
+
print("\nDependency Status:")
|
61 |
+
for dep in available:
|
62 |
+
print(f" {dep}")
|
63 |
+
|
64 |
+
if missing:
|
65 |
+
print("\nMissing Dependencies:")
|
66 |
+
for dep in missing:
|
67 |
+
print(f" {dep}")
|
68 |
+
|
69 |
+
return len(missing) == 0
|
70 |
+
|
71 |
+
|
72 |
+
def check_cuda_support():
|
73 |
+
"""Check CUDA support."""
|
74 |
+
try:
|
75 |
+
import torch
|
76 |
+
|
77 |
+
if torch.cuda.is_available():
|
78 |
+
print(f"\n✓ CUDA is available")
|
79 |
+
print(f" CUDA version: {torch.version.cuda}")
|
80 |
+
print(f" Available devices: {torch.cuda.device_count()}")
|
81 |
+
for i in range(torch.cuda.device_count()):
|
82 |
+
print(f" Device {i}: {torch.cuda.get_device_name(i)}")
|
83 |
+
else:
|
84 |
+
print(f"\n⚠ CUDA is not available (CPU-only mode)")
|
85 |
+
except ImportError:
|
86 |
+
print(f"\n⚠ PyTorch not installed - cannot check CUDA support")
|
87 |
+
|
88 |
+
|
89 |
+
def check_croco_rope():
|
90 |
+
"""Check CroCo RoPE extension."""
|
91 |
+
try:
|
92 |
+
from uniception.models.libs.croco.curope import cuRoPE2D
|
93 |
+
|
94 |
+
print("\n✓ CroCo RoPE extension is available")
|
95 |
+
return True
|
96 |
+
except ImportError:
|
97 |
+
print("\n✗ CroCo RoPE extension not available")
|
98 |
+
print(" To install: cd uniception/models/libs/croco/curope && python setup.py build_ext --inplace")
|
99 |
+
return False
|
100 |
+
|
101 |
+
|
102 |
+
def check_model_availability():
|
103 |
+
"""Check if models can be loaded."""
|
104 |
+
try:
|
105 |
+
# Try to check if encoder modules are available
|
106 |
+
from uniception.models import encoders
|
107 |
+
|
108 |
+
print(f"\n✓ Encoder module is available")
|
109 |
+
|
110 |
+
# Try to run the encoder list command
|
111 |
+
try:
|
112 |
+
import subprocess
|
113 |
+
|
114 |
+
result = subprocess.run(
|
115 |
+
[sys.executable, "-m", "uniception.models.encoders.list"], capture_output=True, text=True, timeout=10
|
116 |
+
)
|
117 |
+
|
118 |
+
if result.returncode == 0:
|
119 |
+
lines = result.stdout.strip().split("\n")
|
120 |
+
encoder_count = len([line for line in lines if line.strip() and not line.startswith("Available")])
|
121 |
+
print(f"✓ Available encoders: {encoder_count}")
|
122 |
+
return True
|
123 |
+
else:
|
124 |
+
print(f"⚠ Encoder listing returned non-zero exit code: {result.returncode}")
|
125 |
+
return False
|
126 |
+
|
127 |
+
except subprocess.TimeoutExpired:
|
128 |
+
print(f"⚠ Encoder listing timed out")
|
129 |
+
return False
|
130 |
+
except Exception as e:
|
131 |
+
print(f"⚠ Could not run encoder listing: {e}")
|
132 |
+
return False
|
133 |
+
|
134 |
+
except Exception as e:
|
135 |
+
print(f"\n✗ Failed to access encoder modules: {e}")
|
136 |
+
return False
|
137 |
+
|
138 |
+
|
139 |
+
def check_file_structure():
|
140 |
+
"""Check if the project file structure is correct."""
|
141 |
+
base_path = Path(__file__).parent.parent
|
142 |
+
required_dirs = [
|
143 |
+
"uniception",
|
144 |
+
"uniception/models",
|
145 |
+
"uniception/models/encoders",
|
146 |
+
"uniception/models/info_sharing",
|
147 |
+
"uniception/models/prediction_heads",
|
148 |
+
"scripts",
|
149 |
+
"tests",
|
150 |
+
]
|
151 |
+
|
152 |
+
missing_dirs = []
|
153 |
+
for dir_path in required_dirs:
|
154 |
+
full_path = base_path / dir_path
|
155 |
+
if not full_path.exists():
|
156 |
+
missing_dirs.append(dir_path)
|
157 |
+
|
158 |
+
if missing_dirs:
|
159 |
+
print(f"\n✗ Missing directories:")
|
160 |
+
for dir_path in missing_dirs:
|
161 |
+
print(f" - {dir_path}")
|
162 |
+
return False
|
163 |
+
else:
|
164 |
+
print(f"\n✓ Project structure is correct")
|
165 |
+
return True
|
166 |
+
|
167 |
+
|
168 |
+
def main():
|
169 |
+
"""Run all validation checks."""
|
170 |
+
print("UniCeption Installation Validation")
|
171 |
+
print("=" * 40)
|
172 |
+
|
173 |
+
checks = [
|
174 |
+
("Package Installation", check_package_installation),
|
175 |
+
("Dependencies", check_dependencies),
|
176 |
+
("CUDA Support", check_cuda_support),
|
177 |
+
("CroCo RoPE Extension", check_croco_rope),
|
178 |
+
("Model Availability", check_model_availability),
|
179 |
+
("File Structure", check_file_structure),
|
180 |
+
]
|
181 |
+
|
182 |
+
results = []
|
183 |
+
for name, check_func in checks:
|
184 |
+
print(f"\nChecking {name}...")
|
185 |
+
try:
|
186 |
+
result = check_func()
|
187 |
+
results.append((name, result))
|
188 |
+
except Exception as e:
|
189 |
+
print(f"✗ Error during {name} check: {e}")
|
190 |
+
results.append((name, False))
|
191 |
+
|
192 |
+
# Summary
|
193 |
+
print("\n" + "=" * 40)
|
194 |
+
print("Validation Summary:")
|
195 |
+
passed = 0
|
196 |
+
for name, result in results:
|
197 |
+
status = "✓ PASS" if result else "✗ FAIL"
|
198 |
+
print(f" {name}: {status}")
|
199 |
+
if result:
|
200 |
+
passed += 1
|
201 |
+
|
202 |
+
print(f"\nOverall: {passed}/{len(results)} checks passed")
|
203 |
+
|
204 |
+
if passed == len(results):
|
205 |
+
print("🎉 All checks passed! UniCeption is ready to use.")
|
206 |
+
return 0
|
207 |
+
else:
|
208 |
+
print("⚠ Some checks failed. Please review the issues above.")
|
209 |
+
return 1
|
210 |
+
|
211 |
+
|
212 |
+
if __name__ == "__main__":
|
213 |
+
sys.exit(main())
|
UniCeption/setup.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Package installation setup."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
import sys
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
from setuptools import find_packages, setup
|
9 |
+
from setuptools.command.develop import develop
|
10 |
+
from setuptools.command.install import install
|
11 |
+
|
12 |
+
|
13 |
+
def install_croco_rope():
|
14 |
+
"""Install CroCo RoPE extension."""
|
15 |
+
try:
|
16 |
+
curope_path = Path(__file__).parent / "uniception" / "models" / "libs" / "croco" / "curope"
|
17 |
+
if curope_path.exists():
|
18 |
+
print("Installing CroCo RoPE extension...")
|
19 |
+
original_cwd = os.getcwd()
|
20 |
+
try:
|
21 |
+
os.chdir(curope_path)
|
22 |
+
subprocess.check_call([sys.executable, "setup.py", "build_ext", "--inplace"])
|
23 |
+
print("CroCo RoPE extension installed successfully!")
|
24 |
+
return True
|
25 |
+
except subprocess.CalledProcessError as e:
|
26 |
+
print(f"Warning: Failed to install CroCo RoPE extension: {e}")
|
27 |
+
print("You can install it later by running:")
|
28 |
+
print(f"cd {curope_path} && python setup.py build_ext --inplace")
|
29 |
+
return False
|
30 |
+
finally:
|
31 |
+
os.chdir(original_cwd)
|
32 |
+
else:
|
33 |
+
print("Warning: CroCo RoPE source code not found.")
|
34 |
+
return False
|
35 |
+
except Exception as e:
|
36 |
+
print(f"Warning: Error during CroCo RoPE installation: {e}")
|
37 |
+
return False
|
38 |
+
|
39 |
+
|
40 |
+
def check_dependencies():
|
41 |
+
"""Check if optional dependencies are available."""
|
42 |
+
try:
|
43 |
+
import torch
|
44 |
+
|
45 |
+
print(f"PyTorch version: {torch.__version__}")
|
46 |
+
if torch.cuda.is_available():
|
47 |
+
print(f"CUDA available: {torch.version.cuda}")
|
48 |
+
else:
|
49 |
+
print("CUDA not available")
|
50 |
+
except ImportError:
|
51 |
+
print("PyTorch not installed")
|
52 |
+
|
53 |
+
try:
|
54 |
+
import xformers
|
55 |
+
|
56 |
+
print(f"XFormers version: {xformers.__version__}")
|
57 |
+
except ImportError:
|
58 |
+
print("XFormers not installed")
|
59 |
+
|
60 |
+
try:
|
61 |
+
from uniception.models.libs.croco.curope import cuRoPE2D
|
62 |
+
|
63 |
+
print("CroCo RoPE extension available")
|
64 |
+
except ImportError:
|
65 |
+
print("CroCo RoPE extension not available")
|
66 |
+
|
67 |
+
|
68 |
+
class CustomDevelopCommand(develop):
|
69 |
+
"""Custom development installation command."""
|
70 |
+
|
71 |
+
def run(self):
|
72 |
+
develop.run(self)
|
73 |
+
# Only install CroCo RoPE if explicitly requested
|
74 |
+
if os.getenv("INSTALL_CROCO_ROPE", "false").lower() in ("true", "1", "yes"):
|
75 |
+
install_croco_rope()
|
76 |
+
|
77 |
+
|
78 |
+
class CustomInstallCommand(install):
|
79 |
+
"""Custom installation command."""
|
80 |
+
|
81 |
+
def run(self):
|
82 |
+
install.run(self)
|
83 |
+
# Only install CroCo RoPE if explicitly requested
|
84 |
+
if os.getenv("INSTALL_CROCO_ROPE", "false").lower() in ("true", "1", "yes"):
|
85 |
+
install_croco_rope()
|
86 |
+
|
87 |
+
|
88 |
+
class CrocoInstallCommand(install):
|
89 |
+
"""Install command that includes CroCo RoPE extension."""
|
90 |
+
|
91 |
+
def run(self):
|
92 |
+
install.run(self)
|
93 |
+
install_croco_rope()
|
94 |
+
|
95 |
+
|
96 |
+
class CheckDependenciesCommand(install):
|
97 |
+
"""Command to check available dependencies."""
|
98 |
+
|
99 |
+
def run(self):
|
100 |
+
check_dependencies()
|
101 |
+
|
102 |
+
|
103 |
+
# Core dependencies (including PyTorch which is essential for UniCeption)
|
104 |
+
install_requires = [
|
105 |
+
"numpy",
|
106 |
+
"torch",
|
107 |
+
"torchvision",
|
108 |
+
"torchaudio",
|
109 |
+
"timm",
|
110 |
+
"black",
|
111 |
+
"jaxtyping",
|
112 |
+
"matplotlib",
|
113 |
+
"Pillow",
|
114 |
+
"scikit-learn",
|
115 |
+
"einops",
|
116 |
+
"rerun-sdk",
|
117 |
+
"pre-commit",
|
118 |
+
"minio",
|
119 |
+
"pytest",
|
120 |
+
"isort",
|
121 |
+
]
|
122 |
+
|
123 |
+
# Optional dependencies
|
124 |
+
extras_require = {
|
125 |
+
"xformers": [
|
126 |
+
"xformers", # Will be installed from PyTorch wheel index
|
127 |
+
],
|
128 |
+
"dev": [
|
129 |
+
"black",
|
130 |
+
"isort",
|
131 |
+
"pre-commit",
|
132 |
+
"pytest",
|
133 |
+
],
|
134 |
+
"minimal": [
|
135 |
+
# Minimal dependencies for basic functionality without PyTorch
|
136 |
+
"numpy",
|
137 |
+
"matplotlib",
|
138 |
+
"Pillow",
|
139 |
+
"scikit-learn",
|
140 |
+
"einops",
|
141 |
+
],
|
142 |
+
}
|
143 |
+
|
144 |
+
# All optional dependencies combined (excluding minimal since it's subset of install_requires)
|
145 |
+
extras_require["all"] = list(set(extras_require["xformers"] + extras_require["dev"]))
|
146 |
+
|
147 |
+
setup(
|
148 |
+
name="uniception",
|
149 |
+
version="0.1.0",
|
150 |
+
description="Generalizable Perception Stack for 3D, 4D, spatial AI and scene understanding",
|
151 |
+
long_description=open("README.md").read(),
|
152 |
+
long_description_content_type="text/markdown",
|
153 |
+
author="AirLab",
|
154 |
+
license="BSD Clause-3",
|
155 |
+
packages=find_packages(),
|
156 |
+
package_dir={"": "."},
|
157 |
+
include_package_data=True,
|
158 |
+
python_requires=">=3.10",
|
159 |
+
install_requires=install_requires,
|
160 |
+
extras_require=extras_require,
|
161 |
+
cmdclass={
|
162 |
+
"develop": CustomDevelopCommand,
|
163 |
+
"install": CustomInstallCommand,
|
164 |
+
"install_croco": CrocoInstallCommand,
|
165 |
+
"check_deps": CheckDependenciesCommand,
|
166 |
+
},
|
167 |
+
entry_points={
|
168 |
+
"console_scripts": [
|
169 |
+
"uniception-download-checkpoints=scripts.download_checkpoints:main",
|
170 |
+
"uniception-validate=scripts.validate_installation:main",
|
171 |
+
"uniception-prepare-offline=scripts.prepare_offline_install:main",
|
172 |
+
"uniception-check-deps=scripts.check_dependencies:main",
|
173 |
+
"uniception-install-croco=scripts.install_croco_rope:main",
|
174 |
+
],
|
175 |
+
},
|
176 |
+
classifiers=[
|
177 |
+
"Development Status :: 3 - Alpha",
|
178 |
+
"Intended Audience :: Developers",
|
179 |
+
"Intended Audience :: Science/Research",
|
180 |
+
"Programming Language :: Python :: 3",
|
181 |
+
"Programming Language :: Python :: 3.10",
|
182 |
+
"Programming Language :: Python :: 3.11",
|
183 |
+
"Programming Language :: Python :: 3.12",
|
184 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
185 |
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
186 |
+
],
|
187 |
+
keywords="computer-vision, 3d-vision, spatial-ai, perception, deep-learning, pytorch",
|
188 |
+
)
|
UniCeption/tests/models/encoders/conftest.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
|
3 |
+
|
4 |
+
def pytest_addoption(parser):
|
5 |
+
# Add custom command-line options
|
6 |
+
parser.addoption("--encoder-name", action="store", default=None, help="Specify the encoder name to test")
|
7 |
+
|
8 |
+
parser.addoption(
|
9 |
+
"--device",
|
10 |
+
action="store",
|
11 |
+
default="cpu",
|
12 |
+
choices=["cpu", "gpu"],
|
13 |
+
help="Specify the device to use (default: cpu)",
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
@pytest.fixture
|
18 |
+
def encoder_name(request):
|
19 |
+
# Access the value of the custom option for encoder name
|
20 |
+
return request.config.getoption("--encoder-name")
|
21 |
+
|
22 |
+
|
23 |
+
@pytest.fixture
|
24 |
+
def device(request):
|
25 |
+
# Access the value of the custom option for device
|
26 |
+
return request.config.getoption("--device")
|
UniCeption/tests/models/encoders/test_encoders.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from functools import lru_cache
|
4 |
+
from typing import Tuple
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import pytest
|
8 |
+
import requests
|
9 |
+
import torch
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from uniception.models.encoders import *
|
13 |
+
from uniception.models.encoders.image_normalizations import *
|
14 |
+
|
15 |
+
|
16 |
+
@pytest.fixture(scope="module")
|
17 |
+
def norm_types():
|
18 |
+
return IMAGE_NORMALIZATION_DICT.keys()
|
19 |
+
|
20 |
+
|
21 |
+
@pytest.fixture(scope="module")
|
22 |
+
def encoders():
|
23 |
+
return [
|
24 |
+
"croco",
|
25 |
+
"dust3r_224",
|
26 |
+
"dust3r_512",
|
27 |
+
"dust3r_512_dpt",
|
28 |
+
"mast3r_512",
|
29 |
+
"dinov2_base",
|
30 |
+
"dinov2_large",
|
31 |
+
"dinov2_large_reg",
|
32 |
+
"dinov2_large_dav2",
|
33 |
+
"dinov2_giant",
|
34 |
+
"dinov2_giant_reg",
|
35 |
+
"radio_v2.5-b",
|
36 |
+
"radio_v2.5-l",
|
37 |
+
"e-radio_v2",
|
38 |
+
"naradio_v2.5-b",
|
39 |
+
"naradio_v2.5-l",
|
40 |
+
"cosmosx8",
|
41 |
+
"patch_embedder",
|
42 |
+
]
|
43 |
+
|
44 |
+
|
45 |
+
@pytest.fixture(scope="module")
|
46 |
+
def encoder_configs(encoders):
|
47 |
+
# Adjust the number of configs to match the number of encoders
|
48 |
+
return [{}] * len(encoders)
|
49 |
+
|
50 |
+
|
51 |
+
@pytest.fixture
|
52 |
+
def device(request):
|
53 |
+
# Access the value of the custom option for device
|
54 |
+
device_str = request.config.getoption("--device")
|
55 |
+
if device_str == "gpu" and torch.cuda.is_available():
|
56 |
+
device = torch.device("cuda") # Use the default CUDA device
|
57 |
+
else:
|
58 |
+
device = torch.device("cpu")
|
59 |
+
print(f"Using device: {device.type.upper()}")
|
60 |
+
return device
|
61 |
+
|
62 |
+
|
63 |
+
@pytest.fixture
|
64 |
+
def example_input(device):
|
65 |
+
@lru_cache(maxsize=3)
|
66 |
+
def _get_example_input(
|
67 |
+
image_size: Tuple[int, int],
|
68 |
+
image_norm_type: str = "dummy",
|
69 |
+
img_selection: int = 1,
|
70 |
+
return_viz_img: bool = False,
|
71 |
+
) -> torch.Tensor:
|
72 |
+
url = f"https://raw.githubusercontent.com/naver/croco/d3d0ab2858d44bcad54e5bfc24f565983fbe18d9/assets/Chateau{img_selection}.png"
|
73 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
74 |
+
image = image.resize(image_size)
|
75 |
+
image = image.convert("RGB")
|
76 |
+
|
77 |
+
img = torch.from_numpy(np.array(image))
|
78 |
+
viz_img = img.clone()
|
79 |
+
|
80 |
+
# Normalize the image
|
81 |
+
image_normalization = IMAGE_NORMALIZATION_DICT[image_norm_type]
|
82 |
+
img_mean = image_normalization.mean
|
83 |
+
img_std = image_normalization.std
|
84 |
+
img = (img.float() / 255.0 - img_mean) / img_std
|
85 |
+
|
86 |
+
# Convert to BCHW format
|
87 |
+
img = img.permute(2, 0, 1).unsqueeze(0).to(device)
|
88 |
+
|
89 |
+
if return_viz_img:
|
90 |
+
return img, viz_img
|
91 |
+
else:
|
92 |
+
return img
|
93 |
+
|
94 |
+
return _get_example_input
|
95 |
+
|
96 |
+
|
97 |
+
def inference_encoder(encoder, encoder_input):
|
98 |
+
# Encoder expects a ViTEncoderInput object
|
99 |
+
return encoder(encoder_input).features
|
100 |
+
|
101 |
+
|
102 |
+
def test_make_dummy_encoder(device):
|
103 |
+
print(f"Testing Init of Dummy Encoder on {device.type.upper()}")
|
104 |
+
encoder = _make_encoder_test("dummy").to(device)
|
105 |
+
|
106 |
+
# Check if the encoder has parameters
|
107 |
+
try:
|
108 |
+
params = list(encoder.parameters())
|
109 |
+
if not params:
|
110 |
+
print("Warning: The encoder has no parameters.")
|
111 |
+
else:
|
112 |
+
# Verify if the model is on the right device
|
113 |
+
assert params[0].is_cuda == (device.type == "cuda")
|
114 |
+
|
115 |
+
except Exception as e:
|
116 |
+
print(f"Error: {e}")
|
117 |
+
assert False # Fail the test if any error occurs
|
118 |
+
|
119 |
+
assert encoder is not None
|
120 |
+
|
121 |
+
|
122 |
+
def test_all_encoder_basics(encoders, encoder_configs, norm_types, example_input, encoder_name, device):
|
123 |
+
if encoder_name:
|
124 |
+
encoders = [encoder_name] # Override default encoders with the one specified
|
125 |
+
|
126 |
+
for encoder_name, encoder_config in zip(encoders, encoder_configs):
|
127 |
+
print(f"Testing encoder: {encoder_name} on {device.type.upper()}")
|
128 |
+
|
129 |
+
encoder = _make_encoder_test(encoder_name, **encoder_config).to(device)
|
130 |
+
_check_baseclass_attribute(encoder, norm_types)
|
131 |
+
_check_norm_check_function(encoder)
|
132 |
+
|
133 |
+
if isinstance(encoder, UniCeptionViTEncoderBase):
|
134 |
+
_check_vit_encoder_attribute(encoder)
|
135 |
+
_test_vit_encoder_patch_size(encoder, example_input)
|
136 |
+
|
137 |
+
|
138 |
+
def _check_baseclass_attribute(encoder, norm_types):
|
139 |
+
assert hasattr(encoder, "name")
|
140 |
+
assert hasattr(encoder, "size")
|
141 |
+
assert hasattr(encoder, "data_norm_type")
|
142 |
+
|
143 |
+
assert isinstance(encoder.name, str)
|
144 |
+
assert isinstance(encoder.size, str) or encoder.size is None
|
145 |
+
assert isinstance(encoder.data_norm_type, str)
|
146 |
+
|
147 |
+
# Check if the data_norm_type is in the list of normalization types
|
148 |
+
assert encoder.data_norm_type in norm_types
|
149 |
+
|
150 |
+
|
151 |
+
def _check_norm_check_function(encoder):
|
152 |
+
assert hasattr(encoder, "_check_data_normalization_type")
|
153 |
+
|
154 |
+
encoder_notm_type = encoder.data_norm_type
|
155 |
+
|
156 |
+
try:
|
157 |
+
encoder._check_data_normalization_type(encoder_notm_type)
|
158 |
+
except AssertionError:
|
159 |
+
assert False
|
160 |
+
|
161 |
+
try:
|
162 |
+
encoder._check_data_normalization_type("some_nonexistent_norm_type")
|
163 |
+
assert False
|
164 |
+
except AssertionError:
|
165 |
+
pass
|
166 |
+
|
167 |
+
|
168 |
+
def _check_vit_encoder_attribute(encoder):
|
169 |
+
assert hasattr(encoder, "patch_size")
|
170 |
+
assert isinstance(encoder.patch_size, int)
|
171 |
+
assert encoder.patch_size > 0
|
172 |
+
|
173 |
+
|
174 |
+
def _test_vit_encoder_patch_size(encoder, example_input):
|
175 |
+
print(f"Testing {encoder.name} inference")
|
176 |
+
image_size = (14 * encoder.patch_size, 14 * encoder.patch_size)
|
177 |
+
|
178 |
+
img = example_input(image_size, encoder.data_norm_type)
|
179 |
+
# Create an instance of ViTEncoderInput with correct attributes
|
180 |
+
encoder_input = ViTEncoderInput(
|
181 |
+
data_norm_type=encoder.data_norm_type,
|
182 |
+
image=img,
|
183 |
+
)
|
184 |
+
|
185 |
+
encoder_output = inference_encoder(encoder, encoder_input)
|
186 |
+
|
187 |
+
assert isinstance(encoder_output, torch.Tensor)
|
188 |
+
assert encoder_output.shape[2] == 14
|
189 |
+
assert encoder_output.shape[3] == 14
|
190 |
+
|
191 |
+
|
192 |
+
@pytest.fixture(scope="session", autouse=True)
|
193 |
+
def seed_everything():
|
194 |
+
seed = 42
|
195 |
+
random.seed(seed)
|
196 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
197 |
+
np.random.seed(seed)
|
198 |
+
torch.manual_seed(seed)
|
199 |
+
torch.backends.cudnn.deterministic = True
|
200 |
+
torch.backends.cudnn.benchmark = False
|
201 |
+
print(f"Seed set to: {seed} (type: {type(seed)})")
|
202 |
+
|
203 |
+
# Turn XFormers off for testing on CPU
|
204 |
+
os.environ["XFORMERS_DISABLED"] = "1"
|
UniCeption/tests/models/encoders/viz_image_encoders.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
PCA Visualization of UniCeption Image Encoders
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
from functools import lru_cache
|
8 |
+
from typing import Tuple
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import requests
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from matplotlib import pyplot as plt
|
15 |
+
from PIL import Image
|
16 |
+
from sklearn.decomposition import PCA
|
17 |
+
|
18 |
+
from uniception.models.encoders import *
|
19 |
+
from uniception.models.encoders.image_normalizations import *
|
20 |
+
|
21 |
+
|
22 |
+
class TestEncoders:
|
23 |
+
def __init__(self, pca_save_folder, *args, **kwargs):
|
24 |
+
super(TestEncoders, self).__init__(*args, **kwargs)
|
25 |
+
|
26 |
+
self.pca_save_folder = pca_save_folder
|
27 |
+
|
28 |
+
self.norm_types = IMAGE_NORMALIZATION_DICT.keys()
|
29 |
+
|
30 |
+
self.encoders = [
|
31 |
+
"croco",
|
32 |
+
"dust3r_224",
|
33 |
+
"dust3r_512",
|
34 |
+
"dust3r_512_dpt",
|
35 |
+
"mast3r_512",
|
36 |
+
"dinov2_large",
|
37 |
+
"dinov2_large_reg",
|
38 |
+
"dinov2_large_dav2",
|
39 |
+
"dinov2_giant",
|
40 |
+
"dinov2_giant_reg",
|
41 |
+
"radio_v2.5-b",
|
42 |
+
"radio_v2.5-l",
|
43 |
+
"e-radio_v2",
|
44 |
+
]
|
45 |
+
|
46 |
+
self.encoder_configs = [{}] * len(self.encoders)
|
47 |
+
|
48 |
+
def inference_encoder(self, encoder, input):
|
49 |
+
return encoder(input)
|
50 |
+
|
51 |
+
def visualize_all_encoders(self):
|
52 |
+
for encoder, encoder_config in zip(self.encoders, self.encoder_configs):
|
53 |
+
encoder = _make_encoder_test(encoder, **encoder_config)
|
54 |
+
self._visualize_encoder_features_consistency(encoder, (224, 224))
|
55 |
+
|
56 |
+
def _visualize_encoder_features(self, encoder, image_size: Tuple[int, int]):
|
57 |
+
img, viz_img = self._get_example_input(image_size, encoder.data_norm_type, return_viz_img=True)
|
58 |
+
# input and output of the encoder
|
59 |
+
encoder_input: ViTEncoderInput = ViTEncoderInput(
|
60 |
+
data_norm_type=encoder.data_norm_type,
|
61 |
+
image=img,
|
62 |
+
)
|
63 |
+
|
64 |
+
encoder_output = self.inference_encoder(encoder, encoder_input)
|
65 |
+
encoder_output = encoder_output.features
|
66 |
+
|
67 |
+
self.assertTrue(isinstance(encoder_output, torch.Tensor))
|
68 |
+
|
69 |
+
# visualize the features
|
70 |
+
pca_viz = get_pca_map(encoder_output.permute(0, 2, 3, 1), image_size, return_pca_stats=False)
|
71 |
+
|
72 |
+
# plot the input image and the PCA features
|
73 |
+
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
|
74 |
+
axs[0].imshow(viz_img)
|
75 |
+
axs[0].set_title("Input Image")
|
76 |
+
axs[0].axis("off")
|
77 |
+
axs[1].imshow(pca_viz)
|
78 |
+
axs[1].set_title(f"PCA Features of {encoder.name}")
|
79 |
+
axs[1].axis("off")
|
80 |
+
plt.savefig(f"{self.pca_save_folder}/pca_{encoder.name}.png", bbox_inches="tight")
|
81 |
+
plt.close()
|
82 |
+
|
83 |
+
def _visualize_encoder_features_consistency(self, encoder, image_size: Tuple[int, int]):
|
84 |
+
img0, viz_img0 = self._get_example_input(
|
85 |
+
image_size, encoder.data_norm_type, img_selection=1, return_viz_img=True
|
86 |
+
)
|
87 |
+
img1, viz_img1 = self._get_example_input(
|
88 |
+
image_size, encoder.data_norm_type, img_selection=2, return_viz_img=True
|
89 |
+
)
|
90 |
+
# input and output of the encoder
|
91 |
+
encoder_input0: ViTEncoderInput = ViTEncoderInput(
|
92 |
+
data_norm_type=encoder.data_norm_type,
|
93 |
+
image=img0,
|
94 |
+
)
|
95 |
+
|
96 |
+
encoder_input1: ViTEncoderInput = ViTEncoderInput(
|
97 |
+
data_norm_type=encoder.data_norm_type,
|
98 |
+
image=img1,
|
99 |
+
)
|
100 |
+
|
101 |
+
encoder_output0 = self.inference_encoder(encoder, encoder_input0)
|
102 |
+
encoder_output0 = encoder_output0.features
|
103 |
+
|
104 |
+
encoder_output1 = self.inference_encoder(encoder, encoder_input1)
|
105 |
+
encoder_output1 = encoder_output1.features
|
106 |
+
|
107 |
+
# get a common PCA codec
|
108 |
+
cat_feats = torch.cat([encoder_output0, encoder_output1], dim=3)
|
109 |
+
|
110 |
+
pca_viz = get_pca_map(cat_feats.permute(0, 2, 3, 1), (image_size[0], image_size[1] * 2), return_pca_stats=True)
|
111 |
+
|
112 |
+
# concatenate the input images along the width dimension
|
113 |
+
cat_imgs = torch.cat([viz_img0, viz_img1], dim=1)
|
114 |
+
|
115 |
+
# plot the input image and the PCA features
|
116 |
+
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
|
117 |
+
axs[0].imshow(cat_imgs)
|
118 |
+
axs[0].set_title("Input Images")
|
119 |
+
axs[0].axis("off")
|
120 |
+
axs[1].imshow(pca_viz[0])
|
121 |
+
axs[1].set_title(f"PCA Features of {encoder.name}")
|
122 |
+
axs[1].axis("off")
|
123 |
+
plt.savefig(f"{self.pca_save_folder}/multi_pca_{encoder.name}.png", bbox_inches="tight")
|
124 |
+
plt.close()
|
125 |
+
|
126 |
+
@lru_cache(maxsize=3)
|
127 |
+
def _get_example_input(
|
128 |
+
self,
|
129 |
+
image_size: Tuple[int, int],
|
130 |
+
image_norm_type: str = "dummy",
|
131 |
+
img_selection: int = 1,
|
132 |
+
return_viz_img: bool = False,
|
133 |
+
) -> torch.Tensor:
|
134 |
+
url = f"https://raw.githubusercontent.com/naver/croco/d3d0ab2858d44bcad54e5bfc24f565983fbe18d9/assets/Chateau{img_selection}.png"
|
135 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
136 |
+
image = image.resize(image_size)
|
137 |
+
image = image.convert("RGB")
|
138 |
+
|
139 |
+
img = torch.from_numpy(np.array(image))
|
140 |
+
viz_img = img.clone()
|
141 |
+
|
142 |
+
# Normalize the images
|
143 |
+
image_normalization = IMAGE_NORMALIZATION_DICT[image_norm_type]
|
144 |
+
|
145 |
+
img_mean, img_std = image_normalization.mean, image_normalization.std
|
146 |
+
|
147 |
+
img = (img.float() / 255.0 - img_mean) / img_std
|
148 |
+
|
149 |
+
# convert to BCHW format
|
150 |
+
img = img.permute(2, 0, 1).unsqueeze(0)
|
151 |
+
|
152 |
+
if return_viz_img:
|
153 |
+
return img, viz_img
|
154 |
+
else:
|
155 |
+
return img
|
156 |
+
|
157 |
+
|
158 |
+
def render_pca_as_rgb(features):
|
159 |
+
"""
|
160 |
+
Perform PCA on the given feature tensor and render the first 3 principal components as RGB.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
features (torch.Tensor): Feature tensor of shape (B, C, H, W).
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
np.ndarray: RGB image of shape (H, W, 3).
|
167 |
+
"""
|
168 |
+
# Ensure input is a 4D tensor
|
169 |
+
assert features.dim() == 4, "Input tensor must be 4D (B, C, H, W)"
|
170 |
+
|
171 |
+
B, C, H, W = features.shape
|
172 |
+
|
173 |
+
# Reshape the tensor to (B * H * W, C)
|
174 |
+
reshaped_features = features.permute(0, 2, 3, 1).contiguous().view(-1, C).cpu().numpy()
|
175 |
+
|
176 |
+
# Perform PCA
|
177 |
+
pca = PCA(n_components=3)
|
178 |
+
principal_components = pca.fit_transform(reshaped_features)
|
179 |
+
|
180 |
+
# Rescale the principal components to [0, 1]
|
181 |
+
principal_components = (principal_components - principal_components.min(axis=0)) / (
|
182 |
+
principal_components.max(axis=0) - principal_components.min(axis=0)
|
183 |
+
)
|
184 |
+
|
185 |
+
# Reshape the principal components to (B, H, W, 3)
|
186 |
+
principal_components = principal_components.reshape(B, H, W, 3)
|
187 |
+
|
188 |
+
# Convert the principal components to RGB image (take the first batch)
|
189 |
+
rgb_image = principal_components[0]
|
190 |
+
|
191 |
+
return rgb_image
|
192 |
+
|
193 |
+
|
194 |
+
def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False):
|
195 |
+
# features: (N, C)
|
196 |
+
# m: a hyperparam controlling how many std dev outside for outliers
|
197 |
+
assert len(features.shape) == 2, "features should be (N, C)"
|
198 |
+
reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2]
|
199 |
+
colors = features @ reduction_mat
|
200 |
+
if remove_first_component:
|
201 |
+
colors_min = colors.min(dim=0).values
|
202 |
+
colors_max = colors.max(dim=0).values
|
203 |
+
tmp_colors = (colors - colors_min) / (colors_max - colors_min)
|
204 |
+
fg_mask = tmp_colors[..., 0] < 0.2
|
205 |
+
reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2]
|
206 |
+
colors = features @ reduction_mat
|
207 |
+
else:
|
208 |
+
fg_mask = torch.ones_like(colors[:, 0]).bool()
|
209 |
+
d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values)
|
210 |
+
mdev = torch.median(d, dim=0).values
|
211 |
+
s = d / mdev
|
212 |
+
try:
|
213 |
+
rins = colors[fg_mask][s[:, 0] < m, 0]
|
214 |
+
gins = colors[fg_mask][s[:, 1] < m, 1]
|
215 |
+
bins = colors[fg_mask][s[:, 2] < m, 2]
|
216 |
+
rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
|
217 |
+
rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
|
218 |
+
except:
|
219 |
+
rins = colors
|
220 |
+
gins = colors
|
221 |
+
bins = colors
|
222 |
+
rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
|
223 |
+
rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
|
224 |
+
|
225 |
+
return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat)
|
226 |
+
|
227 |
+
|
228 |
+
def get_pca_map(
|
229 |
+
feature_map: torch.Tensor,
|
230 |
+
img_size,
|
231 |
+
interpolation="bicubic",
|
232 |
+
return_pca_stats=False,
|
233 |
+
pca_stats=None,
|
234 |
+
):
|
235 |
+
"""
|
236 |
+
feature_map: (1, h, w, C) is the feature map of a single image.
|
237 |
+
"""
|
238 |
+
if feature_map.shape[0] != 1:
|
239 |
+
# make it (1, h, w, C)
|
240 |
+
feature_map = feature_map[None]
|
241 |
+
if pca_stats is None:
|
242 |
+
reduct_mat, color_min, color_max = get_robust_pca(feature_map.reshape(-1, feature_map.shape[-1]))
|
243 |
+
else:
|
244 |
+
reduct_mat, color_min, color_max = pca_stats
|
245 |
+
pca_color = feature_map @ reduct_mat
|
246 |
+
pca_color = (pca_color - color_min) / (color_max - color_min)
|
247 |
+
pca_color = pca_color.clamp(0, 1)
|
248 |
+
pca_color = F.interpolate(
|
249 |
+
pca_color.permute(0, 3, 1, 2),
|
250 |
+
size=img_size,
|
251 |
+
mode=interpolation,
|
252 |
+
).permute(0, 2, 3, 1)
|
253 |
+
pca_color = pca_color.detach().cpu().numpy().squeeze(0)
|
254 |
+
if return_pca_stats:
|
255 |
+
return pca_color, (reduct_mat, color_min, color_max)
|
256 |
+
return pca_color
|
257 |
+
|
258 |
+
|
259 |
+
def seed_everything(seed=42):
|
260 |
+
"""
|
261 |
+
Set the `seed` value for torch and numpy seeds. Also turns on
|
262 |
+
deterministic execution for cudnn.
|
263 |
+
|
264 |
+
Parameters:
|
265 |
+
- seed: A hashable seed value
|
266 |
+
"""
|
267 |
+
random.seed(seed)
|
268 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
269 |
+
np.random.seed(seed)
|
270 |
+
torch.manual_seed(seed)
|
271 |
+
torch.backends.cudnn.deterministic = True
|
272 |
+
torch.backends.cudnn.benchmark = False
|
273 |
+
print(f"Seed set to: {seed} (type: {type(seed)})")
|
274 |
+
|
275 |
+
|
276 |
+
if __name__ == "__main__":
|
277 |
+
# Turn XFormers off for testing on CPU
|
278 |
+
os.environ["XFORMERS_DISABLED"] = "1"
|
279 |
+
|
280 |
+
# Seed everything for consistent testing
|
281 |
+
seed_everything()
|
282 |
+
|
283 |
+
# Create local directory for storing the PCA images
|
284 |
+
current_file_path = os.path.abspath(__file__)
|
285 |
+
relative_pca_image_folder = os.path.join(os.path.dirname(current_file_path), "../../../local/encoders/pca_images")
|
286 |
+
os.makedirs(relative_pca_image_folder, exist_ok=True)
|
287 |
+
|
288 |
+
# Initialize the test class
|
289 |
+
test = TestEncoders(pca_save_folder=relative_pca_image_folder)
|
290 |
+
|
291 |
+
# Visualize the PCA of all encoders
|
292 |
+
test.visualize_all_encoders()
|
293 |
+
|
294 |
+
print(f"The PCA visualizations of all encoders are saved successfully to {relative_pca_image_folder}!")
|
UniCeption/tests/models/info_sharing/viz_mulit_view_cross_attn_transformers.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
PCA Visualization of UniCeption Image Encoders + Multi-View Cross Attention Transformers
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
from functools import lru_cache
|
8 |
+
from typing import Tuple
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import requests
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from matplotlib import pyplot as plt
|
15 |
+
from PIL import Image
|
16 |
+
from sklearn.decomposition import PCA
|
17 |
+
|
18 |
+
from uniception.models.encoders import *
|
19 |
+
from uniception.models.encoders.image_normalizations import *
|
20 |
+
from uniception.models.info_sharing.base import MultiViewTransformerInput
|
21 |
+
from uniception.models.info_sharing.cross_attention_transformer import MultiViewCrossAttentionTransformerIFR
|
22 |
+
from uniception.models.libs.croco.pos_embed import RoPE2D, get_2d_sincos_pos_embed
|
23 |
+
|
24 |
+
|
25 |
+
def _make_mv_cross_attention_transformer_test(model_str: str, **kwargs):
|
26 |
+
current_file_path = os.path.abspath(__file__)
|
27 |
+
relative_checkpoint_path = os.path.join(
|
28 |
+
os.path.dirname(current_file_path), "../../../checkpoints/info_sharing/cross_attn_transformer"
|
29 |
+
)
|
30 |
+
rope = RoPE2D(float(100))
|
31 |
+
if model_str == "croco":
|
32 |
+
return MultiViewCrossAttentionTransformerIFR(
|
33 |
+
name="croco_base_decoder",
|
34 |
+
input_embed_dim=1024,
|
35 |
+
num_views=2,
|
36 |
+
indices=[12 * 2 // 4, 12 * 3 // 4],
|
37 |
+
norm_intermediate=False,
|
38 |
+
custom_positional_encoding=rope,
|
39 |
+
pretrained_checkpoint_path=f"{relative_checkpoint_path}/Two_View_Cross_Attention_Transformer_CroCo.pth",
|
40 |
+
**kwargs,
|
41 |
+
)
|
42 |
+
elif model_str == "dust3r_224":
|
43 |
+
return MultiViewCrossAttentionTransformerIFR(
|
44 |
+
name="dust3r_224_base_decoder",
|
45 |
+
input_embed_dim=1024,
|
46 |
+
num_views=2,
|
47 |
+
indices=[12 * 2 // 4, 12 * 3 // 4],
|
48 |
+
norm_intermediate=False,
|
49 |
+
custom_positional_encoding=rope,
|
50 |
+
pretrained_checkpoint_path=f"{relative_checkpoint_path}/Two_View_Cross_Attention_Transformer_DUSt3R_224_linear.pth",
|
51 |
+
**kwargs,
|
52 |
+
)
|
53 |
+
elif model_str == "dust3r_512":
|
54 |
+
return MultiViewCrossAttentionTransformerIFR(
|
55 |
+
name="dust3r_512_base_decoder",
|
56 |
+
input_embed_dim=1024,
|
57 |
+
num_views=2,
|
58 |
+
indices=[12 * 2 // 4, 12 * 3 // 4],
|
59 |
+
norm_intermediate=False,
|
60 |
+
custom_positional_encoding=rope,
|
61 |
+
pretrained_checkpoint_path=f"{relative_checkpoint_path}/Two_View_Cross_Attention_Transformer_DUSt3R_512_linear.pth",
|
62 |
+
**kwargs,
|
63 |
+
)
|
64 |
+
elif model_str == "dust3r_512_dpt":
|
65 |
+
return MultiViewCrossAttentionTransformerIFR(
|
66 |
+
name="dust3r_512_dpt_base_decoder",
|
67 |
+
input_embed_dim=1024,
|
68 |
+
num_views=2,
|
69 |
+
indices=[12 * 2 // 4, 12 * 3 // 4],
|
70 |
+
norm_intermediate=False,
|
71 |
+
custom_positional_encoding=rope,
|
72 |
+
pretrained_checkpoint_path=f"{relative_checkpoint_path}/Two_View_Cross_Attention_Transformer_DUSt3R_512_dpt.pth",
|
73 |
+
**kwargs,
|
74 |
+
)
|
75 |
+
elif model_str == "mast3r_512":
|
76 |
+
return MultiViewCrossAttentionTransformerIFR(
|
77 |
+
name="mast3r_512_base_decoder",
|
78 |
+
input_embed_dim=1024,
|
79 |
+
num_views=2,
|
80 |
+
indices=[12 * 2 // 4, 12 * 3 // 4],
|
81 |
+
norm_intermediate=False,
|
82 |
+
custom_positional_encoding=rope,
|
83 |
+
pretrained_checkpoint_path=f"{relative_checkpoint_path}/Two_View_Cross_Attention_Transformer_MASt3R.pth",
|
84 |
+
**kwargs,
|
85 |
+
)
|
86 |
+
|
87 |
+
|
88 |
+
class TestMultiViewTransformers:
|
89 |
+
def __init__(self, pca_save_folder, *args, **kwargs):
|
90 |
+
super(TestMultiViewTransformers, self).__init__(*args, **kwargs)
|
91 |
+
|
92 |
+
self.pca_save_folder = pca_save_folder
|
93 |
+
|
94 |
+
self.norm_types = IMAGE_NORMALIZATION_DICT.keys()
|
95 |
+
|
96 |
+
self.models = [
|
97 |
+
"croco",
|
98 |
+
"dust3r_224",
|
99 |
+
"dust3r_512",
|
100 |
+
"dust3r_512_dpt",
|
101 |
+
"mast3r_512",
|
102 |
+
]
|
103 |
+
|
104 |
+
self.model_configs = [{}] * len(self.models)
|
105 |
+
|
106 |
+
def inference_encoder(self, encoder, input):
|
107 |
+
return encoder(input)
|
108 |
+
|
109 |
+
def inference_info_sharing(self, info_sharing, input):
|
110 |
+
return info_sharing(input)
|
111 |
+
|
112 |
+
def visualize_all_models(self):
|
113 |
+
for model, model_config in zip(self.models, self.model_configs):
|
114 |
+
encoder = _make_encoder_test(model, **model_config)
|
115 |
+
info_sharing = _make_mv_cross_attention_transformer_test(model, **model_config)
|
116 |
+
self._visualize_model_features_consistency(encoder, info_sharing, (224, 224))
|
117 |
+
|
118 |
+
def _visualize_model_features_consistency(self, encoder, info_sharing, image_size: Tuple[int, int]):
|
119 |
+
img0, viz_img0 = self._get_example_input(
|
120 |
+
image_size, encoder.data_norm_type, img_selection=1, return_viz_img=True
|
121 |
+
)
|
122 |
+
img1, viz_img1 = self._get_example_input(
|
123 |
+
image_size, encoder.data_norm_type, img_selection=2, return_viz_img=True
|
124 |
+
)
|
125 |
+
# input and output of the encoder
|
126 |
+
encoder_input0: ViTEncoderInput = ViTEncoderInput(
|
127 |
+
data_norm_type=encoder.data_norm_type,
|
128 |
+
image=img0,
|
129 |
+
)
|
130 |
+
|
131 |
+
encoder_input1: ViTEncoderInput = ViTEncoderInput(
|
132 |
+
data_norm_type=encoder.data_norm_type,
|
133 |
+
image=img1,
|
134 |
+
)
|
135 |
+
|
136 |
+
encoder_output0 = self.inference_encoder(encoder, encoder_input0)
|
137 |
+
encoder_output0 = encoder_output0.features
|
138 |
+
|
139 |
+
encoder_output1 = self.inference_encoder(encoder, encoder_input1)
|
140 |
+
encoder_output1 = encoder_output1.features
|
141 |
+
|
142 |
+
# pass the encoder outputs to the info sharing model
|
143 |
+
multi_view_features = [encoder_output0, encoder_output1]
|
144 |
+
info_sharing_input = MultiViewTransformerInput(features=multi_view_features)
|
145 |
+
info_sharing_output = self.inference_info_sharing(info_sharing, info_sharing_input)
|
146 |
+
final_layer_multi_view_features = info_sharing_output[0].features
|
147 |
+
|
148 |
+
# get a common PCA codec
|
149 |
+
cat_feats = torch.cat(final_layer_multi_view_features, dim=3)
|
150 |
+
|
151 |
+
pca_viz = get_pca_map(cat_feats.permute(0, 2, 3, 1), (image_size[0], image_size[1] * 2), return_pca_stats=True)
|
152 |
+
|
153 |
+
# concatenate the input images along the width dimension
|
154 |
+
cat_imgs = torch.cat([viz_img0, viz_img1], dim=1)
|
155 |
+
|
156 |
+
# plot the input image and the PCA features
|
157 |
+
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
|
158 |
+
axs[0].imshow(cat_imgs)
|
159 |
+
axs[0].set_title("Input Images")
|
160 |
+
axs[0].axis("off")
|
161 |
+
axs[1].imshow(pca_viz[0])
|
162 |
+
axs[1].set_title(f"PCA Features of {encoder.name} + Base Decoder")
|
163 |
+
axs[1].axis("off")
|
164 |
+
plt.savefig(f"{self.pca_save_folder}/multi_pca_{encoder.name}.png", bbox_inches="tight")
|
165 |
+
plt.close()
|
166 |
+
|
167 |
+
@lru_cache(maxsize=3)
|
168 |
+
def _get_example_input(
|
169 |
+
self,
|
170 |
+
image_size: Tuple[int, int],
|
171 |
+
image_norm_type: str = "dummy",
|
172 |
+
img_selection: int = 1,
|
173 |
+
return_viz_img: bool = False,
|
174 |
+
) -> torch.Tensor:
|
175 |
+
url = f"https://raw.githubusercontent.com/naver/croco/d3d0ab2858d44bcad54e5bfc24f565983fbe18d9/assets/Chateau{img_selection}.png"
|
176 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
177 |
+
image = image.resize(image_size)
|
178 |
+
image = image.convert("RGB")
|
179 |
+
|
180 |
+
img = torch.from_numpy(np.array(image))
|
181 |
+
viz_img = img.clone()
|
182 |
+
|
183 |
+
# Normalize the images
|
184 |
+
image_normalization = IMAGE_NORMALIZATION_DICT[image_norm_type]
|
185 |
+
|
186 |
+
img_mean, img_std = image_normalization.mean, image_normalization.std
|
187 |
+
|
188 |
+
img = (img.float() / 255.0 - img_mean) / img_std
|
189 |
+
|
190 |
+
# convert to BCHW format
|
191 |
+
img = img.permute(2, 0, 1).unsqueeze(0)
|
192 |
+
|
193 |
+
if return_viz_img:
|
194 |
+
return img, viz_img
|
195 |
+
else:
|
196 |
+
return img
|
197 |
+
|
198 |
+
|
199 |
+
def render_pca_as_rgb(features):
|
200 |
+
"""
|
201 |
+
Perform PCA on the given feature tensor and render the first 3 principal components as RGB.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
features (torch.Tensor): Feature tensor of shape (B, C, H, W).
|
205 |
+
|
206 |
+
Returns:
|
207 |
+
np.ndarray: RGB image of shape (H, W, 3).
|
208 |
+
"""
|
209 |
+
# Ensure input is a 4D tensor
|
210 |
+
assert features.dim() == 4, "Input tensor must be 4D (B, C, H, W)"
|
211 |
+
|
212 |
+
B, C, H, W = features.shape
|
213 |
+
|
214 |
+
# Reshape the tensor to (B * H * W, C)
|
215 |
+
reshaped_features = features.permute(0, 2, 3, 1).contiguous().view(-1, C).cpu().numpy()
|
216 |
+
|
217 |
+
# Perform PCA
|
218 |
+
pca = PCA(n_components=3)
|
219 |
+
principal_components = pca.fit_transform(reshaped_features)
|
220 |
+
|
221 |
+
# Rescale the principal components to [0, 1]
|
222 |
+
principal_components = (principal_components - principal_components.min(axis=0)) / (
|
223 |
+
principal_components.max(axis=0) - principal_components.min(axis=0)
|
224 |
+
)
|
225 |
+
|
226 |
+
# Reshape the principal components to (B, H, W, 3)
|
227 |
+
principal_components = principal_components.reshape(B, H, W, 3)
|
228 |
+
|
229 |
+
# Convert the principal components to RGB image (take the first batch)
|
230 |
+
rgb_image = principal_components[0]
|
231 |
+
|
232 |
+
return rgb_image
|
233 |
+
|
234 |
+
|
235 |
+
def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False):
|
236 |
+
# features: (N, C)
|
237 |
+
# m: a hyperparam controlling how many std dev outside for outliers
|
238 |
+
assert len(features.shape) == 2, "features should be (N, C)"
|
239 |
+
reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2]
|
240 |
+
colors = features @ reduction_mat
|
241 |
+
if remove_first_component:
|
242 |
+
colors_min = colors.min(dim=0).values
|
243 |
+
colors_max = colors.max(dim=0).values
|
244 |
+
tmp_colors = (colors - colors_min) / (colors_max - colors_min)
|
245 |
+
fg_mask = tmp_colors[..., 0] < 0.2
|
246 |
+
reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2]
|
247 |
+
colors = features @ reduction_mat
|
248 |
+
else:
|
249 |
+
fg_mask = torch.ones_like(colors[:, 0]).bool()
|
250 |
+
d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values)
|
251 |
+
mdev = torch.median(d, dim=0).values
|
252 |
+
s = d / mdev
|
253 |
+
try:
|
254 |
+
rins = colors[fg_mask][s[:, 0] < m, 0]
|
255 |
+
gins = colors[fg_mask][s[:, 1] < m, 1]
|
256 |
+
bins = colors[fg_mask][s[:, 2] < m, 2]
|
257 |
+
rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
|
258 |
+
rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
|
259 |
+
except:
|
260 |
+
rins = colors
|
261 |
+
gins = colors
|
262 |
+
bins = colors
|
263 |
+
rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
|
264 |
+
rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
|
265 |
+
|
266 |
+
return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat)
|
267 |
+
|
268 |
+
|
269 |
+
def get_pca_map(
|
270 |
+
feature_map: torch.Tensor,
|
271 |
+
img_size,
|
272 |
+
interpolation="bicubic",
|
273 |
+
return_pca_stats=False,
|
274 |
+
pca_stats=None,
|
275 |
+
):
|
276 |
+
"""
|
277 |
+
feature_map: (1, h, w, C) is the feature map of a single image.
|
278 |
+
"""
|
279 |
+
if feature_map.shape[0] != 1:
|
280 |
+
# make it (1, h, w, C)
|
281 |
+
feature_map = feature_map[None]
|
282 |
+
if pca_stats is None:
|
283 |
+
reduct_mat, color_min, color_max = get_robust_pca(feature_map.reshape(-1, feature_map.shape[-1]))
|
284 |
+
else:
|
285 |
+
reduct_mat, color_min, color_max = pca_stats
|
286 |
+
pca_color = feature_map @ reduct_mat
|
287 |
+
pca_color = (pca_color - color_min) / (color_max - color_min)
|
288 |
+
pca_color = pca_color.clamp(0, 1)
|
289 |
+
pca_color = F.interpolate(
|
290 |
+
pca_color.permute(0, 3, 1, 2),
|
291 |
+
size=img_size,
|
292 |
+
mode=interpolation,
|
293 |
+
).permute(0, 2, 3, 1)
|
294 |
+
pca_color = pca_color.detach().cpu().numpy().squeeze(0)
|
295 |
+
if return_pca_stats:
|
296 |
+
return pca_color, (reduct_mat, color_min, color_max)
|
297 |
+
return pca_color
|
298 |
+
|
299 |
+
|
300 |
+
def seed_everything(seed=42):
|
301 |
+
"""
|
302 |
+
Set the `seed` value for torch and numpy seeds. Also turns on
|
303 |
+
deterministic execution for cudnn.
|
304 |
+
|
305 |
+
Parameters:
|
306 |
+
- seed: A hashable seed value
|
307 |
+
"""
|
308 |
+
random.seed(seed)
|
309 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
310 |
+
np.random.seed(seed)
|
311 |
+
torch.manual_seed(seed)
|
312 |
+
torch.backends.cudnn.deterministic = True
|
313 |
+
torch.backends.cudnn.benchmark = False
|
314 |
+
print(f"Seed set to: {seed} (type: {type(seed)})")
|
315 |
+
|
316 |
+
|
317 |
+
if __name__ == "__main__":
|
318 |
+
# Turn XFormers off for testing on CPU
|
319 |
+
os.environ["XFORMERS_DISABLED"] = "1"
|
320 |
+
|
321 |
+
# Seed everything for consistent testing
|
322 |
+
seed_everything()
|
323 |
+
|
324 |
+
# Create local directory for storing the PCA images
|
325 |
+
current_file_path = os.path.abspath(__file__)
|
326 |
+
relative_pca_image_folder = os.path.join(
|
327 |
+
os.path.dirname(current_file_path), "../../../local/info_sharing/pca_images"
|
328 |
+
)
|
329 |
+
os.makedirs(relative_pca_image_folder, exist_ok=True)
|
330 |
+
|
331 |
+
# Initialize the test class
|
332 |
+
test = TestMultiViewTransformers(pca_save_folder=relative_pca_image_folder)
|
333 |
+
|
334 |
+
# Visualize the PCA of all models
|
335 |
+
test.visualize_all_models()
|
336 |
+
|
337 |
+
print(f"The PCA visualizations of all models are saved successfully to {relative_pca_image_folder}!")
|
UniCeption/uniception/__init__.py
ADDED
File without changes
|
UniCeption/uniception/models/encoders/README.md
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# UniCeption Encoders
|
2 |
+
|
3 |
+
## Currently Supported Encoders
|
4 |
+
|
5 |
+
### UniCeptionViTEncoderBase:
|
6 |
+
|
7 |
+
- `CroCoEncoder`
|
8 |
+
- `CroCoIntermediateFeatureReturner`
|
9 |
+
- `DINOv2Encoder`
|
10 |
+
- `DINOv2IntermediateFeatureReturner`
|
11 |
+
- `PatchEmbedder`
|
12 |
+
- `RADIOEncoder`
|
13 |
+
- `RADIOIntermediateFeatureReturner`
|
14 |
+
|
15 |
+
# Developer Guidelines for UniCeption Encoders
|
16 |
+
|
17 |
+
## Overview
|
18 |
+
|
19 |
+
This folder contains the implementation of various UniCeption encoders. Each encoder must adhere to a specific structure and follow certain guidelines to ensure consistency and compatibility across different projects.
|
20 |
+
|
21 |
+
## Directory Structure
|
22 |
+
|
23 |
+
The encoders and other necessary dependencies/tests for encoders are organized as follows:
|
24 |
+
```
|
25 |
+
uniception/
|
26 |
+
├── models/
|
27 |
+
│ ├── encoders/
|
28 |
+
│ │ ├── __init__.py
|
29 |
+
│ │ ├── base.py
|
30 |
+
│ │ ├── croco.py
|
31 |
+
│ │ ├── dinov2.py
|
32 |
+
│ │ ├── radio.py
|
33 |
+
│ │ ├── image_normalizations.py
|
34 |
+
│ └── ...
|
35 |
+
│ └── libs/
|
36 |
+
│ │ ├── external_dependency_folders/
|
37 |
+
| | | ├── external_dependency_files
|
38 |
+
tests/
|
39 |
+
├── models/
|
40 |
+
│ ├── encoders/
|
41 |
+
│ │ ├── test_encoders.py
|
42 |
+
│ │ ├── viz_image_encoders.py
|
43 |
+
│ │ └── ...
|
44 |
+
| └── ...
|
45 |
+
└── ...
|
46 |
+
```
|
47 |
+
|
48 |
+
## Adding a New Encoder
|
49 |
+
|
50 |
+
To add a new encoder, follow these steps:
|
51 |
+
|
52 |
+
1. **Create a New Encoder File**:
|
53 |
+
- Create a new file in the `encoders` directory, e.g., `new_encoder.py`.
|
54 |
+
- Define the new encoder class in this file, inheriting from `UniCeptionEncoderBase` or `UniCeptionViTEncoderBase`.
|
55 |
+
- Please look at the base class for the necessary attributes and methods to implement.
|
56 |
+
|
57 |
+
2. **Define Input Data Normalization**:
|
58 |
+
- Add the corresponding normalization for the encoder to respective normalization files, for example, image normalizations should be added to `image_normalizations.py`.
|
59 |
+
- Ensure the normalization is added to the dictionaries present in the files, for example, `IMAGE_NORMALIZATION_DICT`.
|
60 |
+
|
61 |
+
4. **Implement the Encoder Class**:
|
62 |
+
- Inherit from `UniCeptionEncoderBase` or `UniCeptionViTEncoderBase` or other UniCeption base classes.
|
63 |
+
- Implement the `forward` method.
|
64 |
+
- Ensure the encoder class has the necessary attributes and methods.
|
65 |
+
|
66 |
+
4. **Update `__init__.py`**:
|
67 |
+
- Import the new encoder class in `__init__.py`.
|
68 |
+
- Add the new encoder to the encoder configuration dictionary `ENCODER_CONFIGS` so that it can be instantiated via the encoder factory.
|
69 |
+
- Update the `_make_encoder_test` function to include the new encoder.
|
70 |
+
|
71 |
+
5. **Run Encoder Unit Tests**:
|
72 |
+
- Run `pytest -vs tests/models/encoders/test_encoders.py --encoder-name="<new_encoder>"` to test the basic expected functionality of UniCeption encoders.
|
73 |
+
- Also, add your new encoder to the list in the encoders() in `tests/models/encoders/test_encoders.py` so that it can be tested along with all the existing encoders.
|
74 |
+
- Optionally, for image encoders, the unit tests in `tests/models/encoders/viz_image_encoders.py` save PCA visualizations of the encoder outputs to the `local/pca_images` directory.
|
75 |
+
|
76 |
+
## Example Encoder Implementation
|
77 |
+
|
78 |
+
Here is an example of how to implement a new encoder:
|
79 |
+
|
80 |
+
```python
|
81 |
+
# new_encoder.py
|
82 |
+
import torch
|
83 |
+
from uniception.models.encoders.base import UniCeptionEncoderBase, EncoderInput, EncoderOutput
|
84 |
+
|
85 |
+
class NewEncoder(UniCeptionEncoderBase):
|
86 |
+
def __init__(self, name: str, data_norm_type: str, *args, **kwargs):
|
87 |
+
super().__init__(name=name, data_norm_type=data_norm_type, *args, **kwargs)
|
88 |
+
# Initialize encoder-specific layers and parameters here
|
89 |
+
|
90 |
+
def forward(self, encoder_input: EncoderInput) -> EncoderOutput:
|
91 |
+
self._check_data_normalization_type(encoder_input.data_norm_type)
|
92 |
+
# Implement the forward pass
|
93 |
+
return EncoderOutput()
|
94 |
+
```
|
95 |
+
|
96 |
+
## Example Normalization
|
97 |
+
|
98 |
+
Add the normalization for the new encoder, for example, to `image_normalizations.py`:
|
99 |
+
|
100 |
+
```python
|
101 |
+
# image_normalizations.py
|
102 |
+
IMAGE_NORMALIZATION_DICT = {
|
103 |
+
"dummy": ImageNormalization(mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([1.0, 1.0, 1.0])),
|
104 |
+
"croco": ImageNormalization(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])),
|
105 |
+
"dust3r": ImageNormalization(mean=torch.tensor([0.5, 0.5, 0.5]), std=torch.tensor([0.5, 0.5, 0.5])),
|
106 |
+
"dinov2": ImageNormalization(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])),
|
107 |
+
"radio": ImageNormalization(mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([1.0, 1.0, 1.0])),
|
108 |
+
"new_encoder": ImageNormalization(mean=torch.tensor([0.5, 0.5, 0.5]), std=torch.tensor([0.2, 0.2, 0.2])),
|
109 |
+
}
|
110 |
+
```
|
111 |
+
|
112 |
+
## Example Unit Testing
|
113 |
+
|
114 |
+
Add the new encoder to the encoder factory in `__init__.py` and the encoder list in `tests/models/encoders/test_encoders.py`. Additional tests can also be added as required.
|
115 |
+
|
116 |
+
Look at `tests/models/encoders/test_encoders.py` to see what tests are run.
|
117 |
+
|
118 |
+
Additionally, if the new encoder is an image encoder, you can add to the encoder list in `tests/models/encoders/viz_image_encoders.py` for saving PCA visualizations of the encoder outputs to the `local/pca_images` directory.
|
119 |
+
|
120 |
+
## Developer Guidelines
|
121 |
+
|
122 |
+
Please follow these guidelines when contributing to the UniCeption encoders:
|
123 |
+
- **Consistency**: Ensure that the new encoder follows the structure and naming conventions of existing encoders.
|
124 |
+
- **Code Style**: Follow the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html) for code style.
|
125 |
+
- **Documentation**: Add docstrings to all classes and methods.
|
126 |
+
- **Unit Tests**: Add necessary unit tests for the encoder class.
|
127 |
+
- **Linting**: Run `black` on your code before committing. For example, you can run `black uniception`.
|
128 |
+
|
129 |
+
## Happy Coding!
|
UniCeption/uniception/models/encoders/__init__.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Encoder Factory for UniCeption
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
from uniception.models.encoders.base import (
|
8 |
+
EncoderGlobalRepInput,
|
9 |
+
EncoderInput,
|
10 |
+
UniCeptionEncoderBase,
|
11 |
+
UniCeptionViTEncoderBase,
|
12 |
+
ViTEncoderInput,
|
13 |
+
ViTEncoderNonImageInput,
|
14 |
+
ViTEncoderOutput,
|
15 |
+
)
|
16 |
+
from uniception.models.encoders.cosmos import CosmosEncoder
|
17 |
+
from uniception.models.encoders.croco import CroCoEncoder, CroCoIntermediateFeatureReturner
|
18 |
+
from uniception.models.encoders.dense_rep_encoder import DenseRepresentationEncoder
|
19 |
+
from uniception.models.encoders.dinov2 import DINOv2Encoder, DINOv2IntermediateFeatureReturner
|
20 |
+
from uniception.models.encoders.global_rep_encoder import GlobalRepresentationEncoder
|
21 |
+
from uniception.models.encoders.naradio import NARADIOEncoder
|
22 |
+
from uniception.models.encoders.patch_embedder import PatchEmbedder
|
23 |
+
from uniception.models.encoders.radio import RADIOEncoder, RADIOIntermediateFeatureReturner
|
24 |
+
|
25 |
+
# Define encoder configurations
|
26 |
+
ENCODER_CONFIGS = {
|
27 |
+
"croco": {
|
28 |
+
"class": CroCoEncoder,
|
29 |
+
"intermediate_feature_returner_class": CroCoIntermediateFeatureReturner,
|
30 |
+
"supported_models": ["CroCov2", "DUSt3R", "MASt3R"],
|
31 |
+
},
|
32 |
+
"dense_rep_encoder": {
|
33 |
+
"class": DenseRepresentationEncoder,
|
34 |
+
"supported_models": ["Dense-Representation-Encoder"],
|
35 |
+
},
|
36 |
+
"dinov2": {
|
37 |
+
"class": DINOv2Encoder,
|
38 |
+
"intermediate_feature_returner_class": DINOv2IntermediateFeatureReturner,
|
39 |
+
"supported_models": ["DINOv2", "DINOv2-Registers", "DINOv2-Depth-Anythingv2"],
|
40 |
+
},
|
41 |
+
"global_rep_encoder": {
|
42 |
+
"class": GlobalRepresentationEncoder,
|
43 |
+
"supported_models": ["Global-Representation-Encoder"],
|
44 |
+
},
|
45 |
+
"patch_embedder": {
|
46 |
+
"class": PatchEmbedder,
|
47 |
+
"supported_models": ["Patch-Embedder"],
|
48 |
+
},
|
49 |
+
"radio": {
|
50 |
+
"class": RADIOEncoder,
|
51 |
+
"intermediate_feature_returner_class": RADIOIntermediateFeatureReturner,
|
52 |
+
"supported_models": ["RADIO", "E-RADIO"],
|
53 |
+
},
|
54 |
+
"cosmos": {
|
55 |
+
"class": CosmosEncoder,
|
56 |
+
"supported_models": ["Cosmos-Tokenizer CI8x8", "Cosmos-Tokenizer CI16x16"],
|
57 |
+
},
|
58 |
+
"naradio": {
|
59 |
+
"class": NARADIOEncoder,
|
60 |
+
"supported_models": ["RADIO"],
|
61 |
+
},
|
62 |
+
# Add other encoders here
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
def encoder_factory(encoder_str: str, **kwargs) -> UniCeptionEncoderBase:
|
67 |
+
"""
|
68 |
+
Encoder factory for UniCeption.
|
69 |
+
Please use python3 -m uniception.models.encoders.list to see available encoders.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
encoder_str (str): Name of the encoder to create.
|
73 |
+
**kwargs: Additional keyword arguments to pass to the encoder constructor.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
UniCeptionEncoderBase: An instance of the specified encoder.
|
77 |
+
"""
|
78 |
+
if encoder_str not in ENCODER_CONFIGS:
|
79 |
+
raise ValueError(
|
80 |
+
f"Unknown encoder: {encoder_str}. For valid encoder_str options, please use python3 -m uniception.models.encoders.list"
|
81 |
+
)
|
82 |
+
|
83 |
+
encoder_config = ENCODER_CONFIGS[encoder_str]
|
84 |
+
encoder_class = encoder_config["class"]
|
85 |
+
|
86 |
+
return encoder_class(**kwargs)
|
87 |
+
|
88 |
+
|
89 |
+
def feature_returner_encoder_factory(encoder_str: str, **kwargs) -> UniCeptionEncoderBase:
|
90 |
+
"""
|
91 |
+
Factory for UniCeption Encoders with support for intermediate feature returning.
|
92 |
+
Please use python3 -m uniception.models.encoders.list to see available encoders.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
encoder_str (str): Name of the encoder to create.
|
96 |
+
**kwargs: Additional keyword arguments to pass to the encoder constructor.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
UniCeptionEncoderBase: An instance of the specified encoder.
|
100 |
+
"""
|
101 |
+
if encoder_str not in ENCODER_CONFIGS:
|
102 |
+
raise ValueError(
|
103 |
+
f"Unknown encoder: {encoder_str}. For valid encoder_str options, please use python3 -m uniception.models.encoders.list"
|
104 |
+
)
|
105 |
+
|
106 |
+
encoder_config = ENCODER_CONFIGS[encoder_str]
|
107 |
+
encoder_class = encoder_config["intermediate_feature_returner_class"]
|
108 |
+
|
109 |
+
return encoder_class(**kwargs)
|
110 |
+
|
111 |
+
|
112 |
+
def get_available_encoders() -> list:
|
113 |
+
"""
|
114 |
+
Get a list of available encoders in UniCeption.
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
list: A list of available encoder names.
|
118 |
+
"""
|
119 |
+
return list(ENCODER_CONFIGS.keys())
|
120 |
+
|
121 |
+
|
122 |
+
def print_available_encoder_models():
|
123 |
+
"""
|
124 |
+
Print the currently supported encoders in UniCeption.
|
125 |
+
"""
|
126 |
+
print("Currently Supported Encoders in UniCeption:\nFormat -> encoder_str: supported_models")
|
127 |
+
for encoder_name, config in ENCODER_CONFIGS.items():
|
128 |
+
print(f"{encoder_name}: {', '.join(config['supported_models'])}")
|
129 |
+
|
130 |
+
|
131 |
+
def _make_encoder_test(encoder_str: str, **kwargs) -> UniCeptionEncoderBase:
|
132 |
+
"Function to create encoders for testing purposes."
|
133 |
+
current_file_path = os.path.abspath(__file__)
|
134 |
+
relative_checkpoint_path = os.path.join(os.path.dirname(current_file_path), "../../../checkpoints/encoders")
|
135 |
+
if encoder_str == "dummy":
|
136 |
+
return UniCeptionEncoderBase(name="dummy", data_norm_type="dummy")
|
137 |
+
elif encoder_str == "croco":
|
138 |
+
return CroCoEncoder(
|
139 |
+
name="croco",
|
140 |
+
data_norm_type="croco",
|
141 |
+
pretrained_checkpoint_path=f"{relative_checkpoint_path}/CroCo_Encoder_224.pth",
|
142 |
+
patch_embed_cls="PatchEmbedCroCo",
|
143 |
+
)
|
144 |
+
elif encoder_str == "dust3r_224":
|
145 |
+
return CroCoEncoder(
|
146 |
+
name="dust3r_224",
|
147 |
+
data_norm_type="dust3r",
|
148 |
+
pretrained_checkpoint_path=f"{relative_checkpoint_path}/CroCo_Encoder_224_DUSt3R_linear.pth",
|
149 |
+
patch_embed_cls="PatchEmbedDust3R",
|
150 |
+
)
|
151 |
+
elif encoder_str == "dust3r_512":
|
152 |
+
return CroCoEncoder(
|
153 |
+
name="dust3r_512",
|
154 |
+
data_norm_type="dust3r",
|
155 |
+
pretrained_checkpoint_path=f"{relative_checkpoint_path}/CroCo_Encoder_512_DUSt3R_linear.pth",
|
156 |
+
patch_embed_cls="ManyAR_PatchEmbed",
|
157 |
+
img_size=(512, 512),
|
158 |
+
)
|
159 |
+
elif encoder_str == "dust3r_512_dpt":
|
160 |
+
return CroCoEncoder(
|
161 |
+
name="dust3r_512_dpt",
|
162 |
+
data_norm_type="dust3r",
|
163 |
+
pretrained_checkpoint_path=f"{relative_checkpoint_path}/CroCo_Encoder_512_DUSt3R_dpt.pth",
|
164 |
+
patch_embed_cls="ManyAR_PatchEmbed",
|
165 |
+
img_size=(512, 512),
|
166 |
+
)
|
167 |
+
elif encoder_str == "mast3r_512":
|
168 |
+
return CroCoEncoder(
|
169 |
+
name="mast3r_512",
|
170 |
+
data_norm_type="dust3r",
|
171 |
+
pretrained_checkpoint_path=f"{relative_checkpoint_path}/CroCo_Encoder_512_MASt3R.pth",
|
172 |
+
patch_embed_cls="ManyAR_PatchEmbed",
|
173 |
+
img_size=(512, 512),
|
174 |
+
)
|
175 |
+
elif "dinov2" in encoder_str:
|
176 |
+
size = encoder_str.split("_")[1]
|
177 |
+
size_single_cap_letter = size[0].upper()
|
178 |
+
if "reg" in encoder_str:
|
179 |
+
with_registers = True
|
180 |
+
pretrained_checkpoint_path = None
|
181 |
+
elif "dav2" in encoder_str:
|
182 |
+
with_registers = False
|
183 |
+
pretrained_checkpoint_path = (
|
184 |
+
f"{relative_checkpoint_path}/DINOv2_ViT{size_single_cap_letter}_DepthAnythingV2.pth"
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
with_registers = False
|
188 |
+
pretrained_checkpoint_path = None
|
189 |
+
return DINOv2Encoder(
|
190 |
+
name=encoder_str.replace("_reg", ""),
|
191 |
+
size=size,
|
192 |
+
with_registers=with_registers,
|
193 |
+
pretrained_checkpoint_path=pretrained_checkpoint_path,
|
194 |
+
)
|
195 |
+
elif "naradio" in encoder_str:
|
196 |
+
return NARADIOEncoder(
|
197 |
+
name=encoder_str,
|
198 |
+
model_version=encoder_str.replace("na", ""),
|
199 |
+
)
|
200 |
+
elif "radio" in encoder_str:
|
201 |
+
if "e-radio" in encoder_str:
|
202 |
+
eradio_input_shape = (224, 224)
|
203 |
+
else:
|
204 |
+
eradio_input_shape = None
|
205 |
+
return RADIOEncoder(
|
206 |
+
name=encoder_str,
|
207 |
+
model_version=encoder_str,
|
208 |
+
eradio_input_shape=eradio_input_shape,
|
209 |
+
)
|
210 |
+
elif "cosmos" in encoder_str:
|
211 |
+
patch_size = int(encoder_str.split("x")[-1])
|
212 |
+
return CosmosEncoder(
|
213 |
+
name=encoder_str,
|
214 |
+
patch_size=patch_size,
|
215 |
+
pretrained_checkpoint_path=f"{relative_checkpoint_path}/Cosmos-Tokenizer-CI{patch_size}x{patch_size}/encoder.pth",
|
216 |
+
)
|
217 |
+
elif "patch_embedder" in encoder_str:
|
218 |
+
return PatchEmbedder(
|
219 |
+
name=encoder_str,
|
220 |
+
)
|
221 |
+
else:
|
222 |
+
raise ValueError(f"Unknown encoder: {encoder_str}")
|
223 |
+
|
224 |
+
|
225 |
+
__all__ = [
|
226 |
+
"encoder_factory",
|
227 |
+
"get_available_encoders",
|
228 |
+
"print_available_encoder_models",
|
229 |
+
"_make_encoder_test",
|
230 |
+
"UniCeptionEncoderBase",
|
231 |
+
"UniCeptionViTEncoderBase",
|
232 |
+
"EncoderInput",
|
233 |
+
"ViTEncoderInput",
|
234 |
+
"ViTEncoderOutput",
|
235 |
+
]
|
UniCeption/uniception/models/encoders/base.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Base Encoder Class for UniCeption
|
3 |
+
"""
|
4 |
+
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
import torch.nn as nn
|
9 |
+
from jaxtyping import Float
|
10 |
+
from torch import Tensor
|
11 |
+
from torch.utils.checkpoint import checkpoint
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class EncoderInput:
|
16 |
+
"Data class for Encoder Input"
|
17 |
+
|
18 |
+
data_norm_type: str
|
19 |
+
# Add other fields that are required by the specific implementation of the encoder.
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class EncoderOutput:
|
24 |
+
"Data class for Encoder Output"
|
25 |
+
|
26 |
+
pass
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class EncoderGlobalRepInput:
|
31 |
+
"Data class for Encoder Global Representation Input"
|
32 |
+
|
33 |
+
data: Float[Tensor, "batch channel"]
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class EncoderGlobalRepOutput:
|
38 |
+
"Data class for Encoder Global Representation Output"
|
39 |
+
|
40 |
+
features: Float[Tensor, "batch enc_embed_dim"]
|
41 |
+
|
42 |
+
|
43 |
+
class UniCeptionEncoderBase(nn.Module):
|
44 |
+
"Encoder Base Class for UniCeption"
|
45 |
+
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
name: str,
|
49 |
+
data_norm_type: str,
|
50 |
+
size: Optional[str] = None,
|
51 |
+
*args,
|
52 |
+
**kwargs,
|
53 |
+
):
|
54 |
+
"""
|
55 |
+
Base class for all encoders in UniCeption.
|
56 |
+
"""
|
57 |
+
super().__init__(*args, **kwargs)
|
58 |
+
|
59 |
+
self.name: str = name
|
60 |
+
self.size: Optional[str] = size
|
61 |
+
|
62 |
+
self.data_norm_type: str = data_norm_type
|
63 |
+
|
64 |
+
def forward(
|
65 |
+
self,
|
66 |
+
encoder_input: EncoderInput,
|
67 |
+
) -> EncoderOutput:
|
68 |
+
"""
|
69 |
+
Forward interface for the UniCeption encoders.
|
70 |
+
|
71 |
+
We expect the "data_norm_type" field to be present in the encoder_input to check for normalization type.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
encoder_input (EncoderInput): Input to the encoder. We expect the following fields: "data_norm_type: str".
|
75 |
+
This is also includes the other fields that are required by the specific implementation of the encoder.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
EncoderOutput: Output of the encoder.
|
79 |
+
"""
|
80 |
+
|
81 |
+
raise NotImplementedError
|
82 |
+
|
83 |
+
def _check_data_normalization_type(self, data_norm_type: str):
|
84 |
+
"""
|
85 |
+
Check if the input normalization type matches the encoder's expected input data normalization type.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
data_norm_type (str): Data normalization type.
|
89 |
+
|
90 |
+
Raises:
|
91 |
+
AssertionError: If the data normalization type does not match the encoder's expected input data normalization type.
|
92 |
+
"""
|
93 |
+
|
94 |
+
assert (
|
95 |
+
data_norm_type == self.data_norm_type
|
96 |
+
), f"Input normalization type {data_norm_type} does not match the encoder's normalization type {self.data_norm_type}."
|
97 |
+
|
98 |
+
|
99 |
+
@dataclass
|
100 |
+
class ViTEncoderInput(EncoderInput):
|
101 |
+
"Data class for Vision Transformer Encoder Input"
|
102 |
+
|
103 |
+
image: Float[Tensor, "batch channel height width"]
|
104 |
+
|
105 |
+
|
106 |
+
@dataclass
|
107 |
+
class ViTEncoderNonImageInput:
|
108 |
+
"Data class for Vision (2D-Grid) Transformer Encoder Non-Image Input"
|
109 |
+
|
110 |
+
data: Float[Tensor, "batch channel height width"]
|
111 |
+
|
112 |
+
|
113 |
+
@dataclass
|
114 |
+
class ViTEncoderOutput(EncoderOutput):
|
115 |
+
"Data class for Vision Transformer Encoder Output"
|
116 |
+
|
117 |
+
features: Float[Tensor, "batch enc_embed_dim feat_height feat_width"]
|
118 |
+
|
119 |
+
|
120 |
+
class UniCeptionViTEncoderBase(UniCeptionEncoderBase):
|
121 |
+
"Vision Transformer Encoder Base Class for UniCeption"
|
122 |
+
|
123 |
+
def __init__(
|
124 |
+
self,
|
125 |
+
patch_size: int,
|
126 |
+
gradient_checkpointing: bool = False,
|
127 |
+
*args,
|
128 |
+
**kwargs,
|
129 |
+
):
|
130 |
+
"""
|
131 |
+
Base class for all Vision Transformer encoders in UniCeption.
|
132 |
+
"""
|
133 |
+
super().__init__(*args, **kwargs)
|
134 |
+
|
135 |
+
self.patch_size = patch_size
|
136 |
+
self.gradient_checkpointing = gradient_checkpointing
|
137 |
+
|
138 |
+
def wrap_module_with_gradient_checkpointing(self, module: nn.Module):
|
139 |
+
"""
|
140 |
+
Wrapper for Gradient Checkpointing
|
141 |
+
References: https://github.com/microsoft/MoGe
|
142 |
+
"""
|
143 |
+
|
144 |
+
class _CheckpointingWrapper(module.__class__):
|
145 |
+
_restore_cls = module.__class__
|
146 |
+
|
147 |
+
def forward(self, *args, **kwargs):
|
148 |
+
return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
|
149 |
+
|
150 |
+
module.__class__ = _CheckpointingWrapper
|
151 |
+
return module
|
152 |
+
|
153 |
+
|
154 |
+
if __name__ == "__main__":
|
155 |
+
dummy_model = UniCeptionEncoderBase(name="name", data_norm_type="norm")
|
156 |
+
dummy_vit_model = UniCeptionViTEncoderBase(name="name", data_norm_type="norm", patch_size=16)
|
157 |
+
print("Dummy Base Encoders created successfully!")
|
UniCeption/uniception/models/encoders/cosmos.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Encoder Class for Cosmos
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput
|
8 |
+
from uniception.models.libs.cosmos_tokenizer.modules import ContinuousFormulation, EncoderType
|
9 |
+
from uniception.models.libs.cosmos_tokenizer.networks import TokenizerConfigs
|
10 |
+
|
11 |
+
|
12 |
+
class CosmosEncoder(UniCeptionViTEncoderBase):
|
13 |
+
"Uniception Cosmos Encoder"
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
name: str,
|
18 |
+
data_norm_type: str = "cosmos",
|
19 |
+
patch_size: int = 8,
|
20 |
+
pretrained_checkpoint_path: str = None,
|
21 |
+
*args,
|
22 |
+
**kwargs,
|
23 |
+
):
|
24 |
+
"""
|
25 |
+
Cosmos Encoder for extracting spatial features from images.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
name (str): Name of the encoder.
|
29 |
+
data_norm_type (str): Image normalization type. Default: "cosmos"
|
30 |
+
patch_size (int): Patch size for the encoder. Default: 8
|
31 |
+
pretrained_checkpoint_path (str): Path to the pretrained checkpoint. Default: None
|
32 |
+
"""
|
33 |
+
# Init the base class
|
34 |
+
super().__init__(name=name, data_norm_type=data_norm_type, patch_size=patch_size, *args, **kwargs)
|
35 |
+
|
36 |
+
# Init Cosmos Encoder sepecific attributes
|
37 |
+
tokenizer_config = TokenizerConfigs["CI"].value.copy()
|
38 |
+
tokenizer_config.update(dict(spatial_compression=self.patch_size))
|
39 |
+
|
40 |
+
z_factor = tokenizer_config["z_factor"]
|
41 |
+
z_channels = tokenizer_config["z_channels"]
|
42 |
+
latent_channels = tokenizer_config["latent_channels"]
|
43 |
+
encoder_name = kwargs.get("encoder", EncoderType.Default.name)
|
44 |
+
print(tokenizer_config)
|
45 |
+
del tokenizer_config["z_factor"]
|
46 |
+
del tokenizer_config["z_channels"]
|
47 |
+
del tokenizer_config["latent_channels"]
|
48 |
+
self.encoder = EncoderType[encoder_name].value(z_channels=z_factor * z_channels, **tokenizer_config)
|
49 |
+
self.quant_conv = torch.nn.Conv2d(z_factor * z_channels, z_factor * latent_channels, 1)
|
50 |
+
formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name)
|
51 |
+
self.distribution = ContinuousFormulation[formulation_name].value()
|
52 |
+
|
53 |
+
# Load the pretrained checkpoint
|
54 |
+
if pretrained_checkpoint_path is not None:
|
55 |
+
print(f"Loading custom pretrained Cosmos checkpoint from {pretrained_checkpoint_path}")
|
56 |
+
ckpt = torch.load(pretrained_checkpoint_path, weights_only=False)
|
57 |
+
print(self.load_state_dict(ckpt["model"]))
|
58 |
+
|
59 |
+
def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]:
|
60 |
+
"""Encodes an image into a latent embedding or code.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
input_tensor: The input tensor Bx3xHxW layout, range [-1..1].
|
64 |
+
Returns:
|
65 |
+
For continuous image (CI) tokenizer, the tuple contains:
|
66 |
+
- The latent embedding, Bx16x(h)x(w), where the compression
|
67 |
+
rate is (H/h x W/w), and channel dimension of 16.
|
68 |
+
For discrete image (DI) tokenizer, the tuple contains:
|
69 |
+
- The indices, Bx(h)x(w), from a codebook of size 64K, which
|
70 |
+
corresponds to FSQ levels of (8,8,8,5,5,5).
|
71 |
+
- The discrete code, Bx6x(h)x(w), where the compression rate is
|
72 |
+
again (H/h x W/w), and channel dimension of 6.
|
73 |
+
"""
|
74 |
+
x = self.encoder(input_tensor)
|
75 |
+
x = self.quant_conv(x)
|
76 |
+
output_latent = self.distribution(x)
|
77 |
+
|
78 |
+
if isinstance(output_latent, torch.Tensor):
|
79 |
+
return output_latent
|
80 |
+
return output_latent[:-1]
|
81 |
+
|
82 |
+
def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput:
|
83 |
+
"""
|
84 |
+
Cosmos Encoder Forward Pass
|
85 |
+
|
86 |
+
Args:
|
87 |
+
encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
ViTEncoderOutput: Output data from the encoder.
|
91 |
+
"""
|
92 |
+
# Check image normalization type
|
93 |
+
self._check_data_normalization_type(encoder_input.data_norm_type)
|
94 |
+
|
95 |
+
# Check the dtype and shape of the input image
|
96 |
+
assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor"
|
97 |
+
assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)"
|
98 |
+
batch_size, channels, height, width = encoder_input.image.shape
|
99 |
+
assert channels == 3, "Input must have 3 channels"
|
100 |
+
assert (
|
101 |
+
height % self.patch_size == 0 and width % self.patch_size == 0
|
102 |
+
), f"Input shape must be divisible by patch size: {self.patch_size}"
|
103 |
+
|
104 |
+
# Extract the features from the DINOv2 model
|
105 |
+
features = self.encode(encoder_input.image)[0].contiguous()
|
106 |
+
|
107 |
+
return ViTEncoderOutput(features=features)
|
108 |
+
|
109 |
+
|
110 |
+
if __name__ == "__main__":
|
111 |
+
|
112 |
+
# initialize different variants of the Cosmos Encoder, untrained
|
113 |
+
for is_continuous in [True]:
|
114 |
+
for patch_size in [8, 16]:
|
115 |
+
encoder = CosmosEncoder(name="cosmos", patch_size=patch_size)
|
116 |
+
|
117 |
+
# # initialize from trained checkpoint, with/without jit inference capability
|
118 |
+
PRETRAINED_JIT_CHECKPOINTS = {
|
119 |
+
("CI", 8): "../../../checkpoints/encoders/cosmos/Cosmos-Tokenizer-CI8x8/encoder.pth",
|
120 |
+
("CI", 16): "../../../checkpoints/encoders/cosmos/Cosmos-Tokenizer-CI16x16/encoder.pth",
|
121 |
+
}
|
122 |
+
|
123 |
+
for patch_size in [8, 16]:
|
124 |
+
|
125 |
+
encoder = CosmosEncoder(
|
126 |
+
name="cosmos",
|
127 |
+
patch_size=patch_size,
|
128 |
+
pretrained_checkpoint_path=PRETRAINED_JIT_CHECKPOINTS[("CI", patch_size)],
|
129 |
+
)
|
130 |
+
|
131 |
+
# example inference
|
132 |
+
dummy_image = torch.randn(1, 3, 256, 256).cuda()
|
133 |
+
|
134 |
+
encoder_input = ViTEncoderInput(data_norm_type="cosmos", image=dummy_image)
|
135 |
+
|
136 |
+
encoder = encoder.cuda()
|
137 |
+
encoder_output = encoder(encoder_input)
|
UniCeption/uniception/models/encoders/croco.py
ADDED
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Encoder Class for CroCo & DUSt3R
|
3 |
+
"""
|
4 |
+
|
5 |
+
from functools import partial
|
6 |
+
from typing import Callable, List, Optional, Tuple, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput
|
12 |
+
from uniception.models.libs.croco.blocks import Block
|
13 |
+
from uniception.models.libs.croco.patch_embed import get_patch_embed
|
14 |
+
from uniception.models.libs.croco.pos_embed import RoPE2D
|
15 |
+
from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner, feature_take_indices
|
16 |
+
|
17 |
+
|
18 |
+
class CroCoEncoder(UniCeptionViTEncoderBase):
|
19 |
+
"UniCeption CroCov2 Encoder"
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
name: str,
|
24 |
+
data_norm_type: str,
|
25 |
+
patch_embed_cls: str = "PatchEmbedDust3R",
|
26 |
+
img_size: Union[int, Tuple[int, int]] = (224, 224),
|
27 |
+
patch_size: int = 16,
|
28 |
+
enc_embed_dim: int = 1024,
|
29 |
+
enc_depth: int = 24,
|
30 |
+
enc_num_heads: int = 16,
|
31 |
+
mlp_ratio: int = 4,
|
32 |
+
norm_layer: Callable = partial(nn.LayerNorm, eps=1e-6),
|
33 |
+
pos_embed: str = "RoPE100",
|
34 |
+
pretrained_checkpoint_path: str = None,
|
35 |
+
override_checkpoint_attributes: bool = False,
|
36 |
+
*args,
|
37 |
+
**kwargs,
|
38 |
+
):
|
39 |
+
"""
|
40 |
+
References: https://github.com/naver/dust3r, https://github.com/naver/croco
|
41 |
+
|
42 |
+
Args:
|
43 |
+
name (str): Name of the encoder.
|
44 |
+
data_norm_type (str): Input data normalization type.
|
45 |
+
patch_embed_cls (str, optional): The class to use for patch embedding.
|
46 |
+
Defaults to 'PatchEmbedDust3R'. Options: ['PatchEmbedCroCo', 'PatchEmbedDust3R', 'ManyAR_PatchEmbed'].
|
47 |
+
img_size (int, optional): The size of the input image. Defaults to 224.
|
48 |
+
patch_size (int, optional): The size of the patches to divide the image into. Defaults to 16.
|
49 |
+
enc_embed_dim (int, optional): The dimension of the encoder's embedding. Defaults to 768.
|
50 |
+
enc_depth (int, optional): The number of encoder layers/transformer blocks. Defaults to 12.
|
51 |
+
enc_num_heads (int, optional): The number of encoder heads. Defaults to 12.
|
52 |
+
mlp_ratio (int, optional): The MLP ratio used for the CroCo encoder transformer. Defaults to 4.
|
53 |
+
norm_layer (nn.Module, optional): The normalization layer to use in the transformer. Defaults to nn.LayerNorm with eps=1e-6.
|
54 |
+
pos_embed (str, optional): Positional Embedding. Defaults to 'RoPE100'. Options: ['RoPEfreq'].
|
55 |
+
pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. Defaults to None.
|
56 |
+
"""
|
57 |
+
# Init the base class
|
58 |
+
super().__init__(
|
59 |
+
name=name,
|
60 |
+
data_norm_type=data_norm_type,
|
61 |
+
patch_size=patch_size,
|
62 |
+
*args,
|
63 |
+
**kwargs,
|
64 |
+
)
|
65 |
+
|
66 |
+
# Init the CroCo Encoder specific attributes
|
67 |
+
self.patch_embed_cls = patch_embed_cls
|
68 |
+
self.img_size = img_size
|
69 |
+
self.enc_embed_dim = enc_embed_dim
|
70 |
+
self.enc_depth = enc_depth
|
71 |
+
self.enc_num_heads = enc_num_heads
|
72 |
+
self.mlp_ratio = mlp_ratio
|
73 |
+
self.norm_layer = norm_layer
|
74 |
+
self.pretrained_checkpoint_path = pretrained_checkpoint_path
|
75 |
+
self.override_checkpoint_attributes = override_checkpoint_attributes
|
76 |
+
|
77 |
+
# Init the positional embedding
|
78 |
+
self.pos_embed = pos_embed
|
79 |
+
if pos_embed.startswith("RoPE"): # eg RoPE100
|
80 |
+
self.enc_pos_embed = None # nothing to add in the encoder with RoPE
|
81 |
+
self.dec_pos_embed = None # nothing to add in the decoder with RoPE
|
82 |
+
if RoPE2D is None:
|
83 |
+
raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions")
|
84 |
+
freq = float(pos_embed[len("RoPE") :])
|
85 |
+
self.rope = RoPE2D(freq=freq)
|
86 |
+
else:
|
87 |
+
raise NotImplementedError("Unknown pos_embed " + pos_embed)
|
88 |
+
|
89 |
+
# Init the patch embedding
|
90 |
+
self._set_patch_embed(img_size, patch_size, enc_embed_dim)
|
91 |
+
|
92 |
+
# Init the encoder
|
93 |
+
self._set_encoder(enc_depth, enc_embed_dim, enc_num_heads, mlp_ratio, norm_layer, self.rope)
|
94 |
+
|
95 |
+
# Initialize random weights
|
96 |
+
self.initialize_weights()
|
97 |
+
|
98 |
+
# Load the pretrained CroCo checkpoint if provided
|
99 |
+
if pretrained_checkpoint_path:
|
100 |
+
print(f"Loading pretrained CroCo checkpoint from {pretrained_checkpoint_path}")
|
101 |
+
ckpt = torch.load(pretrained_checkpoint_path, weights_only=False)
|
102 |
+
print(self.load_state_dict(ckpt["model"]))
|
103 |
+
if not override_checkpoint_attributes:
|
104 |
+
ckpt_data_norm_type = ckpt["data_norm_type"]
|
105 |
+
ckpt_patch_embed_cls = ckpt["patch_embed_cls"]
|
106 |
+
assert (
|
107 |
+
data_norm_type == ckpt_data_norm_type
|
108 |
+
), f"Data normalization type {data_norm_type} does not match the checkpoint {ckpt_data_norm_type}."
|
109 |
+
assert (
|
110 |
+
patch_embed_cls == ckpt_patch_embed_cls
|
111 |
+
), f"Patch embedding class {patch_embed_cls} does not match the checkpoint {ckpt_patch_embed_cls}."
|
112 |
+
else:
|
113 |
+
print("No pretrained checkpoint provided. Randomly initializing the CroCo encoder.")
|
114 |
+
|
115 |
+
def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
|
116 |
+
"Set the patch embedding scheme"
|
117 |
+
self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim)
|
118 |
+
|
119 |
+
def _set_encoder(self, enc_depth, enc_embed_dim, enc_num_heads, mlp_ratio, norm_layer, rope):
|
120 |
+
"Set the encoder"
|
121 |
+
self.enc_blocks = nn.ModuleList(
|
122 |
+
[
|
123 |
+
Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=rope)
|
124 |
+
for _ in range(enc_depth)
|
125 |
+
]
|
126 |
+
)
|
127 |
+
self.enc_norm = norm_layer(enc_embed_dim)
|
128 |
+
|
129 |
+
def initialize_weights(self):
|
130 |
+
"Initialize the weights of the patch embedding and the transformer encoder"
|
131 |
+
# Patch embedding
|
132 |
+
self.patch_embed._init_weights()
|
133 |
+
# Linears and layer norms
|
134 |
+
self.apply(self._init_weights)
|
135 |
+
|
136 |
+
def _init_weights(self, m):
|
137 |
+
"Initialize the transformer encoder weights"
|
138 |
+
if isinstance(m, nn.Linear):
|
139 |
+
# We use xavier_uniform following official JAX ViT:
|
140 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
141 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
142 |
+
nn.init.constant_(m.bias, 0)
|
143 |
+
elif isinstance(m, nn.LayerNorm):
|
144 |
+
nn.init.constant_(m.bias, 0)
|
145 |
+
nn.init.constant_(m.weight, 1.0)
|
146 |
+
|
147 |
+
def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput:
|
148 |
+
"""
|
149 |
+
CroCov2 Encoder Forward Pass
|
150 |
+
|
151 |
+
Args:
|
152 |
+
encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
ViTEncoderOutput: Output data from the encoder.
|
156 |
+
"""
|
157 |
+
# Check image normalization type
|
158 |
+
self._check_data_normalization_type(encoder_input.data_norm_type)
|
159 |
+
|
160 |
+
# Get the true shape of the image for landscape/portrait mode check in patch embedding
|
161 |
+
batch_size, _, height, width = encoder_input.image.shape
|
162 |
+
if hasattr(encoder_input, "true_shape"):
|
163 |
+
true_shape = encoder_input.true_shape
|
164 |
+
else:
|
165 |
+
true_shape = torch.tensor([height, width])[None].repeat(batch_size, 1)
|
166 |
+
|
167 |
+
# Embed the image into patches
|
168 |
+
features, pos = self.patch_embed(encoder_input.image, true_shape=true_shape)
|
169 |
+
|
170 |
+
# Now apply the transformer encoder and normalization
|
171 |
+
for blk in self.enc_blocks:
|
172 |
+
features = blk(features, pos)
|
173 |
+
features = self.enc_norm(features)
|
174 |
+
|
175 |
+
# Resize the features to the expected shape
|
176 |
+
# (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size)
|
177 |
+
features = features.permute(0, 2, 1)
|
178 |
+
features = features.reshape(
|
179 |
+
-1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size
|
180 |
+
).contiguous()
|
181 |
+
|
182 |
+
return ViTEncoderOutput(features=features)
|
183 |
+
|
184 |
+
|
185 |
+
class CroCoIntermediateFeatureReturner(CroCoEncoder, IntermediateFeatureReturner):
|
186 |
+
"Intermediate Feature Returner for UniCeption CroCo Encoder"
|
187 |
+
|
188 |
+
def __init__(
|
189 |
+
self,
|
190 |
+
name: str,
|
191 |
+
data_norm_type: str,
|
192 |
+
patch_embed_cls: str = "PatchEmbedDust3R",
|
193 |
+
img_size: Union[int, Tuple[int, int]] = (224, 224),
|
194 |
+
patch_size: int = 16,
|
195 |
+
enc_embed_dim: int = 1024,
|
196 |
+
enc_depth: int = 24,
|
197 |
+
enc_num_heads: int = 16,
|
198 |
+
mlp_ratio: int = 4,
|
199 |
+
norm_layer: Callable = partial(nn.LayerNorm, eps=1e-6),
|
200 |
+
pos_embed: str = "RoPE100",
|
201 |
+
pretrained_checkpoint_path: str = None,
|
202 |
+
indices: Optional[Union[int, List[int]]] = None,
|
203 |
+
norm_intermediate: bool = True,
|
204 |
+
stop_early: bool = False,
|
205 |
+
intermediates_only: bool = True,
|
206 |
+
*args,
|
207 |
+
**kwargs,
|
208 |
+
):
|
209 |
+
"""
|
210 |
+
Intermediate Feature Returner for the CroCo Encoder.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
name (str): Name of the encoder.
|
214 |
+
data_norm_type (str): Input data normalization type.
|
215 |
+
patch_embed_cls (str, optional): The class to use for patch embedding.
|
216 |
+
Defaults to 'PatchEmbedDust3R'. Options: ['PatchEmbedCroCo', 'PatchEmbedDust3R', 'ManyAR_PatchEmbed'].
|
217 |
+
img_size (int, optional): The size of the input image. Defaults to 224.
|
218 |
+
patch_size (int, optional): The size of the patches to divide the image into. Defaults to 16.
|
219 |
+
enc_embed_dim (int, optional): The dimension of the encoder's embedding. Defaults to 768.
|
220 |
+
enc_depth (int, optional): The number of encoder layers/transformer blocks. Defaults to 12.
|
221 |
+
enc_num_heads (int, optional): The number of encoder heads. Defaults to 12.
|
222 |
+
mlp_ratio (int, optional): The MLP ratio used for the CroCo encoder transformer. Defaults to 4.
|
223 |
+
norm_layer (nn.Module, optional): The normalization layer to use in the transformer. Defaults to nn.LayerNorm with eps=1e-6.
|
224 |
+
pos_embed (str, optional): Positional Embedding. Defaults to 'RoPE100'. Options: ['cosine', 'RoPE100'].
|
225 |
+
pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. Defaults to None.
|
226 |
+
indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. Defaults to None. Options:
|
227 |
+
- None: Return all intermediate layers.
|
228 |
+
- int: Return the last n layers.
|
229 |
+
- List[int]: Return the intermediate layers at the specified indices.
|
230 |
+
norm_intermediate (bool, optional): Whether to normalize the intermediate features. Defaults to True.
|
231 |
+
stop_early (bool, optional): Whether to stop early. Defaults to False.
|
232 |
+
intermediates_only (bool, optional): Whether to return only the intermediate features. Defaults to True.
|
233 |
+
"""
|
234 |
+
# Init the base classes
|
235 |
+
CroCoEncoder.__init__(
|
236 |
+
self,
|
237 |
+
name=name,
|
238 |
+
data_norm_type=data_norm_type,
|
239 |
+
patch_embed_cls=patch_embed_cls,
|
240 |
+
img_size=img_size,
|
241 |
+
patch_size=patch_size,
|
242 |
+
enc_embed_dim=enc_embed_dim,
|
243 |
+
enc_depth=enc_depth,
|
244 |
+
enc_num_heads=enc_num_heads,
|
245 |
+
mlp_ratio=mlp_ratio,
|
246 |
+
norm_layer=norm_layer,
|
247 |
+
pos_embed=pos_embed,
|
248 |
+
pretrained_checkpoint_path=pretrained_checkpoint_path,
|
249 |
+
*args,
|
250 |
+
**kwargs,
|
251 |
+
)
|
252 |
+
IntermediateFeatureReturner.__init__(
|
253 |
+
self,
|
254 |
+
indices=indices,
|
255 |
+
norm_intermediate=norm_intermediate,
|
256 |
+
stop_early=stop_early,
|
257 |
+
intermediates_only=intermediates_only,
|
258 |
+
)
|
259 |
+
|
260 |
+
def forward(
|
261 |
+
self, encoder_input: ViTEncoderInput
|
262 |
+
) -> Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]:
|
263 |
+
"""
|
264 |
+
CroCov2 Encoder Forward Pass with Intermediate Feature Return
|
265 |
+
|
266 |
+
Args:
|
267 |
+
encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]: Output data from the encoder.
|
271 |
+
If `intermediates_only` is True, returns a list of intermediate features.
|
272 |
+
Otherwise, returns a tuple with the final features and a list of intermediate features.
|
273 |
+
"""
|
274 |
+
# Check image normalization type
|
275 |
+
self._check_data_normalization_type(encoder_input.data_norm_type)
|
276 |
+
|
277 |
+
# Get the true shape of the image for landscape/portrait mode check in patch embedding
|
278 |
+
batch_size, _, height, width = encoder_input.image.shape
|
279 |
+
if hasattr(encoder_input, "true_shape"):
|
280 |
+
true_shape = encoder_input.true_shape
|
281 |
+
else:
|
282 |
+
true_shape = torch.tensor([height, width])[None].repeat(batch_size, 1)
|
283 |
+
|
284 |
+
# Embed the image into patches
|
285 |
+
features, pos = self.patch_embed(encoder_input.image, true_shape=true_shape)
|
286 |
+
|
287 |
+
# Get indices of the intermediate features to return
|
288 |
+
intermediate_features = []
|
289 |
+
take_indices, max_index = feature_take_indices(len(self.enc_blocks), self.indices)
|
290 |
+
|
291 |
+
# Get the blocks based on early stopping
|
292 |
+
if torch.jit.is_scripting() or not self.stop_early: # can't slice blocks in torchscript
|
293 |
+
blocks = self.enc_blocks
|
294 |
+
else:
|
295 |
+
blocks = self.enc_blocks[: max_index + 1]
|
296 |
+
|
297 |
+
# Now apply the transformer encoder and normalization
|
298 |
+
for blk_idx, blk in enumerate(blocks):
|
299 |
+
features = blk(features, pos)
|
300 |
+
if blk_idx in take_indices:
|
301 |
+
# Normalize intermediates with final norm layer if enabled
|
302 |
+
intermediate_features.append(self.enc_norm(features) if self.norm_intermediate else features)
|
303 |
+
|
304 |
+
# Reshape the intermediate features and convert to ViTEncoderOutput class
|
305 |
+
intermediate_features = [
|
306 |
+
intermediate.permute(0, 2, 1)
|
307 |
+
.reshape(-1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size)
|
308 |
+
.contiguous()
|
309 |
+
for intermediate in intermediate_features
|
310 |
+
]
|
311 |
+
intermediate_features = [ViTEncoderOutput(features=intermediate) for intermediate in intermediate_features]
|
312 |
+
|
313 |
+
# Return only the intermediate features if enabled
|
314 |
+
if self.intermediates_only:
|
315 |
+
return intermediate_features
|
316 |
+
|
317 |
+
# Normalize and reshape the final features
|
318 |
+
features = self.enc_norm(features)
|
319 |
+
# Resize the features to the expected shape
|
320 |
+
# (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size)
|
321 |
+
features = features.permute(0, 2, 1)
|
322 |
+
features = features.reshape(
|
323 |
+
-1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size
|
324 |
+
).contiguous()
|
325 |
+
final_features = ViTEncoderOutput(features=features)
|
326 |
+
|
327 |
+
return final_features, intermediate_features
|
328 |
+
|
329 |
+
|
330 |
+
if __name__ == "__main__":
|
331 |
+
# Init the pre-trained CroCo Encoder
|
332 |
+
pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_224.pth"
|
333 |
+
croco_encoder = CroCoEncoder(
|
334 |
+
name="croco",
|
335 |
+
data_norm_type="croco",
|
336 |
+
pretrained_checkpoint_path=pretrained_checkpoint_path,
|
337 |
+
patch_embed_cls="PatchEmbedCroCo",
|
338 |
+
)
|
339 |
+
|
340 |
+
# Init the pre-trained DUSt3R CroCo Encoder
|
341 |
+
pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_224_DUSt3R_linear.pth"
|
342 |
+
dust3r_encoder = CroCoEncoder(
|
343 |
+
name="dust3r_224",
|
344 |
+
data_norm_type="dust3r",
|
345 |
+
pretrained_checkpoint_path=pretrained_checkpoint_path,
|
346 |
+
patch_embed_cls="PatchEmbedDust3R",
|
347 |
+
)
|
348 |
+
|
349 |
+
# Init the pre-trained DUSt3R 512 linear CroCo Encoder
|
350 |
+
pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_512_DUSt3R_linear.pth"
|
351 |
+
dust3r_encoder_512 = CroCoEncoder(
|
352 |
+
name="dust3r_512",
|
353 |
+
data_norm_type="dust3r",
|
354 |
+
pretrained_checkpoint_path=pretrained_checkpoint_path,
|
355 |
+
patch_embed_cls="ManyAR_PatchEmbed",
|
356 |
+
img_size=(512, 512),
|
357 |
+
)
|
358 |
+
|
359 |
+
# Init the pre-trained DUSt3R 512 DPT CroCo Encoder
|
360 |
+
pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_512_DUSt3R_dpt.pth"
|
361 |
+
dust3r_encoder_512_dpt = CroCoEncoder(
|
362 |
+
name="dust3r_512_dpt",
|
363 |
+
data_norm_type="dust3r",
|
364 |
+
pretrained_checkpoint_path=pretrained_checkpoint_path,
|
365 |
+
patch_embed_cls="ManyAR_PatchEmbed",
|
366 |
+
img_size=(512, 512),
|
367 |
+
)
|
368 |
+
|
369 |
+
# Init the MASt3R 512 CroCo Encoder
|
370 |
+
pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_512_MASt3R.pth"
|
371 |
+
mast3r_encoder_512 = CroCoEncoder(
|
372 |
+
name="mast3r_512",
|
373 |
+
data_norm_type="dust3r",
|
374 |
+
pretrained_checkpoint_path=pretrained_checkpoint_path,
|
375 |
+
patch_embed_cls="ManyAR_PatchEmbed",
|
376 |
+
img_size=(512, 512),
|
377 |
+
)
|
378 |
+
|
379 |
+
print("All CroCo & DUSt3R Encoders have been initialized successfully!")
|
380 |
+
|
381 |
+
# Intermediate Feature Returner Tests
|
382 |
+
print("Running Intermediate Feature Returner Tests...")
|
383 |
+
pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_512_DUSt3R_dpt.pth"
|
384 |
+
|
385 |
+
# Run the intermediate feature returner with last-n index
|
386 |
+
dust3r_intermediate_feature_returner = CroCoIntermediateFeatureReturner(
|
387 |
+
name="dust3r_512_dpt",
|
388 |
+
data_norm_type="dust3r",
|
389 |
+
pretrained_checkpoint_path=pretrained_checkpoint_path,
|
390 |
+
patch_embed_cls="ManyAR_PatchEmbed",
|
391 |
+
img_size=(512, 512),
|
392 |
+
indices=6, # Last 6 layers
|
393 |
+
)
|
394 |
+
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dust3r")
|
395 |
+
output = dust3r_intermediate_feature_returner(dummy_input)
|
396 |
+
assert isinstance(output, list), "Output must be a list of intermediate features"
|
397 |
+
assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
|
398 |
+
assert len(output) == 6, "Output must have length of intermediate features equal to the number of indices"
|
399 |
+
|
400 |
+
# Run the intermediate feature returner with specific indices
|
401 |
+
dust3r_intermediate_feature_returner = CroCoIntermediateFeatureReturner(
|
402 |
+
name="dust3r_512_dpt",
|
403 |
+
data_norm_type="dust3r",
|
404 |
+
pretrained_checkpoint_path=pretrained_checkpoint_path,
|
405 |
+
patch_embed_cls="ManyAR_PatchEmbed",
|
406 |
+
img_size=(512, 512),
|
407 |
+
indices=[0, 2, 4, 6], # Specific layers
|
408 |
+
)
|
409 |
+
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dust3r")
|
410 |
+
output = dust3r_intermediate_feature_returner(dummy_input)
|
411 |
+
assert isinstance(output, list), "Output must be a list of intermediate features"
|
412 |
+
assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
|
413 |
+
assert len(output) == 4, "Output must have length of intermediate features equal to the number of indices"
|
414 |
+
|
415 |
+
# Test the normalizing of intermediate features
|
416 |
+
dust3r_intermediate_feature_returner = CroCoIntermediateFeatureReturner(
|
417 |
+
name="dust3r_512_dpt",
|
418 |
+
data_norm_type="dust3r",
|
419 |
+
pretrained_checkpoint_path=pretrained_checkpoint_path,
|
420 |
+
patch_embed_cls="ManyAR_PatchEmbed",
|
421 |
+
img_size=(512, 512),
|
422 |
+
indices=[-1],
|
423 |
+
norm_intermediate=False,
|
424 |
+
intermediates_only=False,
|
425 |
+
)
|
426 |
+
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dust3r")
|
427 |
+
output = dust3r_intermediate_feature_returner(dummy_input)
|
428 |
+
assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features"
|
429 |
+
assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features"
|
430 |
+
assert isinstance(output[1], list), "Second element of output must be a list of intermediate features"
|
431 |
+
assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
|
432 |
+
if not isinstance(dust3r_intermediate_feature_returner.enc_norm, torch.nn.Identity):
|
433 |
+
assert not torch.equal(
|
434 |
+
output[0].features, output[1][0].features
|
435 |
+
), "Final features and intermediate features must be different"
|
436 |
+
|
437 |
+
dust3r_intermediate_feature_returner = CroCoIntermediateFeatureReturner(
|
438 |
+
name="dust3r_512_dpt",
|
439 |
+
data_norm_type="dust3r",
|
440 |
+
pretrained_checkpoint_path=pretrained_checkpoint_path,
|
441 |
+
patch_embed_cls="ManyAR_PatchEmbed",
|
442 |
+
img_size=(512, 512),
|
443 |
+
indices=[-1],
|
444 |
+
norm_intermediate=True,
|
445 |
+
intermediates_only=False,
|
446 |
+
)
|
447 |
+
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dust3r")
|
448 |
+
output = dust3r_intermediate_feature_returner(dummy_input)
|
449 |
+
assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features"
|
450 |
+
assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features"
|
451 |
+
assert isinstance(output[1], list), "Second element of output must be a list of intermediate features"
|
452 |
+
assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
|
453 |
+
assert torch.equal(
|
454 |
+
output[0].features, output[1][0].features
|
455 |
+
), "Final features and intermediate features must be same"
|
456 |
+
|
457 |
+
print("All Intermediate Feature Returner Tests have passed successfully!")
|
UniCeption/uniception/models/encoders/dense_rep_encoder.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Encoder class for Dense Representation Encoder
|
3 |
+
"""
|
4 |
+
|
5 |
+
import math
|
6 |
+
from functools import partial
|
7 |
+
from typing import Callable, List, Optional, Tuple, Type, Union
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from torch.nn.init import trunc_normal_
|
13 |
+
|
14 |
+
from uniception.models.encoders.base import (
|
15 |
+
UniCeptionViTEncoderBase,
|
16 |
+
ViTEncoderInput,
|
17 |
+
ViTEncoderNonImageInput,
|
18 |
+
ViTEncoderOutput,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
def make_2tuple(x):
|
23 |
+
if isinstance(x, tuple):
|
24 |
+
assert len(x) == 2
|
25 |
+
return x
|
26 |
+
|
27 |
+
assert isinstance(x, int)
|
28 |
+
return (x, x)
|
29 |
+
|
30 |
+
|
31 |
+
class ResidualBlock(nn.Module):
|
32 |
+
"Redidual block for Dense Representation Encoder"
|
33 |
+
|
34 |
+
def __init__(self, in_channels: int, out_channels: int, act_layer: Type[nn.Module] = nn.GELU):
|
35 |
+
super(ResidualBlock, self).__init__()
|
36 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
37 |
+
self.act = act_layer()
|
38 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
39 |
+
self.shortcut = (
|
40 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
41 |
+
if in_channels != out_channels
|
42 |
+
else nn.Identity()
|
43 |
+
)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
identity = self.shortcut(x)
|
47 |
+
out = self.conv1(x)
|
48 |
+
out = self.act(out)
|
49 |
+
out = self.conv2(out)
|
50 |
+
out += identity
|
51 |
+
|
52 |
+
return self.act(out)
|
53 |
+
|
54 |
+
|
55 |
+
class DenseRepresentationEncoder(UniCeptionViTEncoderBase):
|
56 |
+
"UniCeption Dense Representation Encoder"
|
57 |
+
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
name: str,
|
61 |
+
in_chans: int = 3,
|
62 |
+
enc_embed_dim: int = 1024,
|
63 |
+
apply_pe: bool = True,
|
64 |
+
input_size_for_pe: Union[int, Tuple[int, int]] = 518,
|
65 |
+
patch_size: int = 14,
|
66 |
+
intermediate_dims: List[int] = [588, 768, 1024],
|
67 |
+
data_norm_type: str = "dense_rep_encoder",
|
68 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
69 |
+
norm_layer: Optional[Callable] = partial(nn.LayerNorm, eps=1e-6),
|
70 |
+
post_pe_norm_layer: Optional[Callable] = partial(nn.LayerNorm, eps=1e-6),
|
71 |
+
interpolate_antialias: bool = False,
|
72 |
+
interpolate_offset: float = 0.1,
|
73 |
+
pretrained_checkpoint_path: str = None,
|
74 |
+
*args,
|
75 |
+
**kwargs,
|
76 |
+
):
|
77 |
+
"""
|
78 |
+
Dense Representation Encoder for extracting patch-wise features from a spatial input of size (B, C, H, W).
|
79 |
+
Uses a convolution based patchify followed by some residual blocks.
|
80 |
+
Also applies positional encoding with interpolation to the patch-wise features if required.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
in_chans (int): Number of input channels.
|
84 |
+
enc_embed_dim (int): Embedding dimension of the encoder.
|
85 |
+
apply_pe (bool): Whether to apply positional encoding.
|
86 |
+
input_size_for_pe (Union[int, Tuple[int, int]]): Input size for positional encoding.
|
87 |
+
patch_size (int): Patch size of the encoder.
|
88 |
+
intermediate_dims (List[int]): Intermediate dimensions of the encoder.
|
89 |
+
data_norm_type (str): Data normalization type. (Used for checking if the input images are normalized correctly.)
|
90 |
+
act_layer (Type[nn.Module]): Activation layer.
|
91 |
+
norm_layer (Optional[Callable]): Normalization layer.
|
92 |
+
post_pe_norm_layer (Optional[Callable]): Normalization layer after positional encoding.
|
93 |
+
interpolate_antialias (bool): Whether to apply antialiasing in interpolation.
|
94 |
+
interpolate_offset (float): Offset for interpolation.
|
95 |
+
pretrained_checkpoint_path (str): Path to pretrained checkpoint.
|
96 |
+
"""
|
97 |
+
# Init the base class
|
98 |
+
super().__init__(
|
99 |
+
name=name,
|
100 |
+
data_norm_type=data_norm_type,
|
101 |
+
patch_size=patch_size,
|
102 |
+
*args,
|
103 |
+
**kwargs,
|
104 |
+
)
|
105 |
+
|
106 |
+
# Init the specific attributes
|
107 |
+
self.in_chans = in_chans
|
108 |
+
self.enc_embed_dim = enc_embed_dim
|
109 |
+
self.intermediate_dims = intermediate_dims
|
110 |
+
self.apply_pe = apply_pe
|
111 |
+
|
112 |
+
# Initialize the encoder with a pixel unshuffle and conv projection to patchify the input
|
113 |
+
self.unshuffle = nn.PixelUnshuffle(self.patch_size)
|
114 |
+
self.conv_in = nn.Conv2d(self.in_chans * (self.patch_size**2), self.intermediate_dims[0], 3, 1, 1)
|
115 |
+
|
116 |
+
# Add residual blocks
|
117 |
+
layers = []
|
118 |
+
for intermediate_idx in range(len(self.intermediate_dims) - 1):
|
119 |
+
layers.append(
|
120 |
+
ResidualBlock(
|
121 |
+
in_channels=self.intermediate_dims[intermediate_idx],
|
122 |
+
out_channels=self.intermediate_dims[intermediate_idx + 1],
|
123 |
+
act_layer=act_layer,
|
124 |
+
)
|
125 |
+
)
|
126 |
+
|
127 |
+
# Final projection to match encoder embeddings dim
|
128 |
+
layers.append(
|
129 |
+
nn.Conv2d(
|
130 |
+
in_channels=self.intermediate_dims[-1],
|
131 |
+
out_channels=self.enc_embed_dim,
|
132 |
+
kernel_size=1,
|
133 |
+
stride=1,
|
134 |
+
padding=0,
|
135 |
+
)
|
136 |
+
)
|
137 |
+
self.encoder = nn.Sequential(*layers)
|
138 |
+
|
139 |
+
# Init norm layer after encoder if required
|
140 |
+
self.norm_layer = norm_layer(enc_embed_dim) if norm_layer else nn.Identity()
|
141 |
+
if isinstance(self.norm_layer, nn.LayerNorm):
|
142 |
+
nn.init.constant_(self.norm_layer.bias, 0)
|
143 |
+
nn.init.constant_(self.norm_layer.weight, 1.0)
|
144 |
+
|
145 |
+
if self.apply_pe:
|
146 |
+
# Init the patch resolution details required for positional encoding
|
147 |
+
patch_HW = make_2tuple(patch_size)
|
148 |
+
self.input_size_for_pe = make_2tuple(input_size_for_pe)
|
149 |
+
self.patches_resolution = (
|
150 |
+
self.input_size_for_pe[0] // patch_HW[0],
|
151 |
+
self.input_size_for_pe[1] // patch_HW[1],
|
152 |
+
)
|
153 |
+
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
|
154 |
+
|
155 |
+
# Init the sinusodial positional encodings
|
156 |
+
self.register_buffer(
|
157 |
+
"pos_embed",
|
158 |
+
self._get_sinusoid_encoding_table(self.num_patches, self.enc_embed_dim, 70007),
|
159 |
+
)
|
160 |
+
self.interpolate_antialias = interpolate_antialias
|
161 |
+
self.interpolate_offset = interpolate_offset
|
162 |
+
|
163 |
+
# Init the norm layer after positional encoding if required
|
164 |
+
self.post_pe_norm = post_pe_norm_layer(enc_embed_dim) if post_pe_norm_layer else nn.Identity()
|
165 |
+
if isinstance(self.post_pe_norm, nn.LayerNorm):
|
166 |
+
nn.init.constant_(self.post_pe_norm.bias, 0)
|
167 |
+
nn.init.constant_(self.post_pe_norm.weight, 1.0)
|
168 |
+
|
169 |
+
# Load the pretrained checkpoint if provided
|
170 |
+
self.pretrained_checkpoint_path = pretrained_checkpoint_path
|
171 |
+
if self.pretrained_checkpoint_path:
|
172 |
+
print(
|
173 |
+
f"Loading custom pretrained Dense Representation Encoder checkpoint from {self.pretrained_checkpoint_path} ..."
|
174 |
+
)
|
175 |
+
ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
|
176 |
+
print(self.load_state_dict(ckpt["model"]))
|
177 |
+
|
178 |
+
def _get_sinusoid_encoding_table(self, n_position, d_hid, base):
|
179 |
+
"Sinusoid position encoding table"
|
180 |
+
|
181 |
+
def get_position_angle_vec(position):
|
182 |
+
return [position / np.power(base, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
183 |
+
|
184 |
+
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
185 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
|
186 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])
|
187 |
+
|
188 |
+
return torch.FloatTensor(sinusoid_table)
|
189 |
+
|
190 |
+
def interpolate_pos_encoding(self, features, height, width):
|
191 |
+
"""
|
192 |
+
Interpolate the positional encoding to the expected size.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
features (torch.Tensor): Input tensor of shape (B, N, C).
|
196 |
+
height (int, float): Height of the input tensor.
|
197 |
+
width (int, float): Width of the input tensor.
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
torch.Tensor: Interpolated positional encoding tensor of shape (1, N, C).
|
201 |
+
"""
|
202 |
+
previous_dtype = features.dtype
|
203 |
+
npatch = features.shape[1]
|
204 |
+
N = self.pos_embed.unsqueeze(0).shape[1]
|
205 |
+
if npatch == N and height == width:
|
206 |
+
return self.pos_embed.unsqueeze(0)
|
207 |
+
patch_pos_embed = self.pos_embed.unsqueeze(0).float()
|
208 |
+
dim = features.shape[-1]
|
209 |
+
height0 = height // self.patch_size
|
210 |
+
width0 = width // self.patch_size
|
211 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
212 |
+
assert N == M * M
|
213 |
+
kwargs = {}
|
214 |
+
if self.interpolate_offset:
|
215 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
216 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
217 |
+
sh = float(height0 + self.interpolate_offset) / M
|
218 |
+
sw = float(width0 + self.interpolate_offset) / M
|
219 |
+
kwargs["scale_factor"] = (sh, sw)
|
220 |
+
else:
|
221 |
+
# Simply specify an output size instead of a scale factor
|
222 |
+
kwargs["size"] = (height0, width0)
|
223 |
+
patch_pos_embed = nn.functional.interpolate(
|
224 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
225 |
+
mode="bicubic",
|
226 |
+
antialias=self.interpolate_antialias,
|
227 |
+
**kwargs,
|
228 |
+
)
|
229 |
+
assert (height0, width0) == patch_pos_embed.shape[-2:]
|
230 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
231 |
+
|
232 |
+
return patch_pos_embed.to(previous_dtype)
|
233 |
+
|
234 |
+
def forward(self, encoder_input: Union[ViTEncoderInput, ViTEncoderNonImageInput]) -> ViTEncoderOutput:
|
235 |
+
"""
|
236 |
+
Dense Representation Encoder Forward Pass
|
237 |
+
|
238 |
+
Args:
|
239 |
+
encoder_input (Union[ViTEncoderInput, ViTEncoderNonImageInput]): Input data for the encoder.
|
240 |
+
If input type is ViTEncoderInput, input data must contain image normalization type and normalized image tensor.
|
241 |
+
If input type is ViTEncoderNonImageInput, input data must contain a tensor of size (B, C, H, W).
|
242 |
+
|
243 |
+
Returns:
|
244 |
+
ViTEncoderOutput: Output data from the encoder.
|
245 |
+
"""
|
246 |
+
# Get the input data and verify normalization if the input type is ViTEncoderInput
|
247 |
+
if isinstance(encoder_input, ViTEncoderInput):
|
248 |
+
self._check_data_normalization_type(encoder_input.data_norm_type)
|
249 |
+
input_data = encoder_input.image
|
250 |
+
elif isinstance(encoder_input, ViTEncoderNonImageInput):
|
251 |
+
input_data = encoder_input.data
|
252 |
+
else:
|
253 |
+
raise ValueError("Unsupported input type for Dense Representation Encoder.")
|
254 |
+
|
255 |
+
# Check the dtype and shape of the input
|
256 |
+
assert isinstance(input_data, torch.Tensor), "Input must be a torch.Tensor"
|
257 |
+
assert input_data.ndim == 4, "Input must be of shape (B, C, H, W)"
|
258 |
+
assert input_data.shape[1] == self.in_chans, f"Input channels must be {self.in_chans}"
|
259 |
+
batch_size, channels, height, width = input_data.shape
|
260 |
+
assert (
|
261 |
+
height % self.patch_size == 0 and width % self.patch_size == 0
|
262 |
+
), f"Input shape must be divisible by patch size: {self.patch_size}"
|
263 |
+
|
264 |
+
# Encode the dense representation
|
265 |
+
features = self.unshuffle(input_data)
|
266 |
+
features = self.conv_in(features)
|
267 |
+
features = self.encoder(features)
|
268 |
+
features = features.flatten(2).transpose(
|
269 |
+
1, 2
|
270 |
+
) # (B, E, H / Patch_Size, W / Patch_Size) -> (B, H / Patch_Size * W / Patch_Size, E)
|
271 |
+
features = self.norm_layer(features) # Normalize the features after patch encoding
|
272 |
+
|
273 |
+
# Apply positional encoding if required
|
274 |
+
if self.apply_pe:
|
275 |
+
features = features + self.interpolate_pos_encoding(
|
276 |
+
features, height, width
|
277 |
+
) # (B, H / Patch_Size * W / Patch_Size, E)
|
278 |
+
features = self.post_pe_norm(features) # Normalize the features after positional encoding
|
279 |
+
|
280 |
+
# Resize the features to the expected shape
|
281 |
+
# (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size)
|
282 |
+
features = features.permute(0, 2, 1)
|
283 |
+
features = features.reshape(
|
284 |
+
-1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size
|
285 |
+
).contiguous()
|
286 |
+
|
287 |
+
return ViTEncoderOutput(features=features)
|
288 |
+
|
289 |
+
|
290 |
+
if __name__ == "__main__":
|
291 |
+
# Init Dense Representation Encoder for images as input
|
292 |
+
patch_embedder = DenseRepresentationEncoder(
|
293 |
+
name="dense_rep_encoder",
|
294 |
+
data_norm_type="dense_rep_encoder",
|
295 |
+
input_size_for_pe=518,
|
296 |
+
patch_size=14,
|
297 |
+
in_chans=3,
|
298 |
+
enc_embed_dim=1024,
|
299 |
+
apply_pe=False,
|
300 |
+
)
|
301 |
+
|
302 |
+
# Test dummy image input
|
303 |
+
dummy_image = torch.randn(1, 3, 518, 518)
|
304 |
+
patch_embedder_output = patch_embedder(ViTEncoderInput(data_norm_type="dense_rep_encoder", image=dummy_image))
|
305 |
+
assert patch_embedder_output.features.shape == (
|
306 |
+
1,
|
307 |
+
1024,
|
308 |
+
37,
|
309 |
+
37,
|
310 |
+
), "Output features must have shape (1, 1024, 37, 37)"
|
311 |
+
|
312 |
+
# Init Dense Representation Encoder for non-image data as input
|
313 |
+
patch_embedder = DenseRepresentationEncoder(
|
314 |
+
name="dense_rep_encoder",
|
315 |
+
data_norm_type="dense_rep_encoder",
|
316 |
+
input_size_for_pe=518,
|
317 |
+
patch_size=14,
|
318 |
+
in_chans=6,
|
319 |
+
enc_embed_dim=1024,
|
320 |
+
)
|
321 |
+
|
322 |
+
# Init Dense Representation Encoder for single channel input
|
323 |
+
patch_embedder = DenseRepresentationEncoder(
|
324 |
+
name="dense_rep_encoder",
|
325 |
+
data_norm_type="dense_rep_encoder",
|
326 |
+
input_size_for_pe=518,
|
327 |
+
patch_size=14,
|
328 |
+
in_chans=1,
|
329 |
+
enc_embed_dim=1024,
|
330 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
331 |
+
apply_pe=True,
|
332 |
+
)
|
333 |
+
|
334 |
+
# Test dummy non-image input
|
335 |
+
dummy_image = torch.randn(1, 1, 980, 980)
|
336 |
+
patch_embedder_output = patch_embedder(ViTEncoderNonImageInput(data=dummy_image))
|
337 |
+
assert patch_embedder_output.features.shape == (
|
338 |
+
1,
|
339 |
+
1024,
|
340 |
+
70,
|
341 |
+
70,
|
342 |
+
), "Output features must have shape (1, 1024, 70, 70)"
|
343 |
+
|
344 |
+
print("All variants of Dense Representation Encoder have been initialized successfully!")
|
UniCeption/uniception/models/encoders/dinov2.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Encoder Class for DINOv2
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import List, Optional, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput
|
12 |
+
from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner
|
13 |
+
|
14 |
+
|
15 |
+
class DINOv2Encoder(UniCeptionViTEncoderBase):
|
16 |
+
"UniCeption DINOv2 Encoder"
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
name: str,
|
21 |
+
data_norm_type: str = "dinov2",
|
22 |
+
patch_size: int = 14,
|
23 |
+
size: str = "large",
|
24 |
+
with_registers: bool = False,
|
25 |
+
pretrained_checkpoint_path: str = None,
|
26 |
+
torch_hub_force_reload: bool = False,
|
27 |
+
gradient_checkpointing: bool = False,
|
28 |
+
keep_first_n_layers: Optional[int] = None,
|
29 |
+
use_pytorch_sdpa=True,
|
30 |
+
*args,
|
31 |
+
**kwargs,
|
32 |
+
):
|
33 |
+
"""
|
34 |
+
DINOv2 Encoder for extracting spatial features from images.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
name (str): Name of the encoder.
|
38 |
+
data_norm_type (str): Image normalization type. Default: "dinov2"
|
39 |
+
patch_size (int): Patch size for the encoder. Default: 14
|
40 |
+
size (str): Size variant of the DINOv2 model. Options: ["small", "base", "large", "giant"]. Default: "large"
|
41 |
+
with_registers (bool): Whether to use the DINOv2 model with registers. Default: False
|
42 |
+
pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of DINOv2. Default: None
|
43 |
+
torch_hub_force_reload (bool): Whether to force reload the model from torch hub. Default: False
|
44 |
+
gradient_checkpointing (bool): Whether to use gradient checkpointing to save GPU memory during backward call. Default: False
|
45 |
+
keep_first_n_layers (Optional[int]): If specified, only the first n layers of the model will be kept. Default: None
|
46 |
+
use_pytorch_sdpa (bool): Whether to use PyTorch native SDPA for attention layers. Default: True
|
47 |
+
"""
|
48 |
+
# Init the base class
|
49 |
+
name = name if not with_registers else f"{name}_reg"
|
50 |
+
super().__init__(
|
51 |
+
name=name,
|
52 |
+
data_norm_type=data_norm_type,
|
53 |
+
patch_size=patch_size,
|
54 |
+
gradient_checkpointing=gradient_checkpointing,
|
55 |
+
*args,
|
56 |
+
**kwargs,
|
57 |
+
)
|
58 |
+
|
59 |
+
# Init the DINOv2 Encoder specific attributes
|
60 |
+
self.version = size
|
61 |
+
self.with_registers = with_registers
|
62 |
+
self.enc_embed_dim = {"small": 384, "base": 768, "large": 1024, "giant": 1536}[self.version]
|
63 |
+
|
64 |
+
# Define DINOv2 model factory
|
65 |
+
DINO_MODELS = {
|
66 |
+
# No registers
|
67 |
+
False: {
|
68 |
+
"small": "dinov2_vits14",
|
69 |
+
"base": "dinov2_vitb14",
|
70 |
+
"large": "dinov2_vitl14",
|
71 |
+
"giant": "dinov2_vitg14",
|
72 |
+
},
|
73 |
+
# With registers
|
74 |
+
True: {
|
75 |
+
"small": "dinov2_vits14_reg",
|
76 |
+
"base": "dinov2_vitb14_reg",
|
77 |
+
"large": "dinov2_vitl14_reg",
|
78 |
+
"giant": "dinov2_vitg14_reg",
|
79 |
+
},
|
80 |
+
}
|
81 |
+
|
82 |
+
# Load the pretrained DINOv2 model from torch hub
|
83 |
+
print(f"Loading pretrained {DINO_MODELS[self.with_registers][self.version]} from torch hub")
|
84 |
+
try: # Requires internet access
|
85 |
+
self.model = torch.hub.load(
|
86 |
+
"facebookresearch/dinov2",
|
87 |
+
DINO_MODELS[self.with_registers][self.version],
|
88 |
+
force_reload=torch_hub_force_reload,
|
89 |
+
)
|
90 |
+
except: # Load from cache
|
91 |
+
self.model = torch.hub.load("facebookresearch/dinov2", DINO_MODELS[self.with_registers][self.version])
|
92 |
+
|
93 |
+
del (
|
94 |
+
self.model.mask_token
|
95 |
+
) # This parameter is unused in producing patch features, and will lead to unused parameters
|
96 |
+
|
97 |
+
# Keep only the first n layers of the model if keep_first_n_layers is specified
|
98 |
+
if keep_first_n_layers is not None:
|
99 |
+
self.model.blocks = nn.ModuleList(self.model.blocks[:keep_first_n_layers])
|
100 |
+
|
101 |
+
# Use Native Torch SDPA for attention layers if specified (instead of DINOv2's XFormers)
|
102 |
+
if use_pytorch_sdpa:
|
103 |
+
self.enable_pytorch_native_sdpa()
|
104 |
+
|
105 |
+
# Wrap the transformer blocks with support for gradient checkpointing if required
|
106 |
+
if self.gradient_checkpointing:
|
107 |
+
for i in range(len(self.model.blocks)):
|
108 |
+
self.model.blocks[i] = self.wrap_module_with_gradient_checkpointing(self.model.blocks[i])
|
109 |
+
|
110 |
+
# Load the custom pretrained checkpoint if provided
|
111 |
+
if pretrained_checkpoint_path:
|
112 |
+
print(f"Loading custom pretrained DINOv2 checkpoint from {pretrained_checkpoint_path}")
|
113 |
+
ckpt = torch.load(pretrained_checkpoint_path, weights_only=False)
|
114 |
+
print(self.load_state_dict(ckpt["model"]))
|
115 |
+
|
116 |
+
def enable_pytorch_native_sdpa(self):
|
117 |
+
"Enable PyTorch native SDPA for attention layers"
|
118 |
+
for i in range(len(self.model.blocks)):
|
119 |
+
self.model.blocks[i].attn = self.wrap_dinov2_attention_with_sdpa(self.model.blocks[i].attn)
|
120 |
+
|
121 |
+
def wrap_dinov2_attention_with_sdpa(self, module: nn.Module):
|
122 |
+
"Wrap DINOv2 attention module with PyTorch native SDPA"
|
123 |
+
assert torch.__version__ >= "2.0", "SDPA requires PyTorch 2.0 or later"
|
124 |
+
|
125 |
+
class _AttentionWrapper(module.__class__):
|
126 |
+
"SDPA Attention Wrapper Class"
|
127 |
+
|
128 |
+
def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
129 |
+
B, N, C = x.shape
|
130 |
+
qkv = (
|
131 |
+
self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
132 |
+
) # (3, B, H, N, C // H)
|
133 |
+
|
134 |
+
q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
|
135 |
+
|
136 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_bias)
|
137 |
+
x = x.permute(0, 2, 1, 3).reshape(B, N, C)
|
138 |
+
|
139 |
+
x = self.proj(x)
|
140 |
+
x = self.proj_drop(x)
|
141 |
+
return x
|
142 |
+
|
143 |
+
module.__class__ = _AttentionWrapper
|
144 |
+
return module
|
145 |
+
|
146 |
+
def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput:
|
147 |
+
"""
|
148 |
+
DINOv2 Encoder Forward Pass
|
149 |
+
|
150 |
+
Args:
|
151 |
+
encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
ViTEncoderOutput: Output data from the encoder.
|
155 |
+
"""
|
156 |
+
# Check image normalization type
|
157 |
+
self._check_data_normalization_type(encoder_input.data_norm_type)
|
158 |
+
|
159 |
+
# Check the dtype and shape of the input image
|
160 |
+
assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor"
|
161 |
+
assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)"
|
162 |
+
batch_size, channels, height, width = encoder_input.image.shape
|
163 |
+
assert channels == 3, "Input must have 3 channels"
|
164 |
+
assert (
|
165 |
+
height % self.patch_size == 0 and width % self.patch_size == 0
|
166 |
+
), f"Input shape must be divisible by patch size: {self.patch_size}"
|
167 |
+
|
168 |
+
# Extract the features from the DINOv2 model
|
169 |
+
features = self.model.forward_features(encoder_input.image)["x_norm_patchtokens"]
|
170 |
+
|
171 |
+
# Resize the features to the expected shape
|
172 |
+
# (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size)
|
173 |
+
features = features.permute(0, 2, 1)
|
174 |
+
features = features.reshape(
|
175 |
+
-1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size
|
176 |
+
).contiguous()
|
177 |
+
|
178 |
+
return ViTEncoderOutput(features=features)
|
179 |
+
|
180 |
+
|
181 |
+
class DINOv2IntermediateFeatureReturner(DINOv2Encoder, IntermediateFeatureReturner):
|
182 |
+
"Intermediate Feature Returner for UniCeption DINOv2 Encoder"
|
183 |
+
|
184 |
+
def __init__(
|
185 |
+
self,
|
186 |
+
name: str,
|
187 |
+
data_norm_type: str = "dinov2",
|
188 |
+
patch_size: int = 14,
|
189 |
+
size: str = "large",
|
190 |
+
with_registers: bool = False,
|
191 |
+
pretrained_checkpoint_path: str = None,
|
192 |
+
indices: Optional[Union[int, List[int]]] = 1,
|
193 |
+
keep_first_n_layers: Optional[int] = None,
|
194 |
+
norm_intermediate: bool = True,
|
195 |
+
*args,
|
196 |
+
**kwargs,
|
197 |
+
):
|
198 |
+
"""
|
199 |
+
DINOv2 Encoder for extracting spatial features from images.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
name (str): Name of the encoder.
|
203 |
+
data_norm_type (str): Image normalization type. Default: "dinov2"
|
204 |
+
patch_size (int): Patch size for the encoder. Default: 14
|
205 |
+
size (str): Size variant of the DINOv2 model. Options: ["small", "base", "large", "giant"]
|
206 |
+
with_registers (bool): Whether to use the DINOv2 model with registers.
|
207 |
+
pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of DINOv2.
|
208 |
+
indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. Defaults to 1. Options:
|
209 |
+
- int: Return the last n layers.
|
210 |
+
- List[int]: Return the intermediate layers at the specified indices.
|
211 |
+
keep_first_n_layers (Optional[int], optional): If specified, only the first n layers of the model will be kept. Defaults to None.
|
212 |
+
norm_intermediate (bool, optional): Whether to normalize the intermediate features. Defaults to True.
|
213 |
+
"""
|
214 |
+
# Init the base classes
|
215 |
+
DINOv2Encoder.__init__(
|
216 |
+
self,
|
217 |
+
name=name,
|
218 |
+
data_norm_type=data_norm_type,
|
219 |
+
patch_size=patch_size,
|
220 |
+
size=size,
|
221 |
+
with_registers=with_registers,
|
222 |
+
keep_first_n_layers=keep_first_n_layers,
|
223 |
+
pretrained_checkpoint_path=pretrained_checkpoint_path,
|
224 |
+
*args,
|
225 |
+
**kwargs,
|
226 |
+
)
|
227 |
+
IntermediateFeatureReturner.__init__(
|
228 |
+
self,
|
229 |
+
indices=indices,
|
230 |
+
norm_intermediate=norm_intermediate,
|
231 |
+
)
|
232 |
+
|
233 |
+
def forward(self, encoder_input: ViTEncoderInput) -> List[ViTEncoderOutput]:
|
234 |
+
"""
|
235 |
+
DINOv2 Encoder Forward Pass with Intermediate Feature Return
|
236 |
+
|
237 |
+
Args:
|
238 |
+
encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
|
239 |
+
|
240 |
+
Returns:
|
241 |
+
List[ViTEncoderOutput]: Output data from the encoder. Returns a list of intermediate features.
|
242 |
+
"""
|
243 |
+
# Check image normalization type
|
244 |
+
self._check_data_normalization_type(encoder_input.data_norm_type)
|
245 |
+
|
246 |
+
# Check the dtype and shape of the input image
|
247 |
+
assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor"
|
248 |
+
assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)"
|
249 |
+
batch_size, channels, height, width = encoder_input.image.shape
|
250 |
+
assert channels == 3, "Input must have 3 channels"
|
251 |
+
assert (
|
252 |
+
height % self.patch_size == 0 and width % self.patch_size == 0
|
253 |
+
), f"Input shape must be divisible by patch size: {self.patch_size}"
|
254 |
+
|
255 |
+
if self.indices is None:
|
256 |
+
self.indices = range(len(self.model.blocks))
|
257 |
+
|
258 |
+
# Extract the intermediate features from the DINOv2 model
|
259 |
+
intermediate_features = self.model.get_intermediate_layers(
|
260 |
+
encoder_input.image, n=self.indices, reshape=True, norm=self.norm_intermediate
|
261 |
+
)
|
262 |
+
|
263 |
+
# Convert the intermediate features to a list of ViTEncoderOutput
|
264 |
+
intermediate_features = [ViTEncoderOutput(features=features) for features in intermediate_features]
|
265 |
+
|
266 |
+
return intermediate_features
|
267 |
+
|
268 |
+
|
269 |
+
if __name__ == "__main__":
|
270 |
+
# Init different variants of DINOv2
|
271 |
+
for size in ["small", "base", "large", "giant"]:
|
272 |
+
for with_registers in [False, True]:
|
273 |
+
name = f"dinov2_{size}"
|
274 |
+
dinov2_encoder = DINOv2Encoder(name=name, size=size, with_registers=with_registers)
|
275 |
+
|
276 |
+
# Init the custom pretrained DINOv2 encoders
|
277 |
+
for size in ["small", "base", "large"]:
|
278 |
+
pretrained_checkpoints_dict = {
|
279 |
+
"small": "../../../checkpoints/encoders/DINOv2_ViTS_DepthAnythingV2.pth",
|
280 |
+
"base": "../../../checkpoints/encoders/DINOv2_ViTB_DepthAnythingV2.pth",
|
281 |
+
"large": "../../../checkpoints/encoders/DINOv2_ViTL_DepthAnythingV2.pth",
|
282 |
+
}
|
283 |
+
name = f"dinov2_dav2_{size}"
|
284 |
+
dinov2_encoder = DINOv2Encoder(
|
285 |
+
name=name, size=size, with_registers=False, pretrained_checkpoint_path=pretrained_checkpoints_dict[size]
|
286 |
+
)
|
287 |
+
|
288 |
+
print("All DINOv2 Encoders have been initialized successfully!")
|
289 |
+
|
290 |
+
# Intermediate Feature Returner Tests
|
291 |
+
print("Running Intermediate Feature Returner Tests...")
|
292 |
+
|
293 |
+
# Run the intermediate feature returner with last-n index
|
294 |
+
dinov2_intermediate_feature_returner = DINOv2IntermediateFeatureReturner(
|
295 |
+
name="dinov2_base", size="base", indices=6
|
296 |
+
) # Last 6 layers
|
297 |
+
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dinov2")
|
298 |
+
output = dinov2_intermediate_feature_returner(dummy_input)
|
299 |
+
assert isinstance(output, list), "Output must be a list of intermediate features"
|
300 |
+
assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
|
301 |
+
assert len(output) == 6, "Output must have length of intermediate features equal to the number of indices"
|
302 |
+
|
303 |
+
# Run the intermediate feature returner with specific indices
|
304 |
+
dinov2_intermediate_feature_returner = DINOv2IntermediateFeatureReturner(
|
305 |
+
name="dinov2_base", size="base", indices=[0, 2, 4, 6]
|
306 |
+
) # Specific layers
|
307 |
+
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dinov2")
|
308 |
+
output = dinov2_intermediate_feature_returner(dummy_input)
|
309 |
+
assert isinstance(output, list), "Output must be a list of intermediate features"
|
310 |
+
assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
|
311 |
+
assert len(output) == 4, "Output must have length of intermediate features equal to the number of indices"
|
312 |
+
|
313 |
+
print("All Intermediate Feature Returner Tests have passed successfully!")
|
314 |
+
|
315 |
+
from uniception.models.encoders.utils import profile_encoder
|
316 |
+
|
317 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
318 |
+
torch.backends.cudnn.allow_tf32 = True
|
319 |
+
|
320 |
+
# Profile the DINOv2 Encoder
|
321 |
+
dinov2_encoder = DINOv2Encoder(
|
322 |
+
name="dinov2_large", size="large", use_pytorch_sdpa=True, gradient_checkpointing=True, keep_first_n_layers=12
|
323 |
+
).cuda()
|
324 |
+
dummy_input = ViTEncoderInput(image=torch.randn(24, 3, 560, 420).cuda(), data_norm_type="dinov2")
|
325 |
+
|
326 |
+
class Profiler:
|
327 |
+
@profile_encoder(num_warmup=3, num_runs=20, autocast_precision="bfloat16", use_compile=True, dynamic=False)
|
328 |
+
def run_fn(self):
|
329 |
+
output = dinov2_encoder(dummy_input)
|
330 |
+
return output
|
331 |
+
|
332 |
+
profiler = Profiler()
|
333 |
+
profiler.run_fn()
|
UniCeption/uniception/models/encoders/global_rep_encoder.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Encoder class for Global Representation Encoder
|
3 |
+
"""
|
4 |
+
|
5 |
+
from functools import partial
|
6 |
+
from typing import Callable, List, Optional, Type, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from uniception.models.encoders.base import EncoderGlobalRepInput, EncoderGlobalRepOutput
|
12 |
+
|
13 |
+
|
14 |
+
class GlobalRepresentationEncoder(nn.Module):
|
15 |
+
"UniCeption Global Representation Encoder"
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
name: str,
|
20 |
+
in_chans: int = 3,
|
21 |
+
enc_embed_dim: int = 1024,
|
22 |
+
intermediate_dims: List[int] = [128, 256, 512],
|
23 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
24 |
+
norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(nn.LayerNorm, eps=1e-6),
|
25 |
+
pretrained_checkpoint_path: Optional[str] = None,
|
26 |
+
*args,
|
27 |
+
**kwargs,
|
28 |
+
):
|
29 |
+
"""
|
30 |
+
Global Representation Encoder for projecting a global representation to a desired latent dimension.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
name (str): Name of the Encoder.
|
34 |
+
in_chans (int): Number of input channels.
|
35 |
+
enc_embed_dim (int): Embedding dimension of the encoder.
|
36 |
+
intermediate_dims (List[int]): List of intermediate dimensions of the encoder.
|
37 |
+
act_layer (Type[nn.Module]): Activation layer to use in the encoder.
|
38 |
+
norm_layer (Union[Type[nn.Module], Callable[..., nn.Module]]): Final normalization layer to use in the encoder.
|
39 |
+
pretrained_checkpoint_path (Optional[str]): Path to pretrained checkpoint. (default: None)
|
40 |
+
"""
|
41 |
+
super().__init__(*args, **kwargs)
|
42 |
+
|
43 |
+
# Initialize the attributes
|
44 |
+
self.name = name
|
45 |
+
self.in_chans = in_chans
|
46 |
+
self.enc_embed_dim = enc_embed_dim
|
47 |
+
self.intermediate_dims = intermediate_dims
|
48 |
+
self.pretrained_checkpoint_path = pretrained_checkpoint_path
|
49 |
+
|
50 |
+
# Init the activation layer
|
51 |
+
self.act_layer = act_layer()
|
52 |
+
|
53 |
+
# Initialize the encoder
|
54 |
+
self.encoder = nn.Sequential(
|
55 |
+
nn.Linear(self.in_chans, self.intermediate_dims[0]),
|
56 |
+
self.act_layer,
|
57 |
+
)
|
58 |
+
for intermediate_idx in range(1, len(self.intermediate_dims)):
|
59 |
+
self.encoder = nn.Sequential(
|
60 |
+
self.encoder,
|
61 |
+
nn.Linear(self.intermediate_dims[intermediate_idx - 1], self.intermediate_dims[intermediate_idx]),
|
62 |
+
self.act_layer,
|
63 |
+
)
|
64 |
+
self.encoder = nn.Sequential(
|
65 |
+
self.encoder,
|
66 |
+
nn.Linear(self.intermediate_dims[-1], self.enc_embed_dim),
|
67 |
+
)
|
68 |
+
|
69 |
+
# Init weights of the final norm layer
|
70 |
+
self.norm_layer = norm_layer(enc_embed_dim) if norm_layer else nn.Identity()
|
71 |
+
if isinstance(self.norm_layer, nn.LayerNorm):
|
72 |
+
nn.init.constant_(self.norm_layer.bias, 0)
|
73 |
+
nn.init.constant_(self.norm_layer.weight, 1.0)
|
74 |
+
|
75 |
+
# Load pretrained weights if provided
|
76 |
+
if self.pretrained_checkpoint_path is not None:
|
77 |
+
print(
|
78 |
+
f"Loading pretrained Global Representation Encoder checkpoint from {self.pretrained_checkpoint_path} ..."
|
79 |
+
)
|
80 |
+
ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
|
81 |
+
print(self.load_state_dict(ckpt["model"]))
|
82 |
+
|
83 |
+
def forward(self, encoder_input: EncoderGlobalRepInput) -> EncoderGlobalRepOutput:
|
84 |
+
"""
|
85 |
+
Global Representation Encoder Forward Pass
|
86 |
+
|
87 |
+
Args:
|
88 |
+
encoder_input (EncoderGlobalRepInput): Input data for the encoder.
|
89 |
+
The provided data must contain a tensor of size (B, C).
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
EncoderGlobalRepOutput: Output features from the encoder.
|
93 |
+
"""
|
94 |
+
# Get the input data and verify the shape of the input
|
95 |
+
input_data = encoder_input.data
|
96 |
+
assert input_data.ndim == 2, "Input data must have shape (B, C)"
|
97 |
+
assert input_data.shape[1] == self.in_chans, f"Input data must have {self.in_chans} channels"
|
98 |
+
|
99 |
+
# Encode the global representation
|
100 |
+
features = self.encoder(input_data)
|
101 |
+
|
102 |
+
# Normalize the output
|
103 |
+
features = self.norm_layer(features)
|
104 |
+
|
105 |
+
return EncoderGlobalRepOutput(features=features)
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == "__main__":
|
109 |
+
dummy_model = GlobalRepresentationEncoder(
|
110 |
+
name="dummy", in_chans=3, enc_embed_dim=1024, intermediate_dims=[128, 256, 512]
|
111 |
+
)
|
112 |
+
dummy_input = EncoderGlobalRepInput(data=torch.randn(4, 3))
|
113 |
+
dummy_output = dummy_model(dummy_input)
|
114 |
+
assert dummy_output.features.shape == (4, 1024), "Output features must have shape (B, 1024)"
|
115 |
+
print("Global Representation Encoder has been initialized successfully!")
|
UniCeption/uniception/models/encoders/image_normalizations.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Image normalizations for the different UniCeption image encoders.
|
3 |
+
Image encoders defined in UniCeption must have their corresponding image normalization defined here.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from dataclasses import dataclass
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class ImageNormalization:
|
13 |
+
mean: torch.Tensor
|
14 |
+
std: torch.Tensor
|
15 |
+
|
16 |
+
|
17 |
+
IMAGE_NORMALIZATION_DICT = {
|
18 |
+
"dummy": ImageNormalization(mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([1.0, 1.0, 1.0])),
|
19 |
+
"croco": ImageNormalization(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])),
|
20 |
+
"dust3r": ImageNormalization(mean=torch.tensor([0.5, 0.5, 0.5]), std=torch.tensor([0.5, 0.5, 0.5])),
|
21 |
+
"dinov2": ImageNormalization(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])),
|
22 |
+
"identity": ImageNormalization(mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([1.0, 1.0, 1.0])),
|
23 |
+
"patch_embedder": ImageNormalization(
|
24 |
+
mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])
|
25 |
+
),
|
26 |
+
"radio": ImageNormalization(mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([1.0, 1.0, 1.0])),
|
27 |
+
"sea_raft": ImageNormalization(
|
28 |
+
mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([1.0, 1.0, 1.0]) / 255
|
29 |
+
), # Sea-RAFT uses 0-255 in FP32
|
30 |
+
"unimatch": ImageNormalization(
|
31 |
+
mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([1.0, 1.0, 1.0]) / 255
|
32 |
+
), # UniMatch uses 0-255 in FP32
|
33 |
+
"roma": ImageNormalization(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])),
|
34 |
+
"cosmos": ImageNormalization(mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([0.5, 0.5, 0.5])),
|
35 |
+
}
|
UniCeption/uniception/models/encoders/list.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
List available UniCeption encoders.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
from uniception.models.encoders import print_available_encoder_models
|
8 |
+
|
9 |
+
if __name__ == "__main__":
|
10 |
+
print_available_encoder_models()
|
UniCeption/uniception/models/encoders/naradio.py
ADDED
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Encoder Class for NARADIO (RayFronts)
|
3 |
+
"""
|
4 |
+
|
5 |
+
import math
|
6 |
+
from typing import List, Optional, Tuple, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch.nn.attention.flex_attention import flex_attention
|
12 |
+
|
13 |
+
from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput
|
14 |
+
from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner
|
15 |
+
|
16 |
+
|
17 |
+
class GaussKernelAttn(nn.Module):
|
18 |
+
"""Implementation of Gaussian Kernel based Attention using FlexAttention"""
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
orig_attn,
|
23 |
+
gauss_std: float,
|
24 |
+
dim: int,
|
25 |
+
qk_norm: bool = False,
|
26 |
+
num_prefix_tokens: int = 8,
|
27 |
+
patch_size: int = 16,
|
28 |
+
) -> None:
|
29 |
+
super().__init__()
|
30 |
+
num_heads = orig_attn.num_heads
|
31 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
32 |
+
self.num_heads = num_heads
|
33 |
+
self.head_dim = dim // num_heads
|
34 |
+
self.scale = self.head_dim**-0.5
|
35 |
+
|
36 |
+
self.addition_cache = dict()
|
37 |
+
self.input_resolution = None # to be set when calling forward
|
38 |
+
self.gauss_std = gauss_std
|
39 |
+
self.patch_size = patch_size
|
40 |
+
|
41 |
+
self.qkv = orig_attn.qkv
|
42 |
+
self.q_norm = orig_attn.q_norm if qk_norm else nn.Identity()
|
43 |
+
self.k_norm = orig_attn.k_norm if qk_norm else nn.Identity()
|
44 |
+
self.attn_drop = orig_attn.attn_drop
|
45 |
+
self.proj = orig_attn.proj
|
46 |
+
self.proj_drop = orig_attn.proj_drop
|
47 |
+
self.num_prefix_tokens = num_prefix_tokens
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def gaussian_window(dim1, dim2, std=7.0):
|
51 |
+
constant = 1 / (std * math.sqrt(2))
|
52 |
+
ks = list()
|
53 |
+
for dim in [dim1, dim2]:
|
54 |
+
start = -(dim - 1) / 2.0
|
55 |
+
k = torch.linspace(start=start * constant, end=(start + (dim - 1)) * constant, steps=dim, dtype=torch.float)
|
56 |
+
ks.append(k)
|
57 |
+
dist_square_to_mu = (torch.stack(torch.meshgrid(*ks, indexing="ij")) ** 2).sum(0)
|
58 |
+
|
59 |
+
return torch.exp(-dist_square_to_mu)
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def get_attention_addition(dim1, dim2, window, num_prefix_tokens=8):
|
63 |
+
m = torch.einsum("ij,kl->ijkl", torch.eye(dim1), torch.eye(dim2))
|
64 |
+
m = m.permute((0, 3, 1, 2)).contiguous()
|
65 |
+
out = F.conv2d(m.view(-1, dim1, dim2).unsqueeze(1), window.unsqueeze(0).unsqueeze(1), padding="same").squeeze(1)
|
66 |
+
|
67 |
+
out = out.view(dim1 * dim2, dim1 * dim2)
|
68 |
+
if num_prefix_tokens > 0:
|
69 |
+
v_adjusted = torch.vstack([torch.zeros((num_prefix_tokens, dim1 * dim2)), out])
|
70 |
+
out = torch.hstack([torch.zeros((dim1 * dim2 + num_prefix_tokens, num_prefix_tokens)), v_adjusted])
|
71 |
+
|
72 |
+
return out
|
73 |
+
|
74 |
+
def prepare_gaussian_addition(self, n_patches, device):
|
75 |
+
"""Prepare the Gaussian addition matrix for the current input"""
|
76 |
+
# Check if we have a cached addition matrix for these dimensions
|
77 |
+
if n_patches not in self.addition_cache:
|
78 |
+
window_size = [side * 2 - 1 for side in n_patches]
|
79 |
+
window = self.gaussian_window(*window_size, std=self.gauss_std)
|
80 |
+
addition = self.get_attention_addition(*n_patches, window, self.num_prefix_tokens).to(device)
|
81 |
+
|
82 |
+
# Cache the addition matrix
|
83 |
+
self.addition_cache[n_patches] = addition
|
84 |
+
|
85 |
+
# Return the cached addition matrix
|
86 |
+
return self.addition_cache[n_patches]
|
87 |
+
|
88 |
+
def gauss_score_mod(self, score, b, h, q_idx, kv_idx, addition):
|
89 |
+
"""Score modification function for FlexAttention"""
|
90 |
+
# Adding the precomputed Gaussian pattern to the attention score
|
91 |
+
return score + addition[q_idx, kv_idx]
|
92 |
+
|
93 |
+
def set_input_resolution(self, input_resolution: Tuple[int, int]):
|
94 |
+
"""Set the input resolution for the Gaussian attention window"""
|
95 |
+
self.input_resolution = input_resolution
|
96 |
+
|
97 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
98 |
+
B, N, C = x.shape
|
99 |
+
assert self.input_resolution is not None, "input_resolution must be set before forward pass"
|
100 |
+
h, w = self.input_resolution
|
101 |
+
n_patches = (w // self.patch_size, h // self.patch_size)
|
102 |
+
|
103 |
+
qkv = self.qkv(x)
|
104 |
+
q, k, v = qkv.chunk(3, dim=-1)
|
105 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
106 |
+
|
107 |
+
q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
108 |
+
k = k.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
109 |
+
v = v.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
110 |
+
|
111 |
+
addition = self.prepare_gaussian_addition(n_patches, device=x.device)
|
112 |
+
|
113 |
+
# Create a score_mod function with the current addition matrix
|
114 |
+
score_mod = lambda score, b, h, q_idx, kv_idx: self.gauss_score_mod(score, b, h, q_idx, kv_idx, addition)
|
115 |
+
|
116 |
+
# Use FlexAttention
|
117 |
+
attn_output = flex_attention(q, k, v, score_mod=score_mod)
|
118 |
+
|
119 |
+
# Reshape output and apply projection
|
120 |
+
attn_output = attn_output.transpose(1, 2).reshape(B, N, C)
|
121 |
+
attn_output = self.proj(attn_output)
|
122 |
+
attn_output = self.proj_drop(attn_output)
|
123 |
+
|
124 |
+
return attn_output
|
125 |
+
|
126 |
+
|
127 |
+
class NARADIOEncoder(UniCeptionViTEncoderBase):
|
128 |
+
"""
|
129 |
+
UniCeption NARADIO (RayFronts) Encoder based on NACLIP & RADIO
|
130 |
+
|
131 |
+
The model modifies the attention of the last layer of RADIO following NACLIP,
|
132 |
+
thereby improving the spatial patch features.
|
133 |
+
"""
|
134 |
+
|
135 |
+
def __init__(
|
136 |
+
self,
|
137 |
+
name: str,
|
138 |
+
data_norm_type: str = "radio",
|
139 |
+
patch_size: int = 16,
|
140 |
+
model_version: str = "radio_v2.5-l",
|
141 |
+
gauss_std: float = 7.0,
|
142 |
+
pretrained_checkpoint_path: str = None,
|
143 |
+
eradio_input_shape: Optional[tuple] = None,
|
144 |
+
torch_hub_force_reload: bool = False,
|
145 |
+
keep_first_n_layers: Optional[int] = None,
|
146 |
+
*args,
|
147 |
+
**kwargs,
|
148 |
+
):
|
149 |
+
"""
|
150 |
+
NARADIO Encoder for extracting spatial features from images.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
name (str): Name of the encoder.
|
154 |
+
data_norm_type (str): Image normalization type. Default: "radio"
|
155 |
+
patch_size (int): Patch size for the encoder. Default: 16
|
156 |
+
model_version (str): Version of the RADIO model to load. Default: "radio_v2.5-l"
|
157 |
+
gauss_std: Standard deviation of the gaussian kernel. Default: 7.0
|
158 |
+
pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of RADIO. Default: None
|
159 |
+
eradio_input_shape (tuple): Input shape (height, width) for E-RADIO models. Default: None
|
160 |
+
torch_hub_force_reload (bool): Whether to force reload the model from torch hub. Default: False
|
161 |
+
keep_first_n_layers (Optional[int]): Number of layers to keep from the pretrained model. Default: None
|
162 |
+
"""
|
163 |
+
# Init the base class
|
164 |
+
super().__init__(
|
165 |
+
name=name,
|
166 |
+
data_norm_type=data_norm_type,
|
167 |
+
patch_size=patch_size,
|
168 |
+
*args,
|
169 |
+
**kwargs,
|
170 |
+
)
|
171 |
+
|
172 |
+
# Init the RADIO Encoder specific attributes
|
173 |
+
self.model_version = model_version
|
174 |
+
self.enc_embed_dim = {
|
175 |
+
"radio_v2.5-b": 768,
|
176 |
+
"radio_v2.5-l": 1024,
|
177 |
+
"radio_v2.5-h": 1280,
|
178 |
+
"radio_v2.5-g": 1536,
|
179 |
+
"e-radio_v2": 1536,
|
180 |
+
}[self.model_version]
|
181 |
+
|
182 |
+
if self.model_version == "radio_v2.5-g":
|
183 |
+
assert patch_size == 14, "Patch size must be 14 for RADIO v2.5-g"
|
184 |
+
else:
|
185 |
+
assert patch_size == 16, "Patch size must be 16 for all other versions of RADIO"
|
186 |
+
|
187 |
+
# Load the pretrained RADIO model from torch hub
|
188 |
+
print(f"Loading pretrained {self.model_version} from torch hub")
|
189 |
+
try: # Requires internet access
|
190 |
+
self.model = torch.hub.load(
|
191 |
+
"NVlabs/RADIO",
|
192 |
+
"radio_model",
|
193 |
+
version=self.model_version,
|
194 |
+
progress=True,
|
195 |
+
skip_validation=True,
|
196 |
+
force_reload=torch_hub_force_reload,
|
197 |
+
)
|
198 |
+
except: # Load from cache
|
199 |
+
self.model = torch.hub.load(
|
200 |
+
"NVlabs/RADIO",
|
201 |
+
"radio_model",
|
202 |
+
version=self.model_version,
|
203 |
+
progress=True,
|
204 |
+
skip_validation=True,
|
205 |
+
)
|
206 |
+
|
207 |
+
# Delete the excess blocks if keep_first_n_layers is specified
|
208 |
+
if keep_first_n_layers is not None:
|
209 |
+
assert keep_first_n_layers < len(
|
210 |
+
self.model.model.blocks
|
211 |
+
), "keep_first_n_layers must be less than the number of blocks"
|
212 |
+
print(f"Keeping only the first {keep_first_n_layers} layers of the model")
|
213 |
+
self.model.model.blocks = torch.nn.ModuleList(self.model.model.blocks[:keep_first_n_layers])
|
214 |
+
|
215 |
+
# Set the optimal window size for E-RADIO models
|
216 |
+
if "e-radio" in self.model_version:
|
217 |
+
assert eradio_input_shape is not None, "Input shape (height, width) must be provided for E-RADIO models"
|
218 |
+
self.model.model.set_optimal_window_size(eradio_input_shape)
|
219 |
+
|
220 |
+
# Load the custom pretrained checkpoint if provided
|
221 |
+
if pretrained_checkpoint_path is not None:
|
222 |
+
print(f"Loading custom pretrained NARADIO checkpoint from {pretrained_checkpoint_path}")
|
223 |
+
ckpt = torch.load(pretrained_checkpoint_path, weights_only=False)
|
224 |
+
print(self.load_state_dict(ckpt["model"]))
|
225 |
+
|
226 |
+
# Replace the attention of the last ViT block with the Gaussian Kernel based attention
|
227 |
+
self.model.model.blocks[-1] = GaussKernelAttn(
|
228 |
+
self.model.model.blocks[-1].attn,
|
229 |
+
gauss_std,
|
230 |
+
dim=self.enc_embed_dim,
|
231 |
+
num_prefix_tokens=self.model.num_summary_tokens,
|
232 |
+
patch_size=self.patch_size,
|
233 |
+
)
|
234 |
+
|
235 |
+
def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput:
|
236 |
+
"""
|
237 |
+
NARADIO Encoder Forward Pass
|
238 |
+
|
239 |
+
Args:
|
240 |
+
encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
ViTEncoderOutput: Output data from the encoder.
|
244 |
+
"""
|
245 |
+
# Check image normalization type
|
246 |
+
self._check_data_normalization_type(encoder_input.data_norm_type)
|
247 |
+
|
248 |
+
# Check the dtype and shape of the input image
|
249 |
+
assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor"
|
250 |
+
assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)"
|
251 |
+
batch_size, channels, height, width = encoder_input.image.shape
|
252 |
+
assert channels == 3, "Input must have 3 channels"
|
253 |
+
assert (
|
254 |
+
height % self.patch_size == 0 and width % self.patch_size == 0
|
255 |
+
), f"Input shape must be divisible by patch size: {self.patch_size}"
|
256 |
+
|
257 |
+
# Set input resolution for Gaussian attention
|
258 |
+
self.model.model.blocks[-1].set_input_resolution((height, width))
|
259 |
+
|
260 |
+
# Forward pass throught the RADIO encoder
|
261 |
+
summary, features = self.model(encoder_input.image)
|
262 |
+
|
263 |
+
# Resize the features to the expected shape
|
264 |
+
# (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size)
|
265 |
+
features = features.permute(0, 2, 1)
|
266 |
+
features = features.reshape(
|
267 |
+
-1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size
|
268 |
+
).contiguous()
|
269 |
+
|
270 |
+
return ViTEncoderOutput(features=features)
|
271 |
+
|
272 |
+
|
273 |
+
class NARADIOIntermediateFeatureReturner(NARADIOEncoder, IntermediateFeatureReturner):
|
274 |
+
"Intermediate Feature Returner for UniCeption NARADIO Encoder"
|
275 |
+
|
276 |
+
def __init__(
|
277 |
+
self,
|
278 |
+
name: str,
|
279 |
+
data_norm_type: str = "radio",
|
280 |
+
patch_size: int = 16,
|
281 |
+
model_version: str = "radio_v2.5-l",
|
282 |
+
gauss_std: float = 7.0,
|
283 |
+
pretrained_checkpoint_path: str = None,
|
284 |
+
eradio_input_shape: Optional[tuple] = None,
|
285 |
+
indices: Union[int, List[int]] = [-1],
|
286 |
+
norm_intermediate: bool = True,
|
287 |
+
stop_early: bool = False,
|
288 |
+
intermediates_only: bool = True,
|
289 |
+
feature_adaptor: Optional[str] = None,
|
290 |
+
keep_first_n_layers: Optional[int] = None,
|
291 |
+
*args,
|
292 |
+
**kwargs,
|
293 |
+
):
|
294 |
+
"""
|
295 |
+
Intermediate Feature Returner for the NARADIO Encoder.
|
296 |
+
|
297 |
+
Args:
|
298 |
+
name (str): Name of the encoder.
|
299 |
+
data_norm_type (str): Image normalization type. Default: "radio"
|
300 |
+
patch_size (int): Patch size for the encoder. Default: 16
|
301 |
+
model_version (str): Version of the RADIO model to load. Default: "radio_v2.5-l"
|
302 |
+
gauss_std (float): Standard deviation of the gaussian kernel. Default: 7.0
|
303 |
+
pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of RADIO.
|
304 |
+
eradio_input_shape (tuple): Input shape (height, width) for E-RADIO models. Default: None
|
305 |
+
indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. Defaults to [-1]. Options:
|
306 |
+
- int: Return the last n layers.
|
307 |
+
- List[int]: Return the intermediate layers at the specified indices.
|
308 |
+
norm_intermediate (bool, optional): Whether to normalize the intermediate features. Defaults to True.
|
309 |
+
stop_early (bool, optional): Whether to stop early. Defaults to False.
|
310 |
+
intermediates_only (bool, optional): Whether to return only the intermediate features. Defaults to True.
|
311 |
+
feature_adaptor (Optional[str], optional): Feature adaptor to use. Defaults to None. Currently supported: "dino_v2".
|
312 |
+
keep_first_n_layers (Optional[int], optional): Number of layers to keep from the pretrained model. Defaults to None.
|
313 |
+
"""
|
314 |
+
# Init the base classes
|
315 |
+
NARADIOEncoder.__init__(
|
316 |
+
self,
|
317 |
+
name=name,
|
318 |
+
data_norm_type=data_norm_type,
|
319 |
+
patch_size=patch_size,
|
320 |
+
model_version=model_version,
|
321 |
+
gauss_std=gauss_std,
|
322 |
+
pretrained_checkpoint_path=pretrained_checkpoint_path,
|
323 |
+
eradio_input_shape=eradio_input_shape,
|
324 |
+
keep_first_n_layers=keep_first_n_layers,
|
325 |
+
*args,
|
326 |
+
**kwargs,
|
327 |
+
)
|
328 |
+
IntermediateFeatureReturner.__init__(
|
329 |
+
self,
|
330 |
+
indices=indices,
|
331 |
+
norm_intermediate=norm_intermediate,
|
332 |
+
stop_early=stop_early,
|
333 |
+
intermediates_only=intermediates_only,
|
334 |
+
)
|
335 |
+
|
336 |
+
# Convert indices to absolute indices if indices is None
|
337 |
+
if self.indices is None:
|
338 |
+
self.indices = list(range(len(self.model.model.blocks)))
|
339 |
+
|
340 |
+
self.feature_adaptor = feature_adaptor
|
341 |
+
if self.feature_adaptor is None:
|
342 |
+
pass
|
343 |
+
elif self.feature_adaptor == "dino_v2":
|
344 |
+
# Initialize a dummy radio encoder with the adaptor setting
|
345 |
+
dummy_model = torch.hub.load(
|
346 |
+
"NVlabs/RADIO",
|
347 |
+
"radio_model",
|
348 |
+
version=self.model_version,
|
349 |
+
progress=True,
|
350 |
+
skip_validation=True,
|
351 |
+
adaptor_names="dino_v2",
|
352 |
+
)
|
353 |
+
|
354 |
+
# Extract its feature converter weights
|
355 |
+
self.spatial_feature_converter = dummy_model.adaptors["dino_v2"].feat_mlp
|
356 |
+
|
357 |
+
# Update the embedding dimension because the features have been projected
|
358 |
+
self.enc_embed_dim = self.spatial_feature_converter.final[-1].out_features
|
359 |
+
|
360 |
+
del dummy_model
|
361 |
+
else:
|
362 |
+
raise ValueError("Unsupported feature adaptor. Supported: dino_v2")
|
363 |
+
|
364 |
+
def forward(
|
365 |
+
self, encoder_input: ViTEncoderInput
|
366 |
+
) -> Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]:
|
367 |
+
"""
|
368 |
+
NARADIO Encoder Forward Pass with Intermediate Feature Return
|
369 |
+
|
370 |
+
Args:
|
371 |
+
encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
|
372 |
+
|
373 |
+
Returns:
|
374 |
+
Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]: Output data from the encoder.
|
375 |
+
If `intermediates_only` is True, returns a list of intermediate features.
|
376 |
+
Otherwise, returns a tuple with the final features and a list of intermediate features.
|
377 |
+
"""
|
378 |
+
# Check image normalization type
|
379 |
+
self._check_data_normalization_type(encoder_input.data_norm_type)
|
380 |
+
|
381 |
+
# Check the dtype and shape of the input image
|
382 |
+
assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor"
|
383 |
+
assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)"
|
384 |
+
batch_size, channels, height, width = encoder_input.image.shape
|
385 |
+
assert channels == 3, "Input must have 3 channels"
|
386 |
+
assert (
|
387 |
+
height % self.patch_size == 0 and width % self.patch_size == 0
|
388 |
+
), f"Input shape must be divisible by patch size: {self.patch_size}"
|
389 |
+
|
390 |
+
# Set input resolution for Gaussian attention
|
391 |
+
self.model.model.blocks[-1].set_input_resolution((height, width))
|
392 |
+
|
393 |
+
# Extract the final features and intermediate features accordingly
|
394 |
+
model_outputs = self.model.forward_intermediates(
|
395 |
+
encoder_input.image,
|
396 |
+
indices=self.indices,
|
397 |
+
return_prefix_tokens=False,
|
398 |
+
norm=self.norm_intermediate,
|
399 |
+
stop_early=self.stop_early,
|
400 |
+
output_fmt="NLC",
|
401 |
+
intermediates_only=self.intermediates_only,
|
402 |
+
)
|
403 |
+
|
404 |
+
# Extract the final features and intermediate features accordingly
|
405 |
+
final_features, intermediate_features = None, None
|
406 |
+
if self.intermediates_only:
|
407 |
+
intermediate_features = model_outputs
|
408 |
+
else:
|
409 |
+
final_features = model_outputs[0].features.contiguous()
|
410 |
+
intermediate_features = model_outputs[1]
|
411 |
+
|
412 |
+
# Optionally convert the features using the feature adaptor
|
413 |
+
Hp, Wp = height // self.patch_size, width // self.patch_size
|
414 |
+
|
415 |
+
# Convert final features
|
416 |
+
if final_features is not None:
|
417 |
+
if self.feature_adaptor is not None:
|
418 |
+
final_features = self.spatial_feature_converter(final_features)
|
419 |
+
|
420 |
+
# Convert to BCHW and package
|
421 |
+
final_features = final_features.view(batch_size, Hp, Wp, -1).permute(0, 3, 1, 2)
|
422 |
+
final_features = ViTEncoderOutput(features=final_features)
|
423 |
+
|
424 |
+
# Convert intermediate features
|
425 |
+
if intermediate_features is not None:
|
426 |
+
num_intermediate = len(intermediate_features)
|
427 |
+
all_intermediate_feats_tensor = torch.cat(intermediate_features, dim=0)
|
428 |
+
if self.feature_adaptor is not None:
|
429 |
+
all_intermediate_feats_tensor = self.spatial_feature_converter(all_intermediate_feats_tensor)
|
430 |
+
# Convert to BCHW
|
431 |
+
all_intermediate_feats_tensor = all_intermediate_feats_tensor.view(
|
432 |
+
num_intermediate * batch_size, Hp, Wp, -1
|
433 |
+
).permute(0, 3, 1, 2)
|
434 |
+
all_intermediate_feats = torch.chunk(all_intermediate_feats_tensor, num_intermediate, dim=0)
|
435 |
+
intermediate_features = [ViTEncoderOutput(features=x) for x in all_intermediate_feats]
|
436 |
+
|
437 |
+
# Return the final features and intermediate features accordingly
|
438 |
+
if self.intermediates_only:
|
439 |
+
return intermediate_features
|
440 |
+
else:
|
441 |
+
return final_features, intermediate_features
|
442 |
+
|
443 |
+
|
444 |
+
if __name__ == "__main__":
|
445 |
+
# Init different versions of the RADIO Encoder
|
446 |
+
for model_version in ["radio_v2.5-b", "radio_v2.5-l"]:
|
447 |
+
naradio_encoder = NARADIOEncoder(name="NARADIOv2.5", model_version=model_version)
|
448 |
+
|
449 |
+
print("All NARADIO Encoders have been initialized successfully!")
|
450 |
+
|
451 |
+
# Intermediate Feature Returner Tests
|
452 |
+
print("Running Intermediate Feature Returner Tests...")
|
453 |
+
|
454 |
+
# Run the intermediate feature returner with last-n index
|
455 |
+
naradio_intermediate_feature_returner = NARADIOIntermediateFeatureReturner(
|
456 |
+
name="NARADIOv2.5", model_version="radio_v2.5-b", indices=6
|
457 |
+
) # Last 6 layers
|
458 |
+
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio")
|
459 |
+
output = naradio_intermediate_feature_returner(dummy_input)
|
460 |
+
assert isinstance(output, list), "Output must be a list of intermediate features"
|
461 |
+
assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
|
462 |
+
assert len(output) == 6, "Output must have length of intermediate features equal to the number of indices"
|
463 |
+
|
464 |
+
# Run the intermediate feature returner with specific indices
|
465 |
+
naradio_intermediate_feature_returner = NARADIOIntermediateFeatureReturner(
|
466 |
+
name="NARADIOv2.5", model_version="radio_v2.5-b", indices=[0, 2, 4, 6]
|
467 |
+
) # Specific layers
|
468 |
+
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio")
|
469 |
+
output = naradio_intermediate_feature_returner(dummy_input)
|
470 |
+
assert isinstance(output, list), "Output must be a list of intermediate features"
|
471 |
+
assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
|
472 |
+
assert len(output) == 4, "Output must have length of intermediate features equal to the number of indices"
|
473 |
+
|
474 |
+
# Test the normalizing of intermediate features
|
475 |
+
naradio_intermediate_feature_returner = NARADIOIntermediateFeatureReturner(
|
476 |
+
name="NARADIOv2.5", model_version="radio_v2.5-b", norm_intermediate=False, intermediates_only=False
|
477 |
+
) # Do not normalize
|
478 |
+
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio")
|
479 |
+
output = naradio_intermediate_feature_returner(dummy_input)
|
480 |
+
assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features"
|
481 |
+
assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features"
|
482 |
+
assert isinstance(output[1], list), "Second element of output must be a list of intermediate features"
|
483 |
+
assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
|
484 |
+
if not isinstance(naradio_intermediate_feature_returner.model.model.norm, torch.nn.Identity):
|
485 |
+
assert not torch.equal(
|
486 |
+
output[0].features, output[1][0].features
|
487 |
+
), "Final features and intermediate features must be different"
|
488 |
+
|
489 |
+
naradio_intermediate_feature_returner = NARADIOIntermediateFeatureReturner(
|
490 |
+
name="NARADIOv2.5", model_version="radio_v2.5-b", norm_intermediate=True, intermediates_only=False
|
491 |
+
)
|
492 |
+
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio")
|
493 |
+
output = naradio_intermediate_feature_returner(dummy_input)
|
494 |
+
assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features"
|
495 |
+
assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features"
|
496 |
+
assert isinstance(output[1], list), "Second element of output must be a list of intermediate features"
|
497 |
+
assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
|
498 |
+
assert torch.equal(
|
499 |
+
output[0].features, output[1][0].features
|
500 |
+
), "Final features and intermediate features must be same"
|
501 |
+
|
502 |
+
print("All Intermediate Feature Returner Tests have passed successfully!")
|
UniCeption/uniception/models/encoders/patch_embedder.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Encoder class for Patch Embedder
|
3 |
+
"""
|
4 |
+
|
5 |
+
import math
|
6 |
+
from functools import partial
|
7 |
+
from typing import Callable, Optional, Tuple, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch.nn.init import trunc_normal_
|
12 |
+
|
13 |
+
from uniception.models.encoders.base import (
|
14 |
+
UniCeptionViTEncoderBase,
|
15 |
+
ViTEncoderInput,
|
16 |
+
ViTEncoderNonImageInput,
|
17 |
+
ViTEncoderOutput,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def make_2tuple(x):
|
22 |
+
if isinstance(x, tuple):
|
23 |
+
assert len(x) == 2
|
24 |
+
return x
|
25 |
+
|
26 |
+
assert isinstance(x, int)
|
27 |
+
return (x, x)
|
28 |
+
|
29 |
+
|
30 |
+
class PatchEmbedder(UniCeptionViTEncoderBase):
|
31 |
+
"UniCeption Patch Embedder"
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
name: str,
|
36 |
+
data_norm_type: str = "patch_embedder",
|
37 |
+
input_size: Union[int, Tuple[int, int]] = 518,
|
38 |
+
patch_size: int = 14,
|
39 |
+
in_chans: int = 3,
|
40 |
+
enc_embed_dim: int = 1024,
|
41 |
+
norm_layer: Optional[Callable] = None,
|
42 |
+
post_pe_norm_layer: Optional[Callable] = partial(nn.LayerNorm, eps=1e-6),
|
43 |
+
interpolate_antialias: bool = False,
|
44 |
+
interpolate_offset: float = 0.1,
|
45 |
+
pretrained_checkpoint_path: str = None,
|
46 |
+
*args,
|
47 |
+
**kwargs,
|
48 |
+
):
|
49 |
+
"""
|
50 |
+
Patch Encoder for extracting patch-wise features from a spatial input of size (B, C, H, W).
|
51 |
+
Learnable positional encoding is also applied to the patch-wise features.
|
52 |
+
"""
|
53 |
+
# Init the base class
|
54 |
+
super().__init__(
|
55 |
+
name=name,
|
56 |
+
data_norm_type=data_norm_type,
|
57 |
+
patch_size=patch_size,
|
58 |
+
*args,
|
59 |
+
**kwargs,
|
60 |
+
)
|
61 |
+
|
62 |
+
# Init the Patch Embedder specific attributes
|
63 |
+
patch_HW = make_2tuple(patch_size)
|
64 |
+
self.input_size = make_2tuple(input_size)
|
65 |
+
self.patches_resolution = (self.input_size[0] // patch_HW[0], self.input_size[1] // patch_HW[1])
|
66 |
+
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
|
67 |
+
self.in_chans = in_chans
|
68 |
+
self.enc_embed_dim = enc_embed_dim
|
69 |
+
|
70 |
+
# Init the Patch Embedder layers
|
71 |
+
self.proj = nn.Conv2d(in_chans, enc_embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
72 |
+
self.norm = norm_layer(enc_embed_dim) if norm_layer else nn.Identity()
|
73 |
+
|
74 |
+
# Init the learnable positional encodings
|
75 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, self.enc_embed_dim))
|
76 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
77 |
+
self.interpolate_antialias = interpolate_antialias
|
78 |
+
self.interpolate_offset = interpolate_offset
|
79 |
+
|
80 |
+
# Init the norm layer after positional encoding
|
81 |
+
self.post_pe_norm = post_pe_norm_layer(enc_embed_dim) if post_pe_norm_layer else nn.Identity()
|
82 |
+
|
83 |
+
# Load the pretrained checkpoint if provided
|
84 |
+
self.pretrained_checkpoint_path = pretrained_checkpoint_path
|
85 |
+
if self.pretrained_checkpoint_path:
|
86 |
+
print(f"Loading custom pretrained Patch Embedder checkpoint from {self.pretrained_checkpoint_path} ...")
|
87 |
+
ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
|
88 |
+
print(self.load_state_dict(ckpt["model"]))
|
89 |
+
|
90 |
+
def interpolate_pos_encoding(self, features, height, width):
|
91 |
+
"""
|
92 |
+
Interpolate the positional encoding to the expected size.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
features (torch.Tensor): Input tensor of shape (B, N, C).
|
96 |
+
height (int, float): Height of the input tensor.
|
97 |
+
width (int, float): Width of the input tensor.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
torch.Tensor: Interpolated positional encoding tensor of shape (1, N, C).
|
101 |
+
"""
|
102 |
+
previous_dtype = features.dtype
|
103 |
+
npatch = features.shape[1]
|
104 |
+
N = self.pos_embed.shape[1]
|
105 |
+
if npatch == N and height == width:
|
106 |
+
return self.pos_embed
|
107 |
+
patch_pos_embed = self.pos_embed.float()
|
108 |
+
dim = features.shape[-1]
|
109 |
+
height0 = height // self.patch_size
|
110 |
+
width0 = width // self.patch_size
|
111 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
112 |
+
assert N == M * M
|
113 |
+
kwargs = {}
|
114 |
+
if self.interpolate_offset:
|
115 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
116 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
117 |
+
sh = float(height0 + self.interpolate_offset) / M
|
118 |
+
sw = float(width0 + self.interpolate_offset) / M
|
119 |
+
kwargs["scale_factor"] = (sh, sw)
|
120 |
+
else:
|
121 |
+
# Simply specify an output size instead of a scale factor
|
122 |
+
kwargs["size"] = (height0, width0)
|
123 |
+
patch_pos_embed = nn.functional.interpolate(
|
124 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
125 |
+
mode="bicubic",
|
126 |
+
antialias=self.interpolate_antialias,
|
127 |
+
**kwargs,
|
128 |
+
)
|
129 |
+
assert (height0, width0) == patch_pos_embed.shape[-2:]
|
130 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
131 |
+
|
132 |
+
return patch_pos_embed.to(previous_dtype)
|
133 |
+
|
134 |
+
def forward(self, encoder_input: Union[ViTEncoderInput, ViTEncoderNonImageInput]) -> ViTEncoderOutput:
|
135 |
+
"""
|
136 |
+
Patch Embedder Forward Pass
|
137 |
+
|
138 |
+
Args:
|
139 |
+
encoder_input (Union[ViTEncoderInput, ViTEncoderNonImageInput]): Input data for the encoder.
|
140 |
+
If input type is ViTEncoderInput, input data must contain image normalization type and normalized image tensor.
|
141 |
+
If input type is ViTEncoderNonImageInput, input data must contain a tensor of size (B, C, H, W).
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
ViTEncoderOutput: Output data from the encoder.
|
145 |
+
"""
|
146 |
+
# Get the input data and verify normalization if the input type is ViTEncoderInput
|
147 |
+
if isinstance(encoder_input, ViTEncoderInput):
|
148 |
+
self._check_data_normalization_type(encoder_input.data_norm_type)
|
149 |
+
input_data = encoder_input.image
|
150 |
+
elif isinstance(encoder_input, ViTEncoderNonImageInput):
|
151 |
+
input_data = encoder_input.data
|
152 |
+
else:
|
153 |
+
raise ValueError("Unsupported input type for Patch Embedder.")
|
154 |
+
|
155 |
+
# Check the dtype and shape of the input
|
156 |
+
assert isinstance(input_data, torch.Tensor), "Input must be a torch.Tensor"
|
157 |
+
assert input_data.ndim == 4, "Input must be of shape (B, C, H, W)"
|
158 |
+
batch_size, channels, height, width = input_data.shape
|
159 |
+
assert (
|
160 |
+
height % self.patch_size == 0 and width % self.patch_size == 0
|
161 |
+
), f"Input shape must be divisible by patch size: {self.patch_size}"
|
162 |
+
|
163 |
+
# Patchify the input data and project into expected latent space
|
164 |
+
features = self.proj(input_data) # (B, C, H, W) -> (B, E, H / Patch_Size, W / Patch_Size)
|
165 |
+
features = features.flatten(2).transpose(
|
166 |
+
1, 2
|
167 |
+
) # (B, E, H / Patch_Size, W / Patch_Size) -> (B, H / Patch_Size * W / Patch_Size, E)
|
168 |
+
features = self.norm(features) # Normalize the features after patch embedding
|
169 |
+
features = features + self.interpolate_pos_encoding(
|
170 |
+
features, height, width
|
171 |
+
) # (B, H / Patch_Size * W / Patch_Size, E)
|
172 |
+
features = self.post_pe_norm(features) # Normalize the features after positional encoding
|
173 |
+
|
174 |
+
# Resize the features to the expected shape
|
175 |
+
# (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size)
|
176 |
+
features = features.permute(0, 2, 1)
|
177 |
+
features = features.reshape(
|
178 |
+
-1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size
|
179 |
+
).contiguous()
|
180 |
+
|
181 |
+
return ViTEncoderOutput(features=features)
|
182 |
+
|
183 |
+
|
184 |
+
if __name__ == "__main__":
|
185 |
+
# Init Patch Embedder for images as input
|
186 |
+
patch_embedder = PatchEmbedder(
|
187 |
+
name="patch_embedder",
|
188 |
+
data_norm_type="patch_embedder",
|
189 |
+
input_size=518,
|
190 |
+
patch_size=14,
|
191 |
+
in_chans=3,
|
192 |
+
enc_embed_dim=1024,
|
193 |
+
)
|
194 |
+
|
195 |
+
# Test dummy image input
|
196 |
+
dummy_image = torch.randn(1, 3, 518, 518)
|
197 |
+
patch_embedder_output = patch_embedder(ViTEncoderInput(data_norm_type="patch_embedder", image=dummy_image))
|
198 |
+
assert patch_embedder_output.features.shape == (
|
199 |
+
1,
|
200 |
+
1024,
|
201 |
+
37,
|
202 |
+
37,
|
203 |
+
), "Output features must have shape (1, 1024, 37, 37)"
|
204 |
+
|
205 |
+
# Init Patch Embedder for non-image data as input
|
206 |
+
patch_embedder = PatchEmbedder(
|
207 |
+
name="patch_embedder",
|
208 |
+
data_norm_type="patch_embedder",
|
209 |
+
input_size=518,
|
210 |
+
patch_size=14,
|
211 |
+
in_chans=6,
|
212 |
+
enc_embed_dim=1024,
|
213 |
+
)
|
214 |
+
|
215 |
+
# Init Patch Embedder for single channel input
|
216 |
+
patch_embedder = PatchEmbedder(
|
217 |
+
name="patch_embedder",
|
218 |
+
data_norm_type="patch_embedder",
|
219 |
+
input_size=518,
|
220 |
+
patch_size=14,
|
221 |
+
in_chans=1,
|
222 |
+
enc_embed_dim=1024,
|
223 |
+
)
|
224 |
+
|
225 |
+
# Test dummy non-image input
|
226 |
+
dummy_image = torch.randn(1, 1, 518, 518)
|
227 |
+
patch_embedder_output = patch_embedder(ViTEncoderNonImageInput(data=dummy_image))
|
228 |
+
assert patch_embedder_output.features.shape == (
|
229 |
+
1,
|
230 |
+
1024,
|
231 |
+
37,
|
232 |
+
37,
|
233 |
+
), "Output features must have shape (1, 1024, 37, 37)"
|
234 |
+
|
235 |
+
print("All variants of Patch Embedder have been initialized successfully!")
|
UniCeption/uniception/models/encoders/radio.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Encoder Class for RADIO (Nvidia)
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput
|
10 |
+
from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner
|
11 |
+
|
12 |
+
|
13 |
+
class RADIOEncoder(UniCeptionViTEncoderBase):
|
14 |
+
"UniCeption RADIO Encoder"
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
name: str,
|
19 |
+
data_norm_type: str = "radio",
|
20 |
+
patch_size: int = 16,
|
21 |
+
model_version: str = "radio_v2.5-l",
|
22 |
+
pretrained_checkpoint_path: str = None,
|
23 |
+
eradio_input_shape: Optional[tuple] = None,
|
24 |
+
torch_hub_force_reload: bool = False,
|
25 |
+
keep_first_n_layers: Optional[int] = None,
|
26 |
+
*args,
|
27 |
+
**kwargs,
|
28 |
+
):
|
29 |
+
"""
|
30 |
+
RADIO Encoder for extracting spatial features from images.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
name (str): Name of the encoder.
|
34 |
+
data_norm_type (str): Image normalization type. Default: "radio"
|
35 |
+
patch_size (int): Patch size for the encoder. Default: 16
|
36 |
+
model_version (str): Version of the RADIO model to load. Default: "radio_v2.5-l"
|
37 |
+
pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of RADIO. Default: None
|
38 |
+
eradio_input_shape (tuple): Input shape (height, width) for E-RADIO models. Default: None
|
39 |
+
torch_hub_force_reload (bool): Whether to force reload the model from torch hub. Default: False
|
40 |
+
keep_first_n_layers (Optional[int]): Number of layers to keep from the pretrained model. Default: None
|
41 |
+
"""
|
42 |
+
# Init the base class
|
43 |
+
super().__init__(
|
44 |
+
name=name,
|
45 |
+
data_norm_type=data_norm_type,
|
46 |
+
patch_size=patch_size,
|
47 |
+
*args,
|
48 |
+
**kwargs,
|
49 |
+
)
|
50 |
+
|
51 |
+
# Init the RADIO Encoder specific attributes
|
52 |
+
self.model_version = model_version
|
53 |
+
self.enc_embed_dim = {
|
54 |
+
"radio_v2.5-b": 768,
|
55 |
+
"radio_v2.5-l": 1024,
|
56 |
+
"radio_v2.5-h": 1280,
|
57 |
+
"radio_v2.5-g": 1536,
|
58 |
+
"e-radio_v2": 1536,
|
59 |
+
}[self.model_version]
|
60 |
+
|
61 |
+
if self.model_version == "radio_v2.5-g":
|
62 |
+
assert patch_size == 14, "Patch size must be 14 for RADIO v2.5-g"
|
63 |
+
else:
|
64 |
+
assert patch_size == 16, "Patch size must be 16 for all other versions of RADIO"
|
65 |
+
|
66 |
+
# Load the pretrained RADIO model from torch hub
|
67 |
+
print(f"Loading pretrained {self.model_version} from torch hub")
|
68 |
+
try: # Requires internet access
|
69 |
+
self.model = torch.hub.load(
|
70 |
+
"NVlabs/RADIO",
|
71 |
+
"radio_model",
|
72 |
+
version=self.model_version,
|
73 |
+
progress=True,
|
74 |
+
skip_validation=True,
|
75 |
+
force_reload=torch_hub_force_reload,
|
76 |
+
)
|
77 |
+
except: # Load from cache
|
78 |
+
self.model = torch.hub.load(
|
79 |
+
"NVlabs/RADIO",
|
80 |
+
"radio_model",
|
81 |
+
version=self.model_version,
|
82 |
+
progress=True,
|
83 |
+
skip_validation=True,
|
84 |
+
)
|
85 |
+
|
86 |
+
# Delete the excess blocks if keep_first_n_layers is specified
|
87 |
+
if keep_first_n_layers is not None:
|
88 |
+
assert keep_first_n_layers < len(
|
89 |
+
self.model.model.blocks
|
90 |
+
), "keep_first_n_layers must be less than the number of blocks"
|
91 |
+
print(f"Keeping only the first {keep_first_n_layers} layers of the model")
|
92 |
+
self.model.model.blocks = torch.nn.ModuleList(self.model.model.blocks[:keep_first_n_layers])
|
93 |
+
|
94 |
+
# Set the optimal window size for E-RADIO models
|
95 |
+
if "e-radio" in self.model_version:
|
96 |
+
assert eradio_input_shape is not None, "Input shape (height, width) must be provided for E-RADIO models"
|
97 |
+
self.model.model.set_optimal_window_size(eradio_input_shape)
|
98 |
+
|
99 |
+
# Load the custom pretrained checkpoint if provided
|
100 |
+
if pretrained_checkpoint_path is not None:
|
101 |
+
print(f"Loading custom pretrained RADIO checkpoint from {pretrained_checkpoint_path}")
|
102 |
+
ckpt = torch.load(pretrained_checkpoint_path, weights_only=False)
|
103 |
+
print(self.load_state_dict(ckpt["model"]))
|
104 |
+
|
105 |
+
def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput:
|
106 |
+
"""
|
107 |
+
RADIO Encoder Forward Pass
|
108 |
+
|
109 |
+
Args:
|
110 |
+
encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
ViTEncoderOutput: Output data from the encoder.
|
114 |
+
"""
|
115 |
+
# Check image normalization type
|
116 |
+
self._check_data_normalization_type(encoder_input.data_norm_type)
|
117 |
+
|
118 |
+
# Check the dtype and shape of the input image
|
119 |
+
assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor"
|
120 |
+
assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)"
|
121 |
+
batch_size, channels, height, width = encoder_input.image.shape
|
122 |
+
assert channels == 3, "Input must have 3 channels"
|
123 |
+
assert (
|
124 |
+
height % self.patch_size == 0 and width % self.patch_size == 0
|
125 |
+
), f"Input shape must be divisible by patch size: {self.patch_size}"
|
126 |
+
|
127 |
+
# Forward pass throught the RADIO encoder
|
128 |
+
summary, features = self.model(encoder_input.image)
|
129 |
+
|
130 |
+
# Resize the features to the expected shape
|
131 |
+
# (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size)
|
132 |
+
features = features.permute(0, 2, 1)
|
133 |
+
features = features.reshape(
|
134 |
+
-1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size
|
135 |
+
).contiguous()
|
136 |
+
|
137 |
+
return ViTEncoderOutput(features=features)
|
138 |
+
|
139 |
+
|
140 |
+
class RADIOIntermediateFeatureReturner(RADIOEncoder, IntermediateFeatureReturner):
|
141 |
+
"Intermediate Feature Returner for UniCeption RADIO Encoder"
|
142 |
+
|
143 |
+
def __init__(
|
144 |
+
self,
|
145 |
+
name: str,
|
146 |
+
data_norm_type: str = "radio",
|
147 |
+
patch_size: int = 16,
|
148 |
+
model_version: str = "radio_v2.5-l",
|
149 |
+
pretrained_checkpoint_path: str = None,
|
150 |
+
eradio_input_shape: Optional[tuple] = None,
|
151 |
+
indices: Union[int, List[int]] = [-1],
|
152 |
+
norm_intermediate: bool = True,
|
153 |
+
stop_early: bool = False,
|
154 |
+
intermediates_only: bool = True,
|
155 |
+
feature_adaptor: Optional[str] = None,
|
156 |
+
keep_first_n_layers: Optional[int] = None,
|
157 |
+
*args,
|
158 |
+
**kwargs,
|
159 |
+
):
|
160 |
+
"""
|
161 |
+
Intermediate Feature Returner for the RADIO Encoder.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
name (str): Name of the encoder.
|
165 |
+
data_norm_type (str): Image normalization type. Default: "radio"
|
166 |
+
patch_size (int): Patch size for the encoder. Default: 16
|
167 |
+
model_version (str): Version of the RADIO model to load. Default: "radio_v2.5-l"
|
168 |
+
pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of RADIO.
|
169 |
+
eradio_input_shape (tuple): Input shape (height, width) for E-RADIO models. Default: None
|
170 |
+
indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. Defaults to [-1]. Options:
|
171 |
+
- int: Return the last n layers.
|
172 |
+
- List[int]: Return the intermediate layers at the specified indices.
|
173 |
+
norm_intermediate (bool, optional): Whether to normalize the intermediate features. Defaults to True.
|
174 |
+
stop_early (bool, optional): Whether to stop early. Defaults to False.
|
175 |
+
intermediates_only (bool, optional): Whether to return only the intermediate features. Defaults to True.
|
176 |
+
feature_adaptor (Optional[str], optional): Feature adaptor to use. Defaults to None. Currently supported: "dino_v2".
|
177 |
+
keep_first_n_layers (Optional[int], optional): Number of layers to keep from the pretrained model. Defaults to None.
|
178 |
+
"""
|
179 |
+
# Init the base classes
|
180 |
+
RADIOEncoder.__init__(
|
181 |
+
self,
|
182 |
+
name=name,
|
183 |
+
data_norm_type=data_norm_type,
|
184 |
+
patch_size=patch_size,
|
185 |
+
model_version=model_version,
|
186 |
+
pretrained_checkpoint_path=pretrained_checkpoint_path,
|
187 |
+
eradio_input_shape=eradio_input_shape,
|
188 |
+
keep_first_n_layers=keep_first_n_layers,
|
189 |
+
*args,
|
190 |
+
**kwargs,
|
191 |
+
)
|
192 |
+
IntermediateFeatureReturner.__init__(
|
193 |
+
self,
|
194 |
+
indices=indices,
|
195 |
+
norm_intermediate=norm_intermediate,
|
196 |
+
stop_early=stop_early,
|
197 |
+
intermediates_only=intermediates_only,
|
198 |
+
)
|
199 |
+
|
200 |
+
# Convert indices to absolute indices if indices is None
|
201 |
+
if self.indices is None:
|
202 |
+
self.indices = list(range(len(self.model.model.blocks)))
|
203 |
+
|
204 |
+
self.feature_adaptor = feature_adaptor
|
205 |
+
if self.feature_adaptor is None:
|
206 |
+
pass
|
207 |
+
elif self.feature_adaptor == "dino_v2":
|
208 |
+
# Initialize a dummy radio encoder with the adaptor setting
|
209 |
+
dummy_model = torch.hub.load(
|
210 |
+
"NVlabs/RADIO",
|
211 |
+
"radio_model",
|
212 |
+
version=self.model_version,
|
213 |
+
progress=True,
|
214 |
+
skip_validation=True,
|
215 |
+
adaptor_names="dino_v2",
|
216 |
+
)
|
217 |
+
|
218 |
+
# Extract its feature converter weights
|
219 |
+
self.spatial_feature_converter = dummy_model.adaptors["dino_v2"].feat_mlp
|
220 |
+
|
221 |
+
# Update the embedding dimension because the features have been projected
|
222 |
+
self.enc_embed_dim = self.spatial_feature_converter.final[-1].out_features
|
223 |
+
|
224 |
+
del dummy_model
|
225 |
+
else:
|
226 |
+
raise ValueError("Unsupported feature adaptor. Supported: dino_v2")
|
227 |
+
|
228 |
+
def forward(
|
229 |
+
self, encoder_input: ViTEncoderInput
|
230 |
+
) -> Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]:
|
231 |
+
"""
|
232 |
+
RADIO Encoder Forward Pass with Intermediate Feature Return
|
233 |
+
|
234 |
+
Args:
|
235 |
+
encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
|
236 |
+
|
237 |
+
Returns:
|
238 |
+
Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]: Output data from the encoder.
|
239 |
+
If `intermediates_only` is True, returns a list of intermediate features.
|
240 |
+
Otherwise, returns a tuple with the final features and a list of intermediate features.
|
241 |
+
"""
|
242 |
+
# Check image normalization type
|
243 |
+
self._check_data_normalization_type(encoder_input.data_norm_type)
|
244 |
+
|
245 |
+
# Check the dtype and shape of the input image
|
246 |
+
assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor"
|
247 |
+
assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)"
|
248 |
+
batch_size, channels, height, width = encoder_input.image.shape
|
249 |
+
assert channels == 3, "Input must have 3 channels"
|
250 |
+
assert (
|
251 |
+
height % self.patch_size == 0 and width % self.patch_size == 0
|
252 |
+
), f"Input shape must be divisible by patch size: {self.patch_size}"
|
253 |
+
|
254 |
+
# Extract the final features and intermediate features accordingly
|
255 |
+
model_outputs = self.model.forward_intermediates(
|
256 |
+
encoder_input.image,
|
257 |
+
indices=self.indices,
|
258 |
+
return_prefix_tokens=False,
|
259 |
+
norm=self.norm_intermediate,
|
260 |
+
stop_early=self.stop_early,
|
261 |
+
output_fmt="NLC",
|
262 |
+
intermediates_only=self.intermediates_only,
|
263 |
+
)
|
264 |
+
|
265 |
+
# Extract the final features and intermediate features accordingly
|
266 |
+
final_features, intermediate_features = None, None
|
267 |
+
if self.intermediates_only:
|
268 |
+
intermediate_features = model_outputs
|
269 |
+
else:
|
270 |
+
final_features = model_outputs[0].features.contiguous()
|
271 |
+
intermediate_features = model_outputs[1]
|
272 |
+
|
273 |
+
# Optionally convert the features using the feature adaptor
|
274 |
+
Hp, Wp = height // self.patch_size, width // self.patch_size
|
275 |
+
|
276 |
+
# Convert final features
|
277 |
+
if final_features is not None:
|
278 |
+
if self.feature_adaptor is not None:
|
279 |
+
final_features = self.spatial_feature_converter(final_features)
|
280 |
+
|
281 |
+
# Convert to BCHW and package
|
282 |
+
final_features = final_features.view(batch_size, Hp, Wp, -1).permute(0, 3, 1, 2)
|
283 |
+
final_features = ViTEncoderOutput(features=final_features)
|
284 |
+
|
285 |
+
# Convert intermediate features
|
286 |
+
if intermediate_features is not None:
|
287 |
+
num_intermediate = len(intermediate_features)
|
288 |
+
all_intermediate_feats_tensor = torch.cat(intermediate_features, dim=0)
|
289 |
+
if self.feature_adaptor is not None:
|
290 |
+
all_intermediate_feats_tensor = self.spatial_feature_converter(all_intermediate_feats_tensor)
|
291 |
+
# Convert to BCHW
|
292 |
+
all_intermediate_feats_tensor = all_intermediate_feats_tensor.view(
|
293 |
+
num_intermediate * batch_size, Hp, Wp, -1
|
294 |
+
).permute(0, 3, 1, 2)
|
295 |
+
all_intermediate_feats = torch.chunk(all_intermediate_feats_tensor, num_intermediate, dim=0)
|
296 |
+
intermediate_features = [ViTEncoderOutput(features=x) for x in all_intermediate_feats]
|
297 |
+
|
298 |
+
# Return the final features and intermediate features accordingly
|
299 |
+
if self.intermediates_only:
|
300 |
+
return intermediate_features
|
301 |
+
else:
|
302 |
+
return final_features, intermediate_features
|
303 |
+
|
304 |
+
|
305 |
+
if __name__ == "__main__":
|
306 |
+
# Init different versions of the RADIO Encoder
|
307 |
+
for model_version in ["radio_v2.5-b", "radio_v2.5-l"]:
|
308 |
+
radio_encoder = RADIOEncoder(name="RADIOv2.5", model_version=model_version)
|
309 |
+
|
310 |
+
# Init the E-RADIO Encoder
|
311 |
+
eradio_input_shape = (512, 512)
|
312 |
+
eradio_encoder = RADIOEncoder(name="E-RADIO", model_version="e-radio_v2", eradio_input_shape=eradio_input_shape)
|
313 |
+
|
314 |
+
print("All RADIO Encoders have been initialized successfully!")
|
315 |
+
|
316 |
+
# Intermediate Feature Returner Tests
|
317 |
+
print("Running Intermediate Feature Returner Tests...")
|
318 |
+
|
319 |
+
# Run the intermediate feature returner with last-n index
|
320 |
+
radio_intermediate_feature_returner = RADIOIntermediateFeatureReturner(
|
321 |
+
name="RADIOv2.5", model_version="radio_v2.5-b", indices=6
|
322 |
+
) # Last 6 layers
|
323 |
+
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio")
|
324 |
+
output = radio_intermediate_feature_returner(dummy_input)
|
325 |
+
assert isinstance(output, list), "Output must be a list of intermediate features"
|
326 |
+
assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
|
327 |
+
assert len(output) == 6, "Output must have length of intermediate features equal to the number of indices"
|
328 |
+
|
329 |
+
# Run the intermediate feature returner with specific indices
|
330 |
+
radio_intermediate_feature_returner = RADIOIntermediateFeatureReturner(
|
331 |
+
name="RADIOv2.5", model_version="radio_v2.5-b", indices=[0, 2, 4, 6]
|
332 |
+
) # Specific layers
|
333 |
+
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio")
|
334 |
+
output = radio_intermediate_feature_returner(dummy_input)
|
335 |
+
assert isinstance(output, list), "Output must be a list of intermediate features"
|
336 |
+
assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
|
337 |
+
assert len(output) == 4, "Output must have length of intermediate features equal to the number of indices"
|
338 |
+
|
339 |
+
# Test the normalizing of intermediate features
|
340 |
+
radio_intermediate_feature_returner = RADIOIntermediateFeatureReturner(
|
341 |
+
name="RADIOv2.5", model_version="radio_v2.5-b", norm_intermediate=False, intermediates_only=False
|
342 |
+
) # Do not normalize
|
343 |
+
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio")
|
344 |
+
output = radio_intermediate_feature_returner(dummy_input)
|
345 |
+
assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features"
|
346 |
+
assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features"
|
347 |
+
assert isinstance(output[1], list), "Second element of output must be a list of intermediate features"
|
348 |
+
assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
|
349 |
+
if not isinstance(radio_intermediate_feature_returner.model.model.norm, torch.nn.Identity):
|
350 |
+
assert not torch.equal(
|
351 |
+
output[0].features, output[1][0].features
|
352 |
+
), "Final features and intermediate features must be different"
|
353 |
+
|
354 |
+
radio_intermediate_feature_returner = RADIOIntermediateFeatureReturner(
|
355 |
+
name="RADIOv2.5", model_version="radio_v2.5-b", norm_intermediate=True, intermediates_only=False
|
356 |
+
)
|
357 |
+
dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio")
|
358 |
+
output = radio_intermediate_feature_returner(dummy_input)
|
359 |
+
assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features"
|
360 |
+
assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features"
|
361 |
+
assert isinstance(output[1], list), "Second element of output must be a list of intermediate features"
|
362 |
+
assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
|
363 |
+
assert torch.equal(
|
364 |
+
output[0].features, output[1][0].features
|
365 |
+
), "Final features and intermediate features must be same"
|
366 |
+
|
367 |
+
print("All Intermediate Feature Returner Tests have passed successfully!")
|
UniCeption/uniception/models/encoders/utils.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utility functions for UniCeption Encoders.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import functools
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
def profile_encoder(num_warmup=3, num_runs=20, autocast_precision="float16", use_compile=False, dynamic=True):
|
12 |
+
def decorator(func):
|
13 |
+
@functools.wraps(func)
|
14 |
+
def wrapper(self, *args, **kwargs):
|
15 |
+
device = "cuda"
|
16 |
+
autocast_dtype = getattr(torch, autocast_precision)
|
17 |
+
|
18 |
+
# Compile the model if requested
|
19 |
+
if use_compile:
|
20 |
+
compiled_func = torch.compile(func, dynamic=dynamic, mode="max-autotune")
|
21 |
+
else:
|
22 |
+
compiled_func = func
|
23 |
+
|
24 |
+
with torch.autocast("cuda", dtype=autocast_dtype):
|
25 |
+
# Warm-up runs
|
26 |
+
for _ in range(num_warmup):
|
27 |
+
output = compiled_func(self, *args, **kwargs)
|
28 |
+
if isinstance(output, torch.Tensor):
|
29 |
+
output.sum().backward()
|
30 |
+
else:
|
31 |
+
output.features.sum().backward()
|
32 |
+
torch.cuda.synchronize()
|
33 |
+
|
34 |
+
# Clear memory cache
|
35 |
+
torch.cuda.empty_cache()
|
36 |
+
|
37 |
+
# Lists to store results
|
38 |
+
forward_times, backward_times, memory_usages = [], [], []
|
39 |
+
|
40 |
+
for _ in range(num_runs):
|
41 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
42 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
43 |
+
|
44 |
+
torch.cuda.reset_peak_memory_stats()
|
45 |
+
memory_before = torch.cuda.max_memory_allocated(device)
|
46 |
+
|
47 |
+
# Forward pass
|
48 |
+
start_event.record()
|
49 |
+
output = compiled_func(self, *args, **kwargs)
|
50 |
+
end_event.record()
|
51 |
+
torch.cuda.synchronize()
|
52 |
+
forward_times.append(start_event.elapsed_time(end_event))
|
53 |
+
|
54 |
+
# Backward pass
|
55 |
+
start_event.record()
|
56 |
+
if isinstance(output, torch.Tensor):
|
57 |
+
output.sum().backward()
|
58 |
+
else:
|
59 |
+
output.features.sum().backward()
|
60 |
+
end_event.record()
|
61 |
+
torch.cuda.synchronize()
|
62 |
+
backward_times.append(start_event.elapsed_time(end_event))
|
63 |
+
|
64 |
+
memory_after = torch.cuda.max_memory_allocated(device)
|
65 |
+
memory_usages.append((memory_after - memory_before) / 1e6) # Convert to MB
|
66 |
+
|
67 |
+
# Compute mean and standard deviation
|
68 |
+
fwd_mean, fwd_std = np.mean(forward_times), np.std(forward_times)
|
69 |
+
bwd_mean, bwd_std = np.mean(backward_times), np.std(backward_times)
|
70 |
+
mem_mean, mem_std = np.mean(memory_usages), np.std(memory_usages)
|
71 |
+
|
72 |
+
compile_status = (
|
73 |
+
"with torch.compile (dynamic=True)"
|
74 |
+
if use_compile and dynamic
|
75 |
+
else "with torch.compile (dynamic=False)" if use_compile else "without torch.compile"
|
76 |
+
)
|
77 |
+
print(f"Profiling results {compile_status}:")
|
78 |
+
print(f"Forward Pass Time: {fwd_mean:.2f} ± {fwd_std:.2f} ms")
|
79 |
+
print(f"Backward Pass Time: {bwd_mean:.2f} ± {bwd_std:.2f} ms")
|
80 |
+
print(f"Peak GPU Memory Usage: {mem_mean:.2f} ± {mem_std:.2f} MB")
|
81 |
+
|
82 |
+
return output
|
83 |
+
|
84 |
+
return wrapper
|
85 |
+
|
86 |
+
return decorator
|
UniCeption/uniception/models/factory/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from uniception.models.factory.dust3r import DUSt3R
|
2 |
+
|
3 |
+
__all__ = ["DUSt3R"]
|
UniCeption/uniception/models/factory/dust3r.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from uniception.models.encoders import ViTEncoderInput
|
7 |
+
from uniception.models.encoders.croco import CroCoEncoder
|
8 |
+
from uniception.models.encoders.image_normalizations import IMAGE_NORMALIZATION_DICT
|
9 |
+
from uniception.models.info_sharing.base import MultiViewTransformerInput
|
10 |
+
from uniception.models.info_sharing.cross_attention_transformer import (
|
11 |
+
MultiViewCrossAttentionTransformer,
|
12 |
+
MultiViewCrossAttentionTransformerIFR,
|
13 |
+
)
|
14 |
+
from uniception.models.libs.croco.pos_embed import RoPE2D, get_2d_sincos_pos_embed
|
15 |
+
from uniception.models.prediction_heads.adaptors import PointMapWithConfidenceAdaptor
|
16 |
+
from uniception.models.prediction_heads.base import AdaptorInput, PredictionHeadInput, PredictionHeadLayeredInput
|
17 |
+
from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor
|
18 |
+
from uniception.models.prediction_heads.linear import LinearFeature
|
19 |
+
|
20 |
+
|
21 |
+
def is_symmetrized(gt1, gt2):
|
22 |
+
"Function to check if input pairs are symmetrized, i.e., (a, b) and (b, a) always exist in the input"
|
23 |
+
x = gt1["instance"]
|
24 |
+
y = gt2["instance"]
|
25 |
+
if len(x) == len(y) and len(x) == 1:
|
26 |
+
return False # special case of batchsize 1
|
27 |
+
ok = True
|
28 |
+
for i in range(0, len(x), 2):
|
29 |
+
ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i])
|
30 |
+
return ok
|
31 |
+
|
32 |
+
|
33 |
+
def interleave(tensor1, tensor2):
|
34 |
+
"Interleave two tensors along the first dimension (used to avoid redundant encoding for symmetrized pairs)"
|
35 |
+
res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1)
|
36 |
+
res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1)
|
37 |
+
return res1, res2
|
38 |
+
|
39 |
+
|
40 |
+
class DUSt3R(nn.Module):
|
41 |
+
"DUSt3R defined with UniCeption Modules"
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
name: str,
|
46 |
+
data_norm_type: str = "dust3r",
|
47 |
+
img_size: tuple = (224, 224),
|
48 |
+
patch_embed_cls: str = "PatchEmbedDust3R",
|
49 |
+
pred_head_type: str = "linear",
|
50 |
+
pred_head_output_dim: int = 4,
|
51 |
+
pred_head_feature_dim: int = 256,
|
52 |
+
depth_mode: Tuple[str, float, float] = ("exp", -float("inf"), float("inf")),
|
53 |
+
conf_mode: Tuple[str, float, float] = ("exp", 1, float("inf")),
|
54 |
+
pos_embed: str = "RoPE100",
|
55 |
+
pretrained_checkpoint_path: str = None,
|
56 |
+
pretrained_encoder_checkpoint_path: str = None,
|
57 |
+
pretrained_info_sharing_checkpoint_path: str = None,
|
58 |
+
pretrained_pred_head_checkpoint_paths: List[str] = [None, None],
|
59 |
+
pretrained_pred_head_regressor_checkpoint_paths: List[str] = [None, None],
|
60 |
+
override_encoder_checkpoint_attributes: bool = False,
|
61 |
+
*args,
|
62 |
+
**kwargs,
|
63 |
+
):
|
64 |
+
"""
|
65 |
+
Two-view model containing siamese encoders followed by a two-view cross-attention transformer and respective downstream heads.
|
66 |
+
The goal is to output scene representation directly, both images in view1's frame (hence the asymmetry).
|
67 |
+
|
68 |
+
Args:
|
69 |
+
name (str): Name of the model.
|
70 |
+
data_norm_type (str): Type of data normalization. (default: "dust3r")
|
71 |
+
img_size (tuple): Size of input images. (default: (224, 224))
|
72 |
+
patch_embed_cls (str): Class for patch embedding. (default: "PatchEmbedDust3R"). Options:
|
73 |
+
- "PatchEmbedDust3R"
|
74 |
+
- "ManyAR_PatchEmbed"
|
75 |
+
pred_head_type (str): Type of prediction head. (default: "linear"). Options:
|
76 |
+
- "linear"
|
77 |
+
- "dpt"
|
78 |
+
pred_head_output_dim (int): Output dimension of prediction head. (default: 4)
|
79 |
+
pred_head_feature_dim (int): Feature dimension of prediction head. (default: 256)
|
80 |
+
depth_mode (Tuple[str, float, float]): Depth mode settings (mode=['linear', 'square', 'exp'], vmin, vmax). (default: ('exp', -inf, inf))
|
81 |
+
conf_mode (Tuple[str, float, float]): Confidence mode settings (mode=['linear', 'square', 'exp'], vmin, vmax). (default: ('exp', 1, inf))
|
82 |
+
pos_embed (str): Position embedding type. (default: 'RoPE100')
|
83 |
+
landscape_only (bool): Run downstream head only in landscape orientation. (default: True)
|
84 |
+
pretrained_checkpoint_path (str): Path to pretrained checkpoint. (default: None)
|
85 |
+
pretrained_encoder_checkpoint_path (str): Path to pretrained encoder checkpoint. (default: None)
|
86 |
+
pretrained_info_sharing_checkpoint_path (str): Path to pretrained info_sharing checkpoint. (default: None)
|
87 |
+
pretrained_pred_head_checkpoint_paths (List[str]): Paths to pretrained prediction head checkpoints. (default: None)
|
88 |
+
pretrained_pred_head_regressor_checkpoint_paths (List[str]): Paths to pretrained prediction head regressor checkpoints. (default: None)
|
89 |
+
override_encoder_checkpoint_attributes (bool): Whether to override encoder checkpoint attributes. (default: False)
|
90 |
+
"""
|
91 |
+
super().__init__(*args, **kwargs)
|
92 |
+
|
93 |
+
# Initalize the attributes
|
94 |
+
self.name = name
|
95 |
+
self.data_norm_type = data_norm_type
|
96 |
+
self.img_size = img_size
|
97 |
+
self.patch_embed_cls = patch_embed_cls
|
98 |
+
self.pred_head_type = pred_head_type
|
99 |
+
self.pred_head_output_dim = pred_head_output_dim
|
100 |
+
self.depth_mode = depth_mode
|
101 |
+
self.conf_mode = conf_mode
|
102 |
+
self.pos_embed = pos_embed
|
103 |
+
self.pretrained_checkpoint_path = pretrained_checkpoint_path
|
104 |
+
self.pretrained_encoder_checkpoint_path = pretrained_encoder_checkpoint_path
|
105 |
+
self.pretrained_info_sharing_checkpoint_path = pretrained_info_sharing_checkpoint_path
|
106 |
+
self.pretrained_pred_head_checkpoint_paths = pretrained_pred_head_checkpoint_paths
|
107 |
+
self.pretrained_pred_head_regressor_checkpoint_paths = pretrained_pred_head_regressor_checkpoint_paths
|
108 |
+
self.override_encoder_checkpoint_attributes = override_encoder_checkpoint_attributes
|
109 |
+
|
110 |
+
# Initialize RoPE for the CroCo Encoder & Two-View Cross Attention Transformer
|
111 |
+
freq = float(pos_embed[len("RoPE") :])
|
112 |
+
self.rope = RoPE2D(freq=freq)
|
113 |
+
|
114 |
+
# Initialize Encoder
|
115 |
+
self.encoder = CroCoEncoder(
|
116 |
+
name=name,
|
117 |
+
data_norm_type=data_norm_type,
|
118 |
+
patch_embed_cls=patch_embed_cls,
|
119 |
+
img_size=img_size,
|
120 |
+
pretrained_checkpoint_path=pretrained_encoder_checkpoint_path,
|
121 |
+
override_checkpoint_attributes=override_encoder_checkpoint_attributes,
|
122 |
+
)
|
123 |
+
|
124 |
+
# Initialize Multi-View Cross Attention Transformer
|
125 |
+
if self.pred_head_type == "linear":
|
126 |
+
# Returns only normalized last layer features
|
127 |
+
self.info_sharing = MultiViewCrossAttentionTransformer(
|
128 |
+
name="base_info_sharing",
|
129 |
+
input_embed_dim=self.encoder.enc_embed_dim,
|
130 |
+
num_views=2,
|
131 |
+
custom_positional_encoding=self.rope,
|
132 |
+
pretrained_checkpoint_path=pretrained_info_sharing_checkpoint_path,
|
133 |
+
)
|
134 |
+
elif self.pred_head_type == "dpt":
|
135 |
+
# Returns intermediate features and normalized last layer features
|
136 |
+
self.info_sharing = MultiViewCrossAttentionTransformerIFR(
|
137 |
+
name="base_info_sharing",
|
138 |
+
input_embed_dim=self.encoder.enc_embed_dim,
|
139 |
+
num_views=2,
|
140 |
+
indices=[5, 8],
|
141 |
+
norm_intermediate=False,
|
142 |
+
custom_positional_encoding=self.rope,
|
143 |
+
pretrained_checkpoint_path=pretrained_info_sharing_checkpoint_path,
|
144 |
+
)
|
145 |
+
else:
|
146 |
+
raise ValueError(f"Invalid prediction head type: {pred_head_type}. Must be 'linear' or 'dpt'.")
|
147 |
+
|
148 |
+
# Initialize Prediction Heads
|
149 |
+
if pred_head_type == "linear":
|
150 |
+
# Initialize Prediction Head 1
|
151 |
+
self.head1 = LinearFeature(
|
152 |
+
input_feature_dim=self.info_sharing.dim,
|
153 |
+
output_dim=pred_head_output_dim,
|
154 |
+
patch_size=self.encoder.patch_size,
|
155 |
+
pretrained_checkpoint_path=pretrained_pred_head_checkpoint_paths[0],
|
156 |
+
)
|
157 |
+
# Initialize Prediction Head 2
|
158 |
+
self.head2 = LinearFeature(
|
159 |
+
input_feature_dim=self.info_sharing.dim,
|
160 |
+
output_dim=pred_head_output_dim,
|
161 |
+
patch_size=self.encoder.patch_size,
|
162 |
+
pretrained_checkpoint_path=pretrained_pred_head_checkpoint_paths[1],
|
163 |
+
)
|
164 |
+
elif pred_head_type == "dpt":
|
165 |
+
# Initialze Predction Head 1
|
166 |
+
self.dpt_feature_head1 = DPTFeature(
|
167 |
+
patch_size=self.encoder.patch_size,
|
168 |
+
hooks=[0, 1, 2, 3],
|
169 |
+
input_feature_dims=[self.encoder.enc_embed_dim] + [self.info_sharing.dim] * 3,
|
170 |
+
feature_dim=pred_head_feature_dim,
|
171 |
+
pretrained_checkpoint_path=pretrained_pred_head_checkpoint_paths[0],
|
172 |
+
)
|
173 |
+
self.dpt_regressor_head1 = DPTRegressionProcessor(
|
174 |
+
input_feature_dim=pred_head_feature_dim,
|
175 |
+
output_dim=pred_head_output_dim,
|
176 |
+
pretrained_checkpoint_path=pretrained_pred_head_regressor_checkpoint_paths[0],
|
177 |
+
)
|
178 |
+
self.head1 = nn.Sequential(self.dpt_feature_head1, self.dpt_regressor_head1)
|
179 |
+
# Initialize Prediction Head 2
|
180 |
+
self.dpt_feature_head2 = DPTFeature(
|
181 |
+
patch_size=self.encoder.patch_size,
|
182 |
+
hooks=[0, 1, 2, 3],
|
183 |
+
input_feature_dims=[self.encoder.enc_embed_dim] + [self.info_sharing.dim] * 3,
|
184 |
+
feature_dim=pred_head_feature_dim,
|
185 |
+
pretrained_checkpoint_path=pretrained_pred_head_checkpoint_paths[1],
|
186 |
+
)
|
187 |
+
self.dpt_regressor_head2 = DPTRegressionProcessor(
|
188 |
+
input_feature_dim=pred_head_feature_dim,
|
189 |
+
output_dim=pred_head_output_dim,
|
190 |
+
pretrained_checkpoint_path=pretrained_pred_head_regressor_checkpoint_paths[1],
|
191 |
+
)
|
192 |
+
self.head2 = nn.Sequential(self.dpt_feature_head2, self.dpt_regressor_head2)
|
193 |
+
|
194 |
+
# Initialize Final Output Adaptor
|
195 |
+
self.adaptor = PointMapWithConfidenceAdaptor(
|
196 |
+
name="pointmap",
|
197 |
+
pointmap_mode=depth_mode[0],
|
198 |
+
pointmap_vmin=depth_mode[1],
|
199 |
+
pointmap_vmax=depth_mode[2],
|
200 |
+
confidence_type=conf_mode[0],
|
201 |
+
confidence_vmin=conf_mode[1],
|
202 |
+
confidence_vmax=conf_mode[2],
|
203 |
+
)
|
204 |
+
|
205 |
+
# Load pretrained weights
|
206 |
+
if self.pretrained_checkpoint_path is not None:
|
207 |
+
print(f"Loading pretrained DUSt3R weights from {self.pretrained_checkpoint_path} ...")
|
208 |
+
ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
|
209 |
+
print(self.load_state_dict(ckpt["model"]))
|
210 |
+
|
211 |
+
def _encode_image_pairs(self, img1, img2, data_norm_type):
|
212 |
+
"Encode two different batches of images (each batch can have different image shape)"
|
213 |
+
if img1.shape[-2:] == img2.shape[-2:]:
|
214 |
+
encoder_input = ViTEncoderInput(image=torch.cat((img1, img2), dim=0), data_norm_type=data_norm_type)
|
215 |
+
encoder_output = self.encoder(encoder_input)
|
216 |
+
out, out2 = encoder_output.features.chunk(2, dim=0)
|
217 |
+
else:
|
218 |
+
encoder_input = ViTEncoderInput(image=img1, data_norm_type=data_norm_type)
|
219 |
+
out = self.encoder(encoder_input)
|
220 |
+
out = out.features
|
221 |
+
encoder_input2 = ViTEncoderInput(image=img2)
|
222 |
+
out2 = self.encoder(encoder_input2)
|
223 |
+
out2 = out2.features
|
224 |
+
|
225 |
+
return out, out2
|
226 |
+
|
227 |
+
def _encode_symmetrized(self, view1, view2):
|
228 |
+
"Encode image pairs accounting for symmetrization, i.e., (a, b) and (b, a) always exist in the input"
|
229 |
+
img1 = view1["img"]
|
230 |
+
img2 = view2["img"]
|
231 |
+
if is_symmetrized(view1, view2):
|
232 |
+
# Computing half of forward pass'
|
233 |
+
feat1, feat2 = self._encode_image_pairs(img1[::2], img2[::2], data_norm_type=view1["data_norm_type"])
|
234 |
+
feat1, feat2 = interleave(feat1, feat2)
|
235 |
+
else:
|
236 |
+
feat1, feat2 = self._encode_image_pairs(img1, img2, data_norm_type=view1["data_norm_type"])
|
237 |
+
|
238 |
+
return feat1, feat2
|
239 |
+
|
240 |
+
def _downstream_head(self, head_num, decout, img_shape):
|
241 |
+
"Run the respective prediction heads"
|
242 |
+
head = getattr(self, f"head{head_num}")
|
243 |
+
if self.pred_head_type == "linear":
|
244 |
+
head_input = PredictionHeadInput(last_feature=decout[f"{head_num}"])
|
245 |
+
elif self.pred_head_type == "dpt":
|
246 |
+
head_input = PredictionHeadLayeredInput(list_features=decout[f"{head_num}"], target_output_shape=img_shape)
|
247 |
+
|
248 |
+
return head(head_input)
|
249 |
+
|
250 |
+
def forward(self, view1, view2):
|
251 |
+
"""
|
252 |
+
Forward pass for DUSt3R performing the following operations:
|
253 |
+
1. Encodes the two input views (images).
|
254 |
+
2. Combines the encoded features using a two-view cross-attention transformer.
|
255 |
+
3. Passes the combined features through the respective prediction heads.
|
256 |
+
4. Returns the processed final outputs for both views.
|
257 |
+
|
258 |
+
Args:
|
259 |
+
view1 (dict): Dictionary containing the first view's images and instance information.
|
260 |
+
"img" is a required key and value is a tensor of shape (B, C, H, W).
|
261 |
+
view2 (dict): Dictionary containing the second view's images and instance information.
|
262 |
+
"img" is a required key and value is a tensor of shape (B, C, H, W).
|
263 |
+
|
264 |
+
Returns:
|
265 |
+
Tuple[dict, dict]: A tuple containing the final outputs for both views.
|
266 |
+
"""
|
267 |
+
# Get input shapes
|
268 |
+
_, _, height1, width1 = view1["img"].shape
|
269 |
+
_, _, height2, width2 = view2["img"].shape
|
270 |
+
shape1 = (int(height1), int(width1))
|
271 |
+
shape2 = (int(height2), int(width2))
|
272 |
+
|
273 |
+
# Encode the two images --> Each feat output: BCHW features (batch_size, feature_dim, feature_height, feature_width)
|
274 |
+
feat1, feat2 = self._encode_symmetrized(view1, view2)
|
275 |
+
|
276 |
+
# Combine all images into view-centric representation
|
277 |
+
info_sharing_input = MultiViewTransformerInput(features=[feat1, feat2])
|
278 |
+
if self.pred_head_type == "linear":
|
279 |
+
final_info_sharing_multi_view_feat = self.info_sharing(info_sharing_input)
|
280 |
+
elif self.pred_head_type == "dpt":
|
281 |
+
final_info_sharing_multi_view_feat, intermediate_info_sharing_multi_view_feat = self.info_sharing(
|
282 |
+
info_sharing_input
|
283 |
+
)
|
284 |
+
|
285 |
+
if self.pred_head_type == "linear":
|
286 |
+
# Define feature dictionary for linear head
|
287 |
+
info_sharing_outputs = {
|
288 |
+
"1": final_info_sharing_multi_view_feat.features[0].float(),
|
289 |
+
"2": final_info_sharing_multi_view_feat.features[1].float(),
|
290 |
+
}
|
291 |
+
elif self.pred_head_type == "dpt":
|
292 |
+
# Define feature dictionary for DPT head
|
293 |
+
info_sharing_outputs = {
|
294 |
+
"1": [
|
295 |
+
feat1.float(),
|
296 |
+
intermediate_info_sharing_multi_view_feat[0].features[0].float(),
|
297 |
+
intermediate_info_sharing_multi_view_feat[1].features[0].float(),
|
298 |
+
final_info_sharing_multi_view_feat.features[0].float(),
|
299 |
+
],
|
300 |
+
"2": [
|
301 |
+
feat2.float(),
|
302 |
+
intermediate_info_sharing_multi_view_feat[0].features[1].float(),
|
303 |
+
intermediate_info_sharing_multi_view_feat[1].features[1].float(),
|
304 |
+
final_info_sharing_multi_view_feat.features[1].float(),
|
305 |
+
],
|
306 |
+
}
|
307 |
+
|
308 |
+
# Downstream task prediction
|
309 |
+
with torch.autocast("cuda", enabled=False):
|
310 |
+
# Prediction heads
|
311 |
+
head_output1 = self._downstream_head(1, info_sharing_outputs, shape1)
|
312 |
+
head_output2 = self._downstream_head(2, info_sharing_outputs, shape2)
|
313 |
+
|
314 |
+
# Post-process outputs
|
315 |
+
final_output1 = self.adaptor(
|
316 |
+
AdaptorInput(adaptor_feature=head_output1.decoded_channels, output_shape_hw=shape1)
|
317 |
+
)
|
318 |
+
final_output2 = self.adaptor(
|
319 |
+
AdaptorInput(adaptor_feature=head_output2.decoded_channels, output_shape_hw=shape2)
|
320 |
+
)
|
321 |
+
|
322 |
+
# Convert outputs to dictionary
|
323 |
+
res1 = {
|
324 |
+
"pts3d": final_output1.value.permute(0, 2, 3, 1).contiguous(),
|
325 |
+
"conf": final_output1.confidence.permute(0, 2, 3, 1).contiguous(),
|
326 |
+
}
|
327 |
+
res2 = {
|
328 |
+
"pts3d_in_other_view": final_output2.value.permute(0, 2, 3, 1).contiguous(),
|
329 |
+
"conf": final_output2.confidence.permute(0, 2, 3, 1).contiguous(),
|
330 |
+
}
|
331 |
+
|
332 |
+
return res1, res2
|
UniCeption/uniception/models/info_sharing/README.md
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# UniCeption Information Sharing Blocks
|
2 |
+
|
3 |
+
## Currently Supported Information Sharing Architectures
|
4 |
+
|
5 |
+
### UniCeptionInfoSharingBase:
|
6 |
+
|
7 |
+
- `MultiViewCrossAttentionTransformer`
|
8 |
+
- `MultiViewCrossAttentionTransformerIFR`
|
9 |
+
- `MultiViewGlobalAttentionTransformer`
|
10 |
+
- `MultiViewGlobalAttentionTransformerIFR`
|
11 |
+
- `MultiViewAlternatingAttentionTransformer`
|
12 |
+
- `MultiViewAlternatingAttentionTransformerIFR`
|
13 |
+
|
14 |
+
## Developer Guidelines
|
15 |
+
|
16 |
+
Please follow the main UniCeption developer guidelines described in `README.md` when contributing to the information sharing blocks. Make sure to test your different implementations and add necessary unit tests.
|
17 |
+
|
18 |
+
## Happy Coding!
|
UniCeption/uniception/models/info_sharing/__init__.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from uniception.models.info_sharing.alternating_attention_transformer import (
|
2 |
+
MultiViewAlternatingAttentionTransformer,
|
3 |
+
MultiViewAlternatingAttentionTransformerIFR,
|
4 |
+
)
|
5 |
+
from uniception.models.info_sharing.cross_attention_transformer import (
|
6 |
+
MultiViewCrossAttentionTransformer,
|
7 |
+
MultiViewCrossAttentionTransformerIFR,
|
8 |
+
MultiViewTransformerInput,
|
9 |
+
)
|
10 |
+
from uniception.models.info_sharing.diff_cross_attention_transformer import (
|
11 |
+
DifferentialMultiViewCrossAttentionTransformer,
|
12 |
+
DifferentialMultiViewCrossAttentionTransformerIFR,
|
13 |
+
)
|
14 |
+
from uniception.models.info_sharing.global_attention_transformer import (
|
15 |
+
MultiViewGlobalAttentionTransformer,
|
16 |
+
MultiViewGlobalAttentionTransformerIFR,
|
17 |
+
)
|
18 |
+
|
19 |
+
INFO_SHARING_CLASSES = {
|
20 |
+
"cross_attention": (MultiViewCrossAttentionTransformer, MultiViewCrossAttentionTransformerIFR),
|
21 |
+
"diff_cross_attention": (
|
22 |
+
DifferentialMultiViewCrossAttentionTransformer,
|
23 |
+
DifferentialMultiViewCrossAttentionTransformerIFR,
|
24 |
+
),
|
25 |
+
"alternating_attention": (
|
26 |
+
MultiViewAlternatingAttentionTransformer,
|
27 |
+
MultiViewAlternatingAttentionTransformerIFR,
|
28 |
+
),
|
29 |
+
"global_attention": (
|
30 |
+
MultiViewGlobalAttentionTransformer,
|
31 |
+
MultiViewGlobalAttentionTransformerIFR,
|
32 |
+
),
|
33 |
+
}
|
34 |
+
|
35 |
+
__all__ = ["INFO_SHARING_CLASSES", "MultiViewTransformerInput"]
|
UniCeption/uniception/models/info_sharing/alternating_attention_transformer.py
ADDED
@@ -0,0 +1,944 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
UniCeption Alternating-Attention Transformer for Information Sharing
|
3 |
+
"""
|
4 |
+
|
5 |
+
from functools import partial
|
6 |
+
from typing import Callable, List, Optional, Tuple, Type, Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from uniception.models.info_sharing.base import (
|
13 |
+
MultiViewTransformerInput,
|
14 |
+
MultiViewTransformerOutput,
|
15 |
+
UniCeptionInfoSharingBase,
|
16 |
+
)
|
17 |
+
from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner, feature_take_indices
|
18 |
+
from uniception.models.utils.positional_encoding import PositionGetter
|
19 |
+
from uniception.models.utils.transformer_blocks import Mlp, SelfAttentionBlock
|
20 |
+
|
21 |
+
|
22 |
+
class MultiViewAlternatingAttentionTransformer(UniCeptionInfoSharingBase):
|
23 |
+
"UniCeption Multi-View Alternating-Attention Transformer for information sharing across image features from different views."
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
name: str,
|
28 |
+
input_embed_dim: int,
|
29 |
+
use_pe_for_non_reference_views: bool = False,
|
30 |
+
max_num_views_for_pe: int = 1000,
|
31 |
+
use_rand_idx_pe_for_non_reference_views: bool = True,
|
32 |
+
size: Optional[str] = None,
|
33 |
+
depth: int = 12,
|
34 |
+
dim: int = 768,
|
35 |
+
num_heads: int = 12,
|
36 |
+
mlp_ratio: float = 4.0,
|
37 |
+
qkv_bias: bool = True,
|
38 |
+
qk_norm: bool = False,
|
39 |
+
proj_drop: float = 0.0,
|
40 |
+
attn_drop: float = 0.0,
|
41 |
+
init_values: Optional[float] = None,
|
42 |
+
drop_path: float = 0.0,
|
43 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
44 |
+
norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(nn.LayerNorm, eps=1e-6),
|
45 |
+
mlp_layer: Type[nn.Module] = Mlp,
|
46 |
+
custom_positional_encoding: Optional[Callable] = None,
|
47 |
+
pretrained_checkpoint_path: Optional[str] = None,
|
48 |
+
gradient_checkpointing: bool = False,
|
49 |
+
*args,
|
50 |
+
**kwargs,
|
51 |
+
):
|
52 |
+
"""
|
53 |
+
Initialize the Multi-View Alternating-Attention Transformer for information sharing across image features from different views.
|
54 |
+
Alternates between global and frame-level attention.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
input_embed_dim (int): Dimension of input embeddings.
|
58 |
+
use_pe_for_non_reference_views (bool): Whether to use view positional encoding for input non-referenec views. (default: False)
|
59 |
+
max_num_views_for_pe (int): Maximum number of views for positional encoding. (default: 1000)
|
60 |
+
use_rand_idx_pe_for_non_reference_views (bool): Whether to use random index positional encoding for non-reference views. (default: True)
|
61 |
+
size (str): String to indicate interpretable size of the transformer (for e.g., base, large, ...). (default: None)
|
62 |
+
depth (int): Number of transformer layers. (default: 12, base size)
|
63 |
+
dim (int): Dimension of the transformer. (default: 768, base size)
|
64 |
+
num_heads (int): Number of attention heads. (default: 12, base size)
|
65 |
+
mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.)
|
66 |
+
qkv_bias (bool): Whether to include bias in qkv projection (default: True)
|
67 |
+
qk_norm (bool): Whether to normalize q and k (default: False)
|
68 |
+
proj_drop (float): Dropout rate for output (default: 0.)
|
69 |
+
attn_drop (float): Dropout rate for attention weights (default: 0.)
|
70 |
+
init_values (float): Initial value for LayerScale gamma (default: None)
|
71 |
+
drop_path (float): Dropout rate for stochastic depth (default: 0.)
|
72 |
+
act_layer (nn.Module): Activation layer (default: nn.GELU)
|
73 |
+
norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm)
|
74 |
+
mlp_layer (nn.Module): MLP layer (default: Mlp)
|
75 |
+
custom_positional_encoding (Callable): Custom positional encoding function (default: None)
|
76 |
+
pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None)
|
77 |
+
gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False)
|
78 |
+
"""
|
79 |
+
# Initialize the base class
|
80 |
+
super().__init__(name=name, size=size, *args, **kwargs)
|
81 |
+
|
82 |
+
# Initialize the specific attributes of the transformer
|
83 |
+
self.input_embed_dim = input_embed_dim
|
84 |
+
self.use_pe_for_non_reference_views = use_pe_for_non_reference_views
|
85 |
+
self.max_num_views_for_pe = max_num_views_for_pe
|
86 |
+
self.use_rand_idx_pe_for_non_reference_views = use_rand_idx_pe_for_non_reference_views
|
87 |
+
self.depth = depth
|
88 |
+
self.dim = dim
|
89 |
+
self.num_heads = num_heads
|
90 |
+
self.mlp_ratio = mlp_ratio
|
91 |
+
self.qkv_bias = qkv_bias
|
92 |
+
self.qk_norm = qk_norm
|
93 |
+
self.proj_drop = proj_drop
|
94 |
+
self.attn_drop = attn_drop
|
95 |
+
self.init_values = init_values
|
96 |
+
self.drop_path = drop_path
|
97 |
+
self.act_layer = act_layer
|
98 |
+
self.norm_layer = norm_layer
|
99 |
+
self.mlp_layer = mlp_layer
|
100 |
+
self.custom_positional_encoding = custom_positional_encoding
|
101 |
+
self.pretrained_checkpoint_path = pretrained_checkpoint_path
|
102 |
+
self.gradient_checkpointing = gradient_checkpointing
|
103 |
+
|
104 |
+
# Initialize the projection layer for input embeddings
|
105 |
+
if self.input_embed_dim != self.dim:
|
106 |
+
self.proj_embed = nn.Linear(self.input_embed_dim, self.dim, bias=True)
|
107 |
+
else:
|
108 |
+
self.proj_embed = nn.Identity()
|
109 |
+
|
110 |
+
# Initialize the self-attention blocks which ingest all views at once
|
111 |
+
self.self_attention_blocks = nn.ModuleList(
|
112 |
+
[
|
113 |
+
SelfAttentionBlock(
|
114 |
+
dim=self.dim,
|
115 |
+
num_heads=self.num_heads,
|
116 |
+
mlp_ratio=self.mlp_ratio,
|
117 |
+
qkv_bias=self.qkv_bias,
|
118 |
+
qk_norm=self.qk_norm,
|
119 |
+
proj_drop=self.proj_drop,
|
120 |
+
attn_drop=self.attn_drop,
|
121 |
+
init_values=self.init_values,
|
122 |
+
drop_path=self.drop_path,
|
123 |
+
act_layer=self.act_layer,
|
124 |
+
norm_layer=self.norm_layer,
|
125 |
+
mlp_layer=self.mlp_layer,
|
126 |
+
custom_positional_encoding=self.custom_positional_encoding,
|
127 |
+
)
|
128 |
+
for _ in range(self.depth)
|
129 |
+
]
|
130 |
+
)
|
131 |
+
|
132 |
+
# Initialize the final normalization layer
|
133 |
+
self.norm = self.norm_layer(self.dim)
|
134 |
+
|
135 |
+
# Initialize the position getter for patch positions if required
|
136 |
+
if self.custom_positional_encoding is not None:
|
137 |
+
self.position_getter = PositionGetter()
|
138 |
+
|
139 |
+
if self.use_pe_for_non_reference_views:
|
140 |
+
# Initialize the positional encoding table for the different views
|
141 |
+
self.register_buffer(
|
142 |
+
"view_pos_table",
|
143 |
+
self._get_sinusoid_encoding_table(self.max_num_views_for_pe, self.dim, 10000),
|
144 |
+
)
|
145 |
+
else:
|
146 |
+
# Initialize the positional encoding table for the reference view
|
147 |
+
self.register_buffer(
|
148 |
+
"view_pos_table",
|
149 |
+
self._get_sinusoid_encoding_table(1, self.dim, 10000),
|
150 |
+
)
|
151 |
+
|
152 |
+
# Initialize random weights
|
153 |
+
self.initialize_weights()
|
154 |
+
|
155 |
+
# Apply gradient checkpointing if enabled
|
156 |
+
if self.gradient_checkpointing:
|
157 |
+
for i, block in enumerate(self.self_attention_blocks):
|
158 |
+
self.self_attention_blocks[i] = self.wrap_module_with_gradient_checkpointing(block)
|
159 |
+
|
160 |
+
# Load pretrained weights if provided
|
161 |
+
if self.pretrained_checkpoint_path is not None:
|
162 |
+
print(
|
163 |
+
f"Loading pretrained multi-view Alternating-Attention transformer weights from {self.pretrained_checkpoint_path} ..."
|
164 |
+
)
|
165 |
+
ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
|
166 |
+
print(self.load_state_dict(ckpt["model"]))
|
167 |
+
|
168 |
+
def _get_sinusoid_encoding_table(self, n_position, d_hid, base):
|
169 |
+
"Sinusoid position encoding table"
|
170 |
+
|
171 |
+
def get_position_angle_vec(position):
|
172 |
+
return [position / np.power(base, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
173 |
+
|
174 |
+
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
175 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
|
176 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])
|
177 |
+
|
178 |
+
return torch.FloatTensor(sinusoid_table)
|
179 |
+
|
180 |
+
def initialize_weights(self):
|
181 |
+
"Initialize weights of the transformer."
|
182 |
+
# Linears and layer norms
|
183 |
+
self.apply(self._init_weights)
|
184 |
+
|
185 |
+
def _init_weights(self, m):
|
186 |
+
"Initialize the transformer linear and layer norm weights."
|
187 |
+
if isinstance(m, nn.Linear):
|
188 |
+
# We use xavier_uniform following official JAX ViT:
|
189 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
190 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
191 |
+
nn.init.constant_(m.bias, 0)
|
192 |
+
elif isinstance(m, nn.LayerNorm):
|
193 |
+
nn.init.constant_(m.bias, 0)
|
194 |
+
nn.init.constant_(m.weight, 1.0)
|
195 |
+
|
196 |
+
def forward(
|
197 |
+
self,
|
198 |
+
model_input: MultiViewTransformerInput,
|
199 |
+
) -> MultiViewTransformerOutput:
|
200 |
+
"""
|
201 |
+
Forward interface for the Multi-View Alternating-Attention Transformer.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
model_input (MultiViewTransformerInput): Input to the model.
|
205 |
+
Expects the features to be a list of size (batch, input_embed_dim, height, width),
|
206 |
+
where each entry corresponds to a different view.
|
207 |
+
Optionally, the input can also include additional_input_tokens (e.g., class token, registers, pose tokens, scale token)
|
208 |
+
which are appended to the token set from the multi-view features. The tokens are of size (batch, input_embed_dim, num_of_additional_tokens).
|
209 |
+
|
210 |
+
Returns:
|
211 |
+
MultiViewTransformerOutput: Output of the model post information sharing.
|
212 |
+
"""
|
213 |
+
# Check that the number of views matches the input and the features are of expected shape
|
214 |
+
if self.use_pe_for_non_reference_views:
|
215 |
+
assert (
|
216 |
+
len(model_input.features) <= self.max_num_views_for_pe
|
217 |
+
), f"Expected less than {self.max_num_views_for_pe} views, got {len(model_input.features)}"
|
218 |
+
assert all(
|
219 |
+
view_features.shape[1] == self.input_embed_dim for view_features in model_input.features
|
220 |
+
), f"All views must have input dimension {self.input_embed_dim}"
|
221 |
+
assert all(
|
222 |
+
view_features.ndim == 4 for view_features in model_input.features
|
223 |
+
), "All views must have 4 dimensions (N, C, H, W)"
|
224 |
+
|
225 |
+
# Initialize the multi-view features from the model input and number of views for current input
|
226 |
+
multi_view_features = model_input.features
|
227 |
+
num_of_views = len(multi_view_features)
|
228 |
+
batch_size, _, height, width = multi_view_features[0].shape
|
229 |
+
num_of_tokens_per_view = height * width
|
230 |
+
|
231 |
+
# Stack the multi-view features (N, C, H, W) to (N, V, C, H, W) (assumes all V views have same shape)
|
232 |
+
multi_view_features = torch.stack(multi_view_features, dim=1)
|
233 |
+
|
234 |
+
# Resize the multi-view features from NVCHW to NLC, where L = V * H * W
|
235 |
+
multi_view_features = multi_view_features.permute(0, 1, 3, 4, 2) # (N, V, H, W, C)
|
236 |
+
multi_view_features = multi_view_features.reshape(
|
237 |
+
batch_size, num_of_views * height * width, self.input_embed_dim
|
238 |
+
).contiguous()
|
239 |
+
|
240 |
+
# Process additional input tokens if provided
|
241 |
+
if model_input.additional_input_tokens is not None:
|
242 |
+
|
243 |
+
additional_tokens = model_input.additional_input_tokens
|
244 |
+
assert additional_tokens.ndim == 3, "Additional tokens must have 3 dimensions (N, C, T)"
|
245 |
+
assert (
|
246 |
+
additional_tokens.shape[1] == self.input_embed_dim
|
247 |
+
), f"Additional tokens must have input dimension {self.input_embed_dim}"
|
248 |
+
assert additional_tokens.shape[0] == batch_size, "Batch size mismatch for additional tokens"
|
249 |
+
|
250 |
+
# Reshape to channel-last format for transformer processing
|
251 |
+
additional_tokens = additional_tokens.permute(0, 2, 1).contiguous() # (N, C, T) -> (N, T, C)
|
252 |
+
|
253 |
+
# Concatenate the additional tokens to the multi-view features
|
254 |
+
multi_view_features = torch.cat([multi_view_features, additional_tokens], dim=1)
|
255 |
+
|
256 |
+
# Project input features to the transformer dimension
|
257 |
+
multi_view_features = self.proj_embed(multi_view_features)
|
258 |
+
|
259 |
+
# Create patch positions for each view if custom positional encoding is used
|
260 |
+
if self.custom_positional_encoding is not None:
|
261 |
+
multi_view_positions = [
|
262 |
+
self.position_getter(batch_size, height, width, multi_view_features.device)
|
263 |
+
] * num_of_views # List of length V, where each tensor is (N, H * W, C)
|
264 |
+
multi_view_positions = torch.cat(multi_view_positions, dim=1) # (N, V * H * W, C)
|
265 |
+
else:
|
266 |
+
multi_view_positions = [None] * num_of_views
|
267 |
+
|
268 |
+
# Add None positions for additional tokens if they exist
|
269 |
+
if model_input.additional_input_tokens is not None:
|
270 |
+
|
271 |
+
additional_tokens_positions = [None] * model_input.additional_input_tokens.shape[1]
|
272 |
+
multi_view_positions = multi_view_positions + additional_tokens_positions
|
273 |
+
|
274 |
+
# Add positional encoding for reference view (idx 0)
|
275 |
+
ref_view_pe = self.view_pos_table[0].clone().detach()
|
276 |
+
ref_view_pe = ref_view_pe.reshape((1, 1, self.dim))
|
277 |
+
ref_view_pe = ref_view_pe.repeat(batch_size, num_of_tokens_per_view, 1)
|
278 |
+
ref_view_features = multi_view_features[:, :num_of_tokens_per_view, :]
|
279 |
+
ref_view_features = ref_view_features + ref_view_pe
|
280 |
+
|
281 |
+
if self.use_pe_for_non_reference_views:
|
282 |
+
# Add positional encoding for non-reference views (sequential indices starting from idx 1 or random indices which are uniformly sampled)
|
283 |
+
if self.use_rand_idx_pe_for_non_reference_views:
|
284 |
+
non_ref_view_pe_indices = torch.randint(low=1, high=self.max_num_views_for_pe, size=(num_of_views - 1,))
|
285 |
+
else:
|
286 |
+
non_ref_view_pe_indices = torch.arange(1, num_of_views)
|
287 |
+
non_ref_view_pe = self.view_pos_table[non_ref_view_pe_indices].clone().detach()
|
288 |
+
non_ref_view_pe = non_ref_view_pe.reshape((1, num_of_views - 1, self.dim))
|
289 |
+
non_ref_view_pe = non_ref_view_pe.repeat_interleave(num_of_tokens_per_view, dim=1)
|
290 |
+
non_ref_view_pe = non_ref_view_pe.repeat(batch_size, 1, 1)
|
291 |
+
non_ref_view_features = multi_view_features[
|
292 |
+
:, num_of_tokens_per_view : num_of_views * num_of_tokens_per_view, :
|
293 |
+
]
|
294 |
+
non_ref_view_features = non_ref_view_features + non_ref_view_pe
|
295 |
+
else:
|
296 |
+
non_ref_view_features = multi_view_features[
|
297 |
+
:, num_of_tokens_per_view : num_of_views * num_of_tokens_per_view, :
|
298 |
+
]
|
299 |
+
|
300 |
+
# Concatenate the reference and non-reference view features
|
301 |
+
# Handle additional tokens (no view-based positional encoding for them)
|
302 |
+
if model_input.additional_input_tokens is not None:
|
303 |
+
|
304 |
+
additional_features = multi_view_features[:, num_of_views * num_of_tokens_per_view :, :]
|
305 |
+
multi_view_features = torch.cat([ref_view_features, non_ref_view_features, additional_features], dim=1)
|
306 |
+
else:
|
307 |
+
multi_view_features = torch.cat([ref_view_features, non_ref_view_features], dim=1)
|
308 |
+
|
309 |
+
# Loop over the depth of the transformer
|
310 |
+
for depth_idx in range(self.depth):
|
311 |
+
if depth_idx % 2 == 0:
|
312 |
+
# Apply the self-attention block and update the multi-view features
|
313 |
+
# Global attention across all views
|
314 |
+
multi_view_features = self.self_attention_blocks[depth_idx](multi_view_features, multi_view_positions)
|
315 |
+
else:
|
316 |
+
# Handle additional tokens separately for frame-level attention
|
317 |
+
additional_features = None
|
318 |
+
additional_positions = None
|
319 |
+
if model_input.additional_input_tokens is not None:
|
320 |
+
|
321 |
+
# Extract additional token features
|
322 |
+
additional_features = multi_view_features[:, num_of_views * num_of_tokens_per_view :, :]
|
323 |
+
# Keep only view features for frame-level attention
|
324 |
+
multi_view_features = multi_view_features[:, : num_of_views * num_of_tokens_per_view, :]
|
325 |
+
|
326 |
+
# Handle positions for additional tokens if custom positional encoding is used
|
327 |
+
if self.custom_positional_encoding is not None:
|
328 |
+
additional_positions = multi_view_positions[:, num_of_views * num_of_tokens_per_view :, :]
|
329 |
+
multi_view_positions = multi_view_positions[:, : num_of_views * num_of_tokens_per_view, :]
|
330 |
+
|
331 |
+
# Reshape the multi-view features from (N, V * H * W, C) to (N * V, H * W, C)
|
332 |
+
multi_view_features = multi_view_features.reshape(
|
333 |
+
batch_size * num_of_views, num_of_tokens_per_view, self.dim
|
334 |
+
).contiguous() # (N * V, H * W, C)
|
335 |
+
if multi_view_positions[0] is not None:
|
336 |
+
multi_view_positions = multi_view_positions.reshape(
|
337 |
+
batch_size * num_of_views, num_of_tokens_per_view, 2
|
338 |
+
).contiguous() # (N * V, H * W, C)
|
339 |
+
|
340 |
+
# Apply the self-attention block and update the multi-view features
|
341 |
+
# Frame-level attention within each view
|
342 |
+
multi_view_features = self.self_attention_blocks[depth_idx](multi_view_features, multi_view_positions)
|
343 |
+
|
344 |
+
# Reshape the multi-view features from (N * V, H * W, C) back to (N, V * H * W, C)
|
345 |
+
multi_view_features = multi_view_features.reshape(
|
346 |
+
batch_size, num_of_views * num_of_tokens_per_view, self.dim
|
347 |
+
).contiguous() # (N, V * H * W, C)
|
348 |
+
if multi_view_positions[0] is not None:
|
349 |
+
multi_view_positions = multi_view_positions.reshape(
|
350 |
+
batch_size, num_of_views * num_of_tokens_per_view, 2
|
351 |
+
).contiguous() # (N, V * H * W, C)
|
352 |
+
|
353 |
+
# Reattach additional tokens if they exist
|
354 |
+
if additional_features is not None:
|
355 |
+
multi_view_features = torch.cat([multi_view_features, additional_features], dim=1)
|
356 |
+
# Reattach positions for additional tokens if they exist
|
357 |
+
if additional_positions is not None:
|
358 |
+
multi_view_positions = torch.cat([multi_view_positions, additional_positions], dim=1)
|
359 |
+
|
360 |
+
# Normalize the output features
|
361 |
+
output_multi_view_features = self.norm(multi_view_features)
|
362 |
+
|
363 |
+
# Extract only the view features (excluding additional tokens)
|
364 |
+
view_features = output_multi_view_features[:, : num_of_views * num_of_tokens_per_view, :]
|
365 |
+
|
366 |
+
# Reshape the output multi-view features (N, V * H * W, C) back to (N, V, C, H, W)
|
367 |
+
view_features = view_features.reshape(batch_size, num_of_views, height, width, self.dim) # (N, V, H, W, C)
|
368 |
+
view_features = view_features.permute(0, 1, 4, 2, 3).contiguous() # (N, V, C, H, W)
|
369 |
+
|
370 |
+
# Split the output multi-view features into separate views
|
371 |
+
view_features = view_features.split(1, dim=1)
|
372 |
+
view_features = [output_view_features.squeeze(dim=1) for output_view_features in view_features]
|
373 |
+
|
374 |
+
# Extract and return additional token features if provided
|
375 |
+
if model_input.additional_input_tokens is not None:
|
376 |
+
|
377 |
+
additional_token_features = output_multi_view_features[:, num_of_views * num_of_tokens_per_view :, :]
|
378 |
+
additional_token_features = additional_token_features.permute(0, 2, 1).contiguous() # (N, C, T)
|
379 |
+
return MultiViewTransformerOutput(
|
380 |
+
features=view_features, additional_token_features=additional_token_features
|
381 |
+
)
|
382 |
+
else:
|
383 |
+
return MultiViewTransformerOutput(features=view_features)
|
384 |
+
|
385 |
+
|
386 |
+
class MultiViewAlternatingAttentionTransformerIFR(
|
387 |
+
MultiViewAlternatingAttentionTransformer, IntermediateFeatureReturner
|
388 |
+
):
|
389 |
+
"Intermediate Feature Returner for UniCeption Multi-View Alternating-Attention Transformer"
|
390 |
+
|
391 |
+
def __init__(
|
392 |
+
self,
|
393 |
+
name: str,
|
394 |
+
input_embed_dim: int,
|
395 |
+
use_pe_for_non_reference_views: bool = False,
|
396 |
+
max_num_views_for_pe: int = 1000,
|
397 |
+
use_rand_idx_pe_for_non_reference_views: bool = True,
|
398 |
+
size: Optional[str] = None,
|
399 |
+
depth: int = 12,
|
400 |
+
dim: int = 768,
|
401 |
+
num_heads: int = 12,
|
402 |
+
mlp_ratio: float = 4.0,
|
403 |
+
qkv_bias: bool = True,
|
404 |
+
qk_norm: bool = False,
|
405 |
+
proj_drop: float = 0.0,
|
406 |
+
attn_drop: float = 0.0,
|
407 |
+
init_values: Optional[float] = None,
|
408 |
+
drop_path: float = 0.0,
|
409 |
+
act_layer: nn.Module = nn.GELU,
|
410 |
+
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
|
411 |
+
mlp_layer: nn.Module = Mlp,
|
412 |
+
custom_positional_encoding: Callable = None,
|
413 |
+
pretrained_checkpoint_path: str = None,
|
414 |
+
indices: Optional[Union[int, List[int]]] = None,
|
415 |
+
norm_intermediate: bool = True,
|
416 |
+
intermediates_only: bool = False,
|
417 |
+
gradient_checkpointing: bool = False,
|
418 |
+
*args,
|
419 |
+
**kwargs,
|
420 |
+
):
|
421 |
+
"""
|
422 |
+
Initialize the Multi-View Alternating-Attention Transformer for information sharing across image features from different views.
|
423 |
+
Extends the base class to return intermediate features.
|
424 |
+
|
425 |
+
Args:
|
426 |
+
input_embed_dim (int): Dimension of input embeddings.
|
427 |
+
use_pe_for_non_reference_views (bool): Whether to use view positional encoding for input non-referenec views. (default: False)
|
428 |
+
max_num_views_for_pe (int): Maximum number of views for positional encoding. (default: 1000)
|
429 |
+
use_rand_idx_pe_for_non_reference_views (bool): Whether to use random index positional encoding for non-reference views. (default: True)
|
430 |
+
use_rand_idx_pe_for_non_reference_views (bool): Whether to use random index positional encoding for non-reference views.
|
431 |
+
size (str): String to indicate interpretable size of the transformer (for e.g., base, large, ...). (default: None)
|
432 |
+
depth (int): Number of transformer layers. (default: 12, base size)
|
433 |
+
dim (int): Dimension of the transformer. (default: 768, base size)
|
434 |
+
num_heads (int): Number of attention heads. (default: 12, base size)
|
435 |
+
mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.)
|
436 |
+
qkv_bias (bool): Whether to include bias in qkv projection (default: False)
|
437 |
+
qk_norm (bool): Whether to normalize q and k (default: False)
|
438 |
+
proj_drop (float): Dropout rate for output (default: 0.)
|
439 |
+
attn_drop (float): Dropout rate for attention weights (default: 0.)
|
440 |
+
init_values (float): Initial value for LayerScale gamma (default: None)
|
441 |
+
drop_path (float): Dropout rate for stochastic depth (default: 0.)
|
442 |
+
act_layer (nn.Module): Activation layer (default: nn.GELU)
|
443 |
+
norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm)
|
444 |
+
mlp_layer (nn.Module): MLP layer (default: Mlp)
|
445 |
+
custom_positional_encoding (Callable): Custom positional encoding function (default: None)
|
446 |
+
pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None)
|
447 |
+
indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. (default: None) Options:
|
448 |
+
- None: Return all intermediate layers.
|
449 |
+
- int: Return the last n layers.
|
450 |
+
- List[int]: Return the intermediate layers at the specified indices.
|
451 |
+
norm_intermediate (bool, optional): Whether to normalize the intermediate features. (default: True)
|
452 |
+
intermediates_only (bool, optional): Whether to return only the intermediate features. (default: False)
|
453 |
+
gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False)
|
454 |
+
"""
|
455 |
+
# Init the base classes
|
456 |
+
MultiViewAlternatingAttentionTransformer.__init__(
|
457 |
+
self,
|
458 |
+
name=name,
|
459 |
+
input_embed_dim=input_embed_dim,
|
460 |
+
use_pe_for_non_reference_views=use_pe_for_non_reference_views,
|
461 |
+
max_num_views_for_pe=max_num_views_for_pe,
|
462 |
+
use_rand_idx_pe_for_non_reference_views=use_rand_idx_pe_for_non_reference_views,
|
463 |
+
size=size,
|
464 |
+
depth=depth,
|
465 |
+
dim=dim,
|
466 |
+
num_heads=num_heads,
|
467 |
+
mlp_ratio=mlp_ratio,
|
468 |
+
qkv_bias=qkv_bias,
|
469 |
+
qk_norm=qk_norm,
|
470 |
+
proj_drop=proj_drop,
|
471 |
+
attn_drop=attn_drop,
|
472 |
+
init_values=init_values,
|
473 |
+
drop_path=drop_path,
|
474 |
+
act_layer=act_layer,
|
475 |
+
norm_layer=norm_layer,
|
476 |
+
mlp_layer=mlp_layer,
|
477 |
+
custom_positional_encoding=custom_positional_encoding,
|
478 |
+
pretrained_checkpoint_path=pretrained_checkpoint_path,
|
479 |
+
gradient_checkpointing=gradient_checkpointing,
|
480 |
+
*args,
|
481 |
+
**kwargs,
|
482 |
+
)
|
483 |
+
IntermediateFeatureReturner.__init__(
|
484 |
+
self,
|
485 |
+
indices=indices,
|
486 |
+
norm_intermediate=norm_intermediate,
|
487 |
+
intermediates_only=intermediates_only,
|
488 |
+
)
|
489 |
+
|
490 |
+
def forward(
|
491 |
+
self,
|
492 |
+
model_input: MultiViewTransformerInput,
|
493 |
+
) -> Union[
|
494 |
+
List[MultiViewTransformerOutput],
|
495 |
+
Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]],
|
496 |
+
]:
|
497 |
+
"""
|
498 |
+
Forward interface for the Multi-View Alternating-Attention Transformer with Intermediate Feature Return.
|
499 |
+
|
500 |
+
Args:
|
501 |
+
model_input (MultiViewTransformerInput): Input to the model.
|
502 |
+
Expects the features to be a list of size (batch, input_embed_dim, height, width),
|
503 |
+
where each entry corresponds to a different view.
|
504 |
+
Optionally, the input can also include additional_input_tokens (e.g., class token, registers, pose tokens, scale token)
|
505 |
+
which are appended to the token set from the multi-view features. The tokens are of size (batch, input_embed_dim, num_of_additional_tokens).
|
506 |
+
|
507 |
+
Returns:
|
508 |
+
Union[List[MultiViewTransformerOutput], Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]]]:
|
509 |
+
Output of the model post information sharing.
|
510 |
+
If intermediates_only is True, returns a list of intermediate outputs.
|
511 |
+
If intermediates_only is False, returns a tuple of final output and a list of intermediate outputs.
|
512 |
+
"""
|
513 |
+
# Check that the number of views matches the input and the features are of expected shape
|
514 |
+
if self.use_pe_for_non_reference_views:
|
515 |
+
assert (
|
516 |
+
len(model_input.features) <= self.max_num_views_for_pe
|
517 |
+
), f"Expected less than {self.max_num_views_for_pe} views, got {len(model_input.features)}"
|
518 |
+
assert all(
|
519 |
+
view_features.shape[1] == self.input_embed_dim for view_features in model_input.features
|
520 |
+
), f"All views must have input dimension {self.input_embed_dim}"
|
521 |
+
assert all(
|
522 |
+
view_features.ndim == 4 for view_features in model_input.features
|
523 |
+
), "All views must have 4 dimensions (N, C, H, W)"
|
524 |
+
|
525 |
+
# Get the indices of the intermediate features to return
|
526 |
+
intermediate_multi_view_features = []
|
527 |
+
take_indices, _ = feature_take_indices(self.depth, self.indices)
|
528 |
+
|
529 |
+
# Initialize the multi-view features from the model input and number of views for current input
|
530 |
+
multi_view_features = model_input.features
|
531 |
+
num_of_views = len(multi_view_features)
|
532 |
+
batch_size, _, height, width = multi_view_features[0].shape
|
533 |
+
num_of_tokens_per_view = height * width
|
534 |
+
|
535 |
+
# Stack the multi-view features (N, C, H, W) to (N, V, C, H, W) (assumes all V views have same shape)
|
536 |
+
multi_view_features = torch.stack(multi_view_features, dim=1)
|
537 |
+
|
538 |
+
# Resize the multi-view features from NVCHW to NLC, where L = V * H * W
|
539 |
+
multi_view_features = multi_view_features.permute(0, 1, 3, 4, 2) # (N, V, H, W, C)
|
540 |
+
multi_view_features = multi_view_features.reshape(
|
541 |
+
batch_size, num_of_views * height * width, self.input_embed_dim
|
542 |
+
).contiguous()
|
543 |
+
|
544 |
+
# Process additional input tokens if provided
|
545 |
+
if model_input.additional_input_tokens is not None:
|
546 |
+
|
547 |
+
additional_tokens = model_input.additional_input_tokens
|
548 |
+
assert additional_tokens.ndim == 3, "Additional tokens must have 3 dimensions (N, C, T)"
|
549 |
+
assert (
|
550 |
+
additional_tokens.shape[1] == self.input_embed_dim
|
551 |
+
), f"Additional tokens must have input dimension {self.input_embed_dim}"
|
552 |
+
assert additional_tokens.shape[0] == batch_size, "Batch size mismatch for additional tokens"
|
553 |
+
|
554 |
+
# Reshape to channel-last format for transformer processing
|
555 |
+
additional_tokens = additional_tokens.permute(0, 2, 1).contiguous() # (N, C, T) -> (N, T, C)
|
556 |
+
|
557 |
+
# Concatenate the additional tokens to the multi-view features
|
558 |
+
multi_view_features = torch.cat([multi_view_features, additional_tokens], dim=1)
|
559 |
+
|
560 |
+
# Project input features to the transformer dimension
|
561 |
+
multi_view_features = self.proj_embed(multi_view_features)
|
562 |
+
|
563 |
+
# Create patch positions for each view if custom positional encoding is used
|
564 |
+
if self.custom_positional_encoding is not None:
|
565 |
+
multi_view_positions = [
|
566 |
+
self.position_getter(batch_size, height, width, multi_view_features.device)
|
567 |
+
] * num_of_views # List of length V, where each tensor is (N, H * W, C)
|
568 |
+
multi_view_positions = torch.cat(multi_view_positions, dim=1) # (N, V * H * W, C)
|
569 |
+
else:
|
570 |
+
multi_view_positions = [None] * num_of_views
|
571 |
+
|
572 |
+
# Add None positions for additional tokens if they exist
|
573 |
+
if model_input.additional_input_tokens is not None:
|
574 |
+
|
575 |
+
additional_tokens_positions = [None] * model_input.additional_input_tokens.shape[1]
|
576 |
+
multi_view_positions = multi_view_positions + additional_tokens_positions
|
577 |
+
|
578 |
+
# Add positional encoding for reference view (idx 0)
|
579 |
+
ref_view_pe = self.view_pos_table[0].clone().detach()
|
580 |
+
ref_view_pe = ref_view_pe.reshape((1, 1, self.dim))
|
581 |
+
ref_view_pe = ref_view_pe.repeat(batch_size, num_of_tokens_per_view, 1)
|
582 |
+
ref_view_features = multi_view_features[:, :num_of_tokens_per_view, :]
|
583 |
+
ref_view_features = ref_view_features + ref_view_pe
|
584 |
+
|
585 |
+
if self.use_pe_for_non_reference_views:
|
586 |
+
# Add positional encoding for non-reference views (sequential indices starting from idx 1 or random indices which are uniformly sampled)
|
587 |
+
if self.use_rand_idx_pe_for_non_reference_views:
|
588 |
+
non_ref_view_pe_indices = torch.randint(low=1, high=self.max_num_views_for_pe, size=(num_of_views - 1,))
|
589 |
+
else:
|
590 |
+
non_ref_view_pe_indices = torch.arange(1, num_of_views)
|
591 |
+
non_ref_view_pe = self.view_pos_table[non_ref_view_pe_indices].clone().detach()
|
592 |
+
non_ref_view_pe = non_ref_view_pe.reshape((1, num_of_views - 1, self.dim))
|
593 |
+
non_ref_view_pe = non_ref_view_pe.repeat_interleave(num_of_tokens_per_view, dim=1)
|
594 |
+
non_ref_view_pe = non_ref_view_pe.repeat(batch_size, 1, 1)
|
595 |
+
non_ref_view_features = multi_view_features[
|
596 |
+
:, num_of_tokens_per_view : num_of_views * num_of_tokens_per_view, :
|
597 |
+
]
|
598 |
+
non_ref_view_features = non_ref_view_features + non_ref_view_pe
|
599 |
+
else:
|
600 |
+
non_ref_view_features = multi_view_features[
|
601 |
+
:, num_of_tokens_per_view : num_of_views * num_of_tokens_per_view, :
|
602 |
+
]
|
603 |
+
|
604 |
+
# Concatenate the reference and non-reference view features
|
605 |
+
# Handle additional tokens (no view-based positional encoding for them)
|
606 |
+
if model_input.additional_input_tokens is not None:
|
607 |
+
|
608 |
+
additional_features = multi_view_features[:, num_of_views * num_of_tokens_per_view :, :]
|
609 |
+
multi_view_features = torch.cat([ref_view_features, non_ref_view_features, additional_features], dim=1)
|
610 |
+
else:
|
611 |
+
multi_view_features = torch.cat([ref_view_features, non_ref_view_features], dim=1)
|
612 |
+
|
613 |
+
# Loop over the depth of the transformer
|
614 |
+
for depth_idx in range(self.depth):
|
615 |
+
if depth_idx % 2 == 0:
|
616 |
+
# Apply the self-attention block and update the multi-view features
|
617 |
+
# Global attention across all views
|
618 |
+
multi_view_features = self.self_attention_blocks[depth_idx](multi_view_features, multi_view_positions)
|
619 |
+
else:
|
620 |
+
# Handle additional tokens separately for frame-level attention
|
621 |
+
additional_features = None
|
622 |
+
additional_positions = None
|
623 |
+
if model_input.additional_input_tokens is not None:
|
624 |
+
|
625 |
+
# Extract additional token features
|
626 |
+
additional_features = multi_view_features[:, num_of_views * num_of_tokens_per_view :, :]
|
627 |
+
# Keep only view features for frame-level attention
|
628 |
+
multi_view_features = multi_view_features[:, : num_of_views * num_of_tokens_per_view, :]
|
629 |
+
|
630 |
+
# Handle positions for additional tokens if custom positional encoding is used
|
631 |
+
if self.custom_positional_encoding is not None:
|
632 |
+
additional_positions = multi_view_positions[:, num_of_views * num_of_tokens_per_view :, :]
|
633 |
+
multi_view_positions = multi_view_positions[:, : num_of_views * num_of_tokens_per_view, :]
|
634 |
+
|
635 |
+
# Reshape the multi-view features from (N, V * H * W, C) to (N * V, H * W, C)
|
636 |
+
multi_view_features = multi_view_features.reshape(
|
637 |
+
batch_size * num_of_views, num_of_tokens_per_view, self.dim
|
638 |
+
).contiguous() # (N * V, H * W, C)
|
639 |
+
if multi_view_positions[0] is not None:
|
640 |
+
multi_view_positions = multi_view_positions.reshape(
|
641 |
+
batch_size * num_of_views, num_of_tokens_per_view, 2
|
642 |
+
).contiguous() # (N * V, H * W, C)
|
643 |
+
|
644 |
+
# Apply the self-attention block and update the multi-view features
|
645 |
+
# Frame-level attention within each view
|
646 |
+
multi_view_features = self.self_attention_blocks[depth_idx](multi_view_features, multi_view_positions)
|
647 |
+
|
648 |
+
# Reshape the multi-view features from (N * V, H * W, C) back to (N, V * H * W, C)
|
649 |
+
multi_view_features = multi_view_features.reshape(
|
650 |
+
batch_size, num_of_views * num_of_tokens_per_view, self.dim
|
651 |
+
).contiguous() # (N, V * H * W, C)
|
652 |
+
if multi_view_positions[0] is not None:
|
653 |
+
multi_view_positions = multi_view_positions.reshape(
|
654 |
+
batch_size, num_of_views * num_of_tokens_per_view, 2
|
655 |
+
).contiguous() # (N, V * H * W, C)
|
656 |
+
|
657 |
+
# Reattach additional tokens if they exist
|
658 |
+
if additional_features is not None:
|
659 |
+
multi_view_features = torch.cat([multi_view_features, additional_features], dim=1)
|
660 |
+
# Reattach positions for additional tokens if they exist
|
661 |
+
if additional_positions is not None:
|
662 |
+
multi_view_positions = torch.cat([multi_view_positions, additional_positions], dim=1)
|
663 |
+
if depth_idx in take_indices:
|
664 |
+
# Normalize the intermediate features with final norm layer if enabled
|
665 |
+
intermediate_multi_view_features.append(
|
666 |
+
self.norm(multi_view_features) if self.norm_intermediate else multi_view_features
|
667 |
+
)
|
668 |
+
|
669 |
+
# Reshape the intermediate features and convert to MultiViewTransformerOutput class
|
670 |
+
for idx in range(len(intermediate_multi_view_features)):
|
671 |
+
# Get the current intermediate features
|
672 |
+
current_features = intermediate_multi_view_features[idx]
|
673 |
+
|
674 |
+
# Extract additional token features if provided
|
675 |
+
additional_token_features = None
|
676 |
+
if model_input.additional_input_tokens is not None:
|
677 |
+
|
678 |
+
additional_token_features = current_features[:, num_of_views * num_of_tokens_per_view :, :]
|
679 |
+
additional_token_features = additional_token_features.permute(0, 2, 1).contiguous() # (N, C, T)
|
680 |
+
# Only keep the view features for reshaping
|
681 |
+
current_features = current_features[:, : num_of_views * num_of_tokens_per_view, :]
|
682 |
+
|
683 |
+
# Reshape the intermediate multi-view features (N, V * H * W, C) back to (N, V, C, H, W)
|
684 |
+
current_features = current_features.reshape(
|
685 |
+
batch_size, num_of_views, height, width, self.dim
|
686 |
+
) # (N, V, H, W, C)
|
687 |
+
current_features = current_features.permute(0, 1, 4, 2, 3).contiguous() # (N, V, C, H, W)
|
688 |
+
|
689 |
+
# Split the intermediate multi-view features into separate views
|
690 |
+
current_features = current_features.split(1, dim=1)
|
691 |
+
current_features = [
|
692 |
+
intermediate_view_features.squeeze(dim=1) for intermediate_view_features in current_features
|
693 |
+
]
|
694 |
+
|
695 |
+
intermediate_multi_view_features[idx] = MultiViewTransformerOutput(
|
696 |
+
features=current_features, additional_token_features=additional_token_features
|
697 |
+
)
|
698 |
+
|
699 |
+
# Return only the intermediate features if enabled
|
700 |
+
if self.intermediates_only:
|
701 |
+
return intermediate_multi_view_features
|
702 |
+
|
703 |
+
# Normalize the output features
|
704 |
+
output_multi_view_features = self.norm(multi_view_features)
|
705 |
+
|
706 |
+
# Extract view features (excluding additional tokens)
|
707 |
+
additional_token_features = None
|
708 |
+
if model_input.additional_input_tokens is not None:
|
709 |
+
|
710 |
+
additional_token_features = output_multi_view_features[:, num_of_views * num_of_tokens_per_view :, :]
|
711 |
+
additional_token_features = additional_token_features.permute(0, 2, 1).contiguous() # (N, C, T)
|
712 |
+
view_features = output_multi_view_features[:, : num_of_views * num_of_tokens_per_view, :]
|
713 |
+
else:
|
714 |
+
view_features = output_multi_view_features
|
715 |
+
|
716 |
+
# Reshape the output multi-view features (N, V * H * W, C) back to (N, V, C, H, W)
|
717 |
+
view_features = view_features.reshape(batch_size, num_of_views, height, width, self.dim) # (N, V, H, W, C)
|
718 |
+
view_features = view_features.permute(0, 1, 4, 2, 3).contiguous() # (N, V, C, H, W)
|
719 |
+
|
720 |
+
# Split the output multi-view features into separate views
|
721 |
+
view_features = view_features.split(1, dim=1)
|
722 |
+
view_features = [output_view_features.squeeze(dim=1) for output_view_features in view_features]
|
723 |
+
|
724 |
+
output_multi_view_features = MultiViewTransformerOutput(
|
725 |
+
features=view_features, additional_token_features=additional_token_features
|
726 |
+
)
|
727 |
+
|
728 |
+
return output_multi_view_features, intermediate_multi_view_features
|
729 |
+
|
730 |
+
|
731 |
+
def dummy_positional_encoding(x, xpos):
|
732 |
+
"Dummy function for positional encoding of tokens"
|
733 |
+
x = x
|
734 |
+
xpos = xpos
|
735 |
+
return x
|
736 |
+
|
737 |
+
|
738 |
+
def test_reshape_for_frame_attention():
|
739 |
+
"Test the reshape function for frame-level attention in the Alternating Attention Transformer"
|
740 |
+
batch_size = 2
|
741 |
+
num_of_views = 3
|
742 |
+
height = width = 2
|
743 |
+
dim = 4
|
744 |
+
num_of_tokens_per_view = height * width
|
745 |
+
|
746 |
+
# Create tensor with recognizable pattern
|
747 |
+
x = torch.zeros(batch_size, num_of_views * num_of_tokens_per_view, dim)
|
748 |
+
for b in range(batch_size):
|
749 |
+
for v in range(num_of_views):
|
750 |
+
for h in range(height):
|
751 |
+
for w in range(width):
|
752 |
+
token_idx = v * num_of_tokens_per_view + h * width + w
|
753 |
+
x[b, token_idx] = torch.tensor([b, v, h, w])
|
754 |
+
|
755 |
+
# Apply reshape
|
756 |
+
reshaped = x.reshape(batch_size * num_of_views, num_of_tokens_per_view, dim).contiguous()
|
757 |
+
|
758 |
+
# Verify shape
|
759 |
+
assert reshaped.shape == (batch_size * num_of_views, num_of_tokens_per_view, dim)
|
760 |
+
|
761 |
+
# Verify content (check a few values)
|
762 |
+
for b in range(batch_size):
|
763 |
+
for v in range(num_of_views):
|
764 |
+
for h in range(height):
|
765 |
+
for w in range(width):
|
766 |
+
batch_view_idx = b * num_of_views + v
|
767 |
+
token_idx = h * width + w
|
768 |
+
expected = torch.tensor([b, v, h, w])
|
769 |
+
assert torch.all(reshaped[batch_view_idx, token_idx] == expected)
|
770 |
+
|
771 |
+
# Verify reshape back works
|
772 |
+
back_to_original = reshaped.reshape(batch_size, num_of_views * num_of_tokens_per_view, dim)
|
773 |
+
assert torch.all(x == back_to_original)
|
774 |
+
|
775 |
+
print("Reshape test passed!")
|
776 |
+
|
777 |
+
|
778 |
+
if __name__ == "__main__":
|
779 |
+
# Unit test the reshape logic used for frame-level attention
|
780 |
+
test_reshape_for_frame_attention()
|
781 |
+
|
782 |
+
# Init multi-view alternating-attention transformer with no custom positional encoding and run a forward pass
|
783 |
+
for num_views in [2, 3, 4]:
|
784 |
+
print(f"Testing MultiViewAlternatingAttentionTransformer with {num_views} views ...")
|
785 |
+
# No positional encoding for non-reference views
|
786 |
+
model = MultiViewAlternatingAttentionTransformer(
|
787 |
+
name="MV-AAT",
|
788 |
+
input_embed_dim=1024,
|
789 |
+
)
|
790 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
|
791 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
792 |
+
model_output = model(model_input)
|
793 |
+
assert len(model_output.features) == num_views
|
794 |
+
assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
|
795 |
+
# Sequential idx based positional encoding
|
796 |
+
model = MultiViewAlternatingAttentionTransformer(
|
797 |
+
name="MV-AAT",
|
798 |
+
input_embed_dim=1024,
|
799 |
+
use_pe_for_non_reference_views=True,
|
800 |
+
max_num_views_for_pe=1000,
|
801 |
+
use_rand_idx_pe_for_non_reference_views=False,
|
802 |
+
)
|
803 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
|
804 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
805 |
+
model_output = model(model_input)
|
806 |
+
assert len(model_output.features) == num_views
|
807 |
+
assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
|
808 |
+
# Random idx based positional encoding
|
809 |
+
model = MultiViewAlternatingAttentionTransformer(
|
810 |
+
name="MV-AAT",
|
811 |
+
input_embed_dim=1024,
|
812 |
+
use_pe_for_non_reference_views=True,
|
813 |
+
max_num_views_for_pe=1000,
|
814 |
+
use_rand_idx_pe_for_non_reference_views=True,
|
815 |
+
)
|
816 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
|
817 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
818 |
+
model_output = model(model_input)
|
819 |
+
assert len(model_output.features) == num_views
|
820 |
+
assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
|
821 |
+
|
822 |
+
# Init multi-view alternating-attention transformer with custom positional encoding and run a forward pass
|
823 |
+
for num_views in [2, 3, 4]:
|
824 |
+
print(
|
825 |
+
f"Testing MultiViewAlternatingAttentionTransformer with {num_views} views and custom positional encoding ..."
|
826 |
+
)
|
827 |
+
model = MultiViewAlternatingAttentionTransformer(
|
828 |
+
name="MV-AAT",
|
829 |
+
input_embed_dim=1024,
|
830 |
+
custom_positional_encoding=dummy_positional_encoding,
|
831 |
+
)
|
832 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
|
833 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
834 |
+
model_output = model(model_input)
|
835 |
+
assert len(model_output.features) == num_views
|
836 |
+
assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
|
837 |
+
|
838 |
+
print("All multi-view alternating-attention transformers initialized and tested successfully!")
|
839 |
+
|
840 |
+
# Intermediate Feature Returner Tests
|
841 |
+
print("Running Intermediate Feature Returner Tests ...")
|
842 |
+
|
843 |
+
# Run the intermediate feature returner with last-n index
|
844 |
+
model_intermediate_feature_returner = MultiViewAlternatingAttentionTransformerIFR(
|
845 |
+
name="MV-AAT-IFR",
|
846 |
+
input_embed_dim=1024,
|
847 |
+
indices=6, # Last 6 layers
|
848 |
+
)
|
849 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
|
850 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
851 |
+
output = model_intermediate_feature_returner(model_input)
|
852 |
+
assert isinstance(output, tuple)
|
853 |
+
assert isinstance(output[0], MultiViewTransformerOutput)
|
854 |
+
assert len(output[1]) == 6
|
855 |
+
assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1])
|
856 |
+
assert len(output[1][0].features) == 2
|
857 |
+
|
858 |
+
# Run the intermediate feature returner with specific indices
|
859 |
+
model_intermediate_feature_returner = MultiViewAlternatingAttentionTransformerIFR(
|
860 |
+
name="MV-AAT-IFR",
|
861 |
+
input_embed_dim=1024,
|
862 |
+
indices=[0, 2, 4, 6], # Specific indices
|
863 |
+
)
|
864 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
|
865 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
866 |
+
output = model_intermediate_feature_returner(model_input)
|
867 |
+
assert isinstance(output, tuple)
|
868 |
+
assert isinstance(output[0], MultiViewTransformerOutput)
|
869 |
+
assert len(output[1]) == 4
|
870 |
+
assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1])
|
871 |
+
assert len(output[1][0].features) == 2
|
872 |
+
|
873 |
+
# Test the normalizing of intermediate features
|
874 |
+
model_intermediate_feature_returner = MultiViewAlternatingAttentionTransformerIFR(
|
875 |
+
name="MV-AAT-IFR",
|
876 |
+
input_embed_dim=1024,
|
877 |
+
indices=[-1], # Last layer
|
878 |
+
norm_intermediate=False, # Disable normalization
|
879 |
+
)
|
880 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
|
881 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
882 |
+
output = model_intermediate_feature_returner(model_input)
|
883 |
+
for view_idx in range(2):
|
884 |
+
assert not torch.equal(
|
885 |
+
output[0].features[view_idx], output[1][-1].features[view_idx]
|
886 |
+
), "Final features and intermediate features (last layer) must be different."
|
887 |
+
|
888 |
+
model_intermediate_feature_returner = MultiViewAlternatingAttentionTransformerIFR(
|
889 |
+
name="MV-AAT-IFR",
|
890 |
+
input_embed_dim=1024,
|
891 |
+
indices=[-1], # Last layer
|
892 |
+
norm_intermediate=True,
|
893 |
+
)
|
894 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
|
895 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
896 |
+
output = model_intermediate_feature_returner(model_input)
|
897 |
+
for view_idx in range(2):
|
898 |
+
assert torch.equal(
|
899 |
+
output[0].features[view_idx], output[1][-1].features[view_idx]
|
900 |
+
), "Final features and intermediate features (last layer) must be same."
|
901 |
+
|
902 |
+
print("All Intermediate Feature Returner Tests passed!")
|
903 |
+
|
904 |
+
# Test additonal input tokens for MultiViewAlternatingAttentionTransformer
|
905 |
+
print("Testing MultiViewAlternatingAttentionTransformer with additional input tokens ...")
|
906 |
+
model = MultiViewAlternatingAttentionTransformer(
|
907 |
+
name="MV-AAT",
|
908 |
+
input_embed_dim=1024,
|
909 |
+
)
|
910 |
+
num_views = 2
|
911 |
+
num_additional_tokens = 5
|
912 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
|
913 |
+
additional_tokens = torch.rand(1, 1024, num_additional_tokens)
|
914 |
+
model_input = MultiViewTransformerInput(features=model_input, additional_input_tokens=additional_tokens)
|
915 |
+
model_output = model(model_input)
|
916 |
+
assert len(model_output.features) == num_views
|
917 |
+
assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
|
918 |
+
assert model_output.additional_token_features is not None
|
919 |
+
assert model_output.additional_token_features.shape == (1, model.dim, num_additional_tokens)
|
920 |
+
|
921 |
+
# Test additonal input tokens for MultiViewAlternatingAttentionTransformerIFR
|
922 |
+
print("Testing MultiViewAlternatingAttentionTransformerIFR with additional input tokens ...")
|
923 |
+
model_ifr = MultiViewAlternatingAttentionTransformerIFR(
|
924 |
+
name="MV-AAT-IFR",
|
925 |
+
input_embed_dim=1024,
|
926 |
+
indices=[0, 2, 4],
|
927 |
+
)
|
928 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
|
929 |
+
additional_tokens = torch.rand(1, 1024, num_additional_tokens)
|
930 |
+
model_input = MultiViewTransformerInput(features=model_input, additional_input_tokens=additional_tokens)
|
931 |
+
output = model_ifr(model_input)
|
932 |
+
assert isinstance(output, tuple)
|
933 |
+
assert isinstance(output[0], MultiViewTransformerOutput)
|
934 |
+
assert output[0].additional_token_features is not None
|
935 |
+
assert output[0].additional_token_features.shape == (1, model_ifr.dim, num_additional_tokens)
|
936 |
+
assert len(output[1]) == 3
|
937 |
+
assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1])
|
938 |
+
assert all(intermediate.additional_token_features is not None for intermediate in output[1])
|
939 |
+
assert all(
|
940 |
+
intermediate.additional_token_features.shape == (1, model_ifr.dim, num_additional_tokens)
|
941 |
+
for intermediate in output[1]
|
942 |
+
)
|
943 |
+
|
944 |
+
print("All tests using additional input tokens passed!")
|
UniCeption/uniception/models/info_sharing/base.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Base Information Sharing Class for UniCeption
|
3 |
+
"""
|
4 |
+
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import List, Optional
|
7 |
+
|
8 |
+
import torch.nn as nn
|
9 |
+
from jaxtyping import Float
|
10 |
+
from torch import Tensor
|
11 |
+
from torch.utils.checkpoint import checkpoint
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class InfoSharingInput:
|
16 |
+
pass
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class InfoSharingOutput:
|
21 |
+
pass
|
22 |
+
|
23 |
+
|
24 |
+
class UniCeptionInfoSharingBase(nn.Module):
|
25 |
+
"Information Sharing Base Class for UniCeption"
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
name: str,
|
30 |
+
size: Optional[str] = None,
|
31 |
+
*args,
|
32 |
+
**kwargs,
|
33 |
+
):
|
34 |
+
"""
|
35 |
+
Base class for all models in UniCeption.
|
36 |
+
"""
|
37 |
+
super().__init__(*args, **kwargs)
|
38 |
+
|
39 |
+
self.name: str = name
|
40 |
+
self.size: Optional[str] = size
|
41 |
+
|
42 |
+
def forward(
|
43 |
+
self,
|
44 |
+
model_input: InfoSharingInput,
|
45 |
+
) -> InfoSharingOutput:
|
46 |
+
"""
|
47 |
+
Forward interface for the UniCeption information sharing models.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
model_input (InfoSharingInput): Input to the model.
|
51 |
+
This is also includes the other fields that are required by the specific implementation of the model.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
InfoSharingOutput: Output of the model.
|
55 |
+
"""
|
56 |
+
|
57 |
+
raise NotImplementedError
|
58 |
+
|
59 |
+
def wrap_module_with_gradient_checkpointing(self, module: nn.Module):
|
60 |
+
"""
|
61 |
+
Wrapper for Gradient Checkpointing
|
62 |
+
"""
|
63 |
+
|
64 |
+
class _CheckpointingWrapper(module.__class__):
|
65 |
+
_restore_cls = module.__class__
|
66 |
+
|
67 |
+
def forward(self, *args, **kwargs):
|
68 |
+
return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
|
69 |
+
|
70 |
+
module.__class__ = _CheckpointingWrapper
|
71 |
+
return module
|
72 |
+
|
73 |
+
|
74 |
+
@dataclass
|
75 |
+
class MultiViewTransformerInput(InfoSharingInput):
|
76 |
+
"""
|
77 |
+
Input class for Multi-View Transformer.
|
78 |
+
"""
|
79 |
+
|
80 |
+
features: List[Float[Tensor, "batch input_embed_dim feat_height feat_width"]]
|
81 |
+
additional_input_tokens: Optional[Float[Tensor, "batch input_embed_dim num_additional_tokens"]] = None
|
82 |
+
|
83 |
+
|
84 |
+
@dataclass
|
85 |
+
class MultiViewTransformerOutput(InfoSharingOutput):
|
86 |
+
"""
|
87 |
+
Output class for Multi-View Transformer.
|
88 |
+
"""
|
89 |
+
|
90 |
+
features: List[Float[Tensor, "batch transformer_embed_dim feat_height feat_width"]]
|
91 |
+
additional_token_features: Optional[Float[Tensor, "batch transformer_embed_dim num_additional_tokens"]] = None
|
92 |
+
|
93 |
+
|
94 |
+
@dataclass
|
95 |
+
class MultiSetTransformerInput(InfoSharingInput):
|
96 |
+
"""
|
97 |
+
Input class for Multi-Set Transformer.
|
98 |
+
"""
|
99 |
+
|
100 |
+
features: List[Float[Tensor, "batch input_embed_dim num_tokens"]]
|
101 |
+
additional_input_tokens: Optional[Float[Tensor, "batch input_embed_dim num_additional_tokens"]] = None
|
102 |
+
|
103 |
+
|
104 |
+
@dataclass
|
105 |
+
class MultiSetTransformerOutput(InfoSharingOutput):
|
106 |
+
"""
|
107 |
+
Output class for Multi-Set Transformer.
|
108 |
+
"""
|
109 |
+
|
110 |
+
features: List[Float[Tensor, "batch transformer_embed_dim num_tokens"]]
|
111 |
+
additional_token_features: Optional[Float[Tensor, "batch transformer_embed_dim num_additional_tokens"]] = None
|
112 |
+
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
dummy_model = UniCeptionInfoSharingBase(name="dummy")
|
116 |
+
print("Dummy Base InfoSharing model created successfully!")
|
UniCeption/uniception/models/info_sharing/cross_attention_transformer.py
ADDED
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
UniCeption Cross-Attention Transformer for Information Sharing
|
3 |
+
"""
|
4 |
+
|
5 |
+
from copy import deepcopy
|
6 |
+
from functools import partial
|
7 |
+
from typing import Callable, List, Optional, Tuple, Type, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from uniception.models.info_sharing.base import (
|
13 |
+
MultiViewTransformerInput,
|
14 |
+
MultiViewTransformerOutput,
|
15 |
+
UniCeptionInfoSharingBase,
|
16 |
+
)
|
17 |
+
from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner, feature_take_indices
|
18 |
+
from uniception.models.utils.positional_encoding import PositionGetter
|
19 |
+
from uniception.models.utils.transformer_blocks import CrossAttentionBlock, Mlp
|
20 |
+
|
21 |
+
|
22 |
+
class MultiViewCrossAttentionTransformer(UniCeptionInfoSharingBase):
|
23 |
+
"UniCeption Multi-View Cross-Attention Transformer for information sharing across image features from different views."
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
name: str,
|
28 |
+
input_embed_dim: int,
|
29 |
+
num_views: int,
|
30 |
+
size: Optional[str] = None,
|
31 |
+
depth: int = 12,
|
32 |
+
dim: int = 768,
|
33 |
+
num_heads: int = 12,
|
34 |
+
mlp_ratio: float = 4.0,
|
35 |
+
qkv_bias: bool = True,
|
36 |
+
qk_norm: bool = False,
|
37 |
+
proj_drop: float = 0.0,
|
38 |
+
attn_drop: float = 0.0,
|
39 |
+
init_values: Optional[float] = None,
|
40 |
+
drop_path: float = 0.0,
|
41 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
42 |
+
norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(nn.LayerNorm, eps=1e-6),
|
43 |
+
mlp_layer: Type[nn.Module] = Mlp,
|
44 |
+
custom_positional_encoding: Optional[Callable] = None,
|
45 |
+
norm_cross_tokens: bool = True,
|
46 |
+
pretrained_checkpoint_path: Optional[str] = None,
|
47 |
+
gradient_checkpointing: bool = False,
|
48 |
+
*args,
|
49 |
+
**kwargs,
|
50 |
+
):
|
51 |
+
"""
|
52 |
+
Initialize the Multi-View Cross-Attention Transformer for information sharing across image features from different views.
|
53 |
+
Creates a cross-attention transformer with multiple branches for each view.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
input_embed_dim (int): Dimension of input embeddings.
|
57 |
+
num_views (int): Number of views (input feature sets).
|
58 |
+
size (str): String to indicate interpretable size of the transformer (for e.g., base, large, ...). (default: None)
|
59 |
+
depth (int): Number of transformer layers. (default: 12, base size)
|
60 |
+
dim (int): Dimension of the transformer. (default: 768, base size)
|
61 |
+
num_heads (int): Number of attention heads. (default: 12, base size)
|
62 |
+
mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.)
|
63 |
+
qkv_bias (bool): Whether to include bias in qkv projection (default: True)
|
64 |
+
qk_norm (bool): Whether to normalize q and k (default: False)
|
65 |
+
proj_drop (float): Dropout rate for output (default: 0.)
|
66 |
+
attn_drop (float): Dropout rate for attention weights (default: 0.)
|
67 |
+
init_values (float): Initial value for LayerScale gamma (default: None)
|
68 |
+
drop_path (float): Dropout rate for stochastic depth (default: 0.)
|
69 |
+
act_layer (nn.Module): Activation layer (default: nn.GELU)
|
70 |
+
norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm)
|
71 |
+
mlp_layer (nn.Module): MLP layer (default: Mlp)
|
72 |
+
custom_positional_encoding (Callable): Custom positional encoding function (default: None)
|
73 |
+
norm_cross_tokens (bool): Whether to normalize cross tokens (default: True)
|
74 |
+
pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None)
|
75 |
+
gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False)
|
76 |
+
"""
|
77 |
+
# Initialize the base class
|
78 |
+
super().__init__(name=name, size=size, *args, **kwargs)
|
79 |
+
|
80 |
+
# Initialize the specific attributes of the transformer
|
81 |
+
self.input_embed_dim = input_embed_dim
|
82 |
+
self.num_views = num_views
|
83 |
+
self.depth = depth
|
84 |
+
self.dim = dim
|
85 |
+
self.num_heads = num_heads
|
86 |
+
self.mlp_ratio = mlp_ratio
|
87 |
+
self.qkv_bias = qkv_bias
|
88 |
+
self.qk_norm = qk_norm
|
89 |
+
self.proj_drop = proj_drop
|
90 |
+
self.attn_drop = attn_drop
|
91 |
+
self.init_values = init_values
|
92 |
+
self.drop_path = drop_path
|
93 |
+
self.act_layer = act_layer
|
94 |
+
self.norm_layer = norm_layer
|
95 |
+
self.mlp_layer = mlp_layer
|
96 |
+
self.custom_positional_encoding = custom_positional_encoding
|
97 |
+
self.norm_cross_tokens = norm_cross_tokens
|
98 |
+
self.pretrained_checkpoint_path = pretrained_checkpoint_path
|
99 |
+
self.gradient_checkpointing = gradient_checkpointing
|
100 |
+
|
101 |
+
# Initialize the projection layer for input embeddings
|
102 |
+
if self.input_embed_dim != self.dim:
|
103 |
+
self.proj_embed = nn.Linear(self.input_embed_dim, self.dim, bias=True)
|
104 |
+
else:
|
105 |
+
self.proj_embed = nn.Identity()
|
106 |
+
|
107 |
+
# Initialize the cross-attention blocks for a single view
|
108 |
+
cross_attention_blocks = nn.ModuleList(
|
109 |
+
[
|
110 |
+
CrossAttentionBlock(
|
111 |
+
dim=self.dim,
|
112 |
+
num_heads=self.num_heads,
|
113 |
+
mlp_ratio=self.mlp_ratio,
|
114 |
+
qkv_bias=self.qkv_bias,
|
115 |
+
qk_norm=self.qk_norm,
|
116 |
+
proj_drop=self.proj_drop,
|
117 |
+
attn_drop=self.attn_drop,
|
118 |
+
init_values=self.init_values,
|
119 |
+
drop_path=self.drop_path,
|
120 |
+
act_layer=self.act_layer,
|
121 |
+
norm_layer=self.norm_layer,
|
122 |
+
mlp_layer=self.mlp_layer,
|
123 |
+
custom_positional_encoding=self.custom_positional_encoding,
|
124 |
+
norm_cross_tokens=self.norm_cross_tokens,
|
125 |
+
)
|
126 |
+
for _ in range(self.depth)
|
127 |
+
]
|
128 |
+
)
|
129 |
+
|
130 |
+
# Copy the cross-attention blocks for all other views
|
131 |
+
self.multi_view_branches = nn.ModuleList([cross_attention_blocks])
|
132 |
+
for _ in range(1, self.num_views):
|
133 |
+
self.multi_view_branches.append(deepcopy(cross_attention_blocks))
|
134 |
+
|
135 |
+
# Initialize the final normalization layer
|
136 |
+
self.norm = self.norm_layer(self.dim)
|
137 |
+
|
138 |
+
# Initialize the position getter for patch positions if required
|
139 |
+
if self.custom_positional_encoding is not None:
|
140 |
+
self.position_getter = PositionGetter()
|
141 |
+
|
142 |
+
# Initialize random weights
|
143 |
+
self.initialize_weights()
|
144 |
+
|
145 |
+
# Apply gradient checkpointing if enabled
|
146 |
+
if self.gradient_checkpointing:
|
147 |
+
for i, block in enumerate(self.cross_attention_blocks):
|
148 |
+
self.cross_attention_blocks[i] = self.wrap_module_with_gradient_checkpointing(block)
|
149 |
+
|
150 |
+
# Load pretrained weights if provided
|
151 |
+
if self.pretrained_checkpoint_path is not None:
|
152 |
+
print(
|
153 |
+
f"Loading pretrained multi-view cross-attention transformer weights from {self.pretrained_checkpoint_path} ..."
|
154 |
+
)
|
155 |
+
ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
|
156 |
+
print(self.load_state_dict(ckpt["model"]))
|
157 |
+
|
158 |
+
def initialize_weights(self):
|
159 |
+
"Initialize weights of the transformer."
|
160 |
+
# Linears and layer norms
|
161 |
+
self.apply(self._init_weights)
|
162 |
+
|
163 |
+
def _init_weights(self, m):
|
164 |
+
"Initialize the transformer linear and layer norm weights."
|
165 |
+
if isinstance(m, nn.Linear):
|
166 |
+
# We use xavier_uniform following official JAX ViT:
|
167 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
168 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
169 |
+
nn.init.constant_(m.bias, 0)
|
170 |
+
elif isinstance(m, nn.LayerNorm):
|
171 |
+
nn.init.constant_(m.bias, 0)
|
172 |
+
nn.init.constant_(m.weight, 1.0)
|
173 |
+
|
174 |
+
def forward(
|
175 |
+
self,
|
176 |
+
model_input: MultiViewTransformerInput,
|
177 |
+
) -> MultiViewTransformerOutput:
|
178 |
+
"""
|
179 |
+
Forward interface for the Multi-View Cross-Attention Transformer.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
model_input (MultiViewTransformerInput): Input to the model.
|
183 |
+
Expects the features to be a list of size (batch, input_embed_dim, height, width),
|
184 |
+
where each entry corresponds to a different view.
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
MultiViewTransformerOutput: Output of the model post information sharing.
|
188 |
+
"""
|
189 |
+
# Check that the number of views matches the input and the features are of expected shape
|
190 |
+
assert (
|
191 |
+
len(model_input.features) == self.num_views
|
192 |
+
), f"Expected {self.num_views} views, got {len(model_input.features)}"
|
193 |
+
assert all(
|
194 |
+
view_features.shape[1] == self.input_embed_dim for view_features in model_input.features
|
195 |
+
), f"All views must have input dimension {self.input_embed_dim}"
|
196 |
+
assert all(
|
197 |
+
view_features.ndim == 4 for view_features in model_input.features
|
198 |
+
), "All views must have 4 dimensions (N, C, H, W)"
|
199 |
+
|
200 |
+
# Initialize the multi-view features from the model input
|
201 |
+
multi_view_features = model_input.features
|
202 |
+
|
203 |
+
# Resize the multi-view features from NCHW to NLC
|
204 |
+
batch_size, _, height, width = multi_view_features[0].shape
|
205 |
+
multi_view_features = [
|
206 |
+
view_features.permute(0, 2, 3, 1).reshape(batch_size, height * width, self.input_embed_dim).contiguous()
|
207 |
+
for view_features in multi_view_features
|
208 |
+
]
|
209 |
+
|
210 |
+
# Create patch positions for each view if custom positional encoding is used
|
211 |
+
if self.custom_positional_encoding is not None:
|
212 |
+
multi_view_positions = [
|
213 |
+
self.position_getter(batch_size, height, width, view_features.device)
|
214 |
+
for view_features in multi_view_features
|
215 |
+
]
|
216 |
+
else:
|
217 |
+
multi_view_positions = [None] * self.num_views
|
218 |
+
|
219 |
+
# Project input features to the transformer dimension
|
220 |
+
multi_view_features = [self.proj_embed(view_features) for view_features in multi_view_features]
|
221 |
+
|
222 |
+
# Pass through each view's cross-attention blocks
|
223 |
+
# Loop over the depth of the transformer
|
224 |
+
for depth_idx in range(self.depth):
|
225 |
+
updated_multi_view_features = []
|
226 |
+
# Loop over each view
|
227 |
+
for view_idx, view_features in enumerate(multi_view_features):
|
228 |
+
# Get all the other views
|
229 |
+
other_views_features = [multi_view_features[i] for i in range(self.num_views) if i != view_idx]
|
230 |
+
# Concatenate all the tokens from the other views
|
231 |
+
other_views_features = torch.cat(other_views_features, dim=1)
|
232 |
+
# Get the positions for the current view
|
233 |
+
view_positions = multi_view_positions[view_idx]
|
234 |
+
# Get the positions for all other views
|
235 |
+
other_views_positions = (
|
236 |
+
torch.cat([multi_view_positions[i] for i in range(self.num_views) if i != view_idx], dim=1)
|
237 |
+
if view_positions is not None
|
238 |
+
else None
|
239 |
+
)
|
240 |
+
# Apply the cross-attention block and update the multi-view features
|
241 |
+
updated_view_features = self.multi_view_branches[view_idx][depth_idx](
|
242 |
+
view_features, other_views_features, view_positions, other_views_positions
|
243 |
+
)
|
244 |
+
# Keep track of the updated view features
|
245 |
+
updated_multi_view_features.append(updated_view_features)
|
246 |
+
# Update the multi-view features for the next depth
|
247 |
+
multi_view_features = updated_multi_view_features
|
248 |
+
|
249 |
+
# Normalize the output features
|
250 |
+
output_multi_view_features = [self.norm(view_features) for view_features in multi_view_features]
|
251 |
+
|
252 |
+
# Resize the output multi-view features back to NCHW
|
253 |
+
output_multi_view_features = [
|
254 |
+
view_features.reshape(batch_size, height, width, self.dim).permute(0, 3, 1, 2).contiguous()
|
255 |
+
for view_features in output_multi_view_features
|
256 |
+
]
|
257 |
+
|
258 |
+
return MultiViewTransformerOutput(features=output_multi_view_features)
|
259 |
+
|
260 |
+
|
261 |
+
class MultiViewCrossAttentionTransformerIFR(MultiViewCrossAttentionTransformer, IntermediateFeatureReturner):
|
262 |
+
"Intermediate Feature Returner for UniCeption Multi-View Cross-Attention Transformer"
|
263 |
+
|
264 |
+
def __init__(
|
265 |
+
self,
|
266 |
+
name: str,
|
267 |
+
input_embed_dim: int,
|
268 |
+
num_views: int,
|
269 |
+
size: Optional[str] = None,
|
270 |
+
depth: int = 12,
|
271 |
+
dim: int = 768,
|
272 |
+
num_heads: int = 12,
|
273 |
+
mlp_ratio: float = 4.0,
|
274 |
+
qkv_bias: bool = True,
|
275 |
+
qk_norm: bool = False,
|
276 |
+
proj_drop: float = 0.0,
|
277 |
+
attn_drop: float = 0.0,
|
278 |
+
init_values: Optional[float] = None,
|
279 |
+
drop_path: float = 0.0,
|
280 |
+
act_layer: nn.Module = nn.GELU,
|
281 |
+
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
|
282 |
+
mlp_layer: nn.Module = Mlp,
|
283 |
+
custom_positional_encoding: Callable = None,
|
284 |
+
norm_cross_tokens: bool = True,
|
285 |
+
pretrained_checkpoint_path: str = None,
|
286 |
+
indices: Optional[Union[int, List[int]]] = None,
|
287 |
+
norm_intermediate: bool = True,
|
288 |
+
intermediates_only: bool = False,
|
289 |
+
gradient_checkpointing: bool = False,
|
290 |
+
*args,
|
291 |
+
**kwargs,
|
292 |
+
):
|
293 |
+
"""
|
294 |
+
Initialize the Multi-View Cross-Attention Transformer for information sharing across image features from different views.
|
295 |
+
Creates a cross-attention transformer with multiple branches for each view.
|
296 |
+
Extends the base class to return intermediate features.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
input_embed_dim (int): Dimension of input embeddings.
|
300 |
+
num_views (int): Number of views (input feature sets).
|
301 |
+
size (str): String to indicate interpretable size of the transformer (for e.g., base, large, ...). (default: None)
|
302 |
+
depth (int): Number of transformer layers. (default: 12, base size)
|
303 |
+
dim (int): Dimension of the transformer. (default: 768, base size)
|
304 |
+
num_heads (int): Number of attention heads. (default: 12, base size)
|
305 |
+
mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.)
|
306 |
+
qkv_bias (bool): Whether to include bias in qkv projection (default: True)
|
307 |
+
qk_norm (bool): Whether to normalize q and k (default: False)
|
308 |
+
proj_drop (float): Dropout rate for output (default: 0.)
|
309 |
+
attn_drop (float): Dropout rate for attention weights (default: 0.)
|
310 |
+
init_values (float): Initial value for LayerScale gamma (default: None)
|
311 |
+
drop_path (float): Dropout rate for stochastic depth (default: 0.)
|
312 |
+
act_layer (nn.Module): Activation layer (default: nn.GELU)
|
313 |
+
norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm)
|
314 |
+
mlp_layer (nn.Module): MLP layer (default: Mlp)
|
315 |
+
custom_positional_encoding (Callable): Custom positional encoding function (default: None)
|
316 |
+
norm_cross_tokens (bool): Whether to normalize cross tokens (default: True)
|
317 |
+
pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None)
|
318 |
+
indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. (default: None) Options:
|
319 |
+
- None: Return all intermediate layers.
|
320 |
+
- int: Return the last n layers.
|
321 |
+
- List[int]: Return the intermediate layers at the specified indices.
|
322 |
+
norm_intermediate (bool, optional): Whether to normalize the intermediate features. (default: True)
|
323 |
+
intermediates_only (bool, optional): Whether to return only the intermediate features. (default: False)
|
324 |
+
gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False)
|
325 |
+
"""
|
326 |
+
# Init the base classes
|
327 |
+
MultiViewCrossAttentionTransformer.__init__(
|
328 |
+
self,
|
329 |
+
name=name,
|
330 |
+
input_embed_dim=input_embed_dim,
|
331 |
+
num_views=num_views,
|
332 |
+
size=size,
|
333 |
+
depth=depth,
|
334 |
+
dim=dim,
|
335 |
+
num_heads=num_heads,
|
336 |
+
mlp_ratio=mlp_ratio,
|
337 |
+
qkv_bias=qkv_bias,
|
338 |
+
qk_norm=qk_norm,
|
339 |
+
proj_drop=proj_drop,
|
340 |
+
attn_drop=attn_drop,
|
341 |
+
init_values=init_values,
|
342 |
+
drop_path=drop_path,
|
343 |
+
act_layer=act_layer,
|
344 |
+
norm_layer=norm_layer,
|
345 |
+
mlp_layer=mlp_layer,
|
346 |
+
custom_positional_encoding=custom_positional_encoding,
|
347 |
+
norm_cross_tokens=norm_cross_tokens,
|
348 |
+
pretrained_checkpoint_path=pretrained_checkpoint_path,
|
349 |
+
gradient_checkpointing=gradient_checkpointing,
|
350 |
+
*args,
|
351 |
+
**kwargs,
|
352 |
+
)
|
353 |
+
IntermediateFeatureReturner.__init__(
|
354 |
+
self,
|
355 |
+
indices=indices,
|
356 |
+
norm_intermediate=norm_intermediate,
|
357 |
+
intermediates_only=intermediates_only,
|
358 |
+
)
|
359 |
+
|
360 |
+
def forward(
|
361 |
+
self,
|
362 |
+
model_input: MultiViewTransformerInput,
|
363 |
+
) -> Union[
|
364 |
+
List[MultiViewTransformerOutput],
|
365 |
+
Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]],
|
366 |
+
]:
|
367 |
+
"""
|
368 |
+
Forward interface for the Multi-View Cross-Attention Transformer with Intermediate Feature Return.
|
369 |
+
|
370 |
+
Args:
|
371 |
+
model_input (MultiViewTransformerInput): Input to the model.
|
372 |
+
Expects the features to be a list of size (batch, input_embed_dim, height, width),
|
373 |
+
where each entry corresponds to a different view.
|
374 |
+
|
375 |
+
Returns:
|
376 |
+
Union[List[MultiViewTransformerOutput], Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]]]:
|
377 |
+
Output of the model post information sharing.
|
378 |
+
If intermediates_only is True, returns a list of intermediate outputs.
|
379 |
+
If intermediates_only is False, returns a tuple of final output and a list of intermediate outputs.
|
380 |
+
"""
|
381 |
+
# Check that the number of views matches the input and the features are of expected shape
|
382 |
+
assert (
|
383 |
+
len(model_input.features) == self.num_views
|
384 |
+
), f"Expected {self.num_views} views, got {len(model_input.features)}"
|
385 |
+
assert all(
|
386 |
+
view_features.shape[1] == self.input_embed_dim for view_features in model_input.features
|
387 |
+
), f"All views must have input dimension {self.input_embed_dim}"
|
388 |
+
assert all(
|
389 |
+
view_features.ndim == 4 for view_features in model_input.features
|
390 |
+
), "All views must have 4 dimensions (N, C, H, W)"
|
391 |
+
|
392 |
+
# Get the indices of the intermediate features to return
|
393 |
+
intermediate_multi_view_features = []
|
394 |
+
take_indices, _ = feature_take_indices(self.depth, self.indices)
|
395 |
+
|
396 |
+
# Initialize the multi-view features from the model input
|
397 |
+
multi_view_features = model_input.features
|
398 |
+
|
399 |
+
# Resize the multi-view features from NCHW to NLC
|
400 |
+
batch_size, _, height, width = multi_view_features[0].shape
|
401 |
+
multi_view_features = [
|
402 |
+
view_features.permute(0, 2, 3, 1).reshape(batch_size, height * width, self.input_embed_dim).contiguous()
|
403 |
+
for view_features in multi_view_features
|
404 |
+
]
|
405 |
+
|
406 |
+
# Create patch positions for each view if custom positional encoding is used
|
407 |
+
if self.custom_positional_encoding is not None:
|
408 |
+
multi_view_positions = [
|
409 |
+
self.position_getter(batch_size, height, width, view_features.device)
|
410 |
+
for view_features in multi_view_features
|
411 |
+
]
|
412 |
+
else:
|
413 |
+
multi_view_positions = [None] * self.num_views
|
414 |
+
|
415 |
+
# Project input features to the transformer dimension
|
416 |
+
multi_view_features = [self.proj_embed(view_features) for view_features in multi_view_features]
|
417 |
+
|
418 |
+
# Pass through each view's cross-attention blocks
|
419 |
+
# Loop over the depth of the transformer
|
420 |
+
for depth_idx in range(self.depth):
|
421 |
+
updated_multi_view_features = []
|
422 |
+
# Loop over each view
|
423 |
+
for view_idx, view_features in enumerate(multi_view_features):
|
424 |
+
# Get all the other views
|
425 |
+
other_views_features = [multi_view_features[i] for i in range(self.num_views) if i != view_idx]
|
426 |
+
# Concatenate all the tokens from the other views
|
427 |
+
other_views_features = torch.cat(other_views_features, dim=1)
|
428 |
+
# Get the positions for the current view
|
429 |
+
view_positions = multi_view_positions[view_idx]
|
430 |
+
# Get the positions for all other views
|
431 |
+
other_views_positions = (
|
432 |
+
torch.cat([multi_view_positions[i] for i in range(self.num_views) if i != view_idx], dim=1)
|
433 |
+
if view_positions is not None
|
434 |
+
else None
|
435 |
+
)
|
436 |
+
# Apply the cross-attention block and update the multi-view features
|
437 |
+
updated_view_features = self.multi_view_branches[view_idx][depth_idx](
|
438 |
+
view_features, other_views_features, view_positions, other_views_positions
|
439 |
+
)
|
440 |
+
# Keep track of the updated view features
|
441 |
+
updated_multi_view_features.append(updated_view_features)
|
442 |
+
# Update the multi-view features for the next depth
|
443 |
+
multi_view_features = updated_multi_view_features
|
444 |
+
# Append the intermediate features if required
|
445 |
+
if depth_idx in take_indices:
|
446 |
+
# Normalize the intermediate features with final norm layer if enabled
|
447 |
+
intermediate_multi_view_features.append(
|
448 |
+
[self.norm(view_features) for view_features in multi_view_features]
|
449 |
+
if self.norm_intermediate
|
450 |
+
else multi_view_features
|
451 |
+
)
|
452 |
+
|
453 |
+
# Reshape the intermediate features and convert to MultiViewTransformerOutput class
|
454 |
+
for idx in range(len(intermediate_multi_view_features)):
|
455 |
+
intermediate_multi_view_features[idx] = [
|
456 |
+
view_features.reshape(batch_size, height, width, self.dim).permute(0, 3, 1, 2).contiguous()
|
457 |
+
for view_features in intermediate_multi_view_features[idx]
|
458 |
+
]
|
459 |
+
intermediate_multi_view_features[idx] = MultiViewTransformerOutput(
|
460 |
+
features=intermediate_multi_view_features[idx]
|
461 |
+
)
|
462 |
+
|
463 |
+
# Return only the intermediate features if enabled
|
464 |
+
if self.intermediates_only:
|
465 |
+
return intermediate_multi_view_features
|
466 |
+
|
467 |
+
# Normalize the output features
|
468 |
+
output_multi_view_features = [self.norm(view_features) for view_features in multi_view_features]
|
469 |
+
|
470 |
+
# Resize the output multi-view features back to NCHW
|
471 |
+
output_multi_view_features = [
|
472 |
+
view_features.reshape(batch_size, height, width, self.dim).permute(0, 3, 1, 2).contiguous()
|
473 |
+
for view_features in output_multi_view_features
|
474 |
+
]
|
475 |
+
|
476 |
+
output_multi_view_features = MultiViewTransformerOutput(features=output_multi_view_features)
|
477 |
+
|
478 |
+
return output_multi_view_features, intermediate_multi_view_features
|
479 |
+
|
480 |
+
|
481 |
+
def dummy_positional_encoding(x, xpos):
|
482 |
+
"Dummy function for positional encoding of tokens"
|
483 |
+
x = x
|
484 |
+
xpos = xpos
|
485 |
+
return x
|
486 |
+
|
487 |
+
|
488 |
+
if __name__ == "__main__":
|
489 |
+
# Init multi-view cross-attention transformer with no custom positional encoding and run a forward pass
|
490 |
+
for num_views in [2, 3, 4]:
|
491 |
+
print(f"Testing MultiViewCrossAttentionTransformer with {num_views} views ...")
|
492 |
+
model = MultiViewCrossAttentionTransformer(name="MV-CAT", input_embed_dim=1024, num_views=num_views)
|
493 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
|
494 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
495 |
+
model_output = model(model_input)
|
496 |
+
assert len(model_output.features) == num_views
|
497 |
+
assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
|
498 |
+
|
499 |
+
# Init multi-view cross-attention transformer with custom positional encoding and run a forward pass
|
500 |
+
for num_views in [2, 3, 4]:
|
501 |
+
print(f"Testing MultiViewCrossAttentionTransformer with {num_views} views and custom positional encoding ...")
|
502 |
+
model = MultiViewCrossAttentionTransformer(
|
503 |
+
name="MV-CAT",
|
504 |
+
input_embed_dim=1024,
|
505 |
+
num_views=num_views,
|
506 |
+
custom_positional_encoding=dummy_positional_encoding,
|
507 |
+
)
|
508 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
|
509 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
510 |
+
model_output = model(model_input)
|
511 |
+
assert len(model_output.features) == num_views
|
512 |
+
assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
|
513 |
+
|
514 |
+
print("All multi-view cross-attention transformers initialized and tested successfully!")
|
515 |
+
|
516 |
+
# Intermediate Feature Returner Tests
|
517 |
+
print("Running Intermediate Feature Returner Tests ...")
|
518 |
+
|
519 |
+
# Run the intermediate feature returner with last-n index
|
520 |
+
model_intermediate_feature_returner = MultiViewCrossAttentionTransformerIFR(
|
521 |
+
name="MV-CAT-IFR",
|
522 |
+
input_embed_dim=1024,
|
523 |
+
num_views=2,
|
524 |
+
indices=6, # Last 6 layers
|
525 |
+
)
|
526 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
|
527 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
528 |
+
output = model_intermediate_feature_returner(model_input)
|
529 |
+
assert isinstance(output, tuple)
|
530 |
+
assert isinstance(output[0], MultiViewTransformerOutput)
|
531 |
+
assert len(output[1]) == 6
|
532 |
+
assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1])
|
533 |
+
assert len(output[1][0].features) == 2
|
534 |
+
|
535 |
+
# Run the intermediate feature returner with specific indices
|
536 |
+
model_intermediate_feature_returner = MultiViewCrossAttentionTransformerIFR(
|
537 |
+
name="MV-CAT-IFR",
|
538 |
+
input_embed_dim=1024,
|
539 |
+
num_views=2,
|
540 |
+
indices=[0, 2, 4, 6], # Specific indices
|
541 |
+
)
|
542 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
|
543 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
544 |
+
output = model_intermediate_feature_returner(model_input)
|
545 |
+
assert isinstance(output, tuple)
|
546 |
+
assert isinstance(output[0], MultiViewTransformerOutput)
|
547 |
+
assert len(output[1]) == 4
|
548 |
+
assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1])
|
549 |
+
assert len(output[1][0].features) == 2
|
550 |
+
|
551 |
+
# Test the normalizing of intermediate features
|
552 |
+
model_intermediate_feature_returner = MultiViewCrossAttentionTransformerIFR(
|
553 |
+
name="MV-CAT-IFR",
|
554 |
+
input_embed_dim=1024,
|
555 |
+
num_views=2,
|
556 |
+
indices=[-1], # Last layer
|
557 |
+
norm_intermediate=False, # Disable normalization
|
558 |
+
)
|
559 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
|
560 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
561 |
+
output = model_intermediate_feature_returner(model_input)
|
562 |
+
for view_idx in range(2):
|
563 |
+
assert not torch.equal(
|
564 |
+
output[0].features[view_idx], output[1][-1].features[view_idx]
|
565 |
+
), "Final features and intermediate features (last layer) must be different."
|
566 |
+
|
567 |
+
model_intermediate_feature_returner = MultiViewCrossAttentionTransformerIFR(
|
568 |
+
name="MV-CAT-IFR",
|
569 |
+
input_embed_dim=1024,
|
570 |
+
num_views=2,
|
571 |
+
indices=[-1], # Last layer
|
572 |
+
norm_intermediate=True,
|
573 |
+
)
|
574 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
|
575 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
576 |
+
output = model_intermediate_feature_returner(model_input)
|
577 |
+
for view_idx in range(2):
|
578 |
+
assert torch.equal(
|
579 |
+
output[0].features[view_idx], output[1][-1].features[view_idx]
|
580 |
+
), "Final features and intermediate features (last layer) must be same."
|
581 |
+
|
582 |
+
print("All Intermediate Feature Returner Tests passed!")
|
UniCeption/uniception/models/info_sharing/diff_cross_attention_transformer.py
ADDED
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
UniCeption Cross-Attention Transformer for Information Sharing
|
3 |
+
"""
|
4 |
+
|
5 |
+
from copy import deepcopy
|
6 |
+
from functools import partial
|
7 |
+
from typing import Callable, List, Optional, Tuple, Type, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from uniception.models.info_sharing.base import UniCeptionInfoSharingBase
|
13 |
+
from uniception.models.info_sharing.cross_attention_transformer import (
|
14 |
+
MultiViewTransformerInput,
|
15 |
+
MultiViewTransformerOutput,
|
16 |
+
PositionGetter,
|
17 |
+
)
|
18 |
+
from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner, feature_take_indices
|
19 |
+
from uniception.models.utils.transformer_blocks import DiffCrossAttentionBlock, Mlp
|
20 |
+
|
21 |
+
|
22 |
+
class DifferentialMultiViewCrossAttentionTransformer(UniCeptionInfoSharingBase):
|
23 |
+
"UniCeption Multi-View Cross-Attention Transformer for information sharing across image features from different views."
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
name: str,
|
28 |
+
input_embed_dim: int,
|
29 |
+
num_views: int,
|
30 |
+
size: Optional[str] = None,
|
31 |
+
depth: int = 12,
|
32 |
+
dim: int = 768,
|
33 |
+
num_heads: int = 12,
|
34 |
+
mlp_ratio: float = 4.0,
|
35 |
+
qkv_bias: bool = True,
|
36 |
+
qk_norm: bool = False,
|
37 |
+
proj_drop: float = 0.0,
|
38 |
+
attn_drop: float = 0.0,
|
39 |
+
init_values: Optional[float] = None,
|
40 |
+
drop_path: float = 0.0,
|
41 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
42 |
+
norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(nn.LayerNorm, eps=1e-6),
|
43 |
+
mlp_layer: Type[nn.Module] = Mlp,
|
44 |
+
custom_positional_encoding: Optional[Callable] = None,
|
45 |
+
norm_cross_tokens: bool = True,
|
46 |
+
pretrained_checkpoint_path: Optional[str] = None,
|
47 |
+
gradient_checkpointing: bool = False,
|
48 |
+
*args,
|
49 |
+
**kwargs,
|
50 |
+
):
|
51 |
+
"""
|
52 |
+
Initialize the Multi-View Cross-Attention Transformer for information sharing across image features from different views.
|
53 |
+
Creates a cross-attention transformer with multiple branches for each view.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
input_embed_dim (int): Dimension of input embeddings.
|
57 |
+
num_views (int): Number of views (input feature sets).
|
58 |
+
depth (int): Number of transformer layers. (default: 12, base size)
|
59 |
+
dim (int): Dimension of the transformer. (default: 768, base size)
|
60 |
+
num_heads (int): Number of attention heads. (default: 12, base size)
|
61 |
+
mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.)
|
62 |
+
qkv_bias (bool): Whether to include bias in qkv projection (default: False)
|
63 |
+
qk_norm (bool): Whether to normalize q and k (default: False)
|
64 |
+
proj_drop (float): Dropout rate for output (default: 0.)
|
65 |
+
attn_drop (float): Dropout rate for attention weights (default: 0.)
|
66 |
+
init_values (float): Initial value for LayerScale gamma (default: None)
|
67 |
+
drop_path (float): Dropout rate for stochastic depth (default: 0.)
|
68 |
+
act_layer (nn.Module): Activation layer (default: nn.GELU)
|
69 |
+
norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm)
|
70 |
+
mlp_layer (nn.Module): MLP layer (default: Mlp)
|
71 |
+
custom_positional_encoding (Callable): Custom positional encoding function (default: None)
|
72 |
+
norm_cross_tokens (bool): Whether to normalize cross tokens (default: True)
|
73 |
+
pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None)
|
74 |
+
gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False)
|
75 |
+
"""
|
76 |
+
# Initialize the base class
|
77 |
+
super().__init__(name=name, size=size, *args, **kwargs)
|
78 |
+
|
79 |
+
# Initialize the specific attributes of the transformer
|
80 |
+
self.input_embed_dim = input_embed_dim
|
81 |
+
self.num_views = num_views
|
82 |
+
self.depth = depth
|
83 |
+
self.dim = dim
|
84 |
+
self.num_heads = num_heads
|
85 |
+
self.mlp_ratio = mlp_ratio
|
86 |
+
self.qkv_bias = qkv_bias
|
87 |
+
self.qk_norm = qk_norm
|
88 |
+
self.proj_drop = proj_drop
|
89 |
+
self.attn_drop = attn_drop
|
90 |
+
self.init_values = init_values
|
91 |
+
self.drop_path = drop_path
|
92 |
+
self.act_layer = act_layer
|
93 |
+
self.norm_layer = norm_layer
|
94 |
+
self.mlp_layer = mlp_layer
|
95 |
+
self.custom_positional_encoding = custom_positional_encoding
|
96 |
+
self.norm_cross_tokens = norm_cross_tokens
|
97 |
+
self.pretrained_checkpoint_path = pretrained_checkpoint_path
|
98 |
+
self.gradient_checkpointing = gradient_checkpointing
|
99 |
+
|
100 |
+
# Initialize the projection layer for input embeddings
|
101 |
+
if self.input_embed_dim != self.dim:
|
102 |
+
self.proj_embed = nn.Linear(self.input_embed_dim, self.dim, bias=True)
|
103 |
+
else:
|
104 |
+
self.proj_embed = nn.Identity()
|
105 |
+
|
106 |
+
# Initialize the cross-attention blocks for a single view
|
107 |
+
assert num_heads % 2 == 0, "Number of heads must be divisible by 2 for differential cross-attention."
|
108 |
+
cross_attention_blocks = nn.ModuleList(
|
109 |
+
[
|
110 |
+
DiffCrossAttentionBlock(
|
111 |
+
depth=i,
|
112 |
+
dim=self.dim,
|
113 |
+
num_heads=self.num_heads // 2,
|
114 |
+
mlp_ratio=self.mlp_ratio,
|
115 |
+
qkv_bias=self.qkv_bias,
|
116 |
+
qk_norm=self.qk_norm,
|
117 |
+
proj_drop=self.proj_drop,
|
118 |
+
attn_drop=self.attn_drop,
|
119 |
+
init_values=self.init_values,
|
120 |
+
drop_path=self.drop_path,
|
121 |
+
act_layer=self.act_layer,
|
122 |
+
norm_layer=self.norm_layer,
|
123 |
+
mlp_layer=self.mlp_layer,
|
124 |
+
custom_positional_encoding=self.custom_positional_encoding,
|
125 |
+
norm_cross_tokens=self.norm_cross_tokens,
|
126 |
+
)
|
127 |
+
for i in range(self.depth)
|
128 |
+
]
|
129 |
+
)
|
130 |
+
|
131 |
+
# Copy the cross-attention blocks for all other views
|
132 |
+
self.multi_view_branches = nn.ModuleList([cross_attention_blocks])
|
133 |
+
for _ in range(1, self.num_views):
|
134 |
+
self.multi_view_branches.append(deepcopy(cross_attention_blocks))
|
135 |
+
|
136 |
+
# Initialize the final normalization layer
|
137 |
+
self.norm = self.norm_layer(self.dim)
|
138 |
+
|
139 |
+
# Initialize the position getter for patch positions if required
|
140 |
+
if self.custom_positional_encoding is not None:
|
141 |
+
self.position_getter = PositionGetter()
|
142 |
+
|
143 |
+
# Initialize random weights
|
144 |
+
self.initialize_weights()
|
145 |
+
|
146 |
+
# Apply gradient checkpointing if enabled
|
147 |
+
if self.gradient_checkpointing:
|
148 |
+
for i, block in enumerate(self.cross_attention_blocks):
|
149 |
+
self.cross_attention_blocks[i] = self.wrap_module_with_gradient_checkpointing(block)
|
150 |
+
|
151 |
+
# Load pretrained weights if provided
|
152 |
+
if self.pretrained_checkpoint_path is not None:
|
153 |
+
print(
|
154 |
+
f"Loading pretrained multi-view cross-attention transformer weights from {self.pretrained_checkpoint_path} ..."
|
155 |
+
)
|
156 |
+
ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
|
157 |
+
print(self.load_state_dict(ckpt["model"]))
|
158 |
+
|
159 |
+
def initialize_weights(self):
|
160 |
+
"Initialize weights of the transformer."
|
161 |
+
# Linears and layer norms
|
162 |
+
self.apply(self._init_weights)
|
163 |
+
|
164 |
+
def _init_weights(self, m):
|
165 |
+
"Initialize the transformer linear and layer norm weights."
|
166 |
+
if isinstance(m, nn.Linear):
|
167 |
+
# We use xavier_uniform following official JAX ViT:
|
168 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
169 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
170 |
+
nn.init.constant_(m.bias, 0)
|
171 |
+
elif isinstance(m, nn.LayerNorm):
|
172 |
+
nn.init.constant_(m.bias, 0)
|
173 |
+
nn.init.constant_(m.weight, 1.0)
|
174 |
+
|
175 |
+
def forward(
|
176 |
+
self,
|
177 |
+
model_input: MultiViewTransformerInput,
|
178 |
+
) -> MultiViewTransformerOutput:
|
179 |
+
"""
|
180 |
+
Forward interface for the Multi-View Cross-Attention Transformer.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
model_input (MultiViewTransformerInput): Input to the model.
|
184 |
+
Expects the features to be a list of size (batch, input_embed_dim, height, width),
|
185 |
+
where each entry corresponds to a different view.
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
MultiViewTransformerOutput: Output of the model post information sharing.
|
189 |
+
"""
|
190 |
+
# Check that the number of views matches the input and the features are of expected shape
|
191 |
+
assert (
|
192 |
+
len(model_input.features) == self.num_views
|
193 |
+
), f"Expected {self.num_views} views, got {len(model_input.features)}"
|
194 |
+
assert all(
|
195 |
+
view_features.shape[1] == self.input_embed_dim for view_features in model_input.features
|
196 |
+
), f"All views must have input dimension {self.input_embed_dim}"
|
197 |
+
assert all(
|
198 |
+
view_features.ndim == 4 for view_features in model_input.features
|
199 |
+
), "All views must have 4 dimensions (N, C, H, W)"
|
200 |
+
|
201 |
+
# Initialize the multi-view features from the model input
|
202 |
+
multi_view_features = model_input.features
|
203 |
+
|
204 |
+
# Resize the multi-view features from NCHW to NLC
|
205 |
+
batch_size, _, height, width = multi_view_features[0].shape
|
206 |
+
multi_view_features = [
|
207 |
+
view_features.permute(0, 2, 3, 1).reshape(batch_size, height * width, self.input_embed_dim).contiguous()
|
208 |
+
for view_features in multi_view_features
|
209 |
+
]
|
210 |
+
|
211 |
+
# Create patch positions for each view if custom positional encoding is used
|
212 |
+
if self.custom_positional_encoding is not None:
|
213 |
+
multi_view_positions = [
|
214 |
+
self.position_getter(batch_size, height, width, view_features.device)
|
215 |
+
for view_features in multi_view_features
|
216 |
+
]
|
217 |
+
else:
|
218 |
+
multi_view_positions = [None] * self.num_views
|
219 |
+
|
220 |
+
# Project input features to the transformer dimension
|
221 |
+
multi_view_features = [self.proj_embed(view_features) for view_features in multi_view_features]
|
222 |
+
|
223 |
+
# Pass through each view's cross-attention blocks
|
224 |
+
# Loop over the depth of the transformer
|
225 |
+
for depth_idx in range(self.depth):
|
226 |
+
updated_multi_view_features = []
|
227 |
+
# Loop over each view
|
228 |
+
for view_idx, view_features in enumerate(multi_view_features):
|
229 |
+
# Get all the other views
|
230 |
+
other_views_features = [multi_view_features[i] for i in range(self.num_views) if i != view_idx]
|
231 |
+
# Concatenate all the tokens from the other views
|
232 |
+
other_views_features = torch.cat(other_views_features, dim=1)
|
233 |
+
# Get the positions for the current view
|
234 |
+
view_positions = multi_view_positions[view_idx]
|
235 |
+
# Get the positions for all other views
|
236 |
+
other_views_positions = (
|
237 |
+
torch.cat([multi_view_positions[i] for i in range(self.num_views) if i != view_idx], dim=1)
|
238 |
+
if view_positions is not None
|
239 |
+
else None
|
240 |
+
)
|
241 |
+
# Apply the cross-attention block and update the multi-view features
|
242 |
+
updated_view_features = self.multi_view_branches[view_idx][depth_idx](
|
243 |
+
view_features, other_views_features, view_positions, other_views_positions
|
244 |
+
)
|
245 |
+
# Keep track of the updated view features
|
246 |
+
updated_multi_view_features.append(updated_view_features)
|
247 |
+
# Update the multi-view features for the next depth
|
248 |
+
multi_view_features = updated_multi_view_features
|
249 |
+
|
250 |
+
# Normalize the output features
|
251 |
+
output_multi_view_features = [self.norm(view_features) for view_features in multi_view_features]
|
252 |
+
|
253 |
+
# Resize the output multi-view features back to NCHW
|
254 |
+
output_multi_view_features = [
|
255 |
+
view_features.reshape(batch_size, height, width, self.dim).permute(0, 3, 1, 2).contiguous()
|
256 |
+
for view_features in output_multi_view_features
|
257 |
+
]
|
258 |
+
|
259 |
+
return MultiViewTransformerOutput(features=output_multi_view_features)
|
260 |
+
|
261 |
+
|
262 |
+
class DifferentialMultiViewCrossAttentionTransformerIFR(
|
263 |
+
DifferentialMultiViewCrossAttentionTransformer, IntermediateFeatureReturner
|
264 |
+
):
|
265 |
+
"Intermediate Feature Returner for UniCeption Multi-View Cross-Attention Transformer"
|
266 |
+
|
267 |
+
def __init__(
|
268 |
+
self,
|
269 |
+
name: str,
|
270 |
+
input_embed_dim: int,
|
271 |
+
num_views: int,
|
272 |
+
size: Optional[str] = None,
|
273 |
+
depth: int = 12,
|
274 |
+
dim: int = 768,
|
275 |
+
num_heads: int = 12,
|
276 |
+
mlp_ratio: float = 4.0,
|
277 |
+
qkv_bias: bool = True,
|
278 |
+
qk_norm: bool = False,
|
279 |
+
proj_drop: float = 0.0,
|
280 |
+
attn_drop: float = 0.0,
|
281 |
+
init_values: Optional[float] = None,
|
282 |
+
drop_path: float = 0.0,
|
283 |
+
act_layer: nn.Module = nn.GELU,
|
284 |
+
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
|
285 |
+
mlp_layer: nn.Module = Mlp,
|
286 |
+
custom_positional_encoding: Callable = None,
|
287 |
+
norm_cross_tokens: bool = True,
|
288 |
+
pretrained_checkpoint_path: str = None,
|
289 |
+
indices: Optional[Union[int, List[int]]] = None,
|
290 |
+
norm_intermediate: bool = True,
|
291 |
+
intermediates_only: bool = False,
|
292 |
+
gradient_checkpointing: bool = False,
|
293 |
+
*args,
|
294 |
+
**kwargs,
|
295 |
+
):
|
296 |
+
"""
|
297 |
+
Initialize the Multi-View Cross-Attention Transformer for information sharing across image features from different views.
|
298 |
+
Creates a cross-attention transformer with multiple branches for each view.
|
299 |
+
Extends the base class to return intermediate features.
|
300 |
+
|
301 |
+
Args:
|
302 |
+
input_embed_dim (int): Dimension of input embeddings.
|
303 |
+
num_views (int): Number of views (input feature sets).
|
304 |
+
depth (int): Number of transformer layers. (default: 12, base size)
|
305 |
+
dim (int): Dimension of the transformer. (default: 768, base size)
|
306 |
+
num_heads (int): Number of attention heads. (default: 12, base size)
|
307 |
+
mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.)
|
308 |
+
qkv_bias (bool): Whether to include bias in qkv projection (default: False)
|
309 |
+
qk_norm (bool): Whether to normalize q and k (default: False)
|
310 |
+
proj_drop (float): Dropout rate for output (default: 0.)
|
311 |
+
attn_drop (float): Dropout rate for attention weights (default: 0.)
|
312 |
+
init_values (float): Initial value for LayerScale gamma (default: None)
|
313 |
+
drop_path (float): Dropout rate for stochastic depth (default: 0.)
|
314 |
+
act_layer (nn.Module): Activation layer (default: nn.GELU)
|
315 |
+
norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm)
|
316 |
+
mlp_layer (nn.Module): MLP layer (default: Mlp)
|
317 |
+
custom_positional_encoding (Callable): Custom positional encoding function (default: None)
|
318 |
+
norm_cross_tokens (bool): Whether to normalize cross tokens (default: True)
|
319 |
+
pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None)
|
320 |
+
indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. (default: None) Options:
|
321 |
+
- None: Return all intermediate layers.
|
322 |
+
- int: Return the last n layers.
|
323 |
+
- List[int]: Return the intermediate layers at the specified indices.
|
324 |
+
norm_intermediate (bool, optional): Whether to normalize the intermediate features. (default: True)
|
325 |
+
intermediates_only (bool, optional): Whether to return only the intermediate features. (default: False)
|
326 |
+
gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False)
|
327 |
+
"""
|
328 |
+
# Init the base classes
|
329 |
+
DifferentialMultiViewCrossAttentionTransformer.__init__(
|
330 |
+
self,
|
331 |
+
name=name,
|
332 |
+
input_embed_dim=input_embed_dim,
|
333 |
+
num_views=num_views,
|
334 |
+
size=size,
|
335 |
+
depth=depth,
|
336 |
+
dim=dim,
|
337 |
+
num_heads=num_heads,
|
338 |
+
mlp_ratio=mlp_ratio,
|
339 |
+
qkv_bias=qkv_bias,
|
340 |
+
qk_norm=qk_norm,
|
341 |
+
proj_drop=proj_drop,
|
342 |
+
attn_drop=attn_drop,
|
343 |
+
init_values=init_values,
|
344 |
+
drop_path=drop_path,
|
345 |
+
act_layer=act_layer,
|
346 |
+
norm_layer=norm_layer,
|
347 |
+
mlp_layer=mlp_layer,
|
348 |
+
custom_positional_encoding=custom_positional_encoding,
|
349 |
+
norm_cross_tokens=norm_cross_tokens,
|
350 |
+
pretrained_checkpoint_path=pretrained_checkpoint_path,
|
351 |
+
gradient_checkpointing=gradient_checkpointing,
|
352 |
+
*args,
|
353 |
+
**kwargs,
|
354 |
+
)
|
355 |
+
IntermediateFeatureReturner.__init__(
|
356 |
+
self,
|
357 |
+
indices=indices,
|
358 |
+
norm_intermediate=norm_intermediate,
|
359 |
+
intermediates_only=intermediates_only,
|
360 |
+
)
|
361 |
+
|
362 |
+
def forward(
|
363 |
+
self,
|
364 |
+
model_input: MultiViewTransformerInput,
|
365 |
+
) -> Union[
|
366 |
+
List[MultiViewTransformerOutput],
|
367 |
+
Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]],
|
368 |
+
]:
|
369 |
+
"""
|
370 |
+
Forward interface for the Multi-View Cross-Attention Transformer with Intermediate Feature Return.
|
371 |
+
|
372 |
+
Args:
|
373 |
+
model_input (MultiViewTransformerInput): Input to the model.
|
374 |
+
Expects the features to be a list of size (batch, input_embed_dim, height, width),
|
375 |
+
where each entry corresponds to a different view.
|
376 |
+
|
377 |
+
Returns:
|
378 |
+
Union[List[MultiViewTransformerOutput], Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]]]:
|
379 |
+
Output of the model post information sharing.
|
380 |
+
If intermediates_only is True, returns a list of intermediate outputs.
|
381 |
+
If intermediates_only is False, returns a tuple of final output and a list of intermediate outputs.
|
382 |
+
"""
|
383 |
+
# Check that the number of views matches the input and the features are of expected shape
|
384 |
+
assert (
|
385 |
+
len(model_input.features) == self.num_views
|
386 |
+
), f"Expected {self.num_views} views, got {len(model_input.features)}"
|
387 |
+
assert all(
|
388 |
+
view_features.shape[1] == self.input_embed_dim for view_features in model_input.features
|
389 |
+
), f"All views must have input dimension {self.input_embed_dim}"
|
390 |
+
assert all(
|
391 |
+
view_features.ndim == 4 for view_features in model_input.features
|
392 |
+
), "All views must have 4 dimensions (N, C, H, W)"
|
393 |
+
|
394 |
+
# Get the indices of the intermediate features to return
|
395 |
+
intermediate_multi_view_features = []
|
396 |
+
take_indices, _ = feature_take_indices(self.depth, self.indices)
|
397 |
+
|
398 |
+
# Initialize the multi-view features from the model input
|
399 |
+
multi_view_features = model_input.features
|
400 |
+
|
401 |
+
# Resize the multi-view features from NCHW to NLC
|
402 |
+
batch_size, _, height, width = multi_view_features[0].shape
|
403 |
+
multi_view_features = [
|
404 |
+
view_features.permute(0, 2, 3, 1).reshape(batch_size, height * width, self.input_embed_dim).contiguous()
|
405 |
+
for view_features in multi_view_features
|
406 |
+
]
|
407 |
+
|
408 |
+
# Create patch positions for each view if custom positional encoding is used
|
409 |
+
if self.custom_positional_encoding is not None:
|
410 |
+
multi_view_positions = [
|
411 |
+
self.position_getter(batch_size, height, width, view_features.device)
|
412 |
+
for view_features in multi_view_features
|
413 |
+
]
|
414 |
+
else:
|
415 |
+
multi_view_positions = [None] * self.num_views
|
416 |
+
|
417 |
+
# Project input features to the transformer dimension
|
418 |
+
multi_view_features = [self.proj_embed(view_features) for view_features in multi_view_features]
|
419 |
+
|
420 |
+
# Pass through each view's cross-attention blocks
|
421 |
+
# Loop over the depth of the transformer
|
422 |
+
for depth_idx in range(self.depth):
|
423 |
+
updated_multi_view_features = []
|
424 |
+
# Loop over each view
|
425 |
+
for view_idx, view_features in enumerate(multi_view_features):
|
426 |
+
# Get all the other views
|
427 |
+
other_views_features = [multi_view_features[i] for i in range(self.num_views) if i != view_idx]
|
428 |
+
# Concatenate all the tokens from the other views
|
429 |
+
other_views_features = torch.cat(other_views_features, dim=1)
|
430 |
+
# Get the positions for the current view
|
431 |
+
view_positions = multi_view_positions[view_idx]
|
432 |
+
# Get the positions for all other views
|
433 |
+
other_views_positions = (
|
434 |
+
torch.cat([multi_view_positions[i] for i in range(self.num_views) if i != view_idx], dim=1)
|
435 |
+
if view_positions is not None
|
436 |
+
else None
|
437 |
+
)
|
438 |
+
# Apply the cross-attention block and update the multi-view features
|
439 |
+
updated_view_features = self.multi_view_branches[view_idx][depth_idx](
|
440 |
+
view_features, other_views_features, view_positions, other_views_positions
|
441 |
+
)
|
442 |
+
# Keep track of the updated view features
|
443 |
+
updated_multi_view_features.append(updated_view_features)
|
444 |
+
# Update the multi-view features for the next depth
|
445 |
+
multi_view_features = updated_multi_view_features
|
446 |
+
# Append the intermediate features if required
|
447 |
+
if depth_idx in take_indices:
|
448 |
+
# Normalize the intermediate features with final norm layer if enabled
|
449 |
+
intermediate_multi_view_features.append(
|
450 |
+
[self.norm(view_features) for view_features in multi_view_features]
|
451 |
+
if self.norm_intermediate
|
452 |
+
else multi_view_features
|
453 |
+
)
|
454 |
+
|
455 |
+
# Reshape the intermediate features and convert to MultiViewTransformerOutput class
|
456 |
+
for idx in range(len(intermediate_multi_view_features)):
|
457 |
+
intermediate_multi_view_features[idx] = [
|
458 |
+
view_features.reshape(batch_size, height, width, self.dim).permute(0, 3, 1, 2).contiguous()
|
459 |
+
for view_features in intermediate_multi_view_features[idx]
|
460 |
+
]
|
461 |
+
intermediate_multi_view_features[idx] = MultiViewTransformerOutput(
|
462 |
+
features=intermediate_multi_view_features[idx]
|
463 |
+
)
|
464 |
+
|
465 |
+
# Return only the intermediate features if enabled
|
466 |
+
if self.intermediates_only:
|
467 |
+
return intermediate_multi_view_features
|
468 |
+
|
469 |
+
# Normalize the output features
|
470 |
+
output_multi_view_features = [self.norm(view_features) for view_features in multi_view_features]
|
471 |
+
|
472 |
+
# Resize the output multi-view features back to NCHW
|
473 |
+
output_multi_view_features = [
|
474 |
+
view_features.reshape(batch_size, height, width, self.dim).permute(0, 3, 1, 2).contiguous()
|
475 |
+
for view_features in output_multi_view_features
|
476 |
+
]
|
477 |
+
|
478 |
+
output_multi_view_features = MultiViewTransformerOutput(features=output_multi_view_features)
|
479 |
+
|
480 |
+
return output_multi_view_features, intermediate_multi_view_features
|
481 |
+
|
482 |
+
|
483 |
+
def dummy_positional_encoding(x, xpos):
|
484 |
+
"Dummy function for positional encoding of tokens"
|
485 |
+
x = x
|
486 |
+
xpos = xpos
|
487 |
+
return x
|
488 |
+
|
489 |
+
|
490 |
+
if __name__ == "__main__":
|
491 |
+
# Init multi-view cross-attention transformer with no custom positional encoding and run a forward pass
|
492 |
+
for num_views in [2, 3, 4]:
|
493 |
+
print(f"Testing MultiViewCrossAttentionTransformer with {num_views} views ...")
|
494 |
+
model = DifferentialMultiViewCrossAttentionTransformer(
|
495 |
+
name="MV-DCAT", input_embed_dim=1024, num_views=num_views
|
496 |
+
)
|
497 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
|
498 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
499 |
+
model_output = model(model_input)
|
500 |
+
assert len(model_output.features) == num_views
|
501 |
+
assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
|
502 |
+
|
503 |
+
# Init multi-view cross-attention transformer with custom positional encoding and run a forward pass
|
504 |
+
for num_views in [2, 3, 4]:
|
505 |
+
print(
|
506 |
+
f"Testing Differential MultiViewCrossAttentionTransformer with {num_views} views and custom positional encoding ..."
|
507 |
+
)
|
508 |
+
model = DifferentialMultiViewCrossAttentionTransformer(
|
509 |
+
name="MV-DCAT",
|
510 |
+
input_embed_dim=1024,
|
511 |
+
num_views=num_views,
|
512 |
+
custom_positional_encoding=dummy_positional_encoding,
|
513 |
+
)
|
514 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
|
515 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
516 |
+
model_output = model(model_input)
|
517 |
+
assert len(model_output.features) == num_views
|
518 |
+
assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
|
519 |
+
|
520 |
+
print("All multi-view cross-attention transformers initialized and tested successfully!")
|
521 |
+
|
522 |
+
# Intermediate Feature Returner Tests
|
523 |
+
print("Running Intermediate Feature Returner Tests ...")
|
524 |
+
|
525 |
+
# Run the intermediate feature returner with last-n index
|
526 |
+
model_intermediate_feature_returner = DifferentialMultiViewCrossAttentionTransformerIFR(
|
527 |
+
name="MV-DCAT-IFR",
|
528 |
+
input_embed_dim=1024,
|
529 |
+
num_views=2,
|
530 |
+
indices=6, # Last 6 layers
|
531 |
+
)
|
532 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
|
533 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
534 |
+
output = model_intermediate_feature_returner(model_input)
|
535 |
+
assert isinstance(output, tuple)
|
536 |
+
assert isinstance(output[0], MultiViewTransformerOutput)
|
537 |
+
assert len(output[1]) == 6
|
538 |
+
assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1])
|
539 |
+
assert len(output[1][0].features) == 2
|
540 |
+
|
541 |
+
# Run the intermediate feature returner with specific indices
|
542 |
+
model_intermediate_feature_returner = DifferentialMultiViewCrossAttentionTransformerIFR(
|
543 |
+
name="MV-DCAT-IFR",
|
544 |
+
input_embed_dim=1024,
|
545 |
+
num_views=2,
|
546 |
+
indices=[0, 2, 4, 6], # Specific indices
|
547 |
+
)
|
548 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
|
549 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
550 |
+
output = model_intermediate_feature_returner(model_input)
|
551 |
+
assert isinstance(output, tuple)
|
552 |
+
assert isinstance(output[0], MultiViewTransformerOutput)
|
553 |
+
assert len(output[1]) == 4
|
554 |
+
assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1])
|
555 |
+
assert len(output[1][0].features) == 2
|
556 |
+
|
557 |
+
# Test the normalizing of intermediate features
|
558 |
+
model_intermediate_feature_returner = DifferentialMultiViewCrossAttentionTransformerIFR(
|
559 |
+
name="MV-DCAT-IFR",
|
560 |
+
input_embed_dim=1024,
|
561 |
+
num_views=2,
|
562 |
+
indices=[-1], # Last layer
|
563 |
+
norm_intermediate=False, # Disable normalization
|
564 |
+
)
|
565 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
|
566 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
567 |
+
output = model_intermediate_feature_returner(model_input)
|
568 |
+
for view_idx in range(2):
|
569 |
+
assert not torch.equal(
|
570 |
+
output[0].features[view_idx], output[1][-1].features[view_idx]
|
571 |
+
), "Final features and intermediate features (last layer) must be different."
|
572 |
+
|
573 |
+
model_intermediate_feature_returner = DifferentialMultiViewCrossAttentionTransformerIFR(
|
574 |
+
name="MV-DCAT-IFR",
|
575 |
+
input_embed_dim=1024,
|
576 |
+
num_views=2,
|
577 |
+
indices=[-1], # Last layer
|
578 |
+
norm_intermediate=True,
|
579 |
+
)
|
580 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
|
581 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
582 |
+
output = model_intermediate_feature_returner(model_input)
|
583 |
+
for view_idx in range(2):
|
584 |
+
assert torch.equal(
|
585 |
+
output[0].features[view_idx], output[1][-1].features[view_idx]
|
586 |
+
), "Final features and intermediate features (last layer) must be same."
|
587 |
+
|
588 |
+
print("All Intermediate Feature Returner Tests passed!")
|
UniCeption/uniception/models/info_sharing/global_attention_transformer.py
ADDED
@@ -0,0 +1,1107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
UniCeption Global-Attention Transformer for Information Sharing
|
3 |
+
"""
|
4 |
+
|
5 |
+
from functools import partial
|
6 |
+
from typing import Callable, List, Optional, Tuple, Type, Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from uniception.models.info_sharing.base import (
|
13 |
+
MultiSetTransformerInput,
|
14 |
+
MultiSetTransformerOutput,
|
15 |
+
MultiViewTransformerInput,
|
16 |
+
MultiViewTransformerOutput,
|
17 |
+
UniCeptionInfoSharingBase,
|
18 |
+
)
|
19 |
+
from uniception.models.libs.croco.pos_embed import RoPE2D
|
20 |
+
from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner, feature_take_indices
|
21 |
+
from uniception.models.utils.positional_encoding import PositionGetter
|
22 |
+
from uniception.models.utils.transformer_blocks import Mlp, SelfAttentionBlock
|
23 |
+
|
24 |
+
|
25 |
+
class MultiViewGlobalAttentionTransformer(UniCeptionInfoSharingBase):
|
26 |
+
"UniCeption Multi-View Global-Attention Transformer for information sharing across image features from different views."
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
name: str,
|
31 |
+
input_embed_dim: int,
|
32 |
+
max_num_views: int,
|
33 |
+
use_rand_idx_pe_for_non_reference_views: bool,
|
34 |
+
size: Optional[str] = None,
|
35 |
+
depth: int = 12,
|
36 |
+
dim: int = 768,
|
37 |
+
num_heads: int = 12,
|
38 |
+
mlp_ratio: float = 4.0,
|
39 |
+
qkv_bias: bool = True,
|
40 |
+
qk_norm: bool = False,
|
41 |
+
proj_drop: float = 0.0,
|
42 |
+
attn_drop: float = 0.0,
|
43 |
+
init_values: Optional[float] = None,
|
44 |
+
drop_path: float = 0.0,
|
45 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
46 |
+
norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(nn.LayerNorm, eps=1e-6),
|
47 |
+
mlp_layer: Type[nn.Module] = Mlp,
|
48 |
+
custom_positional_encoding: Optional[Union[str, Callable]] = None,
|
49 |
+
pretrained_checkpoint_path: Optional[str] = None,
|
50 |
+
gradient_checkpointing: bool = False,
|
51 |
+
*args,
|
52 |
+
**kwargs,
|
53 |
+
):
|
54 |
+
"""
|
55 |
+
Initialize the Multi-View Global-Attention Transformer for information sharing across image features from different views.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
input_embed_dim (int): Dimension of input embeddings.
|
59 |
+
max_num_views (int): Maximum number of views for positional encoding.
|
60 |
+
use_rand_idx_pe_for_non_reference_views (bool): Whether to use random index positional encoding for non-reference views.
|
61 |
+
size (str): String to indicate interpretable size of the transformer (for e.g., base, large, ...). (default: None)
|
62 |
+
depth (int): Number of transformer layers. (default: 12, base size)
|
63 |
+
dim (int): Dimension of the transformer. (default: 768, base size)
|
64 |
+
num_heads (int): Number of attention heads. (default: 12, base size)
|
65 |
+
mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.)
|
66 |
+
qkv_bias (bool): Whether to include bias in qkv projection (default: True)
|
67 |
+
qk_norm (bool): Whether to normalize q and k (default: False)
|
68 |
+
proj_drop (float): Dropout rate for output (default: 0.)
|
69 |
+
attn_drop (float): Dropout rate for attention weights (default: 0.)
|
70 |
+
init_values (float): Initial value for LayerScale gamma (default: None)
|
71 |
+
drop_path (float): Dropout rate for stochastic depth (default: 0.)
|
72 |
+
act_layer (nn.Module): Activation layer (default: nn.GELU)
|
73 |
+
norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm)
|
74 |
+
mlp_layer (nn.Module): MLP layer (default: Mlp)
|
75 |
+
custom_positional_encoding (Callable): Custom positional encoding function (default: None)
|
76 |
+
pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None)
|
77 |
+
gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False)
|
78 |
+
"""
|
79 |
+
# Initialize the base class
|
80 |
+
super().__init__(name=name, size=size, *args, **kwargs)
|
81 |
+
|
82 |
+
# Initialize the specific attributes of the transformer
|
83 |
+
self.input_embed_dim = input_embed_dim
|
84 |
+
self.max_num_views = max_num_views
|
85 |
+
self.use_rand_idx_pe_for_non_reference_views = use_rand_idx_pe_for_non_reference_views
|
86 |
+
self.depth = depth
|
87 |
+
self.dim = dim
|
88 |
+
self.num_heads = num_heads
|
89 |
+
self.mlp_ratio = mlp_ratio
|
90 |
+
self.qkv_bias = qkv_bias
|
91 |
+
self.qk_norm = qk_norm
|
92 |
+
self.proj_drop = proj_drop
|
93 |
+
self.attn_drop = attn_drop
|
94 |
+
self.init_values = init_values
|
95 |
+
self.drop_path = drop_path
|
96 |
+
self.act_layer = act_layer
|
97 |
+
self.norm_layer = norm_layer
|
98 |
+
self.mlp_layer = mlp_layer
|
99 |
+
self.custom_positional_encoding = custom_positional_encoding
|
100 |
+
self.pretrained_checkpoint_path = pretrained_checkpoint_path
|
101 |
+
self.gradient_checkpointing = gradient_checkpointing
|
102 |
+
|
103 |
+
# Initialize the projection layer for input embeddings
|
104 |
+
if self.input_embed_dim != self.dim:
|
105 |
+
self.proj_embed = nn.Linear(self.input_embed_dim, self.dim, bias=True)
|
106 |
+
else:
|
107 |
+
self.proj_embed = nn.Identity()
|
108 |
+
|
109 |
+
# Initialize custom position encodings
|
110 |
+
if self.custom_positional_encoding is not None and isinstance(self.custom_positional_encoding, str):
|
111 |
+
if self.custom_positional_encoding == "rope":
|
112 |
+
self.rope = RoPE2D(freq=100.0, F0=1.0)
|
113 |
+
self.custom_positional_encoding = self.rope
|
114 |
+
else:
|
115 |
+
raise ValueError(f"Unknown custom positional encoding: {self.custom_positional_encoding}")
|
116 |
+
|
117 |
+
# Initialize the self-attention blocks which ingest all views at once
|
118 |
+
self.self_attention_blocks = nn.ModuleList(
|
119 |
+
[
|
120 |
+
SelfAttentionBlock(
|
121 |
+
dim=self.dim,
|
122 |
+
num_heads=self.num_heads,
|
123 |
+
mlp_ratio=self.mlp_ratio,
|
124 |
+
qkv_bias=self.qkv_bias,
|
125 |
+
qk_norm=self.qk_norm,
|
126 |
+
proj_drop=self.proj_drop,
|
127 |
+
attn_drop=self.attn_drop,
|
128 |
+
init_values=self.init_values,
|
129 |
+
drop_path=self.drop_path,
|
130 |
+
act_layer=self.act_layer,
|
131 |
+
norm_layer=self.norm_layer,
|
132 |
+
mlp_layer=self.mlp_layer,
|
133 |
+
custom_positional_encoding=self.custom_positional_encoding,
|
134 |
+
)
|
135 |
+
for _ in range(self.depth)
|
136 |
+
]
|
137 |
+
)
|
138 |
+
|
139 |
+
# Initialize the final normalization layer
|
140 |
+
self.norm = self.norm_layer(self.dim)
|
141 |
+
|
142 |
+
# Initialize the position getter for patch positions if required
|
143 |
+
if self.custom_positional_encoding is not None:
|
144 |
+
self.position_getter = PositionGetter()
|
145 |
+
|
146 |
+
# Initialize the positional encoding table for the different views
|
147 |
+
self.register_buffer(
|
148 |
+
"view_pos_table",
|
149 |
+
self._get_sinusoid_encoding_table(self.max_num_views, self.dim, 10000),
|
150 |
+
)
|
151 |
+
|
152 |
+
# Initialize random weights
|
153 |
+
self.initialize_weights()
|
154 |
+
|
155 |
+
# Load pretrained weights if provided
|
156 |
+
if self.pretrained_checkpoint_path is not None:
|
157 |
+
print(
|
158 |
+
f"Loading pretrained multi-view global-attention transformer weights from {self.pretrained_checkpoint_path} ..."
|
159 |
+
)
|
160 |
+
ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
|
161 |
+
print(self.load_state_dict(ckpt["model"]))
|
162 |
+
|
163 |
+
# Apply gradient checkpointing if enabled
|
164 |
+
if self.gradient_checkpointing:
|
165 |
+
for i, block in enumerate(self.self_attention_blocks):
|
166 |
+
self.self_attention_blocks[i] = self.wrap_module_with_gradient_checkpointing(block)
|
167 |
+
|
168 |
+
def _get_sinusoid_encoding_table(self, n_position, d_hid, base):
|
169 |
+
"Sinusoid position encoding table"
|
170 |
+
|
171 |
+
def get_position_angle_vec(position):
|
172 |
+
return [position / np.power(base, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
173 |
+
|
174 |
+
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
175 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
|
176 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])
|
177 |
+
|
178 |
+
return torch.FloatTensor(sinusoid_table)
|
179 |
+
|
180 |
+
def initialize_weights(self):
|
181 |
+
"Initialize weights of the transformer."
|
182 |
+
# Linears and layer norms
|
183 |
+
self.apply(self._init_weights)
|
184 |
+
|
185 |
+
def _init_weights(self, m):
|
186 |
+
"Initialize the transformer linear and layer norm weights."
|
187 |
+
if isinstance(m, nn.Linear):
|
188 |
+
# We use xavier_uniform following official JAX ViT:
|
189 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
190 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
191 |
+
nn.init.constant_(m.bias, 0)
|
192 |
+
elif isinstance(m, nn.LayerNorm):
|
193 |
+
nn.init.constant_(m.bias, 0)
|
194 |
+
nn.init.constant_(m.weight, 1.0)
|
195 |
+
|
196 |
+
def forward(
|
197 |
+
self,
|
198 |
+
model_input: MultiViewTransformerInput,
|
199 |
+
) -> MultiViewTransformerOutput:
|
200 |
+
"""
|
201 |
+
Forward interface for the Multi-View Global-Attention Transformer.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
model_input (MultiViewTransformerInput): Input to the model.
|
205 |
+
Expects the features to be a list of size (batch, input_embed_dim, height, width),
|
206 |
+
where each entry corresponds to a different view.
|
207 |
+
Optionally, the input can also include additional_input_tokens (e.g., class token, registers, pose tokens, scale token)
|
208 |
+
which are appended to the token set from the multi-view features. The tokens are of size (batch, input_embed_dim, num_of_additional_tokens).
|
209 |
+
|
210 |
+
Returns:
|
211 |
+
MultiViewTransformerOutput: Output of the model post information sharing.
|
212 |
+
"""
|
213 |
+
# Check that the number of views matches the input and the features are of expected shape
|
214 |
+
assert (
|
215 |
+
len(model_input.features) <= self.max_num_views
|
216 |
+
), f"Expected less than {self.max_num_views} views, got {len(model_input.features)}"
|
217 |
+
assert all(
|
218 |
+
curr_view_features.shape[1] == self.input_embed_dim for curr_view_features in model_input.features
|
219 |
+
), f"All views must have input dimension {self.input_embed_dim}"
|
220 |
+
assert all(
|
221 |
+
curr_view_features.ndim == 4 for curr_view_features in model_input.features
|
222 |
+
), "All views must have 4 dimensions (N, C, H, W)"
|
223 |
+
|
224 |
+
# Initialize the multi-view features from the model input and number of views for current input
|
225 |
+
multi_view_features = model_input.features
|
226 |
+
num_of_views = len(multi_view_features)
|
227 |
+
batch_size, _, height, width = multi_view_features[0].shape
|
228 |
+
num_of_tokens_per_view = height * width
|
229 |
+
|
230 |
+
# Stack the multi-view features (N, C, H, W) to (N, V, C, H, W) (assumes all V views have same shape)
|
231 |
+
multi_view_features = torch.stack(multi_view_features, dim=1)
|
232 |
+
|
233 |
+
# Resize the multi-view features from NVCHW to NLC, where L = V * H * W
|
234 |
+
multi_view_features = multi_view_features.permute(0, 1, 3, 4, 2) # (N, V, H, W, C)
|
235 |
+
multi_view_features = multi_view_features.reshape(
|
236 |
+
batch_size, num_of_views * height * width, self.input_embed_dim
|
237 |
+
).contiguous()
|
238 |
+
|
239 |
+
# Process additional input tokens if provided
|
240 |
+
if model_input.additional_input_tokens is not None:
|
241 |
+
additional_tokens = model_input.additional_input_tokens
|
242 |
+
assert additional_tokens.ndim == 3, "Additional tokens must have 3 dimensions (N, C, T)"
|
243 |
+
assert (
|
244 |
+
additional_tokens.shape[1] == self.input_embed_dim
|
245 |
+
), f"Additional tokens must have input dimension {self.input_embed_dim}"
|
246 |
+
assert additional_tokens.shape[0] == batch_size, "Batch size mismatch for additional tokens"
|
247 |
+
|
248 |
+
# Reshape to channel-last format for transformer processing
|
249 |
+
additional_tokens = additional_tokens.permute(0, 2, 1).contiguous() # (N, C, T) -> (N, T, C)
|
250 |
+
|
251 |
+
# Concatenate the additional tokens to the multi-view features
|
252 |
+
multi_view_features = torch.cat([multi_view_features, additional_tokens], dim=1)
|
253 |
+
|
254 |
+
# Project input features to the transformer dimension
|
255 |
+
multi_view_features = self.proj_embed(multi_view_features)
|
256 |
+
|
257 |
+
# Create patch positions for each view if custom positional encoding is used
|
258 |
+
if self.custom_positional_encoding is not None:
|
259 |
+
multi_view_positions = [
|
260 |
+
self.position_getter(batch_size, height, width, multi_view_features.device)
|
261 |
+
] * num_of_views # List of length V, where each tensor is (N, H * W, C)
|
262 |
+
multi_view_positions = torch.cat(multi_view_positions, dim=1) # (N, V * H * W, C)
|
263 |
+
else:
|
264 |
+
multi_view_positions = [None] * num_of_views
|
265 |
+
|
266 |
+
# Add None positions for additional tokens if they exist
|
267 |
+
if model_input.additional_input_tokens is not None:
|
268 |
+
additional_tokens_positions = [None] * model_input.additional_input_tokens.shape[1]
|
269 |
+
multi_view_positions = multi_view_positions + additional_tokens_positions
|
270 |
+
|
271 |
+
# Add positional encoding for reference view (idx 0)
|
272 |
+
ref_view_pe = self.view_pos_table[0].clone().detach()
|
273 |
+
ref_view_pe = ref_view_pe.reshape((1, 1, self.dim))
|
274 |
+
ref_view_pe = ref_view_pe.repeat(batch_size, num_of_tokens_per_view, 1)
|
275 |
+
ref_view_features = multi_view_features[:, :num_of_tokens_per_view, :]
|
276 |
+
ref_view_features = ref_view_features + ref_view_pe
|
277 |
+
|
278 |
+
# Add positional encoding for non-reference views (sequential indices starting from idx 1 or random indices which are uniformly sampled)
|
279 |
+
if self.use_rand_idx_pe_for_non_reference_views:
|
280 |
+
non_ref_view_pe_indices = torch.randint(low=1, high=self.max_num_views, size=(num_of_views - 1,))
|
281 |
+
else:
|
282 |
+
non_ref_view_pe_indices = torch.arange(1, num_of_views)
|
283 |
+
non_ref_view_pe = self.view_pos_table[non_ref_view_pe_indices].clone().detach()
|
284 |
+
non_ref_view_pe = non_ref_view_pe.reshape((1, num_of_views - 1, self.dim))
|
285 |
+
non_ref_view_pe = non_ref_view_pe.repeat_interleave(num_of_tokens_per_view, dim=1)
|
286 |
+
non_ref_view_pe = non_ref_view_pe.repeat(batch_size, 1, 1)
|
287 |
+
non_ref_view_features = multi_view_features[
|
288 |
+
:, num_of_tokens_per_view : num_of_views * num_of_tokens_per_view, :
|
289 |
+
]
|
290 |
+
non_ref_view_features = non_ref_view_features + non_ref_view_pe
|
291 |
+
|
292 |
+
# Concatenate the reference and non-reference view features
|
293 |
+
# Handle additional tokens (no view-based positional encoding for them)
|
294 |
+
if model_input.additional_input_tokens is not None:
|
295 |
+
additional_features = multi_view_features[:, num_of_views * num_of_tokens_per_view :, :]
|
296 |
+
multi_view_features = torch.cat([ref_view_features, non_ref_view_features, additional_features], dim=1)
|
297 |
+
else:
|
298 |
+
multi_view_features = torch.cat([ref_view_features, non_ref_view_features], dim=1)
|
299 |
+
|
300 |
+
# Loop over the depth of the transformer
|
301 |
+
for depth_idx in range(self.depth):
|
302 |
+
# Apply the self-attention block and update the multi-view features
|
303 |
+
multi_view_features = self.self_attention_blocks[depth_idx](multi_view_features, multi_view_positions)
|
304 |
+
|
305 |
+
# Normalize the output features
|
306 |
+
output_multi_view_features = self.norm(multi_view_features)
|
307 |
+
|
308 |
+
# Extract only the view features (excluding additional tokens)
|
309 |
+
view_features = output_multi_view_features[:, : num_of_views * num_of_tokens_per_view, :]
|
310 |
+
|
311 |
+
# Reshape the output multi-view features (N, V * H * W, C) back to (N, V, C, H, W)
|
312 |
+
view_features = view_features.reshape(batch_size, num_of_views, height, width, self.dim) # (N, V, H, W, C)
|
313 |
+
view_features = view_features.permute(0, 1, 4, 2, 3).contiguous() # (N, V, C, H, W)
|
314 |
+
|
315 |
+
# Split the output multi-view features into separate views
|
316 |
+
view_features = view_features.split(1, dim=1)
|
317 |
+
view_features = [output_view_features.squeeze(dim=1) for output_view_features in view_features]
|
318 |
+
|
319 |
+
# Extract and return additional token features if provided
|
320 |
+
if model_input.additional_input_tokens is not None:
|
321 |
+
additional_token_features = output_multi_view_features[:, num_of_views * num_of_tokens_per_view :, :]
|
322 |
+
additional_token_features = additional_token_features.permute(0, 2, 1).contiguous() # (N, C, T)
|
323 |
+
return MultiViewTransformerOutput(
|
324 |
+
features=view_features, additional_token_features=additional_token_features
|
325 |
+
)
|
326 |
+
else:
|
327 |
+
return MultiViewTransformerOutput(features=view_features)
|
328 |
+
|
329 |
+
|
330 |
+
class MultiViewGlobalAttentionTransformerIFR(MultiViewGlobalAttentionTransformer, IntermediateFeatureReturner):
|
331 |
+
"Intermediate Feature Returner for UniCeption Multi-View Global-Attention Transformer"
|
332 |
+
|
333 |
+
def __init__(
|
334 |
+
self,
|
335 |
+
name: str,
|
336 |
+
input_embed_dim: int,
|
337 |
+
max_num_views: int,
|
338 |
+
use_rand_idx_pe_for_non_reference_views: bool,
|
339 |
+
size: Optional[str] = None,
|
340 |
+
depth: int = 12,
|
341 |
+
dim: int = 768,
|
342 |
+
num_heads: int = 12,
|
343 |
+
mlp_ratio: float = 4.0,
|
344 |
+
qkv_bias: bool = True,
|
345 |
+
qk_norm: bool = False,
|
346 |
+
proj_drop: float = 0.0,
|
347 |
+
attn_drop: float = 0.0,
|
348 |
+
init_values: Optional[float] = None,
|
349 |
+
drop_path: float = 0.0,
|
350 |
+
act_layer: nn.Module = nn.GELU,
|
351 |
+
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
|
352 |
+
mlp_layer: nn.Module = Mlp,
|
353 |
+
custom_positional_encoding: Callable = None,
|
354 |
+
pretrained_checkpoint_path: str = None,
|
355 |
+
indices: Optional[Union[int, List[int]]] = None,
|
356 |
+
norm_intermediate: bool = True,
|
357 |
+
intermediates_only: bool = False,
|
358 |
+
gradient_checkpointing: bool = False,
|
359 |
+
*args,
|
360 |
+
**kwargs,
|
361 |
+
):
|
362 |
+
"""
|
363 |
+
Initialize the Multi-View Global-Attention Transformer for information sharing across image features from different views.
|
364 |
+
Extends the base class to return intermediate features.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
input_embed_dim (int): Dimension of input embeddings.
|
368 |
+
max_num_views (int): Maximum number of views for positional encoding.
|
369 |
+
use_rand_idx_pe_for_non_reference_views (bool): Whether to use random index positional encoding for non-reference views.
|
370 |
+
size (str): String to indicate interpretable size of the transformer (for e.g., base, large, ...). (default: None)
|
371 |
+
depth (int): Number of transformer layers. (default: 12, base size)
|
372 |
+
dim (int): Dimension of the transformer. (default: 768, base size)
|
373 |
+
num_heads (int): Number of attention heads. (default: 12, base size)
|
374 |
+
mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.)
|
375 |
+
qkv_bias (bool): Whether to include bias in qkv projection (default: False)
|
376 |
+
qk_norm (bool): Whether to normalize q and k (default: False)
|
377 |
+
proj_drop (float): Dropout rate for output (default: 0.)
|
378 |
+
attn_drop (float): Dropout rate for attention weights (default: 0.)
|
379 |
+
init_values (float): Initial value for LayerScale gamma (default: None)
|
380 |
+
drop_path (float): Dropout rate for stochastic depth (default: 0.)
|
381 |
+
act_layer (nn.Module): Activation layer (default: nn.GELU)
|
382 |
+
norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm)
|
383 |
+
mlp_layer (nn.Module): MLP layer (default: Mlp)
|
384 |
+
custom_positional_encoding (Callable): Custom positional encoding function (default: None)
|
385 |
+
pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None)
|
386 |
+
indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. (default: None) Options:
|
387 |
+
- None: Return all intermediate layers.
|
388 |
+
- int: Return the last n layers.
|
389 |
+
- List[int]: Return the intermediate layers at the specified indices.
|
390 |
+
norm_intermediate (bool, optional): Whether to normalize the intermediate features. (default: True)
|
391 |
+
intermediates_only (bool, optional): Whether to return only the intermediate features. (default: False)
|
392 |
+
gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False)
|
393 |
+
"""
|
394 |
+
# Init the base classes
|
395 |
+
MultiViewGlobalAttentionTransformer.__init__(
|
396 |
+
self,
|
397 |
+
name=name,
|
398 |
+
input_embed_dim=input_embed_dim,
|
399 |
+
max_num_views=max_num_views,
|
400 |
+
use_rand_idx_pe_for_non_reference_views=use_rand_idx_pe_for_non_reference_views,
|
401 |
+
size=size,
|
402 |
+
depth=depth,
|
403 |
+
dim=dim,
|
404 |
+
num_heads=num_heads,
|
405 |
+
mlp_ratio=mlp_ratio,
|
406 |
+
qkv_bias=qkv_bias,
|
407 |
+
qk_norm=qk_norm,
|
408 |
+
proj_drop=proj_drop,
|
409 |
+
attn_drop=attn_drop,
|
410 |
+
init_values=init_values,
|
411 |
+
drop_path=drop_path,
|
412 |
+
act_layer=act_layer,
|
413 |
+
norm_layer=norm_layer,
|
414 |
+
mlp_layer=mlp_layer,
|
415 |
+
custom_positional_encoding=custom_positional_encoding,
|
416 |
+
pretrained_checkpoint_path=pretrained_checkpoint_path,
|
417 |
+
gradient_checkpointing=gradient_checkpointing,
|
418 |
+
*args,
|
419 |
+
**kwargs,
|
420 |
+
)
|
421 |
+
IntermediateFeatureReturner.__init__(
|
422 |
+
self,
|
423 |
+
indices=indices,
|
424 |
+
norm_intermediate=norm_intermediate,
|
425 |
+
intermediates_only=intermediates_only,
|
426 |
+
)
|
427 |
+
|
428 |
+
def forward(
|
429 |
+
self,
|
430 |
+
model_input: MultiViewTransformerInput,
|
431 |
+
) -> Union[
|
432 |
+
List[MultiViewTransformerOutput],
|
433 |
+
Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]],
|
434 |
+
]:
|
435 |
+
"""
|
436 |
+
Forward interface for the Multi-View Global-Attention Transformer with Intermediate Feature Return.
|
437 |
+
|
438 |
+
Args:
|
439 |
+
model_input (MultiViewTransformerInput): Input to the model.
|
440 |
+
Expects the features to be a list of size (batch, input_embed_dim, height, width),
|
441 |
+
where each entry corresponds to a different view.
|
442 |
+
Optionally, the input can also include additional_input_tokens (e.g., class token, registers, pose tokens, scale token)
|
443 |
+
which are appended to the token set from the multi-view features. The tokens are of size (batch, input_embed_dim, num_of_additional_tokens).
|
444 |
+
|
445 |
+
Returns:
|
446 |
+
Union[List[MultiViewTransformerOutput], Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]]]:
|
447 |
+
Output of the model post information sharing.
|
448 |
+
If intermediates_only is True, returns a list of intermediate outputs.
|
449 |
+
If intermediates_only is False, returns a tuple of final output and a list of intermediate outputs.
|
450 |
+
"""
|
451 |
+
# Check that the number of views matches the input and the features are of expected shape
|
452 |
+
assert (
|
453 |
+
len(model_input.features) <= self.max_num_views
|
454 |
+
), f"Expected {self.num_views} views, got {len(model_input.features)}"
|
455 |
+
assert all(
|
456 |
+
curr_view_features.shape[1] == self.input_embed_dim for curr_view_features in model_input.features
|
457 |
+
), f"All views must have input dimension {self.input_embed_dim}"
|
458 |
+
assert all(
|
459 |
+
curr_view_features.ndim == 4 for curr_view_features in model_input.features
|
460 |
+
), "All views must have 4 dimensions (N, C, H, W)"
|
461 |
+
|
462 |
+
# Get the indices of the intermediate features to return
|
463 |
+
intermediate_multi_view_features = []
|
464 |
+
take_indices, _ = feature_take_indices(self.depth, self.indices)
|
465 |
+
|
466 |
+
# Initialize the multi-view features from the model input and number of views for current input
|
467 |
+
multi_view_features = model_input.features
|
468 |
+
num_of_views = len(multi_view_features)
|
469 |
+
batch_size, _, height, width = multi_view_features[0].shape
|
470 |
+
num_of_tokens_per_view = height * width
|
471 |
+
|
472 |
+
# Stack the multi-view features (N, C, H, W) to (N, V, C, H, W) (assumes all V views have same shape)
|
473 |
+
multi_view_features = torch.stack(multi_view_features, dim=1)
|
474 |
+
|
475 |
+
# Resize the multi-view features from NVCHW to NLC, where L = V * H * W
|
476 |
+
multi_view_features = multi_view_features.permute(0, 1, 3, 4, 2) # (N, V, H, W, C)
|
477 |
+
multi_view_features = multi_view_features.reshape(
|
478 |
+
batch_size, num_of_views * height * width, self.input_embed_dim
|
479 |
+
).contiguous()
|
480 |
+
|
481 |
+
# Process additional input tokens if provided
|
482 |
+
if model_input.additional_input_tokens is not None:
|
483 |
+
additional_tokens = model_input.additional_input_tokens
|
484 |
+
assert additional_tokens.ndim == 3, "Additional tokens must have 3 dimensions (N, C, T)"
|
485 |
+
assert (
|
486 |
+
additional_tokens.shape[1] == self.input_embed_dim
|
487 |
+
), f"Additional tokens must have input dimension {self.input_embed_dim}"
|
488 |
+
assert additional_tokens.shape[0] == batch_size, "Batch size mismatch for additional tokens"
|
489 |
+
|
490 |
+
# Reshape to channel-last format for transformer processing
|
491 |
+
additional_tokens = additional_tokens.permute(0, 2, 1).contiguous() # (N, C, T) -> (N, T, C)
|
492 |
+
|
493 |
+
# Concatenate the additional tokens to the multi-view features
|
494 |
+
multi_view_features = torch.cat([multi_view_features, additional_tokens], dim=1)
|
495 |
+
|
496 |
+
# Project input features to the transformer dimension
|
497 |
+
multi_view_features = self.proj_embed(multi_view_features)
|
498 |
+
|
499 |
+
# Create patch positions for each view if custom positional encoding is used
|
500 |
+
if self.custom_positional_encoding is not None:
|
501 |
+
multi_view_positions = [
|
502 |
+
self.position_getter(batch_size, height, width, multi_view_features.device)
|
503 |
+
] * num_of_views # List of length V, where each tensor is (N, H * W, C)
|
504 |
+
multi_view_positions = torch.cat(multi_view_positions, dim=1) # (N, V * H * W, C)
|
505 |
+
else:
|
506 |
+
multi_view_positions = [None] * num_of_views
|
507 |
+
|
508 |
+
# Add None positions for additional tokens if they exist
|
509 |
+
if model_input.additional_input_tokens is not None:
|
510 |
+
additional_tokens_positions = [None] * model_input.additional_input_tokens.shape[1]
|
511 |
+
multi_view_positions = multi_view_positions + additional_tokens_positions
|
512 |
+
|
513 |
+
# Add positional encoding for reference view (idx 0)
|
514 |
+
ref_view_pe = self.view_pos_table[0].clone().detach()
|
515 |
+
ref_view_pe = ref_view_pe.reshape((1, 1, self.dim))
|
516 |
+
ref_view_pe = ref_view_pe.repeat(batch_size, num_of_tokens_per_view, 1)
|
517 |
+
ref_view_features = multi_view_features[:, :num_of_tokens_per_view, :]
|
518 |
+
ref_view_features = ref_view_features + ref_view_pe
|
519 |
+
|
520 |
+
# Add positional encoding for non-reference views (sequential indices starting from idx 1 or random indices which are uniformly sampled)
|
521 |
+
if self.use_rand_idx_pe_for_non_reference_views:
|
522 |
+
non_ref_view_pe_indices = torch.randint(low=1, high=self.max_num_views, size=(num_of_views - 1,))
|
523 |
+
else:
|
524 |
+
non_ref_view_pe_indices = torch.arange(1, num_of_views)
|
525 |
+
non_ref_view_pe = self.view_pos_table[non_ref_view_pe_indices].clone().detach()
|
526 |
+
non_ref_view_pe = non_ref_view_pe.reshape((1, num_of_views - 1, self.dim))
|
527 |
+
non_ref_view_pe = non_ref_view_pe.repeat_interleave(num_of_tokens_per_view, dim=1)
|
528 |
+
non_ref_view_pe = non_ref_view_pe.repeat(batch_size, 1, 1)
|
529 |
+
non_ref_view_features = multi_view_features[
|
530 |
+
:, num_of_tokens_per_view : num_of_views * num_of_tokens_per_view, :
|
531 |
+
]
|
532 |
+
non_ref_view_features = non_ref_view_features + non_ref_view_pe
|
533 |
+
|
534 |
+
# Concatenate the reference and non-reference view features
|
535 |
+
# Handle additional tokens (no view-based positional encoding for them)
|
536 |
+
if model_input.additional_input_tokens is not None:
|
537 |
+
additional_features = multi_view_features[:, num_of_views * num_of_tokens_per_view :, :]
|
538 |
+
multi_view_features = torch.cat([ref_view_features, non_ref_view_features, additional_features], dim=1)
|
539 |
+
else:
|
540 |
+
multi_view_features = torch.cat([ref_view_features, non_ref_view_features], dim=1)
|
541 |
+
|
542 |
+
# Loop over the depth of the transformer
|
543 |
+
for depth_idx in range(self.depth):
|
544 |
+
# Apply the self-attention block and update the multi-view features
|
545 |
+
multi_view_features = self.self_attention_blocks[depth_idx](multi_view_features, multi_view_positions)
|
546 |
+
if depth_idx in take_indices:
|
547 |
+
# Normalize the intermediate features with final norm layer if enabled
|
548 |
+
intermediate_multi_view_features.append(
|
549 |
+
self.norm(multi_view_features) if self.norm_intermediate else multi_view_features
|
550 |
+
)
|
551 |
+
|
552 |
+
# Reshape the intermediate features and convert to MultiViewTransformerOutput class
|
553 |
+
for idx in range(len(intermediate_multi_view_features)):
|
554 |
+
# Get the current intermediate features
|
555 |
+
current_features = intermediate_multi_view_features[idx]
|
556 |
+
|
557 |
+
# Extract additional token features if provided
|
558 |
+
additional_token_features = None
|
559 |
+
if model_input.additional_input_tokens is not None:
|
560 |
+
additional_token_features = current_features[:, num_of_views * num_of_tokens_per_view :, :]
|
561 |
+
additional_token_features = additional_token_features.permute(0, 2, 1).contiguous() # (N, C, T)
|
562 |
+
# Only keep the view features for reshaping
|
563 |
+
current_features = current_features[:, : num_of_views * num_of_tokens_per_view, :]
|
564 |
+
|
565 |
+
# Reshape the intermediate multi-view features (N, V * H * W, C) back to (N, V, C, H, W)
|
566 |
+
current_features = current_features.reshape(
|
567 |
+
batch_size, num_of_views, height, width, self.dim
|
568 |
+
) # (N, V, H, W, C)
|
569 |
+
current_features = current_features.permute(0, 1, 4, 2, 3).contiguous() # (N, V, C, H, W)
|
570 |
+
|
571 |
+
# Split the intermediate multi-view features into separate views
|
572 |
+
current_features = current_features.split(1, dim=1)
|
573 |
+
current_features = [
|
574 |
+
intermediate_view_features.squeeze(dim=1) for intermediate_view_features in current_features
|
575 |
+
]
|
576 |
+
|
577 |
+
intermediate_multi_view_features[idx] = MultiViewTransformerOutput(
|
578 |
+
features=current_features, additional_token_features=additional_token_features
|
579 |
+
)
|
580 |
+
|
581 |
+
# Return only the intermediate features if enabled
|
582 |
+
if self.intermediates_only:
|
583 |
+
return intermediate_multi_view_features
|
584 |
+
|
585 |
+
# Normalize the output features
|
586 |
+
output_multi_view_features = self.norm(multi_view_features)
|
587 |
+
|
588 |
+
# Extract view features (excluding additional tokens)
|
589 |
+
additional_token_features = None
|
590 |
+
if model_input.additional_input_tokens is not None:
|
591 |
+
additional_token_features = output_multi_view_features[:, num_of_views * num_of_tokens_per_view :, :]
|
592 |
+
additional_token_features = additional_token_features.permute(0, 2, 1).contiguous() # (N, C, T)
|
593 |
+
view_features = output_multi_view_features[:, : num_of_views * num_of_tokens_per_view, :]
|
594 |
+
else:
|
595 |
+
view_features = output_multi_view_features
|
596 |
+
|
597 |
+
# Reshape the output multi-view features (N, V * H * W, C) back to (N, V, C, H, W)
|
598 |
+
view_features = view_features.reshape(batch_size, num_of_views, height, width, self.dim) # (N, V, H, W, C)
|
599 |
+
view_features = view_features.permute(0, 1, 4, 2, 3).contiguous() # (N, V, C, H, W)
|
600 |
+
|
601 |
+
# Split the output multi-view features into separate views
|
602 |
+
view_features = view_features.split(1, dim=1)
|
603 |
+
view_features = [output_view_features.squeeze(dim=1) for output_view_features in view_features]
|
604 |
+
|
605 |
+
output_multi_view_features = MultiViewTransformerOutput(
|
606 |
+
features=view_features, additional_token_features=additional_token_features
|
607 |
+
)
|
608 |
+
|
609 |
+
return output_multi_view_features, intermediate_multi_view_features
|
610 |
+
|
611 |
+
|
612 |
+
class GlobalAttentionTransformer(UniCeptionInfoSharingBase):
|
613 |
+
"UniCeption Global-Attention Transformer for information sharing across different set of features."
|
614 |
+
|
615 |
+
def __init__(
|
616 |
+
self,
|
617 |
+
name: str,
|
618 |
+
input_embed_dim: int,
|
619 |
+
max_num_sets: int,
|
620 |
+
use_rand_idx_pe_for_non_reference_sets: bool,
|
621 |
+
size: Optional[str] = None,
|
622 |
+
depth: int = 12,
|
623 |
+
dim: int = 768,
|
624 |
+
num_heads: int = 12,
|
625 |
+
mlp_ratio: float = 4.0,
|
626 |
+
qkv_bias: bool = True,
|
627 |
+
qk_norm: bool = False,
|
628 |
+
proj_drop: float = 0.0,
|
629 |
+
attn_drop: float = 0.0,
|
630 |
+
init_values: Optional[float] = None,
|
631 |
+
drop_path: float = 0.0,
|
632 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
633 |
+
norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(nn.LayerNorm, eps=1e-6),
|
634 |
+
mlp_layer: Type[nn.Module] = Mlp,
|
635 |
+
pretrained_checkpoint_path: Optional[str] = None,
|
636 |
+
gradient_checkpointing: bool = False,
|
637 |
+
*args,
|
638 |
+
**kwargs,
|
639 |
+
):
|
640 |
+
"""
|
641 |
+
Initialize the Global-Attention Transformer for information sharing across features from different sets.
|
642 |
+
|
643 |
+
Args:
|
644 |
+
input_embed_dim (int): Dimension of input embeddings.
|
645 |
+
max_num_sets (int): Maximum number of sets for positional encoding.
|
646 |
+
use_rand_idx_pe_for_non_reference_sets (bool): Whether to use random index positional encoding for non-reference sets.
|
647 |
+
size (str): String to indicate interpretable size of the transformer (for e.g., base, large, ...). (default: None)
|
648 |
+
depth (int): Number of transformer layers. (default: 12, base size)
|
649 |
+
dim (int): Dimension of the transformer. (default: 768, base size)
|
650 |
+
num_heads (int): Number of attention heads. (default: 12, base size)
|
651 |
+
mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.)
|
652 |
+
qkv_bias (bool): Whether to include bias in qkv projection (default: True)
|
653 |
+
qk_norm (bool): Whether to normalize q and k (default: False)
|
654 |
+
proj_drop (float): Dropout rate for output (default: 0.)
|
655 |
+
attn_drop (float): Dropout rate for attention weights (default: 0.)
|
656 |
+
init_values (float): Initial value for LayerScale gamma (default: None)
|
657 |
+
drop_path (float): Dropout rate for stochastic depth (default: 0.)
|
658 |
+
act_layer (nn.Module): Activation layer (default: nn.GELU)
|
659 |
+
norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm)
|
660 |
+
mlp_layer (nn.Module): MLP layer (default: Mlp)
|
661 |
+
pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None)
|
662 |
+
gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False)
|
663 |
+
"""
|
664 |
+
# Initialize the base class
|
665 |
+
super().__init__(name=name, size=size, *args, **kwargs)
|
666 |
+
|
667 |
+
# Initialize the specific attributes of the transformer
|
668 |
+
self.input_embed_dim = input_embed_dim
|
669 |
+
self.max_num_sets = max_num_sets
|
670 |
+
self.use_rand_idx_pe_for_non_reference_sets = use_rand_idx_pe_for_non_reference_sets
|
671 |
+
self.depth = depth
|
672 |
+
self.dim = dim
|
673 |
+
self.num_heads = num_heads
|
674 |
+
self.mlp_ratio = mlp_ratio
|
675 |
+
self.qkv_bias = qkv_bias
|
676 |
+
self.qk_norm = qk_norm
|
677 |
+
self.proj_drop = proj_drop
|
678 |
+
self.attn_drop = attn_drop
|
679 |
+
self.init_values = init_values
|
680 |
+
self.drop_path = drop_path
|
681 |
+
self.act_layer = act_layer
|
682 |
+
self.norm_layer = norm_layer
|
683 |
+
self.mlp_layer = mlp_layer
|
684 |
+
self.pretrained_checkpoint_path = pretrained_checkpoint_path
|
685 |
+
self.gradient_checkpointing = gradient_checkpointing
|
686 |
+
|
687 |
+
# Initialize the projection layer for input embeddings
|
688 |
+
if self.input_embed_dim != self.dim:
|
689 |
+
self.proj_embed = nn.Linear(self.input_embed_dim, self.dim, bias=True)
|
690 |
+
else:
|
691 |
+
self.proj_embed = nn.Identity()
|
692 |
+
|
693 |
+
# Initialize the self-attention blocks which ingest all sets at once
|
694 |
+
self.self_attention_blocks = nn.ModuleList(
|
695 |
+
[
|
696 |
+
SelfAttentionBlock(
|
697 |
+
dim=self.dim,
|
698 |
+
num_heads=self.num_heads,
|
699 |
+
mlp_ratio=self.mlp_ratio,
|
700 |
+
qkv_bias=self.qkv_bias,
|
701 |
+
qk_norm=self.qk_norm,
|
702 |
+
proj_drop=self.proj_drop,
|
703 |
+
attn_drop=self.attn_drop,
|
704 |
+
init_values=self.init_values,
|
705 |
+
drop_path=self.drop_path,
|
706 |
+
act_layer=self.act_layer,
|
707 |
+
norm_layer=self.norm_layer,
|
708 |
+
mlp_layer=self.mlp_layer,
|
709 |
+
)
|
710 |
+
for _ in range(self.depth)
|
711 |
+
]
|
712 |
+
)
|
713 |
+
|
714 |
+
# Initialize the final normalization layer
|
715 |
+
self.norm = self.norm_layer(self.dim)
|
716 |
+
|
717 |
+
# Initialize the positional encoding table for the different sets
|
718 |
+
self.register_buffer(
|
719 |
+
"set_pos_table",
|
720 |
+
self._get_sinusoid_encoding_table(self.max_num_sets, self.dim, 10000),
|
721 |
+
)
|
722 |
+
|
723 |
+
# Initialize random weights
|
724 |
+
self.initialize_weights()
|
725 |
+
|
726 |
+
# Load pretrained weights if provided
|
727 |
+
if self.pretrained_checkpoint_path is not None:
|
728 |
+
print(f"Loading pretrained global-attention transformer weights from {self.pretrained_checkpoint_path} ...")
|
729 |
+
ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
|
730 |
+
print(self.load_state_dict(ckpt["model"]))
|
731 |
+
|
732 |
+
# Apply gradient checkpointing if enabled
|
733 |
+
if self.gradient_checkpointing:
|
734 |
+
for i, block in enumerate(self.self_attention_blocks):
|
735 |
+
self.self_attention_blocks[i] = self.wrap_module_with_gradient_checkpointing(block)
|
736 |
+
|
737 |
+
def _get_sinusoid_encoding_table(self, n_position, d_hid, base):
|
738 |
+
"Sinusoid position encoding table"
|
739 |
+
|
740 |
+
def get_position_angle_vec(position):
|
741 |
+
return [position / np.power(base, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
742 |
+
|
743 |
+
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
744 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
|
745 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])
|
746 |
+
|
747 |
+
return torch.FloatTensor(sinusoid_table)
|
748 |
+
|
749 |
+
def initialize_weights(self):
|
750 |
+
"Initialize weights of the transformer."
|
751 |
+
# Linears and layer norms
|
752 |
+
self.apply(self._init_weights)
|
753 |
+
|
754 |
+
def _init_weights(self, m):
|
755 |
+
"Initialize the transformer linear and layer norm weights."
|
756 |
+
if isinstance(m, nn.Linear):
|
757 |
+
# We use xavier_uniform following official JAX ViT:
|
758 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
759 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
760 |
+
nn.init.constant_(m.bias, 0)
|
761 |
+
elif isinstance(m, nn.LayerNorm):
|
762 |
+
nn.init.constant_(m.bias, 0)
|
763 |
+
nn.init.constant_(m.weight, 1.0)
|
764 |
+
|
765 |
+
def forward(
|
766 |
+
self,
|
767 |
+
model_input: MultiSetTransformerInput,
|
768 |
+
) -> MultiSetTransformerOutput:
|
769 |
+
"""
|
770 |
+
Forward interface for the Multi-Set Global-Attention Transformer.
|
771 |
+
|
772 |
+
Args:
|
773 |
+
model_input (MultiSetTransformerInput): Input to the model.
|
774 |
+
Expects the features to be a list of size (batch, input_embed_dim, num_tokens),
|
775 |
+
where each entry corresponds to a different set of tokens and
|
776 |
+
the number of tokens can be different for each set.
|
777 |
+
Optionally, the input can also include additional_input_tokens (e.g., class token, registers, pose tokens, scale token)
|
778 |
+
which are appended to the token set from the multi-view features. The tokens are of size (batch, input_embed_dim, num_of_additional_tokens).
|
779 |
+
|
780 |
+
Returns:
|
781 |
+
MultiSetTransformerOutput: Output of the model post information sharing.
|
782 |
+
"""
|
783 |
+
# Check that the number of sets matches the input and the features are of expected shape
|
784 |
+
assert (
|
785 |
+
len(model_input.features) <= self.max_num_sets
|
786 |
+
), f"Expected less than {self.max_num_sets} sets, got {len(model_input.features)}"
|
787 |
+
assert all(
|
788 |
+
set_features.shape[1] == self.input_embed_dim for set_features in model_input.features
|
789 |
+
), f"All sets must have input dimension {self.input_embed_dim}"
|
790 |
+
assert all(
|
791 |
+
set_features.ndim == 3 for set_features in model_input.features
|
792 |
+
), "All sets must have 3 dimensions (N, C, T)"
|
793 |
+
|
794 |
+
# Initialize the multi-set features from the model input and number of sets for current input
|
795 |
+
multi_set_features = model_input.features
|
796 |
+
num_of_sets = len(multi_set_features)
|
797 |
+
batch_size, _, _ = multi_set_features[0].shape
|
798 |
+
num_of_tokens_per_set = [set_features.shape[2] for set_features in multi_set_features]
|
799 |
+
|
800 |
+
# Permute the multi-set features from (N, C, T) to (N, T, C)
|
801 |
+
multi_set_features = [set_features.permute(0, 2, 1).contiguous() for set_features in multi_set_features]
|
802 |
+
|
803 |
+
# Stack the multi-set features along the number of tokens dimension
|
804 |
+
multi_set_features = torch.cat(multi_set_features, dim=1)
|
805 |
+
|
806 |
+
# Process additional input tokens if provided
|
807 |
+
if model_input.additional_input_tokens is not None:
|
808 |
+
additional_tokens = model_input.additional_input_tokens
|
809 |
+
assert additional_tokens.ndim == 3, "Additional tokens must have 3 dimensions (N, C, T)"
|
810 |
+
assert (
|
811 |
+
additional_tokens.shape[1] == self.input_embed_dim
|
812 |
+
), f"Additional tokens must have input dimension {self.input_embed_dim}"
|
813 |
+
assert additional_tokens.shape[0] == batch_size, "Batch size mismatch for additional tokens"
|
814 |
+
|
815 |
+
# Reshape to channel-last format for transformer processing
|
816 |
+
additional_tokens = additional_tokens.permute(0, 2, 1).contiguous() # (N, C, T) -> (N, T, C)
|
817 |
+
|
818 |
+
# Concatenate the additional tokens to the multi-set features
|
819 |
+
multi_set_features = torch.cat([multi_set_features, additional_tokens], dim=1)
|
820 |
+
|
821 |
+
# Project input features to the transformer dimension
|
822 |
+
multi_set_features = self.proj_embed(multi_set_features)
|
823 |
+
|
824 |
+
# Create dummy patch positions for each set
|
825 |
+
multi_set_positions = [None] * num_of_sets
|
826 |
+
|
827 |
+
# Add positional encoding for reference set (idx 0)
|
828 |
+
ref_set_pe = self.set_pos_table[0].clone().detach()
|
829 |
+
ref_set_pe = ref_set_pe.reshape((1, 1, self.dim))
|
830 |
+
ref_set_pe = ref_set_pe.repeat(batch_size, num_of_tokens_per_set[0], 1)
|
831 |
+
ref_set_features = multi_set_features[:, : num_of_tokens_per_set[0], :]
|
832 |
+
ref_set_features = ref_set_features + ref_set_pe
|
833 |
+
|
834 |
+
# Add positional encoding for non-reference sets (sequential indices starting from idx 1 or random indices which are uniformly sampled)
|
835 |
+
if self.use_rand_idx_pe_for_non_reference_sets:
|
836 |
+
non_ref_set_pe_indices = torch.randint(low=1, high=self.max_num_sets, size=(num_of_sets - 1,))
|
837 |
+
else:
|
838 |
+
non_ref_set_pe_indices = torch.arange(1, num_of_sets)
|
839 |
+
non_ref_set_pe_list = []
|
840 |
+
for non_ref_set_idx in range(1, num_of_sets):
|
841 |
+
non_ref_set_pe_for_idx = self.set_pos_table[non_ref_set_pe_indices[non_ref_set_idx - 1]].clone().detach()
|
842 |
+
non_ref_set_pe_for_idx = non_ref_set_pe_for_idx.reshape((1, 1, self.dim))
|
843 |
+
non_ref_set_pe_for_idx = non_ref_set_pe_for_idx.repeat(
|
844 |
+
batch_size, num_of_tokens_per_set[non_ref_set_idx], 1
|
845 |
+
)
|
846 |
+
non_ref_set_pe_list.append(non_ref_set_pe_for_idx)
|
847 |
+
non_ref_set_pe = torch.cat(non_ref_set_pe_list, dim=1)
|
848 |
+
non_ref_set_features = multi_set_features[:, num_of_tokens_per_set[0] : sum(num_of_tokens_per_set), :]
|
849 |
+
non_ref_set_features = non_ref_set_features + non_ref_set_pe
|
850 |
+
|
851 |
+
# Concatenate the reference and non-reference set features
|
852 |
+
# Handle additional tokens (no set-based positional encoding for them)
|
853 |
+
if model_input.additional_input_tokens is not None:
|
854 |
+
additional_features = multi_set_features[:, sum(num_of_tokens_per_set) :, :]
|
855 |
+
multi_set_features = torch.cat([ref_set_features, non_ref_set_features, additional_features], dim=1)
|
856 |
+
else:
|
857 |
+
multi_set_features = torch.cat([ref_set_features, non_ref_set_features], dim=1)
|
858 |
+
|
859 |
+
# Add None positions for additional tokens if they exist
|
860 |
+
if model_input.additional_input_tokens is not None:
|
861 |
+
additional_tokens_positions = [None] * model_input.additional_input_tokens.shape[2]
|
862 |
+
multi_set_positions = multi_set_positions + additional_tokens_positions
|
863 |
+
|
864 |
+
# Loop over the depth of the transformer
|
865 |
+
for depth_idx in range(self.depth):
|
866 |
+
# Apply the self-attention block and update the multi-set features
|
867 |
+
multi_set_features = self.self_attention_blocks[depth_idx](multi_set_features, multi_set_positions)
|
868 |
+
|
869 |
+
# Normalize the output features
|
870 |
+
output_multi_set_features = self.norm(multi_set_features)
|
871 |
+
|
872 |
+
# Extract additional token features if provided
|
873 |
+
additional_token_features = None
|
874 |
+
if model_input.additional_input_tokens is not None:
|
875 |
+
additional_token_features = output_multi_set_features[:, sum(num_of_tokens_per_set) :, :]
|
876 |
+
additional_token_features = additional_token_features.permute(
|
877 |
+
0, 2, 1
|
878 |
+
).contiguous() # (N, T, C) -> (N, C, T)
|
879 |
+
# Only keep the set features for reshaping
|
880 |
+
output_multi_set_features = output_multi_set_features[:, : sum(num_of_tokens_per_set), :]
|
881 |
+
|
882 |
+
# Reshape the output multi-set features from (N, T, C) to (N, C, T)
|
883 |
+
output_multi_set_features = output_multi_set_features.permute(0, 2, 1).contiguous()
|
884 |
+
|
885 |
+
# Split the output multi-set features into separate sets using the list of number of tokens per set
|
886 |
+
output_multi_set_features = torch.split(output_multi_set_features, num_of_tokens_per_set, dim=2)
|
887 |
+
|
888 |
+
# Return the output multi-set features with additional token features if provided
|
889 |
+
return MultiSetTransformerOutput(
|
890 |
+
features=output_multi_set_features, additional_token_features=additional_token_features
|
891 |
+
)
|
892 |
+
|
893 |
+
|
894 |
+
def dummy_positional_encoding(x, xpos):
|
895 |
+
"Dummy function for positional encoding of tokens"
|
896 |
+
x = x
|
897 |
+
xpos = xpos
|
898 |
+
return x
|
899 |
+
|
900 |
+
|
901 |
+
if __name__ == "__main__":
|
902 |
+
# Init multi-view global-attention transformer with no custom positional encoding and run a forward pass
|
903 |
+
for num_views in [2, 3, 4]:
|
904 |
+
print(f"Testing MultiViewGlobalAttentionTransformer with {num_views} views ...")
|
905 |
+
# Sequential idx based positional encoding
|
906 |
+
model = MultiViewGlobalAttentionTransformer(
|
907 |
+
name="MV-GAT", input_embed_dim=1024, max_num_views=1000, use_rand_idx_pe_for_non_reference_views=False
|
908 |
+
)
|
909 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
|
910 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
911 |
+
model_output = model(model_input)
|
912 |
+
assert len(model_output.features) == num_views
|
913 |
+
assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
|
914 |
+
# Random idx based positional encoding
|
915 |
+
model = MultiViewGlobalAttentionTransformer(
|
916 |
+
name="MV-GAT", input_embed_dim=1024, max_num_views=1000, use_rand_idx_pe_for_non_reference_views=True
|
917 |
+
)
|
918 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
|
919 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
920 |
+
model_output = model(model_input)
|
921 |
+
assert len(model_output.features) == num_views
|
922 |
+
assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
|
923 |
+
|
924 |
+
# Init multi-view global-attention transformer with custom positional encoding and run a forward pass
|
925 |
+
for num_views in [2, 3, 4]:
|
926 |
+
print(f"Testing MultiViewGlobalAttentionTransformer with {num_views} views and custom positional encoding ...")
|
927 |
+
model = MultiViewGlobalAttentionTransformer(
|
928 |
+
name="MV-GAT",
|
929 |
+
input_embed_dim=1024,
|
930 |
+
max_num_views=1000,
|
931 |
+
use_rand_idx_pe_for_non_reference_views=True,
|
932 |
+
custom_positional_encoding=dummy_positional_encoding,
|
933 |
+
)
|
934 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
|
935 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
936 |
+
model_output = model(model_input)
|
937 |
+
assert len(model_output.features) == num_views
|
938 |
+
assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
|
939 |
+
|
940 |
+
print("All multi-view global-attention transformers initialized and tested successfully!")
|
941 |
+
|
942 |
+
# Intermediate Feature Returner Tests
|
943 |
+
print("Running Intermediate Feature Returner Tests ...")
|
944 |
+
|
945 |
+
# Run the intermediate feature returner with last-n index
|
946 |
+
model_intermediate_feature_returner = MultiViewGlobalAttentionTransformerIFR(
|
947 |
+
name="MV-GAT-IFR",
|
948 |
+
input_embed_dim=1024,
|
949 |
+
max_num_views=1000,
|
950 |
+
use_rand_idx_pe_for_non_reference_views=True,
|
951 |
+
indices=6, # Last 6 layers
|
952 |
+
)
|
953 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
|
954 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
955 |
+
output = model_intermediate_feature_returner(model_input)
|
956 |
+
assert isinstance(output, tuple)
|
957 |
+
assert isinstance(output[0], MultiViewTransformerOutput)
|
958 |
+
assert len(output[1]) == 6
|
959 |
+
assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1])
|
960 |
+
assert len(output[1][0].features) == 2
|
961 |
+
|
962 |
+
# Run the intermediate feature returner with specific indices
|
963 |
+
model_intermediate_feature_returner = MultiViewGlobalAttentionTransformerIFR(
|
964 |
+
name="MV-GAT-IFR",
|
965 |
+
input_embed_dim=1024,
|
966 |
+
max_num_views=1000,
|
967 |
+
use_rand_idx_pe_for_non_reference_views=True,
|
968 |
+
indices=[0, 2, 4, 6], # Specific indices
|
969 |
+
)
|
970 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
|
971 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
972 |
+
output = model_intermediate_feature_returner(model_input)
|
973 |
+
assert isinstance(output, tuple)
|
974 |
+
assert isinstance(output[0], MultiViewTransformerOutput)
|
975 |
+
assert len(output[1]) == 4
|
976 |
+
assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1])
|
977 |
+
assert len(output[1][0].features) == 2
|
978 |
+
|
979 |
+
# Test the normalizing of intermediate features
|
980 |
+
model_intermediate_feature_returner = MultiViewGlobalAttentionTransformerIFR(
|
981 |
+
name="MV-GAT-IFR",
|
982 |
+
input_embed_dim=1024,
|
983 |
+
max_num_views=1000,
|
984 |
+
use_rand_idx_pe_for_non_reference_views=True,
|
985 |
+
indices=[-1], # Last layer
|
986 |
+
norm_intermediate=False, # Disable normalization
|
987 |
+
)
|
988 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
|
989 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
990 |
+
output = model_intermediate_feature_returner(model_input)
|
991 |
+
for view_idx in range(2):
|
992 |
+
assert not torch.equal(
|
993 |
+
output[0].features[view_idx], output[1][-1].features[view_idx]
|
994 |
+
), "Final features and intermediate features (last layer) must be different."
|
995 |
+
|
996 |
+
model_intermediate_feature_returner = MultiViewGlobalAttentionTransformerIFR(
|
997 |
+
name="MV-GAT-IFR",
|
998 |
+
input_embed_dim=1024,
|
999 |
+
max_num_views=1000,
|
1000 |
+
use_rand_idx_pe_for_non_reference_views=True,
|
1001 |
+
indices=[-1], # Last layer
|
1002 |
+
norm_intermediate=True,
|
1003 |
+
)
|
1004 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
|
1005 |
+
model_input = MultiViewTransformerInput(features=model_input)
|
1006 |
+
output = model_intermediate_feature_returner(model_input)
|
1007 |
+
for view_idx in range(2):
|
1008 |
+
assert torch.equal(
|
1009 |
+
output[0].features[view_idx], output[1][-1].features[view_idx]
|
1010 |
+
), "Final features and intermediate features (last layer) must be same."
|
1011 |
+
|
1012 |
+
print("All Intermediate Feature Returner Tests passed!")
|
1013 |
+
|
1014 |
+
# Init multi-set global-attention transformer and run a forward pass with different number of sets and set token sizes
|
1015 |
+
import random
|
1016 |
+
|
1017 |
+
model = GlobalAttentionTransformer(
|
1018 |
+
name="GAT", input_embed_dim=1024, max_num_sets=3, use_rand_idx_pe_for_non_reference_sets=False
|
1019 |
+
)
|
1020 |
+
for num_sets in [2, 3]:
|
1021 |
+
print(f"Testing GlobalAttentionTransformer with {num_sets} sets ...")
|
1022 |
+
model_input = [torch.rand(1, 1024, random.randint(256, 513)) for _ in range(num_sets)]
|
1023 |
+
model_input = MultiSetTransformerInput(features=model_input)
|
1024 |
+
model_output = model(model_input)
|
1025 |
+
assert len(model_output.features) == num_sets
|
1026 |
+
for feat, rand_input in zip(model_output.features, model_input.features):
|
1027 |
+
assert feat.shape[2] == rand_input.shape[2]
|
1028 |
+
assert feat.shape[1] == model.dim
|
1029 |
+
assert feat.shape[0] == rand_input.shape[0]
|
1030 |
+
# Random idx based positional encoding
|
1031 |
+
model = GlobalAttentionTransformer(
|
1032 |
+
name="GAT", input_embed_dim=1024, max_num_sets=1000, use_rand_idx_pe_for_non_reference_sets=True
|
1033 |
+
)
|
1034 |
+
for num_sets in [2, 3, 4]:
|
1035 |
+
print(f"Testing GlobalAttentionTransformer with {num_sets} sets ...")
|
1036 |
+
model_input = [torch.rand(1, 1024, random.randint(256, 513)) for _ in range(num_sets)]
|
1037 |
+
model_input = MultiSetTransformerInput(features=model_input)
|
1038 |
+
model_output = model(model_input)
|
1039 |
+
assert len(model_output.features) == num_sets
|
1040 |
+
for feat, rand_input in zip(model_output.features, model_input.features):
|
1041 |
+
assert feat.shape[2] == rand_input.shape[2]
|
1042 |
+
assert feat.shape[1] == model.dim
|
1043 |
+
assert feat.shape[0] == rand_input.shape[0]
|
1044 |
+
|
1045 |
+
print("All Global Attention Transformer Tests passed!")
|
1046 |
+
|
1047 |
+
# Test additional input tokens for MultiViewGlobalAttentionTransformer
|
1048 |
+
print("Testing MultiViewGlobalAttentionTransformer with additional input tokens...")
|
1049 |
+
model = MultiViewGlobalAttentionTransformer(
|
1050 |
+
name="MV-GAT", input_embed_dim=1024, max_num_views=1000, use_rand_idx_pe_for_non_reference_views=False
|
1051 |
+
)
|
1052 |
+
num_views = 2
|
1053 |
+
num_additional_tokens = 5
|
1054 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
|
1055 |
+
additional_tokens = torch.rand(1, 1024, num_additional_tokens)
|
1056 |
+
model_input = MultiViewTransformerInput(features=model_input, additional_input_tokens=additional_tokens)
|
1057 |
+
model_output = model(model_input)
|
1058 |
+
assert len(model_output.features) == num_views
|
1059 |
+
assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
|
1060 |
+
assert model_output.additional_token_features is not None
|
1061 |
+
assert model_output.additional_token_features.shape == (1, model.dim, num_additional_tokens)
|
1062 |
+
|
1063 |
+
# Test additional input tokens for MultiViewGlobalAttentionTransformerIFR
|
1064 |
+
print("Testing MultiViewGlobalAttentionTransformerIFR with additional input tokens...")
|
1065 |
+
model_ifr = MultiViewGlobalAttentionTransformerIFR(
|
1066 |
+
name="MV-GAT-IFR",
|
1067 |
+
input_embed_dim=1024,
|
1068 |
+
max_num_views=1000,
|
1069 |
+
use_rand_idx_pe_for_non_reference_views=True,
|
1070 |
+
indices=[0, 2, 4],
|
1071 |
+
)
|
1072 |
+
model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
|
1073 |
+
additional_tokens = torch.rand(1, 1024, num_additional_tokens)
|
1074 |
+
model_input = MultiViewTransformerInput(features=model_input, additional_input_tokens=additional_tokens)
|
1075 |
+
output = model_ifr(model_input)
|
1076 |
+
assert isinstance(output, tuple)
|
1077 |
+
assert isinstance(output[0], MultiViewTransformerOutput)
|
1078 |
+
assert output[0].additional_token_features is not None
|
1079 |
+
assert output[0].additional_token_features.shape == (1, model_ifr.dim, num_additional_tokens)
|
1080 |
+
assert len(output[1]) == 3
|
1081 |
+
assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1])
|
1082 |
+
assert all(intermediate.additional_token_features is not None for intermediate in output[1])
|
1083 |
+
assert all(
|
1084 |
+
intermediate.additional_token_features.shape == (1, model_ifr.dim, num_additional_tokens)
|
1085 |
+
for intermediate in output[1]
|
1086 |
+
)
|
1087 |
+
|
1088 |
+
# Test additional input tokens for GlobalAttentionTransformer
|
1089 |
+
print("Testing GlobalAttentionTransformer with additional input tokens...")
|
1090 |
+
model = GlobalAttentionTransformer(
|
1091 |
+
name="GAT", input_embed_dim=1024, max_num_sets=1000, use_rand_idx_pe_for_non_reference_sets=False
|
1092 |
+
)
|
1093 |
+
num_sets = 3
|
1094 |
+
num_additional_tokens = 8
|
1095 |
+
model_input = [torch.rand(1, 1024, random.randint(256, 513)) for _ in range(num_sets)]
|
1096 |
+
additional_tokens = torch.rand(1, 1024, num_additional_tokens)
|
1097 |
+
model_input = MultiSetTransformerInput(features=model_input, additional_input_tokens=additional_tokens)
|
1098 |
+
model_output = model(model_input)
|
1099 |
+
assert len(model_output.features) == num_sets
|
1100 |
+
for feat, rand_input in zip(model_output.features, model_input.features):
|
1101 |
+
assert feat.shape[2] == rand_input.shape[2]
|
1102 |
+
assert feat.shape[1] == model.dim
|
1103 |
+
assert feat.shape[0] == rand_input.shape[0]
|
1104 |
+
assert model_output.additional_token_features is not None
|
1105 |
+
assert model_output.additional_token_features.shape == (1, model.dim, num_additional_tokens)
|
1106 |
+
|
1107 |
+
print("All tests using additional input tokens passed!")
|
UniCeption/uniception/models/libs/__init__.py
ADDED
File without changes
|