infinity1096 commited on
Commit
c8b42eb
·
1 Parent(s): 3991736

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +152 -0
  3. LICENSE.txt +58 -0
  4. UniCeption/.gitignore +167 -0
  5. UniCeption/.pre-commit-config.yaml +18 -0
  6. UniCeption/.pylintrc +399 -0
  7. UniCeption/LICENSE +28 -0
  8. UniCeption/README.md +155 -0
  9. UniCeption/examples/models/cosmos/autoencoding.py +48 -0
  10. UniCeption/examples/models/cosmos/example.png +3 -0
  11. UniCeption/examples/models/cosmos/example_decoded.png +3 -0
  12. UniCeption/examples/models/dust3r/convert_dust3r_weights_to_uniception.py +331 -0
  13. UniCeption/examples/models/dust3r/dust3r.py +261 -0
  14. UniCeption/examples/models/dust3r/profile_dust3r.py +47 -0
  15. UniCeption/pyproject.toml +21 -0
  16. UniCeption/scripts/check_dependencies.py +49 -0
  17. UniCeption/scripts/download_checkpoints.py +48 -0
  18. UniCeption/scripts/install_croco_rope.py +62 -0
  19. UniCeption/scripts/prepare_offline_install.py +399 -0
  20. UniCeption/scripts/validate_installation.py +213 -0
  21. UniCeption/setup.py +188 -0
  22. UniCeption/tests/models/encoders/conftest.py +26 -0
  23. UniCeption/tests/models/encoders/test_encoders.py +204 -0
  24. UniCeption/tests/models/encoders/viz_image_encoders.py +294 -0
  25. UniCeption/tests/models/info_sharing/viz_mulit_view_cross_attn_transformers.py +337 -0
  26. UniCeption/uniception/__init__.py +0 -0
  27. UniCeption/uniception/models/encoders/README.md +129 -0
  28. UniCeption/uniception/models/encoders/__init__.py +235 -0
  29. UniCeption/uniception/models/encoders/base.py +157 -0
  30. UniCeption/uniception/models/encoders/cosmos.py +137 -0
  31. UniCeption/uniception/models/encoders/croco.py +457 -0
  32. UniCeption/uniception/models/encoders/dense_rep_encoder.py +344 -0
  33. UniCeption/uniception/models/encoders/dinov2.py +333 -0
  34. UniCeption/uniception/models/encoders/global_rep_encoder.py +115 -0
  35. UniCeption/uniception/models/encoders/image_normalizations.py +35 -0
  36. UniCeption/uniception/models/encoders/list.py +10 -0
  37. UniCeption/uniception/models/encoders/naradio.py +502 -0
  38. UniCeption/uniception/models/encoders/patch_embedder.py +235 -0
  39. UniCeption/uniception/models/encoders/radio.py +367 -0
  40. UniCeption/uniception/models/encoders/utils.py +86 -0
  41. UniCeption/uniception/models/factory/__init__.py +3 -0
  42. UniCeption/uniception/models/factory/dust3r.py +332 -0
  43. UniCeption/uniception/models/info_sharing/README.md +18 -0
  44. UniCeption/uniception/models/info_sharing/__init__.py +35 -0
  45. UniCeption/uniception/models/info_sharing/alternating_attention_transformer.py +944 -0
  46. UniCeption/uniception/models/info_sharing/base.py +116 -0
  47. UniCeption/uniception/models/info_sharing/cross_attention_transformer.py +582 -0
  48. UniCeption/uniception/models/info_sharing/diff_cross_attention_transformer.py +588 -0
  49. UniCeption/uniception/models/info_sharing/global_attention_transformer.py +1107 -0
  50. UniCeption/uniception/models/libs/__init__.py +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ uniflowmatch.egg-info/**
2
+ ufm_model_refine/**
3
+ ufm_model/**
4
+ /home/inf/UniFlowMatch/convert_old_ckpt.py
5
+ checkpoints/**
6
+
7
+ # Byte-compiled / optimized / DLL files
8
+ __pycache__/
9
+ **/__pycache__/
10
+ *.py[cod]
11
+ *$py.class
12
+
13
+ # C extensions
14
+ *.so
15
+
16
+ # Distribution / packaging
17
+ .Python
18
+ build/
19
+ develop-eggs/
20
+ dist/
21
+ downloads/
22
+ eggs/
23
+ .eggs/
24
+ lib/
25
+ lib64/
26
+ parts/
27
+ sdist/
28
+ var/
29
+ wheels/
30
+ pip-wheel-metadata/
31
+ share/python-wheels/
32
+ *.egg-info/
33
+ .installed.cfg
34
+ *.egg
35
+ MANIFEST
36
+
37
+ # PyInstaller
38
+ # Usually these files are written by a python script from a template
39
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
40
+ *.manifest
41
+ *.spec
42
+
43
+ # Installer logs
44
+ pip-log.txt
45
+ pip-delete-this-directory.txt
46
+
47
+ # Unit test / coverage reports
48
+ htmlcov/
49
+ .tox/
50
+ .nox/
51
+ .coverage
52
+ .coverage.*
53
+ .cache
54
+ nosetests.xml
55
+ coverage.xml
56
+ *.cover
57
+ *.py,cover
58
+ .hypothesis/
59
+ .pytest_cache/
60
+ cover/
61
+
62
+ # Translations
63
+ *.mo
64
+ *.pot
65
+
66
+ # Django stuff:
67
+ *.log
68
+ local_settings.py
69
+ db.sqlite3
70
+ db.sqlite3-journal
71
+
72
+ # Flask stuff:
73
+ instance/
74
+ .webassets-cache
75
+
76
+ # Scrapy stuff:
77
+ .scrapy
78
+
79
+ # Sphinx documentation
80
+ docs/_build/
81
+
82
+ # PyBuilder
83
+ target/
84
+
85
+ # Jupyter Notebook
86
+ .ipynb_checkpoints
87
+
88
+ # IPython
89
+ profile_default/
90
+ ipython_config.py
91
+
92
+ # pyenv
93
+ .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
103
+ __pypackages__/
104
+
105
+ # Celery stuff
106
+ celerybeat-schedule
107
+ celerybeat.pid
108
+
109
+ # SageMath parsed files
110
+ *.sage.py
111
+
112
+ # Environments
113
+ .env
114
+ .venv
115
+ env/
116
+ venv/
117
+ ENV/
118
+ env.bak/
119
+ venv.bak/
120
+
121
+ # Spyder project settings
122
+ .spyderproject
123
+ .spyproject
124
+
125
+ # Rope project settings
126
+ .ropeproject
127
+
128
+ # mkdocs documentation
129
+ /site
130
+
131
+ # mypy
132
+ .mypy_cache/
133
+ .dmypy.json
134
+ dmypy.json
135
+
136
+ # Pyre type checker
137
+ .pyre/
138
+
139
+ # pytype static type analyzer
140
+ .pytype/
141
+
142
+ # Profiling data
143
+ .prof
144
+
145
+ # Folder specific to your needs
146
+ **/tmp/
147
+ **/outputs/skyseg.onnx
148
+ skyseg.onnx
149
+
150
+ # pixi environments
151
+ .pixi
152
+ *.egg-info
LICENSE.txt ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 2.5 Generic
2
+ CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE LEGAL SERVICES. DISTRIBUTION OF THIS LICENSE DOES NOT CREATE AN ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES REGARDING THE INFORMATION PROVIDED, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM ITS USE.
3
+ License
4
+
5
+ THE WORK (AS DEFINED BELOW) IS PROVIDED UNDER THE TERMS OF THIS CREATIVE COMMONS PUBLIC LICENSE ("CCPL" OR "LICENSE"). THE WORK IS PROTECTED BY COPYRIGHT AND/OR OTHER APPLICABLE LAW. ANY USE OF THE WORK OTHER THAN AS AUTHORIZED UNDER THIS LICENSE OR COPYRIGHT LAW IS PROHIBITED.
6
+
7
+ BY EXERCISING ANY RIGHTS TO THE WORK PROVIDED HERE, YOU ACCEPT AND AGREE TO BE BOUND BY THE TERMS OF THIS LICENSE. THE LICENSOR GRANTS YOU THE RIGHTS CONTAINED HERE IN CONSIDERATION OF YOUR ACCEPTANCE OF SUCH TERMS AND CONDITIONS.
8
+
9
+ 1. Definitions
10
+
11
+ "Collective Work" means a work, such as a periodical issue, anthology or encyclopedia, in which the Work in its entirety in unmodified form, along with a number of other contributions, constituting separate and independent works in themselves, are assembled into a collective whole. A work that constitutes a Collective Work will not be considered a Derivative Work (as defined below) for the purposes of this License.
12
+ "Derivative Work" means a work based upon the Work or upon the Work and other pre-existing works, such as a translation, musical arrangement, dramatization, fictionalization, motion picture version, sound recording, art reproduction, abridgment, condensation, or any other form in which the Work may be recast, transformed, or adapted, except that a work that constitutes a Collective Work will not be considered a Derivative Work for the purpose of this License. For the avoidance of doubt, where the Work is a musical composition or sound recording, the synchronization of the Work in timed-relation with a moving image ("synching") will be considered a Derivative Work for the purpose of this License.
13
+ "Licensor" means the individual or entity that offers the Work under the terms of this License.
14
+ "Original Author" means the individual or entity who created the Work.
15
+ "Work" means the copyrightable work of authorship offered under the terms of this License.
16
+ "You" means an individual or entity exercising rights under this License who has not previously violated the terms of this License with respect to the Work, or who has received express permission from the Licensor to exercise rights under this License despite a previous violation.
17
+ 2. Fair Use Rights. Nothing in this license is intended to reduce, limit, or restrict any rights arising from fair use, first sale or other limitations on the exclusive rights of the copyright owner under copyright law or other applicable laws.
18
+
19
+ 3. License Grant. Subject to the terms and conditions of this License, Licensor hereby grants You a worldwide, royalty-free, non-exclusive, perpetual (for the duration of the applicable copyright) license to exercise the rights in the Work as stated below:
20
+
21
+ to reproduce the Work, to incorporate the Work into one or more Collective Works, and to reproduce the Work as incorporated in the Collective Works;
22
+ to create and reproduce Derivative Works;
23
+ to distribute copies or phonorecords of, display publicly, perform publicly, and perform publicly by means of a digital audio transmission the Work including as incorporated in Collective Works;
24
+ to distribute copies or phonorecords of, display publicly, perform publicly, and perform publicly by means of a digital audio transmission Derivative Works;
25
+ The above rights may be exercised in all media and formats whether now known or hereafter devised. The above rights include the right to make such modifications as are technically necessary to exercise the rights in other media and formats. All rights not expressly granted by Licensor are hereby reserved, including but not limited to the rights set forth in Sections 4(d) and 4(e).
26
+
27
+ 4. Restrictions. The license granted in Section 3 above is expressly made subject to and limited by the following restrictions:
28
+
29
+ You may distribute, publicly display, publicly perform, or publicly digitally perform the Work only under the terms of this License, and You must include a copy of, or the Uniform Resource Identifier for, this License with every copy or phonorecord of the Work You distribute, publicly display, publicly perform, or publicly digitally perform. You may not offer or impose any terms on the Work that alter or restrict the terms of this License or the recipients' exercise of the rights granted hereunder. You may not sublicense the Work. You must keep intact all notices that refer to this License and to the disclaimer of warranties. You may not distribute, publicly display, publicly perform, or publicly digitally perform the Work with any technological measures that control access or use of the Work in a manner inconsistent with the terms of this License Agreement. The above applies to the Work as incorporated in a Collective Work, but this does not require the Collective Work apart from the Work itself to be made subject to the terms of this License. If You create a Collective Work, upon notice from any Licensor You must, to the extent practicable, remove from the Collective Work any credit as required by clause 4(c), as requested. If You create a Derivative Work, upon notice from any Licensor You must, to the extent practicable, remove from the Derivative Work any credit as required by clause 4(c), as requested.
30
+ You may not exercise any of the rights granted to You in Section 3 above in any manner that is primarily intended for or directed toward commercial advantage or private monetary compensation. The exchange of the Work for other copyrighted works by means of digital file-sharing or otherwise shall not be considered to be intended for or directed toward commercial advantage or private monetary compensation, provided there is no payment of any monetary compensation in connection with the exchange of copyrighted works.
31
+ If you distribute, publicly display, publicly perform, or publicly digitally perform the Work or any Derivative Works or Collective Works, You must keep intact all copyright notices for the Work and provide, reasonable to the medium or means You are utilizing: (i) the name of Original Author (or pseudonym, if applicable) if supplied, and/or (ii) if the Original Author and/or Licensor designate another party or parties (e.g. a sponsor institute, publishing entity, journal) for attribution in Licensor's copyright notice, terms of service or by other reasonable means, the name of such party or parties; the title of the Work if supplied; to the extent reasonably practicable, the Uniform Resource Identifier, if any, that Licensor specifies to be associated with the Work, unless such URI does not refer to the copyright notice or licensing information for the Work; and in the case of a Derivative Work, a credit identifying the use of the Work in the Derivative Work (e.g., "French translation of the Work by Original Author," or "Screenplay based on original Work by Original Author"). Such credit may be implemented in any reasonable manner; provided, however, that in the case of a Derivative Work or Collective Work, at a minimum such credit will appear where any other comparable authorship credit appears and in a manner at least as prominent as such other comparable authorship credit.
32
+ For the avoidance of doubt, where the Work is a musical composition:
33
+
34
+ Performance Royalties Under Blanket Licenses . Licensor reserves the exclusive right to collect, whether individually or via a performance rights society (e.g. ASCAP, BMI, SESAC), royalties for the public performance or public digital performance (e.g. webcast) of the Work if that performance is primarily intended for or directed toward commercial advantage or private monetary compensation.
35
+ Mechanical Rights and Statutory Royalties . Licensor reserves the exclusive right to collect, whether individually or via a music rights agency or designated agent (e.g. Harry Fox Agency), royalties for any phonorecord You create from the Work ("cover version") and distribute, subject to the compulsory license created by 17 USC Section 115 of the US Copyright Act (or the equivalent in other jurisdictions), if Your distribution of such cover version is primarily intended for or directed toward commercial advantage or private monetary compensation.
36
+ Webcasting Rights and Statutory Royalties. For the avoidance of doubt, where the Work is a sound recording, Licensor reserves the exclusive right to collect, whether individually or via a performance-rights society (e.g. SoundExchange), royalties for the public digital performance (e.g. webcast) of the Work, subject to the compulsory license created by 17 USC Section 114 of the US Copyright Act (or the equivalent in other jurisdictions), if Your public digital performance is primarily intended for or directed toward commercial advantage or private monetary compensation.
37
+ 5. Representations, Warranties and Disclaimer
38
+
39
+ UNLESS OTHERWISE MUTUALLY AGREED TO BY THE PARTIES IN WRITING, LICENSOR OFFERS THE WORK AS-IS AND MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE WORK, EXPRESS, IMPLIED, STATUTORY OR OTHERWISE, INCLUDING, WITHOUT LIMITATION, WARRANTIES OF TITLE, MERCHANTIBILITY, FITNESS FOR A PARTICULAR PURPOSE, NONINFRINGEMENT, OR THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OF ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OF IMPLIED WARRANTIES, SO SUCH EXCLUSION MAY NOT APPLY TO YOU.
40
+
41
+ 6. Limitation on Liability. EXCEPT TO THE EXTENT REQUIRED BY APPLICABLE LAW, IN NO EVENT WILL LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY FOR ANY SPECIAL, INCIDENTAL, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES ARISING OUT OF THIS LICENSE OR THE USE OF THE WORK, EVEN IF LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
42
+
43
+ 7. Termination
44
+
45
+ This License and the rights granted hereunder will terminate automatically upon any breach by You of the terms of this License. Individuals or entities who have received Derivative Works or Collective Works from You under this License, however, will not have their licenses terminated provided such individuals or entities remain in full compliance with those licenses. Sections 1, 2, 5, 6, 7, and 8 will survive any termination of this License.
46
+ Subject to the above terms and conditions, the license granted here is perpetual (for the duration of the applicable copyright in the Work). Notwithstanding the above, Licensor reserves the right to release the Work under different license terms or to stop distributing the Work at any time; provided, however that any such election will not serve to withdraw this License (or any other license that has been, or is required to be, granted under the terms of this License), and this License will continue in full force and effect unless terminated as stated above.
47
+ 8. Miscellaneous
48
+
49
+ Each time You distribute or publicly digitally perform the Work or a Collective Work, the Licensor offers to the recipient a license to the Work on the same terms and conditions as the license granted to You under this License.
50
+ Each time You distribute or publicly digitally perform a Derivative Work, Licensor offers to the recipient a license to the original Work on the same terms and conditions as the license granted to You under this License.
51
+ If any provision of this License is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this License, and without further action by the parties to this agreement, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable.
52
+ No term or provision of this License shall be deemed waived and no breach consented to unless such waiver or consent shall be in writing and signed by the party to be charged with such waiver or consent.
53
+ This License constitutes the entire agreement between the parties with respect to the Work licensed here. There are no understandings, agreements or representations with respect to the Work not specified here. Licensor shall not be bound by any additional provisions that may appear in any communication from You. This License may not be modified without the mutual written agreement of the Licensor and You.
54
+ Creative Commons is not a party to this License, and makes no warranty whatsoever in connection with the Work. Creative Commons will not be liable to You or any party on any legal theory for any damages whatsoever, including without limitation any general, special, incidental or consequential damages arising in connection to this license. Notwithstanding the foregoing two (2) sentences, if Creative Commons has expressly identified itself as the Licensor hereunder, it shall have all rights and obligations of Licensor.
55
+
56
+ Except for the limited purpose of indicating to the public that the Work is licensed under the CCPL, neither party will use the trademark "Creative Commons" or any related trademark or logo of Creative Commons without the prior written consent of Creative Commons. Any permitted use will be in compliance with Creative Commons' then-current trademark usage guidelines, as may be published on its website or otherwise made available upon request from time to time.
57
+
58
+ Creative Commons may be contacted at https://creativecommons.org/ .
UniCeption/.gitignore ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Local Folders
2
+ checkpoints/
3
+ local/
4
+ reference_data/
5
+
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ .pybuilder/
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # poetry
103
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107
+ #poetry.lock
108
+
109
+ # pdm
110
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111
+ #pdm.lock
112
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113
+ # in version control.
114
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
115
+ .pdm.toml
116
+ .pdm-python
117
+ .pdm-build/
118
+
119
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
120
+ __pypackages__/
121
+
122
+ # Celery stuff
123
+ celerybeat-schedule
124
+ celerybeat.pid
125
+
126
+ # SageMath parsed files
127
+ *.sage.py
128
+
129
+ # Environments
130
+ .env
131
+ .venv
132
+ env/
133
+ venv/
134
+ ENV/
135
+ env.bak/
136
+ venv.bak/
137
+
138
+ # Spyder project settings
139
+ .spyderproject
140
+ .spyproject
141
+
142
+ # Rope project settings
143
+ .ropeproject
144
+
145
+ # mkdocs documentation
146
+ /site
147
+
148
+ # mypy
149
+ .mypy_cache/
150
+ .dmypy.json
151
+ dmypy.json
152
+
153
+ # Pyre type checker
154
+ .pyre/
155
+
156
+ # pytype static type analyzer
157
+ .pytype/
158
+
159
+ # Cython debug symbols
160
+ cython_debug/
161
+
162
+ # PyCharm
163
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
164
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
165
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
166
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
167
+ #.idea/
UniCeption/.pre-commit-config.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # See https://pre-commit.com for more information
2
+ # See https://pre-commit.com/hooks.html for more hooks
3
+ default_language_version:
4
+ python: python3
5
+ repos:
6
+ - repo: https://github.com/pre-commit/pre-commit-hooks
7
+ rev: v3.2.0
8
+ hooks:
9
+ - id: trailing-whitespace
10
+ - id: end-of-file-fixer
11
+ - repo: https://github.com/pre-commit/mirrors-isort
12
+ rev: 'v5.10.1'
13
+ hooks:
14
+ - id: isort
15
+ - repo: https://github.com/psf/black
16
+ rev: '23.3.0'
17
+ hooks:
18
+ - id: black
UniCeption/.pylintrc ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This Pylint rcfile contains a best-effort configuration to uphold the
2
+ # best-practices and style described in the Google Python style guide:
3
+ # https://google.github.io/styleguide/pyguide.html
4
+ #
5
+ # Its canonical open-source location is:
6
+ # https://google.github.io/styleguide/pylintrc
7
+
8
+ [MAIN]
9
+
10
+ # Files or directories to be skipped. They should be base names, not paths.
11
+ ignore=third_party
12
+
13
+ # Files or directories matching the regex patterns are skipped. The regex
14
+ # matches against base names, not paths.
15
+ ignore-patterns=
16
+
17
+ # Pickle collected data for later comparisons.
18
+ persistent=no
19
+
20
+ # List of plugins (as comma separated values of python modules names) to load,
21
+ # usually to register additional checkers.
22
+ load-plugins=
23
+
24
+ # Use multiple processes to speed up Pylint.
25
+ jobs=4
26
+
27
+ # Allow loading of arbitrary C extensions. Extensions are imported into the
28
+ # active Python interpreter and may run arbitrary code.
29
+ unsafe-load-any-extension=no
30
+
31
+
32
+ [MESSAGES CONTROL]
33
+
34
+ # Only show warnings with the listed confidence levels. Leave empty to show
35
+ # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
36
+ confidence=
37
+
38
+ # Enable the message, report, category or checker with the given id(s). You can
39
+ # either give multiple identifier separated by comma (,) or put this option
40
+ # multiple time (only on the command line, not in the configuration file where
41
+ # it should appear only once). See also the "--disable" option for examples.
42
+ #enable=
43
+
44
+ # Disable the message, report, category or checker with the given id(s). You
45
+ # can either give multiple identifiers separated by comma (,) or put this
46
+ # option multiple times (only on the command line, not in the configuration
47
+ # file where it should appear only once).You can also use "--disable=all" to
48
+ # disable everything first and then reenable specific checks. For example, if
49
+ # you want to run only the similarities checker, you can use "--disable=all
50
+ # --enable=similarities". If you want to run only the classes checker, but have
51
+ # no Warning level messages displayed, use"--disable=all --enable=classes
52
+ # --disable=W"
53
+ disable=R,
54
+ abstract-method,
55
+ apply-builtin,
56
+ arguments-differ,
57
+ attribute-defined-outside-init,
58
+ backtick,
59
+ bad-option-value,
60
+ basestring-builtin,
61
+ buffer-builtin,
62
+ c-extension-no-member,
63
+ consider-using-enumerate,
64
+ cmp-builtin,
65
+ cmp-method,
66
+ coerce-builtin,
67
+ coerce-method,
68
+ delslice-method,
69
+ div-method,
70
+ eq-without-hash,
71
+ execfile-builtin,
72
+ file-builtin,
73
+ filter-builtin-not-iterating,
74
+ fixme,
75
+ getslice-method,
76
+ global-statement,
77
+ hex-method,
78
+ idiv-method,
79
+ implicit-str-concat,
80
+ import-error,
81
+ import-self,
82
+ import-star-module-level,
83
+ input-builtin,
84
+ intern-builtin,
85
+ invalid-str-codec,
86
+ locally-disabled,
87
+ long-builtin,
88
+ long-suffix,
89
+ map-builtin-not-iterating,
90
+ misplaced-comparison-constant,
91
+ missing-function-docstring,
92
+ metaclass-assignment,
93
+ next-method-called,
94
+ next-method-defined,
95
+ no-absolute-import,
96
+ no-init, # added
97
+ no-member,
98
+ no-name-in-module,
99
+ no-self-use,
100
+ nonzero-method,
101
+ oct-method,
102
+ old-division,
103
+ old-ne-operator,
104
+ old-octal-literal,
105
+ old-raise-syntax,
106
+ parameter-unpacking,
107
+ print-statement,
108
+ raising-string,
109
+ range-builtin-not-iterating,
110
+ raw_input-builtin,
111
+ rdiv-method,
112
+ reduce-builtin,
113
+ relative-import,
114
+ reload-builtin,
115
+ round-builtin,
116
+ setslice-method,
117
+ signature-differs,
118
+ standarderror-builtin,
119
+ suppressed-message,
120
+ sys-max-int,
121
+ trailing-newlines,
122
+ unichr-builtin,
123
+ unicode-builtin,
124
+ unnecessary-pass,
125
+ unpacking-in-except,
126
+ useless-else-on-loop,
127
+ useless-suppression,
128
+ using-cmp-argument,
129
+ wrong-import-order,
130
+ xrange-builtin,
131
+ zip-builtin-not-iterating,
132
+
133
+
134
+ [REPORTS]
135
+
136
+ # Set the output format. Available formats are text, parseable, colorized, msvs
137
+ # (visual studio) and html. You can also give a reporter class, eg
138
+ # mypackage.mymodule.MyReporterClass.
139
+ output-format=text
140
+
141
+ # Tells whether to display a full report or only the messages
142
+ reports=no
143
+
144
+ # Python expression which should return a note less than 10 (10 is the highest
145
+ # note). You have access to the variables errors warning, statement which
146
+ # respectively contain the number of errors / warnings messages and the total
147
+ # number of statements analyzed. This is used by the global evaluation report
148
+ # (RP0004).
149
+ evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
150
+
151
+ # Template used to display messages. This is a python new-style format string
152
+ # used to format the message information. See doc for all details
153
+ #msg-template=
154
+
155
+
156
+ [BASIC]
157
+
158
+ # Good variable names which should always be accepted, separated by a comma
159
+ good-names=main,_
160
+
161
+ # Bad variable names which should always be refused, separated by a comma
162
+ bad-names=
163
+
164
+ # Colon-delimited sets of names that determine each other's naming style when
165
+ # the name regexes allow several styles.
166
+ name-group=
167
+
168
+ # Include a hint for the correct naming format with invalid-name
169
+ include-naming-hint=no
170
+
171
+ # List of decorators that produce properties, such as abc.abstractproperty. Add
172
+ # to this list to register other decorators that produce valid properties.
173
+ property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
174
+
175
+ # Regular expression matching correct function names
176
+ function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
177
+
178
+ # Regular expression matching correct variable names
179
+ variable-rgx=^[a-z][a-z0-9_]*$
180
+
181
+ # Regular expression matching correct constant names
182
+ const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
183
+
184
+ # Regular expression matching correct attribute names
185
+ attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
186
+
187
+ # Regular expression matching correct argument names
188
+ argument-rgx=^[a-z][a-z0-9_]*$
189
+
190
+ # Regular expression matching correct class attribute names
191
+ class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
192
+
193
+ # Regular expression matching correct inline iteration names
194
+ inlinevar-rgx=^[a-z][a-z0-9_]*$
195
+
196
+ # Regular expression matching correct class names
197
+ class-rgx=^_?[A-Z][a-zA-Z0-9]*$
198
+
199
+ # Regular expression matching correct module names
200
+ module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
201
+
202
+ # Regular expression matching correct method names
203
+ method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
204
+
205
+ # Regular expression which should only match function or class names that do
206
+ # not require a docstring.
207
+ no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
208
+
209
+ # Minimum line length for functions/classes that require docstrings, shorter
210
+ # ones are exempt.
211
+ docstring-min-length=12
212
+
213
+
214
+ [TYPECHECK]
215
+
216
+ # List of decorators that produce context managers, such as
217
+ # contextlib.contextmanager. Add to this list to register other decorators that
218
+ # produce valid context managers.
219
+ contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
220
+
221
+ # List of module names for which member attributes should not be checked
222
+ # (useful for modules/projects where namespaces are manipulated during runtime
223
+ # and thus existing member attributes cannot be deduced by static analysis. It
224
+ # supports qualified module names, as well as Unix pattern matching.
225
+ ignored-modules=
226
+
227
+ # List of class names for which member attributes should not be checked (useful
228
+ # for classes with dynamically set attributes). This supports the use of
229
+ # qualified names.
230
+ ignored-classes=optparse.Values,thread._local,_thread._local
231
+
232
+ # List of members which are set dynamically and missed by pylint inference
233
+ # system, and so shouldn't trigger E1101 when accessed. Python regular
234
+ # expressions are accepted.
235
+ generated-members=
236
+
237
+
238
+ [FORMAT]
239
+
240
+ # Maximum number of characters on a single line.
241
+ max-line-length=80
242
+
243
+ # TODO(https://github.com/pylint-dev/pylint/issues/3352): Direct pylint to exempt
244
+ # lines made too long by directives to pytype.
245
+
246
+ # Regexp for a line that is allowed to be longer than the limit.
247
+ ignore-long-lines=(?x)(
248
+ ^\s*(\#\ )?<?https?://\S+>?$|
249
+ ^\s*(from\s+\S+\s+)?import\s+.+$)
250
+
251
+ # Allow the body of an if to be on the same line as the test if there is no
252
+ # else.
253
+ single-line-if-stmt=yes
254
+
255
+ # Maximum number of lines in a module
256
+ max-module-lines=99999
257
+
258
+ # String used as indentation unit. The internal Google style guide mandates 2
259
+ # spaces. Google's externaly-published style guide says 4, consistent with
260
+ # PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
261
+ # projects (like TensorFlow).
262
+ indent-string=' '
263
+
264
+ # Number of spaces of indent required inside a hanging or continued line.
265
+ indent-after-paren=4
266
+
267
+ # Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
268
+ expected-line-ending-format=
269
+
270
+
271
+ [MISCELLANEOUS]
272
+
273
+ # List of note tags to take in consideration, separated by a comma.
274
+ notes=TODO
275
+
276
+
277
+ [STRING]
278
+
279
+ # This flag controls whether inconsistent-quotes generates a warning when the
280
+ # character used as a quote delimiter is used inconsistently within a module.
281
+ check-quote-consistency=yes
282
+
283
+
284
+ [VARIABLES]
285
+
286
+ # Tells whether we should check for unused import in __init__ files.
287
+ init-import=no
288
+
289
+ # A regular expression matching the name of dummy variables (i.e. expectedly
290
+ # not used).
291
+ dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
292
+
293
+ # List of additional names supposed to be defined in builtins. Remember that
294
+ # you should avoid to define new builtins when possible.
295
+ additional-builtins=
296
+
297
+ # List of strings which can identify a callback function by name. A callback
298
+ # name must start or end with one of those strings.
299
+ callbacks=cb_,_cb
300
+
301
+ # List of qualified module names which can have objects that can redefine
302
+ # builtins.
303
+ redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
304
+
305
+
306
+ [LOGGING]
307
+
308
+ # Logging modules to check that the string format arguments are in logging
309
+ # function parameter format
310
+ logging-modules=logging,absl.logging,tensorflow.io.logging
311
+
312
+
313
+ [SIMILARITIES]
314
+
315
+ # Minimum lines number of a similarity.
316
+ min-similarity-lines=4
317
+
318
+ # Ignore comments when computing similarities.
319
+ ignore-comments=yes
320
+
321
+ # Ignore docstrings when computing similarities.
322
+ ignore-docstrings=yes
323
+
324
+ # Ignore imports when computing similarities.
325
+ ignore-imports=no
326
+
327
+
328
+ [SPELLING]
329
+
330
+ # Spelling dictionary name. Available dictionaries: none. To make it working
331
+ # install python-enchant package.
332
+ spelling-dict=
333
+
334
+ # List of comma separated words that should not be checked.
335
+ spelling-ignore-words=
336
+
337
+ # A path to a file that contains private dictionary; one word per line.
338
+ spelling-private-dict-file=
339
+
340
+ # Tells whether to store unknown words to indicated private dictionary in
341
+ # --spelling-private-dict-file option instead of raising a message.
342
+ spelling-store-unknown-words=no
343
+
344
+
345
+ [IMPORTS]
346
+
347
+ # Deprecated modules which should not be used, separated by a comma
348
+ deprecated-modules=regsub,
349
+ TERMIOS,
350
+ Bastion,
351
+ rexec,
352
+ sets
353
+
354
+ # Create a graph of every (i.e. internal and external) dependencies in the
355
+ # given file (report RP0402 must not be disabled)
356
+ import-graph=
357
+
358
+ # Create a graph of external dependencies in the given file (report RP0402 must
359
+ # not be disabled)
360
+ ext-import-graph=
361
+
362
+ # Create a graph of internal dependencies in the given file (report RP0402 must
363
+ # not be disabled)
364
+ int-import-graph=
365
+
366
+ # Force import order to recognize a module as part of the standard
367
+ # compatibility libraries.
368
+ known-standard-library=
369
+
370
+ # Force import order to recognize a module as part of a third party library.
371
+ known-third-party=enchant, absl
372
+
373
+ # Analyse import fallback blocks. This can be used to support both Python 2 and
374
+ # 3 compatible code, which means that the block might have code that exists
375
+ # only in one or another interpreter, leading to false positives when analysed.
376
+ analyse-fallback-blocks=no
377
+
378
+
379
+ [CLASSES]
380
+
381
+ # List of method names used to declare (i.e. assign) instance attributes.
382
+ defining-attr-methods=__init__,
383
+ __new__,
384
+ setUp
385
+
386
+ # List of member names, which should be excluded from the protected access
387
+ # warning.
388
+ exclude-protected=_asdict,
389
+ _fields,
390
+ _replace,
391
+ _source,
392
+ _make
393
+
394
+ # List of valid names for the first argument in a class method.
395
+ valid-classmethod-first-arg=cls,
396
+ class_
397
+
398
+ # List of valid names for the first argument in a metaclass class method.
399
+ valid-metaclass-classmethod-first-arg=mcs
UniCeption/LICENSE ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2024, AirLab Stacks
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
UniCeption/README.md ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # UniCeption
2
+
3
+ UniCeption houses modular building blocks for developing and training generalizable perception models for all things related to 3D, 4D, spatial AI and scene understanding.
4
+ It is designed to be flexible and extensible, allowing researchers to easily experiment with different architectures and configurations.
5
+
6
+ Please refer to the [Developer Guidelines](#developer-guidelines) for contributing to the project.
7
+
8
+ ## Installation
9
+
10
+ Clone the repository to your local machine by running the following command:
11
+
12
+ ```bash
13
+ git clone [email protected]:castacks/UniCeption.git
14
+ cd UniCeption
15
+ ```
16
+
17
+ ### Standard Installation
18
+
19
+ Install the `uniception` package in development mode by running the following commands:
20
+
21
+ ```bash
22
+ # Please use Conda or Python Virtual Environment based on your preference
23
+ # For Conda Environment
24
+ conda create --name uniception python=3.12
25
+ conda activate uniception
26
+ # For Python Virtual Environment
27
+ virtualenv uniception
28
+ source uniception/bin/activate
29
+
30
+ # Install UniCeption with base dependencies (includes PyTorch)
31
+ pip install -e .
32
+
33
+ # Optional: Install with XFormers support
34
+ pip install -e ".[xformers]"
35
+
36
+ # Optional: Install with development tools
37
+ pip install -e ".[dev]"
38
+
39
+ # Optional: Install all optional dependencies
40
+ pip install -e ".[all]"
41
+
42
+ # Setup pre-commit hooks for development
43
+ pre-commit install
44
+ ```
45
+
46
+ ### Optional: CroCo RoPE Extension Installation
47
+
48
+ To use CroCo models with the custom RoPE kernel:
49
+
50
+ ```bash
51
+ # Recommended: Use the console script
52
+ uniception-install-croco
53
+
54
+ # Alternative: Set environment variable during installation
55
+ INSTALL_CROCO_ROPE=true pip install -e .
56
+
57
+ # Manual compilation (if needed)
58
+ cd uniception/models/libs/croco/curope
59
+ python setup.py build_ext --inplace
60
+ cd ../../../../../
61
+ ```
62
+
63
+ ### Installation Validation and Dependency Checking
64
+
65
+ After installation, use these console scripts to validate your setup:
66
+
67
+ ```bash
68
+ # Validate installation and check dependencies
69
+ uniception-validate
70
+
71
+ # Check which optional dependencies are available
72
+ uniception-check-deps
73
+ ```
74
+
75
+ ### Advanced Installation Options
76
+
77
+ #### Docker Installation (No Internet Access)
78
+
79
+ If you're working in a Docker container that already has Python dependencies installed but no internet access, you can install UniCeption in development mode without triggering network requests:
80
+
81
+ ```bash
82
+ # Install only the package structure without dependencies
83
+ pip install -e . --no-deps
84
+ ```
85
+
86
+ **Note:** This command assumes your Docker image already contains all required dependencies (PyTorch, etc.). Use `uniception-validate` after installation to verify all dependencies are available.
87
+
88
+ #### Offline Installation
89
+
90
+ For environments without internet access:
91
+
92
+ ```bash
93
+ # 1. On a machine with internet access, prepare offline wheels
94
+ uniception-prepare-offline --output-dir offline_wheels --extras all
95
+
96
+ # 2. Copy the offline_wheels directory to your offline environment
97
+ # 3. Run the offline installation
98
+ cd offline_wheels
99
+ INSTALL_CROCO_ROPE=true INSTALL_XFORMERS=true ./install_offline.sh
100
+ ```
101
+
102
+ #### Downloading Checkpoints
103
+
104
+ Download UniCeption format custom checkpoints:
105
+
106
+ ```bash
107
+ # Download all available checkpoints
108
+ uniception-download-checkpoints
109
+
110
+ # Download specific folders only (e.g., encoders and prediction heads)
111
+ uniception-download-checkpoints --folders encoders prediction_heads
112
+
113
+ # Specify custom destination
114
+ uniception-download-checkpoints --destination /path/to/checkpoints
115
+ ```
116
+
117
+ **Available options:**
118
+ - `--folders`: Specify which folders to download. Choices: `encoders`, `info_sharing`, `prediction_heads`, `examples` (default: all folders)
119
+ - `--destination`: Custom destination folder for downloaded checkpoints (default: current directory)
120
+
121
+ ---
122
+
123
+ ## Currently Supported Components
124
+
125
+ ### Encoders
126
+
127
+ Please refer to the `uniception/models/encoders` directory for the supported encoders and documentation for adding new encoders. The supported encoders can be listed by running:
128
+
129
+ ```bash
130
+ python3 -m uniception.models.encoders.list
131
+ ```
132
+
133
+ ---
134
+
135
+ ## Information Sharing Blocks
136
+
137
+ Please refer to the `uniception/models/info_sharing` directory for the supported information sharing blocks.
138
+
139
+ ---
140
+
141
+ ## Prediction Heads
142
+
143
+ Please refer to the `uniception/models/prediction_heads` directory for the supported prediction heads.
144
+
145
+ ---
146
+
147
+ ## Developer Guidelines
148
+
149
+ Please follow these guidelines when contributing to UniCeption:
150
+ - **Code Style**: Follow the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html) for code style.
151
+ - **Documentation**: Add docstrings to all classes and methods.
152
+ - **Unit Tests**: Add necessary unit tests to the `tests` folder.
153
+ - **Linting**: Run `black` & `isort` on your code before committing. For example, you can run `black . && isort .`.
154
+
155
+ Please create a pull request for any changes you make, and ensure that all tests pass before merging.
UniCeption/examples/models/cosmos/autoencoding.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import torch
5
+ from matplotlib import pyplot as plt
6
+
7
+ from uniception.models.encoders.base import ViTEncoderInput
8
+ from uniception.models.encoders.cosmos import CosmosEncoder
9
+ from uniception.models.prediction_heads.cosmos import CosmosSingleChannel
10
+
11
+ base_path = os.path.dirname(os.path.abspath(__file__))
12
+
13
+ encoder = CosmosEncoder(
14
+ name="cosmos",
15
+ patch_size=8,
16
+ pretrained_checkpoint_path=os.path.join(
17
+ base_path, "../../../checkpoints/encoders/cosmos/Cosmos-Tokenizer-CI8x8/encoder.pth"
18
+ ),
19
+ )
20
+
21
+ decoder = CosmosSingleChannel(
22
+ patch_size=8,
23
+ pretrained_checkpoint_path=os.path.join(base_path, "../../../checkpoints/prediction_heads/cosmos/decoder_8.pth"),
24
+ )
25
+
26
+ example_image = cv2.imread(os.path.join(base_path, "./example.png"))
27
+ example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB)
28
+ example_tensor = torch.tensor(example_image).permute(2, 0, 1).unsqueeze(0).float() / 255.0
29
+ example_tensor = example_tensor * 2.0 - 1.0 # Normalize to [-1, 1] according to the COSMOS Encoder
30
+
31
+ encoded_latent = encoder(ViTEncoderInput("cosmos", example_tensor)).features
32
+
33
+ decoded_image = decoder(encoded_latent)
34
+ decoded_image = (decoded_image + 1.0) / 2.0 # Denormalize to [0, 1] for visualization
35
+
36
+ # plot the original and decoded images
37
+ plt.figure(figsize=(10, 5))
38
+ plt.subplot(1, 2, 1)
39
+ plt.imshow(example_image)
40
+ plt.title("Original Image")
41
+ plt.axis("off")
42
+
43
+ plt.subplot(1, 2, 2)
44
+ plt.imshow(decoded_image.squeeze().detach().permute(1, 2, 0).cpu().numpy())
45
+ plt.title("Decoded Image")
46
+ plt.axis("off")
47
+
48
+ plt.savefig(os.path.join(base_path, "example_decoded.png"))
UniCeption/examples/models/cosmos/example.png ADDED

Git LFS Details

  • SHA256: 5e6ee5528f76e5c0794e2708d688877b0f06f2139a11e883a3832ad57f19f89c
  • Pointer size: 131 Bytes
  • Size of remote file: 711 kB
UniCeption/examples/models/cosmos/example_decoded.png ADDED

Git LFS Details

  • SHA256: f948b50b602260352e14fca5f51999f01bd98b8e167dd1595451418380eaed21
  • Pointer size: 131 Bytes
  • Size of remote file: 348 kB
UniCeption/examples/models/dust3r/convert_dust3r_weights_to_uniception.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file extracts the cross-attention transformer & prediction head weights from dust3r checkpoints into uniception format.
3
+
4
+ Special Notice: dust3r have changed their released weights before/after CVPR, and
5
+ uniception uses the checkpoint BEFORE CVPR (they perform better). So please make sure you are not converting
6
+ the newly downloaded weights. Consult Yuchen and Nikhil on where to find the old weights.
7
+ """
8
+
9
+ import argparse
10
+ import os
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from uniception.models.info_sharing.cross_attention_transformer import MultiViewCrossAttentionTransformerIFR
16
+ from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor
17
+ from uniception.models.prediction_heads.linear import LinearFeature
18
+
19
+
20
+ def extract_cross_attention_weights(checkpoint_path, output_folder, output_filename):
21
+ "Extract the UniCeption format cross attention weights from the original CroCoV2/DUSt3R/MASt3R checkpoints."
22
+ # Load checkpoint
23
+ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
24
+
25
+ # Filter the relevant keys for the cross attention model and duplicate if necessary
26
+ filtered_checkpoint = checkpoint["model"]
27
+ filtered_checkpoint = {k: v for k, v in filtered_checkpoint.items() if "dec" in k}
28
+ duplicate_checkpoint = {}
29
+ if not any(k.startswith("dec_blocks2") for k in filtered_checkpoint):
30
+ print("Duplicating dec_blocks to dec_blocks2")
31
+ for key, value in filtered_checkpoint.items():
32
+ if key.startswith("dec_blocks"):
33
+ duplicate_checkpoint[key.replace("dec_blocks", "dec_blocks2")] = value
34
+ filtered_checkpoint = {**filtered_checkpoint, **duplicate_checkpoint}
35
+ new_checkpoint = {}
36
+ for k, v in filtered_checkpoint.items():
37
+ if "decoder_embed" in k:
38
+ new_key = k.replace("decoder_embed", "proj_embed")
39
+ new_checkpoint[new_key] = v
40
+ elif "dec_blocks." in k:
41
+ new_key = k.replace("dec_blocks.", "multi_view_branches.0.")
42
+ new_checkpoint[new_key] = v
43
+ elif "dec_blocks2." in k:
44
+ new_key = k.replace("dec_blocks2.", "multi_view_branches.1.")
45
+ new_checkpoint[new_key] = v
46
+ elif "dec_norm" in k:
47
+ new_key = k.replace("dec_norm", "norm")
48
+ new_checkpoint[new_key] = v
49
+
50
+ # Init model
51
+ model = MultiViewCrossAttentionTransformerIFR(
52
+ name="MV-CAT-IFR",
53
+ input_embed_dim=1024,
54
+ num_views=2,
55
+ indices=[5, 8],
56
+ norm_intermediate=False,
57
+ )
58
+
59
+ # Load new checkpoint
60
+ print(model.load_state_dict(new_checkpoint))
61
+
62
+ # Save the checkpoint
63
+ save_checkpoint = {}
64
+ save_checkpoint["model"] = model.state_dict()
65
+ os.makedirs(os.path.join(output_folder, "cross_attn_transformer"), exist_ok=True)
66
+ save_path = os.path.join(output_folder, "cross_attn_transformer", output_filename)
67
+ torch.save(save_checkpoint, save_path)
68
+
69
+
70
+ def extract_dust3r_dpt_checkpoints(checkpoint_path, output_folder, output_filename):
71
+ "Extract the UniCeption format DPT head weights from the original DUSt3R checkpoint."
72
+ source_ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
73
+
74
+ for head in ["head1", "head2"]:
75
+ # Extract head weights from the checkpoint
76
+ dpt_head_weights = {k: v for k, v in source_ckpt["model"].items() if k.startswith(f"downstream_{head}")}
77
+ dpt_head_weights = {k.replace(f"downstream_{head}.dpt.", ""): v for k, v in dpt_head_weights.items()}
78
+ dpt_feature_weights = {k: v for k, v in dpt_head_weights.items() if not (k.startswith("head"))}
79
+
80
+ # Construct the DPTFeature module and load the weights
81
+ dpt = DPTFeature(
82
+ patch_size=16,
83
+ hooks=[0, 1, 2, 3],
84
+ input_feature_dims=[1024, 768, 768, 768],
85
+ layer_dims=[96, 192, 384, 768],
86
+ feature_dim=256,
87
+ use_bn=False,
88
+ output_width_ratio=1,
89
+ )
90
+
91
+ dpt.load_state_dict(dpt_feature_weights, strict=True)
92
+
93
+ # Construct the dpt processor module and load the weights
94
+ dpt_processor_weights = {k.replace("head.", ""): v for k, v in dpt_head_weights.items() if k.startswith("head")}
95
+
96
+ # Replace the keys according to:
97
+ key_replace_dict = {
98
+ "0.weight": "conv1.weight",
99
+ "0.bias": "conv1.bias",
100
+ "2.weight": "conv2.0.weight",
101
+ "2.bias": "conv2.0.bias",
102
+ "4.weight": "conv2.2.weight",
103
+ "4.bias": "conv2.2.bias",
104
+ }
105
+
106
+ dpt_processor_weights = {key_replace_dict.get(k, k): v for k, v in dpt_processor_weights.items()}
107
+
108
+ dpt_reg_processor = DPTRegressionProcessor(input_feature_dim=256, output_dim=4, hidden_dims=[128, 128])
109
+
110
+ dpt_reg_processor.load_state_dict(dpt_processor_weights, strict=True)
111
+
112
+ # Save the state_dicts of the DPTFeature and DPTRegressionProcessor
113
+ dpt_feature_path = os.path.join(output_folder, "dpt_feature_head", output_filename + f"_feature_{head}.pth")
114
+ dpt_reg_processor_path = os.path.join(
115
+ output_folder, "dpt_reg_processor", output_filename + f"_reg_processor{head[-1]}.pth"
116
+ )
117
+
118
+ os.makedirs(os.path.dirname(dpt_feature_path), exist_ok=True)
119
+ os.makedirs(os.path.dirname(dpt_reg_processor_path), exist_ok=True)
120
+
121
+ torch.save({"model": dpt.state_dict()}, dpt_feature_path)
122
+ torch.save({"model": dpt_reg_processor.state_dict()}, dpt_reg_processor_path)
123
+
124
+
125
+ def extract_dust3r_linear_checkpoints(checkpoint_path, output_folder, output_filename):
126
+ "Extract the UniCeption format linear head weights from the original DUSt3R checkpoint."
127
+ test_linear_to_conv()
128
+
129
+ source_ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
130
+
131
+ for head in ["head1", "head2"]:
132
+ linear_head_params = {k: v for k, v in source_ckpt["model"].items() if k.startswith(f"downstream_{head}")}
133
+ linear_head_params = {k.replace(f"downstream_{head}.proj.", ""): v for k, v in linear_head_params.items()}
134
+
135
+ assert set(linear_head_params.keys()) == {"weight", "bias"}
136
+
137
+ input_feature_dim = 768
138
+ output_dim = 4
139
+ patch_size = 16
140
+
141
+ linear = nn.Linear(input_feature_dim, output_dim * patch_size * patch_size, bias=True)
142
+ linear.load_state_dict(linear_head_params, strict=True)
143
+
144
+ conv_layer = linear_to_conv2d(linear)
145
+
146
+ linear_feature = LinearFeature(input_feature_dim, 4, patch_size)
147
+ linear_feature.linear.load_state_dict(conv_layer.state_dict(), strict=True)
148
+
149
+ linear_feature_path = os.path.join(
150
+ output_folder, "linear_feature_head", output_filename + f"_feature_{head}.pth"
151
+ )
152
+ os.makedirs(os.path.dirname(linear_feature_path), exist_ok=True)
153
+ torch.save({"model": linear_feature.state_dict()}, linear_feature_path)
154
+
155
+
156
+ def extract_mast3r_dpt_checkpoints(checkpoint_path, output_folder, output_filename):
157
+ "Extract the UniCeption format DPT head weights from the original MASt3R checkpoint."
158
+ source_ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
159
+
160
+ for head in ["head1", "head2"]:
161
+ dpt_head = {k: v for k, v in source_ckpt["model"].items() if k.startswith(f"downstream_{head}")}
162
+ dpt_head = {k.replace(f"downstream_{head}.", ""): v for k, v in dpt_head.items()}
163
+ dpt_head = {k.replace("dpt.", ""): v for k, v in dpt_head.items()}
164
+
165
+ dpt_feature_weights = {
166
+ k: v for k, v in dpt_head.items() if not (k.startswith("head") or k.startswith("head_local_features"))
167
+ }
168
+
169
+ dpt = DPTFeature(
170
+ patch_size=16,
171
+ hooks=[0, 1, 2, 3],
172
+ input_feature_dims=[1024, 768, 768, 768],
173
+ layer_dims=[96, 192, 384, 768],
174
+ feature_dim=256,
175
+ use_bn=False,
176
+ output_width_ratio=1,
177
+ )
178
+
179
+ dpt.load_state_dict(dpt_feature_weights, strict=True)
180
+
181
+ dpt_processor_weights = {
182
+ k.replace("head.", ""): v
183
+ for k, v in dpt_head.items()
184
+ if (k.startswith("head") and not k.startswith("head_local_features"))
185
+ }
186
+
187
+ # Replace the keys according to:
188
+ key_replace_dict = {
189
+ "0.weight": "conv1.weight",
190
+ "0.bias": "conv1.bias",
191
+ "2.weight": "conv2.0.weight",
192
+ "2.bias": "conv2.0.bias",
193
+ "4.weight": "conv2.2.weight",
194
+ "4.bias": "conv2.2.bias",
195
+ }
196
+
197
+ dpt_processor_weights = {key_replace_dict.get(k, k): v for k, v in dpt_processor_weights.items()}
198
+
199
+ dpt_reg_processor = DPTRegressionProcessor(input_feature_dim=256, output_dim=4, hidden_dims=[128, 128])
200
+
201
+ dpt_reg_processor.load_state_dict(dpt_processor_weights, strict=True)
202
+
203
+ # Save the state_dicts of the DPTFeature and DPTRegressionProcessor
204
+ dpt_feature_path = os.path.join(output_folder, "dpt_feature_head", output_filename + f"_feature_{head}.pth")
205
+ dpt_reg_processor_path = os.path.join(
206
+ output_folder, "dpt_reg_processor", output_filename + f"_reg_processor{head[-1]}.pth"
207
+ )
208
+
209
+ os.makedirs(os.path.dirname(dpt_feature_path), exist_ok=True)
210
+ os.makedirs(os.path.dirname(dpt_reg_processor_path), exist_ok=True)
211
+
212
+ torch.save({"model": dpt.state_dict()}, dpt_feature_path)
213
+ torch.save({"model": dpt_reg_processor.state_dict()}, dpt_reg_processor_path)
214
+
215
+
216
+ def linear_to_conv2d(linear_layer):
217
+ """
218
+ Converts a nn.Linear layer to an equivalent nn.Conv2d layer with a 1x1 kernel.
219
+
220
+ Parameters:
221
+ - linear_layer (nn.Linear): The Linear layer to convert.
222
+
223
+ Returns:
224
+ - conv_layer (nn.Conv2d): The equivalent Conv2d layer.
225
+ """
226
+ # Extract in_features and out_features from the Linear layer
227
+ in_features = linear_layer.in_features
228
+ out_features = linear_layer.out_features
229
+ bias = linear_layer.bias is not None
230
+
231
+ # Create a Conv2d layer with a 1x1 kernel
232
+ conv_layer = nn.Conv2d(
233
+ in_channels=in_features, out_channels=out_features, kernel_size=1, stride=1, padding=0, bias=bias
234
+ )
235
+
236
+ # Reshape Linear weights to match Conv2d weights
237
+ conv_weight = linear_layer.weight.data.view(out_features, in_features, 1, 1).clone()
238
+ conv_layer.weight.data = conv_weight
239
+
240
+ # Copy bias if it exists
241
+ if bias:
242
+ conv_layer.bias.data = linear_layer.bias.data.clone()
243
+
244
+ return conv_layer
245
+
246
+
247
+ def test_linear_to_conv():
248
+ "Test the linear_to_conv2d function."
249
+ batch_size = 4
250
+ height = 16
251
+ width = 24
252
+ in_channels = 3
253
+ out_channels = 5
254
+
255
+ # Sample input tensor in BHWC format
256
+ x_linear = torch.randn(batch_size, height, width, in_channels)
257
+
258
+ # Define Linear layer
259
+ linear_layer = nn.Linear(in_channels, out_channels)
260
+ output_linear = linear_layer(x_linear)
261
+
262
+ # Transpose input tensor to BCHW format for Conv2d
263
+ x_conv = x_linear.permute(0, 3, 1, 2)
264
+
265
+ # Define Conv2d layer
266
+ conv_layer = linear_to_conv2d(linear_layer)
267
+
268
+ # Get Conv2d output and transpose back to BHWC format
269
+ output_conv = conv_layer(x_conv).permute(0, 2, 3, 1)
270
+
271
+ # Verify that outputs are the same
272
+ assert torch.allclose(output_linear, output_conv, atol=1e-6)
273
+
274
+
275
+ if __name__ == "__main__":
276
+ parser = argparse.ArgumentParser(description="Extract dust3r checkpoints to uniception format")
277
+
278
+ parser.add_argument(
279
+ "-dcf", "--dust3r_checkpoints_folder", type=str, required=True, help="Path to the dust3r checkpoints folder"
280
+ )
281
+ parser.add_argument("-of", "--output_folder", type=str, required=True, help="Path to the output folder")
282
+
283
+ args = parser.parse_args()
284
+
285
+ output_folder = args.output_folder
286
+ info_sharing_output_folder = os.path.join(output_folder, "info_sharing")
287
+ pred_head_output_folder = os.path.join(output_folder, "prediction_heads")
288
+ os.makedirs(output_folder, exist_ok=True)
289
+ os.makedirs(info_sharing_output_folder, exist_ok=True)
290
+ os.makedirs(pred_head_output_folder, exist_ok=True)
291
+
292
+ # Extract croco checkpoint
293
+ print("Extracting CroCo checkpoint...")
294
+ croco_ckpt_filepath = os.path.join(args.dust3r_checkpoints_folder, "CroCo_V2_ViTLarge_BaseDecoder.pth")
295
+ extract_cross_attention_weights(
296
+ croco_ckpt_filepath, info_sharing_output_folder, "Two_View_Cross_Attention_Transformer_CroCo.pth"
297
+ )
298
+
299
+ # Extract dust3r 224 linear checkpoint
300
+ print("Extracting DUSt3R 224 linear checkpoint...")
301
+ dust3r_ckpt_filepath = os.path.join(args.dust3r_checkpoints_folder, "DUSt3R_ViTLarge_BaseDecoder_224_linear.pth")
302
+ extract_cross_attention_weights(
303
+ dust3r_ckpt_filepath, info_sharing_output_folder, "Two_View_Cross_Attention_Transformer_DUSt3R_224_linear.pth"
304
+ )
305
+ extract_dust3r_linear_checkpoints(dust3r_ckpt_filepath, pred_head_output_folder, "DUSt3R_224_linear")
306
+
307
+ # Extract dust3r 512 linear checkpoint
308
+ print("Extracting DUSt3R 512 linear checkpoint...")
309
+ dust3r_ckpt_filepath = os.path.join(args.dust3r_checkpoints_folder, "DUSt3R_ViTLarge_BaseDecoder_512_linear.pth")
310
+ extract_cross_attention_weights(
311
+ dust3r_ckpt_filepath, info_sharing_output_folder, "Two_View_Cross_Attention_Transformer_DUSt3R_512_linear.pth"
312
+ )
313
+ extract_dust3r_linear_checkpoints(dust3r_ckpt_filepath, pred_head_output_folder, "DUSt3R_512_linear")
314
+
315
+ # Extract dust3r 512 dpt checkpoint
316
+ print("Extracting DUSt3R 512 dpt checkpoint...")
317
+ dust3r_ckpt_filepath = os.path.join(args.dust3r_checkpoints_folder, "DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth")
318
+ extract_cross_attention_weights(
319
+ dust3r_ckpt_filepath, info_sharing_output_folder, "Two_View_Cross_Attention_Transformer_DUSt3R_512_dpt.pth"
320
+ )
321
+ extract_dust3r_dpt_checkpoints(dust3r_ckpt_filepath, pred_head_output_folder, "DUSt3R_512_dpt")
322
+
323
+ # Extract mast3r 512 dpt checkpoint
324
+ print("Extracting MASt3R 512 dpt checkpoint...")
325
+ mast3r_ckpt_path = os.path.join(
326
+ args.dust3r_checkpoints_folder, "MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth"
327
+ )
328
+ extract_cross_attention_weights(
329
+ mast3r_ckpt_path, info_sharing_output_folder, "Two_View_Cross_Attention_Transformer_MASt3R_512_dpt.pth"
330
+ )
331
+ extract_mast3r_dpt_checkpoints(mast3r_ckpt_path, pred_head_output_folder, "MASt3R_512_dpt")
UniCeption/examples/models/dust3r/dust3r.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Initalizing Pre-trained DUSt3R using UniCeption
3
+ """
4
+
5
+ import argparse
6
+ import os
7
+ from io import BytesIO
8
+
9
+ import numpy as np
10
+ import requests
11
+ import rerun as rr
12
+ import torch
13
+ from PIL import Image
14
+
15
+ from uniception.models.factory import DUSt3R
16
+ from uniception.utils.viz import script_add_rerun_args
17
+
18
+
19
+ def get_model_configurations_and_checkpoints():
20
+ """
21
+ Get different DUSt3R model configurations and paths to refactored checkpoints.
22
+
23
+ Returns:
24
+ Tuple[List[str], dict]: A tuple containing the model configurations and paths to refactored checkpoints.
25
+ """
26
+ # Initialize model configurations
27
+ model_configurations = ["dust3r_224_linear", "dust3r_512_linear", "dust3r_512_dpt", "dust3r_512_dpt_mast3r"]
28
+
29
+ # Get paths to pretrained checkpoints
30
+ current_file_path = os.path.abspath(__file__)
31
+ relative_checkpoint_path = os.path.join(os.path.dirname(current_file_path), "../../../checkpoints")
32
+
33
+ # Initialize model configurations
34
+ model_to_checkpoint_path = {
35
+ "dust3r_512_dpt": {
36
+ "encoder": f"{relative_checkpoint_path}/encoders/CroCo_Encoder_512_DUSt3R_dpt.pth",
37
+ "info_sharing": f"{relative_checkpoint_path}/info_sharing/cross_attn_transformer/Two_View_Cross_Attention_Transformer_DUSt3R_512_dpt.pth",
38
+ "feature_head": [
39
+ f"{relative_checkpoint_path}/prediction_heads/dpt_feature_head/DUSt3R_512_dpt_feature_head1.pth",
40
+ f"{relative_checkpoint_path}/prediction_heads/dpt_feature_head/DUSt3R_512_dpt_feature_head2.pth",
41
+ ],
42
+ "regressor": [
43
+ f"{relative_checkpoint_path}/prediction_heads/dpt_reg_processor/DUSt3R_512_dpt_reg_processor1.pth",
44
+ f"{relative_checkpoint_path}/prediction_heads/dpt_reg_processor/DUSt3R_512_dpt_reg_processor2.pth",
45
+ ],
46
+ "ckpt_path": f"{relative_checkpoint_path}/examples/original_dust3r/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth",
47
+ },
48
+ "dust3r_512_dpt_mast3r": {
49
+ "encoder": f"{relative_checkpoint_path}/encoders/CroCo_Encoder_512_MASt3R.pth",
50
+ "info_sharing": f"{relative_checkpoint_path}/info_sharing/cross_attn_transformer/Two_View_Cross_Attention_Transformer_MASt3R_512_dpt.pth",
51
+ "feature_head": [
52
+ f"{relative_checkpoint_path}/prediction_heads/dpt_feature_head/MASt3R_512_dpt_feature_head1.pth",
53
+ f"{relative_checkpoint_path}/prediction_heads/dpt_feature_head/MASt3R_512_dpt_feature_head2.pth",
54
+ ],
55
+ "regressor": [
56
+ f"{relative_checkpoint_path}/prediction_heads/dpt_reg_processor/MASt3R_512_dpt_reg_processor1.pth",
57
+ f"{relative_checkpoint_path}/prediction_heads/dpt_reg_processor/MASt3R_512_dpt_reg_processor2.pth",
58
+ ],
59
+ "ckpt_path": f"{relative_checkpoint_path}/examples/original_dust3r/DUSt3R_ViTLarge_BaseDecoder_512_dpt_mast3r.pth",
60
+ },
61
+ "dust3r_512_linear": {
62
+ "encoder": f"{relative_checkpoint_path}/encoders/CroCo_Encoder_512_DUSt3R_linear.pth",
63
+ "info_sharing": f"{relative_checkpoint_path}/info_sharing/cross_attn_transformer/Two_View_Cross_Attention_Transformer_DUSt3R_512_linear.pth",
64
+ "feature_head": [
65
+ f"{relative_checkpoint_path}/prediction_heads/linear_feature_head/DUSt3R_512_linear_feature_head1.pth",
66
+ f"{relative_checkpoint_path}/prediction_heads/linear_feature_head/DUSt3R_512_linear_feature_head2.pth",
67
+ ],
68
+ "regressor": None,
69
+ "ckpt_path": f"{relative_checkpoint_path}/examples/original_dust3r/DUSt3R_ViTLarge_BaseDecoder_512_linear.pth",
70
+ },
71
+ "dust3r_224_linear": {
72
+ "encoder": f"{relative_checkpoint_path}/encoders/CroCo_Encoder_224_DUSt3R_linear.pth",
73
+ "info_sharing": f"{relative_checkpoint_path}/info_sharing/cross_attn_transformer/Two_View_Cross_Attention_Transformer_DUSt3R_224_linear.pth",
74
+ "feature_head": [
75
+ f"{relative_checkpoint_path}/prediction_heads/linear_feature_head/DUSt3R_224_linear_feature_head1.pth",
76
+ f"{relative_checkpoint_path}/prediction_heads/linear_feature_head/DUSt3R_224_linear_feature_head2.pth",
77
+ ],
78
+ "regressor": None,
79
+ "ckpt_path": f"{relative_checkpoint_path}/examples/original_dust3r/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth",
80
+ },
81
+ }
82
+ return model_configurations, model_to_checkpoint_path
83
+
84
+
85
+ def get_parser():
86
+ "Argument parser for the script."
87
+ parser = argparse.ArgumentParser()
88
+ parser.add_argument("--viz", action="store_true")
89
+
90
+ return parser
91
+
92
+
93
+ if __name__ == "__main__":
94
+ # Parse arguments
95
+ parser = get_parser()
96
+ script_add_rerun_args(parser) # Options: --addr
97
+ args = parser.parse_args()
98
+
99
+ # Set up Rerun for visualization
100
+ if args.viz:
101
+ rr.script_setup(args, f"UniCeption_DUSt3R_Inference")
102
+ rr.set_time("stable_time", sequence=0)
103
+
104
+ # the reference data are collected under this setting.
105
+ # may use (False, "high") to test the relative error at TF32 precision
106
+ torch.backends.cuda.matmul.allow_tf32 = False
107
+ torch.set_float32_matmul_precision("highest")
108
+
109
+ # Get paths to pretrained checkpoints
110
+ current_file_path = os.path.abspath(__file__)
111
+ relative_checkpoint_path = os.path.join(os.path.dirname(current_file_path), "../../../checkpoints")
112
+ model_configurations, model_to_checkpoint_path = get_model_configurations_and_checkpoints()
113
+
114
+ MODEL_TO_VERIFICATION_PATH = {
115
+ "dust3r_512_dpt": {
116
+ "head_output": os.path.join(
117
+ os.path.dirname(current_file_path),
118
+ "../../../reference_data/dust3r_pre_cvpr",
119
+ "DUSt3R_512_dpt",
120
+ "03_head_output.npz",
121
+ )
122
+ },
123
+ "dust3r_512_dpt_mast3r": {
124
+ "head_output": os.path.join(
125
+ os.path.dirname(current_file_path),
126
+ "../../../reference_data/dust3r_pre_cvpr",
127
+ "MASt3R_512_dpt",
128
+ "03_head_output.npz",
129
+ )
130
+ },
131
+ "dust3r_512_linear": {
132
+ "head_output": os.path.join(
133
+ os.path.dirname(current_file_path),
134
+ "../../../reference_data/dust3r_pre_cvpr",
135
+ "DUSt3R_512_linear",
136
+ "03_head_output.npz",
137
+ )
138
+ },
139
+ "dust3r_224_linear": {
140
+ "head_output": os.path.join(
141
+ os.path.dirname(current_file_path),
142
+ "../../../reference_data/dust3r_pre_cvpr",
143
+ "DUSt3R_224_linear",
144
+ "03_head_output.npz",
145
+ )
146
+ },
147
+ }
148
+
149
+ # Test different DUSt3R models using UniCeption modules
150
+ for model_name in model_configurations:
151
+ dust3r_model = DUSt3R(
152
+ name=model_name,
153
+ img_size=(512, 512) if "512" in model_name else (224, 224),
154
+ patch_embed_cls="PatchEmbedDust3R",
155
+ pred_head_type="linear" if "linear" in model_name else "dpt",
156
+ pretrained_checkpoint_path=model_to_checkpoint_path[model_name]["ckpt_path"],
157
+ # pretrained_encoder_checkpoint_path=model_to_checkpoint_path[model_name]["encoder"],
158
+ # pretrained_info_sharing_checkpoint_path=model_to_checkpoint_path[model_name]["info_sharing"],
159
+ # pretrained_pred_head_checkpoint_paths=model_to_checkpoint_path[model_name]["feature_head"],
160
+ # pretrained_pred_head_regressor_checkpoint_paths=model_to_checkpoint_path[model_name]["regressor"],
161
+ # override_encoder_checkpoint_attributes=True,
162
+ )
163
+ print("DUSt3R model initialized successfully!")
164
+
165
+ # Initalize device
166
+ if torch.cuda.is_available():
167
+ device = "cuda:0"
168
+ else:
169
+ device = "cpu"
170
+ dust3r_model.to(device)
171
+
172
+ # Initalize two example images
173
+ img0_url = (
174
+ "https://raw.githubusercontent.com/naver/croco/d3d0ab2858d44bcad54e5bfc24f565983fbe18d9/assets/Chateau1.png"
175
+ )
176
+ img1_url = (
177
+ "https://raw.githubusercontent.com/naver/croco/d3d0ab2858d44bcad54e5bfc24f565983fbe18d9/assets/Chateau2.png"
178
+ )
179
+ response = requests.get(img0_url)
180
+ img0 = Image.open(BytesIO(response.content))
181
+ response = requests.get(img1_url)
182
+ img1 = Image.open(BytesIO(response.content))
183
+ img0_tensor = torch.from_numpy(np.array(img0))[..., :3].permute(2, 0, 1).unsqueeze(0).float() / 255
184
+ img1_tensor = torch.from_numpy(np.array(img1))[..., :3].permute(2, 0, 1).unsqueeze(0).float() / 255
185
+
186
+ # Normalize images according to DUSt3R's normalization
187
+ img0_tensor = (img0_tensor - 0.5) / 0.5
188
+ img1_tensor = (img1_tensor - 0.5) / 0.5
189
+ img_tensor = torch.cat((img0_tensor, img1_tensor), dim=0).to(device)
190
+
191
+ # Run a forward pass
192
+ view1 = {"img": img_tensor, "instance": [0, 1], "data_norm_type": "dust3r"}
193
+ view2 = {"img": view1["img"][[1, 0]].clone().to(device), "instance": [1, 0], "data_norm_type": "dust3r"}
194
+
195
+ res1, res2 = dust3r_model(view1, view2)
196
+ print("Forward pass completed successfully!")
197
+
198
+ # Automatically test the results against the reference result from vanilla dust3r code if they exist
199
+ reference_output_path = MODEL_TO_VERIFICATION_PATH[model_name]["head_output"]
200
+ if os.path.exists(reference_output_path):
201
+ reference_output_data = np.load(reference_output_path)
202
+
203
+ # Check against the reference output
204
+ check_dict = {
205
+ "head1_pts3d": (
206
+ res1["pts3d"].detach().cpu().numpy(),
207
+ reference_output_data["head1_pts3d"],
208
+ ),
209
+ "head2_pts3d": (
210
+ res2["pts3d_in_other_view"].detach().cpu().numpy(),
211
+ reference_output_data["head2_pts3d"],
212
+ ),
213
+ "head1_conf": (
214
+ res1["conf"].detach().squeeze(-1).cpu().numpy(),
215
+ reference_output_data["head1_conf"],
216
+ ),
217
+ "head2_conf": (
218
+ res2["conf"].detach().squeeze(-1).cpu().numpy(),
219
+ reference_output_data["head2_conf"],
220
+ ),
221
+ }
222
+
223
+ compute_abs_and_rel_error = lambda x, y: (np.abs(x - y).max(), np.linalg.norm(x - y) / np.linalg.norm(x))
224
+
225
+ print(f"===== Checking for {model_name} model =====")
226
+ for key, (output, reference) in check_dict.items():
227
+ abs_error, rel_error = compute_abs_and_rel_error(output, reference)
228
+ print(f"{key} abs_error: {abs_error}, rel_error: {rel_error}")
229
+
230
+ assert abs_error < 1e-2 and rel_error < 1e-3, f"Error in {key} output"
231
+
232
+ points1 = res1["pts3d"][0].detach().cpu().numpy()
233
+ points2 = res2["pts3d_in_other_view"][0].detach().cpu().numpy()
234
+ conf_mask1 = res1["conf"][0].squeeze(-1).detach().cpu().numpy() > 3.0
235
+ conf_mask2 = res2["conf"][0].squeeze(-1).detach().cpu().numpy() > 3.0
236
+
237
+ if args.viz:
238
+ rr.log(f"{model_name}", rr.ViewCoordinates.RDF, static=True)
239
+ filtered_pts3d1 = points1[conf_mask1]
240
+ filtered_pts3d1_colors = np.array(img0)[..., :3][conf_mask1] / 255
241
+ filtered_pts3d2 = points2[conf_mask2]
242
+ filtered_pts3d2_colors = np.array(img1)[..., :3][conf_mask2] / 255
243
+ rr.log(
244
+ f"{model_name}/view1",
245
+ rr.Points3D(
246
+ positions=filtered_pts3d1.reshape(-1, 3),
247
+ colors=filtered_pts3d1_colors.reshape(-1, 3),
248
+ ),
249
+ )
250
+ rr.log(
251
+ f"{model_name}/view2",
252
+ rr.Points3D(
253
+ positions=filtered_pts3d2.reshape(-1, 3),
254
+ colors=filtered_pts3d2_colors.reshape(-1, 3),
255
+ ),
256
+ )
257
+ print(
258
+ "Visualizations logged to Rerun: rerun+http://127.0.0.1:<rr-port>/proxy."
259
+ "For example, to spawn viewer: rerun --connect rerun+http://127.0.0.1:<rr-port>/proxy"
260
+ "Replace <rr-port> with the actual port."
261
+ )
UniCeption/examples/models/dust3r/profile_dust3r.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dust3r import get_model_configurations_and_checkpoints
3
+
4
+ from uniception.models.factory import DUSt3R
5
+ from uniception.utils.profile import benchmark_torch_function
6
+
7
+ if __name__ == "__main__":
8
+ # Get model configurations and checkpoints
9
+ model_configurations, model_to_checkpoint_path = get_model_configurations_and_checkpoints()
10
+
11
+ # Test different DUSt3R models using UniCeption modules
12
+ for model_name in model_configurations:
13
+ dust3r_model = DUSt3R(
14
+ name=model_name,
15
+ img_size=(512, 512) if "512" in model_name else (224, 224),
16
+ patch_embed_cls="PatchEmbedDust3R",
17
+ pred_head_type="linear" if "linear" in model_name else "dpt",
18
+ pretrained_checkpoint_path=model_to_checkpoint_path[model_name]["ckpt_path"],
19
+ )
20
+ print(f"DUSt3R model ({model_name}) initialized successfully!")
21
+
22
+ # Initialize device
23
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
24
+ dust3r_model.to(device)
25
+ print(f"Running on {device}")
26
+
27
+ # Generate random input tensors
28
+ img_size = (512, 512) if "512" in model_name else (224, 224)
29
+ batch_sizes = [1, 2, 4, 8]
30
+
31
+ for batch_size in batch_sizes:
32
+ # Prepare input views
33
+ view1_instances = range(batch_size)
34
+ view1_img_tensor = torch.randn(batch_size, 3, *img_size).to(device)
35
+ view1 = {"img": view1_img_tensor, "instance": view1_instances, "data_norm_type": "dust3r"}
36
+ view2_instances = range(batch_size)
37
+ view2_instances = [id + batch_size for id in view2_instances]
38
+ view2_img_tensor = torch.randn(batch_size, 3, *img_size).to(device)
39
+ view2 = {"img": view2_img_tensor, "instance": view2_instances, "data_norm_type": "dust3r"}
40
+
41
+ # Benchmark the forward pass of the model
42
+ with torch.no_grad():
43
+ with torch.autocast("cuda", enabled=True):
44
+ execution_time = benchmark_torch_function(dust3r_model, view1, view2)
45
+ print(
46
+ f"\033[92mForward pass for {model_name}, batch size : {batch_size} completed in {execution_time:.3f} milliseconds\033[0m"
47
+ )
UniCeption/pyproject.toml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.black]
2
+ line-length = 120
3
+ include = '\.pyi?$'
4
+ exclude = '''
5
+ /(
6
+ \.git
7
+ | \.hg
8
+ | \.mypy_cache
9
+ | \.tox
10
+ | \.venv
11
+ | _build
12
+ | buck-out
13
+ | build
14
+ | cuda
15
+ | dist
16
+ )/
17
+ '''
18
+
19
+ [tool.isort]
20
+ profile = "black"
21
+ line_length = 120
UniCeption/scripts/check_dependencies.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Console script to check UniCeption dependencies.
4
+ """
5
+
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ # Add the parent directory to the path to import uniception
10
+ sys.path.insert(0, str(Path(__file__).parent.parent))
11
+
12
+
13
+ def check_dependencies():
14
+ """Check if optional dependencies are available."""
15
+ try:
16
+ import torch
17
+
18
+ print(f"PyTorch version: {torch.__version__}")
19
+ if torch.cuda.is_available():
20
+ print(f"CUDA available: {torch.version.cuda}")
21
+ else:
22
+ print("CUDA not available")
23
+ except ImportError:
24
+ print("PyTorch not installed")
25
+
26
+ try:
27
+ import xformers
28
+
29
+ print(f"XFormers version: {xformers.__version__}")
30
+ except ImportError:
31
+ print("XFormers not installed")
32
+
33
+ try:
34
+ from uniception.models.libs.croco.curope import cuRoPE2D
35
+
36
+ print("CroCo RoPE extension available")
37
+ except ImportError:
38
+ print("CroCo RoPE extension not available")
39
+
40
+
41
+ def main():
42
+ """Main entry point for the check dependencies script."""
43
+ print("Checking UniCeption Dependencies...")
44
+ print("=" * 40)
45
+ check_dependencies()
46
+
47
+
48
+ if __name__ == "__main__":
49
+ main()
UniCeption/scripts/download_checkpoints.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Download the UniCeption format checkpoints from the AirLab Data Server"
2
+
3
+ import argparse
4
+ import os
5
+
6
+ from minio import Minio
7
+ from minio.error import S3Error
8
+ from tqdm import tqdm
9
+
10
+
11
+ def main():
12
+ parser = argparse.ArgumentParser(description="Download UniCeption format checkpoints from AirLab Data Server")
13
+ parser.add_argument(
14
+ "--folders",
15
+ nargs="+",
16
+ default=["encoders", "info_sharing", "prediction_heads", "examples"],
17
+ help="List of folders to download (default: all folders). Choices: encoders, info_sharing, prediction_heads, examples",
18
+ )
19
+ parser.add_argument("--destination", type=str, default="./", help="Destination folder for downloaded checkpoints")
20
+ args = parser.parse_args()
21
+
22
+ access_key = "bT79gQYtfhpxFIitlpns"
23
+ secret_key = "g7mSvUJ5k2a9mKv9IbhwXmUQjQX52MLwulhW9ONO"
24
+ client = Minio("airlab-share-02.andrew.cmu.edu:9000", access_key=access_key, secret_key=secret_key, secure=True)
25
+
26
+ bucket_name = "uniception"
27
+
28
+ def download_folder(folder_name, bucket_name, client, destination_folder):
29
+ folder_name = f"checkpoints/{folder_name}/"
30
+ objects = client.list_objects(bucket_name, prefix=folder_name, recursive=True)
31
+ for obj in tqdm(objects, desc=f"Downloading {folder_name}"):
32
+ destination_file = os.path.join(destination_folder, obj.object_name)
33
+ if not os.path.exists(destination_file):
34
+ os.makedirs(os.path.dirname(destination_file), exist_ok=True)
35
+ try:
36
+ client.fget_object(bucket_name, obj.object_name, destination_file)
37
+ print(f"Downloaded {obj.object_name} to {destination_file}")
38
+ except S3Error as e:
39
+ print(f"Error downloading {obj.object_name}: {e}")
40
+ else:
41
+ print(f"File {destination_file} already exists. Skipping...")
42
+
43
+ for folder in args.folders:
44
+ download_folder(folder, bucket_name, client, args.destination)
45
+
46
+
47
+ if __name__ == "__main__":
48
+ main()
UniCeption/scripts/install_croco_rope.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Console script to install CroCo RoPE extension.
4
+ """
5
+
6
+ import os
7
+ import subprocess
8
+ import sys
9
+ from pathlib import Path
10
+
11
+
12
+ def install_croco_rope():
13
+ """Install CroCo RoPE extension."""
14
+ try:
15
+ # Find the project root (where setup.py is located)
16
+ script_dir = Path(__file__).parent
17
+ project_root = script_dir.parent
18
+ curope_path = project_root / "uniception" / "models" / "libs" / "croco" / "curope"
19
+
20
+ if curope_path.exists():
21
+ print("Installing CroCo RoPE extension...")
22
+ original_cwd = os.getcwd()
23
+ try:
24
+ os.chdir(curope_path)
25
+ subprocess.check_call([sys.executable, "setup.py", "build_ext", "--inplace"])
26
+ print("CroCo RoPE extension installed successfully!")
27
+ return True
28
+ except subprocess.CalledProcessError as e:
29
+ print(f"Warning: Failed to install CroCo RoPE extension: {e}")
30
+ print("You can install it later by running:")
31
+ print(f"cd {curope_path} && python setup.py build_ext --inplace")
32
+ return False
33
+ finally:
34
+ os.chdir(original_cwd)
35
+ else:
36
+ print("Warning: CroCo RoPE source code not found.")
37
+ print(f"Expected location: {curope_path}")
38
+ return False
39
+ except Exception as e:
40
+ print(f"Warning: Error during CroCo RoPE installation: {e}")
41
+ return False
42
+
43
+
44
+ def main():
45
+ """Main entry point for the CroCo RoPE installation script."""
46
+ print("UniCeption CroCo RoPE Extension Installer")
47
+ print("=" * 45)
48
+
49
+ success = install_croco_rope()
50
+
51
+ if success:
52
+ print("\n✓ CroCo RoPE extension installation completed successfully!")
53
+ sys.exit(0)
54
+ else:
55
+ print("\n⚠ CroCo RoPE extension installation failed or skipped.")
56
+ print("This is typically due to missing CUDA development tools.")
57
+ print("The extension is optional and UniCeption will work without it.")
58
+ sys.exit(1)
59
+
60
+
61
+ if __name__ == "__main__":
62
+ main()
UniCeption/scripts/prepare_offline_install.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Script to prepare dependencies for offline installation.
4
+
5
+ This script downloads all necessary wheel files for offline installation
6
+ of UniCeption in environments without internet access.
7
+ """
8
+
9
+ import argparse
10
+ import os
11
+ import subprocess
12
+ import sys
13
+ from pathlib import Path
14
+
15
+
16
+ def download_wheels(output_dir: Path, extras: list = None):
17
+ """Download wheel files for offline installation."""
18
+ output_dir.mkdir(parents=True, exist_ok=True)
19
+
20
+ # Create temporary requirements files
21
+ temp_dir = output_dir / "temp"
22
+ temp_dir.mkdir(exist_ok=True)
23
+
24
+ try:
25
+ # Create requirements files
26
+ create_requirements_files(temp_dir, extras)
27
+
28
+ # Download base dependencies
29
+ base_cmd = [
30
+ sys.executable,
31
+ "-m",
32
+ "pip",
33
+ "download",
34
+ "--dest",
35
+ str(output_dir),
36
+ "-r",
37
+ str(temp_dir / "requirements-base.txt"),
38
+ ]
39
+
40
+ print(f"Downloading base dependencies to {output_dir}...")
41
+ subprocess.check_call(base_cmd)
42
+
43
+ # Download optional dependencies if requested
44
+ if extras:
45
+ for extra in extras:
46
+ if extra == "all":
47
+ # Download all extras
48
+ for req_file in ["requirements-xformers.txt", "requirements-dev.txt"]:
49
+ if (temp_dir / req_file).exists():
50
+ cmd = [
51
+ sys.executable,
52
+ "-m",
53
+ "pip",
54
+ "download",
55
+ "--dest",
56
+ str(output_dir),
57
+ "-r",
58
+ str(temp_dir / req_file),
59
+ ]
60
+ print(
61
+ f"Downloading {req_file.replace('requirements-', '').replace('.txt', '')} dependencies..."
62
+ )
63
+ try:
64
+ subprocess.check_call(cmd)
65
+ except subprocess.CalledProcessError as e:
66
+ print(f"Warning: Failed to download {extra} dependencies: {e}")
67
+ else:
68
+ req_file = temp_dir / f"requirements-{extra}.txt"
69
+ if req_file.exists():
70
+ cmd = [sys.executable, "-m", "pip", "download", "--dest", str(output_dir), "-r", str(req_file)]
71
+ print(f"Downloading {extra} dependencies...")
72
+ try:
73
+ subprocess.check_call(cmd)
74
+ except subprocess.CalledProcessError as e:
75
+ print(f"Warning: Failed to download {extra} dependencies: {e}")
76
+
77
+ # Create final offline installation files
78
+ create_offline_installation_files(output_dir)
79
+
80
+ print("Download completed successfully!")
81
+
82
+ except subprocess.CalledProcessError as e:
83
+ print(f"Error downloading wheels: {e}")
84
+ sys.exit(1)
85
+ finally:
86
+ # Clean up temporary files
87
+ import shutil
88
+
89
+ if temp_dir.exists():
90
+ shutil.rmtree(temp_dir)
91
+
92
+
93
+ def create_requirements_files(temp_dir: Path, extras: list = None):
94
+ """Create temporary requirements files for downloading."""
95
+
96
+ # Base requirements (including PyTorch)
97
+ base_reqs = [
98
+ "numpy",
99
+ "torch",
100
+ "torchvision",
101
+ "torchaudio",
102
+ "timm",
103
+ "black",
104
+ "jaxtyping",
105
+ "matplotlib",
106
+ "Pillow",
107
+ "scikit-learn",
108
+ "einops",
109
+ "rerun-sdk",
110
+ "pre-commit",
111
+ "minio",
112
+ "pytest",
113
+ "isort",
114
+ ]
115
+
116
+ # Write base requirements
117
+ with open(temp_dir / "requirements-base.txt", "w") as f:
118
+ for req in base_reqs:
119
+ f.write(f"{req}\n")
120
+
121
+ # XFormers requirements
122
+ with open(temp_dir / "requirements-xformers.txt", "w") as f:
123
+ f.write("xformers\n")
124
+
125
+ # Dev requirements
126
+ dev_reqs = [
127
+ "black",
128
+ "isort",
129
+ "pre-commit",
130
+ "pytest",
131
+ ]
132
+
133
+ with open(temp_dir / "requirements-dev.txt", "w") as f:
134
+ for req in dev_reqs:
135
+ f.write(f"{req}\n")
136
+
137
+
138
+ def create_offline_installation_files(output_dir: Path):
139
+ """Create requirements files and installation script for offline use."""
140
+
141
+ # Base requirements (including PyTorch)
142
+ base_reqs = [
143
+ "numpy",
144
+ "torch",
145
+ "torchvision",
146
+ "torchaudio",
147
+ "timm",
148
+ "black",
149
+ "jaxtyping",
150
+ "matplotlib",
151
+ "Pillow",
152
+ "scikit-learn",
153
+ "einops",
154
+ "rerun-sdk",
155
+ "pre-commit",
156
+ "minio",
157
+ "pytest",
158
+ "isort",
159
+ ]
160
+
161
+ # Write base requirements
162
+ with open(output_dir / "requirements-base.txt", "w") as f:
163
+ for req in base_reqs:
164
+ f.write(f"{req}\n")
165
+
166
+ # XFormers requirements
167
+ with open(output_dir / "requirements-xformers.txt", "w") as f:
168
+ f.write("xformers\n")
169
+
170
+ # Dev requirements
171
+ dev_reqs = [
172
+ "black",
173
+ "isort",
174
+ "pre-commit",
175
+ "pytest",
176
+ ]
177
+
178
+ with open(output_dir / "requirements-dev.txt", "w") as f:
179
+ for req in dev_reqs:
180
+ f.write(f"{req}\n")
181
+
182
+ # Create installation script
183
+ install_script = output_dir / "install_offline.sh"
184
+ with open(install_script, "w") as f:
185
+ f.write(
186
+ """#!/bin/bash
187
+ # Offline installation script for UniCeption
188
+
189
+ set -e
190
+
191
+ echo "Installing UniCeption dependencies offline..."
192
+
193
+ # Check if we're in the right directory
194
+ if [ ! -f "requirements-base.txt" ]; then
195
+ echo "Error: requirements-base.txt not found. Please run this script from the offline_wheels directory."
196
+ exit 1
197
+ fi
198
+
199
+ # Install base dependencies (includes PyTorch)
200
+ echo "Installing base dependencies (including PyTorch)..."
201
+ pip install --no-index --find-links . -r requirements-base.txt
202
+
203
+ # Install XFormers if requested
204
+ if [ "$INSTALL_XFORMERS" = "true" ]; then
205
+ echo "Installing XFormers..."
206
+ pip install --no-index --find-links . -r requirements-xformers.txt
207
+ fi
208
+
209
+ # Install dev dependencies if requested
210
+ if [ "$INSTALL_DEV" = "true" ]; then
211
+ echo "Installing development dependencies..."
212
+ pip install --no-index --find-links . -r requirements-dev.txt
213
+ fi
214
+
215
+ # Navigate back to UniCeption directory and install the package
216
+ echo "Installing UniCeption package..."
217
+ cd ..
218
+ pip install --no-deps -e .
219
+
220
+ # Install CroCo RoPE extension if requested
221
+ if [ "$INSTALL_CROCO_ROPE" = "true" ]; then
222
+ echo "Installing CroCo RoPE extension..."
223
+ cd uniception/models/libs/croco/curope
224
+ python setup.py build_ext --inplace
225
+ cd -
226
+ fi
227
+
228
+ echo "Offline installation completed successfully!"
229
+ echo ""
230
+ echo "To verify installation, run:"
231
+ echo "python setup.py check_deps"
232
+ """
233
+ )
234
+
235
+ # Make script executable
236
+ install_script.chmod(0o755)
237
+
238
+ # Create Windows batch script as well
239
+ install_bat = output_dir / "install_offline.bat"
240
+ with open(install_bat, "w") as f:
241
+ f.write(
242
+ """@echo off
243
+ REM Offline installation script for UniCeption (Windows)
244
+
245
+ echo Installing UniCeption dependencies offline...
246
+
247
+ REM Check if we're in the right directory
248
+ if not exist "requirements-base.txt" (
249
+ echo Error: requirements-base.txt not found. Please run this script from the offline_wheels directory.
250
+ exit /b 1
251
+ )
252
+
253
+ REM Install base dependencies (includes PyTorch)
254
+ echo Installing base dependencies (including PyTorch)...
255
+ pip install --no-index --find-links . -r requirements-base.txt
256
+
257
+ REM Install XFormers if requested
258
+ if "%INSTALL_XFORMERS%"=="true" (
259
+ echo Installing XFormers...
260
+ pip install --no-index --find-links . -r requirements-xformers.txt
261
+ )
262
+
263
+ REM Install dev dependencies if requested
264
+ if "%INSTALL_DEV%"=="true" (
265
+ echo Installing development dependencies...
266
+ pip install --no-index --find-links . -r requirements-dev.txt
267
+ )
268
+
269
+ REM Navigate back to UniCeption directory and install the package
270
+ echo Installing UniCeption package...
271
+ cd ..
272
+ pip install --no-deps -e .
273
+
274
+ REM Install CroCo RoPE extension if requested
275
+ if "%INSTALL_CROCO_ROPE%"=="true" (
276
+ echo Installing CroCo RoPE extension...
277
+ cd uniception\\models\\libs\\croco\\curope
278
+ python setup.py build_ext --inplace
279
+ cd ..\\..\\..\\..\\..
280
+ )
281
+
282
+ echo Offline installation completed successfully!
283
+ echo.
284
+ echo To verify installation, run:
285
+ echo python setup.py check_deps
286
+ """
287
+ )
288
+
289
+ # Create a README for offline installation
290
+ readme_file = output_dir / "README_OFFLINE.md"
291
+ with open(readme_file, "w") as f:
292
+ f.write(
293
+ """# UniCeption Offline Installation
294
+
295
+ This directory contains all the necessary files for installing UniCeption without internet access.
296
+
297
+ ## Files Included
298
+
299
+ - `requirements-base.txt` - Core dependencies (including PyTorch)
300
+ - `requirements-xformers.txt` - XFormers dependency
301
+ - `requirements-dev.txt` - Development dependencies
302
+ - `install_offline.sh` - Installation script for Unix/Linux/macOS
303
+ - `install_offline.bat` - Installation script for Windows
304
+ - `*.whl` files - Downloaded wheel packages
305
+
306
+ ## Installation Instructions
307
+
308
+ ### Unix/Linux/macOS
309
+
310
+ ```bash
311
+ # Set environment variables for optional components
312
+ export INSTALL_XFORMERS=true # Install XFormers
313
+ export INSTALL_DEV=true # Install development tools
314
+ export INSTALL_CROCO_ROPE=true # Compile CroCo RoPE extension
315
+
316
+ # Run the installation script
317
+ ./install_offline.sh
318
+ ```
319
+
320
+ ### Windows
321
+
322
+ ```cmd
323
+ REM Set environment variables for optional components
324
+ set INSTALL_XFORMERS=true
325
+ set INSTALL_DEV=true
326
+ set INSTALL_CROCO_ROPE=true
327
+
328
+ REM Run the installation script
329
+ install_offline.bat
330
+ ```
331
+
332
+ ## Manual Installation
333
+
334
+ If the scripts don't work, you can install manually:
335
+
336
+ ```bash
337
+ # Install base dependencies (includes PyTorch)
338
+ pip install --no-index --find-links . -r requirements-base.txt
339
+
340
+ # Install optional dependencies as needed
341
+ pip install --no-index --find-links . -r requirements-xformers.txt
342
+ pip install --no-index --find-links . -r requirements-dev.txt
343
+
344
+ # Install UniCeption package (from parent directory)
345
+ cd ..
346
+ pip install --no-deps -e .
347
+
348
+ # Compile CroCo RoPE extension (optional)
349
+ cd uniception/models/libs/croco/curope
350
+ python setup.py build_ext --inplace
351
+ ```
352
+
353
+ ## Verification
354
+
355
+ After installation, verify everything is working:
356
+
357
+ ```bash
358
+ cd .. # Go back to UniCeption root directory
359
+ python setup.py check_deps
360
+ ```
361
+
362
+ ## Notes
363
+
364
+ - PyTorch, TorchVision, and TorchAudio are now included in the base requirements
365
+ - XFormers is optional and only needed for certain performance optimizations
366
+ - CroCo RoPE extension compilation requires a CUDA-enabled environment
367
+ """
368
+ )
369
+
370
+ print(f"Created offline installation files in {output_dir}")
371
+ print("Files created:")
372
+ print(" - requirements-base.txt (includes PyTorch)")
373
+ print(" - requirements-xformers.txt")
374
+ print(" - requirements-dev.txt")
375
+ print(" - install_offline.sh (Unix/Linux/macOS)")
376
+ print(" - install_offline.bat (Windows)")
377
+ print(" - README_OFFLINE.md")
378
+
379
+
380
+ def create_offline_requirements(output_dir: Path):
381
+ """Create requirements files for offline installation."""
382
+ # This function is now replaced by create_offline_installation_files
383
+ pass
384
+
385
+
386
+ def main():
387
+ parser = argparse.ArgumentParser(description="Prepare UniCeption for offline installation")
388
+ parser.add_argument(
389
+ "--output-dir", type=Path, default="offline_wheels", help="Directory to store downloaded wheels"
390
+ )
391
+ parser.add_argument("--extras", nargs="+", choices=["xformers", "dev", "all"], help="Extra dependencies to include")
392
+
393
+ args = parser.parse_args()
394
+
395
+ download_wheels(args.output_dir, args.extras)
396
+
397
+
398
+ if __name__ == "__main__":
399
+ main()
UniCeption/scripts/validate_installation.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Validation script for UniCeption installation.
4
+
5
+ This script validates that all components of UniCeption are correctly installed
6
+ and provides helpful diagnostics.
7
+ """
8
+
9
+ import importlib
10
+ import sys
11
+ from pathlib import Path
12
+
13
+
14
+ def check_package_installation():
15
+ """Check if UniCeption package is properly installed."""
16
+ try:
17
+ import uniception
18
+
19
+ print("✓ UniCeption package is installed")
20
+
21
+ # Check if we can import core modules
22
+ try:
23
+ from uniception.models.encoders import UniCeptionViTEncoderBase
24
+
25
+ print("✓ Core encoder modules are available")
26
+ except ImportError as e:
27
+ print(f"✗ Failed to import core encoder modules: {e}")
28
+
29
+ return True
30
+ except ImportError as e:
31
+ print(f"✗ UniCeption package not found: {e}")
32
+ return False
33
+
34
+
35
+ def check_dependencies():
36
+ """Check optional dependencies."""
37
+ dependencies = {
38
+ "torch": "PyTorch",
39
+ "torchvision": "TorchVision",
40
+ "torchaudio": "TorchAudio",
41
+ "xformers": "XFormers",
42
+ "timm": "Timm (PyTorch Image Models)",
43
+ "einops": "Einops",
44
+ "matplotlib": "Matplotlib",
45
+ "numpy": "NumPy",
46
+ "PIL": "Pillow",
47
+ }
48
+
49
+ available = []
50
+ missing = []
51
+
52
+ for module, name in dependencies.items():
53
+ try:
54
+ mod = importlib.import_module(module)
55
+ version = getattr(mod, "__version__", "unknown")
56
+ available.append(f"✓ {name}: {version}")
57
+ except ImportError:
58
+ missing.append(f"✗ {name}: not installed")
59
+
60
+ print("\nDependency Status:")
61
+ for dep in available:
62
+ print(f" {dep}")
63
+
64
+ if missing:
65
+ print("\nMissing Dependencies:")
66
+ for dep in missing:
67
+ print(f" {dep}")
68
+
69
+ return len(missing) == 0
70
+
71
+
72
+ def check_cuda_support():
73
+ """Check CUDA support."""
74
+ try:
75
+ import torch
76
+
77
+ if torch.cuda.is_available():
78
+ print(f"\n✓ CUDA is available")
79
+ print(f" CUDA version: {torch.version.cuda}")
80
+ print(f" Available devices: {torch.cuda.device_count()}")
81
+ for i in range(torch.cuda.device_count()):
82
+ print(f" Device {i}: {torch.cuda.get_device_name(i)}")
83
+ else:
84
+ print(f"\n⚠ CUDA is not available (CPU-only mode)")
85
+ except ImportError:
86
+ print(f"\n⚠ PyTorch not installed - cannot check CUDA support")
87
+
88
+
89
+ def check_croco_rope():
90
+ """Check CroCo RoPE extension."""
91
+ try:
92
+ from uniception.models.libs.croco.curope import cuRoPE2D
93
+
94
+ print("\n✓ CroCo RoPE extension is available")
95
+ return True
96
+ except ImportError:
97
+ print("\n✗ CroCo RoPE extension not available")
98
+ print(" To install: cd uniception/models/libs/croco/curope && python setup.py build_ext --inplace")
99
+ return False
100
+
101
+
102
+ def check_model_availability():
103
+ """Check if models can be loaded."""
104
+ try:
105
+ # Try to check if encoder modules are available
106
+ from uniception.models import encoders
107
+
108
+ print(f"\n✓ Encoder module is available")
109
+
110
+ # Try to run the encoder list command
111
+ try:
112
+ import subprocess
113
+
114
+ result = subprocess.run(
115
+ [sys.executable, "-m", "uniception.models.encoders.list"], capture_output=True, text=True, timeout=10
116
+ )
117
+
118
+ if result.returncode == 0:
119
+ lines = result.stdout.strip().split("\n")
120
+ encoder_count = len([line for line in lines if line.strip() and not line.startswith("Available")])
121
+ print(f"✓ Available encoders: {encoder_count}")
122
+ return True
123
+ else:
124
+ print(f"⚠ Encoder listing returned non-zero exit code: {result.returncode}")
125
+ return False
126
+
127
+ except subprocess.TimeoutExpired:
128
+ print(f"⚠ Encoder listing timed out")
129
+ return False
130
+ except Exception as e:
131
+ print(f"⚠ Could not run encoder listing: {e}")
132
+ return False
133
+
134
+ except Exception as e:
135
+ print(f"\n✗ Failed to access encoder modules: {e}")
136
+ return False
137
+
138
+
139
+ def check_file_structure():
140
+ """Check if the project file structure is correct."""
141
+ base_path = Path(__file__).parent.parent
142
+ required_dirs = [
143
+ "uniception",
144
+ "uniception/models",
145
+ "uniception/models/encoders",
146
+ "uniception/models/info_sharing",
147
+ "uniception/models/prediction_heads",
148
+ "scripts",
149
+ "tests",
150
+ ]
151
+
152
+ missing_dirs = []
153
+ for dir_path in required_dirs:
154
+ full_path = base_path / dir_path
155
+ if not full_path.exists():
156
+ missing_dirs.append(dir_path)
157
+
158
+ if missing_dirs:
159
+ print(f"\n✗ Missing directories:")
160
+ for dir_path in missing_dirs:
161
+ print(f" - {dir_path}")
162
+ return False
163
+ else:
164
+ print(f"\n✓ Project structure is correct")
165
+ return True
166
+
167
+
168
+ def main():
169
+ """Run all validation checks."""
170
+ print("UniCeption Installation Validation")
171
+ print("=" * 40)
172
+
173
+ checks = [
174
+ ("Package Installation", check_package_installation),
175
+ ("Dependencies", check_dependencies),
176
+ ("CUDA Support", check_cuda_support),
177
+ ("CroCo RoPE Extension", check_croco_rope),
178
+ ("Model Availability", check_model_availability),
179
+ ("File Structure", check_file_structure),
180
+ ]
181
+
182
+ results = []
183
+ for name, check_func in checks:
184
+ print(f"\nChecking {name}...")
185
+ try:
186
+ result = check_func()
187
+ results.append((name, result))
188
+ except Exception as e:
189
+ print(f"✗ Error during {name} check: {e}")
190
+ results.append((name, False))
191
+
192
+ # Summary
193
+ print("\n" + "=" * 40)
194
+ print("Validation Summary:")
195
+ passed = 0
196
+ for name, result in results:
197
+ status = "✓ PASS" if result else "✗ FAIL"
198
+ print(f" {name}: {status}")
199
+ if result:
200
+ passed += 1
201
+
202
+ print(f"\nOverall: {passed}/{len(results)} checks passed")
203
+
204
+ if passed == len(results):
205
+ print("🎉 All checks passed! UniCeption is ready to use.")
206
+ return 0
207
+ else:
208
+ print("⚠ Some checks failed. Please review the issues above.")
209
+ return 1
210
+
211
+
212
+ if __name__ == "__main__":
213
+ sys.exit(main())
UniCeption/setup.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Package installation setup."""
2
+
3
+ import os
4
+ import subprocess
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ from setuptools import find_packages, setup
9
+ from setuptools.command.develop import develop
10
+ from setuptools.command.install import install
11
+
12
+
13
+ def install_croco_rope():
14
+ """Install CroCo RoPE extension."""
15
+ try:
16
+ curope_path = Path(__file__).parent / "uniception" / "models" / "libs" / "croco" / "curope"
17
+ if curope_path.exists():
18
+ print("Installing CroCo RoPE extension...")
19
+ original_cwd = os.getcwd()
20
+ try:
21
+ os.chdir(curope_path)
22
+ subprocess.check_call([sys.executable, "setup.py", "build_ext", "--inplace"])
23
+ print("CroCo RoPE extension installed successfully!")
24
+ return True
25
+ except subprocess.CalledProcessError as e:
26
+ print(f"Warning: Failed to install CroCo RoPE extension: {e}")
27
+ print("You can install it later by running:")
28
+ print(f"cd {curope_path} && python setup.py build_ext --inplace")
29
+ return False
30
+ finally:
31
+ os.chdir(original_cwd)
32
+ else:
33
+ print("Warning: CroCo RoPE source code not found.")
34
+ return False
35
+ except Exception as e:
36
+ print(f"Warning: Error during CroCo RoPE installation: {e}")
37
+ return False
38
+
39
+
40
+ def check_dependencies():
41
+ """Check if optional dependencies are available."""
42
+ try:
43
+ import torch
44
+
45
+ print(f"PyTorch version: {torch.__version__}")
46
+ if torch.cuda.is_available():
47
+ print(f"CUDA available: {torch.version.cuda}")
48
+ else:
49
+ print("CUDA not available")
50
+ except ImportError:
51
+ print("PyTorch not installed")
52
+
53
+ try:
54
+ import xformers
55
+
56
+ print(f"XFormers version: {xformers.__version__}")
57
+ except ImportError:
58
+ print("XFormers not installed")
59
+
60
+ try:
61
+ from uniception.models.libs.croco.curope import cuRoPE2D
62
+
63
+ print("CroCo RoPE extension available")
64
+ except ImportError:
65
+ print("CroCo RoPE extension not available")
66
+
67
+
68
+ class CustomDevelopCommand(develop):
69
+ """Custom development installation command."""
70
+
71
+ def run(self):
72
+ develop.run(self)
73
+ # Only install CroCo RoPE if explicitly requested
74
+ if os.getenv("INSTALL_CROCO_ROPE", "false").lower() in ("true", "1", "yes"):
75
+ install_croco_rope()
76
+
77
+
78
+ class CustomInstallCommand(install):
79
+ """Custom installation command."""
80
+
81
+ def run(self):
82
+ install.run(self)
83
+ # Only install CroCo RoPE if explicitly requested
84
+ if os.getenv("INSTALL_CROCO_ROPE", "false").lower() in ("true", "1", "yes"):
85
+ install_croco_rope()
86
+
87
+
88
+ class CrocoInstallCommand(install):
89
+ """Install command that includes CroCo RoPE extension."""
90
+
91
+ def run(self):
92
+ install.run(self)
93
+ install_croco_rope()
94
+
95
+
96
+ class CheckDependenciesCommand(install):
97
+ """Command to check available dependencies."""
98
+
99
+ def run(self):
100
+ check_dependencies()
101
+
102
+
103
+ # Core dependencies (including PyTorch which is essential for UniCeption)
104
+ install_requires = [
105
+ "numpy",
106
+ "torch",
107
+ "torchvision",
108
+ "torchaudio",
109
+ "timm",
110
+ "black",
111
+ "jaxtyping",
112
+ "matplotlib",
113
+ "Pillow",
114
+ "scikit-learn",
115
+ "einops",
116
+ "rerun-sdk",
117
+ "pre-commit",
118
+ "minio",
119
+ "pytest",
120
+ "isort",
121
+ ]
122
+
123
+ # Optional dependencies
124
+ extras_require = {
125
+ "xformers": [
126
+ "xformers", # Will be installed from PyTorch wheel index
127
+ ],
128
+ "dev": [
129
+ "black",
130
+ "isort",
131
+ "pre-commit",
132
+ "pytest",
133
+ ],
134
+ "minimal": [
135
+ # Minimal dependencies for basic functionality without PyTorch
136
+ "numpy",
137
+ "matplotlib",
138
+ "Pillow",
139
+ "scikit-learn",
140
+ "einops",
141
+ ],
142
+ }
143
+
144
+ # All optional dependencies combined (excluding minimal since it's subset of install_requires)
145
+ extras_require["all"] = list(set(extras_require["xformers"] + extras_require["dev"]))
146
+
147
+ setup(
148
+ name="uniception",
149
+ version="0.1.0",
150
+ description="Generalizable Perception Stack for 3D, 4D, spatial AI and scene understanding",
151
+ long_description=open("README.md").read(),
152
+ long_description_content_type="text/markdown",
153
+ author="AirLab",
154
+ license="BSD Clause-3",
155
+ packages=find_packages(),
156
+ package_dir={"": "."},
157
+ include_package_data=True,
158
+ python_requires=">=3.10",
159
+ install_requires=install_requires,
160
+ extras_require=extras_require,
161
+ cmdclass={
162
+ "develop": CustomDevelopCommand,
163
+ "install": CustomInstallCommand,
164
+ "install_croco": CrocoInstallCommand,
165
+ "check_deps": CheckDependenciesCommand,
166
+ },
167
+ entry_points={
168
+ "console_scripts": [
169
+ "uniception-download-checkpoints=scripts.download_checkpoints:main",
170
+ "uniception-validate=scripts.validate_installation:main",
171
+ "uniception-prepare-offline=scripts.prepare_offline_install:main",
172
+ "uniception-check-deps=scripts.check_dependencies:main",
173
+ "uniception-install-croco=scripts.install_croco_rope:main",
174
+ ],
175
+ },
176
+ classifiers=[
177
+ "Development Status :: 3 - Alpha",
178
+ "Intended Audience :: Developers",
179
+ "Intended Audience :: Science/Research",
180
+ "Programming Language :: Python :: 3",
181
+ "Programming Language :: Python :: 3.10",
182
+ "Programming Language :: Python :: 3.11",
183
+ "Programming Language :: Python :: 3.12",
184
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
185
+ "Topic :: Software Development :: Libraries :: Python Modules",
186
+ ],
187
+ keywords="computer-vision, 3d-vision, spatial-ai, perception, deep-learning, pytorch",
188
+ )
UniCeption/tests/models/encoders/conftest.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+
4
+ def pytest_addoption(parser):
5
+ # Add custom command-line options
6
+ parser.addoption("--encoder-name", action="store", default=None, help="Specify the encoder name to test")
7
+
8
+ parser.addoption(
9
+ "--device",
10
+ action="store",
11
+ default="cpu",
12
+ choices=["cpu", "gpu"],
13
+ help="Specify the device to use (default: cpu)",
14
+ )
15
+
16
+
17
+ @pytest.fixture
18
+ def encoder_name(request):
19
+ # Access the value of the custom option for encoder name
20
+ return request.config.getoption("--encoder-name")
21
+
22
+
23
+ @pytest.fixture
24
+ def device(request):
25
+ # Access the value of the custom option for device
26
+ return request.config.getoption("--device")
UniCeption/tests/models/encoders/test_encoders.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from functools import lru_cache
4
+ from typing import Tuple
5
+
6
+ import numpy as np
7
+ import pytest
8
+ import requests
9
+ import torch
10
+ from PIL import Image
11
+
12
+ from uniception.models.encoders import *
13
+ from uniception.models.encoders.image_normalizations import *
14
+
15
+
16
+ @pytest.fixture(scope="module")
17
+ def norm_types():
18
+ return IMAGE_NORMALIZATION_DICT.keys()
19
+
20
+
21
+ @pytest.fixture(scope="module")
22
+ def encoders():
23
+ return [
24
+ "croco",
25
+ "dust3r_224",
26
+ "dust3r_512",
27
+ "dust3r_512_dpt",
28
+ "mast3r_512",
29
+ "dinov2_base",
30
+ "dinov2_large",
31
+ "dinov2_large_reg",
32
+ "dinov2_large_dav2",
33
+ "dinov2_giant",
34
+ "dinov2_giant_reg",
35
+ "radio_v2.5-b",
36
+ "radio_v2.5-l",
37
+ "e-radio_v2",
38
+ "naradio_v2.5-b",
39
+ "naradio_v2.5-l",
40
+ "cosmosx8",
41
+ "patch_embedder",
42
+ ]
43
+
44
+
45
+ @pytest.fixture(scope="module")
46
+ def encoder_configs(encoders):
47
+ # Adjust the number of configs to match the number of encoders
48
+ return [{}] * len(encoders)
49
+
50
+
51
+ @pytest.fixture
52
+ def device(request):
53
+ # Access the value of the custom option for device
54
+ device_str = request.config.getoption("--device")
55
+ if device_str == "gpu" and torch.cuda.is_available():
56
+ device = torch.device("cuda") # Use the default CUDA device
57
+ else:
58
+ device = torch.device("cpu")
59
+ print(f"Using device: {device.type.upper()}")
60
+ return device
61
+
62
+
63
+ @pytest.fixture
64
+ def example_input(device):
65
+ @lru_cache(maxsize=3)
66
+ def _get_example_input(
67
+ image_size: Tuple[int, int],
68
+ image_norm_type: str = "dummy",
69
+ img_selection: int = 1,
70
+ return_viz_img: bool = False,
71
+ ) -> torch.Tensor:
72
+ url = f"https://raw.githubusercontent.com/naver/croco/d3d0ab2858d44bcad54e5bfc24f565983fbe18d9/assets/Chateau{img_selection}.png"
73
+ image = Image.open(requests.get(url, stream=True).raw)
74
+ image = image.resize(image_size)
75
+ image = image.convert("RGB")
76
+
77
+ img = torch.from_numpy(np.array(image))
78
+ viz_img = img.clone()
79
+
80
+ # Normalize the image
81
+ image_normalization = IMAGE_NORMALIZATION_DICT[image_norm_type]
82
+ img_mean = image_normalization.mean
83
+ img_std = image_normalization.std
84
+ img = (img.float() / 255.0 - img_mean) / img_std
85
+
86
+ # Convert to BCHW format
87
+ img = img.permute(2, 0, 1).unsqueeze(0).to(device)
88
+
89
+ if return_viz_img:
90
+ return img, viz_img
91
+ else:
92
+ return img
93
+
94
+ return _get_example_input
95
+
96
+
97
+ def inference_encoder(encoder, encoder_input):
98
+ # Encoder expects a ViTEncoderInput object
99
+ return encoder(encoder_input).features
100
+
101
+
102
+ def test_make_dummy_encoder(device):
103
+ print(f"Testing Init of Dummy Encoder on {device.type.upper()}")
104
+ encoder = _make_encoder_test("dummy").to(device)
105
+
106
+ # Check if the encoder has parameters
107
+ try:
108
+ params = list(encoder.parameters())
109
+ if not params:
110
+ print("Warning: The encoder has no parameters.")
111
+ else:
112
+ # Verify if the model is on the right device
113
+ assert params[0].is_cuda == (device.type == "cuda")
114
+
115
+ except Exception as e:
116
+ print(f"Error: {e}")
117
+ assert False # Fail the test if any error occurs
118
+
119
+ assert encoder is not None
120
+
121
+
122
+ def test_all_encoder_basics(encoders, encoder_configs, norm_types, example_input, encoder_name, device):
123
+ if encoder_name:
124
+ encoders = [encoder_name] # Override default encoders with the one specified
125
+
126
+ for encoder_name, encoder_config in zip(encoders, encoder_configs):
127
+ print(f"Testing encoder: {encoder_name} on {device.type.upper()}")
128
+
129
+ encoder = _make_encoder_test(encoder_name, **encoder_config).to(device)
130
+ _check_baseclass_attribute(encoder, norm_types)
131
+ _check_norm_check_function(encoder)
132
+
133
+ if isinstance(encoder, UniCeptionViTEncoderBase):
134
+ _check_vit_encoder_attribute(encoder)
135
+ _test_vit_encoder_patch_size(encoder, example_input)
136
+
137
+
138
+ def _check_baseclass_attribute(encoder, norm_types):
139
+ assert hasattr(encoder, "name")
140
+ assert hasattr(encoder, "size")
141
+ assert hasattr(encoder, "data_norm_type")
142
+
143
+ assert isinstance(encoder.name, str)
144
+ assert isinstance(encoder.size, str) or encoder.size is None
145
+ assert isinstance(encoder.data_norm_type, str)
146
+
147
+ # Check if the data_norm_type is in the list of normalization types
148
+ assert encoder.data_norm_type in norm_types
149
+
150
+
151
+ def _check_norm_check_function(encoder):
152
+ assert hasattr(encoder, "_check_data_normalization_type")
153
+
154
+ encoder_notm_type = encoder.data_norm_type
155
+
156
+ try:
157
+ encoder._check_data_normalization_type(encoder_notm_type)
158
+ except AssertionError:
159
+ assert False
160
+
161
+ try:
162
+ encoder._check_data_normalization_type("some_nonexistent_norm_type")
163
+ assert False
164
+ except AssertionError:
165
+ pass
166
+
167
+
168
+ def _check_vit_encoder_attribute(encoder):
169
+ assert hasattr(encoder, "patch_size")
170
+ assert isinstance(encoder.patch_size, int)
171
+ assert encoder.patch_size > 0
172
+
173
+
174
+ def _test_vit_encoder_patch_size(encoder, example_input):
175
+ print(f"Testing {encoder.name} inference")
176
+ image_size = (14 * encoder.patch_size, 14 * encoder.patch_size)
177
+
178
+ img = example_input(image_size, encoder.data_norm_type)
179
+ # Create an instance of ViTEncoderInput with correct attributes
180
+ encoder_input = ViTEncoderInput(
181
+ data_norm_type=encoder.data_norm_type,
182
+ image=img,
183
+ )
184
+
185
+ encoder_output = inference_encoder(encoder, encoder_input)
186
+
187
+ assert isinstance(encoder_output, torch.Tensor)
188
+ assert encoder_output.shape[2] == 14
189
+ assert encoder_output.shape[3] == 14
190
+
191
+
192
+ @pytest.fixture(scope="session", autouse=True)
193
+ def seed_everything():
194
+ seed = 42
195
+ random.seed(seed)
196
+ os.environ["PYTHONHASHSEED"] = str(seed)
197
+ np.random.seed(seed)
198
+ torch.manual_seed(seed)
199
+ torch.backends.cudnn.deterministic = True
200
+ torch.backends.cudnn.benchmark = False
201
+ print(f"Seed set to: {seed} (type: {type(seed)})")
202
+
203
+ # Turn XFormers off for testing on CPU
204
+ os.environ["XFORMERS_DISABLED"] = "1"
UniCeption/tests/models/encoders/viz_image_encoders.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PCA Visualization of UniCeption Image Encoders
3
+ """
4
+
5
+ import os
6
+ import random
7
+ from functools import lru_cache
8
+ from typing import Tuple
9
+
10
+ import numpy as np
11
+ import requests
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from matplotlib import pyplot as plt
15
+ from PIL import Image
16
+ from sklearn.decomposition import PCA
17
+
18
+ from uniception.models.encoders import *
19
+ from uniception.models.encoders.image_normalizations import *
20
+
21
+
22
+ class TestEncoders:
23
+ def __init__(self, pca_save_folder, *args, **kwargs):
24
+ super(TestEncoders, self).__init__(*args, **kwargs)
25
+
26
+ self.pca_save_folder = pca_save_folder
27
+
28
+ self.norm_types = IMAGE_NORMALIZATION_DICT.keys()
29
+
30
+ self.encoders = [
31
+ "croco",
32
+ "dust3r_224",
33
+ "dust3r_512",
34
+ "dust3r_512_dpt",
35
+ "mast3r_512",
36
+ "dinov2_large",
37
+ "dinov2_large_reg",
38
+ "dinov2_large_dav2",
39
+ "dinov2_giant",
40
+ "dinov2_giant_reg",
41
+ "radio_v2.5-b",
42
+ "radio_v2.5-l",
43
+ "e-radio_v2",
44
+ ]
45
+
46
+ self.encoder_configs = [{}] * len(self.encoders)
47
+
48
+ def inference_encoder(self, encoder, input):
49
+ return encoder(input)
50
+
51
+ def visualize_all_encoders(self):
52
+ for encoder, encoder_config in zip(self.encoders, self.encoder_configs):
53
+ encoder = _make_encoder_test(encoder, **encoder_config)
54
+ self._visualize_encoder_features_consistency(encoder, (224, 224))
55
+
56
+ def _visualize_encoder_features(self, encoder, image_size: Tuple[int, int]):
57
+ img, viz_img = self._get_example_input(image_size, encoder.data_norm_type, return_viz_img=True)
58
+ # input and output of the encoder
59
+ encoder_input: ViTEncoderInput = ViTEncoderInput(
60
+ data_norm_type=encoder.data_norm_type,
61
+ image=img,
62
+ )
63
+
64
+ encoder_output = self.inference_encoder(encoder, encoder_input)
65
+ encoder_output = encoder_output.features
66
+
67
+ self.assertTrue(isinstance(encoder_output, torch.Tensor))
68
+
69
+ # visualize the features
70
+ pca_viz = get_pca_map(encoder_output.permute(0, 2, 3, 1), image_size, return_pca_stats=False)
71
+
72
+ # plot the input image and the PCA features
73
+ fig, axs = plt.subplots(1, 2, figsize=(12, 6))
74
+ axs[0].imshow(viz_img)
75
+ axs[0].set_title("Input Image")
76
+ axs[0].axis("off")
77
+ axs[1].imshow(pca_viz)
78
+ axs[1].set_title(f"PCA Features of {encoder.name}")
79
+ axs[1].axis("off")
80
+ plt.savefig(f"{self.pca_save_folder}/pca_{encoder.name}.png", bbox_inches="tight")
81
+ plt.close()
82
+
83
+ def _visualize_encoder_features_consistency(self, encoder, image_size: Tuple[int, int]):
84
+ img0, viz_img0 = self._get_example_input(
85
+ image_size, encoder.data_norm_type, img_selection=1, return_viz_img=True
86
+ )
87
+ img1, viz_img1 = self._get_example_input(
88
+ image_size, encoder.data_norm_type, img_selection=2, return_viz_img=True
89
+ )
90
+ # input and output of the encoder
91
+ encoder_input0: ViTEncoderInput = ViTEncoderInput(
92
+ data_norm_type=encoder.data_norm_type,
93
+ image=img0,
94
+ )
95
+
96
+ encoder_input1: ViTEncoderInput = ViTEncoderInput(
97
+ data_norm_type=encoder.data_norm_type,
98
+ image=img1,
99
+ )
100
+
101
+ encoder_output0 = self.inference_encoder(encoder, encoder_input0)
102
+ encoder_output0 = encoder_output0.features
103
+
104
+ encoder_output1 = self.inference_encoder(encoder, encoder_input1)
105
+ encoder_output1 = encoder_output1.features
106
+
107
+ # get a common PCA codec
108
+ cat_feats = torch.cat([encoder_output0, encoder_output1], dim=3)
109
+
110
+ pca_viz = get_pca_map(cat_feats.permute(0, 2, 3, 1), (image_size[0], image_size[1] * 2), return_pca_stats=True)
111
+
112
+ # concatenate the input images along the width dimension
113
+ cat_imgs = torch.cat([viz_img0, viz_img1], dim=1)
114
+
115
+ # plot the input image and the PCA features
116
+ fig, axs = plt.subplots(1, 2, figsize=(12, 6))
117
+ axs[0].imshow(cat_imgs)
118
+ axs[0].set_title("Input Images")
119
+ axs[0].axis("off")
120
+ axs[1].imshow(pca_viz[0])
121
+ axs[1].set_title(f"PCA Features of {encoder.name}")
122
+ axs[1].axis("off")
123
+ plt.savefig(f"{self.pca_save_folder}/multi_pca_{encoder.name}.png", bbox_inches="tight")
124
+ plt.close()
125
+
126
+ @lru_cache(maxsize=3)
127
+ def _get_example_input(
128
+ self,
129
+ image_size: Tuple[int, int],
130
+ image_norm_type: str = "dummy",
131
+ img_selection: int = 1,
132
+ return_viz_img: bool = False,
133
+ ) -> torch.Tensor:
134
+ url = f"https://raw.githubusercontent.com/naver/croco/d3d0ab2858d44bcad54e5bfc24f565983fbe18d9/assets/Chateau{img_selection}.png"
135
+ image = Image.open(requests.get(url, stream=True).raw)
136
+ image = image.resize(image_size)
137
+ image = image.convert("RGB")
138
+
139
+ img = torch.from_numpy(np.array(image))
140
+ viz_img = img.clone()
141
+
142
+ # Normalize the images
143
+ image_normalization = IMAGE_NORMALIZATION_DICT[image_norm_type]
144
+
145
+ img_mean, img_std = image_normalization.mean, image_normalization.std
146
+
147
+ img = (img.float() / 255.0 - img_mean) / img_std
148
+
149
+ # convert to BCHW format
150
+ img = img.permute(2, 0, 1).unsqueeze(0)
151
+
152
+ if return_viz_img:
153
+ return img, viz_img
154
+ else:
155
+ return img
156
+
157
+
158
+ def render_pca_as_rgb(features):
159
+ """
160
+ Perform PCA on the given feature tensor and render the first 3 principal components as RGB.
161
+
162
+ Args:
163
+ features (torch.Tensor): Feature tensor of shape (B, C, H, W).
164
+
165
+ Returns:
166
+ np.ndarray: RGB image of shape (H, W, 3).
167
+ """
168
+ # Ensure input is a 4D tensor
169
+ assert features.dim() == 4, "Input tensor must be 4D (B, C, H, W)"
170
+
171
+ B, C, H, W = features.shape
172
+
173
+ # Reshape the tensor to (B * H * W, C)
174
+ reshaped_features = features.permute(0, 2, 3, 1).contiguous().view(-1, C).cpu().numpy()
175
+
176
+ # Perform PCA
177
+ pca = PCA(n_components=3)
178
+ principal_components = pca.fit_transform(reshaped_features)
179
+
180
+ # Rescale the principal components to [0, 1]
181
+ principal_components = (principal_components - principal_components.min(axis=0)) / (
182
+ principal_components.max(axis=0) - principal_components.min(axis=0)
183
+ )
184
+
185
+ # Reshape the principal components to (B, H, W, 3)
186
+ principal_components = principal_components.reshape(B, H, W, 3)
187
+
188
+ # Convert the principal components to RGB image (take the first batch)
189
+ rgb_image = principal_components[0]
190
+
191
+ return rgb_image
192
+
193
+
194
+ def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False):
195
+ # features: (N, C)
196
+ # m: a hyperparam controlling how many std dev outside for outliers
197
+ assert len(features.shape) == 2, "features should be (N, C)"
198
+ reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2]
199
+ colors = features @ reduction_mat
200
+ if remove_first_component:
201
+ colors_min = colors.min(dim=0).values
202
+ colors_max = colors.max(dim=0).values
203
+ tmp_colors = (colors - colors_min) / (colors_max - colors_min)
204
+ fg_mask = tmp_colors[..., 0] < 0.2
205
+ reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2]
206
+ colors = features @ reduction_mat
207
+ else:
208
+ fg_mask = torch.ones_like(colors[:, 0]).bool()
209
+ d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values)
210
+ mdev = torch.median(d, dim=0).values
211
+ s = d / mdev
212
+ try:
213
+ rins = colors[fg_mask][s[:, 0] < m, 0]
214
+ gins = colors[fg_mask][s[:, 1] < m, 1]
215
+ bins = colors[fg_mask][s[:, 2] < m, 2]
216
+ rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
217
+ rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
218
+ except:
219
+ rins = colors
220
+ gins = colors
221
+ bins = colors
222
+ rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
223
+ rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
224
+
225
+ return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat)
226
+
227
+
228
+ def get_pca_map(
229
+ feature_map: torch.Tensor,
230
+ img_size,
231
+ interpolation="bicubic",
232
+ return_pca_stats=False,
233
+ pca_stats=None,
234
+ ):
235
+ """
236
+ feature_map: (1, h, w, C) is the feature map of a single image.
237
+ """
238
+ if feature_map.shape[0] != 1:
239
+ # make it (1, h, w, C)
240
+ feature_map = feature_map[None]
241
+ if pca_stats is None:
242
+ reduct_mat, color_min, color_max = get_robust_pca(feature_map.reshape(-1, feature_map.shape[-1]))
243
+ else:
244
+ reduct_mat, color_min, color_max = pca_stats
245
+ pca_color = feature_map @ reduct_mat
246
+ pca_color = (pca_color - color_min) / (color_max - color_min)
247
+ pca_color = pca_color.clamp(0, 1)
248
+ pca_color = F.interpolate(
249
+ pca_color.permute(0, 3, 1, 2),
250
+ size=img_size,
251
+ mode=interpolation,
252
+ ).permute(0, 2, 3, 1)
253
+ pca_color = pca_color.detach().cpu().numpy().squeeze(0)
254
+ if return_pca_stats:
255
+ return pca_color, (reduct_mat, color_min, color_max)
256
+ return pca_color
257
+
258
+
259
+ def seed_everything(seed=42):
260
+ """
261
+ Set the `seed` value for torch and numpy seeds. Also turns on
262
+ deterministic execution for cudnn.
263
+
264
+ Parameters:
265
+ - seed: A hashable seed value
266
+ """
267
+ random.seed(seed)
268
+ os.environ["PYTHONHASHSEED"] = str(seed)
269
+ np.random.seed(seed)
270
+ torch.manual_seed(seed)
271
+ torch.backends.cudnn.deterministic = True
272
+ torch.backends.cudnn.benchmark = False
273
+ print(f"Seed set to: {seed} (type: {type(seed)})")
274
+
275
+
276
+ if __name__ == "__main__":
277
+ # Turn XFormers off for testing on CPU
278
+ os.environ["XFORMERS_DISABLED"] = "1"
279
+
280
+ # Seed everything for consistent testing
281
+ seed_everything()
282
+
283
+ # Create local directory for storing the PCA images
284
+ current_file_path = os.path.abspath(__file__)
285
+ relative_pca_image_folder = os.path.join(os.path.dirname(current_file_path), "../../../local/encoders/pca_images")
286
+ os.makedirs(relative_pca_image_folder, exist_ok=True)
287
+
288
+ # Initialize the test class
289
+ test = TestEncoders(pca_save_folder=relative_pca_image_folder)
290
+
291
+ # Visualize the PCA of all encoders
292
+ test.visualize_all_encoders()
293
+
294
+ print(f"The PCA visualizations of all encoders are saved successfully to {relative_pca_image_folder}!")
UniCeption/tests/models/info_sharing/viz_mulit_view_cross_attn_transformers.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PCA Visualization of UniCeption Image Encoders + Multi-View Cross Attention Transformers
3
+ """
4
+
5
+ import os
6
+ import random
7
+ from functools import lru_cache
8
+ from typing import Tuple
9
+
10
+ import numpy as np
11
+ import requests
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from matplotlib import pyplot as plt
15
+ from PIL import Image
16
+ from sklearn.decomposition import PCA
17
+
18
+ from uniception.models.encoders import *
19
+ from uniception.models.encoders.image_normalizations import *
20
+ from uniception.models.info_sharing.base import MultiViewTransformerInput
21
+ from uniception.models.info_sharing.cross_attention_transformer import MultiViewCrossAttentionTransformerIFR
22
+ from uniception.models.libs.croco.pos_embed import RoPE2D, get_2d_sincos_pos_embed
23
+
24
+
25
+ def _make_mv_cross_attention_transformer_test(model_str: str, **kwargs):
26
+ current_file_path = os.path.abspath(__file__)
27
+ relative_checkpoint_path = os.path.join(
28
+ os.path.dirname(current_file_path), "../../../checkpoints/info_sharing/cross_attn_transformer"
29
+ )
30
+ rope = RoPE2D(float(100))
31
+ if model_str == "croco":
32
+ return MultiViewCrossAttentionTransformerIFR(
33
+ name="croco_base_decoder",
34
+ input_embed_dim=1024,
35
+ num_views=2,
36
+ indices=[12 * 2 // 4, 12 * 3 // 4],
37
+ norm_intermediate=False,
38
+ custom_positional_encoding=rope,
39
+ pretrained_checkpoint_path=f"{relative_checkpoint_path}/Two_View_Cross_Attention_Transformer_CroCo.pth",
40
+ **kwargs,
41
+ )
42
+ elif model_str == "dust3r_224":
43
+ return MultiViewCrossAttentionTransformerIFR(
44
+ name="dust3r_224_base_decoder",
45
+ input_embed_dim=1024,
46
+ num_views=2,
47
+ indices=[12 * 2 // 4, 12 * 3 // 4],
48
+ norm_intermediate=False,
49
+ custom_positional_encoding=rope,
50
+ pretrained_checkpoint_path=f"{relative_checkpoint_path}/Two_View_Cross_Attention_Transformer_DUSt3R_224_linear.pth",
51
+ **kwargs,
52
+ )
53
+ elif model_str == "dust3r_512":
54
+ return MultiViewCrossAttentionTransformerIFR(
55
+ name="dust3r_512_base_decoder",
56
+ input_embed_dim=1024,
57
+ num_views=2,
58
+ indices=[12 * 2 // 4, 12 * 3 // 4],
59
+ norm_intermediate=False,
60
+ custom_positional_encoding=rope,
61
+ pretrained_checkpoint_path=f"{relative_checkpoint_path}/Two_View_Cross_Attention_Transformer_DUSt3R_512_linear.pth",
62
+ **kwargs,
63
+ )
64
+ elif model_str == "dust3r_512_dpt":
65
+ return MultiViewCrossAttentionTransformerIFR(
66
+ name="dust3r_512_dpt_base_decoder",
67
+ input_embed_dim=1024,
68
+ num_views=2,
69
+ indices=[12 * 2 // 4, 12 * 3 // 4],
70
+ norm_intermediate=False,
71
+ custom_positional_encoding=rope,
72
+ pretrained_checkpoint_path=f"{relative_checkpoint_path}/Two_View_Cross_Attention_Transformer_DUSt3R_512_dpt.pth",
73
+ **kwargs,
74
+ )
75
+ elif model_str == "mast3r_512":
76
+ return MultiViewCrossAttentionTransformerIFR(
77
+ name="mast3r_512_base_decoder",
78
+ input_embed_dim=1024,
79
+ num_views=2,
80
+ indices=[12 * 2 // 4, 12 * 3 // 4],
81
+ norm_intermediate=False,
82
+ custom_positional_encoding=rope,
83
+ pretrained_checkpoint_path=f"{relative_checkpoint_path}/Two_View_Cross_Attention_Transformer_MASt3R.pth",
84
+ **kwargs,
85
+ )
86
+
87
+
88
+ class TestMultiViewTransformers:
89
+ def __init__(self, pca_save_folder, *args, **kwargs):
90
+ super(TestMultiViewTransformers, self).__init__(*args, **kwargs)
91
+
92
+ self.pca_save_folder = pca_save_folder
93
+
94
+ self.norm_types = IMAGE_NORMALIZATION_DICT.keys()
95
+
96
+ self.models = [
97
+ "croco",
98
+ "dust3r_224",
99
+ "dust3r_512",
100
+ "dust3r_512_dpt",
101
+ "mast3r_512",
102
+ ]
103
+
104
+ self.model_configs = [{}] * len(self.models)
105
+
106
+ def inference_encoder(self, encoder, input):
107
+ return encoder(input)
108
+
109
+ def inference_info_sharing(self, info_sharing, input):
110
+ return info_sharing(input)
111
+
112
+ def visualize_all_models(self):
113
+ for model, model_config in zip(self.models, self.model_configs):
114
+ encoder = _make_encoder_test(model, **model_config)
115
+ info_sharing = _make_mv_cross_attention_transformer_test(model, **model_config)
116
+ self._visualize_model_features_consistency(encoder, info_sharing, (224, 224))
117
+
118
+ def _visualize_model_features_consistency(self, encoder, info_sharing, image_size: Tuple[int, int]):
119
+ img0, viz_img0 = self._get_example_input(
120
+ image_size, encoder.data_norm_type, img_selection=1, return_viz_img=True
121
+ )
122
+ img1, viz_img1 = self._get_example_input(
123
+ image_size, encoder.data_norm_type, img_selection=2, return_viz_img=True
124
+ )
125
+ # input and output of the encoder
126
+ encoder_input0: ViTEncoderInput = ViTEncoderInput(
127
+ data_norm_type=encoder.data_norm_type,
128
+ image=img0,
129
+ )
130
+
131
+ encoder_input1: ViTEncoderInput = ViTEncoderInput(
132
+ data_norm_type=encoder.data_norm_type,
133
+ image=img1,
134
+ )
135
+
136
+ encoder_output0 = self.inference_encoder(encoder, encoder_input0)
137
+ encoder_output0 = encoder_output0.features
138
+
139
+ encoder_output1 = self.inference_encoder(encoder, encoder_input1)
140
+ encoder_output1 = encoder_output1.features
141
+
142
+ # pass the encoder outputs to the info sharing model
143
+ multi_view_features = [encoder_output0, encoder_output1]
144
+ info_sharing_input = MultiViewTransformerInput(features=multi_view_features)
145
+ info_sharing_output = self.inference_info_sharing(info_sharing, info_sharing_input)
146
+ final_layer_multi_view_features = info_sharing_output[0].features
147
+
148
+ # get a common PCA codec
149
+ cat_feats = torch.cat(final_layer_multi_view_features, dim=3)
150
+
151
+ pca_viz = get_pca_map(cat_feats.permute(0, 2, 3, 1), (image_size[0], image_size[1] * 2), return_pca_stats=True)
152
+
153
+ # concatenate the input images along the width dimension
154
+ cat_imgs = torch.cat([viz_img0, viz_img1], dim=1)
155
+
156
+ # plot the input image and the PCA features
157
+ fig, axs = plt.subplots(1, 2, figsize=(12, 6))
158
+ axs[0].imshow(cat_imgs)
159
+ axs[0].set_title("Input Images")
160
+ axs[0].axis("off")
161
+ axs[1].imshow(pca_viz[0])
162
+ axs[1].set_title(f"PCA Features of {encoder.name} + Base Decoder")
163
+ axs[1].axis("off")
164
+ plt.savefig(f"{self.pca_save_folder}/multi_pca_{encoder.name}.png", bbox_inches="tight")
165
+ plt.close()
166
+
167
+ @lru_cache(maxsize=3)
168
+ def _get_example_input(
169
+ self,
170
+ image_size: Tuple[int, int],
171
+ image_norm_type: str = "dummy",
172
+ img_selection: int = 1,
173
+ return_viz_img: bool = False,
174
+ ) -> torch.Tensor:
175
+ url = f"https://raw.githubusercontent.com/naver/croco/d3d0ab2858d44bcad54e5bfc24f565983fbe18d9/assets/Chateau{img_selection}.png"
176
+ image = Image.open(requests.get(url, stream=True).raw)
177
+ image = image.resize(image_size)
178
+ image = image.convert("RGB")
179
+
180
+ img = torch.from_numpy(np.array(image))
181
+ viz_img = img.clone()
182
+
183
+ # Normalize the images
184
+ image_normalization = IMAGE_NORMALIZATION_DICT[image_norm_type]
185
+
186
+ img_mean, img_std = image_normalization.mean, image_normalization.std
187
+
188
+ img = (img.float() / 255.0 - img_mean) / img_std
189
+
190
+ # convert to BCHW format
191
+ img = img.permute(2, 0, 1).unsqueeze(0)
192
+
193
+ if return_viz_img:
194
+ return img, viz_img
195
+ else:
196
+ return img
197
+
198
+
199
+ def render_pca_as_rgb(features):
200
+ """
201
+ Perform PCA on the given feature tensor and render the first 3 principal components as RGB.
202
+
203
+ Args:
204
+ features (torch.Tensor): Feature tensor of shape (B, C, H, W).
205
+
206
+ Returns:
207
+ np.ndarray: RGB image of shape (H, W, 3).
208
+ """
209
+ # Ensure input is a 4D tensor
210
+ assert features.dim() == 4, "Input tensor must be 4D (B, C, H, W)"
211
+
212
+ B, C, H, W = features.shape
213
+
214
+ # Reshape the tensor to (B * H * W, C)
215
+ reshaped_features = features.permute(0, 2, 3, 1).contiguous().view(-1, C).cpu().numpy()
216
+
217
+ # Perform PCA
218
+ pca = PCA(n_components=3)
219
+ principal_components = pca.fit_transform(reshaped_features)
220
+
221
+ # Rescale the principal components to [0, 1]
222
+ principal_components = (principal_components - principal_components.min(axis=0)) / (
223
+ principal_components.max(axis=0) - principal_components.min(axis=0)
224
+ )
225
+
226
+ # Reshape the principal components to (B, H, W, 3)
227
+ principal_components = principal_components.reshape(B, H, W, 3)
228
+
229
+ # Convert the principal components to RGB image (take the first batch)
230
+ rgb_image = principal_components[0]
231
+
232
+ return rgb_image
233
+
234
+
235
+ def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False):
236
+ # features: (N, C)
237
+ # m: a hyperparam controlling how many std dev outside for outliers
238
+ assert len(features.shape) == 2, "features should be (N, C)"
239
+ reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2]
240
+ colors = features @ reduction_mat
241
+ if remove_first_component:
242
+ colors_min = colors.min(dim=0).values
243
+ colors_max = colors.max(dim=0).values
244
+ tmp_colors = (colors - colors_min) / (colors_max - colors_min)
245
+ fg_mask = tmp_colors[..., 0] < 0.2
246
+ reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2]
247
+ colors = features @ reduction_mat
248
+ else:
249
+ fg_mask = torch.ones_like(colors[:, 0]).bool()
250
+ d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values)
251
+ mdev = torch.median(d, dim=0).values
252
+ s = d / mdev
253
+ try:
254
+ rins = colors[fg_mask][s[:, 0] < m, 0]
255
+ gins = colors[fg_mask][s[:, 1] < m, 1]
256
+ bins = colors[fg_mask][s[:, 2] < m, 2]
257
+ rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
258
+ rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
259
+ except:
260
+ rins = colors
261
+ gins = colors
262
+ bins = colors
263
+ rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
264
+ rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
265
+
266
+ return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat)
267
+
268
+
269
+ def get_pca_map(
270
+ feature_map: torch.Tensor,
271
+ img_size,
272
+ interpolation="bicubic",
273
+ return_pca_stats=False,
274
+ pca_stats=None,
275
+ ):
276
+ """
277
+ feature_map: (1, h, w, C) is the feature map of a single image.
278
+ """
279
+ if feature_map.shape[0] != 1:
280
+ # make it (1, h, w, C)
281
+ feature_map = feature_map[None]
282
+ if pca_stats is None:
283
+ reduct_mat, color_min, color_max = get_robust_pca(feature_map.reshape(-1, feature_map.shape[-1]))
284
+ else:
285
+ reduct_mat, color_min, color_max = pca_stats
286
+ pca_color = feature_map @ reduct_mat
287
+ pca_color = (pca_color - color_min) / (color_max - color_min)
288
+ pca_color = pca_color.clamp(0, 1)
289
+ pca_color = F.interpolate(
290
+ pca_color.permute(0, 3, 1, 2),
291
+ size=img_size,
292
+ mode=interpolation,
293
+ ).permute(0, 2, 3, 1)
294
+ pca_color = pca_color.detach().cpu().numpy().squeeze(0)
295
+ if return_pca_stats:
296
+ return pca_color, (reduct_mat, color_min, color_max)
297
+ return pca_color
298
+
299
+
300
+ def seed_everything(seed=42):
301
+ """
302
+ Set the `seed` value for torch and numpy seeds. Also turns on
303
+ deterministic execution for cudnn.
304
+
305
+ Parameters:
306
+ - seed: A hashable seed value
307
+ """
308
+ random.seed(seed)
309
+ os.environ["PYTHONHASHSEED"] = str(seed)
310
+ np.random.seed(seed)
311
+ torch.manual_seed(seed)
312
+ torch.backends.cudnn.deterministic = True
313
+ torch.backends.cudnn.benchmark = False
314
+ print(f"Seed set to: {seed} (type: {type(seed)})")
315
+
316
+
317
+ if __name__ == "__main__":
318
+ # Turn XFormers off for testing on CPU
319
+ os.environ["XFORMERS_DISABLED"] = "1"
320
+
321
+ # Seed everything for consistent testing
322
+ seed_everything()
323
+
324
+ # Create local directory for storing the PCA images
325
+ current_file_path = os.path.abspath(__file__)
326
+ relative_pca_image_folder = os.path.join(
327
+ os.path.dirname(current_file_path), "../../../local/info_sharing/pca_images"
328
+ )
329
+ os.makedirs(relative_pca_image_folder, exist_ok=True)
330
+
331
+ # Initialize the test class
332
+ test = TestMultiViewTransformers(pca_save_folder=relative_pca_image_folder)
333
+
334
+ # Visualize the PCA of all models
335
+ test.visualize_all_models()
336
+
337
+ print(f"The PCA visualizations of all models are saved successfully to {relative_pca_image_folder}!")
UniCeption/uniception/__init__.py ADDED
File without changes
UniCeption/uniception/models/encoders/README.md ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # UniCeption Encoders
2
+
3
+ ## Currently Supported Encoders
4
+
5
+ ### UniCeptionViTEncoderBase:
6
+
7
+ - `CroCoEncoder`
8
+ - `CroCoIntermediateFeatureReturner`
9
+ - `DINOv2Encoder`
10
+ - `DINOv2IntermediateFeatureReturner`
11
+ - `PatchEmbedder`
12
+ - `RADIOEncoder`
13
+ - `RADIOIntermediateFeatureReturner`
14
+
15
+ # Developer Guidelines for UniCeption Encoders
16
+
17
+ ## Overview
18
+
19
+ This folder contains the implementation of various UniCeption encoders. Each encoder must adhere to a specific structure and follow certain guidelines to ensure consistency and compatibility across different projects.
20
+
21
+ ## Directory Structure
22
+
23
+ The encoders and other necessary dependencies/tests for encoders are organized as follows:
24
+ ```
25
+ uniception/
26
+ ├── models/
27
+ │ ├── encoders/
28
+ │ │ ├── __init__.py
29
+ │ │ ├── base.py
30
+ │ │ ├── croco.py
31
+ │ │ ├── dinov2.py
32
+ │ │ ├── radio.py
33
+ │ │ ├── image_normalizations.py
34
+ │ └── ...
35
+ │ └── libs/
36
+ │ │ ├── external_dependency_folders/
37
+ | | | ├── external_dependency_files
38
+ tests/
39
+ ├── models/
40
+ │ ├── encoders/
41
+ │ │ ├── test_encoders.py
42
+ │ │ ├── viz_image_encoders.py
43
+ │ │ └── ...
44
+ | └── ...
45
+ └── ...
46
+ ```
47
+
48
+ ## Adding a New Encoder
49
+
50
+ To add a new encoder, follow these steps:
51
+
52
+ 1. **Create a New Encoder File**:
53
+ - Create a new file in the `encoders` directory, e.g., `new_encoder.py`.
54
+ - Define the new encoder class in this file, inheriting from `UniCeptionEncoderBase` or `UniCeptionViTEncoderBase`.
55
+ - Please look at the base class for the necessary attributes and methods to implement.
56
+
57
+ 2. **Define Input Data Normalization**:
58
+ - Add the corresponding normalization for the encoder to respective normalization files, for example, image normalizations should be added to `image_normalizations.py`.
59
+ - Ensure the normalization is added to the dictionaries present in the files, for example, `IMAGE_NORMALIZATION_DICT`.
60
+
61
+ 4. **Implement the Encoder Class**:
62
+ - Inherit from `UniCeptionEncoderBase` or `UniCeptionViTEncoderBase` or other UniCeption base classes.
63
+ - Implement the `forward` method.
64
+ - Ensure the encoder class has the necessary attributes and methods.
65
+
66
+ 4. **Update `__init__.py`**:
67
+ - Import the new encoder class in `__init__.py`.
68
+ - Add the new encoder to the encoder configuration dictionary `ENCODER_CONFIGS` so that it can be instantiated via the encoder factory.
69
+ - Update the `_make_encoder_test` function to include the new encoder.
70
+
71
+ 5. **Run Encoder Unit Tests**:
72
+ - Run `pytest -vs tests/models/encoders/test_encoders.py --encoder-name="<new_encoder>"` to test the basic expected functionality of UniCeption encoders.
73
+ - Also, add your new encoder to the list in the encoders() in `tests/models/encoders/test_encoders.py` so that it can be tested along with all the existing encoders.
74
+ - Optionally, for image encoders, the unit tests in `tests/models/encoders/viz_image_encoders.py` save PCA visualizations of the encoder outputs to the `local/pca_images` directory.
75
+
76
+ ## Example Encoder Implementation
77
+
78
+ Here is an example of how to implement a new encoder:
79
+
80
+ ```python
81
+ # new_encoder.py
82
+ import torch
83
+ from uniception.models.encoders.base import UniCeptionEncoderBase, EncoderInput, EncoderOutput
84
+
85
+ class NewEncoder(UniCeptionEncoderBase):
86
+ def __init__(self, name: str, data_norm_type: str, *args, **kwargs):
87
+ super().__init__(name=name, data_norm_type=data_norm_type, *args, **kwargs)
88
+ # Initialize encoder-specific layers and parameters here
89
+
90
+ def forward(self, encoder_input: EncoderInput) -> EncoderOutput:
91
+ self._check_data_normalization_type(encoder_input.data_norm_type)
92
+ # Implement the forward pass
93
+ return EncoderOutput()
94
+ ```
95
+
96
+ ## Example Normalization
97
+
98
+ Add the normalization for the new encoder, for example, to `image_normalizations.py`:
99
+
100
+ ```python
101
+ # image_normalizations.py
102
+ IMAGE_NORMALIZATION_DICT = {
103
+ "dummy": ImageNormalization(mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([1.0, 1.0, 1.0])),
104
+ "croco": ImageNormalization(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])),
105
+ "dust3r": ImageNormalization(mean=torch.tensor([0.5, 0.5, 0.5]), std=torch.tensor([0.5, 0.5, 0.5])),
106
+ "dinov2": ImageNormalization(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])),
107
+ "radio": ImageNormalization(mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([1.0, 1.0, 1.0])),
108
+ "new_encoder": ImageNormalization(mean=torch.tensor([0.5, 0.5, 0.5]), std=torch.tensor([0.2, 0.2, 0.2])),
109
+ }
110
+ ```
111
+
112
+ ## Example Unit Testing
113
+
114
+ Add the new encoder to the encoder factory in `__init__.py` and the encoder list in `tests/models/encoders/test_encoders.py`. Additional tests can also be added as required.
115
+
116
+ Look at `tests/models/encoders/test_encoders.py` to see what tests are run.
117
+
118
+ Additionally, if the new encoder is an image encoder, you can add to the encoder list in `tests/models/encoders/viz_image_encoders.py` for saving PCA visualizations of the encoder outputs to the `local/pca_images` directory.
119
+
120
+ ## Developer Guidelines
121
+
122
+ Please follow these guidelines when contributing to the UniCeption encoders:
123
+ - **Consistency**: Ensure that the new encoder follows the structure and naming conventions of existing encoders.
124
+ - **Code Style**: Follow the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html) for code style.
125
+ - **Documentation**: Add docstrings to all classes and methods.
126
+ - **Unit Tests**: Add necessary unit tests for the encoder class.
127
+ - **Linting**: Run `black` on your code before committing. For example, you can run `black uniception`.
128
+
129
+ ## Happy Coding!
UniCeption/uniception/models/encoders/__init__.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Encoder Factory for UniCeption
3
+ """
4
+
5
+ import os
6
+
7
+ from uniception.models.encoders.base import (
8
+ EncoderGlobalRepInput,
9
+ EncoderInput,
10
+ UniCeptionEncoderBase,
11
+ UniCeptionViTEncoderBase,
12
+ ViTEncoderInput,
13
+ ViTEncoderNonImageInput,
14
+ ViTEncoderOutput,
15
+ )
16
+ from uniception.models.encoders.cosmos import CosmosEncoder
17
+ from uniception.models.encoders.croco import CroCoEncoder, CroCoIntermediateFeatureReturner
18
+ from uniception.models.encoders.dense_rep_encoder import DenseRepresentationEncoder
19
+ from uniception.models.encoders.dinov2 import DINOv2Encoder, DINOv2IntermediateFeatureReturner
20
+ from uniception.models.encoders.global_rep_encoder import GlobalRepresentationEncoder
21
+ from uniception.models.encoders.naradio import NARADIOEncoder
22
+ from uniception.models.encoders.patch_embedder import PatchEmbedder
23
+ from uniception.models.encoders.radio import RADIOEncoder, RADIOIntermediateFeatureReturner
24
+
25
+ # Define encoder configurations
26
+ ENCODER_CONFIGS = {
27
+ "croco": {
28
+ "class": CroCoEncoder,
29
+ "intermediate_feature_returner_class": CroCoIntermediateFeatureReturner,
30
+ "supported_models": ["CroCov2", "DUSt3R", "MASt3R"],
31
+ },
32
+ "dense_rep_encoder": {
33
+ "class": DenseRepresentationEncoder,
34
+ "supported_models": ["Dense-Representation-Encoder"],
35
+ },
36
+ "dinov2": {
37
+ "class": DINOv2Encoder,
38
+ "intermediate_feature_returner_class": DINOv2IntermediateFeatureReturner,
39
+ "supported_models": ["DINOv2", "DINOv2-Registers", "DINOv2-Depth-Anythingv2"],
40
+ },
41
+ "global_rep_encoder": {
42
+ "class": GlobalRepresentationEncoder,
43
+ "supported_models": ["Global-Representation-Encoder"],
44
+ },
45
+ "patch_embedder": {
46
+ "class": PatchEmbedder,
47
+ "supported_models": ["Patch-Embedder"],
48
+ },
49
+ "radio": {
50
+ "class": RADIOEncoder,
51
+ "intermediate_feature_returner_class": RADIOIntermediateFeatureReturner,
52
+ "supported_models": ["RADIO", "E-RADIO"],
53
+ },
54
+ "cosmos": {
55
+ "class": CosmosEncoder,
56
+ "supported_models": ["Cosmos-Tokenizer CI8x8", "Cosmos-Tokenizer CI16x16"],
57
+ },
58
+ "naradio": {
59
+ "class": NARADIOEncoder,
60
+ "supported_models": ["RADIO"],
61
+ },
62
+ # Add other encoders here
63
+ }
64
+
65
+
66
+ def encoder_factory(encoder_str: str, **kwargs) -> UniCeptionEncoderBase:
67
+ """
68
+ Encoder factory for UniCeption.
69
+ Please use python3 -m uniception.models.encoders.list to see available encoders.
70
+
71
+ Args:
72
+ encoder_str (str): Name of the encoder to create.
73
+ **kwargs: Additional keyword arguments to pass to the encoder constructor.
74
+
75
+ Returns:
76
+ UniCeptionEncoderBase: An instance of the specified encoder.
77
+ """
78
+ if encoder_str not in ENCODER_CONFIGS:
79
+ raise ValueError(
80
+ f"Unknown encoder: {encoder_str}. For valid encoder_str options, please use python3 -m uniception.models.encoders.list"
81
+ )
82
+
83
+ encoder_config = ENCODER_CONFIGS[encoder_str]
84
+ encoder_class = encoder_config["class"]
85
+
86
+ return encoder_class(**kwargs)
87
+
88
+
89
+ def feature_returner_encoder_factory(encoder_str: str, **kwargs) -> UniCeptionEncoderBase:
90
+ """
91
+ Factory for UniCeption Encoders with support for intermediate feature returning.
92
+ Please use python3 -m uniception.models.encoders.list to see available encoders.
93
+
94
+ Args:
95
+ encoder_str (str): Name of the encoder to create.
96
+ **kwargs: Additional keyword arguments to pass to the encoder constructor.
97
+
98
+ Returns:
99
+ UniCeptionEncoderBase: An instance of the specified encoder.
100
+ """
101
+ if encoder_str not in ENCODER_CONFIGS:
102
+ raise ValueError(
103
+ f"Unknown encoder: {encoder_str}. For valid encoder_str options, please use python3 -m uniception.models.encoders.list"
104
+ )
105
+
106
+ encoder_config = ENCODER_CONFIGS[encoder_str]
107
+ encoder_class = encoder_config["intermediate_feature_returner_class"]
108
+
109
+ return encoder_class(**kwargs)
110
+
111
+
112
+ def get_available_encoders() -> list:
113
+ """
114
+ Get a list of available encoders in UniCeption.
115
+
116
+ Returns:
117
+ list: A list of available encoder names.
118
+ """
119
+ return list(ENCODER_CONFIGS.keys())
120
+
121
+
122
+ def print_available_encoder_models():
123
+ """
124
+ Print the currently supported encoders in UniCeption.
125
+ """
126
+ print("Currently Supported Encoders in UniCeption:\nFormat -> encoder_str: supported_models")
127
+ for encoder_name, config in ENCODER_CONFIGS.items():
128
+ print(f"{encoder_name}: {', '.join(config['supported_models'])}")
129
+
130
+
131
+ def _make_encoder_test(encoder_str: str, **kwargs) -> UniCeptionEncoderBase:
132
+ "Function to create encoders for testing purposes."
133
+ current_file_path = os.path.abspath(__file__)
134
+ relative_checkpoint_path = os.path.join(os.path.dirname(current_file_path), "../../../checkpoints/encoders")
135
+ if encoder_str == "dummy":
136
+ return UniCeptionEncoderBase(name="dummy", data_norm_type="dummy")
137
+ elif encoder_str == "croco":
138
+ return CroCoEncoder(
139
+ name="croco",
140
+ data_norm_type="croco",
141
+ pretrained_checkpoint_path=f"{relative_checkpoint_path}/CroCo_Encoder_224.pth",
142
+ patch_embed_cls="PatchEmbedCroCo",
143
+ )
144
+ elif encoder_str == "dust3r_224":
145
+ return CroCoEncoder(
146
+ name="dust3r_224",
147
+ data_norm_type="dust3r",
148
+ pretrained_checkpoint_path=f"{relative_checkpoint_path}/CroCo_Encoder_224_DUSt3R_linear.pth",
149
+ patch_embed_cls="PatchEmbedDust3R",
150
+ )
151
+ elif encoder_str == "dust3r_512":
152
+ return CroCoEncoder(
153
+ name="dust3r_512",
154
+ data_norm_type="dust3r",
155
+ pretrained_checkpoint_path=f"{relative_checkpoint_path}/CroCo_Encoder_512_DUSt3R_linear.pth",
156
+ patch_embed_cls="ManyAR_PatchEmbed",
157
+ img_size=(512, 512),
158
+ )
159
+ elif encoder_str == "dust3r_512_dpt":
160
+ return CroCoEncoder(
161
+ name="dust3r_512_dpt",
162
+ data_norm_type="dust3r",
163
+ pretrained_checkpoint_path=f"{relative_checkpoint_path}/CroCo_Encoder_512_DUSt3R_dpt.pth",
164
+ patch_embed_cls="ManyAR_PatchEmbed",
165
+ img_size=(512, 512),
166
+ )
167
+ elif encoder_str == "mast3r_512":
168
+ return CroCoEncoder(
169
+ name="mast3r_512",
170
+ data_norm_type="dust3r",
171
+ pretrained_checkpoint_path=f"{relative_checkpoint_path}/CroCo_Encoder_512_MASt3R.pth",
172
+ patch_embed_cls="ManyAR_PatchEmbed",
173
+ img_size=(512, 512),
174
+ )
175
+ elif "dinov2" in encoder_str:
176
+ size = encoder_str.split("_")[1]
177
+ size_single_cap_letter = size[0].upper()
178
+ if "reg" in encoder_str:
179
+ with_registers = True
180
+ pretrained_checkpoint_path = None
181
+ elif "dav2" in encoder_str:
182
+ with_registers = False
183
+ pretrained_checkpoint_path = (
184
+ f"{relative_checkpoint_path}/DINOv2_ViT{size_single_cap_letter}_DepthAnythingV2.pth"
185
+ )
186
+ else:
187
+ with_registers = False
188
+ pretrained_checkpoint_path = None
189
+ return DINOv2Encoder(
190
+ name=encoder_str.replace("_reg", ""),
191
+ size=size,
192
+ with_registers=with_registers,
193
+ pretrained_checkpoint_path=pretrained_checkpoint_path,
194
+ )
195
+ elif "naradio" in encoder_str:
196
+ return NARADIOEncoder(
197
+ name=encoder_str,
198
+ model_version=encoder_str.replace("na", ""),
199
+ )
200
+ elif "radio" in encoder_str:
201
+ if "e-radio" in encoder_str:
202
+ eradio_input_shape = (224, 224)
203
+ else:
204
+ eradio_input_shape = None
205
+ return RADIOEncoder(
206
+ name=encoder_str,
207
+ model_version=encoder_str,
208
+ eradio_input_shape=eradio_input_shape,
209
+ )
210
+ elif "cosmos" in encoder_str:
211
+ patch_size = int(encoder_str.split("x")[-1])
212
+ return CosmosEncoder(
213
+ name=encoder_str,
214
+ patch_size=patch_size,
215
+ pretrained_checkpoint_path=f"{relative_checkpoint_path}/Cosmos-Tokenizer-CI{patch_size}x{patch_size}/encoder.pth",
216
+ )
217
+ elif "patch_embedder" in encoder_str:
218
+ return PatchEmbedder(
219
+ name=encoder_str,
220
+ )
221
+ else:
222
+ raise ValueError(f"Unknown encoder: {encoder_str}")
223
+
224
+
225
+ __all__ = [
226
+ "encoder_factory",
227
+ "get_available_encoders",
228
+ "print_available_encoder_models",
229
+ "_make_encoder_test",
230
+ "UniCeptionEncoderBase",
231
+ "UniCeptionViTEncoderBase",
232
+ "EncoderInput",
233
+ "ViTEncoderInput",
234
+ "ViTEncoderOutput",
235
+ ]
UniCeption/uniception/models/encoders/base.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base Encoder Class for UniCeption
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Optional
7
+
8
+ import torch.nn as nn
9
+ from jaxtyping import Float
10
+ from torch import Tensor
11
+ from torch.utils.checkpoint import checkpoint
12
+
13
+
14
+ @dataclass
15
+ class EncoderInput:
16
+ "Data class for Encoder Input"
17
+
18
+ data_norm_type: str
19
+ # Add other fields that are required by the specific implementation of the encoder.
20
+
21
+
22
+ @dataclass
23
+ class EncoderOutput:
24
+ "Data class for Encoder Output"
25
+
26
+ pass
27
+
28
+
29
+ @dataclass
30
+ class EncoderGlobalRepInput:
31
+ "Data class for Encoder Global Representation Input"
32
+
33
+ data: Float[Tensor, "batch channel"]
34
+
35
+
36
+ @dataclass
37
+ class EncoderGlobalRepOutput:
38
+ "Data class for Encoder Global Representation Output"
39
+
40
+ features: Float[Tensor, "batch enc_embed_dim"]
41
+
42
+
43
+ class UniCeptionEncoderBase(nn.Module):
44
+ "Encoder Base Class for UniCeption"
45
+
46
+ def __init__(
47
+ self,
48
+ name: str,
49
+ data_norm_type: str,
50
+ size: Optional[str] = None,
51
+ *args,
52
+ **kwargs,
53
+ ):
54
+ """
55
+ Base class for all encoders in UniCeption.
56
+ """
57
+ super().__init__(*args, **kwargs)
58
+
59
+ self.name: str = name
60
+ self.size: Optional[str] = size
61
+
62
+ self.data_norm_type: str = data_norm_type
63
+
64
+ def forward(
65
+ self,
66
+ encoder_input: EncoderInput,
67
+ ) -> EncoderOutput:
68
+ """
69
+ Forward interface for the UniCeption encoders.
70
+
71
+ We expect the "data_norm_type" field to be present in the encoder_input to check for normalization type.
72
+
73
+ Args:
74
+ encoder_input (EncoderInput): Input to the encoder. We expect the following fields: "data_norm_type: str".
75
+ This is also includes the other fields that are required by the specific implementation of the encoder.
76
+
77
+ Returns:
78
+ EncoderOutput: Output of the encoder.
79
+ """
80
+
81
+ raise NotImplementedError
82
+
83
+ def _check_data_normalization_type(self, data_norm_type: str):
84
+ """
85
+ Check if the input normalization type matches the encoder's expected input data normalization type.
86
+
87
+ Args:
88
+ data_norm_type (str): Data normalization type.
89
+
90
+ Raises:
91
+ AssertionError: If the data normalization type does not match the encoder's expected input data normalization type.
92
+ """
93
+
94
+ assert (
95
+ data_norm_type == self.data_norm_type
96
+ ), f"Input normalization type {data_norm_type} does not match the encoder's normalization type {self.data_norm_type}."
97
+
98
+
99
+ @dataclass
100
+ class ViTEncoderInput(EncoderInput):
101
+ "Data class for Vision Transformer Encoder Input"
102
+
103
+ image: Float[Tensor, "batch channel height width"]
104
+
105
+
106
+ @dataclass
107
+ class ViTEncoderNonImageInput:
108
+ "Data class for Vision (2D-Grid) Transformer Encoder Non-Image Input"
109
+
110
+ data: Float[Tensor, "batch channel height width"]
111
+
112
+
113
+ @dataclass
114
+ class ViTEncoderOutput(EncoderOutput):
115
+ "Data class for Vision Transformer Encoder Output"
116
+
117
+ features: Float[Tensor, "batch enc_embed_dim feat_height feat_width"]
118
+
119
+
120
+ class UniCeptionViTEncoderBase(UniCeptionEncoderBase):
121
+ "Vision Transformer Encoder Base Class for UniCeption"
122
+
123
+ def __init__(
124
+ self,
125
+ patch_size: int,
126
+ gradient_checkpointing: bool = False,
127
+ *args,
128
+ **kwargs,
129
+ ):
130
+ """
131
+ Base class for all Vision Transformer encoders in UniCeption.
132
+ """
133
+ super().__init__(*args, **kwargs)
134
+
135
+ self.patch_size = patch_size
136
+ self.gradient_checkpointing = gradient_checkpointing
137
+
138
+ def wrap_module_with_gradient_checkpointing(self, module: nn.Module):
139
+ """
140
+ Wrapper for Gradient Checkpointing
141
+ References: https://github.com/microsoft/MoGe
142
+ """
143
+
144
+ class _CheckpointingWrapper(module.__class__):
145
+ _restore_cls = module.__class__
146
+
147
+ def forward(self, *args, **kwargs):
148
+ return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
149
+
150
+ module.__class__ = _CheckpointingWrapper
151
+ return module
152
+
153
+
154
+ if __name__ == "__main__":
155
+ dummy_model = UniCeptionEncoderBase(name="name", data_norm_type="norm")
156
+ dummy_vit_model = UniCeptionViTEncoderBase(name="name", data_norm_type="norm", patch_size=16)
157
+ print("Dummy Base Encoders created successfully!")
UniCeption/uniception/models/encoders/cosmos.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Encoder Class for Cosmos
3
+ """
4
+
5
+ import torch
6
+
7
+ from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput
8
+ from uniception.models.libs.cosmos_tokenizer.modules import ContinuousFormulation, EncoderType
9
+ from uniception.models.libs.cosmos_tokenizer.networks import TokenizerConfigs
10
+
11
+
12
+ class CosmosEncoder(UniCeptionViTEncoderBase):
13
+ "Uniception Cosmos Encoder"
14
+
15
+ def __init__(
16
+ self,
17
+ name: str,
18
+ data_norm_type: str = "cosmos",
19
+ patch_size: int = 8,
20
+ pretrained_checkpoint_path: str = None,
21
+ *args,
22
+ **kwargs,
23
+ ):
24
+ """
25
+ Cosmos Encoder for extracting spatial features from images.
26
+
27
+ Args:
28
+ name (str): Name of the encoder.
29
+ data_norm_type (str): Image normalization type. Default: "cosmos"
30
+ patch_size (int): Patch size for the encoder. Default: 8
31
+ pretrained_checkpoint_path (str): Path to the pretrained checkpoint. Default: None
32
+ """
33
+ # Init the base class
34
+ super().__init__(name=name, data_norm_type=data_norm_type, patch_size=patch_size, *args, **kwargs)
35
+
36
+ # Init Cosmos Encoder sepecific attributes
37
+ tokenizer_config = TokenizerConfigs["CI"].value.copy()
38
+ tokenizer_config.update(dict(spatial_compression=self.patch_size))
39
+
40
+ z_factor = tokenizer_config["z_factor"]
41
+ z_channels = tokenizer_config["z_channels"]
42
+ latent_channels = tokenizer_config["latent_channels"]
43
+ encoder_name = kwargs.get("encoder", EncoderType.Default.name)
44
+ print(tokenizer_config)
45
+ del tokenizer_config["z_factor"]
46
+ del tokenizer_config["z_channels"]
47
+ del tokenizer_config["latent_channels"]
48
+ self.encoder = EncoderType[encoder_name].value(z_channels=z_factor * z_channels, **tokenizer_config)
49
+ self.quant_conv = torch.nn.Conv2d(z_factor * z_channels, z_factor * latent_channels, 1)
50
+ formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name)
51
+ self.distribution = ContinuousFormulation[formulation_name].value()
52
+
53
+ # Load the pretrained checkpoint
54
+ if pretrained_checkpoint_path is not None:
55
+ print(f"Loading custom pretrained Cosmos checkpoint from {pretrained_checkpoint_path}")
56
+ ckpt = torch.load(pretrained_checkpoint_path, weights_only=False)
57
+ print(self.load_state_dict(ckpt["model"]))
58
+
59
+ def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]:
60
+ """Encodes an image into a latent embedding or code.
61
+
62
+ Args:
63
+ input_tensor: The input tensor Bx3xHxW layout, range [-1..1].
64
+ Returns:
65
+ For continuous image (CI) tokenizer, the tuple contains:
66
+ - The latent embedding, Bx16x(h)x(w), where the compression
67
+ rate is (H/h x W/w), and channel dimension of 16.
68
+ For discrete image (DI) tokenizer, the tuple contains:
69
+ - The indices, Bx(h)x(w), from a codebook of size 64K, which
70
+ corresponds to FSQ levels of (8,8,8,5,5,5).
71
+ - The discrete code, Bx6x(h)x(w), where the compression rate is
72
+ again (H/h x W/w), and channel dimension of 6.
73
+ """
74
+ x = self.encoder(input_tensor)
75
+ x = self.quant_conv(x)
76
+ output_latent = self.distribution(x)
77
+
78
+ if isinstance(output_latent, torch.Tensor):
79
+ return output_latent
80
+ return output_latent[:-1]
81
+
82
+ def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput:
83
+ """
84
+ Cosmos Encoder Forward Pass
85
+
86
+ Args:
87
+ encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
88
+
89
+ Returns:
90
+ ViTEncoderOutput: Output data from the encoder.
91
+ """
92
+ # Check image normalization type
93
+ self._check_data_normalization_type(encoder_input.data_norm_type)
94
+
95
+ # Check the dtype and shape of the input image
96
+ assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor"
97
+ assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)"
98
+ batch_size, channels, height, width = encoder_input.image.shape
99
+ assert channels == 3, "Input must have 3 channels"
100
+ assert (
101
+ height % self.patch_size == 0 and width % self.patch_size == 0
102
+ ), f"Input shape must be divisible by patch size: {self.patch_size}"
103
+
104
+ # Extract the features from the DINOv2 model
105
+ features = self.encode(encoder_input.image)[0].contiguous()
106
+
107
+ return ViTEncoderOutput(features=features)
108
+
109
+
110
+ if __name__ == "__main__":
111
+
112
+ # initialize different variants of the Cosmos Encoder, untrained
113
+ for is_continuous in [True]:
114
+ for patch_size in [8, 16]:
115
+ encoder = CosmosEncoder(name="cosmos", patch_size=patch_size)
116
+
117
+ # # initialize from trained checkpoint, with/without jit inference capability
118
+ PRETRAINED_JIT_CHECKPOINTS = {
119
+ ("CI", 8): "../../../checkpoints/encoders/cosmos/Cosmos-Tokenizer-CI8x8/encoder.pth",
120
+ ("CI", 16): "../../../checkpoints/encoders/cosmos/Cosmos-Tokenizer-CI16x16/encoder.pth",
121
+ }
122
+
123
+ for patch_size in [8, 16]:
124
+
125
+ encoder = CosmosEncoder(
126
+ name="cosmos",
127
+ patch_size=patch_size,
128
+ pretrained_checkpoint_path=PRETRAINED_JIT_CHECKPOINTS[("CI", patch_size)],
129
+ )
130
+
131
+ # example inference
132
+ dummy_image = torch.randn(1, 3, 256, 256).cuda()
133
+
134
+ encoder_input = ViTEncoderInput(data_norm_type="cosmos", image=dummy_image)
135
+
136
+ encoder = encoder.cuda()
137
+ encoder_output = encoder(encoder_input)
UniCeption/uniception/models/encoders/croco.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Encoder Class for CroCo & DUSt3R
3
+ """
4
+
5
+ from functools import partial
6
+ from typing import Callable, List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput
12
+ from uniception.models.libs.croco.blocks import Block
13
+ from uniception.models.libs.croco.patch_embed import get_patch_embed
14
+ from uniception.models.libs.croco.pos_embed import RoPE2D
15
+ from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner, feature_take_indices
16
+
17
+
18
+ class CroCoEncoder(UniCeptionViTEncoderBase):
19
+ "UniCeption CroCov2 Encoder"
20
+
21
+ def __init__(
22
+ self,
23
+ name: str,
24
+ data_norm_type: str,
25
+ patch_embed_cls: str = "PatchEmbedDust3R",
26
+ img_size: Union[int, Tuple[int, int]] = (224, 224),
27
+ patch_size: int = 16,
28
+ enc_embed_dim: int = 1024,
29
+ enc_depth: int = 24,
30
+ enc_num_heads: int = 16,
31
+ mlp_ratio: int = 4,
32
+ norm_layer: Callable = partial(nn.LayerNorm, eps=1e-6),
33
+ pos_embed: str = "RoPE100",
34
+ pretrained_checkpoint_path: str = None,
35
+ override_checkpoint_attributes: bool = False,
36
+ *args,
37
+ **kwargs,
38
+ ):
39
+ """
40
+ References: https://github.com/naver/dust3r, https://github.com/naver/croco
41
+
42
+ Args:
43
+ name (str): Name of the encoder.
44
+ data_norm_type (str): Input data normalization type.
45
+ patch_embed_cls (str, optional): The class to use for patch embedding.
46
+ Defaults to 'PatchEmbedDust3R'. Options: ['PatchEmbedCroCo', 'PatchEmbedDust3R', 'ManyAR_PatchEmbed'].
47
+ img_size (int, optional): The size of the input image. Defaults to 224.
48
+ patch_size (int, optional): The size of the patches to divide the image into. Defaults to 16.
49
+ enc_embed_dim (int, optional): The dimension of the encoder's embedding. Defaults to 768.
50
+ enc_depth (int, optional): The number of encoder layers/transformer blocks. Defaults to 12.
51
+ enc_num_heads (int, optional): The number of encoder heads. Defaults to 12.
52
+ mlp_ratio (int, optional): The MLP ratio used for the CroCo encoder transformer. Defaults to 4.
53
+ norm_layer (nn.Module, optional): The normalization layer to use in the transformer. Defaults to nn.LayerNorm with eps=1e-6.
54
+ pos_embed (str, optional): Positional Embedding. Defaults to 'RoPE100'. Options: ['RoPEfreq'].
55
+ pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. Defaults to None.
56
+ """
57
+ # Init the base class
58
+ super().__init__(
59
+ name=name,
60
+ data_norm_type=data_norm_type,
61
+ patch_size=patch_size,
62
+ *args,
63
+ **kwargs,
64
+ )
65
+
66
+ # Init the CroCo Encoder specific attributes
67
+ self.patch_embed_cls = patch_embed_cls
68
+ self.img_size = img_size
69
+ self.enc_embed_dim = enc_embed_dim
70
+ self.enc_depth = enc_depth
71
+ self.enc_num_heads = enc_num_heads
72
+ self.mlp_ratio = mlp_ratio
73
+ self.norm_layer = norm_layer
74
+ self.pretrained_checkpoint_path = pretrained_checkpoint_path
75
+ self.override_checkpoint_attributes = override_checkpoint_attributes
76
+
77
+ # Init the positional embedding
78
+ self.pos_embed = pos_embed
79
+ if pos_embed.startswith("RoPE"): # eg RoPE100
80
+ self.enc_pos_embed = None # nothing to add in the encoder with RoPE
81
+ self.dec_pos_embed = None # nothing to add in the decoder with RoPE
82
+ if RoPE2D is None:
83
+ raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions")
84
+ freq = float(pos_embed[len("RoPE") :])
85
+ self.rope = RoPE2D(freq=freq)
86
+ else:
87
+ raise NotImplementedError("Unknown pos_embed " + pos_embed)
88
+
89
+ # Init the patch embedding
90
+ self._set_patch_embed(img_size, patch_size, enc_embed_dim)
91
+
92
+ # Init the encoder
93
+ self._set_encoder(enc_depth, enc_embed_dim, enc_num_heads, mlp_ratio, norm_layer, self.rope)
94
+
95
+ # Initialize random weights
96
+ self.initialize_weights()
97
+
98
+ # Load the pretrained CroCo checkpoint if provided
99
+ if pretrained_checkpoint_path:
100
+ print(f"Loading pretrained CroCo checkpoint from {pretrained_checkpoint_path}")
101
+ ckpt = torch.load(pretrained_checkpoint_path, weights_only=False)
102
+ print(self.load_state_dict(ckpt["model"]))
103
+ if not override_checkpoint_attributes:
104
+ ckpt_data_norm_type = ckpt["data_norm_type"]
105
+ ckpt_patch_embed_cls = ckpt["patch_embed_cls"]
106
+ assert (
107
+ data_norm_type == ckpt_data_norm_type
108
+ ), f"Data normalization type {data_norm_type} does not match the checkpoint {ckpt_data_norm_type}."
109
+ assert (
110
+ patch_embed_cls == ckpt_patch_embed_cls
111
+ ), f"Patch embedding class {patch_embed_cls} does not match the checkpoint {ckpt_patch_embed_cls}."
112
+ else:
113
+ print("No pretrained checkpoint provided. Randomly initializing the CroCo encoder.")
114
+
115
+ def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
116
+ "Set the patch embedding scheme"
117
+ self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim)
118
+
119
+ def _set_encoder(self, enc_depth, enc_embed_dim, enc_num_heads, mlp_ratio, norm_layer, rope):
120
+ "Set the encoder"
121
+ self.enc_blocks = nn.ModuleList(
122
+ [
123
+ Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=rope)
124
+ for _ in range(enc_depth)
125
+ ]
126
+ )
127
+ self.enc_norm = norm_layer(enc_embed_dim)
128
+
129
+ def initialize_weights(self):
130
+ "Initialize the weights of the patch embedding and the transformer encoder"
131
+ # Patch embedding
132
+ self.patch_embed._init_weights()
133
+ # Linears and layer norms
134
+ self.apply(self._init_weights)
135
+
136
+ def _init_weights(self, m):
137
+ "Initialize the transformer encoder weights"
138
+ if isinstance(m, nn.Linear):
139
+ # We use xavier_uniform following official JAX ViT:
140
+ torch.nn.init.xavier_uniform_(m.weight)
141
+ if isinstance(m, nn.Linear) and m.bias is not None:
142
+ nn.init.constant_(m.bias, 0)
143
+ elif isinstance(m, nn.LayerNorm):
144
+ nn.init.constant_(m.bias, 0)
145
+ nn.init.constant_(m.weight, 1.0)
146
+
147
+ def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput:
148
+ """
149
+ CroCov2 Encoder Forward Pass
150
+
151
+ Args:
152
+ encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
153
+
154
+ Returns:
155
+ ViTEncoderOutput: Output data from the encoder.
156
+ """
157
+ # Check image normalization type
158
+ self._check_data_normalization_type(encoder_input.data_norm_type)
159
+
160
+ # Get the true shape of the image for landscape/portrait mode check in patch embedding
161
+ batch_size, _, height, width = encoder_input.image.shape
162
+ if hasattr(encoder_input, "true_shape"):
163
+ true_shape = encoder_input.true_shape
164
+ else:
165
+ true_shape = torch.tensor([height, width])[None].repeat(batch_size, 1)
166
+
167
+ # Embed the image into patches
168
+ features, pos = self.patch_embed(encoder_input.image, true_shape=true_shape)
169
+
170
+ # Now apply the transformer encoder and normalization
171
+ for blk in self.enc_blocks:
172
+ features = blk(features, pos)
173
+ features = self.enc_norm(features)
174
+
175
+ # Resize the features to the expected shape
176
+ # (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size)
177
+ features = features.permute(0, 2, 1)
178
+ features = features.reshape(
179
+ -1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size
180
+ ).contiguous()
181
+
182
+ return ViTEncoderOutput(features=features)
183
+
184
+
185
+ class CroCoIntermediateFeatureReturner(CroCoEncoder, IntermediateFeatureReturner):
186
+ "Intermediate Feature Returner for UniCeption CroCo Encoder"
187
+
188
+ def __init__(
189
+ self,
190
+ name: str,
191
+ data_norm_type: str,
192
+ patch_embed_cls: str = "PatchEmbedDust3R",
193
+ img_size: Union[int, Tuple[int, int]] = (224, 224),
194
+ patch_size: int = 16,
195
+ enc_embed_dim: int = 1024,
196
+ enc_depth: int = 24,
197
+ enc_num_heads: int = 16,
198
+ mlp_ratio: int = 4,
199
+ norm_layer: Callable = partial(nn.LayerNorm, eps=1e-6),
200
+ pos_embed: str = "RoPE100",
201
+ pretrained_checkpoint_path: str = None,
202
+ indices: Optional[Union[int, List[int]]] = None,
203
+ norm_intermediate: bool = True,
204
+ stop_early: bool = False,
205
+ intermediates_only: bool = True,
206
+ *args,
207
+ **kwargs,
208
+ ):
209
+ """
210
+ Intermediate Feature Returner for the CroCo Encoder.
211
+
212
+ Args:
213
+ name (str): Name of the encoder.
214
+ data_norm_type (str): Input data normalization type.
215
+ patch_embed_cls (str, optional): The class to use for patch embedding.
216
+ Defaults to 'PatchEmbedDust3R'. Options: ['PatchEmbedCroCo', 'PatchEmbedDust3R', 'ManyAR_PatchEmbed'].
217
+ img_size (int, optional): The size of the input image. Defaults to 224.
218
+ patch_size (int, optional): The size of the patches to divide the image into. Defaults to 16.
219
+ enc_embed_dim (int, optional): The dimension of the encoder's embedding. Defaults to 768.
220
+ enc_depth (int, optional): The number of encoder layers/transformer blocks. Defaults to 12.
221
+ enc_num_heads (int, optional): The number of encoder heads. Defaults to 12.
222
+ mlp_ratio (int, optional): The MLP ratio used for the CroCo encoder transformer. Defaults to 4.
223
+ norm_layer (nn.Module, optional): The normalization layer to use in the transformer. Defaults to nn.LayerNorm with eps=1e-6.
224
+ pos_embed (str, optional): Positional Embedding. Defaults to 'RoPE100'. Options: ['cosine', 'RoPE100'].
225
+ pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. Defaults to None.
226
+ indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. Defaults to None. Options:
227
+ - None: Return all intermediate layers.
228
+ - int: Return the last n layers.
229
+ - List[int]: Return the intermediate layers at the specified indices.
230
+ norm_intermediate (bool, optional): Whether to normalize the intermediate features. Defaults to True.
231
+ stop_early (bool, optional): Whether to stop early. Defaults to False.
232
+ intermediates_only (bool, optional): Whether to return only the intermediate features. Defaults to True.
233
+ """
234
+ # Init the base classes
235
+ CroCoEncoder.__init__(
236
+ self,
237
+ name=name,
238
+ data_norm_type=data_norm_type,
239
+ patch_embed_cls=patch_embed_cls,
240
+ img_size=img_size,
241
+ patch_size=patch_size,
242
+ enc_embed_dim=enc_embed_dim,
243
+ enc_depth=enc_depth,
244
+ enc_num_heads=enc_num_heads,
245
+ mlp_ratio=mlp_ratio,
246
+ norm_layer=norm_layer,
247
+ pos_embed=pos_embed,
248
+ pretrained_checkpoint_path=pretrained_checkpoint_path,
249
+ *args,
250
+ **kwargs,
251
+ )
252
+ IntermediateFeatureReturner.__init__(
253
+ self,
254
+ indices=indices,
255
+ norm_intermediate=norm_intermediate,
256
+ stop_early=stop_early,
257
+ intermediates_only=intermediates_only,
258
+ )
259
+
260
+ def forward(
261
+ self, encoder_input: ViTEncoderInput
262
+ ) -> Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]:
263
+ """
264
+ CroCov2 Encoder Forward Pass with Intermediate Feature Return
265
+
266
+ Args:
267
+ encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
268
+
269
+ Returns:
270
+ Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]: Output data from the encoder.
271
+ If `intermediates_only` is True, returns a list of intermediate features.
272
+ Otherwise, returns a tuple with the final features and a list of intermediate features.
273
+ """
274
+ # Check image normalization type
275
+ self._check_data_normalization_type(encoder_input.data_norm_type)
276
+
277
+ # Get the true shape of the image for landscape/portrait mode check in patch embedding
278
+ batch_size, _, height, width = encoder_input.image.shape
279
+ if hasattr(encoder_input, "true_shape"):
280
+ true_shape = encoder_input.true_shape
281
+ else:
282
+ true_shape = torch.tensor([height, width])[None].repeat(batch_size, 1)
283
+
284
+ # Embed the image into patches
285
+ features, pos = self.patch_embed(encoder_input.image, true_shape=true_shape)
286
+
287
+ # Get indices of the intermediate features to return
288
+ intermediate_features = []
289
+ take_indices, max_index = feature_take_indices(len(self.enc_blocks), self.indices)
290
+
291
+ # Get the blocks based on early stopping
292
+ if torch.jit.is_scripting() or not self.stop_early: # can't slice blocks in torchscript
293
+ blocks = self.enc_blocks
294
+ else:
295
+ blocks = self.enc_blocks[: max_index + 1]
296
+
297
+ # Now apply the transformer encoder and normalization
298
+ for blk_idx, blk in enumerate(blocks):
299
+ features = blk(features, pos)
300
+ if blk_idx in take_indices:
301
+ # Normalize intermediates with final norm layer if enabled
302
+ intermediate_features.append(self.enc_norm(features) if self.norm_intermediate else features)
303
+
304
+ # Reshape the intermediate features and convert to ViTEncoderOutput class
305
+ intermediate_features = [
306
+ intermediate.permute(0, 2, 1)
307
+ .reshape(-1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size)
308
+ .contiguous()
309
+ for intermediate in intermediate_features
310
+ ]
311
+ intermediate_features = [ViTEncoderOutput(features=intermediate) for intermediate in intermediate_features]
312
+
313
+ # Return only the intermediate features if enabled
314
+ if self.intermediates_only:
315
+ return intermediate_features
316
+
317
+ # Normalize and reshape the final features
318
+ features = self.enc_norm(features)
319
+ # Resize the features to the expected shape
320
+ # (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size)
321
+ features = features.permute(0, 2, 1)
322
+ features = features.reshape(
323
+ -1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size
324
+ ).contiguous()
325
+ final_features = ViTEncoderOutput(features=features)
326
+
327
+ return final_features, intermediate_features
328
+
329
+
330
+ if __name__ == "__main__":
331
+ # Init the pre-trained CroCo Encoder
332
+ pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_224.pth"
333
+ croco_encoder = CroCoEncoder(
334
+ name="croco",
335
+ data_norm_type="croco",
336
+ pretrained_checkpoint_path=pretrained_checkpoint_path,
337
+ patch_embed_cls="PatchEmbedCroCo",
338
+ )
339
+
340
+ # Init the pre-trained DUSt3R CroCo Encoder
341
+ pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_224_DUSt3R_linear.pth"
342
+ dust3r_encoder = CroCoEncoder(
343
+ name="dust3r_224",
344
+ data_norm_type="dust3r",
345
+ pretrained_checkpoint_path=pretrained_checkpoint_path,
346
+ patch_embed_cls="PatchEmbedDust3R",
347
+ )
348
+
349
+ # Init the pre-trained DUSt3R 512 linear CroCo Encoder
350
+ pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_512_DUSt3R_linear.pth"
351
+ dust3r_encoder_512 = CroCoEncoder(
352
+ name="dust3r_512",
353
+ data_norm_type="dust3r",
354
+ pretrained_checkpoint_path=pretrained_checkpoint_path,
355
+ patch_embed_cls="ManyAR_PatchEmbed",
356
+ img_size=(512, 512),
357
+ )
358
+
359
+ # Init the pre-trained DUSt3R 512 DPT CroCo Encoder
360
+ pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_512_DUSt3R_dpt.pth"
361
+ dust3r_encoder_512_dpt = CroCoEncoder(
362
+ name="dust3r_512_dpt",
363
+ data_norm_type="dust3r",
364
+ pretrained_checkpoint_path=pretrained_checkpoint_path,
365
+ patch_embed_cls="ManyAR_PatchEmbed",
366
+ img_size=(512, 512),
367
+ )
368
+
369
+ # Init the MASt3R 512 CroCo Encoder
370
+ pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_512_MASt3R.pth"
371
+ mast3r_encoder_512 = CroCoEncoder(
372
+ name="mast3r_512",
373
+ data_norm_type="dust3r",
374
+ pretrained_checkpoint_path=pretrained_checkpoint_path,
375
+ patch_embed_cls="ManyAR_PatchEmbed",
376
+ img_size=(512, 512),
377
+ )
378
+
379
+ print("All CroCo & DUSt3R Encoders have been initialized successfully!")
380
+
381
+ # Intermediate Feature Returner Tests
382
+ print("Running Intermediate Feature Returner Tests...")
383
+ pretrained_checkpoint_path = "../../../checkpoints/encoders/CroCo_Encoder_512_DUSt3R_dpt.pth"
384
+
385
+ # Run the intermediate feature returner with last-n index
386
+ dust3r_intermediate_feature_returner = CroCoIntermediateFeatureReturner(
387
+ name="dust3r_512_dpt",
388
+ data_norm_type="dust3r",
389
+ pretrained_checkpoint_path=pretrained_checkpoint_path,
390
+ patch_embed_cls="ManyAR_PatchEmbed",
391
+ img_size=(512, 512),
392
+ indices=6, # Last 6 layers
393
+ )
394
+ dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dust3r")
395
+ output = dust3r_intermediate_feature_returner(dummy_input)
396
+ assert isinstance(output, list), "Output must be a list of intermediate features"
397
+ assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
398
+ assert len(output) == 6, "Output must have length of intermediate features equal to the number of indices"
399
+
400
+ # Run the intermediate feature returner with specific indices
401
+ dust3r_intermediate_feature_returner = CroCoIntermediateFeatureReturner(
402
+ name="dust3r_512_dpt",
403
+ data_norm_type="dust3r",
404
+ pretrained_checkpoint_path=pretrained_checkpoint_path,
405
+ patch_embed_cls="ManyAR_PatchEmbed",
406
+ img_size=(512, 512),
407
+ indices=[0, 2, 4, 6], # Specific layers
408
+ )
409
+ dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dust3r")
410
+ output = dust3r_intermediate_feature_returner(dummy_input)
411
+ assert isinstance(output, list), "Output must be a list of intermediate features"
412
+ assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
413
+ assert len(output) == 4, "Output must have length of intermediate features equal to the number of indices"
414
+
415
+ # Test the normalizing of intermediate features
416
+ dust3r_intermediate_feature_returner = CroCoIntermediateFeatureReturner(
417
+ name="dust3r_512_dpt",
418
+ data_norm_type="dust3r",
419
+ pretrained_checkpoint_path=pretrained_checkpoint_path,
420
+ patch_embed_cls="ManyAR_PatchEmbed",
421
+ img_size=(512, 512),
422
+ indices=[-1],
423
+ norm_intermediate=False,
424
+ intermediates_only=False,
425
+ )
426
+ dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dust3r")
427
+ output = dust3r_intermediate_feature_returner(dummy_input)
428
+ assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features"
429
+ assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features"
430
+ assert isinstance(output[1], list), "Second element of output must be a list of intermediate features"
431
+ assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
432
+ if not isinstance(dust3r_intermediate_feature_returner.enc_norm, torch.nn.Identity):
433
+ assert not torch.equal(
434
+ output[0].features, output[1][0].features
435
+ ), "Final features and intermediate features must be different"
436
+
437
+ dust3r_intermediate_feature_returner = CroCoIntermediateFeatureReturner(
438
+ name="dust3r_512_dpt",
439
+ data_norm_type="dust3r",
440
+ pretrained_checkpoint_path=pretrained_checkpoint_path,
441
+ patch_embed_cls="ManyAR_PatchEmbed",
442
+ img_size=(512, 512),
443
+ indices=[-1],
444
+ norm_intermediate=True,
445
+ intermediates_only=False,
446
+ )
447
+ dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dust3r")
448
+ output = dust3r_intermediate_feature_returner(dummy_input)
449
+ assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features"
450
+ assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features"
451
+ assert isinstance(output[1], list), "Second element of output must be a list of intermediate features"
452
+ assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
453
+ assert torch.equal(
454
+ output[0].features, output[1][0].features
455
+ ), "Final features and intermediate features must be same"
456
+
457
+ print("All Intermediate Feature Returner Tests have passed successfully!")
UniCeption/uniception/models/encoders/dense_rep_encoder.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Encoder class for Dense Representation Encoder
3
+ """
4
+
5
+ import math
6
+ from functools import partial
7
+ from typing import Callable, List, Optional, Tuple, Type, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.nn.init import trunc_normal_
13
+
14
+ from uniception.models.encoders.base import (
15
+ UniCeptionViTEncoderBase,
16
+ ViTEncoderInput,
17
+ ViTEncoderNonImageInput,
18
+ ViTEncoderOutput,
19
+ )
20
+
21
+
22
+ def make_2tuple(x):
23
+ if isinstance(x, tuple):
24
+ assert len(x) == 2
25
+ return x
26
+
27
+ assert isinstance(x, int)
28
+ return (x, x)
29
+
30
+
31
+ class ResidualBlock(nn.Module):
32
+ "Redidual block for Dense Representation Encoder"
33
+
34
+ def __init__(self, in_channels: int, out_channels: int, act_layer: Type[nn.Module] = nn.GELU):
35
+ super(ResidualBlock, self).__init__()
36
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
37
+ self.act = act_layer()
38
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
39
+ self.shortcut = (
40
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
41
+ if in_channels != out_channels
42
+ else nn.Identity()
43
+ )
44
+
45
+ def forward(self, x):
46
+ identity = self.shortcut(x)
47
+ out = self.conv1(x)
48
+ out = self.act(out)
49
+ out = self.conv2(out)
50
+ out += identity
51
+
52
+ return self.act(out)
53
+
54
+
55
+ class DenseRepresentationEncoder(UniCeptionViTEncoderBase):
56
+ "UniCeption Dense Representation Encoder"
57
+
58
+ def __init__(
59
+ self,
60
+ name: str,
61
+ in_chans: int = 3,
62
+ enc_embed_dim: int = 1024,
63
+ apply_pe: bool = True,
64
+ input_size_for_pe: Union[int, Tuple[int, int]] = 518,
65
+ patch_size: int = 14,
66
+ intermediate_dims: List[int] = [588, 768, 1024],
67
+ data_norm_type: str = "dense_rep_encoder",
68
+ act_layer: Type[nn.Module] = nn.GELU,
69
+ norm_layer: Optional[Callable] = partial(nn.LayerNorm, eps=1e-6),
70
+ post_pe_norm_layer: Optional[Callable] = partial(nn.LayerNorm, eps=1e-6),
71
+ interpolate_antialias: bool = False,
72
+ interpolate_offset: float = 0.1,
73
+ pretrained_checkpoint_path: str = None,
74
+ *args,
75
+ **kwargs,
76
+ ):
77
+ """
78
+ Dense Representation Encoder for extracting patch-wise features from a spatial input of size (B, C, H, W).
79
+ Uses a convolution based patchify followed by some residual blocks.
80
+ Also applies positional encoding with interpolation to the patch-wise features if required.
81
+
82
+ Args:
83
+ in_chans (int): Number of input channels.
84
+ enc_embed_dim (int): Embedding dimension of the encoder.
85
+ apply_pe (bool): Whether to apply positional encoding.
86
+ input_size_for_pe (Union[int, Tuple[int, int]]): Input size for positional encoding.
87
+ patch_size (int): Patch size of the encoder.
88
+ intermediate_dims (List[int]): Intermediate dimensions of the encoder.
89
+ data_norm_type (str): Data normalization type. (Used for checking if the input images are normalized correctly.)
90
+ act_layer (Type[nn.Module]): Activation layer.
91
+ norm_layer (Optional[Callable]): Normalization layer.
92
+ post_pe_norm_layer (Optional[Callable]): Normalization layer after positional encoding.
93
+ interpolate_antialias (bool): Whether to apply antialiasing in interpolation.
94
+ interpolate_offset (float): Offset for interpolation.
95
+ pretrained_checkpoint_path (str): Path to pretrained checkpoint.
96
+ """
97
+ # Init the base class
98
+ super().__init__(
99
+ name=name,
100
+ data_norm_type=data_norm_type,
101
+ patch_size=patch_size,
102
+ *args,
103
+ **kwargs,
104
+ )
105
+
106
+ # Init the specific attributes
107
+ self.in_chans = in_chans
108
+ self.enc_embed_dim = enc_embed_dim
109
+ self.intermediate_dims = intermediate_dims
110
+ self.apply_pe = apply_pe
111
+
112
+ # Initialize the encoder with a pixel unshuffle and conv projection to patchify the input
113
+ self.unshuffle = nn.PixelUnshuffle(self.patch_size)
114
+ self.conv_in = nn.Conv2d(self.in_chans * (self.patch_size**2), self.intermediate_dims[0], 3, 1, 1)
115
+
116
+ # Add residual blocks
117
+ layers = []
118
+ for intermediate_idx in range(len(self.intermediate_dims) - 1):
119
+ layers.append(
120
+ ResidualBlock(
121
+ in_channels=self.intermediate_dims[intermediate_idx],
122
+ out_channels=self.intermediate_dims[intermediate_idx + 1],
123
+ act_layer=act_layer,
124
+ )
125
+ )
126
+
127
+ # Final projection to match encoder embeddings dim
128
+ layers.append(
129
+ nn.Conv2d(
130
+ in_channels=self.intermediate_dims[-1],
131
+ out_channels=self.enc_embed_dim,
132
+ kernel_size=1,
133
+ stride=1,
134
+ padding=0,
135
+ )
136
+ )
137
+ self.encoder = nn.Sequential(*layers)
138
+
139
+ # Init norm layer after encoder if required
140
+ self.norm_layer = norm_layer(enc_embed_dim) if norm_layer else nn.Identity()
141
+ if isinstance(self.norm_layer, nn.LayerNorm):
142
+ nn.init.constant_(self.norm_layer.bias, 0)
143
+ nn.init.constant_(self.norm_layer.weight, 1.0)
144
+
145
+ if self.apply_pe:
146
+ # Init the patch resolution details required for positional encoding
147
+ patch_HW = make_2tuple(patch_size)
148
+ self.input_size_for_pe = make_2tuple(input_size_for_pe)
149
+ self.patches_resolution = (
150
+ self.input_size_for_pe[0] // patch_HW[0],
151
+ self.input_size_for_pe[1] // patch_HW[1],
152
+ )
153
+ self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
154
+
155
+ # Init the sinusodial positional encodings
156
+ self.register_buffer(
157
+ "pos_embed",
158
+ self._get_sinusoid_encoding_table(self.num_patches, self.enc_embed_dim, 70007),
159
+ )
160
+ self.interpolate_antialias = interpolate_antialias
161
+ self.interpolate_offset = interpolate_offset
162
+
163
+ # Init the norm layer after positional encoding if required
164
+ self.post_pe_norm = post_pe_norm_layer(enc_embed_dim) if post_pe_norm_layer else nn.Identity()
165
+ if isinstance(self.post_pe_norm, nn.LayerNorm):
166
+ nn.init.constant_(self.post_pe_norm.bias, 0)
167
+ nn.init.constant_(self.post_pe_norm.weight, 1.0)
168
+
169
+ # Load the pretrained checkpoint if provided
170
+ self.pretrained_checkpoint_path = pretrained_checkpoint_path
171
+ if self.pretrained_checkpoint_path:
172
+ print(
173
+ f"Loading custom pretrained Dense Representation Encoder checkpoint from {self.pretrained_checkpoint_path} ..."
174
+ )
175
+ ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
176
+ print(self.load_state_dict(ckpt["model"]))
177
+
178
+ def _get_sinusoid_encoding_table(self, n_position, d_hid, base):
179
+ "Sinusoid position encoding table"
180
+
181
+ def get_position_angle_vec(position):
182
+ return [position / np.power(base, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
183
+
184
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
185
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
186
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])
187
+
188
+ return torch.FloatTensor(sinusoid_table)
189
+
190
+ def interpolate_pos_encoding(self, features, height, width):
191
+ """
192
+ Interpolate the positional encoding to the expected size.
193
+
194
+ Args:
195
+ features (torch.Tensor): Input tensor of shape (B, N, C).
196
+ height (int, float): Height of the input tensor.
197
+ width (int, float): Width of the input tensor.
198
+
199
+ Returns:
200
+ torch.Tensor: Interpolated positional encoding tensor of shape (1, N, C).
201
+ """
202
+ previous_dtype = features.dtype
203
+ npatch = features.shape[1]
204
+ N = self.pos_embed.unsqueeze(0).shape[1]
205
+ if npatch == N and height == width:
206
+ return self.pos_embed.unsqueeze(0)
207
+ patch_pos_embed = self.pos_embed.unsqueeze(0).float()
208
+ dim = features.shape[-1]
209
+ height0 = height // self.patch_size
210
+ width0 = width // self.patch_size
211
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
212
+ assert N == M * M
213
+ kwargs = {}
214
+ if self.interpolate_offset:
215
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
216
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
217
+ sh = float(height0 + self.interpolate_offset) / M
218
+ sw = float(width0 + self.interpolate_offset) / M
219
+ kwargs["scale_factor"] = (sh, sw)
220
+ else:
221
+ # Simply specify an output size instead of a scale factor
222
+ kwargs["size"] = (height0, width0)
223
+ patch_pos_embed = nn.functional.interpolate(
224
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
225
+ mode="bicubic",
226
+ antialias=self.interpolate_antialias,
227
+ **kwargs,
228
+ )
229
+ assert (height0, width0) == patch_pos_embed.shape[-2:]
230
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
231
+
232
+ return patch_pos_embed.to(previous_dtype)
233
+
234
+ def forward(self, encoder_input: Union[ViTEncoderInput, ViTEncoderNonImageInput]) -> ViTEncoderOutput:
235
+ """
236
+ Dense Representation Encoder Forward Pass
237
+
238
+ Args:
239
+ encoder_input (Union[ViTEncoderInput, ViTEncoderNonImageInput]): Input data for the encoder.
240
+ If input type is ViTEncoderInput, input data must contain image normalization type and normalized image tensor.
241
+ If input type is ViTEncoderNonImageInput, input data must contain a tensor of size (B, C, H, W).
242
+
243
+ Returns:
244
+ ViTEncoderOutput: Output data from the encoder.
245
+ """
246
+ # Get the input data and verify normalization if the input type is ViTEncoderInput
247
+ if isinstance(encoder_input, ViTEncoderInput):
248
+ self._check_data_normalization_type(encoder_input.data_norm_type)
249
+ input_data = encoder_input.image
250
+ elif isinstance(encoder_input, ViTEncoderNonImageInput):
251
+ input_data = encoder_input.data
252
+ else:
253
+ raise ValueError("Unsupported input type for Dense Representation Encoder.")
254
+
255
+ # Check the dtype and shape of the input
256
+ assert isinstance(input_data, torch.Tensor), "Input must be a torch.Tensor"
257
+ assert input_data.ndim == 4, "Input must be of shape (B, C, H, W)"
258
+ assert input_data.shape[1] == self.in_chans, f"Input channels must be {self.in_chans}"
259
+ batch_size, channels, height, width = input_data.shape
260
+ assert (
261
+ height % self.patch_size == 0 and width % self.patch_size == 0
262
+ ), f"Input shape must be divisible by patch size: {self.patch_size}"
263
+
264
+ # Encode the dense representation
265
+ features = self.unshuffle(input_data)
266
+ features = self.conv_in(features)
267
+ features = self.encoder(features)
268
+ features = features.flatten(2).transpose(
269
+ 1, 2
270
+ ) # (B, E, H / Patch_Size, W / Patch_Size) -> (B, H / Patch_Size * W / Patch_Size, E)
271
+ features = self.norm_layer(features) # Normalize the features after patch encoding
272
+
273
+ # Apply positional encoding if required
274
+ if self.apply_pe:
275
+ features = features + self.interpolate_pos_encoding(
276
+ features, height, width
277
+ ) # (B, H / Patch_Size * W / Patch_Size, E)
278
+ features = self.post_pe_norm(features) # Normalize the features after positional encoding
279
+
280
+ # Resize the features to the expected shape
281
+ # (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size)
282
+ features = features.permute(0, 2, 1)
283
+ features = features.reshape(
284
+ -1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size
285
+ ).contiguous()
286
+
287
+ return ViTEncoderOutput(features=features)
288
+
289
+
290
+ if __name__ == "__main__":
291
+ # Init Dense Representation Encoder for images as input
292
+ patch_embedder = DenseRepresentationEncoder(
293
+ name="dense_rep_encoder",
294
+ data_norm_type="dense_rep_encoder",
295
+ input_size_for_pe=518,
296
+ patch_size=14,
297
+ in_chans=3,
298
+ enc_embed_dim=1024,
299
+ apply_pe=False,
300
+ )
301
+
302
+ # Test dummy image input
303
+ dummy_image = torch.randn(1, 3, 518, 518)
304
+ patch_embedder_output = patch_embedder(ViTEncoderInput(data_norm_type="dense_rep_encoder", image=dummy_image))
305
+ assert patch_embedder_output.features.shape == (
306
+ 1,
307
+ 1024,
308
+ 37,
309
+ 37,
310
+ ), "Output features must have shape (1, 1024, 37, 37)"
311
+
312
+ # Init Dense Representation Encoder for non-image data as input
313
+ patch_embedder = DenseRepresentationEncoder(
314
+ name="dense_rep_encoder",
315
+ data_norm_type="dense_rep_encoder",
316
+ input_size_for_pe=518,
317
+ patch_size=14,
318
+ in_chans=6,
319
+ enc_embed_dim=1024,
320
+ )
321
+
322
+ # Init Dense Representation Encoder for single channel input
323
+ patch_embedder = DenseRepresentationEncoder(
324
+ name="dense_rep_encoder",
325
+ data_norm_type="dense_rep_encoder",
326
+ input_size_for_pe=518,
327
+ patch_size=14,
328
+ in_chans=1,
329
+ enc_embed_dim=1024,
330
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
331
+ apply_pe=True,
332
+ )
333
+
334
+ # Test dummy non-image input
335
+ dummy_image = torch.randn(1, 1, 980, 980)
336
+ patch_embedder_output = patch_embedder(ViTEncoderNonImageInput(data=dummy_image))
337
+ assert patch_embedder_output.features.shape == (
338
+ 1,
339
+ 1024,
340
+ 70,
341
+ 70,
342
+ ), "Output features must have shape (1, 1024, 70, 70)"
343
+
344
+ print("All variants of Dense Representation Encoder have been initialized successfully!")
UniCeption/uniception/models/encoders/dinov2.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Encoder Class for DINOv2
3
+ """
4
+
5
+ from typing import List, Optional, Union
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput
12
+ from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner
13
+
14
+
15
+ class DINOv2Encoder(UniCeptionViTEncoderBase):
16
+ "UniCeption DINOv2 Encoder"
17
+
18
+ def __init__(
19
+ self,
20
+ name: str,
21
+ data_norm_type: str = "dinov2",
22
+ patch_size: int = 14,
23
+ size: str = "large",
24
+ with_registers: bool = False,
25
+ pretrained_checkpoint_path: str = None,
26
+ torch_hub_force_reload: bool = False,
27
+ gradient_checkpointing: bool = False,
28
+ keep_first_n_layers: Optional[int] = None,
29
+ use_pytorch_sdpa=True,
30
+ *args,
31
+ **kwargs,
32
+ ):
33
+ """
34
+ DINOv2 Encoder for extracting spatial features from images.
35
+
36
+ Args:
37
+ name (str): Name of the encoder.
38
+ data_norm_type (str): Image normalization type. Default: "dinov2"
39
+ patch_size (int): Patch size for the encoder. Default: 14
40
+ size (str): Size variant of the DINOv2 model. Options: ["small", "base", "large", "giant"]. Default: "large"
41
+ with_registers (bool): Whether to use the DINOv2 model with registers. Default: False
42
+ pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of DINOv2. Default: None
43
+ torch_hub_force_reload (bool): Whether to force reload the model from torch hub. Default: False
44
+ gradient_checkpointing (bool): Whether to use gradient checkpointing to save GPU memory during backward call. Default: False
45
+ keep_first_n_layers (Optional[int]): If specified, only the first n layers of the model will be kept. Default: None
46
+ use_pytorch_sdpa (bool): Whether to use PyTorch native SDPA for attention layers. Default: True
47
+ """
48
+ # Init the base class
49
+ name = name if not with_registers else f"{name}_reg"
50
+ super().__init__(
51
+ name=name,
52
+ data_norm_type=data_norm_type,
53
+ patch_size=patch_size,
54
+ gradient_checkpointing=gradient_checkpointing,
55
+ *args,
56
+ **kwargs,
57
+ )
58
+
59
+ # Init the DINOv2 Encoder specific attributes
60
+ self.version = size
61
+ self.with_registers = with_registers
62
+ self.enc_embed_dim = {"small": 384, "base": 768, "large": 1024, "giant": 1536}[self.version]
63
+
64
+ # Define DINOv2 model factory
65
+ DINO_MODELS = {
66
+ # No registers
67
+ False: {
68
+ "small": "dinov2_vits14",
69
+ "base": "dinov2_vitb14",
70
+ "large": "dinov2_vitl14",
71
+ "giant": "dinov2_vitg14",
72
+ },
73
+ # With registers
74
+ True: {
75
+ "small": "dinov2_vits14_reg",
76
+ "base": "dinov2_vitb14_reg",
77
+ "large": "dinov2_vitl14_reg",
78
+ "giant": "dinov2_vitg14_reg",
79
+ },
80
+ }
81
+
82
+ # Load the pretrained DINOv2 model from torch hub
83
+ print(f"Loading pretrained {DINO_MODELS[self.with_registers][self.version]} from torch hub")
84
+ try: # Requires internet access
85
+ self.model = torch.hub.load(
86
+ "facebookresearch/dinov2",
87
+ DINO_MODELS[self.with_registers][self.version],
88
+ force_reload=torch_hub_force_reload,
89
+ )
90
+ except: # Load from cache
91
+ self.model = torch.hub.load("facebookresearch/dinov2", DINO_MODELS[self.with_registers][self.version])
92
+
93
+ del (
94
+ self.model.mask_token
95
+ ) # This parameter is unused in producing patch features, and will lead to unused parameters
96
+
97
+ # Keep only the first n layers of the model if keep_first_n_layers is specified
98
+ if keep_first_n_layers is not None:
99
+ self.model.blocks = nn.ModuleList(self.model.blocks[:keep_first_n_layers])
100
+
101
+ # Use Native Torch SDPA for attention layers if specified (instead of DINOv2's XFormers)
102
+ if use_pytorch_sdpa:
103
+ self.enable_pytorch_native_sdpa()
104
+
105
+ # Wrap the transformer blocks with support for gradient checkpointing if required
106
+ if self.gradient_checkpointing:
107
+ for i in range(len(self.model.blocks)):
108
+ self.model.blocks[i] = self.wrap_module_with_gradient_checkpointing(self.model.blocks[i])
109
+
110
+ # Load the custom pretrained checkpoint if provided
111
+ if pretrained_checkpoint_path:
112
+ print(f"Loading custom pretrained DINOv2 checkpoint from {pretrained_checkpoint_path}")
113
+ ckpt = torch.load(pretrained_checkpoint_path, weights_only=False)
114
+ print(self.load_state_dict(ckpt["model"]))
115
+
116
+ def enable_pytorch_native_sdpa(self):
117
+ "Enable PyTorch native SDPA for attention layers"
118
+ for i in range(len(self.model.blocks)):
119
+ self.model.blocks[i].attn = self.wrap_dinov2_attention_with_sdpa(self.model.blocks[i].attn)
120
+
121
+ def wrap_dinov2_attention_with_sdpa(self, module: nn.Module):
122
+ "Wrap DINOv2 attention module with PyTorch native SDPA"
123
+ assert torch.__version__ >= "2.0", "SDPA requires PyTorch 2.0 or later"
124
+
125
+ class _AttentionWrapper(module.__class__):
126
+ "SDPA Attention Wrapper Class"
127
+
128
+ def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
129
+ B, N, C = x.shape
130
+ qkv = (
131
+ self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
132
+ ) # (3, B, H, N, C // H)
133
+
134
+ q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
135
+
136
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias)
137
+ x = x.permute(0, 2, 1, 3).reshape(B, N, C)
138
+
139
+ x = self.proj(x)
140
+ x = self.proj_drop(x)
141
+ return x
142
+
143
+ module.__class__ = _AttentionWrapper
144
+ return module
145
+
146
+ def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput:
147
+ """
148
+ DINOv2 Encoder Forward Pass
149
+
150
+ Args:
151
+ encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
152
+
153
+ Returns:
154
+ ViTEncoderOutput: Output data from the encoder.
155
+ """
156
+ # Check image normalization type
157
+ self._check_data_normalization_type(encoder_input.data_norm_type)
158
+
159
+ # Check the dtype and shape of the input image
160
+ assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor"
161
+ assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)"
162
+ batch_size, channels, height, width = encoder_input.image.shape
163
+ assert channels == 3, "Input must have 3 channels"
164
+ assert (
165
+ height % self.patch_size == 0 and width % self.patch_size == 0
166
+ ), f"Input shape must be divisible by patch size: {self.patch_size}"
167
+
168
+ # Extract the features from the DINOv2 model
169
+ features = self.model.forward_features(encoder_input.image)["x_norm_patchtokens"]
170
+
171
+ # Resize the features to the expected shape
172
+ # (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size)
173
+ features = features.permute(0, 2, 1)
174
+ features = features.reshape(
175
+ -1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size
176
+ ).contiguous()
177
+
178
+ return ViTEncoderOutput(features=features)
179
+
180
+
181
+ class DINOv2IntermediateFeatureReturner(DINOv2Encoder, IntermediateFeatureReturner):
182
+ "Intermediate Feature Returner for UniCeption DINOv2 Encoder"
183
+
184
+ def __init__(
185
+ self,
186
+ name: str,
187
+ data_norm_type: str = "dinov2",
188
+ patch_size: int = 14,
189
+ size: str = "large",
190
+ with_registers: bool = False,
191
+ pretrained_checkpoint_path: str = None,
192
+ indices: Optional[Union[int, List[int]]] = 1,
193
+ keep_first_n_layers: Optional[int] = None,
194
+ norm_intermediate: bool = True,
195
+ *args,
196
+ **kwargs,
197
+ ):
198
+ """
199
+ DINOv2 Encoder for extracting spatial features from images.
200
+
201
+ Args:
202
+ name (str): Name of the encoder.
203
+ data_norm_type (str): Image normalization type. Default: "dinov2"
204
+ patch_size (int): Patch size for the encoder. Default: 14
205
+ size (str): Size variant of the DINOv2 model. Options: ["small", "base", "large", "giant"]
206
+ with_registers (bool): Whether to use the DINOv2 model with registers.
207
+ pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of DINOv2.
208
+ indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. Defaults to 1. Options:
209
+ - int: Return the last n layers.
210
+ - List[int]: Return the intermediate layers at the specified indices.
211
+ keep_first_n_layers (Optional[int], optional): If specified, only the first n layers of the model will be kept. Defaults to None.
212
+ norm_intermediate (bool, optional): Whether to normalize the intermediate features. Defaults to True.
213
+ """
214
+ # Init the base classes
215
+ DINOv2Encoder.__init__(
216
+ self,
217
+ name=name,
218
+ data_norm_type=data_norm_type,
219
+ patch_size=patch_size,
220
+ size=size,
221
+ with_registers=with_registers,
222
+ keep_first_n_layers=keep_first_n_layers,
223
+ pretrained_checkpoint_path=pretrained_checkpoint_path,
224
+ *args,
225
+ **kwargs,
226
+ )
227
+ IntermediateFeatureReturner.__init__(
228
+ self,
229
+ indices=indices,
230
+ norm_intermediate=norm_intermediate,
231
+ )
232
+
233
+ def forward(self, encoder_input: ViTEncoderInput) -> List[ViTEncoderOutput]:
234
+ """
235
+ DINOv2 Encoder Forward Pass with Intermediate Feature Return
236
+
237
+ Args:
238
+ encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
239
+
240
+ Returns:
241
+ List[ViTEncoderOutput]: Output data from the encoder. Returns a list of intermediate features.
242
+ """
243
+ # Check image normalization type
244
+ self._check_data_normalization_type(encoder_input.data_norm_type)
245
+
246
+ # Check the dtype and shape of the input image
247
+ assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor"
248
+ assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)"
249
+ batch_size, channels, height, width = encoder_input.image.shape
250
+ assert channels == 3, "Input must have 3 channels"
251
+ assert (
252
+ height % self.patch_size == 0 and width % self.patch_size == 0
253
+ ), f"Input shape must be divisible by patch size: {self.patch_size}"
254
+
255
+ if self.indices is None:
256
+ self.indices = range(len(self.model.blocks))
257
+
258
+ # Extract the intermediate features from the DINOv2 model
259
+ intermediate_features = self.model.get_intermediate_layers(
260
+ encoder_input.image, n=self.indices, reshape=True, norm=self.norm_intermediate
261
+ )
262
+
263
+ # Convert the intermediate features to a list of ViTEncoderOutput
264
+ intermediate_features = [ViTEncoderOutput(features=features) for features in intermediate_features]
265
+
266
+ return intermediate_features
267
+
268
+
269
+ if __name__ == "__main__":
270
+ # Init different variants of DINOv2
271
+ for size in ["small", "base", "large", "giant"]:
272
+ for with_registers in [False, True]:
273
+ name = f"dinov2_{size}"
274
+ dinov2_encoder = DINOv2Encoder(name=name, size=size, with_registers=with_registers)
275
+
276
+ # Init the custom pretrained DINOv2 encoders
277
+ for size in ["small", "base", "large"]:
278
+ pretrained_checkpoints_dict = {
279
+ "small": "../../../checkpoints/encoders/DINOv2_ViTS_DepthAnythingV2.pth",
280
+ "base": "../../../checkpoints/encoders/DINOv2_ViTB_DepthAnythingV2.pth",
281
+ "large": "../../../checkpoints/encoders/DINOv2_ViTL_DepthAnythingV2.pth",
282
+ }
283
+ name = f"dinov2_dav2_{size}"
284
+ dinov2_encoder = DINOv2Encoder(
285
+ name=name, size=size, with_registers=False, pretrained_checkpoint_path=pretrained_checkpoints_dict[size]
286
+ )
287
+
288
+ print("All DINOv2 Encoders have been initialized successfully!")
289
+
290
+ # Intermediate Feature Returner Tests
291
+ print("Running Intermediate Feature Returner Tests...")
292
+
293
+ # Run the intermediate feature returner with last-n index
294
+ dinov2_intermediate_feature_returner = DINOv2IntermediateFeatureReturner(
295
+ name="dinov2_base", size="base", indices=6
296
+ ) # Last 6 layers
297
+ dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dinov2")
298
+ output = dinov2_intermediate_feature_returner(dummy_input)
299
+ assert isinstance(output, list), "Output must be a list of intermediate features"
300
+ assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
301
+ assert len(output) == 6, "Output must have length of intermediate features equal to the number of indices"
302
+
303
+ # Run the intermediate feature returner with specific indices
304
+ dinov2_intermediate_feature_returner = DINOv2IntermediateFeatureReturner(
305
+ name="dinov2_base", size="base", indices=[0, 2, 4, 6]
306
+ ) # Specific layers
307
+ dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="dinov2")
308
+ output = dinov2_intermediate_feature_returner(dummy_input)
309
+ assert isinstance(output, list), "Output must be a list of intermediate features"
310
+ assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
311
+ assert len(output) == 4, "Output must have length of intermediate features equal to the number of indices"
312
+
313
+ print("All Intermediate Feature Returner Tests have passed successfully!")
314
+
315
+ from uniception.models.encoders.utils import profile_encoder
316
+
317
+ torch.backends.cuda.matmul.allow_tf32 = True
318
+ torch.backends.cudnn.allow_tf32 = True
319
+
320
+ # Profile the DINOv2 Encoder
321
+ dinov2_encoder = DINOv2Encoder(
322
+ name="dinov2_large", size="large", use_pytorch_sdpa=True, gradient_checkpointing=True, keep_first_n_layers=12
323
+ ).cuda()
324
+ dummy_input = ViTEncoderInput(image=torch.randn(24, 3, 560, 420).cuda(), data_norm_type="dinov2")
325
+
326
+ class Profiler:
327
+ @profile_encoder(num_warmup=3, num_runs=20, autocast_precision="bfloat16", use_compile=True, dynamic=False)
328
+ def run_fn(self):
329
+ output = dinov2_encoder(dummy_input)
330
+ return output
331
+
332
+ profiler = Profiler()
333
+ profiler.run_fn()
UniCeption/uniception/models/encoders/global_rep_encoder.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Encoder class for Global Representation Encoder
3
+ """
4
+
5
+ from functools import partial
6
+ from typing import Callable, List, Optional, Type, Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from uniception.models.encoders.base import EncoderGlobalRepInput, EncoderGlobalRepOutput
12
+
13
+
14
+ class GlobalRepresentationEncoder(nn.Module):
15
+ "UniCeption Global Representation Encoder"
16
+
17
+ def __init__(
18
+ self,
19
+ name: str,
20
+ in_chans: int = 3,
21
+ enc_embed_dim: int = 1024,
22
+ intermediate_dims: List[int] = [128, 256, 512],
23
+ act_layer: Type[nn.Module] = nn.GELU,
24
+ norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(nn.LayerNorm, eps=1e-6),
25
+ pretrained_checkpoint_path: Optional[str] = None,
26
+ *args,
27
+ **kwargs,
28
+ ):
29
+ """
30
+ Global Representation Encoder for projecting a global representation to a desired latent dimension.
31
+
32
+ Args:
33
+ name (str): Name of the Encoder.
34
+ in_chans (int): Number of input channels.
35
+ enc_embed_dim (int): Embedding dimension of the encoder.
36
+ intermediate_dims (List[int]): List of intermediate dimensions of the encoder.
37
+ act_layer (Type[nn.Module]): Activation layer to use in the encoder.
38
+ norm_layer (Union[Type[nn.Module], Callable[..., nn.Module]]): Final normalization layer to use in the encoder.
39
+ pretrained_checkpoint_path (Optional[str]): Path to pretrained checkpoint. (default: None)
40
+ """
41
+ super().__init__(*args, **kwargs)
42
+
43
+ # Initialize the attributes
44
+ self.name = name
45
+ self.in_chans = in_chans
46
+ self.enc_embed_dim = enc_embed_dim
47
+ self.intermediate_dims = intermediate_dims
48
+ self.pretrained_checkpoint_path = pretrained_checkpoint_path
49
+
50
+ # Init the activation layer
51
+ self.act_layer = act_layer()
52
+
53
+ # Initialize the encoder
54
+ self.encoder = nn.Sequential(
55
+ nn.Linear(self.in_chans, self.intermediate_dims[0]),
56
+ self.act_layer,
57
+ )
58
+ for intermediate_idx in range(1, len(self.intermediate_dims)):
59
+ self.encoder = nn.Sequential(
60
+ self.encoder,
61
+ nn.Linear(self.intermediate_dims[intermediate_idx - 1], self.intermediate_dims[intermediate_idx]),
62
+ self.act_layer,
63
+ )
64
+ self.encoder = nn.Sequential(
65
+ self.encoder,
66
+ nn.Linear(self.intermediate_dims[-1], self.enc_embed_dim),
67
+ )
68
+
69
+ # Init weights of the final norm layer
70
+ self.norm_layer = norm_layer(enc_embed_dim) if norm_layer else nn.Identity()
71
+ if isinstance(self.norm_layer, nn.LayerNorm):
72
+ nn.init.constant_(self.norm_layer.bias, 0)
73
+ nn.init.constant_(self.norm_layer.weight, 1.0)
74
+
75
+ # Load pretrained weights if provided
76
+ if self.pretrained_checkpoint_path is not None:
77
+ print(
78
+ f"Loading pretrained Global Representation Encoder checkpoint from {self.pretrained_checkpoint_path} ..."
79
+ )
80
+ ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
81
+ print(self.load_state_dict(ckpt["model"]))
82
+
83
+ def forward(self, encoder_input: EncoderGlobalRepInput) -> EncoderGlobalRepOutput:
84
+ """
85
+ Global Representation Encoder Forward Pass
86
+
87
+ Args:
88
+ encoder_input (EncoderGlobalRepInput): Input data for the encoder.
89
+ The provided data must contain a tensor of size (B, C).
90
+
91
+ Returns:
92
+ EncoderGlobalRepOutput: Output features from the encoder.
93
+ """
94
+ # Get the input data and verify the shape of the input
95
+ input_data = encoder_input.data
96
+ assert input_data.ndim == 2, "Input data must have shape (B, C)"
97
+ assert input_data.shape[1] == self.in_chans, f"Input data must have {self.in_chans} channels"
98
+
99
+ # Encode the global representation
100
+ features = self.encoder(input_data)
101
+
102
+ # Normalize the output
103
+ features = self.norm_layer(features)
104
+
105
+ return EncoderGlobalRepOutput(features=features)
106
+
107
+
108
+ if __name__ == "__main__":
109
+ dummy_model = GlobalRepresentationEncoder(
110
+ name="dummy", in_chans=3, enc_embed_dim=1024, intermediate_dims=[128, 256, 512]
111
+ )
112
+ dummy_input = EncoderGlobalRepInput(data=torch.randn(4, 3))
113
+ dummy_output = dummy_model(dummy_input)
114
+ assert dummy_output.features.shape == (4, 1024), "Output features must have shape (B, 1024)"
115
+ print("Global Representation Encoder has been initialized successfully!")
UniCeption/uniception/models/encoders/image_normalizations.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image normalizations for the different UniCeption image encoders.
3
+ Image encoders defined in UniCeption must have their corresponding image normalization defined here.
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+
8
+ import torch
9
+
10
+
11
+ @dataclass
12
+ class ImageNormalization:
13
+ mean: torch.Tensor
14
+ std: torch.Tensor
15
+
16
+
17
+ IMAGE_NORMALIZATION_DICT = {
18
+ "dummy": ImageNormalization(mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([1.0, 1.0, 1.0])),
19
+ "croco": ImageNormalization(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])),
20
+ "dust3r": ImageNormalization(mean=torch.tensor([0.5, 0.5, 0.5]), std=torch.tensor([0.5, 0.5, 0.5])),
21
+ "dinov2": ImageNormalization(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])),
22
+ "identity": ImageNormalization(mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([1.0, 1.0, 1.0])),
23
+ "patch_embedder": ImageNormalization(
24
+ mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])
25
+ ),
26
+ "radio": ImageNormalization(mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([1.0, 1.0, 1.0])),
27
+ "sea_raft": ImageNormalization(
28
+ mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([1.0, 1.0, 1.0]) / 255
29
+ ), # Sea-RAFT uses 0-255 in FP32
30
+ "unimatch": ImageNormalization(
31
+ mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([1.0, 1.0, 1.0]) / 255
32
+ ), # UniMatch uses 0-255 in FP32
33
+ "roma": ImageNormalization(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])),
34
+ "cosmos": ImageNormalization(mean=torch.tensor([0.0, 0.0, 0.0]), std=torch.tensor([0.5, 0.5, 0.5])),
35
+ }
UniCeption/uniception/models/encoders/list.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ List available UniCeption encoders.
3
+ """
4
+
5
+ import argparse
6
+
7
+ from uniception.models.encoders import print_available_encoder_models
8
+
9
+ if __name__ == "__main__":
10
+ print_available_encoder_models()
UniCeption/uniception/models/encoders/naradio.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Encoder Class for NARADIO (RayFronts)
3
+ """
4
+
5
+ import math
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.nn.attention.flex_attention import flex_attention
12
+
13
+ from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput
14
+ from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner
15
+
16
+
17
+ class GaussKernelAttn(nn.Module):
18
+ """Implementation of Gaussian Kernel based Attention using FlexAttention"""
19
+
20
+ def __init__(
21
+ self,
22
+ orig_attn,
23
+ gauss_std: float,
24
+ dim: int,
25
+ qk_norm: bool = False,
26
+ num_prefix_tokens: int = 8,
27
+ patch_size: int = 16,
28
+ ) -> None:
29
+ super().__init__()
30
+ num_heads = orig_attn.num_heads
31
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
32
+ self.num_heads = num_heads
33
+ self.head_dim = dim // num_heads
34
+ self.scale = self.head_dim**-0.5
35
+
36
+ self.addition_cache = dict()
37
+ self.input_resolution = None # to be set when calling forward
38
+ self.gauss_std = gauss_std
39
+ self.patch_size = patch_size
40
+
41
+ self.qkv = orig_attn.qkv
42
+ self.q_norm = orig_attn.q_norm if qk_norm else nn.Identity()
43
+ self.k_norm = orig_attn.k_norm if qk_norm else nn.Identity()
44
+ self.attn_drop = orig_attn.attn_drop
45
+ self.proj = orig_attn.proj
46
+ self.proj_drop = orig_attn.proj_drop
47
+ self.num_prefix_tokens = num_prefix_tokens
48
+
49
+ @staticmethod
50
+ def gaussian_window(dim1, dim2, std=7.0):
51
+ constant = 1 / (std * math.sqrt(2))
52
+ ks = list()
53
+ for dim in [dim1, dim2]:
54
+ start = -(dim - 1) / 2.0
55
+ k = torch.linspace(start=start * constant, end=(start + (dim - 1)) * constant, steps=dim, dtype=torch.float)
56
+ ks.append(k)
57
+ dist_square_to_mu = (torch.stack(torch.meshgrid(*ks, indexing="ij")) ** 2).sum(0)
58
+
59
+ return torch.exp(-dist_square_to_mu)
60
+
61
+ @staticmethod
62
+ def get_attention_addition(dim1, dim2, window, num_prefix_tokens=8):
63
+ m = torch.einsum("ij,kl->ijkl", torch.eye(dim1), torch.eye(dim2))
64
+ m = m.permute((0, 3, 1, 2)).contiguous()
65
+ out = F.conv2d(m.view(-1, dim1, dim2).unsqueeze(1), window.unsqueeze(0).unsqueeze(1), padding="same").squeeze(1)
66
+
67
+ out = out.view(dim1 * dim2, dim1 * dim2)
68
+ if num_prefix_tokens > 0:
69
+ v_adjusted = torch.vstack([torch.zeros((num_prefix_tokens, dim1 * dim2)), out])
70
+ out = torch.hstack([torch.zeros((dim1 * dim2 + num_prefix_tokens, num_prefix_tokens)), v_adjusted])
71
+
72
+ return out
73
+
74
+ def prepare_gaussian_addition(self, n_patches, device):
75
+ """Prepare the Gaussian addition matrix for the current input"""
76
+ # Check if we have a cached addition matrix for these dimensions
77
+ if n_patches not in self.addition_cache:
78
+ window_size = [side * 2 - 1 for side in n_patches]
79
+ window = self.gaussian_window(*window_size, std=self.gauss_std)
80
+ addition = self.get_attention_addition(*n_patches, window, self.num_prefix_tokens).to(device)
81
+
82
+ # Cache the addition matrix
83
+ self.addition_cache[n_patches] = addition
84
+
85
+ # Return the cached addition matrix
86
+ return self.addition_cache[n_patches]
87
+
88
+ def gauss_score_mod(self, score, b, h, q_idx, kv_idx, addition):
89
+ """Score modification function for FlexAttention"""
90
+ # Adding the precomputed Gaussian pattern to the attention score
91
+ return score + addition[q_idx, kv_idx]
92
+
93
+ def set_input_resolution(self, input_resolution: Tuple[int, int]):
94
+ """Set the input resolution for the Gaussian attention window"""
95
+ self.input_resolution = input_resolution
96
+
97
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
98
+ B, N, C = x.shape
99
+ assert self.input_resolution is not None, "input_resolution must be set before forward pass"
100
+ h, w = self.input_resolution
101
+ n_patches = (w // self.patch_size, h // self.patch_size)
102
+
103
+ qkv = self.qkv(x)
104
+ q, k, v = qkv.chunk(3, dim=-1)
105
+ q, k = self.q_norm(q), self.k_norm(k)
106
+
107
+ q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
108
+ k = k.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
109
+ v = v.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
110
+
111
+ addition = self.prepare_gaussian_addition(n_patches, device=x.device)
112
+
113
+ # Create a score_mod function with the current addition matrix
114
+ score_mod = lambda score, b, h, q_idx, kv_idx: self.gauss_score_mod(score, b, h, q_idx, kv_idx, addition)
115
+
116
+ # Use FlexAttention
117
+ attn_output = flex_attention(q, k, v, score_mod=score_mod)
118
+
119
+ # Reshape output and apply projection
120
+ attn_output = attn_output.transpose(1, 2).reshape(B, N, C)
121
+ attn_output = self.proj(attn_output)
122
+ attn_output = self.proj_drop(attn_output)
123
+
124
+ return attn_output
125
+
126
+
127
+ class NARADIOEncoder(UniCeptionViTEncoderBase):
128
+ """
129
+ UniCeption NARADIO (RayFronts) Encoder based on NACLIP & RADIO
130
+
131
+ The model modifies the attention of the last layer of RADIO following NACLIP,
132
+ thereby improving the spatial patch features.
133
+ """
134
+
135
+ def __init__(
136
+ self,
137
+ name: str,
138
+ data_norm_type: str = "radio",
139
+ patch_size: int = 16,
140
+ model_version: str = "radio_v2.5-l",
141
+ gauss_std: float = 7.0,
142
+ pretrained_checkpoint_path: str = None,
143
+ eradio_input_shape: Optional[tuple] = None,
144
+ torch_hub_force_reload: bool = False,
145
+ keep_first_n_layers: Optional[int] = None,
146
+ *args,
147
+ **kwargs,
148
+ ):
149
+ """
150
+ NARADIO Encoder for extracting spatial features from images.
151
+
152
+ Args:
153
+ name (str): Name of the encoder.
154
+ data_norm_type (str): Image normalization type. Default: "radio"
155
+ patch_size (int): Patch size for the encoder. Default: 16
156
+ model_version (str): Version of the RADIO model to load. Default: "radio_v2.5-l"
157
+ gauss_std: Standard deviation of the gaussian kernel. Default: 7.0
158
+ pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of RADIO. Default: None
159
+ eradio_input_shape (tuple): Input shape (height, width) for E-RADIO models. Default: None
160
+ torch_hub_force_reload (bool): Whether to force reload the model from torch hub. Default: False
161
+ keep_first_n_layers (Optional[int]): Number of layers to keep from the pretrained model. Default: None
162
+ """
163
+ # Init the base class
164
+ super().__init__(
165
+ name=name,
166
+ data_norm_type=data_norm_type,
167
+ patch_size=patch_size,
168
+ *args,
169
+ **kwargs,
170
+ )
171
+
172
+ # Init the RADIO Encoder specific attributes
173
+ self.model_version = model_version
174
+ self.enc_embed_dim = {
175
+ "radio_v2.5-b": 768,
176
+ "radio_v2.5-l": 1024,
177
+ "radio_v2.5-h": 1280,
178
+ "radio_v2.5-g": 1536,
179
+ "e-radio_v2": 1536,
180
+ }[self.model_version]
181
+
182
+ if self.model_version == "radio_v2.5-g":
183
+ assert patch_size == 14, "Patch size must be 14 for RADIO v2.5-g"
184
+ else:
185
+ assert patch_size == 16, "Patch size must be 16 for all other versions of RADIO"
186
+
187
+ # Load the pretrained RADIO model from torch hub
188
+ print(f"Loading pretrained {self.model_version} from torch hub")
189
+ try: # Requires internet access
190
+ self.model = torch.hub.load(
191
+ "NVlabs/RADIO",
192
+ "radio_model",
193
+ version=self.model_version,
194
+ progress=True,
195
+ skip_validation=True,
196
+ force_reload=torch_hub_force_reload,
197
+ )
198
+ except: # Load from cache
199
+ self.model = torch.hub.load(
200
+ "NVlabs/RADIO",
201
+ "radio_model",
202
+ version=self.model_version,
203
+ progress=True,
204
+ skip_validation=True,
205
+ )
206
+
207
+ # Delete the excess blocks if keep_first_n_layers is specified
208
+ if keep_first_n_layers is not None:
209
+ assert keep_first_n_layers < len(
210
+ self.model.model.blocks
211
+ ), "keep_first_n_layers must be less than the number of blocks"
212
+ print(f"Keeping only the first {keep_first_n_layers} layers of the model")
213
+ self.model.model.blocks = torch.nn.ModuleList(self.model.model.blocks[:keep_first_n_layers])
214
+
215
+ # Set the optimal window size for E-RADIO models
216
+ if "e-radio" in self.model_version:
217
+ assert eradio_input_shape is not None, "Input shape (height, width) must be provided for E-RADIO models"
218
+ self.model.model.set_optimal_window_size(eradio_input_shape)
219
+
220
+ # Load the custom pretrained checkpoint if provided
221
+ if pretrained_checkpoint_path is not None:
222
+ print(f"Loading custom pretrained NARADIO checkpoint from {pretrained_checkpoint_path}")
223
+ ckpt = torch.load(pretrained_checkpoint_path, weights_only=False)
224
+ print(self.load_state_dict(ckpt["model"]))
225
+
226
+ # Replace the attention of the last ViT block with the Gaussian Kernel based attention
227
+ self.model.model.blocks[-1] = GaussKernelAttn(
228
+ self.model.model.blocks[-1].attn,
229
+ gauss_std,
230
+ dim=self.enc_embed_dim,
231
+ num_prefix_tokens=self.model.num_summary_tokens,
232
+ patch_size=self.patch_size,
233
+ )
234
+
235
+ def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput:
236
+ """
237
+ NARADIO Encoder Forward Pass
238
+
239
+ Args:
240
+ encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
241
+
242
+ Returns:
243
+ ViTEncoderOutput: Output data from the encoder.
244
+ """
245
+ # Check image normalization type
246
+ self._check_data_normalization_type(encoder_input.data_norm_type)
247
+
248
+ # Check the dtype and shape of the input image
249
+ assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor"
250
+ assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)"
251
+ batch_size, channels, height, width = encoder_input.image.shape
252
+ assert channels == 3, "Input must have 3 channels"
253
+ assert (
254
+ height % self.patch_size == 0 and width % self.patch_size == 0
255
+ ), f"Input shape must be divisible by patch size: {self.patch_size}"
256
+
257
+ # Set input resolution for Gaussian attention
258
+ self.model.model.blocks[-1].set_input_resolution((height, width))
259
+
260
+ # Forward pass throught the RADIO encoder
261
+ summary, features = self.model(encoder_input.image)
262
+
263
+ # Resize the features to the expected shape
264
+ # (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size)
265
+ features = features.permute(0, 2, 1)
266
+ features = features.reshape(
267
+ -1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size
268
+ ).contiguous()
269
+
270
+ return ViTEncoderOutput(features=features)
271
+
272
+
273
+ class NARADIOIntermediateFeatureReturner(NARADIOEncoder, IntermediateFeatureReturner):
274
+ "Intermediate Feature Returner for UniCeption NARADIO Encoder"
275
+
276
+ def __init__(
277
+ self,
278
+ name: str,
279
+ data_norm_type: str = "radio",
280
+ patch_size: int = 16,
281
+ model_version: str = "radio_v2.5-l",
282
+ gauss_std: float = 7.0,
283
+ pretrained_checkpoint_path: str = None,
284
+ eradio_input_shape: Optional[tuple] = None,
285
+ indices: Union[int, List[int]] = [-1],
286
+ norm_intermediate: bool = True,
287
+ stop_early: bool = False,
288
+ intermediates_only: bool = True,
289
+ feature_adaptor: Optional[str] = None,
290
+ keep_first_n_layers: Optional[int] = None,
291
+ *args,
292
+ **kwargs,
293
+ ):
294
+ """
295
+ Intermediate Feature Returner for the NARADIO Encoder.
296
+
297
+ Args:
298
+ name (str): Name of the encoder.
299
+ data_norm_type (str): Image normalization type. Default: "radio"
300
+ patch_size (int): Patch size for the encoder. Default: 16
301
+ model_version (str): Version of the RADIO model to load. Default: "radio_v2.5-l"
302
+ gauss_std (float): Standard deviation of the gaussian kernel. Default: 7.0
303
+ pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of RADIO.
304
+ eradio_input_shape (tuple): Input shape (height, width) for E-RADIO models. Default: None
305
+ indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. Defaults to [-1]. Options:
306
+ - int: Return the last n layers.
307
+ - List[int]: Return the intermediate layers at the specified indices.
308
+ norm_intermediate (bool, optional): Whether to normalize the intermediate features. Defaults to True.
309
+ stop_early (bool, optional): Whether to stop early. Defaults to False.
310
+ intermediates_only (bool, optional): Whether to return only the intermediate features. Defaults to True.
311
+ feature_adaptor (Optional[str], optional): Feature adaptor to use. Defaults to None. Currently supported: "dino_v2".
312
+ keep_first_n_layers (Optional[int], optional): Number of layers to keep from the pretrained model. Defaults to None.
313
+ """
314
+ # Init the base classes
315
+ NARADIOEncoder.__init__(
316
+ self,
317
+ name=name,
318
+ data_norm_type=data_norm_type,
319
+ patch_size=patch_size,
320
+ model_version=model_version,
321
+ gauss_std=gauss_std,
322
+ pretrained_checkpoint_path=pretrained_checkpoint_path,
323
+ eradio_input_shape=eradio_input_shape,
324
+ keep_first_n_layers=keep_first_n_layers,
325
+ *args,
326
+ **kwargs,
327
+ )
328
+ IntermediateFeatureReturner.__init__(
329
+ self,
330
+ indices=indices,
331
+ norm_intermediate=norm_intermediate,
332
+ stop_early=stop_early,
333
+ intermediates_only=intermediates_only,
334
+ )
335
+
336
+ # Convert indices to absolute indices if indices is None
337
+ if self.indices is None:
338
+ self.indices = list(range(len(self.model.model.blocks)))
339
+
340
+ self.feature_adaptor = feature_adaptor
341
+ if self.feature_adaptor is None:
342
+ pass
343
+ elif self.feature_adaptor == "dino_v2":
344
+ # Initialize a dummy radio encoder with the adaptor setting
345
+ dummy_model = torch.hub.load(
346
+ "NVlabs/RADIO",
347
+ "radio_model",
348
+ version=self.model_version,
349
+ progress=True,
350
+ skip_validation=True,
351
+ adaptor_names="dino_v2",
352
+ )
353
+
354
+ # Extract its feature converter weights
355
+ self.spatial_feature_converter = dummy_model.adaptors["dino_v2"].feat_mlp
356
+
357
+ # Update the embedding dimension because the features have been projected
358
+ self.enc_embed_dim = self.spatial_feature_converter.final[-1].out_features
359
+
360
+ del dummy_model
361
+ else:
362
+ raise ValueError("Unsupported feature adaptor. Supported: dino_v2")
363
+
364
+ def forward(
365
+ self, encoder_input: ViTEncoderInput
366
+ ) -> Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]:
367
+ """
368
+ NARADIO Encoder Forward Pass with Intermediate Feature Return
369
+
370
+ Args:
371
+ encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
372
+
373
+ Returns:
374
+ Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]: Output data from the encoder.
375
+ If `intermediates_only` is True, returns a list of intermediate features.
376
+ Otherwise, returns a tuple with the final features and a list of intermediate features.
377
+ """
378
+ # Check image normalization type
379
+ self._check_data_normalization_type(encoder_input.data_norm_type)
380
+
381
+ # Check the dtype and shape of the input image
382
+ assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor"
383
+ assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)"
384
+ batch_size, channels, height, width = encoder_input.image.shape
385
+ assert channels == 3, "Input must have 3 channels"
386
+ assert (
387
+ height % self.patch_size == 0 and width % self.patch_size == 0
388
+ ), f"Input shape must be divisible by patch size: {self.patch_size}"
389
+
390
+ # Set input resolution for Gaussian attention
391
+ self.model.model.blocks[-1].set_input_resolution((height, width))
392
+
393
+ # Extract the final features and intermediate features accordingly
394
+ model_outputs = self.model.forward_intermediates(
395
+ encoder_input.image,
396
+ indices=self.indices,
397
+ return_prefix_tokens=False,
398
+ norm=self.norm_intermediate,
399
+ stop_early=self.stop_early,
400
+ output_fmt="NLC",
401
+ intermediates_only=self.intermediates_only,
402
+ )
403
+
404
+ # Extract the final features and intermediate features accordingly
405
+ final_features, intermediate_features = None, None
406
+ if self.intermediates_only:
407
+ intermediate_features = model_outputs
408
+ else:
409
+ final_features = model_outputs[0].features.contiguous()
410
+ intermediate_features = model_outputs[1]
411
+
412
+ # Optionally convert the features using the feature adaptor
413
+ Hp, Wp = height // self.patch_size, width // self.patch_size
414
+
415
+ # Convert final features
416
+ if final_features is not None:
417
+ if self.feature_adaptor is not None:
418
+ final_features = self.spatial_feature_converter(final_features)
419
+
420
+ # Convert to BCHW and package
421
+ final_features = final_features.view(batch_size, Hp, Wp, -1).permute(0, 3, 1, 2)
422
+ final_features = ViTEncoderOutput(features=final_features)
423
+
424
+ # Convert intermediate features
425
+ if intermediate_features is not None:
426
+ num_intermediate = len(intermediate_features)
427
+ all_intermediate_feats_tensor = torch.cat(intermediate_features, dim=0)
428
+ if self.feature_adaptor is not None:
429
+ all_intermediate_feats_tensor = self.spatial_feature_converter(all_intermediate_feats_tensor)
430
+ # Convert to BCHW
431
+ all_intermediate_feats_tensor = all_intermediate_feats_tensor.view(
432
+ num_intermediate * batch_size, Hp, Wp, -1
433
+ ).permute(0, 3, 1, 2)
434
+ all_intermediate_feats = torch.chunk(all_intermediate_feats_tensor, num_intermediate, dim=0)
435
+ intermediate_features = [ViTEncoderOutput(features=x) for x in all_intermediate_feats]
436
+
437
+ # Return the final features and intermediate features accordingly
438
+ if self.intermediates_only:
439
+ return intermediate_features
440
+ else:
441
+ return final_features, intermediate_features
442
+
443
+
444
+ if __name__ == "__main__":
445
+ # Init different versions of the RADIO Encoder
446
+ for model_version in ["radio_v2.5-b", "radio_v2.5-l"]:
447
+ naradio_encoder = NARADIOEncoder(name="NARADIOv2.5", model_version=model_version)
448
+
449
+ print("All NARADIO Encoders have been initialized successfully!")
450
+
451
+ # Intermediate Feature Returner Tests
452
+ print("Running Intermediate Feature Returner Tests...")
453
+
454
+ # Run the intermediate feature returner with last-n index
455
+ naradio_intermediate_feature_returner = NARADIOIntermediateFeatureReturner(
456
+ name="NARADIOv2.5", model_version="radio_v2.5-b", indices=6
457
+ ) # Last 6 layers
458
+ dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio")
459
+ output = naradio_intermediate_feature_returner(dummy_input)
460
+ assert isinstance(output, list), "Output must be a list of intermediate features"
461
+ assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
462
+ assert len(output) == 6, "Output must have length of intermediate features equal to the number of indices"
463
+
464
+ # Run the intermediate feature returner with specific indices
465
+ naradio_intermediate_feature_returner = NARADIOIntermediateFeatureReturner(
466
+ name="NARADIOv2.5", model_version="radio_v2.5-b", indices=[0, 2, 4, 6]
467
+ ) # Specific layers
468
+ dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio")
469
+ output = naradio_intermediate_feature_returner(dummy_input)
470
+ assert isinstance(output, list), "Output must be a list of intermediate features"
471
+ assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
472
+ assert len(output) == 4, "Output must have length of intermediate features equal to the number of indices"
473
+
474
+ # Test the normalizing of intermediate features
475
+ naradio_intermediate_feature_returner = NARADIOIntermediateFeatureReturner(
476
+ name="NARADIOv2.5", model_version="radio_v2.5-b", norm_intermediate=False, intermediates_only=False
477
+ ) # Do not normalize
478
+ dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio")
479
+ output = naradio_intermediate_feature_returner(dummy_input)
480
+ assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features"
481
+ assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features"
482
+ assert isinstance(output[1], list), "Second element of output must be a list of intermediate features"
483
+ assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
484
+ if not isinstance(naradio_intermediate_feature_returner.model.model.norm, torch.nn.Identity):
485
+ assert not torch.equal(
486
+ output[0].features, output[1][0].features
487
+ ), "Final features and intermediate features must be different"
488
+
489
+ naradio_intermediate_feature_returner = NARADIOIntermediateFeatureReturner(
490
+ name="NARADIOv2.5", model_version="radio_v2.5-b", norm_intermediate=True, intermediates_only=False
491
+ )
492
+ dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio")
493
+ output = naradio_intermediate_feature_returner(dummy_input)
494
+ assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features"
495
+ assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features"
496
+ assert isinstance(output[1], list), "Second element of output must be a list of intermediate features"
497
+ assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
498
+ assert torch.equal(
499
+ output[0].features, output[1][0].features
500
+ ), "Final features and intermediate features must be same"
501
+
502
+ print("All Intermediate Feature Returner Tests have passed successfully!")
UniCeption/uniception/models/encoders/patch_embedder.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Encoder class for Patch Embedder
3
+ """
4
+
5
+ import math
6
+ from functools import partial
7
+ from typing import Callable, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn.init import trunc_normal_
12
+
13
+ from uniception.models.encoders.base import (
14
+ UniCeptionViTEncoderBase,
15
+ ViTEncoderInput,
16
+ ViTEncoderNonImageInput,
17
+ ViTEncoderOutput,
18
+ )
19
+
20
+
21
+ def make_2tuple(x):
22
+ if isinstance(x, tuple):
23
+ assert len(x) == 2
24
+ return x
25
+
26
+ assert isinstance(x, int)
27
+ return (x, x)
28
+
29
+
30
+ class PatchEmbedder(UniCeptionViTEncoderBase):
31
+ "UniCeption Patch Embedder"
32
+
33
+ def __init__(
34
+ self,
35
+ name: str,
36
+ data_norm_type: str = "patch_embedder",
37
+ input_size: Union[int, Tuple[int, int]] = 518,
38
+ patch_size: int = 14,
39
+ in_chans: int = 3,
40
+ enc_embed_dim: int = 1024,
41
+ norm_layer: Optional[Callable] = None,
42
+ post_pe_norm_layer: Optional[Callable] = partial(nn.LayerNorm, eps=1e-6),
43
+ interpolate_antialias: bool = False,
44
+ interpolate_offset: float = 0.1,
45
+ pretrained_checkpoint_path: str = None,
46
+ *args,
47
+ **kwargs,
48
+ ):
49
+ """
50
+ Patch Encoder for extracting patch-wise features from a spatial input of size (B, C, H, W).
51
+ Learnable positional encoding is also applied to the patch-wise features.
52
+ """
53
+ # Init the base class
54
+ super().__init__(
55
+ name=name,
56
+ data_norm_type=data_norm_type,
57
+ patch_size=patch_size,
58
+ *args,
59
+ **kwargs,
60
+ )
61
+
62
+ # Init the Patch Embedder specific attributes
63
+ patch_HW = make_2tuple(patch_size)
64
+ self.input_size = make_2tuple(input_size)
65
+ self.patches_resolution = (self.input_size[0] // patch_HW[0], self.input_size[1] // patch_HW[1])
66
+ self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
67
+ self.in_chans = in_chans
68
+ self.enc_embed_dim = enc_embed_dim
69
+
70
+ # Init the Patch Embedder layers
71
+ self.proj = nn.Conv2d(in_chans, enc_embed_dim, kernel_size=patch_HW, stride=patch_HW)
72
+ self.norm = norm_layer(enc_embed_dim) if norm_layer else nn.Identity()
73
+
74
+ # Init the learnable positional encodings
75
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, self.enc_embed_dim))
76
+ trunc_normal_(self.pos_embed, std=0.02)
77
+ self.interpolate_antialias = interpolate_antialias
78
+ self.interpolate_offset = interpolate_offset
79
+
80
+ # Init the norm layer after positional encoding
81
+ self.post_pe_norm = post_pe_norm_layer(enc_embed_dim) if post_pe_norm_layer else nn.Identity()
82
+
83
+ # Load the pretrained checkpoint if provided
84
+ self.pretrained_checkpoint_path = pretrained_checkpoint_path
85
+ if self.pretrained_checkpoint_path:
86
+ print(f"Loading custom pretrained Patch Embedder checkpoint from {self.pretrained_checkpoint_path} ...")
87
+ ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
88
+ print(self.load_state_dict(ckpt["model"]))
89
+
90
+ def interpolate_pos_encoding(self, features, height, width):
91
+ """
92
+ Interpolate the positional encoding to the expected size.
93
+
94
+ Args:
95
+ features (torch.Tensor): Input tensor of shape (B, N, C).
96
+ height (int, float): Height of the input tensor.
97
+ width (int, float): Width of the input tensor.
98
+
99
+ Returns:
100
+ torch.Tensor: Interpolated positional encoding tensor of shape (1, N, C).
101
+ """
102
+ previous_dtype = features.dtype
103
+ npatch = features.shape[1]
104
+ N = self.pos_embed.shape[1]
105
+ if npatch == N and height == width:
106
+ return self.pos_embed
107
+ patch_pos_embed = self.pos_embed.float()
108
+ dim = features.shape[-1]
109
+ height0 = height // self.patch_size
110
+ width0 = width // self.patch_size
111
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
112
+ assert N == M * M
113
+ kwargs = {}
114
+ if self.interpolate_offset:
115
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
116
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
117
+ sh = float(height0 + self.interpolate_offset) / M
118
+ sw = float(width0 + self.interpolate_offset) / M
119
+ kwargs["scale_factor"] = (sh, sw)
120
+ else:
121
+ # Simply specify an output size instead of a scale factor
122
+ kwargs["size"] = (height0, width0)
123
+ patch_pos_embed = nn.functional.interpolate(
124
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
125
+ mode="bicubic",
126
+ antialias=self.interpolate_antialias,
127
+ **kwargs,
128
+ )
129
+ assert (height0, width0) == patch_pos_embed.shape[-2:]
130
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
131
+
132
+ return patch_pos_embed.to(previous_dtype)
133
+
134
+ def forward(self, encoder_input: Union[ViTEncoderInput, ViTEncoderNonImageInput]) -> ViTEncoderOutput:
135
+ """
136
+ Patch Embedder Forward Pass
137
+
138
+ Args:
139
+ encoder_input (Union[ViTEncoderInput, ViTEncoderNonImageInput]): Input data for the encoder.
140
+ If input type is ViTEncoderInput, input data must contain image normalization type and normalized image tensor.
141
+ If input type is ViTEncoderNonImageInput, input data must contain a tensor of size (B, C, H, W).
142
+
143
+ Returns:
144
+ ViTEncoderOutput: Output data from the encoder.
145
+ """
146
+ # Get the input data and verify normalization if the input type is ViTEncoderInput
147
+ if isinstance(encoder_input, ViTEncoderInput):
148
+ self._check_data_normalization_type(encoder_input.data_norm_type)
149
+ input_data = encoder_input.image
150
+ elif isinstance(encoder_input, ViTEncoderNonImageInput):
151
+ input_data = encoder_input.data
152
+ else:
153
+ raise ValueError("Unsupported input type for Patch Embedder.")
154
+
155
+ # Check the dtype and shape of the input
156
+ assert isinstance(input_data, torch.Tensor), "Input must be a torch.Tensor"
157
+ assert input_data.ndim == 4, "Input must be of shape (B, C, H, W)"
158
+ batch_size, channels, height, width = input_data.shape
159
+ assert (
160
+ height % self.patch_size == 0 and width % self.patch_size == 0
161
+ ), f"Input shape must be divisible by patch size: {self.patch_size}"
162
+
163
+ # Patchify the input data and project into expected latent space
164
+ features = self.proj(input_data) # (B, C, H, W) -> (B, E, H / Patch_Size, W / Patch_Size)
165
+ features = features.flatten(2).transpose(
166
+ 1, 2
167
+ ) # (B, E, H / Patch_Size, W / Patch_Size) -> (B, H / Patch_Size * W / Patch_Size, E)
168
+ features = self.norm(features) # Normalize the features after patch embedding
169
+ features = features + self.interpolate_pos_encoding(
170
+ features, height, width
171
+ ) # (B, H / Patch_Size * W / Patch_Size, E)
172
+ features = self.post_pe_norm(features) # Normalize the features after positional encoding
173
+
174
+ # Resize the features to the expected shape
175
+ # (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size)
176
+ features = features.permute(0, 2, 1)
177
+ features = features.reshape(
178
+ -1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size
179
+ ).contiguous()
180
+
181
+ return ViTEncoderOutput(features=features)
182
+
183
+
184
+ if __name__ == "__main__":
185
+ # Init Patch Embedder for images as input
186
+ patch_embedder = PatchEmbedder(
187
+ name="patch_embedder",
188
+ data_norm_type="patch_embedder",
189
+ input_size=518,
190
+ patch_size=14,
191
+ in_chans=3,
192
+ enc_embed_dim=1024,
193
+ )
194
+
195
+ # Test dummy image input
196
+ dummy_image = torch.randn(1, 3, 518, 518)
197
+ patch_embedder_output = patch_embedder(ViTEncoderInput(data_norm_type="patch_embedder", image=dummy_image))
198
+ assert patch_embedder_output.features.shape == (
199
+ 1,
200
+ 1024,
201
+ 37,
202
+ 37,
203
+ ), "Output features must have shape (1, 1024, 37, 37)"
204
+
205
+ # Init Patch Embedder for non-image data as input
206
+ patch_embedder = PatchEmbedder(
207
+ name="patch_embedder",
208
+ data_norm_type="patch_embedder",
209
+ input_size=518,
210
+ patch_size=14,
211
+ in_chans=6,
212
+ enc_embed_dim=1024,
213
+ )
214
+
215
+ # Init Patch Embedder for single channel input
216
+ patch_embedder = PatchEmbedder(
217
+ name="patch_embedder",
218
+ data_norm_type="patch_embedder",
219
+ input_size=518,
220
+ patch_size=14,
221
+ in_chans=1,
222
+ enc_embed_dim=1024,
223
+ )
224
+
225
+ # Test dummy non-image input
226
+ dummy_image = torch.randn(1, 1, 518, 518)
227
+ patch_embedder_output = patch_embedder(ViTEncoderNonImageInput(data=dummy_image))
228
+ assert patch_embedder_output.features.shape == (
229
+ 1,
230
+ 1024,
231
+ 37,
232
+ 37,
233
+ ), "Output features must have shape (1, 1024, 37, 37)"
234
+
235
+ print("All variants of Patch Embedder have been initialized successfully!")
UniCeption/uniception/models/encoders/radio.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Encoder Class for RADIO (Nvidia)
3
+ """
4
+
5
+ from typing import List, Optional, Tuple, Union
6
+
7
+ import torch
8
+
9
+ from uniception.models.encoders.base import UniCeptionViTEncoderBase, ViTEncoderInput, ViTEncoderOutput
10
+ from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner
11
+
12
+
13
+ class RADIOEncoder(UniCeptionViTEncoderBase):
14
+ "UniCeption RADIO Encoder"
15
+
16
+ def __init__(
17
+ self,
18
+ name: str,
19
+ data_norm_type: str = "radio",
20
+ patch_size: int = 16,
21
+ model_version: str = "radio_v2.5-l",
22
+ pretrained_checkpoint_path: str = None,
23
+ eradio_input_shape: Optional[tuple] = None,
24
+ torch_hub_force_reload: bool = False,
25
+ keep_first_n_layers: Optional[int] = None,
26
+ *args,
27
+ **kwargs,
28
+ ):
29
+ """
30
+ RADIO Encoder for extracting spatial features from images.
31
+
32
+ Args:
33
+ name (str): Name of the encoder.
34
+ data_norm_type (str): Image normalization type. Default: "radio"
35
+ patch_size (int): Patch size for the encoder. Default: 16
36
+ model_version (str): Version of the RADIO model to load. Default: "radio_v2.5-l"
37
+ pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of RADIO. Default: None
38
+ eradio_input_shape (tuple): Input shape (height, width) for E-RADIO models. Default: None
39
+ torch_hub_force_reload (bool): Whether to force reload the model from torch hub. Default: False
40
+ keep_first_n_layers (Optional[int]): Number of layers to keep from the pretrained model. Default: None
41
+ """
42
+ # Init the base class
43
+ super().__init__(
44
+ name=name,
45
+ data_norm_type=data_norm_type,
46
+ patch_size=patch_size,
47
+ *args,
48
+ **kwargs,
49
+ )
50
+
51
+ # Init the RADIO Encoder specific attributes
52
+ self.model_version = model_version
53
+ self.enc_embed_dim = {
54
+ "radio_v2.5-b": 768,
55
+ "radio_v2.5-l": 1024,
56
+ "radio_v2.5-h": 1280,
57
+ "radio_v2.5-g": 1536,
58
+ "e-radio_v2": 1536,
59
+ }[self.model_version]
60
+
61
+ if self.model_version == "radio_v2.5-g":
62
+ assert patch_size == 14, "Patch size must be 14 for RADIO v2.5-g"
63
+ else:
64
+ assert patch_size == 16, "Patch size must be 16 for all other versions of RADIO"
65
+
66
+ # Load the pretrained RADIO model from torch hub
67
+ print(f"Loading pretrained {self.model_version} from torch hub")
68
+ try: # Requires internet access
69
+ self.model = torch.hub.load(
70
+ "NVlabs/RADIO",
71
+ "radio_model",
72
+ version=self.model_version,
73
+ progress=True,
74
+ skip_validation=True,
75
+ force_reload=torch_hub_force_reload,
76
+ )
77
+ except: # Load from cache
78
+ self.model = torch.hub.load(
79
+ "NVlabs/RADIO",
80
+ "radio_model",
81
+ version=self.model_version,
82
+ progress=True,
83
+ skip_validation=True,
84
+ )
85
+
86
+ # Delete the excess blocks if keep_first_n_layers is specified
87
+ if keep_first_n_layers is not None:
88
+ assert keep_first_n_layers < len(
89
+ self.model.model.blocks
90
+ ), "keep_first_n_layers must be less than the number of blocks"
91
+ print(f"Keeping only the first {keep_first_n_layers} layers of the model")
92
+ self.model.model.blocks = torch.nn.ModuleList(self.model.model.blocks[:keep_first_n_layers])
93
+
94
+ # Set the optimal window size for E-RADIO models
95
+ if "e-radio" in self.model_version:
96
+ assert eradio_input_shape is not None, "Input shape (height, width) must be provided for E-RADIO models"
97
+ self.model.model.set_optimal_window_size(eradio_input_shape)
98
+
99
+ # Load the custom pretrained checkpoint if provided
100
+ if pretrained_checkpoint_path is not None:
101
+ print(f"Loading custom pretrained RADIO checkpoint from {pretrained_checkpoint_path}")
102
+ ckpt = torch.load(pretrained_checkpoint_path, weights_only=False)
103
+ print(self.load_state_dict(ckpt["model"]))
104
+
105
+ def forward(self, encoder_input: ViTEncoderInput) -> ViTEncoderOutput:
106
+ """
107
+ RADIO Encoder Forward Pass
108
+
109
+ Args:
110
+ encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
111
+
112
+ Returns:
113
+ ViTEncoderOutput: Output data from the encoder.
114
+ """
115
+ # Check image normalization type
116
+ self._check_data_normalization_type(encoder_input.data_norm_type)
117
+
118
+ # Check the dtype and shape of the input image
119
+ assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor"
120
+ assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)"
121
+ batch_size, channels, height, width = encoder_input.image.shape
122
+ assert channels == 3, "Input must have 3 channels"
123
+ assert (
124
+ height % self.patch_size == 0 and width % self.patch_size == 0
125
+ ), f"Input shape must be divisible by patch size: {self.patch_size}"
126
+
127
+ # Forward pass throught the RADIO encoder
128
+ summary, features = self.model(encoder_input.image)
129
+
130
+ # Resize the features to the expected shape
131
+ # (B x Num_patches x Embed_dim) -> (B x Embed_dim x H / Patch_Size x W / Patch_Size)
132
+ features = features.permute(0, 2, 1)
133
+ features = features.reshape(
134
+ -1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size
135
+ ).contiguous()
136
+
137
+ return ViTEncoderOutput(features=features)
138
+
139
+
140
+ class RADIOIntermediateFeatureReturner(RADIOEncoder, IntermediateFeatureReturner):
141
+ "Intermediate Feature Returner for UniCeption RADIO Encoder"
142
+
143
+ def __init__(
144
+ self,
145
+ name: str,
146
+ data_norm_type: str = "radio",
147
+ patch_size: int = 16,
148
+ model_version: str = "radio_v2.5-l",
149
+ pretrained_checkpoint_path: str = None,
150
+ eradio_input_shape: Optional[tuple] = None,
151
+ indices: Union[int, List[int]] = [-1],
152
+ norm_intermediate: bool = True,
153
+ stop_early: bool = False,
154
+ intermediates_only: bool = True,
155
+ feature_adaptor: Optional[str] = None,
156
+ keep_first_n_layers: Optional[int] = None,
157
+ *args,
158
+ **kwargs,
159
+ ):
160
+ """
161
+ Intermediate Feature Returner for the RADIO Encoder.
162
+
163
+ Args:
164
+ name (str): Name of the encoder.
165
+ data_norm_type (str): Image normalization type. Default: "radio"
166
+ patch_size (int): Patch size for the encoder. Default: 16
167
+ model_version (str): Version of the RADIO model to load. Default: "radio_v2.5-l"
168
+ pretrained_checkpoint_path (str): Path to the pretrained checkpoint if using custom trained version of RADIO.
169
+ eradio_input_shape (tuple): Input shape (height, width) for E-RADIO models. Default: None
170
+ indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. Defaults to [-1]. Options:
171
+ - int: Return the last n layers.
172
+ - List[int]: Return the intermediate layers at the specified indices.
173
+ norm_intermediate (bool, optional): Whether to normalize the intermediate features. Defaults to True.
174
+ stop_early (bool, optional): Whether to stop early. Defaults to False.
175
+ intermediates_only (bool, optional): Whether to return only the intermediate features. Defaults to True.
176
+ feature_adaptor (Optional[str], optional): Feature adaptor to use. Defaults to None. Currently supported: "dino_v2".
177
+ keep_first_n_layers (Optional[int], optional): Number of layers to keep from the pretrained model. Defaults to None.
178
+ """
179
+ # Init the base classes
180
+ RADIOEncoder.__init__(
181
+ self,
182
+ name=name,
183
+ data_norm_type=data_norm_type,
184
+ patch_size=patch_size,
185
+ model_version=model_version,
186
+ pretrained_checkpoint_path=pretrained_checkpoint_path,
187
+ eradio_input_shape=eradio_input_shape,
188
+ keep_first_n_layers=keep_first_n_layers,
189
+ *args,
190
+ **kwargs,
191
+ )
192
+ IntermediateFeatureReturner.__init__(
193
+ self,
194
+ indices=indices,
195
+ norm_intermediate=norm_intermediate,
196
+ stop_early=stop_early,
197
+ intermediates_only=intermediates_only,
198
+ )
199
+
200
+ # Convert indices to absolute indices if indices is None
201
+ if self.indices is None:
202
+ self.indices = list(range(len(self.model.model.blocks)))
203
+
204
+ self.feature_adaptor = feature_adaptor
205
+ if self.feature_adaptor is None:
206
+ pass
207
+ elif self.feature_adaptor == "dino_v2":
208
+ # Initialize a dummy radio encoder with the adaptor setting
209
+ dummy_model = torch.hub.load(
210
+ "NVlabs/RADIO",
211
+ "radio_model",
212
+ version=self.model_version,
213
+ progress=True,
214
+ skip_validation=True,
215
+ adaptor_names="dino_v2",
216
+ )
217
+
218
+ # Extract its feature converter weights
219
+ self.spatial_feature_converter = dummy_model.adaptors["dino_v2"].feat_mlp
220
+
221
+ # Update the embedding dimension because the features have been projected
222
+ self.enc_embed_dim = self.spatial_feature_converter.final[-1].out_features
223
+
224
+ del dummy_model
225
+ else:
226
+ raise ValueError("Unsupported feature adaptor. Supported: dino_v2")
227
+
228
+ def forward(
229
+ self, encoder_input: ViTEncoderInput
230
+ ) -> Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]:
231
+ """
232
+ RADIO Encoder Forward Pass with Intermediate Feature Return
233
+
234
+ Args:
235
+ encoder_input (ViTEncoderInput): Input data for the encoder. Input data must contain image normalization type and normalized image tensor.
236
+
237
+ Returns:
238
+ Union[List[ViTEncoderOutput], Tuple[ViTEncoderOutput, List[ViTEncoderOutput]]]: Output data from the encoder.
239
+ If `intermediates_only` is True, returns a list of intermediate features.
240
+ Otherwise, returns a tuple with the final features and a list of intermediate features.
241
+ """
242
+ # Check image normalization type
243
+ self._check_data_normalization_type(encoder_input.data_norm_type)
244
+
245
+ # Check the dtype and shape of the input image
246
+ assert isinstance(encoder_input.image, torch.Tensor), "Input must be a torch.Tensor"
247
+ assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)"
248
+ batch_size, channels, height, width = encoder_input.image.shape
249
+ assert channels == 3, "Input must have 3 channels"
250
+ assert (
251
+ height % self.patch_size == 0 and width % self.patch_size == 0
252
+ ), f"Input shape must be divisible by patch size: {self.patch_size}"
253
+
254
+ # Extract the final features and intermediate features accordingly
255
+ model_outputs = self.model.forward_intermediates(
256
+ encoder_input.image,
257
+ indices=self.indices,
258
+ return_prefix_tokens=False,
259
+ norm=self.norm_intermediate,
260
+ stop_early=self.stop_early,
261
+ output_fmt="NLC",
262
+ intermediates_only=self.intermediates_only,
263
+ )
264
+
265
+ # Extract the final features and intermediate features accordingly
266
+ final_features, intermediate_features = None, None
267
+ if self.intermediates_only:
268
+ intermediate_features = model_outputs
269
+ else:
270
+ final_features = model_outputs[0].features.contiguous()
271
+ intermediate_features = model_outputs[1]
272
+
273
+ # Optionally convert the features using the feature adaptor
274
+ Hp, Wp = height // self.patch_size, width // self.patch_size
275
+
276
+ # Convert final features
277
+ if final_features is not None:
278
+ if self.feature_adaptor is not None:
279
+ final_features = self.spatial_feature_converter(final_features)
280
+
281
+ # Convert to BCHW and package
282
+ final_features = final_features.view(batch_size, Hp, Wp, -1).permute(0, 3, 1, 2)
283
+ final_features = ViTEncoderOutput(features=final_features)
284
+
285
+ # Convert intermediate features
286
+ if intermediate_features is not None:
287
+ num_intermediate = len(intermediate_features)
288
+ all_intermediate_feats_tensor = torch.cat(intermediate_features, dim=0)
289
+ if self.feature_adaptor is not None:
290
+ all_intermediate_feats_tensor = self.spatial_feature_converter(all_intermediate_feats_tensor)
291
+ # Convert to BCHW
292
+ all_intermediate_feats_tensor = all_intermediate_feats_tensor.view(
293
+ num_intermediate * batch_size, Hp, Wp, -1
294
+ ).permute(0, 3, 1, 2)
295
+ all_intermediate_feats = torch.chunk(all_intermediate_feats_tensor, num_intermediate, dim=0)
296
+ intermediate_features = [ViTEncoderOutput(features=x) for x in all_intermediate_feats]
297
+
298
+ # Return the final features and intermediate features accordingly
299
+ if self.intermediates_only:
300
+ return intermediate_features
301
+ else:
302
+ return final_features, intermediate_features
303
+
304
+
305
+ if __name__ == "__main__":
306
+ # Init different versions of the RADIO Encoder
307
+ for model_version in ["radio_v2.5-b", "radio_v2.5-l"]:
308
+ radio_encoder = RADIOEncoder(name="RADIOv2.5", model_version=model_version)
309
+
310
+ # Init the E-RADIO Encoder
311
+ eradio_input_shape = (512, 512)
312
+ eradio_encoder = RADIOEncoder(name="E-RADIO", model_version="e-radio_v2", eradio_input_shape=eradio_input_shape)
313
+
314
+ print("All RADIO Encoders have been initialized successfully!")
315
+
316
+ # Intermediate Feature Returner Tests
317
+ print("Running Intermediate Feature Returner Tests...")
318
+
319
+ # Run the intermediate feature returner with last-n index
320
+ radio_intermediate_feature_returner = RADIOIntermediateFeatureReturner(
321
+ name="RADIOv2.5", model_version="radio_v2.5-b", indices=6
322
+ ) # Last 6 layers
323
+ dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio")
324
+ output = radio_intermediate_feature_returner(dummy_input)
325
+ assert isinstance(output, list), "Output must be a list of intermediate features"
326
+ assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
327
+ assert len(output) == 6, "Output must have length of intermediate features equal to the number of indices"
328
+
329
+ # Run the intermediate feature returner with specific indices
330
+ radio_intermediate_feature_returner = RADIOIntermediateFeatureReturner(
331
+ name="RADIOv2.5", model_version="radio_v2.5-b", indices=[0, 2, 4, 6]
332
+ ) # Specific layers
333
+ dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio")
334
+ output = radio_intermediate_feature_returner(dummy_input)
335
+ assert isinstance(output, list), "Output must be a list of intermediate features"
336
+ assert isinstance(output[0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
337
+ assert len(output) == 4, "Output must have length of intermediate features equal to the number of indices"
338
+
339
+ # Test the normalizing of intermediate features
340
+ radio_intermediate_feature_returner = RADIOIntermediateFeatureReturner(
341
+ name="RADIOv2.5", model_version="radio_v2.5-b", norm_intermediate=False, intermediates_only=False
342
+ ) # Do not normalize
343
+ dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio")
344
+ output = radio_intermediate_feature_returner(dummy_input)
345
+ assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features"
346
+ assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features"
347
+ assert isinstance(output[1], list), "Second element of output must be a list of intermediate features"
348
+ assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
349
+ if not isinstance(radio_intermediate_feature_returner.model.model.norm, torch.nn.Identity):
350
+ assert not torch.equal(
351
+ output[0].features, output[1][0].features
352
+ ), "Final features and intermediate features must be different"
353
+
354
+ radio_intermediate_feature_returner = RADIOIntermediateFeatureReturner(
355
+ name="RADIOv2.5", model_version="radio_v2.5-b", norm_intermediate=True, intermediates_only=False
356
+ )
357
+ dummy_input = ViTEncoderInput(image=torch.randn(1, 3, 224, 224), data_norm_type="radio")
358
+ output = radio_intermediate_feature_returner(dummy_input)
359
+ assert isinstance(output, tuple), "Output must be a tuple with final features and intermediate features"
360
+ assert isinstance(output[0], ViTEncoderOutput), "First element of output must be the final features"
361
+ assert isinstance(output[1], list), "Second element of output must be a list of intermediate features"
362
+ assert isinstance(output[1][0], ViTEncoderOutput), "Output must be a list of ViTEncoderOutput"
363
+ assert torch.equal(
364
+ output[0].features, output[1][0].features
365
+ ), "Final features and intermediate features must be same"
366
+
367
+ print("All Intermediate Feature Returner Tests have passed successfully!")
UniCeption/uniception/models/encoders/utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for UniCeption Encoders.
3
+ """
4
+
5
+ import functools
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ def profile_encoder(num_warmup=3, num_runs=20, autocast_precision="float16", use_compile=False, dynamic=True):
12
+ def decorator(func):
13
+ @functools.wraps(func)
14
+ def wrapper(self, *args, **kwargs):
15
+ device = "cuda"
16
+ autocast_dtype = getattr(torch, autocast_precision)
17
+
18
+ # Compile the model if requested
19
+ if use_compile:
20
+ compiled_func = torch.compile(func, dynamic=dynamic, mode="max-autotune")
21
+ else:
22
+ compiled_func = func
23
+
24
+ with torch.autocast("cuda", dtype=autocast_dtype):
25
+ # Warm-up runs
26
+ for _ in range(num_warmup):
27
+ output = compiled_func(self, *args, **kwargs)
28
+ if isinstance(output, torch.Tensor):
29
+ output.sum().backward()
30
+ else:
31
+ output.features.sum().backward()
32
+ torch.cuda.synchronize()
33
+
34
+ # Clear memory cache
35
+ torch.cuda.empty_cache()
36
+
37
+ # Lists to store results
38
+ forward_times, backward_times, memory_usages = [], [], []
39
+
40
+ for _ in range(num_runs):
41
+ start_event = torch.cuda.Event(enable_timing=True)
42
+ end_event = torch.cuda.Event(enable_timing=True)
43
+
44
+ torch.cuda.reset_peak_memory_stats()
45
+ memory_before = torch.cuda.max_memory_allocated(device)
46
+
47
+ # Forward pass
48
+ start_event.record()
49
+ output = compiled_func(self, *args, **kwargs)
50
+ end_event.record()
51
+ torch.cuda.synchronize()
52
+ forward_times.append(start_event.elapsed_time(end_event))
53
+
54
+ # Backward pass
55
+ start_event.record()
56
+ if isinstance(output, torch.Tensor):
57
+ output.sum().backward()
58
+ else:
59
+ output.features.sum().backward()
60
+ end_event.record()
61
+ torch.cuda.synchronize()
62
+ backward_times.append(start_event.elapsed_time(end_event))
63
+
64
+ memory_after = torch.cuda.max_memory_allocated(device)
65
+ memory_usages.append((memory_after - memory_before) / 1e6) # Convert to MB
66
+
67
+ # Compute mean and standard deviation
68
+ fwd_mean, fwd_std = np.mean(forward_times), np.std(forward_times)
69
+ bwd_mean, bwd_std = np.mean(backward_times), np.std(backward_times)
70
+ mem_mean, mem_std = np.mean(memory_usages), np.std(memory_usages)
71
+
72
+ compile_status = (
73
+ "with torch.compile (dynamic=True)"
74
+ if use_compile and dynamic
75
+ else "with torch.compile (dynamic=False)" if use_compile else "without torch.compile"
76
+ )
77
+ print(f"Profiling results {compile_status}:")
78
+ print(f"Forward Pass Time: {fwd_mean:.2f} ± {fwd_std:.2f} ms")
79
+ print(f"Backward Pass Time: {bwd_mean:.2f} ± {bwd_std:.2f} ms")
80
+ print(f"Peak GPU Memory Usage: {mem_mean:.2f} ± {mem_std:.2f} MB")
81
+
82
+ return output
83
+
84
+ return wrapper
85
+
86
+ return decorator
UniCeption/uniception/models/factory/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from uniception.models.factory.dust3r import DUSt3R
2
+
3
+ __all__ = ["DUSt3R"]
UniCeption/uniception/models/factory/dust3r.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from uniception.models.encoders import ViTEncoderInput
7
+ from uniception.models.encoders.croco import CroCoEncoder
8
+ from uniception.models.encoders.image_normalizations import IMAGE_NORMALIZATION_DICT
9
+ from uniception.models.info_sharing.base import MultiViewTransformerInput
10
+ from uniception.models.info_sharing.cross_attention_transformer import (
11
+ MultiViewCrossAttentionTransformer,
12
+ MultiViewCrossAttentionTransformerIFR,
13
+ )
14
+ from uniception.models.libs.croco.pos_embed import RoPE2D, get_2d_sincos_pos_embed
15
+ from uniception.models.prediction_heads.adaptors import PointMapWithConfidenceAdaptor
16
+ from uniception.models.prediction_heads.base import AdaptorInput, PredictionHeadInput, PredictionHeadLayeredInput
17
+ from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor
18
+ from uniception.models.prediction_heads.linear import LinearFeature
19
+
20
+
21
+ def is_symmetrized(gt1, gt2):
22
+ "Function to check if input pairs are symmetrized, i.e., (a, b) and (b, a) always exist in the input"
23
+ x = gt1["instance"]
24
+ y = gt2["instance"]
25
+ if len(x) == len(y) and len(x) == 1:
26
+ return False # special case of batchsize 1
27
+ ok = True
28
+ for i in range(0, len(x), 2):
29
+ ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i])
30
+ return ok
31
+
32
+
33
+ def interleave(tensor1, tensor2):
34
+ "Interleave two tensors along the first dimension (used to avoid redundant encoding for symmetrized pairs)"
35
+ res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1)
36
+ res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1)
37
+ return res1, res2
38
+
39
+
40
+ class DUSt3R(nn.Module):
41
+ "DUSt3R defined with UniCeption Modules"
42
+
43
+ def __init__(
44
+ self,
45
+ name: str,
46
+ data_norm_type: str = "dust3r",
47
+ img_size: tuple = (224, 224),
48
+ patch_embed_cls: str = "PatchEmbedDust3R",
49
+ pred_head_type: str = "linear",
50
+ pred_head_output_dim: int = 4,
51
+ pred_head_feature_dim: int = 256,
52
+ depth_mode: Tuple[str, float, float] = ("exp", -float("inf"), float("inf")),
53
+ conf_mode: Tuple[str, float, float] = ("exp", 1, float("inf")),
54
+ pos_embed: str = "RoPE100",
55
+ pretrained_checkpoint_path: str = None,
56
+ pretrained_encoder_checkpoint_path: str = None,
57
+ pretrained_info_sharing_checkpoint_path: str = None,
58
+ pretrained_pred_head_checkpoint_paths: List[str] = [None, None],
59
+ pretrained_pred_head_regressor_checkpoint_paths: List[str] = [None, None],
60
+ override_encoder_checkpoint_attributes: bool = False,
61
+ *args,
62
+ **kwargs,
63
+ ):
64
+ """
65
+ Two-view model containing siamese encoders followed by a two-view cross-attention transformer and respective downstream heads.
66
+ The goal is to output scene representation directly, both images in view1's frame (hence the asymmetry).
67
+
68
+ Args:
69
+ name (str): Name of the model.
70
+ data_norm_type (str): Type of data normalization. (default: "dust3r")
71
+ img_size (tuple): Size of input images. (default: (224, 224))
72
+ patch_embed_cls (str): Class for patch embedding. (default: "PatchEmbedDust3R"). Options:
73
+ - "PatchEmbedDust3R"
74
+ - "ManyAR_PatchEmbed"
75
+ pred_head_type (str): Type of prediction head. (default: "linear"). Options:
76
+ - "linear"
77
+ - "dpt"
78
+ pred_head_output_dim (int): Output dimension of prediction head. (default: 4)
79
+ pred_head_feature_dim (int): Feature dimension of prediction head. (default: 256)
80
+ depth_mode (Tuple[str, float, float]): Depth mode settings (mode=['linear', 'square', 'exp'], vmin, vmax). (default: ('exp', -inf, inf))
81
+ conf_mode (Tuple[str, float, float]): Confidence mode settings (mode=['linear', 'square', 'exp'], vmin, vmax). (default: ('exp', 1, inf))
82
+ pos_embed (str): Position embedding type. (default: 'RoPE100')
83
+ landscape_only (bool): Run downstream head only in landscape orientation. (default: True)
84
+ pretrained_checkpoint_path (str): Path to pretrained checkpoint. (default: None)
85
+ pretrained_encoder_checkpoint_path (str): Path to pretrained encoder checkpoint. (default: None)
86
+ pretrained_info_sharing_checkpoint_path (str): Path to pretrained info_sharing checkpoint. (default: None)
87
+ pretrained_pred_head_checkpoint_paths (List[str]): Paths to pretrained prediction head checkpoints. (default: None)
88
+ pretrained_pred_head_regressor_checkpoint_paths (List[str]): Paths to pretrained prediction head regressor checkpoints. (default: None)
89
+ override_encoder_checkpoint_attributes (bool): Whether to override encoder checkpoint attributes. (default: False)
90
+ """
91
+ super().__init__(*args, **kwargs)
92
+
93
+ # Initalize the attributes
94
+ self.name = name
95
+ self.data_norm_type = data_norm_type
96
+ self.img_size = img_size
97
+ self.patch_embed_cls = patch_embed_cls
98
+ self.pred_head_type = pred_head_type
99
+ self.pred_head_output_dim = pred_head_output_dim
100
+ self.depth_mode = depth_mode
101
+ self.conf_mode = conf_mode
102
+ self.pos_embed = pos_embed
103
+ self.pretrained_checkpoint_path = pretrained_checkpoint_path
104
+ self.pretrained_encoder_checkpoint_path = pretrained_encoder_checkpoint_path
105
+ self.pretrained_info_sharing_checkpoint_path = pretrained_info_sharing_checkpoint_path
106
+ self.pretrained_pred_head_checkpoint_paths = pretrained_pred_head_checkpoint_paths
107
+ self.pretrained_pred_head_regressor_checkpoint_paths = pretrained_pred_head_regressor_checkpoint_paths
108
+ self.override_encoder_checkpoint_attributes = override_encoder_checkpoint_attributes
109
+
110
+ # Initialize RoPE for the CroCo Encoder & Two-View Cross Attention Transformer
111
+ freq = float(pos_embed[len("RoPE") :])
112
+ self.rope = RoPE2D(freq=freq)
113
+
114
+ # Initialize Encoder
115
+ self.encoder = CroCoEncoder(
116
+ name=name,
117
+ data_norm_type=data_norm_type,
118
+ patch_embed_cls=patch_embed_cls,
119
+ img_size=img_size,
120
+ pretrained_checkpoint_path=pretrained_encoder_checkpoint_path,
121
+ override_checkpoint_attributes=override_encoder_checkpoint_attributes,
122
+ )
123
+
124
+ # Initialize Multi-View Cross Attention Transformer
125
+ if self.pred_head_type == "linear":
126
+ # Returns only normalized last layer features
127
+ self.info_sharing = MultiViewCrossAttentionTransformer(
128
+ name="base_info_sharing",
129
+ input_embed_dim=self.encoder.enc_embed_dim,
130
+ num_views=2,
131
+ custom_positional_encoding=self.rope,
132
+ pretrained_checkpoint_path=pretrained_info_sharing_checkpoint_path,
133
+ )
134
+ elif self.pred_head_type == "dpt":
135
+ # Returns intermediate features and normalized last layer features
136
+ self.info_sharing = MultiViewCrossAttentionTransformerIFR(
137
+ name="base_info_sharing",
138
+ input_embed_dim=self.encoder.enc_embed_dim,
139
+ num_views=2,
140
+ indices=[5, 8],
141
+ norm_intermediate=False,
142
+ custom_positional_encoding=self.rope,
143
+ pretrained_checkpoint_path=pretrained_info_sharing_checkpoint_path,
144
+ )
145
+ else:
146
+ raise ValueError(f"Invalid prediction head type: {pred_head_type}. Must be 'linear' or 'dpt'.")
147
+
148
+ # Initialize Prediction Heads
149
+ if pred_head_type == "linear":
150
+ # Initialize Prediction Head 1
151
+ self.head1 = LinearFeature(
152
+ input_feature_dim=self.info_sharing.dim,
153
+ output_dim=pred_head_output_dim,
154
+ patch_size=self.encoder.patch_size,
155
+ pretrained_checkpoint_path=pretrained_pred_head_checkpoint_paths[0],
156
+ )
157
+ # Initialize Prediction Head 2
158
+ self.head2 = LinearFeature(
159
+ input_feature_dim=self.info_sharing.dim,
160
+ output_dim=pred_head_output_dim,
161
+ patch_size=self.encoder.patch_size,
162
+ pretrained_checkpoint_path=pretrained_pred_head_checkpoint_paths[1],
163
+ )
164
+ elif pred_head_type == "dpt":
165
+ # Initialze Predction Head 1
166
+ self.dpt_feature_head1 = DPTFeature(
167
+ patch_size=self.encoder.patch_size,
168
+ hooks=[0, 1, 2, 3],
169
+ input_feature_dims=[self.encoder.enc_embed_dim] + [self.info_sharing.dim] * 3,
170
+ feature_dim=pred_head_feature_dim,
171
+ pretrained_checkpoint_path=pretrained_pred_head_checkpoint_paths[0],
172
+ )
173
+ self.dpt_regressor_head1 = DPTRegressionProcessor(
174
+ input_feature_dim=pred_head_feature_dim,
175
+ output_dim=pred_head_output_dim,
176
+ pretrained_checkpoint_path=pretrained_pred_head_regressor_checkpoint_paths[0],
177
+ )
178
+ self.head1 = nn.Sequential(self.dpt_feature_head1, self.dpt_regressor_head1)
179
+ # Initialize Prediction Head 2
180
+ self.dpt_feature_head2 = DPTFeature(
181
+ patch_size=self.encoder.patch_size,
182
+ hooks=[0, 1, 2, 3],
183
+ input_feature_dims=[self.encoder.enc_embed_dim] + [self.info_sharing.dim] * 3,
184
+ feature_dim=pred_head_feature_dim,
185
+ pretrained_checkpoint_path=pretrained_pred_head_checkpoint_paths[1],
186
+ )
187
+ self.dpt_regressor_head2 = DPTRegressionProcessor(
188
+ input_feature_dim=pred_head_feature_dim,
189
+ output_dim=pred_head_output_dim,
190
+ pretrained_checkpoint_path=pretrained_pred_head_regressor_checkpoint_paths[1],
191
+ )
192
+ self.head2 = nn.Sequential(self.dpt_feature_head2, self.dpt_regressor_head2)
193
+
194
+ # Initialize Final Output Adaptor
195
+ self.adaptor = PointMapWithConfidenceAdaptor(
196
+ name="pointmap",
197
+ pointmap_mode=depth_mode[0],
198
+ pointmap_vmin=depth_mode[1],
199
+ pointmap_vmax=depth_mode[2],
200
+ confidence_type=conf_mode[0],
201
+ confidence_vmin=conf_mode[1],
202
+ confidence_vmax=conf_mode[2],
203
+ )
204
+
205
+ # Load pretrained weights
206
+ if self.pretrained_checkpoint_path is not None:
207
+ print(f"Loading pretrained DUSt3R weights from {self.pretrained_checkpoint_path} ...")
208
+ ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
209
+ print(self.load_state_dict(ckpt["model"]))
210
+
211
+ def _encode_image_pairs(self, img1, img2, data_norm_type):
212
+ "Encode two different batches of images (each batch can have different image shape)"
213
+ if img1.shape[-2:] == img2.shape[-2:]:
214
+ encoder_input = ViTEncoderInput(image=torch.cat((img1, img2), dim=0), data_norm_type=data_norm_type)
215
+ encoder_output = self.encoder(encoder_input)
216
+ out, out2 = encoder_output.features.chunk(2, dim=0)
217
+ else:
218
+ encoder_input = ViTEncoderInput(image=img1, data_norm_type=data_norm_type)
219
+ out = self.encoder(encoder_input)
220
+ out = out.features
221
+ encoder_input2 = ViTEncoderInput(image=img2)
222
+ out2 = self.encoder(encoder_input2)
223
+ out2 = out2.features
224
+
225
+ return out, out2
226
+
227
+ def _encode_symmetrized(self, view1, view2):
228
+ "Encode image pairs accounting for symmetrization, i.e., (a, b) and (b, a) always exist in the input"
229
+ img1 = view1["img"]
230
+ img2 = view2["img"]
231
+ if is_symmetrized(view1, view2):
232
+ # Computing half of forward pass'
233
+ feat1, feat2 = self._encode_image_pairs(img1[::2], img2[::2], data_norm_type=view1["data_norm_type"])
234
+ feat1, feat2 = interleave(feat1, feat2)
235
+ else:
236
+ feat1, feat2 = self._encode_image_pairs(img1, img2, data_norm_type=view1["data_norm_type"])
237
+
238
+ return feat1, feat2
239
+
240
+ def _downstream_head(self, head_num, decout, img_shape):
241
+ "Run the respective prediction heads"
242
+ head = getattr(self, f"head{head_num}")
243
+ if self.pred_head_type == "linear":
244
+ head_input = PredictionHeadInput(last_feature=decout[f"{head_num}"])
245
+ elif self.pred_head_type == "dpt":
246
+ head_input = PredictionHeadLayeredInput(list_features=decout[f"{head_num}"], target_output_shape=img_shape)
247
+
248
+ return head(head_input)
249
+
250
+ def forward(self, view1, view2):
251
+ """
252
+ Forward pass for DUSt3R performing the following operations:
253
+ 1. Encodes the two input views (images).
254
+ 2. Combines the encoded features using a two-view cross-attention transformer.
255
+ 3. Passes the combined features through the respective prediction heads.
256
+ 4. Returns the processed final outputs for both views.
257
+
258
+ Args:
259
+ view1 (dict): Dictionary containing the first view's images and instance information.
260
+ "img" is a required key and value is a tensor of shape (B, C, H, W).
261
+ view2 (dict): Dictionary containing the second view's images and instance information.
262
+ "img" is a required key and value is a tensor of shape (B, C, H, W).
263
+
264
+ Returns:
265
+ Tuple[dict, dict]: A tuple containing the final outputs for both views.
266
+ """
267
+ # Get input shapes
268
+ _, _, height1, width1 = view1["img"].shape
269
+ _, _, height2, width2 = view2["img"].shape
270
+ shape1 = (int(height1), int(width1))
271
+ shape2 = (int(height2), int(width2))
272
+
273
+ # Encode the two images --> Each feat output: BCHW features (batch_size, feature_dim, feature_height, feature_width)
274
+ feat1, feat2 = self._encode_symmetrized(view1, view2)
275
+
276
+ # Combine all images into view-centric representation
277
+ info_sharing_input = MultiViewTransformerInput(features=[feat1, feat2])
278
+ if self.pred_head_type == "linear":
279
+ final_info_sharing_multi_view_feat = self.info_sharing(info_sharing_input)
280
+ elif self.pred_head_type == "dpt":
281
+ final_info_sharing_multi_view_feat, intermediate_info_sharing_multi_view_feat = self.info_sharing(
282
+ info_sharing_input
283
+ )
284
+
285
+ if self.pred_head_type == "linear":
286
+ # Define feature dictionary for linear head
287
+ info_sharing_outputs = {
288
+ "1": final_info_sharing_multi_view_feat.features[0].float(),
289
+ "2": final_info_sharing_multi_view_feat.features[1].float(),
290
+ }
291
+ elif self.pred_head_type == "dpt":
292
+ # Define feature dictionary for DPT head
293
+ info_sharing_outputs = {
294
+ "1": [
295
+ feat1.float(),
296
+ intermediate_info_sharing_multi_view_feat[0].features[0].float(),
297
+ intermediate_info_sharing_multi_view_feat[1].features[0].float(),
298
+ final_info_sharing_multi_view_feat.features[0].float(),
299
+ ],
300
+ "2": [
301
+ feat2.float(),
302
+ intermediate_info_sharing_multi_view_feat[0].features[1].float(),
303
+ intermediate_info_sharing_multi_view_feat[1].features[1].float(),
304
+ final_info_sharing_multi_view_feat.features[1].float(),
305
+ ],
306
+ }
307
+
308
+ # Downstream task prediction
309
+ with torch.autocast("cuda", enabled=False):
310
+ # Prediction heads
311
+ head_output1 = self._downstream_head(1, info_sharing_outputs, shape1)
312
+ head_output2 = self._downstream_head(2, info_sharing_outputs, shape2)
313
+
314
+ # Post-process outputs
315
+ final_output1 = self.adaptor(
316
+ AdaptorInput(adaptor_feature=head_output1.decoded_channels, output_shape_hw=shape1)
317
+ )
318
+ final_output2 = self.adaptor(
319
+ AdaptorInput(adaptor_feature=head_output2.decoded_channels, output_shape_hw=shape2)
320
+ )
321
+
322
+ # Convert outputs to dictionary
323
+ res1 = {
324
+ "pts3d": final_output1.value.permute(0, 2, 3, 1).contiguous(),
325
+ "conf": final_output1.confidence.permute(0, 2, 3, 1).contiguous(),
326
+ }
327
+ res2 = {
328
+ "pts3d_in_other_view": final_output2.value.permute(0, 2, 3, 1).contiguous(),
329
+ "conf": final_output2.confidence.permute(0, 2, 3, 1).contiguous(),
330
+ }
331
+
332
+ return res1, res2
UniCeption/uniception/models/info_sharing/README.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # UniCeption Information Sharing Blocks
2
+
3
+ ## Currently Supported Information Sharing Architectures
4
+
5
+ ### UniCeptionInfoSharingBase:
6
+
7
+ - `MultiViewCrossAttentionTransformer`
8
+ - `MultiViewCrossAttentionTransformerIFR`
9
+ - `MultiViewGlobalAttentionTransformer`
10
+ - `MultiViewGlobalAttentionTransformerIFR`
11
+ - `MultiViewAlternatingAttentionTransformer`
12
+ - `MultiViewAlternatingAttentionTransformerIFR`
13
+
14
+ ## Developer Guidelines
15
+
16
+ Please follow the main UniCeption developer guidelines described in `README.md` when contributing to the information sharing blocks. Make sure to test your different implementations and add necessary unit tests.
17
+
18
+ ## Happy Coding!
UniCeption/uniception/models/info_sharing/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from uniception.models.info_sharing.alternating_attention_transformer import (
2
+ MultiViewAlternatingAttentionTransformer,
3
+ MultiViewAlternatingAttentionTransformerIFR,
4
+ )
5
+ from uniception.models.info_sharing.cross_attention_transformer import (
6
+ MultiViewCrossAttentionTransformer,
7
+ MultiViewCrossAttentionTransformerIFR,
8
+ MultiViewTransformerInput,
9
+ )
10
+ from uniception.models.info_sharing.diff_cross_attention_transformer import (
11
+ DifferentialMultiViewCrossAttentionTransformer,
12
+ DifferentialMultiViewCrossAttentionTransformerIFR,
13
+ )
14
+ from uniception.models.info_sharing.global_attention_transformer import (
15
+ MultiViewGlobalAttentionTransformer,
16
+ MultiViewGlobalAttentionTransformerIFR,
17
+ )
18
+
19
+ INFO_SHARING_CLASSES = {
20
+ "cross_attention": (MultiViewCrossAttentionTransformer, MultiViewCrossAttentionTransformerIFR),
21
+ "diff_cross_attention": (
22
+ DifferentialMultiViewCrossAttentionTransformer,
23
+ DifferentialMultiViewCrossAttentionTransformerIFR,
24
+ ),
25
+ "alternating_attention": (
26
+ MultiViewAlternatingAttentionTransformer,
27
+ MultiViewAlternatingAttentionTransformerIFR,
28
+ ),
29
+ "global_attention": (
30
+ MultiViewGlobalAttentionTransformer,
31
+ MultiViewGlobalAttentionTransformerIFR,
32
+ ),
33
+ }
34
+
35
+ __all__ = ["INFO_SHARING_CLASSES", "MultiViewTransformerInput"]
UniCeption/uniception/models/info_sharing/alternating_attention_transformer.py ADDED
@@ -0,0 +1,944 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UniCeption Alternating-Attention Transformer for Information Sharing
3
+ """
4
+
5
+ from functools import partial
6
+ from typing import Callable, List, Optional, Tuple, Type, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from uniception.models.info_sharing.base import (
13
+ MultiViewTransformerInput,
14
+ MultiViewTransformerOutput,
15
+ UniCeptionInfoSharingBase,
16
+ )
17
+ from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner, feature_take_indices
18
+ from uniception.models.utils.positional_encoding import PositionGetter
19
+ from uniception.models.utils.transformer_blocks import Mlp, SelfAttentionBlock
20
+
21
+
22
+ class MultiViewAlternatingAttentionTransformer(UniCeptionInfoSharingBase):
23
+ "UniCeption Multi-View Alternating-Attention Transformer for information sharing across image features from different views."
24
+
25
+ def __init__(
26
+ self,
27
+ name: str,
28
+ input_embed_dim: int,
29
+ use_pe_for_non_reference_views: bool = False,
30
+ max_num_views_for_pe: int = 1000,
31
+ use_rand_idx_pe_for_non_reference_views: bool = True,
32
+ size: Optional[str] = None,
33
+ depth: int = 12,
34
+ dim: int = 768,
35
+ num_heads: int = 12,
36
+ mlp_ratio: float = 4.0,
37
+ qkv_bias: bool = True,
38
+ qk_norm: bool = False,
39
+ proj_drop: float = 0.0,
40
+ attn_drop: float = 0.0,
41
+ init_values: Optional[float] = None,
42
+ drop_path: float = 0.0,
43
+ act_layer: Type[nn.Module] = nn.GELU,
44
+ norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(nn.LayerNorm, eps=1e-6),
45
+ mlp_layer: Type[nn.Module] = Mlp,
46
+ custom_positional_encoding: Optional[Callable] = None,
47
+ pretrained_checkpoint_path: Optional[str] = None,
48
+ gradient_checkpointing: bool = False,
49
+ *args,
50
+ **kwargs,
51
+ ):
52
+ """
53
+ Initialize the Multi-View Alternating-Attention Transformer for information sharing across image features from different views.
54
+ Alternates between global and frame-level attention.
55
+
56
+ Args:
57
+ input_embed_dim (int): Dimension of input embeddings.
58
+ use_pe_for_non_reference_views (bool): Whether to use view positional encoding for input non-referenec views. (default: False)
59
+ max_num_views_for_pe (int): Maximum number of views for positional encoding. (default: 1000)
60
+ use_rand_idx_pe_for_non_reference_views (bool): Whether to use random index positional encoding for non-reference views. (default: True)
61
+ size (str): String to indicate interpretable size of the transformer (for e.g., base, large, ...). (default: None)
62
+ depth (int): Number of transformer layers. (default: 12, base size)
63
+ dim (int): Dimension of the transformer. (default: 768, base size)
64
+ num_heads (int): Number of attention heads. (default: 12, base size)
65
+ mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.)
66
+ qkv_bias (bool): Whether to include bias in qkv projection (default: True)
67
+ qk_norm (bool): Whether to normalize q and k (default: False)
68
+ proj_drop (float): Dropout rate for output (default: 0.)
69
+ attn_drop (float): Dropout rate for attention weights (default: 0.)
70
+ init_values (float): Initial value for LayerScale gamma (default: None)
71
+ drop_path (float): Dropout rate for stochastic depth (default: 0.)
72
+ act_layer (nn.Module): Activation layer (default: nn.GELU)
73
+ norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm)
74
+ mlp_layer (nn.Module): MLP layer (default: Mlp)
75
+ custom_positional_encoding (Callable): Custom positional encoding function (default: None)
76
+ pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None)
77
+ gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False)
78
+ """
79
+ # Initialize the base class
80
+ super().__init__(name=name, size=size, *args, **kwargs)
81
+
82
+ # Initialize the specific attributes of the transformer
83
+ self.input_embed_dim = input_embed_dim
84
+ self.use_pe_for_non_reference_views = use_pe_for_non_reference_views
85
+ self.max_num_views_for_pe = max_num_views_for_pe
86
+ self.use_rand_idx_pe_for_non_reference_views = use_rand_idx_pe_for_non_reference_views
87
+ self.depth = depth
88
+ self.dim = dim
89
+ self.num_heads = num_heads
90
+ self.mlp_ratio = mlp_ratio
91
+ self.qkv_bias = qkv_bias
92
+ self.qk_norm = qk_norm
93
+ self.proj_drop = proj_drop
94
+ self.attn_drop = attn_drop
95
+ self.init_values = init_values
96
+ self.drop_path = drop_path
97
+ self.act_layer = act_layer
98
+ self.norm_layer = norm_layer
99
+ self.mlp_layer = mlp_layer
100
+ self.custom_positional_encoding = custom_positional_encoding
101
+ self.pretrained_checkpoint_path = pretrained_checkpoint_path
102
+ self.gradient_checkpointing = gradient_checkpointing
103
+
104
+ # Initialize the projection layer for input embeddings
105
+ if self.input_embed_dim != self.dim:
106
+ self.proj_embed = nn.Linear(self.input_embed_dim, self.dim, bias=True)
107
+ else:
108
+ self.proj_embed = nn.Identity()
109
+
110
+ # Initialize the self-attention blocks which ingest all views at once
111
+ self.self_attention_blocks = nn.ModuleList(
112
+ [
113
+ SelfAttentionBlock(
114
+ dim=self.dim,
115
+ num_heads=self.num_heads,
116
+ mlp_ratio=self.mlp_ratio,
117
+ qkv_bias=self.qkv_bias,
118
+ qk_norm=self.qk_norm,
119
+ proj_drop=self.proj_drop,
120
+ attn_drop=self.attn_drop,
121
+ init_values=self.init_values,
122
+ drop_path=self.drop_path,
123
+ act_layer=self.act_layer,
124
+ norm_layer=self.norm_layer,
125
+ mlp_layer=self.mlp_layer,
126
+ custom_positional_encoding=self.custom_positional_encoding,
127
+ )
128
+ for _ in range(self.depth)
129
+ ]
130
+ )
131
+
132
+ # Initialize the final normalization layer
133
+ self.norm = self.norm_layer(self.dim)
134
+
135
+ # Initialize the position getter for patch positions if required
136
+ if self.custom_positional_encoding is not None:
137
+ self.position_getter = PositionGetter()
138
+
139
+ if self.use_pe_for_non_reference_views:
140
+ # Initialize the positional encoding table for the different views
141
+ self.register_buffer(
142
+ "view_pos_table",
143
+ self._get_sinusoid_encoding_table(self.max_num_views_for_pe, self.dim, 10000),
144
+ )
145
+ else:
146
+ # Initialize the positional encoding table for the reference view
147
+ self.register_buffer(
148
+ "view_pos_table",
149
+ self._get_sinusoid_encoding_table(1, self.dim, 10000),
150
+ )
151
+
152
+ # Initialize random weights
153
+ self.initialize_weights()
154
+
155
+ # Apply gradient checkpointing if enabled
156
+ if self.gradient_checkpointing:
157
+ for i, block in enumerate(self.self_attention_blocks):
158
+ self.self_attention_blocks[i] = self.wrap_module_with_gradient_checkpointing(block)
159
+
160
+ # Load pretrained weights if provided
161
+ if self.pretrained_checkpoint_path is not None:
162
+ print(
163
+ f"Loading pretrained multi-view Alternating-Attention transformer weights from {self.pretrained_checkpoint_path} ..."
164
+ )
165
+ ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
166
+ print(self.load_state_dict(ckpt["model"]))
167
+
168
+ def _get_sinusoid_encoding_table(self, n_position, d_hid, base):
169
+ "Sinusoid position encoding table"
170
+
171
+ def get_position_angle_vec(position):
172
+ return [position / np.power(base, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
173
+
174
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
175
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
176
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])
177
+
178
+ return torch.FloatTensor(sinusoid_table)
179
+
180
+ def initialize_weights(self):
181
+ "Initialize weights of the transformer."
182
+ # Linears and layer norms
183
+ self.apply(self._init_weights)
184
+
185
+ def _init_weights(self, m):
186
+ "Initialize the transformer linear and layer norm weights."
187
+ if isinstance(m, nn.Linear):
188
+ # We use xavier_uniform following official JAX ViT:
189
+ torch.nn.init.xavier_uniform_(m.weight)
190
+ if isinstance(m, nn.Linear) and m.bias is not None:
191
+ nn.init.constant_(m.bias, 0)
192
+ elif isinstance(m, nn.LayerNorm):
193
+ nn.init.constant_(m.bias, 0)
194
+ nn.init.constant_(m.weight, 1.0)
195
+
196
+ def forward(
197
+ self,
198
+ model_input: MultiViewTransformerInput,
199
+ ) -> MultiViewTransformerOutput:
200
+ """
201
+ Forward interface for the Multi-View Alternating-Attention Transformer.
202
+
203
+ Args:
204
+ model_input (MultiViewTransformerInput): Input to the model.
205
+ Expects the features to be a list of size (batch, input_embed_dim, height, width),
206
+ where each entry corresponds to a different view.
207
+ Optionally, the input can also include additional_input_tokens (e.g., class token, registers, pose tokens, scale token)
208
+ which are appended to the token set from the multi-view features. The tokens are of size (batch, input_embed_dim, num_of_additional_tokens).
209
+
210
+ Returns:
211
+ MultiViewTransformerOutput: Output of the model post information sharing.
212
+ """
213
+ # Check that the number of views matches the input and the features are of expected shape
214
+ if self.use_pe_for_non_reference_views:
215
+ assert (
216
+ len(model_input.features) <= self.max_num_views_for_pe
217
+ ), f"Expected less than {self.max_num_views_for_pe} views, got {len(model_input.features)}"
218
+ assert all(
219
+ view_features.shape[1] == self.input_embed_dim for view_features in model_input.features
220
+ ), f"All views must have input dimension {self.input_embed_dim}"
221
+ assert all(
222
+ view_features.ndim == 4 for view_features in model_input.features
223
+ ), "All views must have 4 dimensions (N, C, H, W)"
224
+
225
+ # Initialize the multi-view features from the model input and number of views for current input
226
+ multi_view_features = model_input.features
227
+ num_of_views = len(multi_view_features)
228
+ batch_size, _, height, width = multi_view_features[0].shape
229
+ num_of_tokens_per_view = height * width
230
+
231
+ # Stack the multi-view features (N, C, H, W) to (N, V, C, H, W) (assumes all V views have same shape)
232
+ multi_view_features = torch.stack(multi_view_features, dim=1)
233
+
234
+ # Resize the multi-view features from NVCHW to NLC, where L = V * H * W
235
+ multi_view_features = multi_view_features.permute(0, 1, 3, 4, 2) # (N, V, H, W, C)
236
+ multi_view_features = multi_view_features.reshape(
237
+ batch_size, num_of_views * height * width, self.input_embed_dim
238
+ ).contiguous()
239
+
240
+ # Process additional input tokens if provided
241
+ if model_input.additional_input_tokens is not None:
242
+
243
+ additional_tokens = model_input.additional_input_tokens
244
+ assert additional_tokens.ndim == 3, "Additional tokens must have 3 dimensions (N, C, T)"
245
+ assert (
246
+ additional_tokens.shape[1] == self.input_embed_dim
247
+ ), f"Additional tokens must have input dimension {self.input_embed_dim}"
248
+ assert additional_tokens.shape[0] == batch_size, "Batch size mismatch for additional tokens"
249
+
250
+ # Reshape to channel-last format for transformer processing
251
+ additional_tokens = additional_tokens.permute(0, 2, 1).contiguous() # (N, C, T) -> (N, T, C)
252
+
253
+ # Concatenate the additional tokens to the multi-view features
254
+ multi_view_features = torch.cat([multi_view_features, additional_tokens], dim=1)
255
+
256
+ # Project input features to the transformer dimension
257
+ multi_view_features = self.proj_embed(multi_view_features)
258
+
259
+ # Create patch positions for each view if custom positional encoding is used
260
+ if self.custom_positional_encoding is not None:
261
+ multi_view_positions = [
262
+ self.position_getter(batch_size, height, width, multi_view_features.device)
263
+ ] * num_of_views # List of length V, where each tensor is (N, H * W, C)
264
+ multi_view_positions = torch.cat(multi_view_positions, dim=1) # (N, V * H * W, C)
265
+ else:
266
+ multi_view_positions = [None] * num_of_views
267
+
268
+ # Add None positions for additional tokens if they exist
269
+ if model_input.additional_input_tokens is not None:
270
+
271
+ additional_tokens_positions = [None] * model_input.additional_input_tokens.shape[1]
272
+ multi_view_positions = multi_view_positions + additional_tokens_positions
273
+
274
+ # Add positional encoding for reference view (idx 0)
275
+ ref_view_pe = self.view_pos_table[0].clone().detach()
276
+ ref_view_pe = ref_view_pe.reshape((1, 1, self.dim))
277
+ ref_view_pe = ref_view_pe.repeat(batch_size, num_of_tokens_per_view, 1)
278
+ ref_view_features = multi_view_features[:, :num_of_tokens_per_view, :]
279
+ ref_view_features = ref_view_features + ref_view_pe
280
+
281
+ if self.use_pe_for_non_reference_views:
282
+ # Add positional encoding for non-reference views (sequential indices starting from idx 1 or random indices which are uniformly sampled)
283
+ if self.use_rand_idx_pe_for_non_reference_views:
284
+ non_ref_view_pe_indices = torch.randint(low=1, high=self.max_num_views_for_pe, size=(num_of_views - 1,))
285
+ else:
286
+ non_ref_view_pe_indices = torch.arange(1, num_of_views)
287
+ non_ref_view_pe = self.view_pos_table[non_ref_view_pe_indices].clone().detach()
288
+ non_ref_view_pe = non_ref_view_pe.reshape((1, num_of_views - 1, self.dim))
289
+ non_ref_view_pe = non_ref_view_pe.repeat_interleave(num_of_tokens_per_view, dim=1)
290
+ non_ref_view_pe = non_ref_view_pe.repeat(batch_size, 1, 1)
291
+ non_ref_view_features = multi_view_features[
292
+ :, num_of_tokens_per_view : num_of_views * num_of_tokens_per_view, :
293
+ ]
294
+ non_ref_view_features = non_ref_view_features + non_ref_view_pe
295
+ else:
296
+ non_ref_view_features = multi_view_features[
297
+ :, num_of_tokens_per_view : num_of_views * num_of_tokens_per_view, :
298
+ ]
299
+
300
+ # Concatenate the reference and non-reference view features
301
+ # Handle additional tokens (no view-based positional encoding for them)
302
+ if model_input.additional_input_tokens is not None:
303
+
304
+ additional_features = multi_view_features[:, num_of_views * num_of_tokens_per_view :, :]
305
+ multi_view_features = torch.cat([ref_view_features, non_ref_view_features, additional_features], dim=1)
306
+ else:
307
+ multi_view_features = torch.cat([ref_view_features, non_ref_view_features], dim=1)
308
+
309
+ # Loop over the depth of the transformer
310
+ for depth_idx in range(self.depth):
311
+ if depth_idx % 2 == 0:
312
+ # Apply the self-attention block and update the multi-view features
313
+ # Global attention across all views
314
+ multi_view_features = self.self_attention_blocks[depth_idx](multi_view_features, multi_view_positions)
315
+ else:
316
+ # Handle additional tokens separately for frame-level attention
317
+ additional_features = None
318
+ additional_positions = None
319
+ if model_input.additional_input_tokens is not None:
320
+
321
+ # Extract additional token features
322
+ additional_features = multi_view_features[:, num_of_views * num_of_tokens_per_view :, :]
323
+ # Keep only view features for frame-level attention
324
+ multi_view_features = multi_view_features[:, : num_of_views * num_of_tokens_per_view, :]
325
+
326
+ # Handle positions for additional tokens if custom positional encoding is used
327
+ if self.custom_positional_encoding is not None:
328
+ additional_positions = multi_view_positions[:, num_of_views * num_of_tokens_per_view :, :]
329
+ multi_view_positions = multi_view_positions[:, : num_of_views * num_of_tokens_per_view, :]
330
+
331
+ # Reshape the multi-view features from (N, V * H * W, C) to (N * V, H * W, C)
332
+ multi_view_features = multi_view_features.reshape(
333
+ batch_size * num_of_views, num_of_tokens_per_view, self.dim
334
+ ).contiguous() # (N * V, H * W, C)
335
+ if multi_view_positions[0] is not None:
336
+ multi_view_positions = multi_view_positions.reshape(
337
+ batch_size * num_of_views, num_of_tokens_per_view, 2
338
+ ).contiguous() # (N * V, H * W, C)
339
+
340
+ # Apply the self-attention block and update the multi-view features
341
+ # Frame-level attention within each view
342
+ multi_view_features = self.self_attention_blocks[depth_idx](multi_view_features, multi_view_positions)
343
+
344
+ # Reshape the multi-view features from (N * V, H * W, C) back to (N, V * H * W, C)
345
+ multi_view_features = multi_view_features.reshape(
346
+ batch_size, num_of_views * num_of_tokens_per_view, self.dim
347
+ ).contiguous() # (N, V * H * W, C)
348
+ if multi_view_positions[0] is not None:
349
+ multi_view_positions = multi_view_positions.reshape(
350
+ batch_size, num_of_views * num_of_tokens_per_view, 2
351
+ ).contiguous() # (N, V * H * W, C)
352
+
353
+ # Reattach additional tokens if they exist
354
+ if additional_features is not None:
355
+ multi_view_features = torch.cat([multi_view_features, additional_features], dim=1)
356
+ # Reattach positions for additional tokens if they exist
357
+ if additional_positions is not None:
358
+ multi_view_positions = torch.cat([multi_view_positions, additional_positions], dim=1)
359
+
360
+ # Normalize the output features
361
+ output_multi_view_features = self.norm(multi_view_features)
362
+
363
+ # Extract only the view features (excluding additional tokens)
364
+ view_features = output_multi_view_features[:, : num_of_views * num_of_tokens_per_view, :]
365
+
366
+ # Reshape the output multi-view features (N, V * H * W, C) back to (N, V, C, H, W)
367
+ view_features = view_features.reshape(batch_size, num_of_views, height, width, self.dim) # (N, V, H, W, C)
368
+ view_features = view_features.permute(0, 1, 4, 2, 3).contiguous() # (N, V, C, H, W)
369
+
370
+ # Split the output multi-view features into separate views
371
+ view_features = view_features.split(1, dim=1)
372
+ view_features = [output_view_features.squeeze(dim=1) for output_view_features in view_features]
373
+
374
+ # Extract and return additional token features if provided
375
+ if model_input.additional_input_tokens is not None:
376
+
377
+ additional_token_features = output_multi_view_features[:, num_of_views * num_of_tokens_per_view :, :]
378
+ additional_token_features = additional_token_features.permute(0, 2, 1).contiguous() # (N, C, T)
379
+ return MultiViewTransformerOutput(
380
+ features=view_features, additional_token_features=additional_token_features
381
+ )
382
+ else:
383
+ return MultiViewTransformerOutput(features=view_features)
384
+
385
+
386
+ class MultiViewAlternatingAttentionTransformerIFR(
387
+ MultiViewAlternatingAttentionTransformer, IntermediateFeatureReturner
388
+ ):
389
+ "Intermediate Feature Returner for UniCeption Multi-View Alternating-Attention Transformer"
390
+
391
+ def __init__(
392
+ self,
393
+ name: str,
394
+ input_embed_dim: int,
395
+ use_pe_for_non_reference_views: bool = False,
396
+ max_num_views_for_pe: int = 1000,
397
+ use_rand_idx_pe_for_non_reference_views: bool = True,
398
+ size: Optional[str] = None,
399
+ depth: int = 12,
400
+ dim: int = 768,
401
+ num_heads: int = 12,
402
+ mlp_ratio: float = 4.0,
403
+ qkv_bias: bool = True,
404
+ qk_norm: bool = False,
405
+ proj_drop: float = 0.0,
406
+ attn_drop: float = 0.0,
407
+ init_values: Optional[float] = None,
408
+ drop_path: float = 0.0,
409
+ act_layer: nn.Module = nn.GELU,
410
+ norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
411
+ mlp_layer: nn.Module = Mlp,
412
+ custom_positional_encoding: Callable = None,
413
+ pretrained_checkpoint_path: str = None,
414
+ indices: Optional[Union[int, List[int]]] = None,
415
+ norm_intermediate: bool = True,
416
+ intermediates_only: bool = False,
417
+ gradient_checkpointing: bool = False,
418
+ *args,
419
+ **kwargs,
420
+ ):
421
+ """
422
+ Initialize the Multi-View Alternating-Attention Transformer for information sharing across image features from different views.
423
+ Extends the base class to return intermediate features.
424
+
425
+ Args:
426
+ input_embed_dim (int): Dimension of input embeddings.
427
+ use_pe_for_non_reference_views (bool): Whether to use view positional encoding for input non-referenec views. (default: False)
428
+ max_num_views_for_pe (int): Maximum number of views for positional encoding. (default: 1000)
429
+ use_rand_idx_pe_for_non_reference_views (bool): Whether to use random index positional encoding for non-reference views. (default: True)
430
+ use_rand_idx_pe_for_non_reference_views (bool): Whether to use random index positional encoding for non-reference views.
431
+ size (str): String to indicate interpretable size of the transformer (for e.g., base, large, ...). (default: None)
432
+ depth (int): Number of transformer layers. (default: 12, base size)
433
+ dim (int): Dimension of the transformer. (default: 768, base size)
434
+ num_heads (int): Number of attention heads. (default: 12, base size)
435
+ mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.)
436
+ qkv_bias (bool): Whether to include bias in qkv projection (default: False)
437
+ qk_norm (bool): Whether to normalize q and k (default: False)
438
+ proj_drop (float): Dropout rate for output (default: 0.)
439
+ attn_drop (float): Dropout rate for attention weights (default: 0.)
440
+ init_values (float): Initial value for LayerScale gamma (default: None)
441
+ drop_path (float): Dropout rate for stochastic depth (default: 0.)
442
+ act_layer (nn.Module): Activation layer (default: nn.GELU)
443
+ norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm)
444
+ mlp_layer (nn.Module): MLP layer (default: Mlp)
445
+ custom_positional_encoding (Callable): Custom positional encoding function (default: None)
446
+ pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None)
447
+ indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. (default: None) Options:
448
+ - None: Return all intermediate layers.
449
+ - int: Return the last n layers.
450
+ - List[int]: Return the intermediate layers at the specified indices.
451
+ norm_intermediate (bool, optional): Whether to normalize the intermediate features. (default: True)
452
+ intermediates_only (bool, optional): Whether to return only the intermediate features. (default: False)
453
+ gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False)
454
+ """
455
+ # Init the base classes
456
+ MultiViewAlternatingAttentionTransformer.__init__(
457
+ self,
458
+ name=name,
459
+ input_embed_dim=input_embed_dim,
460
+ use_pe_for_non_reference_views=use_pe_for_non_reference_views,
461
+ max_num_views_for_pe=max_num_views_for_pe,
462
+ use_rand_idx_pe_for_non_reference_views=use_rand_idx_pe_for_non_reference_views,
463
+ size=size,
464
+ depth=depth,
465
+ dim=dim,
466
+ num_heads=num_heads,
467
+ mlp_ratio=mlp_ratio,
468
+ qkv_bias=qkv_bias,
469
+ qk_norm=qk_norm,
470
+ proj_drop=proj_drop,
471
+ attn_drop=attn_drop,
472
+ init_values=init_values,
473
+ drop_path=drop_path,
474
+ act_layer=act_layer,
475
+ norm_layer=norm_layer,
476
+ mlp_layer=mlp_layer,
477
+ custom_positional_encoding=custom_positional_encoding,
478
+ pretrained_checkpoint_path=pretrained_checkpoint_path,
479
+ gradient_checkpointing=gradient_checkpointing,
480
+ *args,
481
+ **kwargs,
482
+ )
483
+ IntermediateFeatureReturner.__init__(
484
+ self,
485
+ indices=indices,
486
+ norm_intermediate=norm_intermediate,
487
+ intermediates_only=intermediates_only,
488
+ )
489
+
490
+ def forward(
491
+ self,
492
+ model_input: MultiViewTransformerInput,
493
+ ) -> Union[
494
+ List[MultiViewTransformerOutput],
495
+ Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]],
496
+ ]:
497
+ """
498
+ Forward interface for the Multi-View Alternating-Attention Transformer with Intermediate Feature Return.
499
+
500
+ Args:
501
+ model_input (MultiViewTransformerInput): Input to the model.
502
+ Expects the features to be a list of size (batch, input_embed_dim, height, width),
503
+ where each entry corresponds to a different view.
504
+ Optionally, the input can also include additional_input_tokens (e.g., class token, registers, pose tokens, scale token)
505
+ which are appended to the token set from the multi-view features. The tokens are of size (batch, input_embed_dim, num_of_additional_tokens).
506
+
507
+ Returns:
508
+ Union[List[MultiViewTransformerOutput], Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]]]:
509
+ Output of the model post information sharing.
510
+ If intermediates_only is True, returns a list of intermediate outputs.
511
+ If intermediates_only is False, returns a tuple of final output and a list of intermediate outputs.
512
+ """
513
+ # Check that the number of views matches the input and the features are of expected shape
514
+ if self.use_pe_for_non_reference_views:
515
+ assert (
516
+ len(model_input.features) <= self.max_num_views_for_pe
517
+ ), f"Expected less than {self.max_num_views_for_pe} views, got {len(model_input.features)}"
518
+ assert all(
519
+ view_features.shape[1] == self.input_embed_dim for view_features in model_input.features
520
+ ), f"All views must have input dimension {self.input_embed_dim}"
521
+ assert all(
522
+ view_features.ndim == 4 for view_features in model_input.features
523
+ ), "All views must have 4 dimensions (N, C, H, W)"
524
+
525
+ # Get the indices of the intermediate features to return
526
+ intermediate_multi_view_features = []
527
+ take_indices, _ = feature_take_indices(self.depth, self.indices)
528
+
529
+ # Initialize the multi-view features from the model input and number of views for current input
530
+ multi_view_features = model_input.features
531
+ num_of_views = len(multi_view_features)
532
+ batch_size, _, height, width = multi_view_features[0].shape
533
+ num_of_tokens_per_view = height * width
534
+
535
+ # Stack the multi-view features (N, C, H, W) to (N, V, C, H, W) (assumes all V views have same shape)
536
+ multi_view_features = torch.stack(multi_view_features, dim=1)
537
+
538
+ # Resize the multi-view features from NVCHW to NLC, where L = V * H * W
539
+ multi_view_features = multi_view_features.permute(0, 1, 3, 4, 2) # (N, V, H, W, C)
540
+ multi_view_features = multi_view_features.reshape(
541
+ batch_size, num_of_views * height * width, self.input_embed_dim
542
+ ).contiguous()
543
+
544
+ # Process additional input tokens if provided
545
+ if model_input.additional_input_tokens is not None:
546
+
547
+ additional_tokens = model_input.additional_input_tokens
548
+ assert additional_tokens.ndim == 3, "Additional tokens must have 3 dimensions (N, C, T)"
549
+ assert (
550
+ additional_tokens.shape[1] == self.input_embed_dim
551
+ ), f"Additional tokens must have input dimension {self.input_embed_dim}"
552
+ assert additional_tokens.shape[0] == batch_size, "Batch size mismatch for additional tokens"
553
+
554
+ # Reshape to channel-last format for transformer processing
555
+ additional_tokens = additional_tokens.permute(0, 2, 1).contiguous() # (N, C, T) -> (N, T, C)
556
+
557
+ # Concatenate the additional tokens to the multi-view features
558
+ multi_view_features = torch.cat([multi_view_features, additional_tokens], dim=1)
559
+
560
+ # Project input features to the transformer dimension
561
+ multi_view_features = self.proj_embed(multi_view_features)
562
+
563
+ # Create patch positions for each view if custom positional encoding is used
564
+ if self.custom_positional_encoding is not None:
565
+ multi_view_positions = [
566
+ self.position_getter(batch_size, height, width, multi_view_features.device)
567
+ ] * num_of_views # List of length V, where each tensor is (N, H * W, C)
568
+ multi_view_positions = torch.cat(multi_view_positions, dim=1) # (N, V * H * W, C)
569
+ else:
570
+ multi_view_positions = [None] * num_of_views
571
+
572
+ # Add None positions for additional tokens if they exist
573
+ if model_input.additional_input_tokens is not None:
574
+
575
+ additional_tokens_positions = [None] * model_input.additional_input_tokens.shape[1]
576
+ multi_view_positions = multi_view_positions + additional_tokens_positions
577
+
578
+ # Add positional encoding for reference view (idx 0)
579
+ ref_view_pe = self.view_pos_table[0].clone().detach()
580
+ ref_view_pe = ref_view_pe.reshape((1, 1, self.dim))
581
+ ref_view_pe = ref_view_pe.repeat(batch_size, num_of_tokens_per_view, 1)
582
+ ref_view_features = multi_view_features[:, :num_of_tokens_per_view, :]
583
+ ref_view_features = ref_view_features + ref_view_pe
584
+
585
+ if self.use_pe_for_non_reference_views:
586
+ # Add positional encoding for non-reference views (sequential indices starting from idx 1 or random indices which are uniformly sampled)
587
+ if self.use_rand_idx_pe_for_non_reference_views:
588
+ non_ref_view_pe_indices = torch.randint(low=1, high=self.max_num_views_for_pe, size=(num_of_views - 1,))
589
+ else:
590
+ non_ref_view_pe_indices = torch.arange(1, num_of_views)
591
+ non_ref_view_pe = self.view_pos_table[non_ref_view_pe_indices].clone().detach()
592
+ non_ref_view_pe = non_ref_view_pe.reshape((1, num_of_views - 1, self.dim))
593
+ non_ref_view_pe = non_ref_view_pe.repeat_interleave(num_of_tokens_per_view, dim=1)
594
+ non_ref_view_pe = non_ref_view_pe.repeat(batch_size, 1, 1)
595
+ non_ref_view_features = multi_view_features[
596
+ :, num_of_tokens_per_view : num_of_views * num_of_tokens_per_view, :
597
+ ]
598
+ non_ref_view_features = non_ref_view_features + non_ref_view_pe
599
+ else:
600
+ non_ref_view_features = multi_view_features[
601
+ :, num_of_tokens_per_view : num_of_views * num_of_tokens_per_view, :
602
+ ]
603
+
604
+ # Concatenate the reference and non-reference view features
605
+ # Handle additional tokens (no view-based positional encoding for them)
606
+ if model_input.additional_input_tokens is not None:
607
+
608
+ additional_features = multi_view_features[:, num_of_views * num_of_tokens_per_view :, :]
609
+ multi_view_features = torch.cat([ref_view_features, non_ref_view_features, additional_features], dim=1)
610
+ else:
611
+ multi_view_features = torch.cat([ref_view_features, non_ref_view_features], dim=1)
612
+
613
+ # Loop over the depth of the transformer
614
+ for depth_idx in range(self.depth):
615
+ if depth_idx % 2 == 0:
616
+ # Apply the self-attention block and update the multi-view features
617
+ # Global attention across all views
618
+ multi_view_features = self.self_attention_blocks[depth_idx](multi_view_features, multi_view_positions)
619
+ else:
620
+ # Handle additional tokens separately for frame-level attention
621
+ additional_features = None
622
+ additional_positions = None
623
+ if model_input.additional_input_tokens is not None:
624
+
625
+ # Extract additional token features
626
+ additional_features = multi_view_features[:, num_of_views * num_of_tokens_per_view :, :]
627
+ # Keep only view features for frame-level attention
628
+ multi_view_features = multi_view_features[:, : num_of_views * num_of_tokens_per_view, :]
629
+
630
+ # Handle positions for additional tokens if custom positional encoding is used
631
+ if self.custom_positional_encoding is not None:
632
+ additional_positions = multi_view_positions[:, num_of_views * num_of_tokens_per_view :, :]
633
+ multi_view_positions = multi_view_positions[:, : num_of_views * num_of_tokens_per_view, :]
634
+
635
+ # Reshape the multi-view features from (N, V * H * W, C) to (N * V, H * W, C)
636
+ multi_view_features = multi_view_features.reshape(
637
+ batch_size * num_of_views, num_of_tokens_per_view, self.dim
638
+ ).contiguous() # (N * V, H * W, C)
639
+ if multi_view_positions[0] is not None:
640
+ multi_view_positions = multi_view_positions.reshape(
641
+ batch_size * num_of_views, num_of_tokens_per_view, 2
642
+ ).contiguous() # (N * V, H * W, C)
643
+
644
+ # Apply the self-attention block and update the multi-view features
645
+ # Frame-level attention within each view
646
+ multi_view_features = self.self_attention_blocks[depth_idx](multi_view_features, multi_view_positions)
647
+
648
+ # Reshape the multi-view features from (N * V, H * W, C) back to (N, V * H * W, C)
649
+ multi_view_features = multi_view_features.reshape(
650
+ batch_size, num_of_views * num_of_tokens_per_view, self.dim
651
+ ).contiguous() # (N, V * H * W, C)
652
+ if multi_view_positions[0] is not None:
653
+ multi_view_positions = multi_view_positions.reshape(
654
+ batch_size, num_of_views * num_of_tokens_per_view, 2
655
+ ).contiguous() # (N, V * H * W, C)
656
+
657
+ # Reattach additional tokens if they exist
658
+ if additional_features is not None:
659
+ multi_view_features = torch.cat([multi_view_features, additional_features], dim=1)
660
+ # Reattach positions for additional tokens if they exist
661
+ if additional_positions is not None:
662
+ multi_view_positions = torch.cat([multi_view_positions, additional_positions], dim=1)
663
+ if depth_idx in take_indices:
664
+ # Normalize the intermediate features with final norm layer if enabled
665
+ intermediate_multi_view_features.append(
666
+ self.norm(multi_view_features) if self.norm_intermediate else multi_view_features
667
+ )
668
+
669
+ # Reshape the intermediate features and convert to MultiViewTransformerOutput class
670
+ for idx in range(len(intermediate_multi_view_features)):
671
+ # Get the current intermediate features
672
+ current_features = intermediate_multi_view_features[idx]
673
+
674
+ # Extract additional token features if provided
675
+ additional_token_features = None
676
+ if model_input.additional_input_tokens is not None:
677
+
678
+ additional_token_features = current_features[:, num_of_views * num_of_tokens_per_view :, :]
679
+ additional_token_features = additional_token_features.permute(0, 2, 1).contiguous() # (N, C, T)
680
+ # Only keep the view features for reshaping
681
+ current_features = current_features[:, : num_of_views * num_of_tokens_per_view, :]
682
+
683
+ # Reshape the intermediate multi-view features (N, V * H * W, C) back to (N, V, C, H, W)
684
+ current_features = current_features.reshape(
685
+ batch_size, num_of_views, height, width, self.dim
686
+ ) # (N, V, H, W, C)
687
+ current_features = current_features.permute(0, 1, 4, 2, 3).contiguous() # (N, V, C, H, W)
688
+
689
+ # Split the intermediate multi-view features into separate views
690
+ current_features = current_features.split(1, dim=1)
691
+ current_features = [
692
+ intermediate_view_features.squeeze(dim=1) for intermediate_view_features in current_features
693
+ ]
694
+
695
+ intermediate_multi_view_features[idx] = MultiViewTransformerOutput(
696
+ features=current_features, additional_token_features=additional_token_features
697
+ )
698
+
699
+ # Return only the intermediate features if enabled
700
+ if self.intermediates_only:
701
+ return intermediate_multi_view_features
702
+
703
+ # Normalize the output features
704
+ output_multi_view_features = self.norm(multi_view_features)
705
+
706
+ # Extract view features (excluding additional tokens)
707
+ additional_token_features = None
708
+ if model_input.additional_input_tokens is not None:
709
+
710
+ additional_token_features = output_multi_view_features[:, num_of_views * num_of_tokens_per_view :, :]
711
+ additional_token_features = additional_token_features.permute(0, 2, 1).contiguous() # (N, C, T)
712
+ view_features = output_multi_view_features[:, : num_of_views * num_of_tokens_per_view, :]
713
+ else:
714
+ view_features = output_multi_view_features
715
+
716
+ # Reshape the output multi-view features (N, V * H * W, C) back to (N, V, C, H, W)
717
+ view_features = view_features.reshape(batch_size, num_of_views, height, width, self.dim) # (N, V, H, W, C)
718
+ view_features = view_features.permute(0, 1, 4, 2, 3).contiguous() # (N, V, C, H, W)
719
+
720
+ # Split the output multi-view features into separate views
721
+ view_features = view_features.split(1, dim=1)
722
+ view_features = [output_view_features.squeeze(dim=1) for output_view_features in view_features]
723
+
724
+ output_multi_view_features = MultiViewTransformerOutput(
725
+ features=view_features, additional_token_features=additional_token_features
726
+ )
727
+
728
+ return output_multi_view_features, intermediate_multi_view_features
729
+
730
+
731
+ def dummy_positional_encoding(x, xpos):
732
+ "Dummy function for positional encoding of tokens"
733
+ x = x
734
+ xpos = xpos
735
+ return x
736
+
737
+
738
+ def test_reshape_for_frame_attention():
739
+ "Test the reshape function for frame-level attention in the Alternating Attention Transformer"
740
+ batch_size = 2
741
+ num_of_views = 3
742
+ height = width = 2
743
+ dim = 4
744
+ num_of_tokens_per_view = height * width
745
+
746
+ # Create tensor with recognizable pattern
747
+ x = torch.zeros(batch_size, num_of_views * num_of_tokens_per_view, dim)
748
+ for b in range(batch_size):
749
+ for v in range(num_of_views):
750
+ for h in range(height):
751
+ for w in range(width):
752
+ token_idx = v * num_of_tokens_per_view + h * width + w
753
+ x[b, token_idx] = torch.tensor([b, v, h, w])
754
+
755
+ # Apply reshape
756
+ reshaped = x.reshape(batch_size * num_of_views, num_of_tokens_per_view, dim).contiguous()
757
+
758
+ # Verify shape
759
+ assert reshaped.shape == (batch_size * num_of_views, num_of_tokens_per_view, dim)
760
+
761
+ # Verify content (check a few values)
762
+ for b in range(batch_size):
763
+ for v in range(num_of_views):
764
+ for h in range(height):
765
+ for w in range(width):
766
+ batch_view_idx = b * num_of_views + v
767
+ token_idx = h * width + w
768
+ expected = torch.tensor([b, v, h, w])
769
+ assert torch.all(reshaped[batch_view_idx, token_idx] == expected)
770
+
771
+ # Verify reshape back works
772
+ back_to_original = reshaped.reshape(batch_size, num_of_views * num_of_tokens_per_view, dim)
773
+ assert torch.all(x == back_to_original)
774
+
775
+ print("Reshape test passed!")
776
+
777
+
778
+ if __name__ == "__main__":
779
+ # Unit test the reshape logic used for frame-level attention
780
+ test_reshape_for_frame_attention()
781
+
782
+ # Init multi-view alternating-attention transformer with no custom positional encoding and run a forward pass
783
+ for num_views in [2, 3, 4]:
784
+ print(f"Testing MultiViewAlternatingAttentionTransformer with {num_views} views ...")
785
+ # No positional encoding for non-reference views
786
+ model = MultiViewAlternatingAttentionTransformer(
787
+ name="MV-AAT",
788
+ input_embed_dim=1024,
789
+ )
790
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
791
+ model_input = MultiViewTransformerInput(features=model_input)
792
+ model_output = model(model_input)
793
+ assert len(model_output.features) == num_views
794
+ assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
795
+ # Sequential idx based positional encoding
796
+ model = MultiViewAlternatingAttentionTransformer(
797
+ name="MV-AAT",
798
+ input_embed_dim=1024,
799
+ use_pe_for_non_reference_views=True,
800
+ max_num_views_for_pe=1000,
801
+ use_rand_idx_pe_for_non_reference_views=False,
802
+ )
803
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
804
+ model_input = MultiViewTransformerInput(features=model_input)
805
+ model_output = model(model_input)
806
+ assert len(model_output.features) == num_views
807
+ assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
808
+ # Random idx based positional encoding
809
+ model = MultiViewAlternatingAttentionTransformer(
810
+ name="MV-AAT",
811
+ input_embed_dim=1024,
812
+ use_pe_for_non_reference_views=True,
813
+ max_num_views_for_pe=1000,
814
+ use_rand_idx_pe_for_non_reference_views=True,
815
+ )
816
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
817
+ model_input = MultiViewTransformerInput(features=model_input)
818
+ model_output = model(model_input)
819
+ assert len(model_output.features) == num_views
820
+ assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
821
+
822
+ # Init multi-view alternating-attention transformer with custom positional encoding and run a forward pass
823
+ for num_views in [2, 3, 4]:
824
+ print(
825
+ f"Testing MultiViewAlternatingAttentionTransformer with {num_views} views and custom positional encoding ..."
826
+ )
827
+ model = MultiViewAlternatingAttentionTransformer(
828
+ name="MV-AAT",
829
+ input_embed_dim=1024,
830
+ custom_positional_encoding=dummy_positional_encoding,
831
+ )
832
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
833
+ model_input = MultiViewTransformerInput(features=model_input)
834
+ model_output = model(model_input)
835
+ assert len(model_output.features) == num_views
836
+ assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
837
+
838
+ print("All multi-view alternating-attention transformers initialized and tested successfully!")
839
+
840
+ # Intermediate Feature Returner Tests
841
+ print("Running Intermediate Feature Returner Tests ...")
842
+
843
+ # Run the intermediate feature returner with last-n index
844
+ model_intermediate_feature_returner = MultiViewAlternatingAttentionTransformerIFR(
845
+ name="MV-AAT-IFR",
846
+ input_embed_dim=1024,
847
+ indices=6, # Last 6 layers
848
+ )
849
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
850
+ model_input = MultiViewTransformerInput(features=model_input)
851
+ output = model_intermediate_feature_returner(model_input)
852
+ assert isinstance(output, tuple)
853
+ assert isinstance(output[0], MultiViewTransformerOutput)
854
+ assert len(output[1]) == 6
855
+ assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1])
856
+ assert len(output[1][0].features) == 2
857
+
858
+ # Run the intermediate feature returner with specific indices
859
+ model_intermediate_feature_returner = MultiViewAlternatingAttentionTransformerIFR(
860
+ name="MV-AAT-IFR",
861
+ input_embed_dim=1024,
862
+ indices=[0, 2, 4, 6], # Specific indices
863
+ )
864
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
865
+ model_input = MultiViewTransformerInput(features=model_input)
866
+ output = model_intermediate_feature_returner(model_input)
867
+ assert isinstance(output, tuple)
868
+ assert isinstance(output[0], MultiViewTransformerOutput)
869
+ assert len(output[1]) == 4
870
+ assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1])
871
+ assert len(output[1][0].features) == 2
872
+
873
+ # Test the normalizing of intermediate features
874
+ model_intermediate_feature_returner = MultiViewAlternatingAttentionTransformerIFR(
875
+ name="MV-AAT-IFR",
876
+ input_embed_dim=1024,
877
+ indices=[-1], # Last layer
878
+ norm_intermediate=False, # Disable normalization
879
+ )
880
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
881
+ model_input = MultiViewTransformerInput(features=model_input)
882
+ output = model_intermediate_feature_returner(model_input)
883
+ for view_idx in range(2):
884
+ assert not torch.equal(
885
+ output[0].features[view_idx], output[1][-1].features[view_idx]
886
+ ), "Final features and intermediate features (last layer) must be different."
887
+
888
+ model_intermediate_feature_returner = MultiViewAlternatingAttentionTransformerIFR(
889
+ name="MV-AAT-IFR",
890
+ input_embed_dim=1024,
891
+ indices=[-1], # Last layer
892
+ norm_intermediate=True,
893
+ )
894
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
895
+ model_input = MultiViewTransformerInput(features=model_input)
896
+ output = model_intermediate_feature_returner(model_input)
897
+ for view_idx in range(2):
898
+ assert torch.equal(
899
+ output[0].features[view_idx], output[1][-1].features[view_idx]
900
+ ), "Final features and intermediate features (last layer) must be same."
901
+
902
+ print("All Intermediate Feature Returner Tests passed!")
903
+
904
+ # Test additonal input tokens for MultiViewAlternatingAttentionTransformer
905
+ print("Testing MultiViewAlternatingAttentionTransformer with additional input tokens ...")
906
+ model = MultiViewAlternatingAttentionTransformer(
907
+ name="MV-AAT",
908
+ input_embed_dim=1024,
909
+ )
910
+ num_views = 2
911
+ num_additional_tokens = 5
912
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
913
+ additional_tokens = torch.rand(1, 1024, num_additional_tokens)
914
+ model_input = MultiViewTransformerInput(features=model_input, additional_input_tokens=additional_tokens)
915
+ model_output = model(model_input)
916
+ assert len(model_output.features) == num_views
917
+ assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
918
+ assert model_output.additional_token_features is not None
919
+ assert model_output.additional_token_features.shape == (1, model.dim, num_additional_tokens)
920
+
921
+ # Test additonal input tokens for MultiViewAlternatingAttentionTransformerIFR
922
+ print("Testing MultiViewAlternatingAttentionTransformerIFR with additional input tokens ...")
923
+ model_ifr = MultiViewAlternatingAttentionTransformerIFR(
924
+ name="MV-AAT-IFR",
925
+ input_embed_dim=1024,
926
+ indices=[0, 2, 4],
927
+ )
928
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
929
+ additional_tokens = torch.rand(1, 1024, num_additional_tokens)
930
+ model_input = MultiViewTransformerInput(features=model_input, additional_input_tokens=additional_tokens)
931
+ output = model_ifr(model_input)
932
+ assert isinstance(output, tuple)
933
+ assert isinstance(output[0], MultiViewTransformerOutput)
934
+ assert output[0].additional_token_features is not None
935
+ assert output[0].additional_token_features.shape == (1, model_ifr.dim, num_additional_tokens)
936
+ assert len(output[1]) == 3
937
+ assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1])
938
+ assert all(intermediate.additional_token_features is not None for intermediate in output[1])
939
+ assert all(
940
+ intermediate.additional_token_features.shape == (1, model_ifr.dim, num_additional_tokens)
941
+ for intermediate in output[1]
942
+ )
943
+
944
+ print("All tests using additional input tokens passed!")
UniCeption/uniception/models/info_sharing/base.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base Information Sharing Class for UniCeption
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+ from typing import List, Optional
7
+
8
+ import torch.nn as nn
9
+ from jaxtyping import Float
10
+ from torch import Tensor
11
+ from torch.utils.checkpoint import checkpoint
12
+
13
+
14
+ @dataclass
15
+ class InfoSharingInput:
16
+ pass
17
+
18
+
19
+ @dataclass
20
+ class InfoSharingOutput:
21
+ pass
22
+
23
+
24
+ class UniCeptionInfoSharingBase(nn.Module):
25
+ "Information Sharing Base Class for UniCeption"
26
+
27
+ def __init__(
28
+ self,
29
+ name: str,
30
+ size: Optional[str] = None,
31
+ *args,
32
+ **kwargs,
33
+ ):
34
+ """
35
+ Base class for all models in UniCeption.
36
+ """
37
+ super().__init__(*args, **kwargs)
38
+
39
+ self.name: str = name
40
+ self.size: Optional[str] = size
41
+
42
+ def forward(
43
+ self,
44
+ model_input: InfoSharingInput,
45
+ ) -> InfoSharingOutput:
46
+ """
47
+ Forward interface for the UniCeption information sharing models.
48
+
49
+ Args:
50
+ model_input (InfoSharingInput): Input to the model.
51
+ This is also includes the other fields that are required by the specific implementation of the model.
52
+
53
+ Returns:
54
+ InfoSharingOutput: Output of the model.
55
+ """
56
+
57
+ raise NotImplementedError
58
+
59
+ def wrap_module_with_gradient_checkpointing(self, module: nn.Module):
60
+ """
61
+ Wrapper for Gradient Checkpointing
62
+ """
63
+
64
+ class _CheckpointingWrapper(module.__class__):
65
+ _restore_cls = module.__class__
66
+
67
+ def forward(self, *args, **kwargs):
68
+ return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
69
+
70
+ module.__class__ = _CheckpointingWrapper
71
+ return module
72
+
73
+
74
+ @dataclass
75
+ class MultiViewTransformerInput(InfoSharingInput):
76
+ """
77
+ Input class for Multi-View Transformer.
78
+ """
79
+
80
+ features: List[Float[Tensor, "batch input_embed_dim feat_height feat_width"]]
81
+ additional_input_tokens: Optional[Float[Tensor, "batch input_embed_dim num_additional_tokens"]] = None
82
+
83
+
84
+ @dataclass
85
+ class MultiViewTransformerOutput(InfoSharingOutput):
86
+ """
87
+ Output class for Multi-View Transformer.
88
+ """
89
+
90
+ features: List[Float[Tensor, "batch transformer_embed_dim feat_height feat_width"]]
91
+ additional_token_features: Optional[Float[Tensor, "batch transformer_embed_dim num_additional_tokens"]] = None
92
+
93
+
94
+ @dataclass
95
+ class MultiSetTransformerInput(InfoSharingInput):
96
+ """
97
+ Input class for Multi-Set Transformer.
98
+ """
99
+
100
+ features: List[Float[Tensor, "batch input_embed_dim num_tokens"]]
101
+ additional_input_tokens: Optional[Float[Tensor, "batch input_embed_dim num_additional_tokens"]] = None
102
+
103
+
104
+ @dataclass
105
+ class MultiSetTransformerOutput(InfoSharingOutput):
106
+ """
107
+ Output class for Multi-Set Transformer.
108
+ """
109
+
110
+ features: List[Float[Tensor, "batch transformer_embed_dim num_tokens"]]
111
+ additional_token_features: Optional[Float[Tensor, "batch transformer_embed_dim num_additional_tokens"]] = None
112
+
113
+
114
+ if __name__ == "__main__":
115
+ dummy_model = UniCeptionInfoSharingBase(name="dummy")
116
+ print("Dummy Base InfoSharing model created successfully!")
UniCeption/uniception/models/info_sharing/cross_attention_transformer.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UniCeption Cross-Attention Transformer for Information Sharing
3
+ """
4
+
5
+ from copy import deepcopy
6
+ from functools import partial
7
+ from typing import Callable, List, Optional, Tuple, Type, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from uniception.models.info_sharing.base import (
13
+ MultiViewTransformerInput,
14
+ MultiViewTransformerOutput,
15
+ UniCeptionInfoSharingBase,
16
+ )
17
+ from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner, feature_take_indices
18
+ from uniception.models.utils.positional_encoding import PositionGetter
19
+ from uniception.models.utils.transformer_blocks import CrossAttentionBlock, Mlp
20
+
21
+
22
+ class MultiViewCrossAttentionTransformer(UniCeptionInfoSharingBase):
23
+ "UniCeption Multi-View Cross-Attention Transformer for information sharing across image features from different views."
24
+
25
+ def __init__(
26
+ self,
27
+ name: str,
28
+ input_embed_dim: int,
29
+ num_views: int,
30
+ size: Optional[str] = None,
31
+ depth: int = 12,
32
+ dim: int = 768,
33
+ num_heads: int = 12,
34
+ mlp_ratio: float = 4.0,
35
+ qkv_bias: bool = True,
36
+ qk_norm: bool = False,
37
+ proj_drop: float = 0.0,
38
+ attn_drop: float = 0.0,
39
+ init_values: Optional[float] = None,
40
+ drop_path: float = 0.0,
41
+ act_layer: Type[nn.Module] = nn.GELU,
42
+ norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(nn.LayerNorm, eps=1e-6),
43
+ mlp_layer: Type[nn.Module] = Mlp,
44
+ custom_positional_encoding: Optional[Callable] = None,
45
+ norm_cross_tokens: bool = True,
46
+ pretrained_checkpoint_path: Optional[str] = None,
47
+ gradient_checkpointing: bool = False,
48
+ *args,
49
+ **kwargs,
50
+ ):
51
+ """
52
+ Initialize the Multi-View Cross-Attention Transformer for information sharing across image features from different views.
53
+ Creates a cross-attention transformer with multiple branches for each view.
54
+
55
+ Args:
56
+ input_embed_dim (int): Dimension of input embeddings.
57
+ num_views (int): Number of views (input feature sets).
58
+ size (str): String to indicate interpretable size of the transformer (for e.g., base, large, ...). (default: None)
59
+ depth (int): Number of transformer layers. (default: 12, base size)
60
+ dim (int): Dimension of the transformer. (default: 768, base size)
61
+ num_heads (int): Number of attention heads. (default: 12, base size)
62
+ mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.)
63
+ qkv_bias (bool): Whether to include bias in qkv projection (default: True)
64
+ qk_norm (bool): Whether to normalize q and k (default: False)
65
+ proj_drop (float): Dropout rate for output (default: 0.)
66
+ attn_drop (float): Dropout rate for attention weights (default: 0.)
67
+ init_values (float): Initial value for LayerScale gamma (default: None)
68
+ drop_path (float): Dropout rate for stochastic depth (default: 0.)
69
+ act_layer (nn.Module): Activation layer (default: nn.GELU)
70
+ norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm)
71
+ mlp_layer (nn.Module): MLP layer (default: Mlp)
72
+ custom_positional_encoding (Callable): Custom positional encoding function (default: None)
73
+ norm_cross_tokens (bool): Whether to normalize cross tokens (default: True)
74
+ pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None)
75
+ gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False)
76
+ """
77
+ # Initialize the base class
78
+ super().__init__(name=name, size=size, *args, **kwargs)
79
+
80
+ # Initialize the specific attributes of the transformer
81
+ self.input_embed_dim = input_embed_dim
82
+ self.num_views = num_views
83
+ self.depth = depth
84
+ self.dim = dim
85
+ self.num_heads = num_heads
86
+ self.mlp_ratio = mlp_ratio
87
+ self.qkv_bias = qkv_bias
88
+ self.qk_norm = qk_norm
89
+ self.proj_drop = proj_drop
90
+ self.attn_drop = attn_drop
91
+ self.init_values = init_values
92
+ self.drop_path = drop_path
93
+ self.act_layer = act_layer
94
+ self.norm_layer = norm_layer
95
+ self.mlp_layer = mlp_layer
96
+ self.custom_positional_encoding = custom_positional_encoding
97
+ self.norm_cross_tokens = norm_cross_tokens
98
+ self.pretrained_checkpoint_path = pretrained_checkpoint_path
99
+ self.gradient_checkpointing = gradient_checkpointing
100
+
101
+ # Initialize the projection layer for input embeddings
102
+ if self.input_embed_dim != self.dim:
103
+ self.proj_embed = nn.Linear(self.input_embed_dim, self.dim, bias=True)
104
+ else:
105
+ self.proj_embed = nn.Identity()
106
+
107
+ # Initialize the cross-attention blocks for a single view
108
+ cross_attention_blocks = nn.ModuleList(
109
+ [
110
+ CrossAttentionBlock(
111
+ dim=self.dim,
112
+ num_heads=self.num_heads,
113
+ mlp_ratio=self.mlp_ratio,
114
+ qkv_bias=self.qkv_bias,
115
+ qk_norm=self.qk_norm,
116
+ proj_drop=self.proj_drop,
117
+ attn_drop=self.attn_drop,
118
+ init_values=self.init_values,
119
+ drop_path=self.drop_path,
120
+ act_layer=self.act_layer,
121
+ norm_layer=self.norm_layer,
122
+ mlp_layer=self.mlp_layer,
123
+ custom_positional_encoding=self.custom_positional_encoding,
124
+ norm_cross_tokens=self.norm_cross_tokens,
125
+ )
126
+ for _ in range(self.depth)
127
+ ]
128
+ )
129
+
130
+ # Copy the cross-attention blocks for all other views
131
+ self.multi_view_branches = nn.ModuleList([cross_attention_blocks])
132
+ for _ in range(1, self.num_views):
133
+ self.multi_view_branches.append(deepcopy(cross_attention_blocks))
134
+
135
+ # Initialize the final normalization layer
136
+ self.norm = self.norm_layer(self.dim)
137
+
138
+ # Initialize the position getter for patch positions if required
139
+ if self.custom_positional_encoding is not None:
140
+ self.position_getter = PositionGetter()
141
+
142
+ # Initialize random weights
143
+ self.initialize_weights()
144
+
145
+ # Apply gradient checkpointing if enabled
146
+ if self.gradient_checkpointing:
147
+ for i, block in enumerate(self.cross_attention_blocks):
148
+ self.cross_attention_blocks[i] = self.wrap_module_with_gradient_checkpointing(block)
149
+
150
+ # Load pretrained weights if provided
151
+ if self.pretrained_checkpoint_path is not None:
152
+ print(
153
+ f"Loading pretrained multi-view cross-attention transformer weights from {self.pretrained_checkpoint_path} ..."
154
+ )
155
+ ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
156
+ print(self.load_state_dict(ckpt["model"]))
157
+
158
+ def initialize_weights(self):
159
+ "Initialize weights of the transformer."
160
+ # Linears and layer norms
161
+ self.apply(self._init_weights)
162
+
163
+ def _init_weights(self, m):
164
+ "Initialize the transformer linear and layer norm weights."
165
+ if isinstance(m, nn.Linear):
166
+ # We use xavier_uniform following official JAX ViT:
167
+ torch.nn.init.xavier_uniform_(m.weight)
168
+ if isinstance(m, nn.Linear) and m.bias is not None:
169
+ nn.init.constant_(m.bias, 0)
170
+ elif isinstance(m, nn.LayerNorm):
171
+ nn.init.constant_(m.bias, 0)
172
+ nn.init.constant_(m.weight, 1.0)
173
+
174
+ def forward(
175
+ self,
176
+ model_input: MultiViewTransformerInput,
177
+ ) -> MultiViewTransformerOutput:
178
+ """
179
+ Forward interface for the Multi-View Cross-Attention Transformer.
180
+
181
+ Args:
182
+ model_input (MultiViewTransformerInput): Input to the model.
183
+ Expects the features to be a list of size (batch, input_embed_dim, height, width),
184
+ where each entry corresponds to a different view.
185
+
186
+ Returns:
187
+ MultiViewTransformerOutput: Output of the model post information sharing.
188
+ """
189
+ # Check that the number of views matches the input and the features are of expected shape
190
+ assert (
191
+ len(model_input.features) == self.num_views
192
+ ), f"Expected {self.num_views} views, got {len(model_input.features)}"
193
+ assert all(
194
+ view_features.shape[1] == self.input_embed_dim for view_features in model_input.features
195
+ ), f"All views must have input dimension {self.input_embed_dim}"
196
+ assert all(
197
+ view_features.ndim == 4 for view_features in model_input.features
198
+ ), "All views must have 4 dimensions (N, C, H, W)"
199
+
200
+ # Initialize the multi-view features from the model input
201
+ multi_view_features = model_input.features
202
+
203
+ # Resize the multi-view features from NCHW to NLC
204
+ batch_size, _, height, width = multi_view_features[0].shape
205
+ multi_view_features = [
206
+ view_features.permute(0, 2, 3, 1).reshape(batch_size, height * width, self.input_embed_dim).contiguous()
207
+ for view_features in multi_view_features
208
+ ]
209
+
210
+ # Create patch positions for each view if custom positional encoding is used
211
+ if self.custom_positional_encoding is not None:
212
+ multi_view_positions = [
213
+ self.position_getter(batch_size, height, width, view_features.device)
214
+ for view_features in multi_view_features
215
+ ]
216
+ else:
217
+ multi_view_positions = [None] * self.num_views
218
+
219
+ # Project input features to the transformer dimension
220
+ multi_view_features = [self.proj_embed(view_features) for view_features in multi_view_features]
221
+
222
+ # Pass through each view's cross-attention blocks
223
+ # Loop over the depth of the transformer
224
+ for depth_idx in range(self.depth):
225
+ updated_multi_view_features = []
226
+ # Loop over each view
227
+ for view_idx, view_features in enumerate(multi_view_features):
228
+ # Get all the other views
229
+ other_views_features = [multi_view_features[i] for i in range(self.num_views) if i != view_idx]
230
+ # Concatenate all the tokens from the other views
231
+ other_views_features = torch.cat(other_views_features, dim=1)
232
+ # Get the positions for the current view
233
+ view_positions = multi_view_positions[view_idx]
234
+ # Get the positions for all other views
235
+ other_views_positions = (
236
+ torch.cat([multi_view_positions[i] for i in range(self.num_views) if i != view_idx], dim=1)
237
+ if view_positions is not None
238
+ else None
239
+ )
240
+ # Apply the cross-attention block and update the multi-view features
241
+ updated_view_features = self.multi_view_branches[view_idx][depth_idx](
242
+ view_features, other_views_features, view_positions, other_views_positions
243
+ )
244
+ # Keep track of the updated view features
245
+ updated_multi_view_features.append(updated_view_features)
246
+ # Update the multi-view features for the next depth
247
+ multi_view_features = updated_multi_view_features
248
+
249
+ # Normalize the output features
250
+ output_multi_view_features = [self.norm(view_features) for view_features in multi_view_features]
251
+
252
+ # Resize the output multi-view features back to NCHW
253
+ output_multi_view_features = [
254
+ view_features.reshape(batch_size, height, width, self.dim).permute(0, 3, 1, 2).contiguous()
255
+ for view_features in output_multi_view_features
256
+ ]
257
+
258
+ return MultiViewTransformerOutput(features=output_multi_view_features)
259
+
260
+
261
+ class MultiViewCrossAttentionTransformerIFR(MultiViewCrossAttentionTransformer, IntermediateFeatureReturner):
262
+ "Intermediate Feature Returner for UniCeption Multi-View Cross-Attention Transformer"
263
+
264
+ def __init__(
265
+ self,
266
+ name: str,
267
+ input_embed_dim: int,
268
+ num_views: int,
269
+ size: Optional[str] = None,
270
+ depth: int = 12,
271
+ dim: int = 768,
272
+ num_heads: int = 12,
273
+ mlp_ratio: float = 4.0,
274
+ qkv_bias: bool = True,
275
+ qk_norm: bool = False,
276
+ proj_drop: float = 0.0,
277
+ attn_drop: float = 0.0,
278
+ init_values: Optional[float] = None,
279
+ drop_path: float = 0.0,
280
+ act_layer: nn.Module = nn.GELU,
281
+ norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
282
+ mlp_layer: nn.Module = Mlp,
283
+ custom_positional_encoding: Callable = None,
284
+ norm_cross_tokens: bool = True,
285
+ pretrained_checkpoint_path: str = None,
286
+ indices: Optional[Union[int, List[int]]] = None,
287
+ norm_intermediate: bool = True,
288
+ intermediates_only: bool = False,
289
+ gradient_checkpointing: bool = False,
290
+ *args,
291
+ **kwargs,
292
+ ):
293
+ """
294
+ Initialize the Multi-View Cross-Attention Transformer for information sharing across image features from different views.
295
+ Creates a cross-attention transformer with multiple branches for each view.
296
+ Extends the base class to return intermediate features.
297
+
298
+ Args:
299
+ input_embed_dim (int): Dimension of input embeddings.
300
+ num_views (int): Number of views (input feature sets).
301
+ size (str): String to indicate interpretable size of the transformer (for e.g., base, large, ...). (default: None)
302
+ depth (int): Number of transformer layers. (default: 12, base size)
303
+ dim (int): Dimension of the transformer. (default: 768, base size)
304
+ num_heads (int): Number of attention heads. (default: 12, base size)
305
+ mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.)
306
+ qkv_bias (bool): Whether to include bias in qkv projection (default: True)
307
+ qk_norm (bool): Whether to normalize q and k (default: False)
308
+ proj_drop (float): Dropout rate for output (default: 0.)
309
+ attn_drop (float): Dropout rate for attention weights (default: 0.)
310
+ init_values (float): Initial value for LayerScale gamma (default: None)
311
+ drop_path (float): Dropout rate for stochastic depth (default: 0.)
312
+ act_layer (nn.Module): Activation layer (default: nn.GELU)
313
+ norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm)
314
+ mlp_layer (nn.Module): MLP layer (default: Mlp)
315
+ custom_positional_encoding (Callable): Custom positional encoding function (default: None)
316
+ norm_cross_tokens (bool): Whether to normalize cross tokens (default: True)
317
+ pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None)
318
+ indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. (default: None) Options:
319
+ - None: Return all intermediate layers.
320
+ - int: Return the last n layers.
321
+ - List[int]: Return the intermediate layers at the specified indices.
322
+ norm_intermediate (bool, optional): Whether to normalize the intermediate features. (default: True)
323
+ intermediates_only (bool, optional): Whether to return only the intermediate features. (default: False)
324
+ gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False)
325
+ """
326
+ # Init the base classes
327
+ MultiViewCrossAttentionTransformer.__init__(
328
+ self,
329
+ name=name,
330
+ input_embed_dim=input_embed_dim,
331
+ num_views=num_views,
332
+ size=size,
333
+ depth=depth,
334
+ dim=dim,
335
+ num_heads=num_heads,
336
+ mlp_ratio=mlp_ratio,
337
+ qkv_bias=qkv_bias,
338
+ qk_norm=qk_norm,
339
+ proj_drop=proj_drop,
340
+ attn_drop=attn_drop,
341
+ init_values=init_values,
342
+ drop_path=drop_path,
343
+ act_layer=act_layer,
344
+ norm_layer=norm_layer,
345
+ mlp_layer=mlp_layer,
346
+ custom_positional_encoding=custom_positional_encoding,
347
+ norm_cross_tokens=norm_cross_tokens,
348
+ pretrained_checkpoint_path=pretrained_checkpoint_path,
349
+ gradient_checkpointing=gradient_checkpointing,
350
+ *args,
351
+ **kwargs,
352
+ )
353
+ IntermediateFeatureReturner.__init__(
354
+ self,
355
+ indices=indices,
356
+ norm_intermediate=norm_intermediate,
357
+ intermediates_only=intermediates_only,
358
+ )
359
+
360
+ def forward(
361
+ self,
362
+ model_input: MultiViewTransformerInput,
363
+ ) -> Union[
364
+ List[MultiViewTransformerOutput],
365
+ Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]],
366
+ ]:
367
+ """
368
+ Forward interface for the Multi-View Cross-Attention Transformer with Intermediate Feature Return.
369
+
370
+ Args:
371
+ model_input (MultiViewTransformerInput): Input to the model.
372
+ Expects the features to be a list of size (batch, input_embed_dim, height, width),
373
+ where each entry corresponds to a different view.
374
+
375
+ Returns:
376
+ Union[List[MultiViewTransformerOutput], Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]]]:
377
+ Output of the model post information sharing.
378
+ If intermediates_only is True, returns a list of intermediate outputs.
379
+ If intermediates_only is False, returns a tuple of final output and a list of intermediate outputs.
380
+ """
381
+ # Check that the number of views matches the input and the features are of expected shape
382
+ assert (
383
+ len(model_input.features) == self.num_views
384
+ ), f"Expected {self.num_views} views, got {len(model_input.features)}"
385
+ assert all(
386
+ view_features.shape[1] == self.input_embed_dim for view_features in model_input.features
387
+ ), f"All views must have input dimension {self.input_embed_dim}"
388
+ assert all(
389
+ view_features.ndim == 4 for view_features in model_input.features
390
+ ), "All views must have 4 dimensions (N, C, H, W)"
391
+
392
+ # Get the indices of the intermediate features to return
393
+ intermediate_multi_view_features = []
394
+ take_indices, _ = feature_take_indices(self.depth, self.indices)
395
+
396
+ # Initialize the multi-view features from the model input
397
+ multi_view_features = model_input.features
398
+
399
+ # Resize the multi-view features from NCHW to NLC
400
+ batch_size, _, height, width = multi_view_features[0].shape
401
+ multi_view_features = [
402
+ view_features.permute(0, 2, 3, 1).reshape(batch_size, height * width, self.input_embed_dim).contiguous()
403
+ for view_features in multi_view_features
404
+ ]
405
+
406
+ # Create patch positions for each view if custom positional encoding is used
407
+ if self.custom_positional_encoding is not None:
408
+ multi_view_positions = [
409
+ self.position_getter(batch_size, height, width, view_features.device)
410
+ for view_features in multi_view_features
411
+ ]
412
+ else:
413
+ multi_view_positions = [None] * self.num_views
414
+
415
+ # Project input features to the transformer dimension
416
+ multi_view_features = [self.proj_embed(view_features) for view_features in multi_view_features]
417
+
418
+ # Pass through each view's cross-attention blocks
419
+ # Loop over the depth of the transformer
420
+ for depth_idx in range(self.depth):
421
+ updated_multi_view_features = []
422
+ # Loop over each view
423
+ for view_idx, view_features in enumerate(multi_view_features):
424
+ # Get all the other views
425
+ other_views_features = [multi_view_features[i] for i in range(self.num_views) if i != view_idx]
426
+ # Concatenate all the tokens from the other views
427
+ other_views_features = torch.cat(other_views_features, dim=1)
428
+ # Get the positions for the current view
429
+ view_positions = multi_view_positions[view_idx]
430
+ # Get the positions for all other views
431
+ other_views_positions = (
432
+ torch.cat([multi_view_positions[i] for i in range(self.num_views) if i != view_idx], dim=1)
433
+ if view_positions is not None
434
+ else None
435
+ )
436
+ # Apply the cross-attention block and update the multi-view features
437
+ updated_view_features = self.multi_view_branches[view_idx][depth_idx](
438
+ view_features, other_views_features, view_positions, other_views_positions
439
+ )
440
+ # Keep track of the updated view features
441
+ updated_multi_view_features.append(updated_view_features)
442
+ # Update the multi-view features for the next depth
443
+ multi_view_features = updated_multi_view_features
444
+ # Append the intermediate features if required
445
+ if depth_idx in take_indices:
446
+ # Normalize the intermediate features with final norm layer if enabled
447
+ intermediate_multi_view_features.append(
448
+ [self.norm(view_features) for view_features in multi_view_features]
449
+ if self.norm_intermediate
450
+ else multi_view_features
451
+ )
452
+
453
+ # Reshape the intermediate features and convert to MultiViewTransformerOutput class
454
+ for idx in range(len(intermediate_multi_view_features)):
455
+ intermediate_multi_view_features[idx] = [
456
+ view_features.reshape(batch_size, height, width, self.dim).permute(0, 3, 1, 2).contiguous()
457
+ for view_features in intermediate_multi_view_features[idx]
458
+ ]
459
+ intermediate_multi_view_features[idx] = MultiViewTransformerOutput(
460
+ features=intermediate_multi_view_features[idx]
461
+ )
462
+
463
+ # Return only the intermediate features if enabled
464
+ if self.intermediates_only:
465
+ return intermediate_multi_view_features
466
+
467
+ # Normalize the output features
468
+ output_multi_view_features = [self.norm(view_features) for view_features in multi_view_features]
469
+
470
+ # Resize the output multi-view features back to NCHW
471
+ output_multi_view_features = [
472
+ view_features.reshape(batch_size, height, width, self.dim).permute(0, 3, 1, 2).contiguous()
473
+ for view_features in output_multi_view_features
474
+ ]
475
+
476
+ output_multi_view_features = MultiViewTransformerOutput(features=output_multi_view_features)
477
+
478
+ return output_multi_view_features, intermediate_multi_view_features
479
+
480
+
481
+ def dummy_positional_encoding(x, xpos):
482
+ "Dummy function for positional encoding of tokens"
483
+ x = x
484
+ xpos = xpos
485
+ return x
486
+
487
+
488
+ if __name__ == "__main__":
489
+ # Init multi-view cross-attention transformer with no custom positional encoding and run a forward pass
490
+ for num_views in [2, 3, 4]:
491
+ print(f"Testing MultiViewCrossAttentionTransformer with {num_views} views ...")
492
+ model = MultiViewCrossAttentionTransformer(name="MV-CAT", input_embed_dim=1024, num_views=num_views)
493
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
494
+ model_input = MultiViewTransformerInput(features=model_input)
495
+ model_output = model(model_input)
496
+ assert len(model_output.features) == num_views
497
+ assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
498
+
499
+ # Init multi-view cross-attention transformer with custom positional encoding and run a forward pass
500
+ for num_views in [2, 3, 4]:
501
+ print(f"Testing MultiViewCrossAttentionTransformer with {num_views} views and custom positional encoding ...")
502
+ model = MultiViewCrossAttentionTransformer(
503
+ name="MV-CAT",
504
+ input_embed_dim=1024,
505
+ num_views=num_views,
506
+ custom_positional_encoding=dummy_positional_encoding,
507
+ )
508
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
509
+ model_input = MultiViewTransformerInput(features=model_input)
510
+ model_output = model(model_input)
511
+ assert len(model_output.features) == num_views
512
+ assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
513
+
514
+ print("All multi-view cross-attention transformers initialized and tested successfully!")
515
+
516
+ # Intermediate Feature Returner Tests
517
+ print("Running Intermediate Feature Returner Tests ...")
518
+
519
+ # Run the intermediate feature returner with last-n index
520
+ model_intermediate_feature_returner = MultiViewCrossAttentionTransformerIFR(
521
+ name="MV-CAT-IFR",
522
+ input_embed_dim=1024,
523
+ num_views=2,
524
+ indices=6, # Last 6 layers
525
+ )
526
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
527
+ model_input = MultiViewTransformerInput(features=model_input)
528
+ output = model_intermediate_feature_returner(model_input)
529
+ assert isinstance(output, tuple)
530
+ assert isinstance(output[0], MultiViewTransformerOutput)
531
+ assert len(output[1]) == 6
532
+ assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1])
533
+ assert len(output[1][0].features) == 2
534
+
535
+ # Run the intermediate feature returner with specific indices
536
+ model_intermediate_feature_returner = MultiViewCrossAttentionTransformerIFR(
537
+ name="MV-CAT-IFR",
538
+ input_embed_dim=1024,
539
+ num_views=2,
540
+ indices=[0, 2, 4, 6], # Specific indices
541
+ )
542
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
543
+ model_input = MultiViewTransformerInput(features=model_input)
544
+ output = model_intermediate_feature_returner(model_input)
545
+ assert isinstance(output, tuple)
546
+ assert isinstance(output[0], MultiViewTransformerOutput)
547
+ assert len(output[1]) == 4
548
+ assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1])
549
+ assert len(output[1][0].features) == 2
550
+
551
+ # Test the normalizing of intermediate features
552
+ model_intermediate_feature_returner = MultiViewCrossAttentionTransformerIFR(
553
+ name="MV-CAT-IFR",
554
+ input_embed_dim=1024,
555
+ num_views=2,
556
+ indices=[-1], # Last layer
557
+ norm_intermediate=False, # Disable normalization
558
+ )
559
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
560
+ model_input = MultiViewTransformerInput(features=model_input)
561
+ output = model_intermediate_feature_returner(model_input)
562
+ for view_idx in range(2):
563
+ assert not torch.equal(
564
+ output[0].features[view_idx], output[1][-1].features[view_idx]
565
+ ), "Final features and intermediate features (last layer) must be different."
566
+
567
+ model_intermediate_feature_returner = MultiViewCrossAttentionTransformerIFR(
568
+ name="MV-CAT-IFR",
569
+ input_embed_dim=1024,
570
+ num_views=2,
571
+ indices=[-1], # Last layer
572
+ norm_intermediate=True,
573
+ )
574
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
575
+ model_input = MultiViewTransformerInput(features=model_input)
576
+ output = model_intermediate_feature_returner(model_input)
577
+ for view_idx in range(2):
578
+ assert torch.equal(
579
+ output[0].features[view_idx], output[1][-1].features[view_idx]
580
+ ), "Final features and intermediate features (last layer) must be same."
581
+
582
+ print("All Intermediate Feature Returner Tests passed!")
UniCeption/uniception/models/info_sharing/diff_cross_attention_transformer.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UniCeption Cross-Attention Transformer for Information Sharing
3
+ """
4
+
5
+ from copy import deepcopy
6
+ from functools import partial
7
+ from typing import Callable, List, Optional, Tuple, Type, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from uniception.models.info_sharing.base import UniCeptionInfoSharingBase
13
+ from uniception.models.info_sharing.cross_attention_transformer import (
14
+ MultiViewTransformerInput,
15
+ MultiViewTransformerOutput,
16
+ PositionGetter,
17
+ )
18
+ from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner, feature_take_indices
19
+ from uniception.models.utils.transformer_blocks import DiffCrossAttentionBlock, Mlp
20
+
21
+
22
+ class DifferentialMultiViewCrossAttentionTransformer(UniCeptionInfoSharingBase):
23
+ "UniCeption Multi-View Cross-Attention Transformer for information sharing across image features from different views."
24
+
25
+ def __init__(
26
+ self,
27
+ name: str,
28
+ input_embed_dim: int,
29
+ num_views: int,
30
+ size: Optional[str] = None,
31
+ depth: int = 12,
32
+ dim: int = 768,
33
+ num_heads: int = 12,
34
+ mlp_ratio: float = 4.0,
35
+ qkv_bias: bool = True,
36
+ qk_norm: bool = False,
37
+ proj_drop: float = 0.0,
38
+ attn_drop: float = 0.0,
39
+ init_values: Optional[float] = None,
40
+ drop_path: float = 0.0,
41
+ act_layer: Type[nn.Module] = nn.GELU,
42
+ norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(nn.LayerNorm, eps=1e-6),
43
+ mlp_layer: Type[nn.Module] = Mlp,
44
+ custom_positional_encoding: Optional[Callable] = None,
45
+ norm_cross_tokens: bool = True,
46
+ pretrained_checkpoint_path: Optional[str] = None,
47
+ gradient_checkpointing: bool = False,
48
+ *args,
49
+ **kwargs,
50
+ ):
51
+ """
52
+ Initialize the Multi-View Cross-Attention Transformer for information sharing across image features from different views.
53
+ Creates a cross-attention transformer with multiple branches for each view.
54
+
55
+ Args:
56
+ input_embed_dim (int): Dimension of input embeddings.
57
+ num_views (int): Number of views (input feature sets).
58
+ depth (int): Number of transformer layers. (default: 12, base size)
59
+ dim (int): Dimension of the transformer. (default: 768, base size)
60
+ num_heads (int): Number of attention heads. (default: 12, base size)
61
+ mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.)
62
+ qkv_bias (bool): Whether to include bias in qkv projection (default: False)
63
+ qk_norm (bool): Whether to normalize q and k (default: False)
64
+ proj_drop (float): Dropout rate for output (default: 0.)
65
+ attn_drop (float): Dropout rate for attention weights (default: 0.)
66
+ init_values (float): Initial value for LayerScale gamma (default: None)
67
+ drop_path (float): Dropout rate for stochastic depth (default: 0.)
68
+ act_layer (nn.Module): Activation layer (default: nn.GELU)
69
+ norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm)
70
+ mlp_layer (nn.Module): MLP layer (default: Mlp)
71
+ custom_positional_encoding (Callable): Custom positional encoding function (default: None)
72
+ norm_cross_tokens (bool): Whether to normalize cross tokens (default: True)
73
+ pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None)
74
+ gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False)
75
+ """
76
+ # Initialize the base class
77
+ super().__init__(name=name, size=size, *args, **kwargs)
78
+
79
+ # Initialize the specific attributes of the transformer
80
+ self.input_embed_dim = input_embed_dim
81
+ self.num_views = num_views
82
+ self.depth = depth
83
+ self.dim = dim
84
+ self.num_heads = num_heads
85
+ self.mlp_ratio = mlp_ratio
86
+ self.qkv_bias = qkv_bias
87
+ self.qk_norm = qk_norm
88
+ self.proj_drop = proj_drop
89
+ self.attn_drop = attn_drop
90
+ self.init_values = init_values
91
+ self.drop_path = drop_path
92
+ self.act_layer = act_layer
93
+ self.norm_layer = norm_layer
94
+ self.mlp_layer = mlp_layer
95
+ self.custom_positional_encoding = custom_positional_encoding
96
+ self.norm_cross_tokens = norm_cross_tokens
97
+ self.pretrained_checkpoint_path = pretrained_checkpoint_path
98
+ self.gradient_checkpointing = gradient_checkpointing
99
+
100
+ # Initialize the projection layer for input embeddings
101
+ if self.input_embed_dim != self.dim:
102
+ self.proj_embed = nn.Linear(self.input_embed_dim, self.dim, bias=True)
103
+ else:
104
+ self.proj_embed = nn.Identity()
105
+
106
+ # Initialize the cross-attention blocks for a single view
107
+ assert num_heads % 2 == 0, "Number of heads must be divisible by 2 for differential cross-attention."
108
+ cross_attention_blocks = nn.ModuleList(
109
+ [
110
+ DiffCrossAttentionBlock(
111
+ depth=i,
112
+ dim=self.dim,
113
+ num_heads=self.num_heads // 2,
114
+ mlp_ratio=self.mlp_ratio,
115
+ qkv_bias=self.qkv_bias,
116
+ qk_norm=self.qk_norm,
117
+ proj_drop=self.proj_drop,
118
+ attn_drop=self.attn_drop,
119
+ init_values=self.init_values,
120
+ drop_path=self.drop_path,
121
+ act_layer=self.act_layer,
122
+ norm_layer=self.norm_layer,
123
+ mlp_layer=self.mlp_layer,
124
+ custom_positional_encoding=self.custom_positional_encoding,
125
+ norm_cross_tokens=self.norm_cross_tokens,
126
+ )
127
+ for i in range(self.depth)
128
+ ]
129
+ )
130
+
131
+ # Copy the cross-attention blocks for all other views
132
+ self.multi_view_branches = nn.ModuleList([cross_attention_blocks])
133
+ for _ in range(1, self.num_views):
134
+ self.multi_view_branches.append(deepcopy(cross_attention_blocks))
135
+
136
+ # Initialize the final normalization layer
137
+ self.norm = self.norm_layer(self.dim)
138
+
139
+ # Initialize the position getter for patch positions if required
140
+ if self.custom_positional_encoding is not None:
141
+ self.position_getter = PositionGetter()
142
+
143
+ # Initialize random weights
144
+ self.initialize_weights()
145
+
146
+ # Apply gradient checkpointing if enabled
147
+ if self.gradient_checkpointing:
148
+ for i, block in enumerate(self.cross_attention_blocks):
149
+ self.cross_attention_blocks[i] = self.wrap_module_with_gradient_checkpointing(block)
150
+
151
+ # Load pretrained weights if provided
152
+ if self.pretrained_checkpoint_path is not None:
153
+ print(
154
+ f"Loading pretrained multi-view cross-attention transformer weights from {self.pretrained_checkpoint_path} ..."
155
+ )
156
+ ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
157
+ print(self.load_state_dict(ckpt["model"]))
158
+
159
+ def initialize_weights(self):
160
+ "Initialize weights of the transformer."
161
+ # Linears and layer norms
162
+ self.apply(self._init_weights)
163
+
164
+ def _init_weights(self, m):
165
+ "Initialize the transformer linear and layer norm weights."
166
+ if isinstance(m, nn.Linear):
167
+ # We use xavier_uniform following official JAX ViT:
168
+ torch.nn.init.xavier_uniform_(m.weight)
169
+ if isinstance(m, nn.Linear) and m.bias is not None:
170
+ nn.init.constant_(m.bias, 0)
171
+ elif isinstance(m, nn.LayerNorm):
172
+ nn.init.constant_(m.bias, 0)
173
+ nn.init.constant_(m.weight, 1.0)
174
+
175
+ def forward(
176
+ self,
177
+ model_input: MultiViewTransformerInput,
178
+ ) -> MultiViewTransformerOutput:
179
+ """
180
+ Forward interface for the Multi-View Cross-Attention Transformer.
181
+
182
+ Args:
183
+ model_input (MultiViewTransformerInput): Input to the model.
184
+ Expects the features to be a list of size (batch, input_embed_dim, height, width),
185
+ where each entry corresponds to a different view.
186
+
187
+ Returns:
188
+ MultiViewTransformerOutput: Output of the model post information sharing.
189
+ """
190
+ # Check that the number of views matches the input and the features are of expected shape
191
+ assert (
192
+ len(model_input.features) == self.num_views
193
+ ), f"Expected {self.num_views} views, got {len(model_input.features)}"
194
+ assert all(
195
+ view_features.shape[1] == self.input_embed_dim for view_features in model_input.features
196
+ ), f"All views must have input dimension {self.input_embed_dim}"
197
+ assert all(
198
+ view_features.ndim == 4 for view_features in model_input.features
199
+ ), "All views must have 4 dimensions (N, C, H, W)"
200
+
201
+ # Initialize the multi-view features from the model input
202
+ multi_view_features = model_input.features
203
+
204
+ # Resize the multi-view features from NCHW to NLC
205
+ batch_size, _, height, width = multi_view_features[0].shape
206
+ multi_view_features = [
207
+ view_features.permute(0, 2, 3, 1).reshape(batch_size, height * width, self.input_embed_dim).contiguous()
208
+ for view_features in multi_view_features
209
+ ]
210
+
211
+ # Create patch positions for each view if custom positional encoding is used
212
+ if self.custom_positional_encoding is not None:
213
+ multi_view_positions = [
214
+ self.position_getter(batch_size, height, width, view_features.device)
215
+ for view_features in multi_view_features
216
+ ]
217
+ else:
218
+ multi_view_positions = [None] * self.num_views
219
+
220
+ # Project input features to the transformer dimension
221
+ multi_view_features = [self.proj_embed(view_features) for view_features in multi_view_features]
222
+
223
+ # Pass through each view's cross-attention blocks
224
+ # Loop over the depth of the transformer
225
+ for depth_idx in range(self.depth):
226
+ updated_multi_view_features = []
227
+ # Loop over each view
228
+ for view_idx, view_features in enumerate(multi_view_features):
229
+ # Get all the other views
230
+ other_views_features = [multi_view_features[i] for i in range(self.num_views) if i != view_idx]
231
+ # Concatenate all the tokens from the other views
232
+ other_views_features = torch.cat(other_views_features, dim=1)
233
+ # Get the positions for the current view
234
+ view_positions = multi_view_positions[view_idx]
235
+ # Get the positions for all other views
236
+ other_views_positions = (
237
+ torch.cat([multi_view_positions[i] for i in range(self.num_views) if i != view_idx], dim=1)
238
+ if view_positions is not None
239
+ else None
240
+ )
241
+ # Apply the cross-attention block and update the multi-view features
242
+ updated_view_features = self.multi_view_branches[view_idx][depth_idx](
243
+ view_features, other_views_features, view_positions, other_views_positions
244
+ )
245
+ # Keep track of the updated view features
246
+ updated_multi_view_features.append(updated_view_features)
247
+ # Update the multi-view features for the next depth
248
+ multi_view_features = updated_multi_view_features
249
+
250
+ # Normalize the output features
251
+ output_multi_view_features = [self.norm(view_features) for view_features in multi_view_features]
252
+
253
+ # Resize the output multi-view features back to NCHW
254
+ output_multi_view_features = [
255
+ view_features.reshape(batch_size, height, width, self.dim).permute(0, 3, 1, 2).contiguous()
256
+ for view_features in output_multi_view_features
257
+ ]
258
+
259
+ return MultiViewTransformerOutput(features=output_multi_view_features)
260
+
261
+
262
+ class DifferentialMultiViewCrossAttentionTransformerIFR(
263
+ DifferentialMultiViewCrossAttentionTransformer, IntermediateFeatureReturner
264
+ ):
265
+ "Intermediate Feature Returner for UniCeption Multi-View Cross-Attention Transformer"
266
+
267
+ def __init__(
268
+ self,
269
+ name: str,
270
+ input_embed_dim: int,
271
+ num_views: int,
272
+ size: Optional[str] = None,
273
+ depth: int = 12,
274
+ dim: int = 768,
275
+ num_heads: int = 12,
276
+ mlp_ratio: float = 4.0,
277
+ qkv_bias: bool = True,
278
+ qk_norm: bool = False,
279
+ proj_drop: float = 0.0,
280
+ attn_drop: float = 0.0,
281
+ init_values: Optional[float] = None,
282
+ drop_path: float = 0.0,
283
+ act_layer: nn.Module = nn.GELU,
284
+ norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
285
+ mlp_layer: nn.Module = Mlp,
286
+ custom_positional_encoding: Callable = None,
287
+ norm_cross_tokens: bool = True,
288
+ pretrained_checkpoint_path: str = None,
289
+ indices: Optional[Union[int, List[int]]] = None,
290
+ norm_intermediate: bool = True,
291
+ intermediates_only: bool = False,
292
+ gradient_checkpointing: bool = False,
293
+ *args,
294
+ **kwargs,
295
+ ):
296
+ """
297
+ Initialize the Multi-View Cross-Attention Transformer for information sharing across image features from different views.
298
+ Creates a cross-attention transformer with multiple branches for each view.
299
+ Extends the base class to return intermediate features.
300
+
301
+ Args:
302
+ input_embed_dim (int): Dimension of input embeddings.
303
+ num_views (int): Number of views (input feature sets).
304
+ depth (int): Number of transformer layers. (default: 12, base size)
305
+ dim (int): Dimension of the transformer. (default: 768, base size)
306
+ num_heads (int): Number of attention heads. (default: 12, base size)
307
+ mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.)
308
+ qkv_bias (bool): Whether to include bias in qkv projection (default: False)
309
+ qk_norm (bool): Whether to normalize q and k (default: False)
310
+ proj_drop (float): Dropout rate for output (default: 0.)
311
+ attn_drop (float): Dropout rate for attention weights (default: 0.)
312
+ init_values (float): Initial value for LayerScale gamma (default: None)
313
+ drop_path (float): Dropout rate for stochastic depth (default: 0.)
314
+ act_layer (nn.Module): Activation layer (default: nn.GELU)
315
+ norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm)
316
+ mlp_layer (nn.Module): MLP layer (default: Mlp)
317
+ custom_positional_encoding (Callable): Custom positional encoding function (default: None)
318
+ norm_cross_tokens (bool): Whether to normalize cross tokens (default: True)
319
+ pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None)
320
+ indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. (default: None) Options:
321
+ - None: Return all intermediate layers.
322
+ - int: Return the last n layers.
323
+ - List[int]: Return the intermediate layers at the specified indices.
324
+ norm_intermediate (bool, optional): Whether to normalize the intermediate features. (default: True)
325
+ intermediates_only (bool, optional): Whether to return only the intermediate features. (default: False)
326
+ gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False)
327
+ """
328
+ # Init the base classes
329
+ DifferentialMultiViewCrossAttentionTransformer.__init__(
330
+ self,
331
+ name=name,
332
+ input_embed_dim=input_embed_dim,
333
+ num_views=num_views,
334
+ size=size,
335
+ depth=depth,
336
+ dim=dim,
337
+ num_heads=num_heads,
338
+ mlp_ratio=mlp_ratio,
339
+ qkv_bias=qkv_bias,
340
+ qk_norm=qk_norm,
341
+ proj_drop=proj_drop,
342
+ attn_drop=attn_drop,
343
+ init_values=init_values,
344
+ drop_path=drop_path,
345
+ act_layer=act_layer,
346
+ norm_layer=norm_layer,
347
+ mlp_layer=mlp_layer,
348
+ custom_positional_encoding=custom_positional_encoding,
349
+ norm_cross_tokens=norm_cross_tokens,
350
+ pretrained_checkpoint_path=pretrained_checkpoint_path,
351
+ gradient_checkpointing=gradient_checkpointing,
352
+ *args,
353
+ **kwargs,
354
+ )
355
+ IntermediateFeatureReturner.__init__(
356
+ self,
357
+ indices=indices,
358
+ norm_intermediate=norm_intermediate,
359
+ intermediates_only=intermediates_only,
360
+ )
361
+
362
+ def forward(
363
+ self,
364
+ model_input: MultiViewTransformerInput,
365
+ ) -> Union[
366
+ List[MultiViewTransformerOutput],
367
+ Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]],
368
+ ]:
369
+ """
370
+ Forward interface for the Multi-View Cross-Attention Transformer with Intermediate Feature Return.
371
+
372
+ Args:
373
+ model_input (MultiViewTransformerInput): Input to the model.
374
+ Expects the features to be a list of size (batch, input_embed_dim, height, width),
375
+ where each entry corresponds to a different view.
376
+
377
+ Returns:
378
+ Union[List[MultiViewTransformerOutput], Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]]]:
379
+ Output of the model post information sharing.
380
+ If intermediates_only is True, returns a list of intermediate outputs.
381
+ If intermediates_only is False, returns a tuple of final output and a list of intermediate outputs.
382
+ """
383
+ # Check that the number of views matches the input and the features are of expected shape
384
+ assert (
385
+ len(model_input.features) == self.num_views
386
+ ), f"Expected {self.num_views} views, got {len(model_input.features)}"
387
+ assert all(
388
+ view_features.shape[1] == self.input_embed_dim for view_features in model_input.features
389
+ ), f"All views must have input dimension {self.input_embed_dim}"
390
+ assert all(
391
+ view_features.ndim == 4 for view_features in model_input.features
392
+ ), "All views must have 4 dimensions (N, C, H, W)"
393
+
394
+ # Get the indices of the intermediate features to return
395
+ intermediate_multi_view_features = []
396
+ take_indices, _ = feature_take_indices(self.depth, self.indices)
397
+
398
+ # Initialize the multi-view features from the model input
399
+ multi_view_features = model_input.features
400
+
401
+ # Resize the multi-view features from NCHW to NLC
402
+ batch_size, _, height, width = multi_view_features[0].shape
403
+ multi_view_features = [
404
+ view_features.permute(0, 2, 3, 1).reshape(batch_size, height * width, self.input_embed_dim).contiguous()
405
+ for view_features in multi_view_features
406
+ ]
407
+
408
+ # Create patch positions for each view if custom positional encoding is used
409
+ if self.custom_positional_encoding is not None:
410
+ multi_view_positions = [
411
+ self.position_getter(batch_size, height, width, view_features.device)
412
+ for view_features in multi_view_features
413
+ ]
414
+ else:
415
+ multi_view_positions = [None] * self.num_views
416
+
417
+ # Project input features to the transformer dimension
418
+ multi_view_features = [self.proj_embed(view_features) for view_features in multi_view_features]
419
+
420
+ # Pass through each view's cross-attention blocks
421
+ # Loop over the depth of the transformer
422
+ for depth_idx in range(self.depth):
423
+ updated_multi_view_features = []
424
+ # Loop over each view
425
+ for view_idx, view_features in enumerate(multi_view_features):
426
+ # Get all the other views
427
+ other_views_features = [multi_view_features[i] for i in range(self.num_views) if i != view_idx]
428
+ # Concatenate all the tokens from the other views
429
+ other_views_features = torch.cat(other_views_features, dim=1)
430
+ # Get the positions for the current view
431
+ view_positions = multi_view_positions[view_idx]
432
+ # Get the positions for all other views
433
+ other_views_positions = (
434
+ torch.cat([multi_view_positions[i] for i in range(self.num_views) if i != view_idx], dim=1)
435
+ if view_positions is not None
436
+ else None
437
+ )
438
+ # Apply the cross-attention block and update the multi-view features
439
+ updated_view_features = self.multi_view_branches[view_idx][depth_idx](
440
+ view_features, other_views_features, view_positions, other_views_positions
441
+ )
442
+ # Keep track of the updated view features
443
+ updated_multi_view_features.append(updated_view_features)
444
+ # Update the multi-view features for the next depth
445
+ multi_view_features = updated_multi_view_features
446
+ # Append the intermediate features if required
447
+ if depth_idx in take_indices:
448
+ # Normalize the intermediate features with final norm layer if enabled
449
+ intermediate_multi_view_features.append(
450
+ [self.norm(view_features) for view_features in multi_view_features]
451
+ if self.norm_intermediate
452
+ else multi_view_features
453
+ )
454
+
455
+ # Reshape the intermediate features and convert to MultiViewTransformerOutput class
456
+ for idx in range(len(intermediate_multi_view_features)):
457
+ intermediate_multi_view_features[idx] = [
458
+ view_features.reshape(batch_size, height, width, self.dim).permute(0, 3, 1, 2).contiguous()
459
+ for view_features in intermediate_multi_view_features[idx]
460
+ ]
461
+ intermediate_multi_view_features[idx] = MultiViewTransformerOutput(
462
+ features=intermediate_multi_view_features[idx]
463
+ )
464
+
465
+ # Return only the intermediate features if enabled
466
+ if self.intermediates_only:
467
+ return intermediate_multi_view_features
468
+
469
+ # Normalize the output features
470
+ output_multi_view_features = [self.norm(view_features) for view_features in multi_view_features]
471
+
472
+ # Resize the output multi-view features back to NCHW
473
+ output_multi_view_features = [
474
+ view_features.reshape(batch_size, height, width, self.dim).permute(0, 3, 1, 2).contiguous()
475
+ for view_features in output_multi_view_features
476
+ ]
477
+
478
+ output_multi_view_features = MultiViewTransformerOutput(features=output_multi_view_features)
479
+
480
+ return output_multi_view_features, intermediate_multi_view_features
481
+
482
+
483
+ def dummy_positional_encoding(x, xpos):
484
+ "Dummy function for positional encoding of tokens"
485
+ x = x
486
+ xpos = xpos
487
+ return x
488
+
489
+
490
+ if __name__ == "__main__":
491
+ # Init multi-view cross-attention transformer with no custom positional encoding and run a forward pass
492
+ for num_views in [2, 3, 4]:
493
+ print(f"Testing MultiViewCrossAttentionTransformer with {num_views} views ...")
494
+ model = DifferentialMultiViewCrossAttentionTransformer(
495
+ name="MV-DCAT", input_embed_dim=1024, num_views=num_views
496
+ )
497
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
498
+ model_input = MultiViewTransformerInput(features=model_input)
499
+ model_output = model(model_input)
500
+ assert len(model_output.features) == num_views
501
+ assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
502
+
503
+ # Init multi-view cross-attention transformer with custom positional encoding and run a forward pass
504
+ for num_views in [2, 3, 4]:
505
+ print(
506
+ f"Testing Differential MultiViewCrossAttentionTransformer with {num_views} views and custom positional encoding ..."
507
+ )
508
+ model = DifferentialMultiViewCrossAttentionTransformer(
509
+ name="MV-DCAT",
510
+ input_embed_dim=1024,
511
+ num_views=num_views,
512
+ custom_positional_encoding=dummy_positional_encoding,
513
+ )
514
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
515
+ model_input = MultiViewTransformerInput(features=model_input)
516
+ model_output = model(model_input)
517
+ assert len(model_output.features) == num_views
518
+ assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
519
+
520
+ print("All multi-view cross-attention transformers initialized and tested successfully!")
521
+
522
+ # Intermediate Feature Returner Tests
523
+ print("Running Intermediate Feature Returner Tests ...")
524
+
525
+ # Run the intermediate feature returner with last-n index
526
+ model_intermediate_feature_returner = DifferentialMultiViewCrossAttentionTransformerIFR(
527
+ name="MV-DCAT-IFR",
528
+ input_embed_dim=1024,
529
+ num_views=2,
530
+ indices=6, # Last 6 layers
531
+ )
532
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
533
+ model_input = MultiViewTransformerInput(features=model_input)
534
+ output = model_intermediate_feature_returner(model_input)
535
+ assert isinstance(output, tuple)
536
+ assert isinstance(output[0], MultiViewTransformerOutput)
537
+ assert len(output[1]) == 6
538
+ assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1])
539
+ assert len(output[1][0].features) == 2
540
+
541
+ # Run the intermediate feature returner with specific indices
542
+ model_intermediate_feature_returner = DifferentialMultiViewCrossAttentionTransformerIFR(
543
+ name="MV-DCAT-IFR",
544
+ input_embed_dim=1024,
545
+ num_views=2,
546
+ indices=[0, 2, 4, 6], # Specific indices
547
+ )
548
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
549
+ model_input = MultiViewTransformerInput(features=model_input)
550
+ output = model_intermediate_feature_returner(model_input)
551
+ assert isinstance(output, tuple)
552
+ assert isinstance(output[0], MultiViewTransformerOutput)
553
+ assert len(output[1]) == 4
554
+ assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1])
555
+ assert len(output[1][0].features) == 2
556
+
557
+ # Test the normalizing of intermediate features
558
+ model_intermediate_feature_returner = DifferentialMultiViewCrossAttentionTransformerIFR(
559
+ name="MV-DCAT-IFR",
560
+ input_embed_dim=1024,
561
+ num_views=2,
562
+ indices=[-1], # Last layer
563
+ norm_intermediate=False, # Disable normalization
564
+ )
565
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
566
+ model_input = MultiViewTransformerInput(features=model_input)
567
+ output = model_intermediate_feature_returner(model_input)
568
+ for view_idx in range(2):
569
+ assert not torch.equal(
570
+ output[0].features[view_idx], output[1][-1].features[view_idx]
571
+ ), "Final features and intermediate features (last layer) must be different."
572
+
573
+ model_intermediate_feature_returner = DifferentialMultiViewCrossAttentionTransformerIFR(
574
+ name="MV-DCAT-IFR",
575
+ input_embed_dim=1024,
576
+ num_views=2,
577
+ indices=[-1], # Last layer
578
+ norm_intermediate=True,
579
+ )
580
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
581
+ model_input = MultiViewTransformerInput(features=model_input)
582
+ output = model_intermediate_feature_returner(model_input)
583
+ for view_idx in range(2):
584
+ assert torch.equal(
585
+ output[0].features[view_idx], output[1][-1].features[view_idx]
586
+ ), "Final features and intermediate features (last layer) must be same."
587
+
588
+ print("All Intermediate Feature Returner Tests passed!")
UniCeption/uniception/models/info_sharing/global_attention_transformer.py ADDED
@@ -0,0 +1,1107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UniCeption Global-Attention Transformer for Information Sharing
3
+ """
4
+
5
+ from functools import partial
6
+ from typing import Callable, List, Optional, Tuple, Type, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from uniception.models.info_sharing.base import (
13
+ MultiSetTransformerInput,
14
+ MultiSetTransformerOutput,
15
+ MultiViewTransformerInput,
16
+ MultiViewTransformerOutput,
17
+ UniCeptionInfoSharingBase,
18
+ )
19
+ from uniception.models.libs.croco.pos_embed import RoPE2D
20
+ from uniception.models.utils.intermediate_feature_return import IntermediateFeatureReturner, feature_take_indices
21
+ from uniception.models.utils.positional_encoding import PositionGetter
22
+ from uniception.models.utils.transformer_blocks import Mlp, SelfAttentionBlock
23
+
24
+
25
+ class MultiViewGlobalAttentionTransformer(UniCeptionInfoSharingBase):
26
+ "UniCeption Multi-View Global-Attention Transformer for information sharing across image features from different views."
27
+
28
+ def __init__(
29
+ self,
30
+ name: str,
31
+ input_embed_dim: int,
32
+ max_num_views: int,
33
+ use_rand_idx_pe_for_non_reference_views: bool,
34
+ size: Optional[str] = None,
35
+ depth: int = 12,
36
+ dim: int = 768,
37
+ num_heads: int = 12,
38
+ mlp_ratio: float = 4.0,
39
+ qkv_bias: bool = True,
40
+ qk_norm: bool = False,
41
+ proj_drop: float = 0.0,
42
+ attn_drop: float = 0.0,
43
+ init_values: Optional[float] = None,
44
+ drop_path: float = 0.0,
45
+ act_layer: Type[nn.Module] = nn.GELU,
46
+ norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(nn.LayerNorm, eps=1e-6),
47
+ mlp_layer: Type[nn.Module] = Mlp,
48
+ custom_positional_encoding: Optional[Union[str, Callable]] = None,
49
+ pretrained_checkpoint_path: Optional[str] = None,
50
+ gradient_checkpointing: bool = False,
51
+ *args,
52
+ **kwargs,
53
+ ):
54
+ """
55
+ Initialize the Multi-View Global-Attention Transformer for information sharing across image features from different views.
56
+
57
+ Args:
58
+ input_embed_dim (int): Dimension of input embeddings.
59
+ max_num_views (int): Maximum number of views for positional encoding.
60
+ use_rand_idx_pe_for_non_reference_views (bool): Whether to use random index positional encoding for non-reference views.
61
+ size (str): String to indicate interpretable size of the transformer (for e.g., base, large, ...). (default: None)
62
+ depth (int): Number of transformer layers. (default: 12, base size)
63
+ dim (int): Dimension of the transformer. (default: 768, base size)
64
+ num_heads (int): Number of attention heads. (default: 12, base size)
65
+ mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.)
66
+ qkv_bias (bool): Whether to include bias in qkv projection (default: True)
67
+ qk_norm (bool): Whether to normalize q and k (default: False)
68
+ proj_drop (float): Dropout rate for output (default: 0.)
69
+ attn_drop (float): Dropout rate for attention weights (default: 0.)
70
+ init_values (float): Initial value for LayerScale gamma (default: None)
71
+ drop_path (float): Dropout rate for stochastic depth (default: 0.)
72
+ act_layer (nn.Module): Activation layer (default: nn.GELU)
73
+ norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm)
74
+ mlp_layer (nn.Module): MLP layer (default: Mlp)
75
+ custom_positional_encoding (Callable): Custom positional encoding function (default: None)
76
+ pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None)
77
+ gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False)
78
+ """
79
+ # Initialize the base class
80
+ super().__init__(name=name, size=size, *args, **kwargs)
81
+
82
+ # Initialize the specific attributes of the transformer
83
+ self.input_embed_dim = input_embed_dim
84
+ self.max_num_views = max_num_views
85
+ self.use_rand_idx_pe_for_non_reference_views = use_rand_idx_pe_for_non_reference_views
86
+ self.depth = depth
87
+ self.dim = dim
88
+ self.num_heads = num_heads
89
+ self.mlp_ratio = mlp_ratio
90
+ self.qkv_bias = qkv_bias
91
+ self.qk_norm = qk_norm
92
+ self.proj_drop = proj_drop
93
+ self.attn_drop = attn_drop
94
+ self.init_values = init_values
95
+ self.drop_path = drop_path
96
+ self.act_layer = act_layer
97
+ self.norm_layer = norm_layer
98
+ self.mlp_layer = mlp_layer
99
+ self.custom_positional_encoding = custom_positional_encoding
100
+ self.pretrained_checkpoint_path = pretrained_checkpoint_path
101
+ self.gradient_checkpointing = gradient_checkpointing
102
+
103
+ # Initialize the projection layer for input embeddings
104
+ if self.input_embed_dim != self.dim:
105
+ self.proj_embed = nn.Linear(self.input_embed_dim, self.dim, bias=True)
106
+ else:
107
+ self.proj_embed = nn.Identity()
108
+
109
+ # Initialize custom position encodings
110
+ if self.custom_positional_encoding is not None and isinstance(self.custom_positional_encoding, str):
111
+ if self.custom_positional_encoding == "rope":
112
+ self.rope = RoPE2D(freq=100.0, F0=1.0)
113
+ self.custom_positional_encoding = self.rope
114
+ else:
115
+ raise ValueError(f"Unknown custom positional encoding: {self.custom_positional_encoding}")
116
+
117
+ # Initialize the self-attention blocks which ingest all views at once
118
+ self.self_attention_blocks = nn.ModuleList(
119
+ [
120
+ SelfAttentionBlock(
121
+ dim=self.dim,
122
+ num_heads=self.num_heads,
123
+ mlp_ratio=self.mlp_ratio,
124
+ qkv_bias=self.qkv_bias,
125
+ qk_norm=self.qk_norm,
126
+ proj_drop=self.proj_drop,
127
+ attn_drop=self.attn_drop,
128
+ init_values=self.init_values,
129
+ drop_path=self.drop_path,
130
+ act_layer=self.act_layer,
131
+ norm_layer=self.norm_layer,
132
+ mlp_layer=self.mlp_layer,
133
+ custom_positional_encoding=self.custom_positional_encoding,
134
+ )
135
+ for _ in range(self.depth)
136
+ ]
137
+ )
138
+
139
+ # Initialize the final normalization layer
140
+ self.norm = self.norm_layer(self.dim)
141
+
142
+ # Initialize the position getter for patch positions if required
143
+ if self.custom_positional_encoding is not None:
144
+ self.position_getter = PositionGetter()
145
+
146
+ # Initialize the positional encoding table for the different views
147
+ self.register_buffer(
148
+ "view_pos_table",
149
+ self._get_sinusoid_encoding_table(self.max_num_views, self.dim, 10000),
150
+ )
151
+
152
+ # Initialize random weights
153
+ self.initialize_weights()
154
+
155
+ # Load pretrained weights if provided
156
+ if self.pretrained_checkpoint_path is not None:
157
+ print(
158
+ f"Loading pretrained multi-view global-attention transformer weights from {self.pretrained_checkpoint_path} ..."
159
+ )
160
+ ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
161
+ print(self.load_state_dict(ckpt["model"]))
162
+
163
+ # Apply gradient checkpointing if enabled
164
+ if self.gradient_checkpointing:
165
+ for i, block in enumerate(self.self_attention_blocks):
166
+ self.self_attention_blocks[i] = self.wrap_module_with_gradient_checkpointing(block)
167
+
168
+ def _get_sinusoid_encoding_table(self, n_position, d_hid, base):
169
+ "Sinusoid position encoding table"
170
+
171
+ def get_position_angle_vec(position):
172
+ return [position / np.power(base, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
173
+
174
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
175
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
176
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])
177
+
178
+ return torch.FloatTensor(sinusoid_table)
179
+
180
+ def initialize_weights(self):
181
+ "Initialize weights of the transformer."
182
+ # Linears and layer norms
183
+ self.apply(self._init_weights)
184
+
185
+ def _init_weights(self, m):
186
+ "Initialize the transformer linear and layer norm weights."
187
+ if isinstance(m, nn.Linear):
188
+ # We use xavier_uniform following official JAX ViT:
189
+ torch.nn.init.xavier_uniform_(m.weight)
190
+ if isinstance(m, nn.Linear) and m.bias is not None:
191
+ nn.init.constant_(m.bias, 0)
192
+ elif isinstance(m, nn.LayerNorm):
193
+ nn.init.constant_(m.bias, 0)
194
+ nn.init.constant_(m.weight, 1.0)
195
+
196
+ def forward(
197
+ self,
198
+ model_input: MultiViewTransformerInput,
199
+ ) -> MultiViewTransformerOutput:
200
+ """
201
+ Forward interface for the Multi-View Global-Attention Transformer.
202
+
203
+ Args:
204
+ model_input (MultiViewTransformerInput): Input to the model.
205
+ Expects the features to be a list of size (batch, input_embed_dim, height, width),
206
+ where each entry corresponds to a different view.
207
+ Optionally, the input can also include additional_input_tokens (e.g., class token, registers, pose tokens, scale token)
208
+ which are appended to the token set from the multi-view features. The tokens are of size (batch, input_embed_dim, num_of_additional_tokens).
209
+
210
+ Returns:
211
+ MultiViewTransformerOutput: Output of the model post information sharing.
212
+ """
213
+ # Check that the number of views matches the input and the features are of expected shape
214
+ assert (
215
+ len(model_input.features) <= self.max_num_views
216
+ ), f"Expected less than {self.max_num_views} views, got {len(model_input.features)}"
217
+ assert all(
218
+ curr_view_features.shape[1] == self.input_embed_dim for curr_view_features in model_input.features
219
+ ), f"All views must have input dimension {self.input_embed_dim}"
220
+ assert all(
221
+ curr_view_features.ndim == 4 for curr_view_features in model_input.features
222
+ ), "All views must have 4 dimensions (N, C, H, W)"
223
+
224
+ # Initialize the multi-view features from the model input and number of views for current input
225
+ multi_view_features = model_input.features
226
+ num_of_views = len(multi_view_features)
227
+ batch_size, _, height, width = multi_view_features[0].shape
228
+ num_of_tokens_per_view = height * width
229
+
230
+ # Stack the multi-view features (N, C, H, W) to (N, V, C, H, W) (assumes all V views have same shape)
231
+ multi_view_features = torch.stack(multi_view_features, dim=1)
232
+
233
+ # Resize the multi-view features from NVCHW to NLC, where L = V * H * W
234
+ multi_view_features = multi_view_features.permute(0, 1, 3, 4, 2) # (N, V, H, W, C)
235
+ multi_view_features = multi_view_features.reshape(
236
+ batch_size, num_of_views * height * width, self.input_embed_dim
237
+ ).contiguous()
238
+
239
+ # Process additional input tokens if provided
240
+ if model_input.additional_input_tokens is not None:
241
+ additional_tokens = model_input.additional_input_tokens
242
+ assert additional_tokens.ndim == 3, "Additional tokens must have 3 dimensions (N, C, T)"
243
+ assert (
244
+ additional_tokens.shape[1] == self.input_embed_dim
245
+ ), f"Additional tokens must have input dimension {self.input_embed_dim}"
246
+ assert additional_tokens.shape[0] == batch_size, "Batch size mismatch for additional tokens"
247
+
248
+ # Reshape to channel-last format for transformer processing
249
+ additional_tokens = additional_tokens.permute(0, 2, 1).contiguous() # (N, C, T) -> (N, T, C)
250
+
251
+ # Concatenate the additional tokens to the multi-view features
252
+ multi_view_features = torch.cat([multi_view_features, additional_tokens], dim=1)
253
+
254
+ # Project input features to the transformer dimension
255
+ multi_view_features = self.proj_embed(multi_view_features)
256
+
257
+ # Create patch positions for each view if custom positional encoding is used
258
+ if self.custom_positional_encoding is not None:
259
+ multi_view_positions = [
260
+ self.position_getter(batch_size, height, width, multi_view_features.device)
261
+ ] * num_of_views # List of length V, where each tensor is (N, H * W, C)
262
+ multi_view_positions = torch.cat(multi_view_positions, dim=1) # (N, V * H * W, C)
263
+ else:
264
+ multi_view_positions = [None] * num_of_views
265
+
266
+ # Add None positions for additional tokens if they exist
267
+ if model_input.additional_input_tokens is not None:
268
+ additional_tokens_positions = [None] * model_input.additional_input_tokens.shape[1]
269
+ multi_view_positions = multi_view_positions + additional_tokens_positions
270
+
271
+ # Add positional encoding for reference view (idx 0)
272
+ ref_view_pe = self.view_pos_table[0].clone().detach()
273
+ ref_view_pe = ref_view_pe.reshape((1, 1, self.dim))
274
+ ref_view_pe = ref_view_pe.repeat(batch_size, num_of_tokens_per_view, 1)
275
+ ref_view_features = multi_view_features[:, :num_of_tokens_per_view, :]
276
+ ref_view_features = ref_view_features + ref_view_pe
277
+
278
+ # Add positional encoding for non-reference views (sequential indices starting from idx 1 or random indices which are uniformly sampled)
279
+ if self.use_rand_idx_pe_for_non_reference_views:
280
+ non_ref_view_pe_indices = torch.randint(low=1, high=self.max_num_views, size=(num_of_views - 1,))
281
+ else:
282
+ non_ref_view_pe_indices = torch.arange(1, num_of_views)
283
+ non_ref_view_pe = self.view_pos_table[non_ref_view_pe_indices].clone().detach()
284
+ non_ref_view_pe = non_ref_view_pe.reshape((1, num_of_views - 1, self.dim))
285
+ non_ref_view_pe = non_ref_view_pe.repeat_interleave(num_of_tokens_per_view, dim=1)
286
+ non_ref_view_pe = non_ref_view_pe.repeat(batch_size, 1, 1)
287
+ non_ref_view_features = multi_view_features[
288
+ :, num_of_tokens_per_view : num_of_views * num_of_tokens_per_view, :
289
+ ]
290
+ non_ref_view_features = non_ref_view_features + non_ref_view_pe
291
+
292
+ # Concatenate the reference and non-reference view features
293
+ # Handle additional tokens (no view-based positional encoding for them)
294
+ if model_input.additional_input_tokens is not None:
295
+ additional_features = multi_view_features[:, num_of_views * num_of_tokens_per_view :, :]
296
+ multi_view_features = torch.cat([ref_view_features, non_ref_view_features, additional_features], dim=1)
297
+ else:
298
+ multi_view_features = torch.cat([ref_view_features, non_ref_view_features], dim=1)
299
+
300
+ # Loop over the depth of the transformer
301
+ for depth_idx in range(self.depth):
302
+ # Apply the self-attention block and update the multi-view features
303
+ multi_view_features = self.self_attention_blocks[depth_idx](multi_view_features, multi_view_positions)
304
+
305
+ # Normalize the output features
306
+ output_multi_view_features = self.norm(multi_view_features)
307
+
308
+ # Extract only the view features (excluding additional tokens)
309
+ view_features = output_multi_view_features[:, : num_of_views * num_of_tokens_per_view, :]
310
+
311
+ # Reshape the output multi-view features (N, V * H * W, C) back to (N, V, C, H, W)
312
+ view_features = view_features.reshape(batch_size, num_of_views, height, width, self.dim) # (N, V, H, W, C)
313
+ view_features = view_features.permute(0, 1, 4, 2, 3).contiguous() # (N, V, C, H, W)
314
+
315
+ # Split the output multi-view features into separate views
316
+ view_features = view_features.split(1, dim=1)
317
+ view_features = [output_view_features.squeeze(dim=1) for output_view_features in view_features]
318
+
319
+ # Extract and return additional token features if provided
320
+ if model_input.additional_input_tokens is not None:
321
+ additional_token_features = output_multi_view_features[:, num_of_views * num_of_tokens_per_view :, :]
322
+ additional_token_features = additional_token_features.permute(0, 2, 1).contiguous() # (N, C, T)
323
+ return MultiViewTransformerOutput(
324
+ features=view_features, additional_token_features=additional_token_features
325
+ )
326
+ else:
327
+ return MultiViewTransformerOutput(features=view_features)
328
+
329
+
330
+ class MultiViewGlobalAttentionTransformerIFR(MultiViewGlobalAttentionTransformer, IntermediateFeatureReturner):
331
+ "Intermediate Feature Returner for UniCeption Multi-View Global-Attention Transformer"
332
+
333
+ def __init__(
334
+ self,
335
+ name: str,
336
+ input_embed_dim: int,
337
+ max_num_views: int,
338
+ use_rand_idx_pe_for_non_reference_views: bool,
339
+ size: Optional[str] = None,
340
+ depth: int = 12,
341
+ dim: int = 768,
342
+ num_heads: int = 12,
343
+ mlp_ratio: float = 4.0,
344
+ qkv_bias: bool = True,
345
+ qk_norm: bool = False,
346
+ proj_drop: float = 0.0,
347
+ attn_drop: float = 0.0,
348
+ init_values: Optional[float] = None,
349
+ drop_path: float = 0.0,
350
+ act_layer: nn.Module = nn.GELU,
351
+ norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
352
+ mlp_layer: nn.Module = Mlp,
353
+ custom_positional_encoding: Callable = None,
354
+ pretrained_checkpoint_path: str = None,
355
+ indices: Optional[Union[int, List[int]]] = None,
356
+ norm_intermediate: bool = True,
357
+ intermediates_only: bool = False,
358
+ gradient_checkpointing: bool = False,
359
+ *args,
360
+ **kwargs,
361
+ ):
362
+ """
363
+ Initialize the Multi-View Global-Attention Transformer for information sharing across image features from different views.
364
+ Extends the base class to return intermediate features.
365
+
366
+ Args:
367
+ input_embed_dim (int): Dimension of input embeddings.
368
+ max_num_views (int): Maximum number of views for positional encoding.
369
+ use_rand_idx_pe_for_non_reference_views (bool): Whether to use random index positional encoding for non-reference views.
370
+ size (str): String to indicate interpretable size of the transformer (for e.g., base, large, ...). (default: None)
371
+ depth (int): Number of transformer layers. (default: 12, base size)
372
+ dim (int): Dimension of the transformer. (default: 768, base size)
373
+ num_heads (int): Number of attention heads. (default: 12, base size)
374
+ mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.)
375
+ qkv_bias (bool): Whether to include bias in qkv projection (default: False)
376
+ qk_norm (bool): Whether to normalize q and k (default: False)
377
+ proj_drop (float): Dropout rate for output (default: 0.)
378
+ attn_drop (float): Dropout rate for attention weights (default: 0.)
379
+ init_values (float): Initial value for LayerScale gamma (default: None)
380
+ drop_path (float): Dropout rate for stochastic depth (default: 0.)
381
+ act_layer (nn.Module): Activation layer (default: nn.GELU)
382
+ norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm)
383
+ mlp_layer (nn.Module): MLP layer (default: Mlp)
384
+ custom_positional_encoding (Callable): Custom positional encoding function (default: None)
385
+ pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None)
386
+ indices (Optional[Union[int, List[int]]], optional): Indices of the layers to return. (default: None) Options:
387
+ - None: Return all intermediate layers.
388
+ - int: Return the last n layers.
389
+ - List[int]: Return the intermediate layers at the specified indices.
390
+ norm_intermediate (bool, optional): Whether to normalize the intermediate features. (default: True)
391
+ intermediates_only (bool, optional): Whether to return only the intermediate features. (default: False)
392
+ gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False)
393
+ """
394
+ # Init the base classes
395
+ MultiViewGlobalAttentionTransformer.__init__(
396
+ self,
397
+ name=name,
398
+ input_embed_dim=input_embed_dim,
399
+ max_num_views=max_num_views,
400
+ use_rand_idx_pe_for_non_reference_views=use_rand_idx_pe_for_non_reference_views,
401
+ size=size,
402
+ depth=depth,
403
+ dim=dim,
404
+ num_heads=num_heads,
405
+ mlp_ratio=mlp_ratio,
406
+ qkv_bias=qkv_bias,
407
+ qk_norm=qk_norm,
408
+ proj_drop=proj_drop,
409
+ attn_drop=attn_drop,
410
+ init_values=init_values,
411
+ drop_path=drop_path,
412
+ act_layer=act_layer,
413
+ norm_layer=norm_layer,
414
+ mlp_layer=mlp_layer,
415
+ custom_positional_encoding=custom_positional_encoding,
416
+ pretrained_checkpoint_path=pretrained_checkpoint_path,
417
+ gradient_checkpointing=gradient_checkpointing,
418
+ *args,
419
+ **kwargs,
420
+ )
421
+ IntermediateFeatureReturner.__init__(
422
+ self,
423
+ indices=indices,
424
+ norm_intermediate=norm_intermediate,
425
+ intermediates_only=intermediates_only,
426
+ )
427
+
428
+ def forward(
429
+ self,
430
+ model_input: MultiViewTransformerInput,
431
+ ) -> Union[
432
+ List[MultiViewTransformerOutput],
433
+ Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]],
434
+ ]:
435
+ """
436
+ Forward interface for the Multi-View Global-Attention Transformer with Intermediate Feature Return.
437
+
438
+ Args:
439
+ model_input (MultiViewTransformerInput): Input to the model.
440
+ Expects the features to be a list of size (batch, input_embed_dim, height, width),
441
+ where each entry corresponds to a different view.
442
+ Optionally, the input can also include additional_input_tokens (e.g., class token, registers, pose tokens, scale token)
443
+ which are appended to the token set from the multi-view features. The tokens are of size (batch, input_embed_dim, num_of_additional_tokens).
444
+
445
+ Returns:
446
+ Union[List[MultiViewTransformerOutput], Tuple[MultiViewTransformerOutput, List[MultiViewTransformerOutput]]]:
447
+ Output of the model post information sharing.
448
+ If intermediates_only is True, returns a list of intermediate outputs.
449
+ If intermediates_only is False, returns a tuple of final output and a list of intermediate outputs.
450
+ """
451
+ # Check that the number of views matches the input and the features are of expected shape
452
+ assert (
453
+ len(model_input.features) <= self.max_num_views
454
+ ), f"Expected {self.num_views} views, got {len(model_input.features)}"
455
+ assert all(
456
+ curr_view_features.shape[1] == self.input_embed_dim for curr_view_features in model_input.features
457
+ ), f"All views must have input dimension {self.input_embed_dim}"
458
+ assert all(
459
+ curr_view_features.ndim == 4 for curr_view_features in model_input.features
460
+ ), "All views must have 4 dimensions (N, C, H, W)"
461
+
462
+ # Get the indices of the intermediate features to return
463
+ intermediate_multi_view_features = []
464
+ take_indices, _ = feature_take_indices(self.depth, self.indices)
465
+
466
+ # Initialize the multi-view features from the model input and number of views for current input
467
+ multi_view_features = model_input.features
468
+ num_of_views = len(multi_view_features)
469
+ batch_size, _, height, width = multi_view_features[0].shape
470
+ num_of_tokens_per_view = height * width
471
+
472
+ # Stack the multi-view features (N, C, H, W) to (N, V, C, H, W) (assumes all V views have same shape)
473
+ multi_view_features = torch.stack(multi_view_features, dim=1)
474
+
475
+ # Resize the multi-view features from NVCHW to NLC, where L = V * H * W
476
+ multi_view_features = multi_view_features.permute(0, 1, 3, 4, 2) # (N, V, H, W, C)
477
+ multi_view_features = multi_view_features.reshape(
478
+ batch_size, num_of_views * height * width, self.input_embed_dim
479
+ ).contiguous()
480
+
481
+ # Process additional input tokens if provided
482
+ if model_input.additional_input_tokens is not None:
483
+ additional_tokens = model_input.additional_input_tokens
484
+ assert additional_tokens.ndim == 3, "Additional tokens must have 3 dimensions (N, C, T)"
485
+ assert (
486
+ additional_tokens.shape[1] == self.input_embed_dim
487
+ ), f"Additional tokens must have input dimension {self.input_embed_dim}"
488
+ assert additional_tokens.shape[0] == batch_size, "Batch size mismatch for additional tokens"
489
+
490
+ # Reshape to channel-last format for transformer processing
491
+ additional_tokens = additional_tokens.permute(0, 2, 1).contiguous() # (N, C, T) -> (N, T, C)
492
+
493
+ # Concatenate the additional tokens to the multi-view features
494
+ multi_view_features = torch.cat([multi_view_features, additional_tokens], dim=1)
495
+
496
+ # Project input features to the transformer dimension
497
+ multi_view_features = self.proj_embed(multi_view_features)
498
+
499
+ # Create patch positions for each view if custom positional encoding is used
500
+ if self.custom_positional_encoding is not None:
501
+ multi_view_positions = [
502
+ self.position_getter(batch_size, height, width, multi_view_features.device)
503
+ ] * num_of_views # List of length V, where each tensor is (N, H * W, C)
504
+ multi_view_positions = torch.cat(multi_view_positions, dim=1) # (N, V * H * W, C)
505
+ else:
506
+ multi_view_positions = [None] * num_of_views
507
+
508
+ # Add None positions for additional tokens if they exist
509
+ if model_input.additional_input_tokens is not None:
510
+ additional_tokens_positions = [None] * model_input.additional_input_tokens.shape[1]
511
+ multi_view_positions = multi_view_positions + additional_tokens_positions
512
+
513
+ # Add positional encoding for reference view (idx 0)
514
+ ref_view_pe = self.view_pos_table[0].clone().detach()
515
+ ref_view_pe = ref_view_pe.reshape((1, 1, self.dim))
516
+ ref_view_pe = ref_view_pe.repeat(batch_size, num_of_tokens_per_view, 1)
517
+ ref_view_features = multi_view_features[:, :num_of_tokens_per_view, :]
518
+ ref_view_features = ref_view_features + ref_view_pe
519
+
520
+ # Add positional encoding for non-reference views (sequential indices starting from idx 1 or random indices which are uniformly sampled)
521
+ if self.use_rand_idx_pe_for_non_reference_views:
522
+ non_ref_view_pe_indices = torch.randint(low=1, high=self.max_num_views, size=(num_of_views - 1,))
523
+ else:
524
+ non_ref_view_pe_indices = torch.arange(1, num_of_views)
525
+ non_ref_view_pe = self.view_pos_table[non_ref_view_pe_indices].clone().detach()
526
+ non_ref_view_pe = non_ref_view_pe.reshape((1, num_of_views - 1, self.dim))
527
+ non_ref_view_pe = non_ref_view_pe.repeat_interleave(num_of_tokens_per_view, dim=1)
528
+ non_ref_view_pe = non_ref_view_pe.repeat(batch_size, 1, 1)
529
+ non_ref_view_features = multi_view_features[
530
+ :, num_of_tokens_per_view : num_of_views * num_of_tokens_per_view, :
531
+ ]
532
+ non_ref_view_features = non_ref_view_features + non_ref_view_pe
533
+
534
+ # Concatenate the reference and non-reference view features
535
+ # Handle additional tokens (no view-based positional encoding for them)
536
+ if model_input.additional_input_tokens is not None:
537
+ additional_features = multi_view_features[:, num_of_views * num_of_tokens_per_view :, :]
538
+ multi_view_features = torch.cat([ref_view_features, non_ref_view_features, additional_features], dim=1)
539
+ else:
540
+ multi_view_features = torch.cat([ref_view_features, non_ref_view_features], dim=1)
541
+
542
+ # Loop over the depth of the transformer
543
+ for depth_idx in range(self.depth):
544
+ # Apply the self-attention block and update the multi-view features
545
+ multi_view_features = self.self_attention_blocks[depth_idx](multi_view_features, multi_view_positions)
546
+ if depth_idx in take_indices:
547
+ # Normalize the intermediate features with final norm layer if enabled
548
+ intermediate_multi_view_features.append(
549
+ self.norm(multi_view_features) if self.norm_intermediate else multi_view_features
550
+ )
551
+
552
+ # Reshape the intermediate features and convert to MultiViewTransformerOutput class
553
+ for idx in range(len(intermediate_multi_view_features)):
554
+ # Get the current intermediate features
555
+ current_features = intermediate_multi_view_features[idx]
556
+
557
+ # Extract additional token features if provided
558
+ additional_token_features = None
559
+ if model_input.additional_input_tokens is not None:
560
+ additional_token_features = current_features[:, num_of_views * num_of_tokens_per_view :, :]
561
+ additional_token_features = additional_token_features.permute(0, 2, 1).contiguous() # (N, C, T)
562
+ # Only keep the view features for reshaping
563
+ current_features = current_features[:, : num_of_views * num_of_tokens_per_view, :]
564
+
565
+ # Reshape the intermediate multi-view features (N, V * H * W, C) back to (N, V, C, H, W)
566
+ current_features = current_features.reshape(
567
+ batch_size, num_of_views, height, width, self.dim
568
+ ) # (N, V, H, W, C)
569
+ current_features = current_features.permute(0, 1, 4, 2, 3).contiguous() # (N, V, C, H, W)
570
+
571
+ # Split the intermediate multi-view features into separate views
572
+ current_features = current_features.split(1, dim=1)
573
+ current_features = [
574
+ intermediate_view_features.squeeze(dim=1) for intermediate_view_features in current_features
575
+ ]
576
+
577
+ intermediate_multi_view_features[idx] = MultiViewTransformerOutput(
578
+ features=current_features, additional_token_features=additional_token_features
579
+ )
580
+
581
+ # Return only the intermediate features if enabled
582
+ if self.intermediates_only:
583
+ return intermediate_multi_view_features
584
+
585
+ # Normalize the output features
586
+ output_multi_view_features = self.norm(multi_view_features)
587
+
588
+ # Extract view features (excluding additional tokens)
589
+ additional_token_features = None
590
+ if model_input.additional_input_tokens is not None:
591
+ additional_token_features = output_multi_view_features[:, num_of_views * num_of_tokens_per_view :, :]
592
+ additional_token_features = additional_token_features.permute(0, 2, 1).contiguous() # (N, C, T)
593
+ view_features = output_multi_view_features[:, : num_of_views * num_of_tokens_per_view, :]
594
+ else:
595
+ view_features = output_multi_view_features
596
+
597
+ # Reshape the output multi-view features (N, V * H * W, C) back to (N, V, C, H, W)
598
+ view_features = view_features.reshape(batch_size, num_of_views, height, width, self.dim) # (N, V, H, W, C)
599
+ view_features = view_features.permute(0, 1, 4, 2, 3).contiguous() # (N, V, C, H, W)
600
+
601
+ # Split the output multi-view features into separate views
602
+ view_features = view_features.split(1, dim=1)
603
+ view_features = [output_view_features.squeeze(dim=1) for output_view_features in view_features]
604
+
605
+ output_multi_view_features = MultiViewTransformerOutput(
606
+ features=view_features, additional_token_features=additional_token_features
607
+ )
608
+
609
+ return output_multi_view_features, intermediate_multi_view_features
610
+
611
+
612
+ class GlobalAttentionTransformer(UniCeptionInfoSharingBase):
613
+ "UniCeption Global-Attention Transformer for information sharing across different set of features."
614
+
615
+ def __init__(
616
+ self,
617
+ name: str,
618
+ input_embed_dim: int,
619
+ max_num_sets: int,
620
+ use_rand_idx_pe_for_non_reference_sets: bool,
621
+ size: Optional[str] = None,
622
+ depth: int = 12,
623
+ dim: int = 768,
624
+ num_heads: int = 12,
625
+ mlp_ratio: float = 4.0,
626
+ qkv_bias: bool = True,
627
+ qk_norm: bool = False,
628
+ proj_drop: float = 0.0,
629
+ attn_drop: float = 0.0,
630
+ init_values: Optional[float] = None,
631
+ drop_path: float = 0.0,
632
+ act_layer: Type[nn.Module] = nn.GELU,
633
+ norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(nn.LayerNorm, eps=1e-6),
634
+ mlp_layer: Type[nn.Module] = Mlp,
635
+ pretrained_checkpoint_path: Optional[str] = None,
636
+ gradient_checkpointing: bool = False,
637
+ *args,
638
+ **kwargs,
639
+ ):
640
+ """
641
+ Initialize the Global-Attention Transformer for information sharing across features from different sets.
642
+
643
+ Args:
644
+ input_embed_dim (int): Dimension of input embeddings.
645
+ max_num_sets (int): Maximum number of sets for positional encoding.
646
+ use_rand_idx_pe_for_non_reference_sets (bool): Whether to use random index positional encoding for non-reference sets.
647
+ size (str): String to indicate interpretable size of the transformer (for e.g., base, large, ...). (default: None)
648
+ depth (int): Number of transformer layers. (default: 12, base size)
649
+ dim (int): Dimension of the transformer. (default: 768, base size)
650
+ num_heads (int): Number of attention heads. (default: 12, base size)
651
+ mlp_ratio (float): Ratio of hidden to input dimension in MLP (default: 4.)
652
+ qkv_bias (bool): Whether to include bias in qkv projection (default: True)
653
+ qk_norm (bool): Whether to normalize q and k (default: False)
654
+ proj_drop (float): Dropout rate for output (default: 0.)
655
+ attn_drop (float): Dropout rate for attention weights (default: 0.)
656
+ init_values (float): Initial value for LayerScale gamma (default: None)
657
+ drop_path (float): Dropout rate for stochastic depth (default: 0.)
658
+ act_layer (nn.Module): Activation layer (default: nn.GELU)
659
+ norm_layer (nn.Module): Normalization layer (default: nn.LayerNorm)
660
+ mlp_layer (nn.Module): MLP layer (default: Mlp)
661
+ pretrained_checkpoint_path (str, optional): Path to the pretrained checkpoint. (default: None)
662
+ gradient_checkpointing (bool, optional): Whether to use gradient checkpointing for memory efficiency. (default: False)
663
+ """
664
+ # Initialize the base class
665
+ super().__init__(name=name, size=size, *args, **kwargs)
666
+
667
+ # Initialize the specific attributes of the transformer
668
+ self.input_embed_dim = input_embed_dim
669
+ self.max_num_sets = max_num_sets
670
+ self.use_rand_idx_pe_for_non_reference_sets = use_rand_idx_pe_for_non_reference_sets
671
+ self.depth = depth
672
+ self.dim = dim
673
+ self.num_heads = num_heads
674
+ self.mlp_ratio = mlp_ratio
675
+ self.qkv_bias = qkv_bias
676
+ self.qk_norm = qk_norm
677
+ self.proj_drop = proj_drop
678
+ self.attn_drop = attn_drop
679
+ self.init_values = init_values
680
+ self.drop_path = drop_path
681
+ self.act_layer = act_layer
682
+ self.norm_layer = norm_layer
683
+ self.mlp_layer = mlp_layer
684
+ self.pretrained_checkpoint_path = pretrained_checkpoint_path
685
+ self.gradient_checkpointing = gradient_checkpointing
686
+
687
+ # Initialize the projection layer for input embeddings
688
+ if self.input_embed_dim != self.dim:
689
+ self.proj_embed = nn.Linear(self.input_embed_dim, self.dim, bias=True)
690
+ else:
691
+ self.proj_embed = nn.Identity()
692
+
693
+ # Initialize the self-attention blocks which ingest all sets at once
694
+ self.self_attention_blocks = nn.ModuleList(
695
+ [
696
+ SelfAttentionBlock(
697
+ dim=self.dim,
698
+ num_heads=self.num_heads,
699
+ mlp_ratio=self.mlp_ratio,
700
+ qkv_bias=self.qkv_bias,
701
+ qk_norm=self.qk_norm,
702
+ proj_drop=self.proj_drop,
703
+ attn_drop=self.attn_drop,
704
+ init_values=self.init_values,
705
+ drop_path=self.drop_path,
706
+ act_layer=self.act_layer,
707
+ norm_layer=self.norm_layer,
708
+ mlp_layer=self.mlp_layer,
709
+ )
710
+ for _ in range(self.depth)
711
+ ]
712
+ )
713
+
714
+ # Initialize the final normalization layer
715
+ self.norm = self.norm_layer(self.dim)
716
+
717
+ # Initialize the positional encoding table for the different sets
718
+ self.register_buffer(
719
+ "set_pos_table",
720
+ self._get_sinusoid_encoding_table(self.max_num_sets, self.dim, 10000),
721
+ )
722
+
723
+ # Initialize random weights
724
+ self.initialize_weights()
725
+
726
+ # Load pretrained weights if provided
727
+ if self.pretrained_checkpoint_path is not None:
728
+ print(f"Loading pretrained global-attention transformer weights from {self.pretrained_checkpoint_path} ...")
729
+ ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
730
+ print(self.load_state_dict(ckpt["model"]))
731
+
732
+ # Apply gradient checkpointing if enabled
733
+ if self.gradient_checkpointing:
734
+ for i, block in enumerate(self.self_attention_blocks):
735
+ self.self_attention_blocks[i] = self.wrap_module_with_gradient_checkpointing(block)
736
+
737
+ def _get_sinusoid_encoding_table(self, n_position, d_hid, base):
738
+ "Sinusoid position encoding table"
739
+
740
+ def get_position_angle_vec(position):
741
+ return [position / np.power(base, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
742
+
743
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
744
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
745
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])
746
+
747
+ return torch.FloatTensor(sinusoid_table)
748
+
749
+ def initialize_weights(self):
750
+ "Initialize weights of the transformer."
751
+ # Linears and layer norms
752
+ self.apply(self._init_weights)
753
+
754
+ def _init_weights(self, m):
755
+ "Initialize the transformer linear and layer norm weights."
756
+ if isinstance(m, nn.Linear):
757
+ # We use xavier_uniform following official JAX ViT:
758
+ torch.nn.init.xavier_uniform_(m.weight)
759
+ if isinstance(m, nn.Linear) and m.bias is not None:
760
+ nn.init.constant_(m.bias, 0)
761
+ elif isinstance(m, nn.LayerNorm):
762
+ nn.init.constant_(m.bias, 0)
763
+ nn.init.constant_(m.weight, 1.0)
764
+
765
+ def forward(
766
+ self,
767
+ model_input: MultiSetTransformerInput,
768
+ ) -> MultiSetTransformerOutput:
769
+ """
770
+ Forward interface for the Multi-Set Global-Attention Transformer.
771
+
772
+ Args:
773
+ model_input (MultiSetTransformerInput): Input to the model.
774
+ Expects the features to be a list of size (batch, input_embed_dim, num_tokens),
775
+ where each entry corresponds to a different set of tokens and
776
+ the number of tokens can be different for each set.
777
+ Optionally, the input can also include additional_input_tokens (e.g., class token, registers, pose tokens, scale token)
778
+ which are appended to the token set from the multi-view features. The tokens are of size (batch, input_embed_dim, num_of_additional_tokens).
779
+
780
+ Returns:
781
+ MultiSetTransformerOutput: Output of the model post information sharing.
782
+ """
783
+ # Check that the number of sets matches the input and the features are of expected shape
784
+ assert (
785
+ len(model_input.features) <= self.max_num_sets
786
+ ), f"Expected less than {self.max_num_sets} sets, got {len(model_input.features)}"
787
+ assert all(
788
+ set_features.shape[1] == self.input_embed_dim for set_features in model_input.features
789
+ ), f"All sets must have input dimension {self.input_embed_dim}"
790
+ assert all(
791
+ set_features.ndim == 3 for set_features in model_input.features
792
+ ), "All sets must have 3 dimensions (N, C, T)"
793
+
794
+ # Initialize the multi-set features from the model input and number of sets for current input
795
+ multi_set_features = model_input.features
796
+ num_of_sets = len(multi_set_features)
797
+ batch_size, _, _ = multi_set_features[0].shape
798
+ num_of_tokens_per_set = [set_features.shape[2] for set_features in multi_set_features]
799
+
800
+ # Permute the multi-set features from (N, C, T) to (N, T, C)
801
+ multi_set_features = [set_features.permute(0, 2, 1).contiguous() for set_features in multi_set_features]
802
+
803
+ # Stack the multi-set features along the number of tokens dimension
804
+ multi_set_features = torch.cat(multi_set_features, dim=1)
805
+
806
+ # Process additional input tokens if provided
807
+ if model_input.additional_input_tokens is not None:
808
+ additional_tokens = model_input.additional_input_tokens
809
+ assert additional_tokens.ndim == 3, "Additional tokens must have 3 dimensions (N, C, T)"
810
+ assert (
811
+ additional_tokens.shape[1] == self.input_embed_dim
812
+ ), f"Additional tokens must have input dimension {self.input_embed_dim}"
813
+ assert additional_tokens.shape[0] == batch_size, "Batch size mismatch for additional tokens"
814
+
815
+ # Reshape to channel-last format for transformer processing
816
+ additional_tokens = additional_tokens.permute(0, 2, 1).contiguous() # (N, C, T) -> (N, T, C)
817
+
818
+ # Concatenate the additional tokens to the multi-set features
819
+ multi_set_features = torch.cat([multi_set_features, additional_tokens], dim=1)
820
+
821
+ # Project input features to the transformer dimension
822
+ multi_set_features = self.proj_embed(multi_set_features)
823
+
824
+ # Create dummy patch positions for each set
825
+ multi_set_positions = [None] * num_of_sets
826
+
827
+ # Add positional encoding for reference set (idx 0)
828
+ ref_set_pe = self.set_pos_table[0].clone().detach()
829
+ ref_set_pe = ref_set_pe.reshape((1, 1, self.dim))
830
+ ref_set_pe = ref_set_pe.repeat(batch_size, num_of_tokens_per_set[0], 1)
831
+ ref_set_features = multi_set_features[:, : num_of_tokens_per_set[0], :]
832
+ ref_set_features = ref_set_features + ref_set_pe
833
+
834
+ # Add positional encoding for non-reference sets (sequential indices starting from idx 1 or random indices which are uniformly sampled)
835
+ if self.use_rand_idx_pe_for_non_reference_sets:
836
+ non_ref_set_pe_indices = torch.randint(low=1, high=self.max_num_sets, size=(num_of_sets - 1,))
837
+ else:
838
+ non_ref_set_pe_indices = torch.arange(1, num_of_sets)
839
+ non_ref_set_pe_list = []
840
+ for non_ref_set_idx in range(1, num_of_sets):
841
+ non_ref_set_pe_for_idx = self.set_pos_table[non_ref_set_pe_indices[non_ref_set_idx - 1]].clone().detach()
842
+ non_ref_set_pe_for_idx = non_ref_set_pe_for_idx.reshape((1, 1, self.dim))
843
+ non_ref_set_pe_for_idx = non_ref_set_pe_for_idx.repeat(
844
+ batch_size, num_of_tokens_per_set[non_ref_set_idx], 1
845
+ )
846
+ non_ref_set_pe_list.append(non_ref_set_pe_for_idx)
847
+ non_ref_set_pe = torch.cat(non_ref_set_pe_list, dim=1)
848
+ non_ref_set_features = multi_set_features[:, num_of_tokens_per_set[0] : sum(num_of_tokens_per_set), :]
849
+ non_ref_set_features = non_ref_set_features + non_ref_set_pe
850
+
851
+ # Concatenate the reference and non-reference set features
852
+ # Handle additional tokens (no set-based positional encoding for them)
853
+ if model_input.additional_input_tokens is not None:
854
+ additional_features = multi_set_features[:, sum(num_of_tokens_per_set) :, :]
855
+ multi_set_features = torch.cat([ref_set_features, non_ref_set_features, additional_features], dim=1)
856
+ else:
857
+ multi_set_features = torch.cat([ref_set_features, non_ref_set_features], dim=1)
858
+
859
+ # Add None positions for additional tokens if they exist
860
+ if model_input.additional_input_tokens is not None:
861
+ additional_tokens_positions = [None] * model_input.additional_input_tokens.shape[2]
862
+ multi_set_positions = multi_set_positions + additional_tokens_positions
863
+
864
+ # Loop over the depth of the transformer
865
+ for depth_idx in range(self.depth):
866
+ # Apply the self-attention block and update the multi-set features
867
+ multi_set_features = self.self_attention_blocks[depth_idx](multi_set_features, multi_set_positions)
868
+
869
+ # Normalize the output features
870
+ output_multi_set_features = self.norm(multi_set_features)
871
+
872
+ # Extract additional token features if provided
873
+ additional_token_features = None
874
+ if model_input.additional_input_tokens is not None:
875
+ additional_token_features = output_multi_set_features[:, sum(num_of_tokens_per_set) :, :]
876
+ additional_token_features = additional_token_features.permute(
877
+ 0, 2, 1
878
+ ).contiguous() # (N, T, C) -> (N, C, T)
879
+ # Only keep the set features for reshaping
880
+ output_multi_set_features = output_multi_set_features[:, : sum(num_of_tokens_per_set), :]
881
+
882
+ # Reshape the output multi-set features from (N, T, C) to (N, C, T)
883
+ output_multi_set_features = output_multi_set_features.permute(0, 2, 1).contiguous()
884
+
885
+ # Split the output multi-set features into separate sets using the list of number of tokens per set
886
+ output_multi_set_features = torch.split(output_multi_set_features, num_of_tokens_per_set, dim=2)
887
+
888
+ # Return the output multi-set features with additional token features if provided
889
+ return MultiSetTransformerOutput(
890
+ features=output_multi_set_features, additional_token_features=additional_token_features
891
+ )
892
+
893
+
894
+ def dummy_positional_encoding(x, xpos):
895
+ "Dummy function for positional encoding of tokens"
896
+ x = x
897
+ xpos = xpos
898
+ return x
899
+
900
+
901
+ if __name__ == "__main__":
902
+ # Init multi-view global-attention transformer with no custom positional encoding and run a forward pass
903
+ for num_views in [2, 3, 4]:
904
+ print(f"Testing MultiViewGlobalAttentionTransformer with {num_views} views ...")
905
+ # Sequential idx based positional encoding
906
+ model = MultiViewGlobalAttentionTransformer(
907
+ name="MV-GAT", input_embed_dim=1024, max_num_views=1000, use_rand_idx_pe_for_non_reference_views=False
908
+ )
909
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
910
+ model_input = MultiViewTransformerInput(features=model_input)
911
+ model_output = model(model_input)
912
+ assert len(model_output.features) == num_views
913
+ assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
914
+ # Random idx based positional encoding
915
+ model = MultiViewGlobalAttentionTransformer(
916
+ name="MV-GAT", input_embed_dim=1024, max_num_views=1000, use_rand_idx_pe_for_non_reference_views=True
917
+ )
918
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
919
+ model_input = MultiViewTransformerInput(features=model_input)
920
+ model_output = model(model_input)
921
+ assert len(model_output.features) == num_views
922
+ assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
923
+
924
+ # Init multi-view global-attention transformer with custom positional encoding and run a forward pass
925
+ for num_views in [2, 3, 4]:
926
+ print(f"Testing MultiViewGlobalAttentionTransformer with {num_views} views and custom positional encoding ...")
927
+ model = MultiViewGlobalAttentionTransformer(
928
+ name="MV-GAT",
929
+ input_embed_dim=1024,
930
+ max_num_views=1000,
931
+ use_rand_idx_pe_for_non_reference_views=True,
932
+ custom_positional_encoding=dummy_positional_encoding,
933
+ )
934
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
935
+ model_input = MultiViewTransformerInput(features=model_input)
936
+ model_output = model(model_input)
937
+ assert len(model_output.features) == num_views
938
+ assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
939
+
940
+ print("All multi-view global-attention transformers initialized and tested successfully!")
941
+
942
+ # Intermediate Feature Returner Tests
943
+ print("Running Intermediate Feature Returner Tests ...")
944
+
945
+ # Run the intermediate feature returner with last-n index
946
+ model_intermediate_feature_returner = MultiViewGlobalAttentionTransformerIFR(
947
+ name="MV-GAT-IFR",
948
+ input_embed_dim=1024,
949
+ max_num_views=1000,
950
+ use_rand_idx_pe_for_non_reference_views=True,
951
+ indices=6, # Last 6 layers
952
+ )
953
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
954
+ model_input = MultiViewTransformerInput(features=model_input)
955
+ output = model_intermediate_feature_returner(model_input)
956
+ assert isinstance(output, tuple)
957
+ assert isinstance(output[0], MultiViewTransformerOutput)
958
+ assert len(output[1]) == 6
959
+ assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1])
960
+ assert len(output[1][0].features) == 2
961
+
962
+ # Run the intermediate feature returner with specific indices
963
+ model_intermediate_feature_returner = MultiViewGlobalAttentionTransformerIFR(
964
+ name="MV-GAT-IFR",
965
+ input_embed_dim=1024,
966
+ max_num_views=1000,
967
+ use_rand_idx_pe_for_non_reference_views=True,
968
+ indices=[0, 2, 4, 6], # Specific indices
969
+ )
970
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
971
+ model_input = MultiViewTransformerInput(features=model_input)
972
+ output = model_intermediate_feature_returner(model_input)
973
+ assert isinstance(output, tuple)
974
+ assert isinstance(output[0], MultiViewTransformerOutput)
975
+ assert len(output[1]) == 4
976
+ assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1])
977
+ assert len(output[1][0].features) == 2
978
+
979
+ # Test the normalizing of intermediate features
980
+ model_intermediate_feature_returner = MultiViewGlobalAttentionTransformerIFR(
981
+ name="MV-GAT-IFR",
982
+ input_embed_dim=1024,
983
+ max_num_views=1000,
984
+ use_rand_idx_pe_for_non_reference_views=True,
985
+ indices=[-1], # Last layer
986
+ norm_intermediate=False, # Disable normalization
987
+ )
988
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
989
+ model_input = MultiViewTransformerInput(features=model_input)
990
+ output = model_intermediate_feature_returner(model_input)
991
+ for view_idx in range(2):
992
+ assert not torch.equal(
993
+ output[0].features[view_idx], output[1][-1].features[view_idx]
994
+ ), "Final features and intermediate features (last layer) must be different."
995
+
996
+ model_intermediate_feature_returner = MultiViewGlobalAttentionTransformerIFR(
997
+ name="MV-GAT-IFR",
998
+ input_embed_dim=1024,
999
+ max_num_views=1000,
1000
+ use_rand_idx_pe_for_non_reference_views=True,
1001
+ indices=[-1], # Last layer
1002
+ norm_intermediate=True,
1003
+ )
1004
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(2)]
1005
+ model_input = MultiViewTransformerInput(features=model_input)
1006
+ output = model_intermediate_feature_returner(model_input)
1007
+ for view_idx in range(2):
1008
+ assert torch.equal(
1009
+ output[0].features[view_idx], output[1][-1].features[view_idx]
1010
+ ), "Final features and intermediate features (last layer) must be same."
1011
+
1012
+ print("All Intermediate Feature Returner Tests passed!")
1013
+
1014
+ # Init multi-set global-attention transformer and run a forward pass with different number of sets and set token sizes
1015
+ import random
1016
+
1017
+ model = GlobalAttentionTransformer(
1018
+ name="GAT", input_embed_dim=1024, max_num_sets=3, use_rand_idx_pe_for_non_reference_sets=False
1019
+ )
1020
+ for num_sets in [2, 3]:
1021
+ print(f"Testing GlobalAttentionTransformer with {num_sets} sets ...")
1022
+ model_input = [torch.rand(1, 1024, random.randint(256, 513)) for _ in range(num_sets)]
1023
+ model_input = MultiSetTransformerInput(features=model_input)
1024
+ model_output = model(model_input)
1025
+ assert len(model_output.features) == num_sets
1026
+ for feat, rand_input in zip(model_output.features, model_input.features):
1027
+ assert feat.shape[2] == rand_input.shape[2]
1028
+ assert feat.shape[1] == model.dim
1029
+ assert feat.shape[0] == rand_input.shape[0]
1030
+ # Random idx based positional encoding
1031
+ model = GlobalAttentionTransformer(
1032
+ name="GAT", input_embed_dim=1024, max_num_sets=1000, use_rand_idx_pe_for_non_reference_sets=True
1033
+ )
1034
+ for num_sets in [2, 3, 4]:
1035
+ print(f"Testing GlobalAttentionTransformer with {num_sets} sets ...")
1036
+ model_input = [torch.rand(1, 1024, random.randint(256, 513)) for _ in range(num_sets)]
1037
+ model_input = MultiSetTransformerInput(features=model_input)
1038
+ model_output = model(model_input)
1039
+ assert len(model_output.features) == num_sets
1040
+ for feat, rand_input in zip(model_output.features, model_input.features):
1041
+ assert feat.shape[2] == rand_input.shape[2]
1042
+ assert feat.shape[1] == model.dim
1043
+ assert feat.shape[0] == rand_input.shape[0]
1044
+
1045
+ print("All Global Attention Transformer Tests passed!")
1046
+
1047
+ # Test additional input tokens for MultiViewGlobalAttentionTransformer
1048
+ print("Testing MultiViewGlobalAttentionTransformer with additional input tokens...")
1049
+ model = MultiViewGlobalAttentionTransformer(
1050
+ name="MV-GAT", input_embed_dim=1024, max_num_views=1000, use_rand_idx_pe_for_non_reference_views=False
1051
+ )
1052
+ num_views = 2
1053
+ num_additional_tokens = 5
1054
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
1055
+ additional_tokens = torch.rand(1, 1024, num_additional_tokens)
1056
+ model_input = MultiViewTransformerInput(features=model_input, additional_input_tokens=additional_tokens)
1057
+ model_output = model(model_input)
1058
+ assert len(model_output.features) == num_views
1059
+ assert all(f.shape == (1, model.dim, 14, 14) for f in model_output.features)
1060
+ assert model_output.additional_token_features is not None
1061
+ assert model_output.additional_token_features.shape == (1, model.dim, num_additional_tokens)
1062
+
1063
+ # Test additional input tokens for MultiViewGlobalAttentionTransformerIFR
1064
+ print("Testing MultiViewGlobalAttentionTransformerIFR with additional input tokens...")
1065
+ model_ifr = MultiViewGlobalAttentionTransformerIFR(
1066
+ name="MV-GAT-IFR",
1067
+ input_embed_dim=1024,
1068
+ max_num_views=1000,
1069
+ use_rand_idx_pe_for_non_reference_views=True,
1070
+ indices=[0, 2, 4],
1071
+ )
1072
+ model_input = [torch.rand(1, 1024, 14, 14) for _ in range(num_views)]
1073
+ additional_tokens = torch.rand(1, 1024, num_additional_tokens)
1074
+ model_input = MultiViewTransformerInput(features=model_input, additional_input_tokens=additional_tokens)
1075
+ output = model_ifr(model_input)
1076
+ assert isinstance(output, tuple)
1077
+ assert isinstance(output[0], MultiViewTransformerOutput)
1078
+ assert output[0].additional_token_features is not None
1079
+ assert output[0].additional_token_features.shape == (1, model_ifr.dim, num_additional_tokens)
1080
+ assert len(output[1]) == 3
1081
+ assert all(isinstance(intermediate, MultiViewTransformerOutput) for intermediate in output[1])
1082
+ assert all(intermediate.additional_token_features is not None for intermediate in output[1])
1083
+ assert all(
1084
+ intermediate.additional_token_features.shape == (1, model_ifr.dim, num_additional_tokens)
1085
+ for intermediate in output[1]
1086
+ )
1087
+
1088
+ # Test additional input tokens for GlobalAttentionTransformer
1089
+ print("Testing GlobalAttentionTransformer with additional input tokens...")
1090
+ model = GlobalAttentionTransformer(
1091
+ name="GAT", input_embed_dim=1024, max_num_sets=1000, use_rand_idx_pe_for_non_reference_sets=False
1092
+ )
1093
+ num_sets = 3
1094
+ num_additional_tokens = 8
1095
+ model_input = [torch.rand(1, 1024, random.randint(256, 513)) for _ in range(num_sets)]
1096
+ additional_tokens = torch.rand(1, 1024, num_additional_tokens)
1097
+ model_input = MultiSetTransformerInput(features=model_input, additional_input_tokens=additional_tokens)
1098
+ model_output = model(model_input)
1099
+ assert len(model_output.features) == num_sets
1100
+ for feat, rand_input in zip(model_output.features, model_input.features):
1101
+ assert feat.shape[2] == rand_input.shape[2]
1102
+ assert feat.shape[1] == model.dim
1103
+ assert feat.shape[0] == rand_input.shape[0]
1104
+ assert model_output.additional_token_features is not None
1105
+ assert model_output.additional_token_features.shape == (1, model.dim, num_additional_tokens)
1106
+
1107
+ print("All tests using additional input tokens passed!")
UniCeption/uniception/models/libs/__init__.py ADDED
File without changes