hengli commited on
Commit
b7f83b0
·
1 Parent(s): 132427c
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitattributes copy +2 -0
  3. .gitignore +162 -0
  4. .pre-commit-config.yaml +27 -0
  5. CODE_OF_CONDUCT.md +80 -0
  6. CONTRIBUTING.md +31 -0
  7. LICENSE.txt +21 -0
  8. README copy.md +93 -0
  9. app.py +914 -0
  10. demo.py +131 -0
  11. demo_gradio.py +921 -0
  12. docs/traj_ply.png +3 -0
  13. eval/datasets/mip_360.py +115 -0
  14. eval/datasets/seven_scenes.py +58 -0
  15. eval/datasets/tnt.py +116 -0
  16. eval/datasets/tum.py +53 -0
  17. eval/readme.md +110 -0
  18. eval/utils/cropping.py +289 -0
  19. eval/utils/device.py +95 -0
  20. eval/utils/eval_pose_ransac.py +315 -0
  21. eval/utils/eval_utils.py +74 -0
  22. eval/utils/geometry.py +572 -0
  23. eval/utils/image.py +232 -0
  24. eval/utils/load_fn.py +155 -0
  25. eval/utils/misc.py +131 -0
  26. eval/utils/pose_enc.py +135 -0
  27. eval/utils/rotation.py +142 -0
  28. eval/utils/visual_track.py +244 -0
  29. pyproject.toml +58 -0
  30. requirements.txt +10 -0
  31. requirements_demo.txt +16 -0
  32. sailrecon/dependency/__init__.py +3 -0
  33. sailrecon/dependency/distortion.py +223 -0
  34. sailrecon/dependency/np_to_pycolmap.py +355 -0
  35. sailrecon/dependency/projection.py +249 -0
  36. sailrecon/dependency/track_modules/__init__.py +0 -0
  37. sailrecon/dependency/track_modules/base_track_predictor.py +210 -0
  38. sailrecon/dependency/track_modules/blocks.py +396 -0
  39. sailrecon/dependency/track_modules/modules.py +216 -0
  40. sailrecon/dependency/track_modules/track_refine.py +493 -0
  41. sailrecon/dependency/track_modules/utils.py +235 -0
  42. sailrecon/dependency/track_predict.py +349 -0
  43. sailrecon/dependency/vggsfm_tracker.py +148 -0
  44. sailrecon/dependency/vggsfm_utils.py +341 -0
  45. sailrecon/heads/camera_head.py +228 -0
  46. sailrecon/heads/dpt_head.py +598 -0
  47. sailrecon/heads/head_act.py +127 -0
  48. sailrecon/heads/track_head.py +116 -0
  49. sailrecon/heads/track_modules/__init__.py +5 -0
  50. sailrecon/heads/track_modules/base_track_predictor.py +242 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ docs/traj_ply.png filter=lfs diff=lfs merge=lfs -text
.gitattributes copy ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # SCM syntax highlighting & preventing 3-way merges
2
+ pixi.lock merge=binary linguist-language=YAML linguist-generated=true
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .hydra/
2
+ output/
3
+ ckpt/
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ **/__pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ pip-wheel-metadata/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ .python-version
91
+
92
+ # pipenv
93
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
95
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
96
+ # install all needed dependencies.
97
+ #Pipfile.lock
98
+
99
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100
+ __pypackages__/
101
+
102
+ # Celery stuff
103
+ celerybeat-schedule
104
+ celerybeat.pid
105
+
106
+ # SageMath parsed files
107
+ *.sage.py
108
+
109
+ # Environments
110
+ .env
111
+ .venv
112
+ env/
113
+ venv/
114
+ ENV/
115
+ env.bak/
116
+ venv.bak/
117
+
118
+ # Spyder project settings
119
+ .spyderproject
120
+ .spyproject
121
+
122
+ # Rope project settings
123
+ .ropeproject
124
+
125
+ # mkdocs documentation
126
+ /site
127
+
128
+ # mypy
129
+ .mypy_cache/
130
+ .dmypy.json
131
+ dmypy.json
132
+
133
+ # Pyre type checker
134
+ .pyre/
135
+
136
+ # pytype static type analyzer
137
+ .pytype/
138
+
139
+ # Profiling data
140
+ .prof
141
+
142
+ # Folder specific to your needs
143
+ **/tmp/
144
+ **/outputs/skyseg.onnx
145
+ skyseg.onnx
146
+
147
+ # pixi environments
148
+ .pixi
149
+ *.egg-info
150
+
151
+ # demo images
152
+ /docs/demo_image
153
+
154
+ # vscode settings
155
+ .vscode/
156
+
157
+
158
+ outputs/
159
+ samples/
160
+ tmp_video/
161
+ .gradio/
162
+ input_images_*
.pre-commit-config.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_language_version:
2
+ python: python3
3
+
4
+ repos:
5
+ - repo: https://github.com/pre-commit/pre-commit-hooks
6
+ rev: v4.4.0
7
+ hooks:
8
+ - id: trailing-whitespace
9
+ - id: check-ast
10
+ - id: check-merge-conflict
11
+ - id: check-yaml
12
+ - id: end-of-file-fixer
13
+ - id: trailing-whitespace
14
+ args: [--markdown-linebreak-ext=md]
15
+
16
+ - repo: https://github.com/psf/black
17
+ rev: 23.3.0
18
+ hooks:
19
+ - id: black
20
+ language_version: python3
21
+
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.12.0
24
+ hooks:
25
+ - id: isort
26
+ exclude: README.md
27
+ args: ["--profile", "black"]
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <[email protected]>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
CONTRIBUTING.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to vggt
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Facebook's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26
+ disclosure of security bugs. In those cases, please go through the process
27
+ outlined on that page and do not file a public issue.
28
+
29
+ ## License
30
+ By contributing to vggt, you agree that your contributions will be licensed
31
+ under the LICENSE file in the root directory of this source tree.
LICENSE.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 HKUST SAIL-LAB and Horizon Robotics.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README copy.md ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>SAIL-Recon: Large SfM by Augmenting Scene Regression with Localization</h1>
3
+
4
+
5
+ <a href="https://arxiv.org/pdf/2508.17972"><img src="https://img.shields.io/badge/arXiv-2508.17972-b31b1b" alt="arXiv"></a>
6
+ <a href="https://hkust-sail.github.io/sail-recon/"><img src="https://img.shields.io/badge/Project_Page-green" alt="Project Page"></a>
7
+ <!--<a href='https://huggingface.co/'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-blue'></a>-->
8
+
9
+
10
+ **[HKUST Spatial Artificial Intelligence Lab](https://github.com/HKUST-SAIL)**; **[Horizon Robotics](https://en.horizon.auto/)**
11
+
12
+
13
+ [Junyuan Deng](https://scholar.google.com/citations?user=KTCPC5IAAAAJ&hl=en), [Heng Li](https://hengli.me/), [Tao Xie](https://github.com/xbillowy), [Weiqiang Ren](https://cn.linkedin.com/in/weiqiang-ren-b2798636), [Qian Zhang](https://cn.linkedin.com/in/qian-zhang-10234b73), [Ping Tan](https://facultyprofiles.hkust.edu.hk/profiles.php?profile=ping-tan-pingtan), [Xiaoyang Guo](https://xy-guo.github.io/)
14
+ </div>
15
+
16
+ ![pic1](docs/traj_ply.png)
17
+
18
+
19
+
20
+ ## Overview
21
+
22
+ Sail-Recon is a feed-forward Transformer that scales neural scene regression to large-scale Structure-from-Motion by augmenting it with visual localization. From a few anchor views, it constructs a global latent scene representation that encodes both geometry and appearance. Conditioned on this representation, the network directly regresses camera poses, intrinsics, depth maps, and scene coordinate maps for thousands of images in minutes, enabling precise and robust reconstruction without iterative optimization.
23
+
24
+
25
+ ## TODO
26
+ - [x] Inference Code Release
27
+ - [x] Gradio Demo
28
+ - [ ] Evaluation Script
29
+
30
+ ## Quick Start
31
+
32
+ First, clone this repository to your local machine, and install the dependencies (torch, torchvision, numpy, Pillow, and huggingface_hub) following VGGT.
33
+
34
+ ```bash
35
+ git clone https://github.com/HKUST-SAIL/sail-recon.git
36
+ cd sail-recon
37
+ pip install -e .
38
+ ```
39
+
40
+ You can download the demo image (e.g., [Barn](https://drive.google.com/file/d/0B-ePgl6HF260NzQySklGdXZyQzA/view?resourcekey=0-luQ7Jaym5BQL6IjxsgXY9A) from [Tanks & Temples](https://www.tanksandtemples.org/)) and put the images in `examples/demo_image`.
41
+
42
+ Now, you can try the model demo:
43
+ ```bash
44
+ # Images
45
+ python demo.py --img_dir path/to/your/images --out_dir outputs
46
+ # Video
47
+ python demo.py --vid_dir path/to/your/images --out_dir outputs
48
+ ```
49
+
50
+ You can find the ply file and camera pose under `outputs`.
51
+
52
+ We also provide a Gradio demo for easier usage. You can run the demo by:
53
+ ```bash
54
+ python demo_gradio.py
55
+ ```
56
+ Please note that the Gradio demo is slower than `demo.py` due to the visualization part.
57
+
58
+
59
+ ## Evaluation
60
+
61
+ Please refer to [this](eval/readme.md) for more details.
62
+
63
+ ## Acknowledgements
64
+
65
+ Thanks to these great repositories:
66
+
67
+ [ACE0](https://github.com/nianticlabs/acezero) for the PSNR evaluation;
68
+
69
+ [VGGT](https://github.com/facebookresearch/vggt) for the template of github, gradio and visualization;
70
+
71
+ [Fast3R](https://github.com/facebookresearch/fast3r) for the training data processing and some utility functions;
72
+
73
+ And many other inspiring works in the community.
74
+
75
+ If you find this project useful in your research, please consider citing:
76
+ ```bibtex
77
+ @article{dengli2025sail,
78
+ title={SAIL-Recon: Large SfM by Augmenting Scene Regression with Localization},
79
+ author={Deng, Junyuan and Li, Heng and Xie, Tao and Ren, Weiqiang and Zhang, Qian and Tan, Ping and Guo, Xiaoyang},
80
+ journal={arXiv preprint arXiv:2508.17972},
81
+ year={2025}
82
+ }
83
+ ```
84
+
85
+ ## License
86
+
87
+ See the [LICENSE](./LICENSE.txt) file for details about the license under which this code is made available.
88
+
89
+ Please see the license of [VGGT](https://github.com/facebookresearch/vggt) about the other code used in this project.
90
+
91
+ Please see the license of [ACE0](https://github.com/nianticlabs/acezero) about the evaluation used in this project.
92
+
93
+ Please see the license of [Fast3R](https://github.com/facebookresearch/fast3r) about the utility functions used in this project.
app.py ADDED
@@ -0,0 +1,914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import glob
3
+ import os
4
+ import shutil
5
+ import sys
6
+ import time
7
+ from datetime import datetime
8
+
9
+ import cv2
10
+ import gradio as gr
11
+ import numpy as np
12
+ import torch
13
+ from tqdm import tqdm
14
+
15
+ from eval.utils.device import to_cpu
16
+ from eval.utils.eval_utils import uniform_sample
17
+ from sailrecon.models.sail_recon import SailRecon
18
+ from sailrecon.utils.geometry import unproject_depth_map_to_point_map
19
+ from sailrecon.utils.load_fn import load_and_preprocess_images
20
+ from sailrecon.utils.pose_enc import (
21
+ extri_intri_to_pose_encoding,
22
+ pose_encoding_to_extri_intri,
23
+ )
24
+ from visual_util import predictions_to_glb
25
+
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+
28
+ print("Initializing and loading SailRecon model...")
29
+
30
+ model = SailRecon(kv_cache=True)
31
+
32
+ model_dir = "ckpt/sailrecon.pt"
33
+ if os.path.exists(model_dir):
34
+ model.load_state_dict(torch.load(model_dir))
35
+ else:
36
+ _URL = "https://huggingface.co/HKUST-SAIL/SAIL-Recon/resolve/main/sailrecon.pt"
37
+ model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
38
+
39
+
40
+ model.eval()
41
+ model = model.to(device)
42
+
43
+
44
+ # -------------------------------------------------------------------------
45
+ # 1) Core model inference
46
+ # -------------------------------------------------------------------------
47
+ def run_model(target_dir, model, anchor_size=100) -> dict:
48
+ """
49
+ Run the SAIL-Recon model on images in the 'target_dir/images' folder and return predictions.
50
+ """
51
+ print(f"Processing images from {target_dir}")
52
+
53
+ # Device check
54
+ device = "cuda" if torch.cuda.is_available() else "cpu"
55
+ if not torch.cuda.is_available():
56
+ raise ValueError("CUDA is not available. Check your environment.")
57
+
58
+ # Move model to device
59
+ model = model.to(device)
60
+ model.eval()
61
+
62
+ # Load and preprocess images
63
+ image_names = glob.glob(os.path.join(target_dir, "images", "*"))
64
+ image_names = sorted(image_names)
65
+ print(f"Found {len(image_names)} images")
66
+ if len(image_names) == 0:
67
+ raise ValueError("No images found. Check your upload.")
68
+
69
+ images = load_and_preprocess_images(image_names).to(device)
70
+ print(f"Preprocessed images shape: {images.shape}")
71
+ # anchor image selection
72
+ select_indices = uniform_sample(len(image_names), min(100, len(image_names)))
73
+ anchor_images = images[select_indices]
74
+
75
+ # Run inference
76
+ print("Running inference...")
77
+ dtype = (
78
+ torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
79
+ )
80
+
81
+ with torch.no_grad():
82
+ with torch.cuda.amp.autocast(dtype=dtype):
83
+ print("Processing anchor images ...")
84
+ model.tmp_forward(anchor_images)
85
+ # del model.aggregator.global_blocks
86
+ # relocalization on all images
87
+ predictions_s = []
88
+ with tqdm(total=len(image_names), desc="Relocalizing") as pbar:
89
+ for img_split in images.split(10, dim=0):
90
+ pbar.update(10)
91
+ predictions_s += to_cpu(
92
+ model.reloc(img_split, ret_img=True, memory_save=False)
93
+ )
94
+
95
+ predictions = {}
96
+ predictions["extrinsic"] = torch.cat(
97
+ [s["extrinsic"] for s in predictions_s], dim=0
98
+ ) # (S, 4, 4)
99
+ predictions["intrinsic"] = torch.cat(
100
+ [s["intrinsic"] for s in predictions_s], dim=0
101
+ ) # (S, 4, 4)
102
+ predictions["depth"] = torch.cat(
103
+ [s["depth_map"] for s in predictions_s], dim=0
104
+ ) # (S, H, W, 1)
105
+ predictions["depth_conf"] = torch.cat(
106
+ [s["dpt_cnf"] for s in predictions_s], dim=0
107
+ ) # (S, H, W, 1)
108
+ predictions["images"] = torch.cat(
109
+ [s["images"] for s in predictions_s], dim=0
110
+ ) # (S, H, W, 3)
111
+ predictions["world_points"] = torch.cat(
112
+ [s["point_map"] for s in predictions_s], dim=0
113
+ ) # (S, H, W, 3)
114
+ predictions["world_points_conf"] = torch.cat(
115
+ [s["xyz_cnf"] for s in predictions_s], dim=0
116
+ ) # (S, H, W, 3)
117
+ predictions["pose_enc"] = extri_intri_to_pose_encoding(
118
+ predictions["extrinsic"].unsqueeze(0),
119
+ predictions["intrinsic"].unsqueeze(0),
120
+ images.shape[-2:],
121
+ )[
122
+ 0
123
+ ] # a
124
+ del predictions_s
125
+
126
+ # Convert tensors to numpy
127
+ for key in predictions.keys():
128
+ if isinstance(predictions[key], torch.Tensor):
129
+ predictions[key] = predictions[key].cpu().numpy() # remove batch dimension
130
+ predictions["pose_enc_list"] = None # remove pose_enc_list
131
+
132
+ # Generate world points from depth map
133
+ print("Computing world points from depth map...")
134
+ depth_map = predictions["depth"] # (S, H, W, 1)
135
+ world_points = unproject_depth_map_to_point_map(
136
+ depth_map, predictions["extrinsic"], predictions["intrinsic"]
137
+ )
138
+ predictions["world_points_from_depth"] = world_points
139
+
140
+ # Clean up
141
+ torch.cuda.empty_cache()
142
+ return predictions
143
+
144
+
145
+ # -------------------------------------------------------------------------
146
+ # 2) Handle uploaded video/images --> produce target_dir + images
147
+ # -------------------------------------------------------------------------
148
+ def handle_uploads(input_video, input_images):
149
+ """
150
+ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
151
+ images or extracted frames from video into it. Return (target_dir, image_paths).
152
+ """
153
+ start_time = time.time()
154
+ gc.collect()
155
+ torch.cuda.empty_cache()
156
+
157
+ # Create a unique folder name
158
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
159
+ target_dir = f"input_images_{timestamp}"
160
+ target_dir_images = os.path.join(target_dir, "images")
161
+
162
+ # Clean up if somehow that folder already exists
163
+ if os.path.exists(target_dir):
164
+ shutil.rmtree(target_dir)
165
+ os.makedirs(target_dir)
166
+ os.makedirs(target_dir_images)
167
+
168
+ image_paths = []
169
+
170
+ # --- Handle images ---
171
+ if input_images is not None:
172
+ for file_data in input_images:
173
+ if isinstance(file_data, dict) and "name" in file_data:
174
+ file_path = file_data["name"]
175
+ else:
176
+ file_path = file_data
177
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
178
+ shutil.copy(file_path, dst_path)
179
+ image_paths.append(dst_path)
180
+
181
+ # --- Handle video ---
182
+ if input_video is not None:
183
+ if isinstance(input_video, dict) and "name" in input_video:
184
+ video_path = input_video["name"]
185
+ else:
186
+ video_path = input_video
187
+
188
+ vs = cv2.VideoCapture(video_path)
189
+ fps = vs.get(cv2.CAP_PROP_FPS)
190
+
191
+ count = 0
192
+ video_frame_num = 0
193
+ while True:
194
+ gotit, frame = vs.read()
195
+ if not gotit:
196
+ break
197
+ count += 1
198
+ image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
199
+ cv2.imwrite(image_path, frame)
200
+ image_paths.append(image_path)
201
+ video_frame_num += 1
202
+
203
+ # Sort final images for gallery
204
+ image_paths = sorted(image_paths)
205
+
206
+ end_time = time.time()
207
+ print(
208
+ f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds"
209
+ )
210
+ return target_dir, image_paths
211
+
212
+
213
+ # -------------------------------------------------------------------------
214
+ # 3) Update gallery on upload
215
+ # -------------------------------------------------------------------------
216
+ def update_gallery_on_upload(input_video, input_images):
217
+ """
218
+ Whenever user uploads or changes files, immediately handle them
219
+ and show in the gallery. Return (target_dir, image_paths).
220
+ If nothing is uploaded, returns "None" and empty list.
221
+ """
222
+ if not input_video and not input_images:
223
+ return None, None, None, None
224
+ target_dir, image_paths = handle_uploads(input_video, input_images)
225
+ return (
226
+ None,
227
+ target_dir,
228
+ image_paths,
229
+ "Upload complete. Click 'Reconstruct' to begin 3D processing.",
230
+ )
231
+
232
+
233
+ # -------------------------------------------------------------------------
234
+ # 4) Reconstruction: uses the target_dir plus any viz parameters
235
+ # -------------------------------------------------------------------------
236
+ def gradio_demo(
237
+ target_dir,
238
+ conf_thres=3.0,
239
+ frame_filter="All",
240
+ mask_black_bg=False,
241
+ mask_white_bg=False,
242
+ show_cam=True,
243
+ mask_sky=False,
244
+ downsample_ratio=100.0,
245
+ prediction_mode="Pointmap Regression",
246
+ ):
247
+ """
248
+ Perform reconstruction using the already-created target_dir/images.
249
+ """
250
+ if not os.path.isdir(target_dir) or target_dir == "None":
251
+ return None, "No valid target directory found. Please upload first.", None, None
252
+
253
+ start_time = time.time()
254
+ gc.collect()
255
+ torch.cuda.empty_cache()
256
+
257
+ # Prepare frame_filter dropdown
258
+ target_dir_images = os.path.join(target_dir, "images")
259
+ all_files = (
260
+ sorted(os.listdir(target_dir_images))
261
+ if os.path.isdir(target_dir_images)
262
+ else []
263
+ )
264
+ all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
265
+ frame_filter_choices = ["All"] + all_files
266
+
267
+ print("Running run_model...")
268
+ with torch.no_grad():
269
+ predictions = run_model(target_dir, model)
270
+
271
+ # Save predictions
272
+ prediction_save_path = os.path.join(target_dir, "predictions.npz")
273
+ np.savez(prediction_save_path, **predictions)
274
+
275
+ # Handle None frame_filter
276
+ if frame_filter is None:
277
+ frame_filter = "All"
278
+
279
+ # Build a GLB file name
280
+ glbfile = os.path.join(
281
+ target_dir,
282
+ f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
283
+ )
284
+
285
+ # Convert predictions to GLB
286
+ glbscene = predictions_to_glb(
287
+ predictions,
288
+ conf_thres=conf_thres,
289
+ filter_by_frames=frame_filter,
290
+ mask_black_bg=mask_black_bg,
291
+ mask_white_bg=mask_white_bg,
292
+ show_cam=show_cam,
293
+ mask_sky=mask_sky,
294
+ target_dir=target_dir,
295
+ downsample_ratio=downsample_ratio / 100.0,
296
+ prediction_mode=prediction_mode,
297
+ )
298
+ glbscene.export(file_obj=glbfile)
299
+
300
+ # Cleanup
301
+ del predictions
302
+ gc.collect()
303
+ torch.cuda.empty_cache()
304
+
305
+ end_time = time.time()
306
+ print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
307
+ log_msg = (
308
+ f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
309
+ )
310
+
311
+ return (
312
+ glbfile,
313
+ log_msg,
314
+ gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
315
+ )
316
+
317
+
318
+ # -------------------------------------------------------------------------
319
+ # 5) Helper functions for UI resets + re-visualization
320
+ # -------------------------------------------------------------------------
321
+ def clear_fields():
322
+ """
323
+ Clears the 3D viewer, the stored target_dir, and empties the gallery.
324
+ """
325
+ return None
326
+
327
+
328
+ def update_log():
329
+ """
330
+ Display a quick log message while waiting.
331
+ """
332
+ return "Loading and Reconstructing..."
333
+
334
+
335
+ def update_visualization(
336
+ target_dir,
337
+ conf_thres,
338
+ frame_filter,
339
+ mask_black_bg,
340
+ mask_white_bg,
341
+ show_cam,
342
+ mask_sky,
343
+ downsample_ratio,
344
+ prediction_mode,
345
+ is_example,
346
+ ):
347
+ """
348
+ Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
349
+ and return it for the 3D viewer. If is_example == "True", skip.
350
+ """
351
+
352
+ # If it's an example click, skip as requested
353
+ if is_example == "True":
354
+ return (
355
+ None,
356
+ "No reconstruction available. Please click the Reconstruct button first.",
357
+ )
358
+
359
+ if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
360
+ return (
361
+ None,
362
+ "No reconstruction available. Please click the Reconstruct button first.",
363
+ )
364
+
365
+ predictions_path = os.path.join(target_dir, "predictions.npz")
366
+ if not os.path.exists(predictions_path):
367
+ return (
368
+ None,
369
+ f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.",
370
+ )
371
+
372
+ key_list = [
373
+ "pose_enc",
374
+ "depth",
375
+ "depth_conf",
376
+ "world_points",
377
+ "world_points_conf",
378
+ "images",
379
+ "extrinsic",
380
+ "intrinsic",
381
+ "world_points_from_depth",
382
+ ]
383
+
384
+ loaded = np.load(predictions_path)
385
+ predictions = {key: np.array(loaded[key]) for key in key_list if key in loaded}
386
+ print(downsample_ratio)
387
+ glbfile = os.path.join(
388
+ target_dir,
389
+ f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_dr{downsample_ratio}_pred{prediction_mode.replace(' ', '_')}.glb",
390
+ )
391
+
392
+ if not os.path.exists(glbfile):
393
+ glbscene = predictions_to_glb(
394
+ predictions,
395
+ conf_thres=conf_thres,
396
+ filter_by_frames=frame_filter,
397
+ mask_black_bg=mask_black_bg,
398
+ mask_white_bg=mask_white_bg,
399
+ show_cam=show_cam,
400
+ mask_sky=mask_sky,
401
+ target_dir=target_dir,
402
+ downsample_ratio=downsample_ratio * 1.0 / 100.0,
403
+ prediction_mode=prediction_mode,
404
+ )
405
+ glbscene.export(file_obj=glbfile)
406
+
407
+ return glbfile, "Updating Visualization"
408
+
409
+
410
+ # -------------------------------------------------------------------------
411
+ # Example images
412
+ # -------------------------------------------------------------------------
413
+
414
+ great_wall_video = "examples/videos/great_wall.mp4"
415
+ colosseum_video = "examples/videos/Colosseum.mp4"
416
+ room_video = "examples/videos/room.mp4"
417
+ kitchen_video = "examples/videos/kitchen.mp4"
418
+ fern_video = "examples/videos/fern.mp4"
419
+ single_cartoon_video = "examples/videos/single_cartoon.mp4"
420
+ single_oil_painting_video = "examples/videos/single_oil_painting.mp4"
421
+ pyramid_video = "examples/videos/pyramid.mp4"
422
+
423
+
424
+ # -------------------------------------------------------------------------
425
+ # 6) Build Gradio UI
426
+ # -------------------------------------------------------------------------
427
+ theme = gr.themes.Ocean()
428
+ theme.set(
429
+ checkbox_label_background_fill_selected="*button_primary_background_fill",
430
+ checkbox_label_text_color_selected="*button_primary_text_color",
431
+ )
432
+
433
+ with gr.Blocks(
434
+ theme=theme,
435
+ css="""
436
+ .custom-log * {
437
+ font-style: italic;
438
+ font-size: 22px !important;
439
+ background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
440
+ -webkit-background-clip: text;
441
+ background-clip: text;
442
+ font-weight: bold !important;
443
+ color: transparent !important;
444
+ text-align: center !important;
445
+ }
446
+
447
+ .example-log * {
448
+ font-style: italic;
449
+ font-size: 16px !important;
450
+ background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
451
+ -webkit-background-clip: text;
452
+ background-clip: text;
453
+ color: transparent !important;
454
+ }
455
+
456
+ #my_radio .wrap {
457
+ display: flex;
458
+ flex-wrap: nowrap;
459
+ justify-content: center;
460
+ align-items: center;
461
+ }
462
+
463
+ #my_radio .wrap label {
464
+ display: flex;
465
+ width: 50%;
466
+ justify-content: center;
467
+ align-items: center;
468
+ margin: 0;
469
+ padding: 10px 0;
470
+ box-sizing: border-box;
471
+ }
472
+ """,
473
+ ) as demo:
474
+ # Instead of gr.State, we use a hidden Textbox:
475
+ is_example = gr.Textbox(label="is_example", visible=False, value="None")
476
+ num_images = gr.Textbox(label="num_images", visible=False, value="None")
477
+
478
+ gr.HTML(
479
+ """
480
+ <h1>🏛️ SAIL-Recon: Large SfM by Augmenting Scene Regression with Localization</h1>
481
+ <p>
482
+ <a href="https://github.com/HKUST-SAIL/sail-recon">🐙 GitHub Repository</a> |
483
+ <a href="https://hkust-sail.github.io/sail-recon/">Project Page</a>
484
+ </p>
485
+
486
+ <div style="font-size: 16px; line-height: 1.5;">
487
+ <p>Upload a video or a set of images to create a 3D reconstruction of a scene or object. SAIL-Recon takes these images and generates a 3D point cloud, along with estimated camera poses.</p>
488
+
489
+ <h3>Getting Started:</h3>
490
+ <ol>
491
+ <li><strong>Upload Your Data:</strong> Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).</li>
492
+ <li><strong>Preview:</strong> Your uploaded images will appear in the gallery on the left.</li>
493
+ <li><strong>Reconstruct:</strong> Click the "Reconstruct" button to start the 3D reconstruction process.</li>
494
+ <li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note that the visualization of 3D points may be slow for a large number of input images.</li>
495
+ <li>
496
+ <strong>Adjust Visualization (Optional):</strong>
497
+ After reconstruction, you can fine-tune the visualization using the options below
498
+ <details style="display:inline;">
499
+ <summary style="display:inline;">(<strong>click to expand</strong>):</summary>
500
+ <ul>
501
+ <li><em>Confidence Threshold:</em> Adjust the filtering of points based on confidence.</li>
502
+ <li><em>Show Points from Frame:</em> Select specific frames to display in the point cloud.</li>
503
+ <li><em>Show Camera:</em> Toggle the display of estimated camera positions.</li>
504
+ <li><em>Filter Sky / Filter Black Background:</em> Remove sky or black-background points.</li>
505
+ <li><em>Select a Prediction Mode:</em> Choose between "Depthmap and Camera Branch" or "Pointmap Branch."</li>
506
+ </ul>
507
+ </details>
508
+ </li>
509
+ </ol>
510
+ <p><strong style="color: #0ea5e9;">Please note:</strong> <span style="color: #0ea5e9; font-weight: bold;">SAIL-Recon typically reconstructs a scene at 5FPS with full 3D attributes. However, visualizing 3D points may take tens of seconds due to third-party rendering, which is independent of SAIL-Recon's processing time. Using the 'demo.py' can provide much faster processing.</span></p>
511
+ </div>
512
+ """
513
+ )
514
+
515
+ target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
516
+
517
+ with gr.Row():
518
+ with gr.Column(scale=2):
519
+ input_video = gr.Video(label="Upload Video", interactive=True)
520
+ input_images = gr.File(
521
+ file_count="multiple", label="Upload Images", interactive=True
522
+ )
523
+
524
+ image_gallery = gr.Gallery(
525
+ label="Preview",
526
+ columns=4,
527
+ height="300px",
528
+ show_download_button=True,
529
+ object_fit="contain",
530
+ preview=True,
531
+ )
532
+
533
+ with gr.Column(scale=4):
534
+ with gr.Column():
535
+ gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**")
536
+ log_output = gr.Markdown(
537
+ "Please upload a video or images, then click Reconstruct.",
538
+ elem_classes=["custom-log"],
539
+ )
540
+ reconstruction_output = gr.Model3D(
541
+ height=520, zoom_speed=0.5, pan_speed=0.5
542
+ )
543
+
544
+ with gr.Row():
545
+ submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
546
+ clear_btn = gr.ClearButton(
547
+ [
548
+ input_video,
549
+ input_images,
550
+ reconstruction_output,
551
+ log_output,
552
+ target_dir_output,
553
+ image_gallery,
554
+ ],
555
+ scale=1,
556
+ )
557
+
558
+ with gr.Row():
559
+ prediction_mode = gr.Radio(
560
+ ["Depthmap and Camera Branch", "Pointmap Branch"],
561
+ label="Select a Prediction Mode",
562
+ value="Depthmap and Camera Branch",
563
+ scale=1,
564
+ elem_id="my_radio",
565
+ )
566
+
567
+ with gr.Row():
568
+ conf_thres = gr.Slider(
569
+ minimum=0,
570
+ maximum=100,
571
+ value=50,
572
+ step=0.1,
573
+ label="Confidence Threshold (%)",
574
+ )
575
+ downsample_ratio = gr.Slider(
576
+ minimum=1.0,
577
+ maximum=100,
578
+ value=100,
579
+ step=0.1,
580
+ label="Downsample Ratio(%)",
581
+ )
582
+ frame_filter = gr.Dropdown(
583
+ choices=["All"], value="All", label="Show Points from Frame"
584
+ )
585
+ with gr.Column():
586
+ show_cam = gr.Checkbox(label="Show Camera", value=True)
587
+ mask_sky = gr.Checkbox(label="Filter Sky", value=False)
588
+ mask_black_bg = gr.Checkbox(
589
+ label="Filter Black Background", value=False
590
+ )
591
+ mask_white_bg = gr.Checkbox(
592
+ label="Filter White Background", value=False
593
+ )
594
+
595
+ # ---------------------- Examples section ----------------------
596
+ examples = [
597
+ [
598
+ colosseum_video,
599
+ "22",
600
+ None,
601
+ 20.0,
602
+ False,
603
+ False,
604
+ True,
605
+ False,
606
+ "Depthmap and Camera Branch",
607
+ "True",
608
+ ],
609
+ [
610
+ pyramid_video,
611
+ "30",
612
+ None,
613
+ 35.0,
614
+ False,
615
+ False,
616
+ True,
617
+ False,
618
+ "Depthmap and Camera Branch",
619
+ "True",
620
+ ],
621
+ [
622
+ single_cartoon_video,
623
+ "1",
624
+ None,
625
+ 15.0,
626
+ False,
627
+ False,
628
+ True,
629
+ False,
630
+ "Depthmap and Camera Branch",
631
+ "True",
632
+ ],
633
+ [
634
+ single_oil_painting_video,
635
+ "1",
636
+ None,
637
+ 20.0,
638
+ False,
639
+ False,
640
+ True,
641
+ True,
642
+ "Depthmap and Camera Branch",
643
+ "True",
644
+ ],
645
+ [
646
+ room_video,
647
+ "8",
648
+ None,
649
+ 5.0,
650
+ False,
651
+ False,
652
+ True,
653
+ False,
654
+ "Depthmap and Camera Branch",
655
+ "True",
656
+ ],
657
+ [
658
+ kitchen_video,
659
+ "25",
660
+ None,
661
+ 50.0,
662
+ False,
663
+ False,
664
+ True,
665
+ False,
666
+ "Depthmap and Camera Branch",
667
+ "True",
668
+ ],
669
+ [
670
+ fern_video,
671
+ "20",
672
+ None,
673
+ 45.0,
674
+ False,
675
+ False,
676
+ True,
677
+ False,
678
+ "Depthmap and Camera Branch",
679
+ "True",
680
+ ],
681
+ ]
682
+
683
+ def example_pipeline(
684
+ input_video,
685
+ num_images_str,
686
+ input_images,
687
+ conf_thres,
688
+ mask_black_bg,
689
+ mask_white_bg,
690
+ show_cam,
691
+ mask_sky,
692
+ downsample_ratio,
693
+ prediction_mode,
694
+ is_example_str,
695
+ ):
696
+ """
697
+ 1) Copy example images to new target_dir
698
+ 2) Reconstruct
699
+ 3) Return model3D + logs + new_dir + updated dropdown + gallery
700
+ We do NOT return is_example. It's just an input.
701
+ """
702
+ target_dir, image_paths = handle_uploads(input_video, input_images)
703
+ # Always use "All" for frame_filter in examples
704
+ frame_filter = "All"
705
+ glbfile, log_msg, dropdown = gradio_demo(
706
+ target_dir,
707
+ conf_thres,
708
+ frame_filter,
709
+ mask_black_bg,
710
+ mask_white_bg,
711
+ show_cam,
712
+ mask_sky,
713
+ downsample_ratio,
714
+ prediction_mode,
715
+ )
716
+ return glbfile, log_msg, target_dir, dropdown, image_paths
717
+
718
+ gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
719
+
720
+ # gr.Examples(
721
+ # examples=examples,
722
+ # inputs=[
723
+ # input_video,
724
+ # num_images,
725
+ # input_images,
726
+ # conf_thres,
727
+ # mask_black_bg,
728
+ # mask_white_bg,
729
+ # show_cam,
730
+ # mask_sky,
731
+ # downsample_ratio,
732
+ # prediction_mode,
733
+ # is_example,
734
+ # ],
735
+ # outputs=[reconstruction_output, log_output, target_dir_output, frame_filter, image_gallery],
736
+ # fn=example_pipeline,
737
+ # cache_examples=False,
738
+ # examples_per_page=50,
739
+ # )
740
+
741
+ # -------------------------------------------------------------------------
742
+ # "Reconstruct" button logic:
743
+ # - Clear fields
744
+ # - Update log
745
+ # - gradio_demo(...) with the existing target_dir
746
+ # - Then set is_example = "False"
747
+ # -------------------------------------------------------------------------
748
+ submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
749
+ fn=update_log, inputs=[], outputs=[log_output]
750
+ ).then(
751
+ fn=gradio_demo,
752
+ inputs=[
753
+ target_dir_output,
754
+ conf_thres,
755
+ frame_filter,
756
+ mask_black_bg,
757
+ mask_white_bg,
758
+ show_cam,
759
+ mask_sky,
760
+ downsample_ratio,
761
+ prediction_mode,
762
+ ],
763
+ outputs=[reconstruction_output, log_output, frame_filter],
764
+ ).then(
765
+ fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
766
+ )
767
+
768
+ # -------------------------------------------------------------------------
769
+ # Real-time Visualization Updates
770
+ # -------------------------------------------------------------------------
771
+ conf_thres.change(
772
+ update_visualization,
773
+ [
774
+ target_dir_output,
775
+ conf_thres,
776
+ frame_filter,
777
+ mask_black_bg,
778
+ mask_white_bg,
779
+ show_cam,
780
+ mask_sky,
781
+ downsample_ratio,
782
+ prediction_mode,
783
+ is_example,
784
+ ],
785
+ [reconstruction_output, log_output],
786
+ )
787
+ downsample_ratio.change(
788
+ update_visualization,
789
+ [
790
+ target_dir_output,
791
+ conf_thres,
792
+ frame_filter,
793
+ mask_black_bg,
794
+ mask_white_bg,
795
+ show_cam,
796
+ mask_sky,
797
+ downsample_ratio,
798
+ prediction_mode,
799
+ is_example,
800
+ ],
801
+ [reconstruction_output, log_output],
802
+ )
803
+ frame_filter.change(
804
+ update_visualization,
805
+ [
806
+ target_dir_output,
807
+ conf_thres,
808
+ frame_filter,
809
+ mask_black_bg,
810
+ mask_white_bg,
811
+ show_cam,
812
+ mask_sky,
813
+ downsample_ratio,
814
+ prediction_mode,
815
+ is_example,
816
+ ],
817
+ [reconstruction_output, log_output],
818
+ )
819
+ mask_black_bg.change(
820
+ update_visualization,
821
+ [
822
+ target_dir_output,
823
+ conf_thres,
824
+ frame_filter,
825
+ mask_black_bg,
826
+ mask_white_bg,
827
+ show_cam,
828
+ mask_sky,
829
+ downsample_ratio,
830
+ prediction_mode,
831
+ is_example,
832
+ ],
833
+ [reconstruction_output, log_output],
834
+ )
835
+ mask_white_bg.change(
836
+ update_visualization,
837
+ [
838
+ target_dir_output,
839
+ conf_thres,
840
+ frame_filter,
841
+ mask_black_bg,
842
+ mask_white_bg,
843
+ show_cam,
844
+ mask_sky,
845
+ downsample_ratio,
846
+ prediction_mode,
847
+ is_example,
848
+ ],
849
+ [reconstruction_output, log_output],
850
+ )
851
+ show_cam.change(
852
+ update_visualization,
853
+ [
854
+ target_dir_output,
855
+ conf_thres,
856
+ frame_filter,
857
+ mask_black_bg,
858
+ mask_white_bg,
859
+ show_cam,
860
+ mask_sky,
861
+ downsample_ratio,
862
+ prediction_mode,
863
+ is_example,
864
+ ],
865
+ [reconstruction_output, log_output],
866
+ )
867
+ mask_sky.change(
868
+ update_visualization,
869
+ [
870
+ target_dir_output,
871
+ conf_thres,
872
+ frame_filter,
873
+ mask_black_bg,
874
+ mask_white_bg,
875
+ show_cam,
876
+ mask_sky,
877
+ downsample_ratio,
878
+ prediction_mode,
879
+ is_example,
880
+ ],
881
+ [reconstruction_output, log_output],
882
+ )
883
+ prediction_mode.change(
884
+ update_visualization,
885
+ [
886
+ target_dir_output,
887
+ conf_thres,
888
+ frame_filter,
889
+ mask_black_bg,
890
+ mask_white_bg,
891
+ show_cam,
892
+ mask_sky,
893
+ downsample_ratio,
894
+ prediction_mode,
895
+ is_example,
896
+ ],
897
+ [reconstruction_output, log_output],
898
+ )
899
+
900
+ # -------------------------------------------------------------------------
901
+ # Auto-update gallery whenever user uploads or changes their files
902
+ # -------------------------------------------------------------------------
903
+ input_video.change(
904
+ fn=update_gallery_on_upload,
905
+ inputs=[input_video, input_images],
906
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
907
+ )
908
+ input_images.change(
909
+ fn=update_gallery_on_upload,
910
+ inputs=[input_video, input_images],
911
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
912
+ )
913
+
914
+ demo.queue(max_size=20).launch(show_error=True, share=True)
demo.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) HKUST SAIL-Lab and Horizon Robotics.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
7
+ #
8
+ # This source code is licensed under the Apache License, Version 2.0
9
+ # found in the LICENSE file in the root directory of this source tree.
10
+
11
+ import argparse
12
+ import os
13
+
14
+ import torch
15
+ from tqdm import tqdm
16
+
17
+ from eval.utils.device import to_cpu
18
+ from eval.utils.eval_utils import uniform_sample
19
+ from sailrecon.models.sail_recon import SailRecon
20
+ from sailrecon.utils.load_fn import load_and_preprocess_images
21
+
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ # bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+)
24
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
25
+
26
+
27
+ def demo(args):
28
+ # Initialize the model and load the pretrained weights.
29
+ # This will automatically download the model weights the first time it's run, which may take a while.
30
+ _URL = "https://huggingface.co/HKUST-SAIL/SAIL-Recon/resolve/main/sailrecon.pt"
31
+ model_dir = args.ckpt
32
+ # model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
33
+ model = SailRecon(kv_cache=True)
34
+ if model_dir is not None:
35
+ model.load_state_dict(torch.load(model_dir))
36
+ else:
37
+ model.load_state_dict(
38
+ torch.hub.load_state_dict_from_url(_URL, model_dir=model_dir)
39
+ )
40
+ model = model.to(device=device)
41
+ model.eval()
42
+
43
+ # Load and preprocess example images
44
+ scene_name = "1"
45
+ if args.vid_dir is not None:
46
+ import cv2
47
+
48
+ image_names = []
49
+ video_path = args.vid_dir
50
+ vs = cv2.VideoCapture(video_path)
51
+ fps = vs.get(cv2.CAP_PROP_FPS)
52
+ tmp_file = os.path.join("tmp_video", os.path.basename(video_path).split(".")[0])
53
+ os.makedirs(tmp_file, exist_ok=True)
54
+ count = 0
55
+ video_frame_num = 0
56
+ while True:
57
+ gotit, frame = vs.read()
58
+ if not gotit:
59
+ break
60
+ count += 1
61
+ image_path = os.path.join(tmp_file, f"{video_frame_num:06}.png")
62
+ cv2.imwrite(image_path, frame)
63
+ image_names.append(image_path)
64
+ video_frame_num += 1
65
+ images = load_and_preprocess_images(image_names).to(device)
66
+ scene_name = os.path.basename(video_path).split(".")[0]
67
+ else:
68
+ image_names = os.listdir(args.img_dir)
69
+ image_names = [os.path.join(args.img_dir, f) for f in sorted(image_names)]
70
+ images = load_and_preprocess_images(image_names).to(device)
71
+ scene_name = os.path.basename(args.img_dir)
72
+
73
+ # anchor image selection
74
+ select_indices = uniform_sample(len(image_names), min(100, len(image_names)))
75
+ anchor_images = images[select_indices]
76
+
77
+ os.makedirs(os.path.join(args.out_dir, scene_name), exist_ok=True)
78
+
79
+ with torch.no_grad():
80
+ with torch.cuda.amp.autocast(dtype=dtype):
81
+ # processing anchor images to build scene representation (kv_cache)
82
+ print("Processing anchor images ...")
83
+ model.tmp_forward(anchor_images)
84
+ # remove the global transformer blocks to save memory during relocalization
85
+ del model.aggregator.global_blocks
86
+ # relocalization on all images
87
+ predictions = []
88
+
89
+ with tqdm(total=len(image_names), desc="Relocalizing") as pbar:
90
+ for img_split in images.split(20, dim=0):
91
+ pbar.update(20)
92
+ predictions += to_cpu(model.reloc(img_split))
93
+
94
+ # save the predicted point cloud and camera poses
95
+
96
+ from eval.utils.geometry import save_pointcloud_with_plyfile
97
+
98
+ save_pointcloud_with_plyfile(
99
+ predictions, os.path.join(args.out_dir, scene_name, "pred.ply")
100
+ )
101
+
102
+ import numpy as np
103
+
104
+ from eval.utils.eval_utils import save_kitti_poses
105
+
106
+ poses_w2c_estimated = [
107
+ one_result["extrinsic"][0].cpu().numpy() for one_result in predictions
108
+ ]
109
+ poses_c2w_estimated = [
110
+ np.linalg.inv(np.vstack([pose, np.array([0, 0, 0, 1])]))
111
+ for pose in poses_w2c_estimated
112
+ ]
113
+
114
+ save_kitti_poses(
115
+ poses_c2w_estimated,
116
+ os.path.join(args.out_dir, scene_name, "pred.txt"),
117
+ )
118
+
119
+
120
+ if __name__ == "__main__":
121
+ args = argparse.ArgumentParser()
122
+ args.add_argument(
123
+ "--img_dir", type=str, default="samples/kitchen", help="input image folder"
124
+ )
125
+ args.add_argument("--vid_dir", type=str, default=None, help="input video path")
126
+ args.add_argument("--out_dir", type=str, default="outputs", help="output folder")
127
+ args.add_argument(
128
+ "--ckpt", type=str, default=None, help="pretrained model checkpoint"
129
+ )
130
+ args = args.parse_args()
131
+ demo(args)
demo_gradio.py ADDED
@@ -0,0 +1,921 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) HKUST SAIL-Lab and Horizon Robotics.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
7
+ #
8
+ # This source code is licensed under the Apache License, Version 2.0
9
+ # found in the LICENSE file in the root directory of this source tree.
10
+
11
+ import gc
12
+ import glob
13
+ import os
14
+ import shutil
15
+ import sys
16
+ import time
17
+ from datetime import datetime
18
+
19
+ import cv2
20
+ import gradio as gr
21
+ import numpy as np
22
+ import torch
23
+ from tqdm import tqdm
24
+
25
+ from eval.utils.device import to_cpu
26
+ from eval.utils.eval_utils import uniform_sample
27
+ from sailrecon.models.sail_recon import SailRecon
28
+ from sailrecon.utils.geometry import unproject_depth_map_to_point_map
29
+ from sailrecon.utils.load_fn import load_and_preprocess_images
30
+ from sailrecon.utils.pose_enc import (
31
+ extri_intri_to_pose_encoding,
32
+ pose_encoding_to_extri_intri,
33
+ )
34
+ from visual_util import predictions_to_glb
35
+
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+
38
+ print("Initializing and loading SailRecon model...")
39
+
40
+ model = SailRecon(kv_cache=True)
41
+ # _URL = "https://huggingface.co/HKUST-SAIL/SAIL-Recon/resolve/main/sailrecon.pt"
42
+ # model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
43
+ model_dir = "ckpt/sailrecon.pt"
44
+ model.load_state_dict(torch.load(model_dir))
45
+
46
+
47
+ model.eval()
48
+ model = model.to(device)
49
+
50
+
51
+ # -------------------------------------------------------------------------
52
+ # 1) Core model inference
53
+ # -------------------------------------------------------------------------
54
+ def run_model(target_dir, model, anchor_size=100) -> dict:
55
+ """
56
+ Run the SAIL-Recon model on images in the 'target_dir/images' folder and return predictions.
57
+ """
58
+ print(f"Processing images from {target_dir}")
59
+
60
+ # Device check
61
+ device = "cuda" if torch.cuda.is_available() else "cpu"
62
+ if not torch.cuda.is_available():
63
+ raise ValueError("CUDA is not available. Check your environment.")
64
+
65
+ # Move model to device
66
+ model = model.to(device)
67
+ model.eval()
68
+
69
+ # Load and preprocess images
70
+ image_names = glob.glob(os.path.join(target_dir, "images", "*"))
71
+ image_names = sorted(image_names)
72
+ print(f"Found {len(image_names)} images")
73
+ if len(image_names) == 0:
74
+ raise ValueError("No images found. Check your upload.")
75
+
76
+ images = load_and_preprocess_images(image_names).to(device)
77
+ print(f"Preprocessed images shape: {images.shape}")
78
+ # anchor image selection
79
+ select_indices = uniform_sample(len(image_names), min(100, len(image_names)))
80
+ anchor_images = images[select_indices]
81
+
82
+ # Run inference
83
+ print("Running inference...")
84
+ dtype = (
85
+ torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
86
+ )
87
+
88
+ with torch.no_grad():
89
+ with torch.cuda.amp.autocast(dtype=dtype):
90
+ print("Processing anchor images ...")
91
+ model.tmp_forward(anchor_images)
92
+ # del model.aggregator.global_blocks
93
+ # relocalization on all images
94
+ predictions_s = []
95
+ with tqdm(total=len(image_names), desc="Relocalizing") as pbar:
96
+ for img_split in images.split(10, dim=0):
97
+ pbar.update(10)
98
+ predictions_s += to_cpu(
99
+ model.reloc(img_split, ret_img=True, memory_save=False)
100
+ )
101
+
102
+ predictions = {}
103
+ predictions["extrinsic"] = torch.cat(
104
+ [s["extrinsic"] for s in predictions_s], dim=0
105
+ ) # (S, 4, 4)
106
+ predictions["intrinsic"] = torch.cat(
107
+ [s["intrinsic"] for s in predictions_s], dim=0
108
+ ) # (S, 4, 4)
109
+ predictions["depth"] = torch.cat(
110
+ [s["depth_map"] for s in predictions_s], dim=0
111
+ ) # (S, H, W, 1)
112
+ predictions["depth_conf"] = torch.cat(
113
+ [s["dpt_cnf"] for s in predictions_s], dim=0
114
+ ) # (S, H, W, 1)
115
+ predictions["images"] = torch.cat(
116
+ [s["images"] for s in predictions_s], dim=0
117
+ ) # (S, H, W, 3)
118
+ predictions["world_points"] = torch.cat(
119
+ [s["point_map"] for s in predictions_s], dim=0
120
+ ) # (S, H, W, 3)
121
+ predictions["world_points_conf"] = torch.cat(
122
+ [s["xyz_cnf"] for s in predictions_s], dim=0
123
+ ) # (S, H, W, 3)
124
+ predictions["pose_enc"] = extri_intri_to_pose_encoding(
125
+ predictions["extrinsic"].unsqueeze(0),
126
+ predictions["intrinsic"].unsqueeze(0),
127
+ images.shape[-2:],
128
+ )[
129
+ 0
130
+ ] # a
131
+ del predictions_s
132
+
133
+ # Convert tensors to numpy
134
+ for key in predictions.keys():
135
+ if isinstance(predictions[key], torch.Tensor):
136
+ predictions[key] = predictions[key].cpu().numpy() # remove batch dimension
137
+ predictions["pose_enc_list"] = None # remove pose_enc_list
138
+
139
+ # Generate world points from depth map
140
+ print("Computing world points from depth map...")
141
+ depth_map = predictions["depth"] # (S, H, W, 1)
142
+ world_points = unproject_depth_map_to_point_map(
143
+ depth_map, predictions["extrinsic"], predictions["intrinsic"]
144
+ )
145
+ predictions["world_points_from_depth"] = world_points
146
+
147
+ # Clean up
148
+ torch.cuda.empty_cache()
149
+ return predictions
150
+
151
+
152
+ # -------------------------------------------------------------------------
153
+ # 2) Handle uploaded video/images --> produce target_dir + images
154
+ # -------------------------------------------------------------------------
155
+ def handle_uploads(input_video, input_images):
156
+ """
157
+ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
158
+ images or extracted frames from video into it. Return (target_dir, image_paths).
159
+ """
160
+ start_time = time.time()
161
+ gc.collect()
162
+ torch.cuda.empty_cache()
163
+
164
+ # Create a unique folder name
165
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
166
+ target_dir = f"input_images_{timestamp}"
167
+ target_dir_images = os.path.join(target_dir, "images")
168
+
169
+ # Clean up if somehow that folder already exists
170
+ if os.path.exists(target_dir):
171
+ shutil.rmtree(target_dir)
172
+ os.makedirs(target_dir)
173
+ os.makedirs(target_dir_images)
174
+
175
+ image_paths = []
176
+
177
+ # --- Handle images ---
178
+ if input_images is not None:
179
+ for file_data in input_images:
180
+ if isinstance(file_data, dict) and "name" in file_data:
181
+ file_path = file_data["name"]
182
+ else:
183
+ file_path = file_data
184
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
185
+ shutil.copy(file_path, dst_path)
186
+ image_paths.append(dst_path)
187
+
188
+ # --- Handle video ---
189
+ if input_video is not None:
190
+ if isinstance(input_video, dict) and "name" in input_video:
191
+ video_path = input_video["name"]
192
+ else:
193
+ video_path = input_video
194
+
195
+ vs = cv2.VideoCapture(video_path)
196
+ fps = vs.get(cv2.CAP_PROP_FPS)
197
+
198
+ count = 0
199
+ video_frame_num = 0
200
+ while True:
201
+ gotit, frame = vs.read()
202
+ if not gotit:
203
+ break
204
+ count += 1
205
+ image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
206
+ cv2.imwrite(image_path, frame)
207
+ image_paths.append(image_path)
208
+ video_frame_num += 1
209
+
210
+ # Sort final images for gallery
211
+ image_paths = sorted(image_paths)
212
+
213
+ end_time = time.time()
214
+ print(
215
+ f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds"
216
+ )
217
+ return target_dir, image_paths
218
+
219
+
220
+ # -------------------------------------------------------------------------
221
+ # 3) Update gallery on upload
222
+ # -------------------------------------------------------------------------
223
+ def update_gallery_on_upload(input_video, input_images):
224
+ """
225
+ Whenever user uploads or changes files, immediately handle them
226
+ and show in the gallery. Return (target_dir, image_paths).
227
+ If nothing is uploaded, returns "None" and empty list.
228
+ """
229
+ if not input_video and not input_images:
230
+ return None, None, None, None
231
+ target_dir, image_paths = handle_uploads(input_video, input_images)
232
+ return (
233
+ None,
234
+ target_dir,
235
+ image_paths,
236
+ "Upload complete. Click 'Reconstruct' to begin 3D processing.",
237
+ )
238
+
239
+
240
+ # -------------------------------------------------------------------------
241
+ # 4) Reconstruction: uses the target_dir plus any viz parameters
242
+ # -------------------------------------------------------------------------
243
+ def gradio_demo(
244
+ target_dir,
245
+ conf_thres=3.0,
246
+ frame_filter="All",
247
+ mask_black_bg=False,
248
+ mask_white_bg=False,
249
+ show_cam=True,
250
+ mask_sky=False,
251
+ downsample_ratio=100.0,
252
+ prediction_mode="Pointmap Regression",
253
+ ):
254
+ """
255
+ Perform reconstruction using the already-created target_dir/images.
256
+ """
257
+ if not os.path.isdir(target_dir) or target_dir == "None":
258
+ return None, "No valid target directory found. Please upload first.", None, None
259
+
260
+ start_time = time.time()
261
+ gc.collect()
262
+ torch.cuda.empty_cache()
263
+
264
+ # Prepare frame_filter dropdown
265
+ target_dir_images = os.path.join(target_dir, "images")
266
+ all_files = (
267
+ sorted(os.listdir(target_dir_images))
268
+ if os.path.isdir(target_dir_images)
269
+ else []
270
+ )
271
+ all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
272
+ frame_filter_choices = ["All"] + all_files
273
+
274
+ print("Running run_model...")
275
+ with torch.no_grad():
276
+ predictions = run_model(target_dir, model)
277
+
278
+ # Save predictions
279
+ prediction_save_path = os.path.join(target_dir, "predictions.npz")
280
+ np.savez(prediction_save_path, **predictions)
281
+
282
+ # Handle None frame_filter
283
+ if frame_filter is None:
284
+ frame_filter = "All"
285
+
286
+ # Build a GLB file name
287
+ glbfile = os.path.join(
288
+ target_dir,
289
+ f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
290
+ )
291
+
292
+ # Convert predictions to GLB
293
+ glbscene = predictions_to_glb(
294
+ predictions,
295
+ conf_thres=conf_thres,
296
+ filter_by_frames=frame_filter,
297
+ mask_black_bg=mask_black_bg,
298
+ mask_white_bg=mask_white_bg,
299
+ show_cam=show_cam,
300
+ mask_sky=mask_sky,
301
+ target_dir=target_dir,
302
+ downsample_ratio=downsample_ratio / 100.0,
303
+ prediction_mode=prediction_mode,
304
+ )
305
+ glbscene.export(file_obj=glbfile)
306
+
307
+ # Cleanup
308
+ del predictions
309
+ gc.collect()
310
+ torch.cuda.empty_cache()
311
+
312
+ end_time = time.time()
313
+ print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
314
+ log_msg = (
315
+ f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
316
+ )
317
+
318
+ return (
319
+ glbfile,
320
+ log_msg,
321
+ gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
322
+ )
323
+
324
+
325
+ # -------------------------------------------------------------------------
326
+ # 5) Helper functions for UI resets + re-visualization
327
+ # -------------------------------------------------------------------------
328
+ def clear_fields():
329
+ """
330
+ Clears the 3D viewer, the stored target_dir, and empties the gallery.
331
+ """
332
+ return None
333
+
334
+
335
+ def update_log():
336
+ """
337
+ Display a quick log message while waiting.
338
+ """
339
+ return "Loading and Reconstructing..."
340
+
341
+
342
+ def update_visualization(
343
+ target_dir,
344
+ conf_thres,
345
+ frame_filter,
346
+ mask_black_bg,
347
+ mask_white_bg,
348
+ show_cam,
349
+ mask_sky,
350
+ downsample_ratio,
351
+ prediction_mode,
352
+ is_example,
353
+ ):
354
+ """
355
+ Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
356
+ and return it for the 3D viewer. If is_example == "True", skip.
357
+ """
358
+
359
+ # If it's an example click, skip as requested
360
+ if is_example == "True":
361
+ return (
362
+ None,
363
+ "No reconstruction available. Please click the Reconstruct button first.",
364
+ )
365
+
366
+ if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
367
+ return (
368
+ None,
369
+ "No reconstruction available. Please click the Reconstruct button first.",
370
+ )
371
+
372
+ predictions_path = os.path.join(target_dir, "predictions.npz")
373
+ if not os.path.exists(predictions_path):
374
+ return (
375
+ None,
376
+ f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.",
377
+ )
378
+
379
+ key_list = [
380
+ "pose_enc",
381
+ "depth",
382
+ "depth_conf",
383
+ "world_points",
384
+ "world_points_conf",
385
+ "images",
386
+ "extrinsic",
387
+ "intrinsic",
388
+ "world_points_from_depth",
389
+ ]
390
+
391
+ loaded = np.load(predictions_path)
392
+ predictions = {key: np.array(loaded[key]) for key in key_list if key in loaded}
393
+ print(downsample_ratio)
394
+ glbfile = os.path.join(
395
+ target_dir,
396
+ f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_dr{downsample_ratio}_pred{prediction_mode.replace(' ', '_')}.glb",
397
+ )
398
+
399
+ if not os.path.exists(glbfile):
400
+ glbscene = predictions_to_glb(
401
+ predictions,
402
+ conf_thres=conf_thres,
403
+ filter_by_frames=frame_filter,
404
+ mask_black_bg=mask_black_bg,
405
+ mask_white_bg=mask_white_bg,
406
+ show_cam=show_cam,
407
+ mask_sky=mask_sky,
408
+ target_dir=target_dir,
409
+ downsample_ratio=downsample_ratio * 1.0 / 100.0,
410
+ prediction_mode=prediction_mode,
411
+ )
412
+ glbscene.export(file_obj=glbfile)
413
+
414
+ return glbfile, "Updating Visualization"
415
+
416
+
417
+ # -------------------------------------------------------------------------
418
+ # Example images
419
+ # -------------------------------------------------------------------------
420
+
421
+ great_wall_video = "examples/videos/great_wall.mp4"
422
+ colosseum_video = "examples/videos/Colosseum.mp4"
423
+ room_video = "examples/videos/room.mp4"
424
+ kitchen_video = "examples/videos/kitchen.mp4"
425
+ fern_video = "examples/videos/fern.mp4"
426
+ single_cartoon_video = "examples/videos/single_cartoon.mp4"
427
+ single_oil_painting_video = "examples/videos/single_oil_painting.mp4"
428
+ pyramid_video = "examples/videos/pyramid.mp4"
429
+
430
+
431
+ # -------------------------------------------------------------------------
432
+ # 6) Build Gradio UI
433
+ # -------------------------------------------------------------------------
434
+ theme = gr.themes.Ocean()
435
+ theme.set(
436
+ checkbox_label_background_fill_selected="*button_primary_background_fill",
437
+ checkbox_label_text_color_selected="*button_primary_text_color",
438
+ )
439
+
440
+ with gr.Blocks(
441
+ theme=theme,
442
+ css="""
443
+ .custom-log * {
444
+ font-style: italic;
445
+ font-size: 22px !important;
446
+ background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
447
+ -webkit-background-clip: text;
448
+ background-clip: text;
449
+ font-weight: bold !important;
450
+ color: transparent !important;
451
+ text-align: center !important;
452
+ }
453
+
454
+ .example-log * {
455
+ font-style: italic;
456
+ font-size: 16px !important;
457
+ background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
458
+ -webkit-background-clip: text;
459
+ background-clip: text;
460
+ color: transparent !important;
461
+ }
462
+
463
+ #my_radio .wrap {
464
+ display: flex;
465
+ flex-wrap: nowrap;
466
+ justify-content: center;
467
+ align-items: center;
468
+ }
469
+
470
+ #my_radio .wrap label {
471
+ display: flex;
472
+ width: 50%;
473
+ justify-content: center;
474
+ align-items: center;
475
+ margin: 0;
476
+ padding: 10px 0;
477
+ box-sizing: border-box;
478
+ }
479
+ """,
480
+ ) as demo:
481
+ # Instead of gr.State, we use a hidden Textbox:
482
+ is_example = gr.Textbox(label="is_example", visible=False, value="None")
483
+ num_images = gr.Textbox(label="num_images", visible=False, value="None")
484
+
485
+ gr.HTML(
486
+ """
487
+ <h1>🏛️ SAIL-Recon: Large SfM by Augmenting Scene Regression with Localization</h1>
488
+ <p>
489
+ <a href="https://github.com/HKUST-SAIL/sail-recon">🐙 GitHub Repository</a> |
490
+ <a href="https://hkust-sail.github.io/sail-recon/">Project Page</a>
491
+ </p>
492
+
493
+ <div style="font-size: 16px; line-height: 1.5;">
494
+ <p>Upload a video or a set of images to create a 3D reconstruction of a scene or object. SAIL-Recon takes these images and generates a 3D point cloud, along with estimated camera poses.</p>
495
+
496
+ <h3>Getting Started:</h3>
497
+ <ol>
498
+ <li><strong>Upload Your Data:</strong> Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).</li>
499
+ <li><strong>Preview:</strong> Your uploaded images will appear in the gallery on the left.</li>
500
+ <li><strong>Reconstruct:</strong> Click the "Reconstruct" button to start the 3D reconstruction process.</li>
501
+ <li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note that the visualization of 3D points may be slow for a large number of input images.</li>
502
+ <li>
503
+ <strong>Adjust Visualization (Optional):</strong>
504
+ After reconstruction, you can fine-tune the visualization using the options below
505
+ <details style="display:inline;">
506
+ <summary style="display:inline;">(<strong>click to expand</strong>):</summary>
507
+ <ul>
508
+ <li><em>Confidence Threshold:</em> Adjust the filtering of points based on confidence.</li>
509
+ <li><em>Show Points from Frame:</em> Select specific frames to display in the point cloud.</li>
510
+ <li><em>Show Camera:</em> Toggle the display of estimated camera positions.</li>
511
+ <li><em>Filter Sky / Filter Black Background:</em> Remove sky or black-background points.</li>
512
+ <li><em>Select a Prediction Mode:</em> Choose between "Depthmap and Camera Branch" or "Pointmap Branch."</li>
513
+ </ul>
514
+ </details>
515
+ </li>
516
+ </ol>
517
+ <p><strong style="color: #0ea5e9;">Please note:</strong> <span style="color: #0ea5e9; font-weight: bold;">SAIL-Recon typically reconstructs a scene at 5FPS with full 3D attributes. However, visualizing 3D points may take tens of seconds due to third-party rendering, which is independent of SAIL-Recon's processing time. Using the 'demo.py' can provide much faster processing.</span></p>
518
+ </div>
519
+ """
520
+ )
521
+
522
+ target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
523
+
524
+ with gr.Row():
525
+ with gr.Column(scale=2):
526
+ input_video = gr.Video(label="Upload Video", interactive=True)
527
+ input_images = gr.File(
528
+ file_count="multiple", label="Upload Images", interactive=True
529
+ )
530
+
531
+ image_gallery = gr.Gallery(
532
+ label="Preview",
533
+ columns=4,
534
+ height="300px",
535
+ show_download_button=True,
536
+ object_fit="contain",
537
+ preview=True,
538
+ )
539
+
540
+ with gr.Column(scale=4):
541
+ with gr.Column():
542
+ gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**")
543
+ log_output = gr.Markdown(
544
+ "Please upload a video or images, then click Reconstruct.",
545
+ elem_classes=["custom-log"],
546
+ )
547
+ reconstruction_output = gr.Model3D(
548
+ height=520, zoom_speed=0.5, pan_speed=0.5
549
+ )
550
+
551
+ with gr.Row():
552
+ submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
553
+ clear_btn = gr.ClearButton(
554
+ [
555
+ input_video,
556
+ input_images,
557
+ reconstruction_output,
558
+ log_output,
559
+ target_dir_output,
560
+ image_gallery,
561
+ ],
562
+ scale=1,
563
+ )
564
+
565
+ with gr.Row():
566
+ prediction_mode = gr.Radio(
567
+ ["Depthmap and Camera Branch", "Pointmap Branch"],
568
+ label="Select a Prediction Mode",
569
+ value="Depthmap and Camera Branch",
570
+ scale=1,
571
+ elem_id="my_radio",
572
+ )
573
+
574
+ with gr.Row():
575
+ conf_thres = gr.Slider(
576
+ minimum=0,
577
+ maximum=100,
578
+ value=50,
579
+ step=0.1,
580
+ label="Confidence Threshold (%)",
581
+ )
582
+ downsample_ratio = gr.Slider(
583
+ minimum=1.0,
584
+ maximum=100,
585
+ value=100,
586
+ step=0.1,
587
+ label="Downsample Ratio(%)",
588
+ )
589
+ frame_filter = gr.Dropdown(
590
+ choices=["All"], value="All", label="Show Points from Frame"
591
+ )
592
+ with gr.Column():
593
+ show_cam = gr.Checkbox(label="Show Camera", value=True)
594
+ mask_sky = gr.Checkbox(label="Filter Sky", value=False)
595
+ mask_black_bg = gr.Checkbox(
596
+ label="Filter Black Background", value=False
597
+ )
598
+ mask_white_bg = gr.Checkbox(
599
+ label="Filter White Background", value=False
600
+ )
601
+
602
+ # ---------------------- Examples section ----------------------
603
+ examples = [
604
+ [
605
+ colosseum_video,
606
+ "22",
607
+ None,
608
+ 20.0,
609
+ False,
610
+ False,
611
+ True,
612
+ False,
613
+ "Depthmap and Camera Branch",
614
+ "True",
615
+ ],
616
+ [
617
+ pyramid_video,
618
+ "30",
619
+ None,
620
+ 35.0,
621
+ False,
622
+ False,
623
+ True,
624
+ False,
625
+ "Depthmap and Camera Branch",
626
+ "True",
627
+ ],
628
+ [
629
+ single_cartoon_video,
630
+ "1",
631
+ None,
632
+ 15.0,
633
+ False,
634
+ False,
635
+ True,
636
+ False,
637
+ "Depthmap and Camera Branch",
638
+ "True",
639
+ ],
640
+ [
641
+ single_oil_painting_video,
642
+ "1",
643
+ None,
644
+ 20.0,
645
+ False,
646
+ False,
647
+ True,
648
+ True,
649
+ "Depthmap and Camera Branch",
650
+ "True",
651
+ ],
652
+ [
653
+ room_video,
654
+ "8",
655
+ None,
656
+ 5.0,
657
+ False,
658
+ False,
659
+ True,
660
+ False,
661
+ "Depthmap and Camera Branch",
662
+ "True",
663
+ ],
664
+ [
665
+ kitchen_video,
666
+ "25",
667
+ None,
668
+ 50.0,
669
+ False,
670
+ False,
671
+ True,
672
+ False,
673
+ "Depthmap and Camera Branch",
674
+ "True",
675
+ ],
676
+ [
677
+ fern_video,
678
+ "20",
679
+ None,
680
+ 45.0,
681
+ False,
682
+ False,
683
+ True,
684
+ False,
685
+ "Depthmap and Camera Branch",
686
+ "True",
687
+ ],
688
+ ]
689
+
690
+ def example_pipeline(
691
+ input_video,
692
+ num_images_str,
693
+ input_images,
694
+ conf_thres,
695
+ mask_black_bg,
696
+ mask_white_bg,
697
+ show_cam,
698
+ mask_sky,
699
+ downsample_ratio,
700
+ prediction_mode,
701
+ is_example_str,
702
+ ):
703
+ """
704
+ 1) Copy example images to new target_dir
705
+ 2) Reconstruct
706
+ 3) Return model3D + logs + new_dir + updated dropdown + gallery
707
+ We do NOT return is_example. It's just an input.
708
+ """
709
+ target_dir, image_paths = handle_uploads(input_video, input_images)
710
+ # Always use "All" for frame_filter in examples
711
+ frame_filter = "All"
712
+ glbfile, log_msg, dropdown = gradio_demo(
713
+ target_dir,
714
+ conf_thres,
715
+ frame_filter,
716
+ mask_black_bg,
717
+ mask_white_bg,
718
+ show_cam,
719
+ mask_sky,
720
+ downsample_ratio,
721
+ prediction_mode,
722
+ )
723
+ return glbfile, log_msg, target_dir, dropdown, image_paths
724
+
725
+ gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
726
+
727
+ # gr.Examples(
728
+ # examples=examples,
729
+ # inputs=[
730
+ # input_video,
731
+ # num_images,
732
+ # input_images,
733
+ # conf_thres,
734
+ # mask_black_bg,
735
+ # mask_white_bg,
736
+ # show_cam,
737
+ # mask_sky,
738
+ # downsample_ratio,
739
+ # prediction_mode,
740
+ # is_example,
741
+ # ],
742
+ # outputs=[reconstruction_output, log_output, target_dir_output, frame_filter, image_gallery],
743
+ # fn=example_pipeline,
744
+ # cache_examples=False,
745
+ # examples_per_page=50,
746
+ # )
747
+
748
+ # -------------------------------------------------------------------------
749
+ # "Reconstruct" button logic:
750
+ # - Clear fields
751
+ # - Update log
752
+ # - gradio_demo(...) with the existing target_dir
753
+ # - Then set is_example = "False"
754
+ # -------------------------------------------------------------------------
755
+ submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
756
+ fn=update_log, inputs=[], outputs=[log_output]
757
+ ).then(
758
+ fn=gradio_demo,
759
+ inputs=[
760
+ target_dir_output,
761
+ conf_thres,
762
+ frame_filter,
763
+ mask_black_bg,
764
+ mask_white_bg,
765
+ show_cam,
766
+ mask_sky,
767
+ downsample_ratio,
768
+ prediction_mode,
769
+ ],
770
+ outputs=[reconstruction_output, log_output, frame_filter],
771
+ ).then(
772
+ fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
773
+ )
774
+
775
+ # -------------------------------------------------------------------------
776
+ # Real-time Visualization Updates
777
+ # -------------------------------------------------------------------------
778
+ conf_thres.change(
779
+ update_visualization,
780
+ [
781
+ target_dir_output,
782
+ conf_thres,
783
+ frame_filter,
784
+ mask_black_bg,
785
+ mask_white_bg,
786
+ show_cam,
787
+ mask_sky,
788
+ downsample_ratio,
789
+ prediction_mode,
790
+ is_example,
791
+ ],
792
+ [reconstruction_output, log_output],
793
+ )
794
+ downsample_ratio.change(
795
+ update_visualization,
796
+ [
797
+ target_dir_output,
798
+ conf_thres,
799
+ frame_filter,
800
+ mask_black_bg,
801
+ mask_white_bg,
802
+ show_cam,
803
+ mask_sky,
804
+ downsample_ratio,
805
+ prediction_mode,
806
+ is_example,
807
+ ],
808
+ [reconstruction_output, log_output],
809
+ )
810
+ frame_filter.change(
811
+ update_visualization,
812
+ [
813
+ target_dir_output,
814
+ conf_thres,
815
+ frame_filter,
816
+ mask_black_bg,
817
+ mask_white_bg,
818
+ show_cam,
819
+ mask_sky,
820
+ downsample_ratio,
821
+ prediction_mode,
822
+ is_example,
823
+ ],
824
+ [reconstruction_output, log_output],
825
+ )
826
+ mask_black_bg.change(
827
+ update_visualization,
828
+ [
829
+ target_dir_output,
830
+ conf_thres,
831
+ frame_filter,
832
+ mask_black_bg,
833
+ mask_white_bg,
834
+ show_cam,
835
+ mask_sky,
836
+ downsample_ratio,
837
+ prediction_mode,
838
+ is_example,
839
+ ],
840
+ [reconstruction_output, log_output],
841
+ )
842
+ mask_white_bg.change(
843
+ update_visualization,
844
+ [
845
+ target_dir_output,
846
+ conf_thres,
847
+ frame_filter,
848
+ mask_black_bg,
849
+ mask_white_bg,
850
+ show_cam,
851
+ mask_sky,
852
+ downsample_ratio,
853
+ prediction_mode,
854
+ is_example,
855
+ ],
856
+ [reconstruction_output, log_output],
857
+ )
858
+ show_cam.change(
859
+ update_visualization,
860
+ [
861
+ target_dir_output,
862
+ conf_thres,
863
+ frame_filter,
864
+ mask_black_bg,
865
+ mask_white_bg,
866
+ show_cam,
867
+ mask_sky,
868
+ downsample_ratio,
869
+ prediction_mode,
870
+ is_example,
871
+ ],
872
+ [reconstruction_output, log_output],
873
+ )
874
+ mask_sky.change(
875
+ update_visualization,
876
+ [
877
+ target_dir_output,
878
+ conf_thres,
879
+ frame_filter,
880
+ mask_black_bg,
881
+ mask_white_bg,
882
+ show_cam,
883
+ mask_sky,
884
+ downsample_ratio,
885
+ prediction_mode,
886
+ is_example,
887
+ ],
888
+ [reconstruction_output, log_output],
889
+ )
890
+ prediction_mode.change(
891
+ update_visualization,
892
+ [
893
+ target_dir_output,
894
+ conf_thres,
895
+ frame_filter,
896
+ mask_black_bg,
897
+ mask_white_bg,
898
+ show_cam,
899
+ mask_sky,
900
+ downsample_ratio,
901
+ prediction_mode,
902
+ is_example,
903
+ ],
904
+ [reconstruction_output, log_output],
905
+ )
906
+
907
+ # -------------------------------------------------------------------------
908
+ # Auto-update gallery whenever user uploads or changes their files
909
+ # -------------------------------------------------------------------------
910
+ input_video.change(
911
+ fn=update_gallery_on_upload,
912
+ inputs=[input_video, input_images],
913
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
914
+ )
915
+ input_images.change(
916
+ fn=update_gallery_on_upload,
917
+ inputs=[input_video, input_images],
918
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
919
+ )
920
+
921
+ demo.queue(max_size=20).launch(show_error=True, share=True)
docs/traj_ply.png ADDED

Git LFS Details

  • SHA256: 473ea3a5a5056f6f2e905307a71e3fe1fa597cccd03b9545bdbeabb495c417a3
  • Pointer size: 132 Bytes
  • Size of remote file: 3.16 MB
eval/datasets/mip_360.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import random
4
+ import struct
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+ from vggt.utils.load_fn import load_and_preprocess_images
11
+
12
+
13
+ class Mip360Dataset(Dataset):
14
+ def __init__(self, root_dir, scene_name="bicycle"):
15
+ self.scene_dir = os.path.join(
16
+ root_dir,
17
+ f"{scene_name}",
18
+ )
19
+
20
+ self.test_samples = sorted(
21
+ glob.glob(os.path.join(self.scene_dir, "images_8", "*.JPG"))
22
+ )
23
+ # self.train_samples = sorted(glob.glob(os.path.join(self.train_seqs, "rgb", "*.png")))
24
+ self.all_samples = self.test_samples # + self.train_samples
25
+ bin_path = os.path.join(self.scene_dir, "sparse", "0", "images.bin")
26
+ self.poses = read_images_bin(bin_path)
27
+
28
+ def __len__(self):
29
+ return len(self.all_samples)
30
+
31
+ def __getitem__(self, idx):
32
+ return self._load_sample(self.all_samples[idx])
33
+
34
+ def get_train_sample(self, n=4):
35
+ # _rng = np.random.default_rng(seed=777)
36
+ gap = len(self.all_samples) // n
37
+ gap = max(gap, 1) # Ensure at least one sample is selected
38
+ gap = min(gap, len(self.all_samples)) # Ensure gap does not exceed length
39
+ if gap == 1:
40
+ selected = sorted(
41
+ random.sample(self.all_samples, min(n, len(self.all_samples)))
42
+ )
43
+ else:
44
+ selected = self.all_samples[::gap]
45
+ if len(selected) > n:
46
+ selected = sorted(random.sample(selected, n))
47
+ return [self._load_sample(s) for s in selected]
48
+
49
+ def _load_sample(self, rgb_path):
50
+ img_name = os.path.basename(rgb_path)
51
+ color = load_and_preprocess_images([rgb_path])[0]
52
+ pose = torch.from_numpy(self.poses[img_name]).float()
53
+
54
+ return dict(
55
+ img=color,
56
+ camera_pose=pose, # cam2world
57
+ dataset="7Scenes",
58
+ true_shape=torch.tensor([392, 518]),
59
+ label=img_name,
60
+ instance=img_name,
61
+ )
62
+
63
+
64
+ def read_images_bin(bin_path: str | Path):
65
+ bin_path = Path(bin_path)
66
+ poses = {}
67
+
68
+ with bin_path.open("rb") as f:
69
+ num_images = struct.unpack("<Q", f.read(8))[0] # uint64
70
+ for _ in range(num_images):
71
+ image_id = struct.unpack("<I", f.read(4))[0]
72
+ qvec = np.frombuffer(f.read(8 * 4), dtype=np.float64) # qw,qx,qy,qz
73
+ tvec = np.frombuffer(f.read(8 * 3), dtype=np.float64) # tx,ty,tz
74
+ cam_id = struct.unpack("<I", f.read(4))[0] # camera_id
75
+
76
+ name_bytes = bytearray()
77
+ while True:
78
+ c = f.read(1)
79
+ if c == b"\0":
80
+ break
81
+ name_bytes.extend(c)
82
+ name = name_bytes.decode("utf-8")
83
+
84
+ n_pts = struct.unpack("<Q", f.read(8))[0]
85
+ f.seek(n_pts * 24, 1)
86
+
87
+ # world→cam to cam→world
88
+ qw, qx, qy, qz = qvec
89
+ R_wc = np.array(
90
+ [
91
+ [
92
+ 1 - 2 * qy * qy - 2 * qz * qz,
93
+ 2 * qx * qy + 2 * qz * qw,
94
+ 2 * qx * qz - 2 * qy * qw,
95
+ ],
96
+ [
97
+ 2 * qx * qy - 2 * qz * qw,
98
+ 1 - 2 * qx * qx - 2 * qz * qz,
99
+ 2 * qy * qz + 2 * qx * qw,
100
+ ],
101
+ [
102
+ 2 * qx * qz + 2 * qy * qw,
103
+ 2 * qy * qz - 2 * qx * qw,
104
+ 1 - 2 * qx * qx - 2 * qy * qy,
105
+ ],
106
+ ]
107
+ )
108
+ t_wc = -R_wc @ tvec
109
+ c2w = np.eye(4, dtype=np.float32)
110
+ c2w[:3, :3] = R_wc.astype(np.float32)
111
+ c2w[:3, 3] = t_wc.astype(np.float32)
112
+
113
+ poses[name] = c2w
114
+
115
+ return poses
eval/datasets/seven_scenes.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ from vggt.utils.load_fn import load_and_preprocess_images
8
+
9
+ from eval.utils.eval_utils import uniform_sample
10
+
11
+
12
+ class SevenScenesUnifiedDataset(Dataset):
13
+ def __init__(self, root_dir, scene_name="chess"):
14
+ self.scene_dir = os.path.join(root_dir, f"pgt_7scenes_{scene_name}")
15
+
16
+ self.train_seqs = os.path.join(self.scene_dir, "train")
17
+ self.test_seqs = os.path.join(self.scene_dir, "test")
18
+
19
+ self.test_samples = sorted(
20
+ glob.glob(os.path.join(self.test_seqs, "rgb", "*.png"))
21
+ )
22
+ self.train_samples = sorted(
23
+ glob.glob(os.path.join(self.train_seqs, "rgb", "*.png"))
24
+ )
25
+ self.all_samples = self.test_samples # + self.train_samples
26
+ # len_samples = len(self.all_samples)
27
+ # self.all_samples = self.all_samples[::len_samples//200]
28
+
29
+ def __len__(self):
30
+ return len(self.all_samples)
31
+
32
+ def __getitem__(self, idx):
33
+ return self._load_sample(self.all_samples[idx])
34
+
35
+ def get_train_sample(self, n=4):
36
+ uniform_sampled = uniform_sample(len(self.all_samples), n)
37
+ selected = [self.all_samples[i] for i in uniform_sampled]
38
+ return [self._load_sample(s) for s in selected]
39
+
40
+ def _load_sample(self, rgb_path):
41
+ img_name = os.path.basename(rgb_path)
42
+ color = load_and_preprocess_images([rgb_path])[0]
43
+ pose_path = (
44
+ rgb_path.replace("rgb", "poses")
45
+ .replace("color", "pose")
46
+ .replace(".png", ".txt")
47
+ )
48
+ pose = np.loadtxt(pose_path)
49
+ pose = torch.from_numpy(pose).float()
50
+
51
+ return dict(
52
+ img=color,
53
+ camera_pose=pose, # cam2world
54
+ dataset="7Scenes",
55
+ true_shape=torch.tensor([392, 518]),
56
+ label=img_name,
57
+ instance=img_name,
58
+ )
eval/datasets/tnt.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import struct
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import PIL
7
+ import torch
8
+ from PIL import Image
9
+ from torch.utils.data import Dataset
10
+ from vggt.utils.load_fn import load_and_preprocess_images
11
+
12
+ from eval.utils.eval_utils import uniform_sample
13
+
14
+
15
+ class TnTDataset(Dataset):
16
+ def __init__(self, root_dir, colmap_dir, scene_name="advanced__Auditorium"):
17
+ scene_name_ori = scene_name
18
+
19
+ level, scene_name = scene_name.split("__")
20
+
21
+ self.scene_dir = os.path.join(root_dir, f"{level}", f"{scene_name}")
22
+
23
+ self.test_samples = []
24
+
25
+ bin_path = os.path.join(colmap_dir, scene_name_ori, "0", "images.bin")
26
+ self.poses = read_images_bin(bin_path)
27
+ for img in self.poses.keys():
28
+ self.test_samples.append(os.path.join(self.scene_dir, img))
29
+ self.all_samples = self.test_samples
30
+
31
+ def __len__(self):
32
+ return len(self.all_samples)
33
+
34
+ def __getitem__(self, idx):
35
+ return self._load_sample(self.all_samples[idx])
36
+
37
+ def get_train_sample(self, n=4):
38
+ gap = len(self.all_samples) // n
39
+ gap = max(gap, 1) # Ensure at least one sample is selected
40
+ gap = min(gap, len(self.all_samples)) # Ensure gap does not exceed length
41
+ if gap == 1:
42
+ uniform_sampled = uniform_sample(len(self.all_samples), n)
43
+ selected = [self.all_samples[i] for i in uniform_sampled]
44
+ else:
45
+ selected = self.all_samples[::gap]
46
+ if len(selected) > n:
47
+ uniform_sampled = uniform_sample(len(selected), n)
48
+ selected = [selected[i] for i in uniform_sampled]
49
+ return [self._load_sample(s) for s in selected]
50
+
51
+ def _load_sample(self, rgb_path):
52
+ img_name = os.path.basename(rgb_path)
53
+ color = load_and_preprocess_images([rgb_path])[0]
54
+ pose = torch.from_numpy(self.poses[img_name]).float()
55
+
56
+ return dict(
57
+ img=color,
58
+ camera_pose=pose, # cam2world
59
+ dataset="7Scenes",
60
+ true_shape=torch.tensor([392, 518]),
61
+ label=img_name,
62
+ instance=img_name,
63
+ )
64
+
65
+
66
+ def read_images_bin(bin_path: str | Path):
67
+ bin_path = Path(bin_path)
68
+ poses = {}
69
+
70
+ with bin_path.open("rb") as f:
71
+ num_images = struct.unpack("<Q", f.read(8))[0] # uint64
72
+ for _ in range(num_images):
73
+ image_id = struct.unpack("<I", f.read(4))[0]
74
+ qvec = np.frombuffer(f.read(8 * 4), dtype=np.float64) # qw,qx,qy,qz
75
+ tvec = np.frombuffer(f.read(8 * 3), dtype=np.float64) # tx,ty,tz
76
+ cam_id = struct.unpack("<I", f.read(4))[0] # camera_id
77
+
78
+ name_bytes = bytearray()
79
+ while True:
80
+ c = f.read(1)
81
+ if c == b"\0":
82
+ break
83
+ name_bytes.extend(c)
84
+ name = name_bytes.decode("utf-8").split("/")[-1] # 去掉后缀
85
+
86
+ n_pts = struct.unpack("<Q", f.read(8))[0]
87
+ f.seek(n_pts * 24, 1)
88
+
89
+ # world→cam to cam→world
90
+ qw, qx, qy, qz = qvec
91
+ R_wc = np.array(
92
+ [
93
+ [
94
+ 1 - 2 * qy * qy - 2 * qz * qz,
95
+ 2 * qx * qy + 2 * qz * qw,
96
+ 2 * qx * qz - 2 * qy * qw,
97
+ ],
98
+ [
99
+ 2 * qx * qy - 2 * qz * qw,
100
+ 1 - 2 * qx * qx - 2 * qz * qz,
101
+ 2 * qy * qz + 2 * qx * qw,
102
+ ],
103
+ [
104
+ 2 * qx * qz + 2 * qy * qw,
105
+ 2 * qy * qz - 2 * qx * qw,
106
+ 1 - 2 * qx * qx - 2 * qy * qy,
107
+ ],
108
+ ]
109
+ )
110
+ t_wc = -R_wc @ tvec
111
+ c2w = np.eye(4, dtype=np.float32)
112
+ c2w[:3, :3] = R_wc.astype(np.float32)
113
+ c2w[:3, 3] = t_wc.astype(np.float32)
114
+
115
+ poses[name] = c2w
116
+ return poses
eval/datasets/tum.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from torch.utils.data import Dataset
8
+ from vggt.utils.load_fn import load_and_preprocess_images
9
+
10
+ from eval.utils.eval_utils import uniform_sample
11
+
12
+
13
+ class TumDatasetAll(Dataset):
14
+ def __init__(self, root_dir, scene_name="rgbd_dataset_freiburg1_360"):
15
+ self.scene_name = scene_name
16
+ self.scene_all_dir = os.path.join(root_dir, f"{scene_name}", "rgb")
17
+
18
+ self.test_samples = sorted(glob.glob(os.path.join(self.scene_all_dir, "*.png")))
19
+ self.all_samples = self.test_samples
20
+
21
+ def __len__(self):
22
+ return len(self.all_samples)
23
+
24
+ def __getitem__(self, idx):
25
+ return self._load_sample(self.all_samples[idx])
26
+
27
+ def get_train_sample(self, n=4):
28
+ gap = len(self.all_samples) // n
29
+ gap = max(gap, 1) # Ensure at least one sample is selected
30
+ gap = min(gap, len(self.all_samples)) # Ensure gap does not exceed length
31
+ if gap == 1:
32
+ uniform_sampled = uniform_sample(len(self.all_samples), n)
33
+ selected = [self.all_samples[i] for i in uniform_sampled]
34
+ else:
35
+ selected = self.all_samples[::gap]
36
+ if len(selected) > n:
37
+ uniform_sampled = uniform_sample(len(selected), n)
38
+ selected = [selected[i] for i in uniform_sampled]
39
+ if self.scene_name == "rgbd_dataset_freiburg1_floor":
40
+ selected += self.all_samples[-20::5]
41
+ return [self._load_sample(s) for s in selected]
42
+
43
+ def _load_sample(self, rgb_path):
44
+ img_name = os.path.basename(rgb_path)
45
+ color = load_and_preprocess_images([rgb_path])[0]
46
+
47
+ return dict(
48
+ img=color,
49
+ dataset="tnt_all",
50
+ true_shape=torch.tensor([392, 518]),
51
+ label=img_name,
52
+ instance=img_name,
53
+ )
eval/readme.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Tanks and Temples.
2
+
3
+ ### Image set
4
+ 1. Data Preparation
5
+
6
+ Download the images data from [here](https://cvg.cit.tum.de/data/datasets/rgbd-dataset/download) (for intermidiated and advanced set, please download from [here](https://github.com/isl-org/TanksAndTemples/issues/35)), and COLMAP results from (here)[https://storage.googleapis.com/niantic-lon-static/research/acezero/colmap_raw.tar.gz]. We thank [ACE0](https://github.com/nianticlabs/acezero) again for providing the COLMAP results.
7
+
8
+ 2. Adjust the parameter in `run_tnt.sh`
9
+
10
+ Specify the `dataset_root`, `colmap_dir`, `model_path` and `save_dir` in the file.
11
+
12
+ 3. Get the inference results.
13
+
14
+ ```
15
+ sh run_tnt.sh
16
+ ```
17
+ ### Video set
18
+ <details>
19
+ <summary>Click to expand</summary>
20
+
21
+ 1. Data Preparation
22
+
23
+ Download the video sequence and from [here](https://www.tanksandtemples.org/download/) and get images from video via [this](https://www.tanksandtemples.org/tutorial/).
24
+
25
+ 2. Run Inference
26
+
27
+ Replace `docs/demo_image` in `../demo.py` to the path storing images from videl.
28
+ </details>
29
+
30
+ ## 7 scenes
31
+
32
+ 1. Data Preparation
33
+ Download the corresponding sequence from [here](https://jonbarron.info/mipnerf360/).
34
+
35
+
36
+ ## TUM-RGBD
37
+
38
+ 1. Data Preparation
39
+
40
+ Download the corresponding sequence from [here](https://cvg.cit.tum.de/data/datasets/rgbd-dataset/download).
41
+
42
+ 2. Adjust the parameter in `run_tum.sh`
43
+
44
+ Specify the `dataset_root`, `recon_img_num`, `model_path` and `save_dir` in the file.
45
+
46
+ 3. Evaluate the results.
47
+
48
+ ```
49
+ sh run_tum.sh
50
+ ```
51
+ Noting that we set the `recon_img_num` to 50 or 100 according to the length of dataset. Please refer to the supplementary of paper for detail.
52
+
53
+ 4. Using evo to evaluate The results
54
+
55
+ ```
56
+ evo_ape tum gt_pose.txt pred_tum.txt -vas
57
+ ```
58
+
59
+
60
+ ## 7 scenes
61
+
62
+ 1. Download the dataset from [here](https://www.microsoft.com/en-us/research/project/rgb-d-dataset-7-scenes/) and _Pseudo Ground Truth (PGT)_
63
+ (see
64
+ the [ICCV 2021 paper](https://openaccess.thecvf.com/content/ICCV2021/html/Brachmann_On_the_Limits_of_Pseudo_Ground_Truth_in_Visual_Camera_ICCV_2021_paper.html)
65
+ ,
66
+ and [associated code](https://github.com/tsattler/visloc_pseudo_gt_limitations/) for details).
67
+
68
+
69
+ 2. Adjust the parameter in `run_7scenes.sh`
70
+
71
+ Specify the `dataset_root`, `recon_img_num`, `model_path` and `save_dir` in the file.
72
+
73
+ 3. Evaluate the results.
74
+
75
+ ```
76
+ sh run_7scenes.sh
77
+ ```
78
+ You will see a `result.txt` file reporting the evaluation results.
79
+
80
+
81
+ ## Mip-NeRF 360
82
+
83
+ 1. Data Preparation
84
+
85
+ Download the data from [here](https://jonbarron.info/mipnerf360/).
86
+
87
+ 2. Adjust the parameter in `run_mip.sh`
88
+
89
+ Specify the `dataset_root`, `model_path` and `save_dir` in the file.
90
+
91
+ 3. Get the inference results.
92
+
93
+ ```
94
+ sh run_mip.sh
95
+ ```
96
+
97
+ ## Co3D-V2
98
+
99
+ 1. We thank VGGT for providing evaluation code of CO3D-V2 dataset. Please see link [here](https://github.com/facebookresearch/vggt/tree/evaluation/evaluation#dataset-preparation) for data preparation and processing.
100
+
101
+ 2. Adjust the parameterco3d_dir in `runco3d_anno_dir_7scenes.sh`
102
+
103
+ Specify the `dataset_root`, `recon_img_num`, `model_path`, `recon`, `reloc` and `fixed_rank` in the file.
104
+
105
+ 3. Evaluate the results.
106
+
107
+ ```
108
+ sh run_co3d.sh
109
+ ```
110
+ You will see evaluation result in the terminal.
eval/utils/cropping.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
10
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
11
+ #
12
+ # --------------------------------------------------------
13
+ # croppping utilities
14
+ # --------------------------------------------------------
15
+ import PIL.Image
16
+
17
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
18
+ import cv2 # noqa
19
+ import numpy as np # noqa
20
+ from easyvolcap.reloc_eval.utils.device import to_numpy
21
+ from easyvolcap.reloc_eval.utils.geometry import ( # noqa
22
+ colmap_to_opencv_intrinsics,
23
+ geotrf,
24
+ inv,
25
+ opencv_to_colmap_intrinsics,
26
+ )
27
+
28
+ try:
29
+ lanczos = PIL.Image.Resampling.LANCZOS
30
+ bicubic = PIL.Image.Resampling.BICUBIC
31
+ except AttributeError:
32
+ lanczos = PIL.Image.LANCZOS
33
+ bicubic = PIL.Image.BICUBIC
34
+
35
+
36
+ class ImageList:
37
+ """Convenience class to aply the same operation to a whole set of images."""
38
+
39
+ def __init__(self, images):
40
+ if not isinstance(images, (tuple, list, set)):
41
+ images = [images]
42
+ self.images = []
43
+ for image in images:
44
+ if not isinstance(image, PIL.Image.Image):
45
+ image = PIL.Image.fromarray(image)
46
+ self.images.append(image)
47
+
48
+ def __len__(self):
49
+ return len(self.images)
50
+
51
+ def to_pil(self):
52
+ return tuple(self.images) if len(self.images) > 1 else self.images[0]
53
+
54
+ @property
55
+ def size(self):
56
+ sizes = [im.size for im in self.images]
57
+ assert all(sizes[0] == s for s in sizes)
58
+ return sizes[0]
59
+
60
+ def resize(self, *args, **kwargs):
61
+ return ImageList(self._dispatch("resize", *args, **kwargs))
62
+
63
+ def crop(self, *args, **kwargs):
64
+ return ImageList(self._dispatch("crop", *args, **kwargs))
65
+
66
+ def _dispatch(self, func, *args, **kwargs):
67
+ return [getattr(im, func)(*args, **kwargs) for im in self.images]
68
+
69
+
70
+ def rescale_image_depthmap(
71
+ image, depthmap, camera_intrinsics, output_resolution, force=True
72
+ ):
73
+ """Jointly rescale a (image, depthmap)
74
+ so that (out_width, out_height) >= output_res
75
+ """
76
+ image = ImageList(image)
77
+ input_resolution = np.array(image.size) # (W,H)
78
+ output_resolution = np.array(output_resolution)
79
+ if depthmap is not None:
80
+ # can also use this with masks instead of depthmaps
81
+ assert tuple(depthmap.shape[:2]) == image.size[::-1]
82
+
83
+ # define output resolution
84
+ assert output_resolution.shape == (2,)
85
+ scale_final = max(output_resolution / image.size) + 1e-8
86
+ if scale_final >= 1 and not force: # image is already smaller than what is asked
87
+ return (image.to_pil(), depthmap, camera_intrinsics)
88
+ output_resolution = np.floor(input_resolution * scale_final).astype(int)
89
+
90
+ # first rescale the image so that it contains the crop
91
+ image = image.resize(
92
+ tuple(output_resolution), resample=lanczos if scale_final < 1 else bicubic
93
+ )
94
+ if depthmap is not None:
95
+ depthmap = cv2.resize(
96
+ depthmap,
97
+ output_resolution,
98
+ fx=scale_final,
99
+ fy=scale_final,
100
+ interpolation=cv2.INTER_NEAREST,
101
+ )
102
+
103
+ # no offset here; simple rescaling
104
+ camera_intrinsics = camera_matrix_of_crop(
105
+ camera_intrinsics, input_resolution, output_resolution, scaling=scale_final
106
+ )
107
+
108
+ return image.to_pil(), depthmap, camera_intrinsics
109
+
110
+
111
+ def camera_matrix_of_crop(
112
+ input_camera_matrix,
113
+ input_resolution,
114
+ output_resolution,
115
+ scaling=1,
116
+ offset_factor=0.5,
117
+ offset=None,
118
+ ):
119
+ # Margins to offset the origin
120
+ margins = np.asarray(input_resolution) * scaling - output_resolution
121
+ assert np.all(margins >= 0.0)
122
+ if offset is None:
123
+ offset = offset_factor * margins
124
+
125
+ # Generate new camera parameters
126
+ output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
127
+ output_camera_matrix_colmap[:2, :] *= scaling
128
+ output_camera_matrix_colmap[:2, 2] -= offset
129
+ output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
130
+
131
+ return output_camera_matrix
132
+
133
+
134
+ def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):
135
+ """
136
+ Return a crop of the input view.
137
+ """
138
+ image = ImageList(image)
139
+ l, t, r, b = crop_bbox
140
+
141
+ image = image.crop((l, t, r, b))
142
+ if depthmap is not None:
143
+ depthmap = depthmap[t:b, l:r]
144
+
145
+ camera_intrinsics = camera_intrinsics.copy()
146
+ camera_intrinsics[0, 2] -= l
147
+ camera_intrinsics[1, 2] -= t
148
+
149
+ return image.to_pil(), depthmap, camera_intrinsics
150
+
151
+
152
+ def bbox_from_intrinsics_in_out(
153
+ input_camera_matrix, output_camera_matrix, output_resolution
154
+ ):
155
+ out_width, out_height = output_resolution
156
+ l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
157
+ crop_bbox = (l, t, l + out_width, t + out_height)
158
+ return crop_bbox
159
+
160
+
161
+ def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False):
162
+ is_reciprocal1 = corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2))
163
+ pos1 = is_reciprocal1.nonzero()[0]
164
+ pos2 = corres_1_to_2[pos1]
165
+ if ret_recip:
166
+ return is_reciprocal1, pos1, pos2
167
+ return pos1, pos2
168
+
169
+
170
+ def generate_non_self_pairs(n):
171
+ i, j = np.meshgrid(np.arange(n), np.arange(n), indexing="ij")
172
+
173
+ pairs = np.stack([i.ravel(), j.ravel()], axis=1)
174
+
175
+ mask = pairs[:, 0] != pairs[:, 1]
176
+ filtered_pairs = pairs[mask]
177
+
178
+ return filtered_pairs
179
+
180
+
181
+ def unravel_xy(pos, shape):
182
+ # convert (x+W*y) back to 2d (x,y) coordinates
183
+ return np.unravel_index(pos, shape)[0].base[:, ::-1].copy()
184
+
185
+
186
+ def ravel_xy(pos, shape):
187
+ H, W = shape
188
+ with np.errstate(invalid="ignore"):
189
+ qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T
190
+ quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(
191
+ min=0, max=H - 1, out=qy
192
+ )
193
+ return quantized_pos
194
+
195
+
196
+ def extract_correspondences_from_pts3d(
197
+ view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0
198
+ ):
199
+ view1, view2 = to_numpy((view1, view2))
200
+ # project pixels from image1 --> 3d points --> image2 pixels
201
+ shape1, corres1_to_2 = reproject_view(view1["pts3d"], view2)
202
+ shape2, corres2_to_1 = reproject_view(view2["pts3d"], view1)
203
+
204
+ # compute reciprocal correspondences:
205
+ # pos1 == valid pixels (correspondences) in image1
206
+ is_reciprocal1, pos1, pos2 = reciprocal_1d(
207
+ corres1_to_2, corres2_to_1, ret_recip=True
208
+ )
209
+ is_reciprocal2 = corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1))
210
+
211
+ if target_n_corres is None:
212
+ if ret_xy:
213
+ pos1 = unravel_xy(pos1, shape1)
214
+ pos2 = unravel_xy(pos2, shape2)
215
+ return pos1, pos2
216
+
217
+ available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum())
218
+ target_n_positives = int(target_n_corres * (1 - nneg))
219
+ n_positives = min(len(pos1), target_n_positives)
220
+ n_negatives = min(target_n_corres - n_positives, available_negatives)
221
+
222
+ if n_negatives + n_positives != target_n_corres:
223
+ # should be really rare => when there are not enough negatives
224
+ # in that case, break nneg and add a few more positives ?
225
+ n_positives = target_n_corres - n_negatives
226
+ assert n_positives <= len(pos1)
227
+
228
+ assert n_positives <= len(pos1)
229
+ assert n_positives <= len(pos2)
230
+ assert n_negatives <= (~is_reciprocal1).sum()
231
+ assert n_negatives <= (~is_reciprocal2).sum()
232
+ assert n_positives + n_negatives == target_n_corres
233
+
234
+ valid = np.ones(n_positives, dtype=bool)
235
+ if n_positives < len(pos1):
236
+ # random sub-sampling of valid correspondences
237
+ perm = rng.permutation(len(pos1))[:n_positives]
238
+ pos1 = pos1[perm]
239
+ pos2 = pos2[perm]
240
+
241
+ if n_negatives > 0:
242
+ # add false correspondences if not enough
243
+ def norm(p):
244
+ return p / p.sum()
245
+
246
+ pos1 = np.r_[
247
+ pos1,
248
+ rng.choice(
249
+ shape1[0] * shape1[1],
250
+ size=n_negatives,
251
+ replace=False,
252
+ p=norm(~is_reciprocal1),
253
+ ),
254
+ ]
255
+ pos2 = np.r_[
256
+ pos2,
257
+ rng.choice(
258
+ shape2[0] * shape2[1],
259
+ size=n_negatives,
260
+ replace=False,
261
+ p=norm(~is_reciprocal2),
262
+ ),
263
+ ]
264
+ valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)]
265
+
266
+ # convert (x+W*y) back to 2d (x,y) coordinates
267
+ if ret_xy:
268
+ pos1 = unravel_xy(pos1, shape1)
269
+ pos2 = unravel_xy(pos2, shape2)
270
+ return pos1, pos2, valid
271
+
272
+
273
+ def reproject_view(pts3d, view2):
274
+ shape = view2["pts3d"].shape[:2]
275
+ return reproject(
276
+ pts3d, view2["camera_intrinsics"], inv(view2["camera_pose"]), shape
277
+ )
278
+
279
+
280
+ def reproject(pts3d, K, world2cam, shape):
281
+ H, W, THREE = pts3d.shape
282
+ assert THREE == 3
283
+
284
+ # reproject in camera2 space
285
+ with np.errstate(divide="ignore", invalid="ignore"):
286
+ pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2)
287
+
288
+ # quantize to pixel positions
289
+ return (H, W), ravel_xy(pos, shape)
eval/utils/device.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
8
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
9
+ #
10
+ # --------------------------------------------------------
11
+ # utilitary functions for DUSt3R
12
+ # --------------------------------------------------------
13
+ import numpy as np
14
+ import torch
15
+
16
+
17
+ def todevice(batch, device, callback=None, non_blocking=False):
18
+ """Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
19
+
20
+ batch: list, tuple, dict of tensors or other things
21
+ device: pytorch device or 'numpy'
22
+ callback: function that would be called on every sub-elements.
23
+ """
24
+ if callback:
25
+ batch = callback(batch)
26
+
27
+ if isinstance(batch, dict):
28
+ return {k: todevice(v, device) for k, v in batch.items()}
29
+
30
+ if isinstance(batch, (tuple, list)):
31
+ return type(batch)(todevice(x, device) for x in batch)
32
+
33
+ x = batch
34
+ if device == "numpy":
35
+ if isinstance(x, torch.Tensor):
36
+ x = x.detach().cpu().numpy()
37
+ elif x is not None:
38
+ if isinstance(x, np.ndarray):
39
+ x = torch.from_numpy(x)
40
+ if torch.is_tensor(x):
41
+ x = x.to(device, non_blocking=non_blocking)
42
+ return x
43
+
44
+
45
+ to_device = todevice # alias
46
+
47
+
48
+ def to_numpy(x):
49
+ return todevice(x, "numpy")
50
+
51
+
52
+ def to_cpu(x):
53
+ return todevice(x, "cpu")
54
+
55
+
56
+ def to_cuda(x):
57
+ return todevice(x, "cuda")
58
+
59
+
60
+ def collate_with_cat(whatever, lists=False):
61
+ if isinstance(whatever, dict):
62
+ return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()}
63
+
64
+ elif isinstance(whatever, (tuple, list)):
65
+ if len(whatever) == 0:
66
+ return whatever
67
+ elem = whatever[0]
68
+ T = type(whatever)
69
+
70
+ if elem is None:
71
+ return None
72
+ if isinstance(elem, (bool, float, int, str)):
73
+ return whatever
74
+ if isinstance(elem, tuple):
75
+ return T(collate_with_cat(x, lists=lists) for x in zip(*whatever))
76
+ if isinstance(elem, dict):
77
+ return {
78
+ k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem
79
+ }
80
+
81
+ if isinstance(elem, torch.Tensor):
82
+ return listify(whatever) if lists else torch.cat(whatever)
83
+ if isinstance(elem, np.ndarray):
84
+ return (
85
+ listify(whatever)
86
+ if lists
87
+ else torch.cat([torch.from_numpy(x) for x in whatever])
88
+ )
89
+
90
+ # otherwise, we just chain lists
91
+ return sum(whatever, T())
92
+
93
+
94
+ def listify(elems):
95
+ return [x for e in elems for x in e]
eval/utils/eval_pose_ransac.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import random
4
+ from collections import namedtuple
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from scipy.spatial.transform import Rotation
9
+
10
+ _logger = logging.getLogger(__name__)
11
+ _logger.setLevel(logging.DEBUG)
12
+
13
+
14
+ def kabsch(pts1, pts2, estimate_scale=False):
15
+ c_pts1 = pts1 - pts1.mean(axis=0)
16
+ c_pts2 = pts2 - pts2.mean(axis=0)
17
+
18
+ covariance = np.matmul(c_pts1.T, c_pts2) / c_pts1.shape[0]
19
+
20
+ U, S, VT = np.linalg.svd(covariance)
21
+
22
+ d = np.sign(np.linalg.det(np.matmul(VT.T, U.T)))
23
+ correction = np.eye(3)
24
+ correction[2, 2] = d
25
+
26
+ if estimate_scale:
27
+ pts_var = np.mean(np.linalg.norm(c_pts2, axis=1) ** 2)
28
+ scale_factor = pts_var / np.trace(S * correction)
29
+ else:
30
+ scale_factor = 1.0
31
+
32
+ R = scale_factor * np.matmul(np.matmul(VT.T, correction), U.T)
33
+ t = pts2.mean(axis=0) - np.matmul(R, pts1.mean(axis=0))
34
+
35
+ T = np.eye(4)
36
+ T[:3, :3] = R
37
+ T[:3, 3] = t
38
+
39
+ return T, scale_factor
40
+
41
+
42
+ def get_inliers(h_T, poses_gt, poses_est, inlier_threshold_t, inlier_threshold_r):
43
+ # h_T aligns ground truth poses with estimates poses
44
+ poses_gt_transformed = h_T @ poses_gt
45
+
46
+ # calculate differences in position and rotations
47
+ translations_delta = poses_gt_transformed[:, :3, 3] - poses_est[:, :3, 3]
48
+ rotations_delta = poses_gt_transformed[:, :3, :3] @ poses_est[:, :3, :3].transpose(
49
+ [0, 2, 1]
50
+ )
51
+
52
+ # translation inliers
53
+ inliers_t = np.linalg.norm(translations_delta, axis=1) < inlier_threshold_t
54
+ # rotation inliers
55
+ inliers_r = Rotation.from_matrix(rotations_delta).magnitude() < (
56
+ inlier_threshold_r / 180 * math.pi
57
+ )
58
+ # intersection of both
59
+ return np.logical_and(inliers_r, inliers_t)
60
+
61
+
62
+ def print_hyp(hypothesis, hyp_name):
63
+ h_translation = np.linalg.norm(hypothesis["transformation"][:3, 3])
64
+ h_angle = (
65
+ np.linalg.norm(
66
+ Rotation.from_matrix(hypothesis["transformation"][:3, :3]).as_rotvec()
67
+ )
68
+ * 180
69
+ / math.pi
70
+ )
71
+ print(
72
+ f"{hyp_name}: score={hypothesis['score']}, translation={h_translation:.2f}m, "
73
+ f"rotation={h_angle:.1f}deg."
74
+ )
75
+
76
+
77
+ def estimated_alignment(
78
+ pose_est,
79
+ pose_gt,
80
+ inlier_threshold_t=0.05,
81
+ inlier_threshold_r=5,
82
+ ransac_iterations=1000,
83
+ refinement_max_hyp=12,
84
+ refinement_max_it=8,
85
+ estimate_scale=False,
86
+ ):
87
+ n_pose = len(pose_est)
88
+ ransac_hypotheses = []
89
+ for i in range(ransac_iterations):
90
+ min_sample_size = 3
91
+ samples = random.sample(range(n_pose), min_sample_size)
92
+ h_pts1 = pose_gt[samples, :3, 3]
93
+ h_pts2 = pose_est[samples, :3, 3]
94
+
95
+ h_T, h_scale = kabsch(h_pts1, h_pts2, estimate_scale=estimate_scale)
96
+
97
+ inliers = get_inliers(
98
+ h_T, pose_gt, pose_est, inlier_threshold_t, inlier_threshold_r
99
+ )
100
+
101
+ if inliers[samples].sum() >= 3:
102
+ # only keep hypotheses if minimal sample is all inliers
103
+ ransac_hypotheses.append(
104
+ {
105
+ "transformation": h_T,
106
+ "inliers": inliers,
107
+ "score": inliers.sum(),
108
+ "scale": h_scale,
109
+ }
110
+ )
111
+ if len(ransac_hypotheses) == 0:
112
+ print(
113
+ f"Did not fine a single valid RANSAC hypothesis, abort alignment estimation."
114
+ )
115
+ return None, 1
116
+
117
+ # sort according to score
118
+ ransac_hypotheses = sorted(
119
+ ransac_hypotheses, key=lambda x: x["score"], reverse=True
120
+ )
121
+
122
+ # for hyp_idx, hyp in enumerate(ransac_hypotheses):
123
+ # print_hyp(hyp, f"Hypothesis {hyp_idx}")
124
+
125
+ # create shortlist of best hypotheses for refinement
126
+ # print(f"Starting refinement of {refinement_max_hyp} best hypotheses.")
127
+ ransac_hypotheses = ransac_hypotheses[:refinement_max_hyp]
128
+
129
+ # refine all hypotheses in the short list
130
+ for ref_hyp in ransac_hypotheses:
131
+ # print_hyp(ref_hyp, "Pre-Refinement")
132
+
133
+ # refinement loop
134
+ for ref_it in range(refinement_max_it):
135
+ # re-solve alignment on all inliers
136
+ h_pts1 = pose_gt[ref_hyp["inliers"], :3, 3]
137
+ h_pts2 = pose_est[ref_hyp["inliers"], :3, 3]
138
+
139
+ h_T, h_scale = kabsch(h_pts1, h_pts2, estimate_scale)
140
+
141
+ # calculate new inliers
142
+ inliers = get_inliers(
143
+ h_T, pose_gt, pose_est, inlier_threshold_t, inlier_threshold_r
144
+ )
145
+
146
+ # check whether hypothesis score improved
147
+ refined_score = inliers.sum()
148
+
149
+ if refined_score > ref_hyp["score"]:
150
+ ref_hyp["transformation"] = h_T
151
+ ref_hyp["inliers"] = inliers
152
+ ref_hyp["score"] = refined_score
153
+ ref_hyp["scale"] = h_scale
154
+
155
+ # print_hyp(ref_hyp, f"Refinement interation {ref_it}")
156
+
157
+ else:
158
+ # print(f"Stopping refinement. Score did not improve: New score={refined_score}, "
159
+ # f"Old score={ref_hyp['score']}")
160
+ break
161
+
162
+ # re-sort refined hyotheses
163
+ ransac_hypotheses = sorted(
164
+ ransac_hypotheses, key=lambda x: x["score"], reverse=True
165
+ )
166
+
167
+ # for hyp_idx, hyp in enumerate(ransac_hypotheses):
168
+ # print_hyp(hyp, f"Hypothesis {hyp_idx}")
169
+
170
+ return ransac_hypotheses[0]["transformation"], ransac_hypotheses[0]["scale"]
171
+
172
+
173
+ def eval_pose_ransac(gt, est, t_thres=0.05, r_thres=5, aligned=True, save_dir=None):
174
+ if aligned:
175
+ alignment_transformation, alignment_scale = estimated_alignment(
176
+ est,
177
+ gt,
178
+ inlier_threshold_t=0.05,
179
+ inlier_threshold_r=5,
180
+ ransac_iterations=1000,
181
+ refinement_max_hyp=12,
182
+ refinement_max_it=8,
183
+ estimate_scale=True,
184
+ )
185
+ if alignment_transformation is None:
186
+ _logger.info(
187
+ f"Alignment requested but failed. Setting all pose errors to {math.inf}."
188
+ )
189
+ else:
190
+ alignment_transformation = np.eye(4)
191
+ alignment_scale = 1.0
192
+ # Evaluation Loop
193
+
194
+ rErrs = []
195
+ tErrs = []
196
+ accuracy = 0
197
+ r_acc_5 = 0
198
+ r_acc_2 = 0
199
+ r_acc_1 = 0
200
+ t_acc_15 = 0
201
+ t_acc_10 = 0
202
+ t_acc_5 = 0
203
+ t_acc_2 = 0
204
+ t_acc_1 = 0
205
+ acc_10 = 0
206
+ acc_5 = 0
207
+ acc_2 = 0
208
+ acc_1 = 0
209
+
210
+ for pred_pose, gt_pose in zip(est, gt):
211
+ if alignment_transformation is not None:
212
+ # Apply alignment transformation to GT pose
213
+ gt_pose = alignment_transformation @ gt_pose
214
+
215
+ # Calculate translation error.
216
+ t_err = float(np.linalg.norm(gt_pose[0:3, 3] - pred_pose[0:3, 3]))
217
+
218
+ # Correct translation scale with the inverse alignment scale (since we align GT with estimates)
219
+ t_err = t_err / alignment_scale
220
+
221
+ # Rotation error.
222
+ gt_R = gt_pose[0:3, 0:3]
223
+ out_R = pred_pose[0:3, 0:3]
224
+
225
+ r_err = np.matmul(out_R, np.transpose(gt_R))
226
+ # Compute angle-axis representation.
227
+ r_err = cv2.Rodrigues(r_err)[0]
228
+ # Extract the angle.
229
+ r_err = np.linalg.norm(r_err) * 180 / math.pi
230
+ else:
231
+ pose_gt = None
232
+ t_err, r_err = math.inf, math.inf
233
+
234
+ # _logger.info(f"Rotation Error: {r_err:.2f}deg, Translation Error: {t_err * 100:.1f}cm")
235
+
236
+ # Save the errors.
237
+ rErrs.append(r_err)
238
+ tErrs.append(t_err * 100) # in cm
239
+
240
+ # Check various thresholds.
241
+ if r_err < r_thres and t_err < t_thres:
242
+ accuracy += 1
243
+ if r_err < 5:
244
+ r_acc_5 += 1
245
+ if r_err < 2:
246
+ r_acc_2 += 1
247
+ if r_err < 1:
248
+ r_acc_1 += 1
249
+ if t_err < 0.15:
250
+ t_acc_15 += 1
251
+ if t_err < 0.10:
252
+ t_acc_10 += 1
253
+ if t_err < 0.05:
254
+ t_acc_5 += 1
255
+ if t_err < 0.02:
256
+ t_acc_2 += 1
257
+ if t_err < 0.01:
258
+ t_acc_1 += 1
259
+ if r_err < 10 and t_err < 0.10:
260
+ acc_10 += 1
261
+ if r_err < 5 and t_err < 0.05:
262
+ acc_5 += 1
263
+ if r_err < 2 and t_err < 0.02:
264
+ acc_2 += 1
265
+ if r_err < 1 and t_err < 0.01:
266
+ acc_1 += 1
267
+
268
+ total_frames = len(rErrs)
269
+ assert total_frames == len(est)
270
+
271
+ # Compute median errors.
272
+ tErrs.sort()
273
+ rErrs.sort()
274
+ median_idx = total_frames // 2
275
+ median_rErr = rErrs[median_idx]
276
+ median_tErr = tErrs[median_idx]
277
+
278
+ # Compute final precision.
279
+ accuracy = accuracy / total_frames * 100
280
+ r_acc_5 = r_acc_5 / total_frames * 100
281
+ r_acc_2 = r_acc_2 / total_frames * 100
282
+ r_acc_1 = r_acc_1 / total_frames * 100
283
+ t_acc_15 = t_acc_15 / total_frames * 100
284
+ t_acc_10 = t_acc_10 / total_frames * 100
285
+ t_acc_5 = t_acc_5 / total_frames * 100
286
+ t_acc_2 = t_acc_2 / total_frames * 100
287
+ t_acc_1 = t_acc_1 / total_frames * 100
288
+ acc_10 = acc_10 / total_frames * 100
289
+ acc_5 = acc_5 / total_frames * 100
290
+ acc_2 = acc_2 / total_frames * 100
291
+ acc_1 = acc_1 / total_frames * 100
292
+
293
+ # _logger.info("===================================================")
294
+ # _logger.info("Test complete.")
295
+
296
+ # _logger.info(f'Accuracy: {accuracy:.1f}%')
297
+ # _logger.info(f"Median Error: {median_rErr:.1f}deg, {median_tErr:.1f}cm")
298
+ # print("===================================================")
299
+ # print("Test complete.")
300
+
301
+ with open(save_dir, "w") as f:
302
+ f.write(f"Accuracy: {accuracy:.1f}%\n\n")
303
+ f.write(f"Median Error: {median_rErr:.1f}deg, {median_tErr:.1f}cm\n")
304
+ f.write(f"R acc 5: {r_acc_5:.1f}%\n")
305
+ f.write(f"R acc 2: {r_acc_2:.1f}%\n")
306
+ f.write(f"R acc 1: {r_acc_1:.1f}%\n")
307
+ f.write(f"T acc 15: {t_acc_15:.1f}%\n")
308
+ f.write(f"T acc 10: {t_acc_10:.1f}%\n")
309
+ f.write(f"T acc 5: {t_acc_5:.1f}%\n")
310
+ f.write(f"T acc 2: {t_acc_2:.1f}%\n")
311
+ f.write(f"T acc 1: {t_acc_1:.1f}%\n")
312
+ f.write(f"Acc 10: {acc_10:.1f}%\n")
313
+ f.write(f"Acc 5: {acc_5:.1f}%\n")
314
+ f.write(f"Acc 2: {acc_2:.1f}%\n")
315
+ f.write(f"Acc 1: {acc_1:.1f}%\n")
eval/utils/eval_utils.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import sys
3
+
4
+ import numpy as np
5
+ from evo.core.trajectory import PosePath3D, PoseTrajectory3D
6
+
7
+
8
+ def save_kitti_poses(poses, save_path):
9
+ with open(save_path, "w") as f:
10
+ for pose in poses: # pose: 4x4 numpy array
11
+ pose_line = pose[:3].reshape(-1) # flatten first 3 rows
12
+ f.write(" ".join(map(str, pose_line)) + "\n")
13
+
14
+
15
+ def save_tum_poses(poses, timestamps, save_path):
16
+ """
17
+ Save poses in TUM RGB-D format.
18
+ Args:
19
+ poses: list or array of 4x4 numpy arrays (T_w_c)
20
+ timestamps: list or array of float timestamps (same length as poses)
21
+ save_path: output file path
22
+ """
23
+ assert len(poses) == len(timestamps), "poses and timestamps length mismatch"
24
+
25
+ with open(save_path, "w") as f:
26
+ for ts, pose in zip(timestamps, poses):
27
+ tx, ty, tz = pose[0, 3], pose[1, 3], pose[2, 3]
28
+
29
+ R = pose[:3, :3]
30
+ qw = np.sqrt(max(0, 1 + R[0, 0] + R[1, 1] + R[2, 2])) / 2
31
+ qx = np.sqrt(max(0, 1 + R[0, 0] - R[1, 1] - R[2, 2])) / 2
32
+ qy = np.sqrt(max(0, 1 - R[0, 0] + R[1, 1] - R[2, 2])) / 2
33
+ qz = np.sqrt(max(0, 1 - R[0, 0] - R[1, 1] + R[2, 2])) / 2
34
+ qx = math.copysign(qx, R[2, 1] - R[1, 2])
35
+ qy = math.copysign(qy, R[0, 2] - R[2, 0])
36
+ qz = math.copysign(qz, R[1, 0] - R[0, 1])
37
+
38
+ f.write(
39
+ f"{ts:.6f} {tx:.6f} {ty:.6f} {tz:.6f} {qx:.6f} {qy:.6f} {qz:.6f} {qw:.6f}\n"
40
+ )
41
+
42
+
43
+ def align_gt_pred(gt_views, poses_c2w_estimated):
44
+ poses_c2w_gt = [view["camera_pose"][0] for view in gt_views]
45
+ gt = PosePath3D(poses_se3=poses_c2w_gt)
46
+ pred = PosePath3D(poses_se3=poses_c2w_estimated[0])
47
+ r_a, t_a, s = pred.align(gt, correct_scale=True)
48
+
49
+ return pred.poses_se3, gt.poses_se3
50
+
51
+
52
+ def align_gt_pred_2(poses_c2w_gt, poses_c2w_estimated):
53
+ # poses_c2w_gt = [view["camera_pose"][0] for view in gt_views]
54
+ gt = PosePath3D(poses_se3=poses_c2w_gt)
55
+ pred = PosePath3D(poses_se3=poses_c2w_estimated)
56
+ r_a, t_a, s = pred.align(gt, correct_scale=True)
57
+
58
+ return pred.poses_se3, gt.poses_se3
59
+
60
+
61
+ def save_all_intrinsics_to_txt(result, filename="all_intrinsics.txt"):
62
+ with open(filename, "w") as f:
63
+ for i, res in enumerate(result):
64
+ intrinsic = res["intrinsic"].squeeze(0).reshape(-1).cpu().numpy() # (9,)
65
+ line = "\t".join([f"{v:.6f}" for v in intrinsic])
66
+ f.write(line + "\n")
67
+ print(f"[TXT] Saved {len(result)} intrinsics to {filename}")
68
+
69
+
70
+ def uniform_sample(total: int, select: int) -> list:
71
+ if select > total:
72
+ raise ValueError("select cannot be greater than total")
73
+ step = total / select
74
+ return [int(i * step) for i in range(select)]
eval/utils/geometry.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from pathlib import Path
8
+
9
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
10
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
11
+ #
12
+ # --------------------------------------------------------
13
+ # geometry utilitary functions
14
+ # --------------------------------------------------------
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+ import torchvision.utils as vutils
19
+ from plyfile import PlyData, PlyElement
20
+ from scipy.spatial import cKDTree as KDTree
21
+ from tqdm import tqdm
22
+
23
+ from eval.utils.device import to_numpy
24
+ from eval.utils.misc import invalid_to_nans, invalid_to_zeros
25
+
26
+
27
+ def xy_grid(
28
+ W,
29
+ H,
30
+ device=None,
31
+ origin=(0, 0),
32
+ unsqueeze=None,
33
+ cat_dim=-1,
34
+ homogeneous=False,
35
+ **arange_kw,
36
+ ):
37
+ """Output a (H,W,2) array of int32
38
+ with output[j,i,0] = i + origin[0]
39
+ output[j,i,1] = j + origin[1]
40
+ """
41
+ if device is None:
42
+ # numpy
43
+ arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
44
+ else:
45
+ # torch
46
+ arange = lambda *a, **kw: torch.arange(*a, device=device, **kw)
47
+ meshgrid, stack = torch.meshgrid, torch.stack
48
+ ones = lambda *a: torch.ones(*a, device=device)
49
+
50
+ tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)]
51
+ grid = meshgrid(tw, th, indexing="xy")
52
+ if homogeneous:
53
+ grid = grid + (ones((H, W)),)
54
+ if unsqueeze is not None:
55
+ grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
56
+ if cat_dim is not None:
57
+ grid = stack(grid, cat_dim)
58
+ return grid
59
+
60
+
61
+ def geotrf(Trf, pts, ncol=None, norm=False):
62
+ """Apply a geometric transformation to a list of 3-D points.
63
+
64
+ H: 3x3 or 4x4 projection matrix (typically a Homography)
65
+ p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
66
+
67
+ ncol: int. number of columns of the result (2 or 3)
68
+ norm: float. if != 0, the resut is projected on the z=norm plane.
69
+
70
+ Returns an array of projected 2d points.
71
+ """
72
+ assert Trf.ndim >= 2
73
+ if isinstance(Trf, np.ndarray):
74
+ pts = np.asarray(pts)
75
+ elif isinstance(Trf, torch.Tensor):
76
+ pts = torch.as_tensor(pts, dtype=Trf.dtype)
77
+
78
+ # adapt shape if necessary
79
+ output_reshape = pts.shape[:-1]
80
+ ncol = ncol or pts.shape[-1]
81
+
82
+ # optimized code
83
+ if (
84
+ isinstance(Trf, torch.Tensor)
85
+ and isinstance(pts, torch.Tensor)
86
+ and Trf.ndim == 3
87
+ and pts.ndim == 4
88
+ ):
89
+ d = pts.shape[3]
90
+ if Trf.shape[-1] == d:
91
+ pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
92
+ elif Trf.shape[-1] == d + 1:
93
+ pts = (
94
+ torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts)
95
+ + Trf[:, None, None, :d, d]
96
+ )
97
+ else:
98
+ raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}")
99
+ else:
100
+ if Trf.ndim >= 3:
101
+ n = Trf.ndim - 2
102
+ assert Trf.shape[:n] == pts.shape[:n], "batch size does not match"
103
+ Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
104
+
105
+ if pts.ndim > Trf.ndim:
106
+ # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
107
+ pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
108
+ elif pts.ndim == 2:
109
+ # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
110
+ pts = pts[:, None, :]
111
+
112
+ if pts.shape[-1] + 1 == Trf.shape[-1]:
113
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
114
+ pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
115
+ elif pts.shape[-1] == Trf.shape[-1]:
116
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
117
+ pts = pts @ Trf
118
+ else:
119
+ pts = Trf @ pts.T
120
+ if pts.ndim >= 2:
121
+ pts = pts.swapaxes(-1, -2)
122
+
123
+ if norm:
124
+ pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
125
+ if norm != 1:
126
+ pts *= norm
127
+
128
+ res = pts[..., :ncol].reshape(*output_reshape, ncol)
129
+ return res
130
+
131
+
132
+ def inv(mat):
133
+ """Invert a torch or numpy matrix"""
134
+ if isinstance(mat, torch.Tensor):
135
+ return torch.linalg.inv(mat)
136
+ if isinstance(mat, np.ndarray):
137
+ return np.linalg.inv(mat)
138
+ raise ValueError(f"bad matrix type = {type(mat)}")
139
+
140
+
141
+ def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_):
142
+ """
143
+ Args:
144
+ - depthmap (BxHxW array):
145
+ - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W]
146
+ Returns:
147
+ pointmap of absolute coordinates (BxHxWx3 array)
148
+ """
149
+
150
+ if len(depth.shape) == 4:
151
+ B, H, W, n = depth.shape
152
+ else:
153
+ B, H, W = depth.shape
154
+ n = None
155
+
156
+ if len(pseudo_focal.shape) == 3: # [B,H,W]
157
+ pseudo_focalx = pseudo_focaly = pseudo_focal
158
+ elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W]
159
+ pseudo_focalx = pseudo_focal[:, 0]
160
+ if pseudo_focal.shape[1] == 2:
161
+ pseudo_focaly = pseudo_focal[:, 1]
162
+ else:
163
+ pseudo_focaly = pseudo_focalx
164
+ else:
165
+ raise NotImplementedError("Error, unknown input focal shape format.")
166
+
167
+ assert pseudo_focalx.shape == depth.shape[:3]
168
+ assert pseudo_focaly.shape == depth.shape[:3]
169
+ grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None]
170
+
171
+ # set principal point
172
+ if pp is None:
173
+ grid_x = grid_x - (W - 1) / 2
174
+ grid_y = grid_y - (H - 1) / 2
175
+ else:
176
+ grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None]
177
+ grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None]
178
+
179
+ if n is None:
180
+ pts3d = torch.empty((B, H, W, 3), device=depth.device)
181
+ pts3d[..., 0] = depth * grid_x / pseudo_focalx
182
+ pts3d[..., 1] = depth * grid_y / pseudo_focaly
183
+ pts3d[..., 2] = depth
184
+ else:
185
+ pts3d = torch.empty((B, H, W, 3, n), device=depth.device)
186
+ pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None]
187
+ pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None]
188
+ pts3d[..., 2, :] = depth
189
+ return pts3d
190
+
191
+
192
+ def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
193
+ """
194
+ Args:
195
+ - depthmap (HxW array):
196
+ - camera_intrinsics: a 3x3 matrix
197
+ Returns:
198
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
199
+ """
200
+ camera_intrinsics = np.float32(camera_intrinsics)
201
+ H, W = depthmap.shape
202
+
203
+ # Compute 3D ray associated with each pixel
204
+ # Strong assumption: there are no skew terms
205
+ assert camera_intrinsics[0, 1] == 0.0
206
+ assert camera_intrinsics[1, 0] == 0.0
207
+ if pseudo_focal is None:
208
+ fu = camera_intrinsics[0, 0]
209
+ fv = camera_intrinsics[1, 1]
210
+ else:
211
+ assert pseudo_focal.shape == (H, W)
212
+ fu = fv = pseudo_focal
213
+ cu = camera_intrinsics[0, 2]
214
+ cv = camera_intrinsics[1, 2]
215
+
216
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
217
+ z_cam = depthmap
218
+ x_cam = (u - cu) * z_cam / fu
219
+ y_cam = (v - cv) * z_cam / fv
220
+ X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
221
+
222
+ # Mask for valid coordinates
223
+ valid_mask = depthmap > 0.0
224
+ return X_cam, valid_mask
225
+
226
+
227
+ def depthmap_to_absolute_camera_coordinates(
228
+ depthmap, camera_intrinsics, camera_pose, **kw
229
+ ):
230
+ """
231
+ Args:
232
+ - depthmap (HxW array):
233
+ - camera_intrinsics: a 3x3 matrix
234
+ - camera_pose: a 4x3 or 4x4 cam2world matrix
235
+ Returns:
236
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
237
+ """
238
+ X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
239
+
240
+ # R_cam2world = np.float32(camera_params["R_cam2world"])
241
+ # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze()
242
+ R_cam2world = camera_pose[:3, :3]
243
+ t_cam2world = camera_pose[:3, 3]
244
+
245
+ # Express in absolute coordinates (invalid depth values)
246
+ X_world = (
247
+ np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
248
+ )
249
+ return X_world, valid_mask
250
+
251
+
252
+ def colmap_to_opencv_intrinsics(K):
253
+ """
254
+ Modify camera intrinsics to follow a different convention.
255
+ Coordinates of the center of the top-left pixels are by default:
256
+ - (0.5, 0.5) in Colmap
257
+ - (0,0) in OpenCV
258
+ """
259
+ K = K.copy()
260
+ K[0, 2] -= 0.5
261
+ K[1, 2] -= 0.5
262
+ return K
263
+
264
+
265
+ def opencv_to_colmap_intrinsics(K):
266
+ """
267
+ Modify camera intrinsics to follow a different convention.
268
+ Coordinates of the center of the top-left pixels are by default:
269
+ - (0.5, 0.5) in Colmap
270
+ - (0,0) in OpenCV
271
+ """
272
+ K = K.copy()
273
+ K[0, 2] += 0.5
274
+ K[1, 2] += 0.5
275
+ return K
276
+
277
+
278
+ def normalize_pointcloud(pts1, pts2, norm_mode="avg_dis", valid1=None, valid2=None):
279
+ """renorm pointmaps pts1, pts2 with norm_mode"""
280
+ assert pts1.ndim >= 3 and pts1.shape[-1] == 3
281
+ assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3)
282
+ norm_mode, dis_mode = norm_mode.split("_")
283
+
284
+ if norm_mode == "avg":
285
+ # gather all points together (joint normalization)
286
+ nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3)
287
+ nan_pts2, nnz2 = (
288
+ invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0)
289
+ )
290
+ all_pts = (
291
+ torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
292
+ )
293
+
294
+ # compute distance to origin
295
+ all_dis = all_pts.norm(dim=-1)
296
+ if dis_mode == "dis":
297
+ pass # do nothing
298
+ elif dis_mode == "log1p":
299
+ all_dis = torch.log1p(all_dis)
300
+ elif dis_mode == "warp-log1p":
301
+ # actually warp input points before normalizing them
302
+ log_dis = torch.log1p(all_dis)
303
+ warp_factor = log_dis / all_dis.clip(min=1e-8)
304
+ H1, W1 = pts1.shape[1:-1]
305
+ pts1 = pts1 * warp_factor[:, : W1 * H1].view(-1, H1, W1, 1)
306
+ if pts2 is not None:
307
+ H2, W2 = pts2.shape[1:-1]
308
+ pts2 = pts2 * warp_factor[:, W1 * H1 :].view(-1, H2, W2, 1)
309
+ all_dis = log_dis # this is their true distance afterwards
310
+ else:
311
+ raise ValueError(f"bad {dis_mode=}")
312
+
313
+ norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8)
314
+ else:
315
+ # gather all points together (joint normalization)
316
+ nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3)
317
+ nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None
318
+ all_pts = (
319
+ torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
320
+ )
321
+
322
+ # compute distance to origin
323
+ all_dis = all_pts.norm(dim=-1)
324
+
325
+ if norm_mode == "avg":
326
+ norm_factor = all_dis.nanmean(dim=1)
327
+ elif norm_mode == "median":
328
+ norm_factor = all_dis.nanmedian(dim=1).values.detach()
329
+ elif norm_mode == "sqrt":
330
+ norm_factor = all_dis.sqrt().nanmean(dim=1) ** 2
331
+ else:
332
+ raise ValueError(f"bad {norm_mode=}")
333
+
334
+ norm_factor = norm_factor.clip(min=1e-8)
335
+ while norm_factor.ndim < pts1.ndim:
336
+ norm_factor.unsqueeze_(-1)
337
+
338
+ res = pts1 / norm_factor
339
+ if pts2 is not None:
340
+ res = (res, pts2 / norm_factor)
341
+ return res
342
+
343
+
344
+ @torch.no_grad()
345
+ def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5):
346
+ # set invalid points to NaN
347
+ _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1)
348
+ _z2 = (
349
+ invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1)
350
+ if z2 is not None
351
+ else None
352
+ )
353
+ _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1
354
+
355
+ # compute median depth overall (ignoring nans)
356
+ if quantile == 0.5:
357
+ shift_z = torch.nanmedian(_z, dim=-1).values
358
+ else:
359
+ shift_z = torch.nanquantile(_z, quantile, dim=-1)
360
+ return shift_z # (B,)
361
+
362
+
363
+ @torch.no_grad()
364
+ def get_joint_pointcloud_center_scale(
365
+ pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True
366
+ ):
367
+ # set invalid points to NaN
368
+ _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3)
369
+ _pts2 = (
370
+ invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3)
371
+ if pts2 is not None
372
+ else None
373
+ )
374
+ _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1
375
+
376
+ # compute median center
377
+ _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
378
+ if z_only:
379
+ _center[..., :2] = 0 # do not center X and Y
380
+
381
+ # compute median norm
382
+ _norm = ((_pts - _center) if center else _pts).norm(dim=-1)
383
+ scale = torch.nanmedian(_norm, dim=1).values
384
+ return _center[:, None, :, :], scale[:, None, None, None]
385
+
386
+
387
+ def find_reciprocal_matches(P1, P2):
388
+ """
389
+ returns 3 values:
390
+ 1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match
391
+ 2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1
392
+ 3 - reciprocal_in_P2.sum(): the number of matches
393
+ """
394
+ tree1 = KDTree(P1)
395
+ tree2 = KDTree(P2)
396
+
397
+ _, nn1_in_P2 = tree2.query(P1, workers=8)
398
+ _, nn2_in_P1 = tree1.query(P2, workers=8)
399
+
400
+ reciprocal_in_P1 = nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2))
401
+ reciprocal_in_P2 = nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1))
402
+ assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum()
403
+ return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum()
404
+
405
+
406
+ def get_med_dist_between_poses(poses):
407
+ from scipy.spatial.distance import pdist
408
+
409
+ return np.median(pdist([to_numpy(p[:3, 3]) for p in poses]))
410
+
411
+
412
+ def save_pointcloud_with_plyfile(result, filename="output.ply", downsample_ratio=10):
413
+ all_points = []
414
+ all_colors = []
415
+
416
+ for view in result:
417
+ pts = view["point_map_by_unprojection"] # (1, H, W, 3)
418
+ rgbs = view["rgbs"] # (1, 3, H, W)
419
+ dpt_cnf = view["dpt_cnf"]
420
+
421
+ # Remove batch dimension
422
+ pts = pts.squeeze(0) # (H, W, 3)
423
+ rgbs = rgbs.squeeze(0).permute(1, 2, 0) # (3, H, W) -> (H, W, 3)
424
+
425
+ # Flatten
426
+ pts = pts.reshape(-1, 3) # (N, 3)
427
+ rgbs = rgbs.reshape(-1, 3) # (N, 3)
428
+
429
+ # Remove invalid points
430
+ valid = torch.isfinite(pts).all(dim=1) & (pts.norm(dim=1) > 0)
431
+ valid = valid & (dpt_cnf > torch.quantile(view["dpt_cnf"], 0.5)).flatten()
432
+ pts = pts[valid]
433
+ rgbs = rgbs[valid]
434
+
435
+ # Downsample this view
436
+ N = pts.shape[0]
437
+ if downsample_ratio > 1 and N >= downsample_ratio:
438
+ idx = torch.randperm(N)[: N // downsample_ratio]
439
+ pts = pts[idx]
440
+ rgbs = rgbs[idx]
441
+
442
+ all_points.append(pts)
443
+ all_colors.append(rgbs)
444
+
445
+ # Merge all views
446
+ all_points = torch.cat(all_points, dim=0).cpu().numpy()
447
+ all_colors = torch.cat(all_colors, dim=0).cpu().numpy()
448
+
449
+ # Normalize color
450
+ if all_colors.max() <= 1.0:
451
+ all_colors = (all_colors * 255).astype(np.uint8)
452
+ else:
453
+ all_colors = all_colors.astype(np.uint8)
454
+
455
+ # Build structured array
456
+ vertex_data = np.empty(
457
+ len(all_points),
458
+ dtype=[
459
+ ("x", "f4"),
460
+ ("y", "f4"),
461
+ ("z", "f4"),
462
+ ("red", "u1"),
463
+ ("green", "u1"),
464
+ ("blue", "u1"),
465
+ ],
466
+ )
467
+ vertex_data["x"] = all_points[:, 0]
468
+ vertex_data["y"] = all_points[:, 1]
469
+ vertex_data["z"] = all_points[:, 2]
470
+ vertex_data["red"] = all_colors[:, 0]
471
+ vertex_data["green"] = all_colors[:, 1]
472
+ vertex_data["blue"] = all_colors[:, 2]
473
+
474
+ # Save with plyfile
475
+ el = PlyElement.describe(vertex_data, "vertex")
476
+ PlyData([el], text=False).write(filename)
477
+
478
+ print(f"[PLY] Saved {len(all_points)} points to {filename}")
479
+
480
+
481
+ def save_pointcloud_with_plyfile_each_frame(
482
+ result, filename="output.ply", downsample_ratio=10
483
+ ):
484
+ for frame_number, view in enumerate(result):
485
+ all_points = []
486
+ all_colors = []
487
+ pts = view["point_map_by_unprojection"] # (1, H, W, 3)
488
+ rgbs = view["rgbs"] # (1, 3, H, W)
489
+ dpt_cnf = view["dpt_cnf"]
490
+
491
+ # Remove batch dimension
492
+ pts = pts.squeeze(0) # (H, W, 3)
493
+ rgbs = rgbs.squeeze(0).permute(1, 2, 0) # (3, H, W) -> (H, W, 3)
494
+
495
+ # Flatten
496
+ pts = pts.reshape(-1, 3) # (N, 3)
497
+ rgbs = rgbs.reshape(-1, 3) # (N, 3)
498
+
499
+ # Remove invalid points
500
+ valid = torch.isfinite(pts).all(dim=1) & (pts.norm(dim=1) > 0)
501
+ valid = valid & (dpt_cnf > torch.quantile(view["dpt_cnf"], 0.5)).flatten()
502
+ pts = pts[valid]
503
+ rgbs = rgbs[valid]
504
+
505
+ # Downsample this view
506
+ N = pts.shape[0]
507
+ if downsample_ratio > 1 and N >= downsample_ratio:
508
+ idx = torch.randperm(N)[: N // downsample_ratio]
509
+ pts = pts[idx]
510
+ rgbs = rgbs[idx]
511
+
512
+ all_points.append(pts)
513
+ all_colors.append(rgbs)
514
+
515
+ # Merge all views
516
+ all_points = torch.cat(all_points, dim=0).cpu().numpy()
517
+ all_colors = torch.cat(all_colors, dim=0).cpu().numpy()
518
+
519
+ # Normalize color
520
+ if all_colors.max() <= 1.0:
521
+ all_colors = (all_colors * 255).astype(np.uint8)
522
+ else:
523
+ all_colors = all_colors.astype(np.uint8)
524
+
525
+ # Build structured array
526
+ vertex_data = np.empty(
527
+ len(all_points),
528
+ dtype=[
529
+ ("x", "f4"),
530
+ ("y", "f4"),
531
+ ("z", "f4"),
532
+ ("red", "u1"),
533
+ ("green", "u1"),
534
+ ("blue", "u1"),
535
+ ],
536
+ )
537
+ vertex_data["x"] = all_points[:, 0]
538
+ vertex_data["y"] = all_points[:, 1]
539
+ vertex_data["z"] = all_points[:, 2]
540
+ vertex_data["red"] = all_colors[:, 0]
541
+ vertex_data["green"] = all_colors[:, 1]
542
+ vertex_data["blue"] = all_colors[:, 2]
543
+
544
+ # Save with plyfile
545
+ el = PlyElement.describe(vertex_data, "vertex")
546
+ PlyData([el], text=False).write(
547
+ filename.split(".ply")[0] + f"_{frame_number:05d}.ply"
548
+ )
549
+
550
+ print(
551
+ f"[PLY] Saved {len(all_points)} points to {filename} idx {frame_number:05d}"
552
+ )
553
+
554
+
555
+ def save_concatenated_images(samples, save_path):
556
+ imgs = []
557
+
558
+ for sample in samples:
559
+ img = sample["img"] # (1, C, H, W)
560
+ img = F.interpolate(
561
+ img, scale_factor=0.25, mode="bilinear", align_corners=False
562
+ )
563
+ imgs.append(img)
564
+
565
+ imgs = torch.cat(imgs, dim=0).cpu() # (N, C, H, W)
566
+
567
+ save_path = Path(save_path)
568
+ save_path.parent.mkdir(parents=True, exist_ok=True)
569
+
570
+ vutils.save_image(imgs, save_path, normalize=True)
571
+
572
+ print(f"[Image] Saved concatenated image to {save_path}")
eval/utils/image.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
8
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
9
+ #
10
+ # --------------------------------------------------------
11
+ # utilitary functions about images (loading/converting...)
12
+ # --------------------------------------------------------
13
+ import os
14
+ from typing import Dict, Optional
15
+
16
+ import numpy as np
17
+ import PIL.Image
18
+ import torch
19
+ import torchvision.transforms as tvf
20
+ from PIL.ImageOps import exif_transpose
21
+
22
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
23
+ import cv2
24
+
25
+ try:
26
+ from pillow_heif import register_heif_opener
27
+
28
+ register_heif_opener()
29
+ heif_support_enabled = True
30
+ except ImportError:
31
+ heif_support_enabled = False
32
+
33
+ ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
34
+
35
+
36
+ def imread_cv2(path, options=cv2.IMREAD_COLOR):
37
+ """Open an image or a depthmap with opencv-python."""
38
+ if path.endswith((".exr", "EXR")):
39
+ options = cv2.IMREAD_ANYDEPTH
40
+ img = cv2.imread(path, options)
41
+ if img is None:
42
+ raise IOError(f"Could not load image={path} with {options=}")
43
+ if img.ndim == 3:
44
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
45
+ return img
46
+
47
+
48
+ def rgb(ftensor, true_shape=None):
49
+ if isinstance(ftensor, list):
50
+ return [rgb(x, true_shape=true_shape) for x in ftensor]
51
+ if isinstance(ftensor, torch.Tensor):
52
+ ftensor = ftensor.detach().cpu().numpy() # H,W,3
53
+ if ftensor.ndim == 3 and ftensor.shape[0] == 3:
54
+ ftensor = ftensor.transpose(1, 2, 0)
55
+ elif ftensor.ndim == 4 and ftensor.shape[1] == 3:
56
+ ftensor = ftensor.transpose(0, 2, 3, 1)
57
+ if true_shape is not None:
58
+ H, W = true_shape
59
+ ftensor = ftensor[:H, :W]
60
+ if ftensor.dtype == np.uint8:
61
+ img = np.float32(ftensor) / 255
62
+ else:
63
+ img = (ftensor * 0.5) + 0.5
64
+ return img.clip(min=0, max=1)
65
+
66
+
67
+ def _resize_pil_image(img, long_edge_size):
68
+ S = max(img.size)
69
+ if S > long_edge_size:
70
+ interp = PIL.Image.LANCZOS
71
+ elif S <= long_edge_size:
72
+ interp = PIL.Image.BICUBIC
73
+ new_size = tuple(int(round(x * long_edge_size / S)) for x in img.size)
74
+ return img.resize(new_size, interp)
75
+
76
+
77
+ def load_images(
78
+ folder_or_list,
79
+ size,
80
+ square_ok=False,
81
+ verbose=True,
82
+ rotate_clockwise_90=False,
83
+ crop_to_landscape=False,
84
+ ):
85
+ """open and convert all images in a list or folder to proper input format for DUSt3R"""
86
+ if isinstance(folder_or_list, str):
87
+ if verbose:
88
+ print(f">> Loading images from {folder_or_list}")
89
+ root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
90
+
91
+ elif isinstance(folder_or_list, list):
92
+ if verbose:
93
+ print(f">> Loading a list of {len(folder_or_list)} images")
94
+ root, folder_content = "", folder_or_list
95
+
96
+ else:
97
+ raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})")
98
+
99
+ supported_images_extensions = [".jpg", ".jpeg", ".png"]
100
+ if heif_support_enabled:
101
+ supported_images_extensions += [".heic", ".heif"]
102
+ supported_images_extensions = tuple(supported_images_extensions)
103
+
104
+ imgs = []
105
+ for path in folder_content:
106
+ if not path.lower().endswith(supported_images_extensions):
107
+ continue
108
+ img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB")
109
+ if rotate_clockwise_90:
110
+ img = img.rotate(-90, expand=True)
111
+ if crop_to_landscape:
112
+ # Crop to a landscape aspect ratio (e.g., 16:9)
113
+ desired_aspect_ratio = 4 / 3
114
+ width, height = img.size
115
+ current_aspect_ratio = width / height
116
+
117
+ if current_aspect_ratio > desired_aspect_ratio:
118
+ # Wider than landscape: crop width
119
+ new_width = int(height * desired_aspect_ratio)
120
+ left = (width - new_width) // 2
121
+ right = left + new_width
122
+ top = 0
123
+ bottom = height
124
+ else:
125
+ # Taller than landscape: crop height
126
+ new_height = int(width / desired_aspect_ratio)
127
+ top = (height - new_height) // 2
128
+ bottom = top + new_height
129
+ left = 0
130
+ right = width
131
+
132
+ img = img.crop((left, top, right, bottom))
133
+
134
+ W1, H1 = img.size
135
+ if size == 224:
136
+ # resize short side to 224 (then crop)
137
+ img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1)))
138
+ else:
139
+ # resize long side to 512
140
+ img = _resize_pil_image(img, size)
141
+ W, H = img.size
142
+ cx, cy = W // 2, H // 2
143
+ if size == 224:
144
+ half = min(cx, cy)
145
+ img = img.crop((cx - half, cy - half, cx + half, cy + half))
146
+ else:
147
+ halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8
148
+ if not (square_ok) and W == H:
149
+ halfh = 3 * halfw / 4
150
+ img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh))
151
+
152
+ W2, H2 = img.size
153
+ if verbose:
154
+ print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}")
155
+ imgs.append(
156
+ dict(
157
+ img=ImgNorm(img)[None],
158
+ true_shape=np.int32([img.size[::-1]]),
159
+ idx=len(imgs),
160
+ instance=str(len(imgs)),
161
+ )
162
+ )
163
+
164
+ assert imgs, "no images foud at " + root
165
+ if verbose:
166
+ print(f" (Found {len(imgs)} images)")
167
+ return imgs
168
+
169
+
170
+ def get_image_vggt_augmentation(
171
+ color_jitter: Optional[Dict[str, float]] = None,
172
+ gray_scale: bool = True,
173
+ gau_blur: bool = False,
174
+ ) -> Optional[tvf.Compose]:
175
+ """Create a composition of image augmentations.
176
+
177
+ Args:
178
+ color_jitter: Dictionary containing color jitter parameters:
179
+ - brightness: float (default: 0.5)
180
+ - contrast: float (default: 0.5)
181
+ - saturation: float (default: 0.5)
182
+ - hue: float (default: 0.1)
183
+ - p: probability of applying (default: 0.9)
184
+ If None, uses default values
185
+ gray_scale: Whether to apply random grayscale (default: True)
186
+ gau_blur: Whether to apply gaussian blur (default: False)
187
+
188
+ Returns:
189
+ A Compose object of transforms or None if no transforms are added
190
+ """
191
+ transform_list = []
192
+ default_jitter = {
193
+ "brightness": 0.5,
194
+ "contrast": 0.5,
195
+ "saturation": 0.5,
196
+ "hue": 0.1,
197
+ "p": 0.9,
198
+ }
199
+
200
+ # Handle color jitter
201
+ if color_jitter is not None:
202
+ if not isinstance(color_jitter, dict):
203
+ raise ValueError("color_jitter must be a dictionary or None")
204
+ # Merge with defaults for missing keys
205
+ effective_jitter = {**default_jitter, **color_jitter}
206
+ else:
207
+ effective_jitter = default_jitter
208
+
209
+ transform_list.append(
210
+ tvf.RandomApply(
211
+ [
212
+ tvf.ColorJitter(
213
+ brightness=effective_jitter["brightness"],
214
+ contrast=effective_jitter["contrast"],
215
+ saturation=effective_jitter["saturation"],
216
+ hue=effective_jitter["hue"],
217
+ )
218
+ ],
219
+ p=effective_jitter["p"],
220
+ )
221
+ )
222
+
223
+ if gray_scale:
224
+ transform_list.append(tvf.RandomGrayscale(p=0.05))
225
+
226
+ if gau_blur:
227
+ transform_list.append(
228
+ tvf.RandomApply([tvf.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05)
229
+ )
230
+ # transform_list.append(tvf.ToTensor())
231
+
232
+ return tvf.Compose(transform_list) if transform_list else None
eval/utils/load_fn.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from torchvision import transforms as TF
10
+
11
+
12
+ def load_and_preprocess_images(image_path_list, mode="crop"):
13
+ """
14
+ A quick start function to load and preprocess images for model input.
15
+ This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
16
+
17
+ Args:
18
+ image_path_list (list): List of paths to image files
19
+ mode (str, optional): Preprocessing mode, either "crop" or "pad".
20
+ - "crop" (default): Sets width to 518px and center crops height if needed.
21
+ - "pad": Preserves all pixels by making the largest dimension 518px
22
+ and padding the smaller dimension to reach a square shape.
23
+
24
+ Returns:
25
+ torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
26
+
27
+ Raises:
28
+ ValueError: If the input list is empty or if mode is invalid
29
+
30
+ Notes:
31
+ - Images with different dimensions will be padded with white (value=1.0)
32
+ - A warning is printed when images have different shapes
33
+ - When mode="crop": The function ensures width=518px while maintaining aspect ratio
34
+ and height is center-cropped if larger than 518px
35
+ - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
36
+ and the smaller dimension is padded to reach a square shape (518x518)
37
+ - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
38
+ """
39
+ # Check for empty list
40
+ if len(image_path_list) == 0:
41
+ raise ValueError("At least 1 image is required")
42
+
43
+ # Validate mode
44
+ if mode not in ["crop", "pad"]:
45
+ raise ValueError("Mode must be either 'crop' or 'pad'")
46
+
47
+ images = []
48
+ shapes = set()
49
+ to_tensor = TF.ToTensor()
50
+ target_size = 518
51
+
52
+ # First process all images and collect their shapes
53
+ for image_path in image_path_list:
54
+ # Open image
55
+ img = Image.open(image_path)
56
+
57
+ # If there's an alpha channel, blend onto white background:
58
+ if img.mode == "RGBA":
59
+ # Create white background
60
+ background = Image.new("RGBA", img.size, (255, 255, 255, 255))
61
+ # Alpha composite onto the white background
62
+ img = Image.alpha_composite(background, img)
63
+
64
+ # Now convert to "RGB" (this step assigns white for transparent areas)
65
+ img = img.convert("RGB")
66
+
67
+ width, height = img.size
68
+
69
+ if mode == "pad":
70
+ # Make the largest dimension 518px while maintaining aspect ratio
71
+ if width >= height:
72
+ new_width = target_size
73
+ new_height = (
74
+ round(height * (new_width / width) / 14) * 14
75
+ ) # Make divisible by 14
76
+ else:
77
+ new_height = target_size
78
+ new_width = (
79
+ round(width * (new_height / height) / 14) * 14
80
+ ) # Make divisible by 14
81
+ else: # mode == "crop"
82
+ # Original behavior: set width to 518px
83
+ new_width = target_size
84
+ # Calculate height maintaining aspect ratio, divisible by 14
85
+ new_height = round(height * (new_width / width) / 14) * 14
86
+
87
+ # Resize with new dimensions (width, height)
88
+ img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
89
+ img = to_tensor(img) # Convert to tensor (0, 1)
90
+
91
+ # Center crop height if it's larger than 518 (only in crop mode)
92
+ if mode == "crop" and new_height > target_size:
93
+ start_y = (new_height - target_size) // 2
94
+ img = img[:, start_y : start_y + target_size, :]
95
+
96
+ # For pad mode, pad to make a square of target_size x target_size
97
+ if mode == "pad":
98
+ h_padding = target_size - img.shape[1]
99
+ w_padding = target_size - img.shape[2]
100
+
101
+ if h_padding > 0 or w_padding > 0:
102
+ pad_top = h_padding // 2
103
+ pad_bottom = h_padding - pad_top
104
+ pad_left = w_padding // 2
105
+ pad_right = w_padding - pad_left
106
+
107
+ # Pad with white (value=1.0)
108
+ img = torch.nn.functional.pad(
109
+ img,
110
+ (pad_left, pad_right, pad_top, pad_bottom),
111
+ mode="constant",
112
+ value=1.0,
113
+ )
114
+
115
+ shapes.add((img.shape[1], img.shape[2]))
116
+ images.append(img)
117
+
118
+ # Check if we have different shapes
119
+ # In theory our model can also work well with different shapes
120
+ if len(shapes) > 1:
121
+ print(f"Warning: Found images with different shapes: {shapes}")
122
+ # Find maximum dimensions
123
+ max_height = max(shape[0] for shape in shapes)
124
+ max_width = max(shape[1] for shape in shapes)
125
+
126
+ # Pad images if necessary
127
+ padded_images = []
128
+ for img in images:
129
+ h_padding = max_height - img.shape[1]
130
+ w_padding = max_width - img.shape[2]
131
+
132
+ if h_padding > 0 or w_padding > 0:
133
+ pad_top = h_padding // 2
134
+ pad_bottom = h_padding - pad_top
135
+ pad_left = w_padding // 2
136
+ pad_right = w_padding - pad_left
137
+
138
+ img = torch.nn.functional.pad(
139
+ img,
140
+ (pad_left, pad_right, pad_top, pad_bottom),
141
+ mode="constant",
142
+ value=1.0,
143
+ )
144
+ padded_images.append(img)
145
+ images = padded_images
146
+
147
+ images = torch.stack(images) # concatenate images
148
+
149
+ # Ensure correct shape when single image
150
+ if len(image_path_list) == 1:
151
+ # Verify shape is (1, C, H, W)
152
+ if images.dim() == 3:
153
+ images = images.unsqueeze(0)
154
+
155
+ return images
eval/utils/misc.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
8
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
9
+ #
10
+ # --------------------------------------------------------
11
+ # utilitary functions for DUSt3R
12
+ # --------------------------------------------------------
13
+ import torch
14
+
15
+
16
+ def fill_default_args(kwargs, func):
17
+ import inspect # a bit hacky but it works reliably
18
+
19
+ signature = inspect.signature(func)
20
+
21
+ for k, v in signature.parameters.items():
22
+ if v.default is inspect.Parameter.empty:
23
+ continue
24
+ kwargs.setdefault(k, v.default)
25
+
26
+ return kwargs
27
+
28
+
29
+ def freeze_all_params(modules):
30
+ for module in modules:
31
+ try:
32
+ for n, param in module.named_parameters():
33
+ param.requires_grad = False
34
+ except AttributeError:
35
+ # module is directly a parameter
36
+ module.requires_grad = False
37
+
38
+
39
+ def is_symmetrized(gt1, gt2):
40
+ x = gt1["instance"]
41
+ y = gt2["instance"]
42
+ if len(x) == len(y) and len(x) == 1:
43
+ return False # special case of batchsize 1
44
+ ok = True
45
+ for i in range(0, len(x), 2):
46
+ ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i])
47
+ return ok
48
+
49
+
50
+ def flip(tensor):
51
+ """flip so that tensor[0::2] <=> tensor[1::2]"""
52
+ return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1)
53
+
54
+
55
+ def interleave(tensor1, tensor2):
56
+ res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1)
57
+ res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1)
58
+ return res1, res2
59
+
60
+
61
+ def transpose_to_landscape(head, activate=True):
62
+ """Predict in the correct aspect-ratio,
63
+ then transpose the result in landscape
64
+ and stack everything back together.
65
+ """
66
+
67
+ def wrapper_no(decout, true_shape):
68
+ B = len(true_shape)
69
+ assert true_shape[0:1].allclose(true_shape), "true_shape must be all identical"
70
+ H, W = true_shape[0].cpu().tolist()
71
+ res = head(decout, (H, W))
72
+ return res
73
+
74
+ def wrapper_yes(decout, true_shape):
75
+ B = len(true_shape)
76
+ # by definition, the batch is in landscape mode so W >= H
77
+ H, W = int(true_shape.min()), int(true_shape.max())
78
+
79
+ height, width = true_shape.T
80
+ is_landscape = width >= height
81
+ is_portrait = ~is_landscape
82
+
83
+ # true_shape = true_shape.cpu()
84
+ if is_landscape.all():
85
+ return head(decout, (H, W))
86
+ if is_portrait.all():
87
+ return transposed(head(decout, (W, H)))
88
+
89
+ # batch is a mix of both portraint & landscape
90
+ def selout(ar):
91
+ return [d[ar] for d in decout]
92
+
93
+ l_result = head(selout(is_landscape), (H, W))
94
+ p_result = transposed(head(selout(is_portrait), (W, H)))
95
+
96
+ # allocate full result
97
+ result = {}
98
+ for k in l_result | p_result:
99
+ x = l_result[k].new(B, *l_result[k].shape[1:])
100
+ x[is_landscape] = l_result[k]
101
+ x[is_portrait] = p_result[k]
102
+ result[k] = x
103
+
104
+ return result
105
+
106
+ return wrapper_yes if activate else wrapper_no
107
+
108
+
109
+ def transposed(dic):
110
+ return {k: v.swapaxes(1, 2) for k, v in dic.items()}
111
+
112
+
113
+ def invalid_to_nans(arr, valid_mask, ndim=999):
114
+ if valid_mask is not None:
115
+ arr = arr.clone()
116
+ arr[~valid_mask] = float("nan")
117
+ if arr.ndim > ndim:
118
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
119
+ return arr
120
+
121
+
122
+ def invalid_to_zeros(arr, valid_mask, ndim=999):
123
+ if valid_mask is not None:
124
+ arr = arr.clone()
125
+ arr[~valid_mask] = 0
126
+ nnz = valid_mask.view(len(valid_mask), -1).sum(1)
127
+ else:
128
+ nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
129
+ if arr.ndim > ndim:
130
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
131
+ return arr, nnz
eval/utils/pose_enc.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+ from .rotation import mat_to_quat, quat_to_mat
10
+
11
+
12
+ def extri_intri_to_pose_encoding(
13
+ extrinsics,
14
+ intrinsics,
15
+ image_size_hw=None, # e.g., (256, 512)
16
+ pose_encoding_type="absT_quaR_FoV",
17
+ ):
18
+ """Convert camera extrinsics and intrinsics to a compact pose encoding.
19
+
20
+ This function transforms camera parameters into a unified pose encoding format,
21
+ which can be used for various downstream tasks like pose prediction or representation.
22
+
23
+ Args:
24
+ extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
25
+ where B is batch size and S is sequence length.
26
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
27
+ The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
28
+ intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
29
+ Defined in pixels, with format:
30
+ [[fx, 0, cx],
31
+ [0, fy, cy],
32
+ [0, 0, 1]]
33
+ where fx, fy are focal lengths and (cx, cy) is the principal point
34
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
35
+ Required for computing field of view values. For example: (256, 512).
36
+ pose_encoding_type (str): Type of pose encoding to use. Currently only
37
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
38
+
39
+ Returns:
40
+ torch.Tensor: Encoded camera pose parameters with shape BxSx9.
41
+ For "absT_quaR_FoV" type, the 9 dimensions are:
42
+ - [:3] = absolute translation vector T (3D)
43
+ - [3:7] = rotation as quaternion quat (4D)
44
+ - [7:] = field of view (2D)
45
+ """
46
+
47
+ # extrinsics: BxSx3x4
48
+ # intrinsics: BxSx3x3
49
+
50
+ if pose_encoding_type == "absT_quaR_FoV":
51
+ R = extrinsics[:, :, :3, :3] # BxSx3x3
52
+ T = extrinsics[:, :, :3, 3] # BxSx3
53
+
54
+ quat = mat_to_quat(R)
55
+ # Note the order of h and w here
56
+ H, W = image_size_hw
57
+ fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
58
+ fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
59
+ pose_encoding = torch.cat(
60
+ [T, quat, fov_h[..., None], fov_w[..., None]], dim=-1
61
+ ).float()
62
+ else:
63
+ raise NotImplementedError
64
+
65
+ return pose_encoding
66
+
67
+
68
+ def pose_encoding_to_extri_intri(
69
+ pose_encoding,
70
+ image_size_hw=None, # e.g., (256, 512)
71
+ pose_encoding_type="absT_quaR_FoV",
72
+ build_intrinsics=True,
73
+ ):
74
+ """Convert a pose encoding back to camera extrinsics and intrinsics.
75
+
76
+ This function performs the inverse operation of extri_intri_to_pose_encoding,
77
+ reconstructing the full camera parameters from the compact encoding.
78
+
79
+ Args:
80
+ pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,
81
+ where B is batch size and S is sequence length.
82
+ For "absT_quaR_FoV" type, the 9 dimensions are:
83
+ - [:3] = absolute translation vector T (3D)
84
+ - [3:7] = rotation as quaternion quat (4D)
85
+ - [7:] = field of view (2D)
86
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
87
+ Required for reconstructing intrinsics from field of view values.
88
+ For example: (256, 512).
89
+ pose_encoding_type (str): Type of pose encoding used. Currently only
90
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
91
+ build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.
92
+ If False, only extrinsics are returned and intrinsics will be None.
93
+
94
+ Returns:
95
+ tuple: (extrinsics, intrinsics)
96
+ - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.
97
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world
98
+ transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is
99
+ a 3x1 translation vector.
100
+ - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,
101
+ or None if build_intrinsics is False. Defined in pixels, with format:
102
+ [[fx, 0, cx],
103
+ [0, fy, cy],
104
+ [0, 0, 1]]
105
+ where fx, fy are focal lengths and (cx, cy) is the principal point,
106
+ assumed to be at the center of the image (W/2, H/2).
107
+ """
108
+
109
+ intrinsics = None
110
+
111
+ if pose_encoding_type == "absT_quaR_FoV":
112
+ T = pose_encoding[..., :3]
113
+ quat = pose_encoding[..., 3:7]
114
+ fov_h = pose_encoding[..., 7]
115
+ fov_w = pose_encoding[..., 8]
116
+
117
+ R = quat_to_mat(quat)
118
+ extrinsics = torch.cat([R, T[..., None]], dim=-1)
119
+
120
+ if build_intrinsics:
121
+ H, W = image_size_hw
122
+ fy = (H / 2.0) / torch.tan(fov_h / 2.0)
123
+ fx = (W / 2.0) / torch.tan(fov_w / 2.0)
124
+ intrinsics = torch.zeros(
125
+ pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device
126
+ )
127
+ intrinsics[..., 0, 0] = fx
128
+ intrinsics[..., 1, 1] = fy
129
+ intrinsics[..., 0, 2] = W / 2
130
+ intrinsics[..., 1, 2] = H / 2
131
+ intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
132
+ else:
133
+ raise NotImplementedError
134
+
135
+ return extrinsics, intrinsics
eval/utils/rotation.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+
14
+ def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
15
+ """
16
+ Quaternion Order: XYZW or say ijkr, scalar-last
17
+
18
+ Convert rotations given as quaternions to rotation matrices.
19
+ Args:
20
+ quaternions: quaternions with real part last,
21
+ as tensor of shape (..., 4).
22
+
23
+ Returns:
24
+ Rotation matrices as tensor of shape (..., 3, 3).
25
+ """
26
+ i, j, k, r = torch.unbind(quaternions, -1)
27
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
28
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
29
+
30
+ o = torch.stack(
31
+ (
32
+ 1 - two_s * (j * j + k * k),
33
+ two_s * (i * j - k * r),
34
+ two_s * (i * k + j * r),
35
+ two_s * (i * j + k * r),
36
+ 1 - two_s * (i * i + k * k),
37
+ two_s * (j * k - i * r),
38
+ two_s * (i * k - j * r),
39
+ two_s * (j * k + i * r),
40
+ 1 - two_s * (i * i + j * j),
41
+ ),
42
+ -1,
43
+ )
44
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
45
+
46
+
47
+ def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
48
+ """
49
+ Convert rotations given as rotation matrices to quaternions.
50
+
51
+ Args:
52
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
53
+
54
+ Returns:
55
+ quaternions with real part last, as tensor of shape (..., 4).
56
+ Quaternion Order: XYZW or say ijkr, scalar-last
57
+ """
58
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
59
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
60
+
61
+ batch_dim = matrix.shape[:-2]
62
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
63
+ matrix.reshape(batch_dim + (9,)), dim=-1
64
+ )
65
+
66
+ q_abs = _sqrt_positive_part(
67
+ torch.stack(
68
+ [
69
+ 1.0 + m00 + m11 + m22,
70
+ 1.0 + m00 - m11 - m22,
71
+ 1.0 - m00 + m11 - m22,
72
+ 1.0 - m00 - m11 + m22,
73
+ ],
74
+ dim=-1,
75
+ )
76
+ )
77
+
78
+ # we produce the desired quaternion multiplied by each of r, i, j, k
79
+ quat_by_rijk = torch.stack(
80
+ [
81
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
82
+ # `int`.
83
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
84
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
85
+ # `int`.
86
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
87
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
88
+ # `int`.
89
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
90
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
91
+ # `int`.
92
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
93
+ ],
94
+ dim=-2,
95
+ )
96
+
97
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
98
+ # the candidate won't be picked.
99
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
100
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
101
+
102
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
103
+ # forall i; we pick the best-conditioned one (with the largest denominator)
104
+ out = quat_candidates[
105
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
106
+ ].reshape(batch_dim + (4,))
107
+
108
+ # Convert from rijk to ijkr
109
+ out = out[..., [1, 2, 3, 0]]
110
+
111
+ out = standardize_quaternion(out)
112
+
113
+ return out
114
+
115
+
116
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
117
+ """
118
+ Returns torch.sqrt(torch.max(0, x))
119
+ but with a zero subgradient where x is 0.
120
+ """
121
+ ret = torch.zeros_like(x)
122
+ positive_mask = x > 0
123
+ if torch.is_grad_enabled():
124
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
125
+ else:
126
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
127
+ return ret
128
+
129
+
130
+ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
131
+ """
132
+ Convert a unit quaternion to a standard form: one in which the real
133
+ part is non negative.
134
+
135
+ Args:
136
+ quaternions: Quaternions with real part last,
137
+ as tensor of shape (..., 4).
138
+
139
+ Returns:
140
+ Standardized quaternions as tensor of shape (..., 4).
141
+ """
142
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
eval/utils/visual_track.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+
9
+ import cv2
10
+ import numpy as np
11
+ import torch
12
+
13
+
14
+ def color_from_xy(x, y, W, H, cmap_name="hsv"):
15
+ """
16
+ Map (x, y) -> color in (R, G, B).
17
+ 1) Normalize x,y to [0,1].
18
+ 2) Combine them into a single scalar c in [0,1].
19
+ 3) Use matplotlib's colormap to convert c -> (R,G,B).
20
+
21
+ You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y).
22
+ """
23
+ import matplotlib.cm
24
+ import matplotlib.colors
25
+
26
+ x_norm = x / max(W - 1, 1)
27
+ y_norm = y / max(H - 1, 1)
28
+ # Simple combination:
29
+ c = (x_norm + y_norm) / 2.0
30
+
31
+ cmap = matplotlib.cm.get_cmap(cmap_name)
32
+ # cmap(c) -> (r,g,b,a) in [0,1]
33
+ rgba = cmap(c)
34
+ r, g, b = rgba[0], rgba[1], rgba[2]
35
+ return (r, g, b) # in [0,1], RGB order
36
+
37
+
38
+ def get_track_colors_by_position(
39
+ tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"
40
+ ):
41
+ """
42
+ Given all tracks in one sample (b), compute a (N,3) array of RGB color values
43
+ in [0,255]. The color is determined by the (x,y) position in the first
44
+ visible frame for each track.
45
+
46
+ Args:
47
+ tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame.
48
+ vis_mask_b: (S, N) boolean mask; if None, assume all are visible.
49
+ image_width, image_height: used for normalizing (x, y).
50
+ cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet').
51
+
52
+ Returns:
53
+ track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255].
54
+ """
55
+ S, N, _ = tracks_b.shape
56
+ track_colors = np.zeros((N, 3), dtype=np.uint8)
57
+
58
+ if vis_mask_b is None:
59
+ # treat all as visible
60
+ vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device)
61
+
62
+ for i in range(N):
63
+ # Find first visible frame for track i
64
+ visible_frames = torch.where(vis_mask_b[:, i])[0]
65
+ if len(visible_frames) == 0:
66
+ # track is never visible; just assign black or something
67
+ track_colors[i] = (0, 0, 0)
68
+ continue
69
+
70
+ first_s = int(visible_frames[0].item())
71
+ # use that frame's (x,y)
72
+ x, y = tracks_b[first_s, i].tolist()
73
+
74
+ # map (x,y) -> (R,G,B) in [0,1]
75
+ r, g, b = color_from_xy(
76
+ x, y, W=image_width, H=image_height, cmap_name=cmap_name
77
+ )
78
+ # scale to [0,255]
79
+ r, g, b = int(r * 255), int(g * 255), int(b * 255)
80
+ track_colors[i] = (r, g, b)
81
+
82
+ return track_colors
83
+
84
+
85
+ def visualize_tracks_on_images(
86
+ images,
87
+ tracks,
88
+ track_vis_mask=None,
89
+ out_dir="track_visuals_concat_by_xy",
90
+ image_format="CHW", # "CHW" or "HWC"
91
+ normalize_mode="[0,1]",
92
+ cmap_name="hsv", # e.g. "hsv", "rainbow", "jet"
93
+ frames_per_row=4, # New parameter for grid layout
94
+ save_grid=True, # Flag to control whether to save the grid image
95
+ ):
96
+ """
97
+ Visualizes frames in a grid layout with specified frames per row.
98
+ Each track's color is determined by its (x,y) position
99
+ in the first visible frame (or frame 0 if always visible).
100
+ Finally convert the BGR result to RGB before saving.
101
+ Also saves each individual frame as a separate PNG file.
102
+
103
+ Args:
104
+ images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC.
105
+ tracks: torch.Tensor (S, N, 2), last dim = (x, y).
106
+ track_vis_mask: torch.Tensor (S, N) or None.
107
+ out_dir: folder to save visualizations.
108
+ image_format: "CHW" or "HWC".
109
+ normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255
110
+ cmap_name: a matplotlib colormap name for color_from_xy.
111
+ frames_per_row: number of frames to display in each row of the grid.
112
+ save_grid: whether to save all frames in one grid image.
113
+
114
+ Returns:
115
+ None (saves images in out_dir).
116
+ """
117
+
118
+ if len(tracks.shape) == 4:
119
+ tracks = tracks.squeeze(0)
120
+ images = images.squeeze(0)
121
+ if track_vis_mask is not None:
122
+ track_vis_mask = track_vis_mask.squeeze(0)
123
+
124
+ import matplotlib
125
+
126
+ matplotlib.use("Agg") # for non-interactive (optional)
127
+
128
+ os.makedirs(out_dir, exist_ok=True)
129
+
130
+ S = images.shape[0]
131
+ _, N, _ = tracks.shape # (S, N, 2)
132
+
133
+ # Move to CPU
134
+ images = images.cpu().clone()
135
+ tracks = tracks.cpu().clone()
136
+ if track_vis_mask is not None:
137
+ track_vis_mask = track_vis_mask.cpu().clone()
138
+
139
+ # Infer H, W from images shape
140
+ if image_format == "CHW":
141
+ # e.g. images[s].shape = (3, H, W)
142
+ H, W = images.shape[2], images.shape[3]
143
+ else:
144
+ # e.g. images[s].shape = (H, W, 3)
145
+ H, W = images.shape[1], images.shape[2]
146
+
147
+ # Pre-compute the color for each track i based on first visible position
148
+ track_colors_rgb = get_track_colors_by_position(
149
+ tracks, # shape (S, N, 2)
150
+ vis_mask_b=track_vis_mask if track_vis_mask is not None else None,
151
+ image_width=W,
152
+ image_height=H,
153
+ cmap_name=cmap_name,
154
+ )
155
+
156
+ # We'll accumulate each frame's drawn image in a list
157
+ frame_images = []
158
+
159
+ for s in range(S):
160
+ # shape => either (3, H, W) or (H, W, 3)
161
+ img = images[s]
162
+
163
+ # Convert to (H, W, 3)
164
+ if image_format == "CHW":
165
+ img = img.permute(1, 2, 0) # (H, W, 3)
166
+ # else "HWC", do nothing
167
+
168
+ img = img.numpy().astype(np.float32)
169
+
170
+ # Scale to [0,255] if needed
171
+ if normalize_mode == "[0,1]":
172
+ img = np.clip(img, 0, 1) * 255.0
173
+ elif normalize_mode == "[-1,1]":
174
+ img = (img + 1.0) * 0.5 * 255.0
175
+ img = np.clip(img, 0, 255.0)
176
+ # else no normalization
177
+
178
+ # Convert to uint8
179
+ img = img.astype(np.uint8)
180
+
181
+ # For drawing in OpenCV, convert to BGR
182
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
183
+
184
+ # Draw each visible track
185
+ cur_tracks = tracks[s] # shape (N, 2)
186
+ if track_vis_mask is not None:
187
+ valid_indices = torch.where(track_vis_mask[s])[0]
188
+ else:
189
+ valid_indices = range(N)
190
+
191
+ cur_tracks_np = cur_tracks.numpy()
192
+ for i in valid_indices:
193
+ x, y = cur_tracks_np[i]
194
+ pt = (int(round(x)), int(round(y)))
195
+
196
+ # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR
197
+ R, G, B = track_colors_rgb[i]
198
+ color_bgr = (int(B), int(G), int(R))
199
+ cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1)
200
+
201
+ # Convert back to RGB for consistent final saving:
202
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
203
+
204
+ # Save individual frame
205
+ frame_path = os.path.join(out_dir, f"frame_{s:04d}.png")
206
+ # Convert to BGR for OpenCV imwrite
207
+ frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
208
+ cv2.imwrite(frame_path, frame_bgr)
209
+
210
+ frame_images.append(img_rgb)
211
+
212
+ # Only create and save the grid image if save_grid is True
213
+ if save_grid:
214
+ # Calculate grid dimensions
215
+ num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division
216
+
217
+ # Create a grid of images
218
+ grid_img = None
219
+ for row in range(num_rows):
220
+ start_idx = row * frames_per_row
221
+ end_idx = min(start_idx + frames_per_row, S)
222
+
223
+ # Concatenate this row horizontally
224
+ row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1)
225
+
226
+ # If this row has fewer than frames_per_row images, pad with black
227
+ if end_idx - start_idx < frames_per_row:
228
+ padding_width = (frames_per_row - (end_idx - start_idx)) * W
229
+ padding = np.zeros((H, padding_width, 3), dtype=np.uint8)
230
+ row_img = np.concatenate([row_img, padding], axis=1)
231
+
232
+ # Add this row to the grid
233
+ if grid_img is None:
234
+ grid_img = row_img
235
+ else:
236
+ grid_img = np.concatenate([grid_img, row_img], axis=0)
237
+
238
+ out_path = os.path.join(out_dir, "tracks_grid.png")
239
+ # Convert back to BGR for OpenCV imwrite
240
+ grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR)
241
+ cv2.imwrite(out_path, grid_img_bgr)
242
+ print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}")
243
+
244
+ print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")
pyproject.toml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ authors = [{name = "Junyuan DENG"},{name="Heng LI"}]
3
+ dependencies = [
4
+ "numpy<2",
5
+ "Pillow",
6
+ "huggingface_hub",
7
+ "einops",
8
+ "safetensors",
9
+ "opencv-python",
10
+ "torch>=2.3.1",
11
+ "torchvision>=0.18.1",
12
+ "numpy==1.26.1",
13
+ "evo",
14
+ "plyfile",
15
+ #"python-opencv",
16
+ ]
17
+ name = "sailrecon"
18
+ requires-python = ">= 3.10"
19
+ version = "0.0.1"
20
+
21
+ [project.optional-dependencies]
22
+ demo = [
23
+ "gradio>=5.17.1",
24
+ "viser>=0.2.23",
25
+ "tqdm",
26
+ "hydra-core",
27
+ "omegaconf",
28
+ "opencv-python",
29
+ "scipy",
30
+ "onnxruntime",
31
+ "requests",
32
+ "trimesh",
33
+ "matplotlib",
34
+ ]
35
+
36
+ # Using setuptools as the build backend
37
+ [build-system]
38
+ requires = ["setuptools>=61.0", "wheel"]
39
+ build-backend = "setuptools.build_meta"
40
+
41
+ # setuptools configuration
42
+ [tool.setuptools.packages.find]
43
+ where = ["."]
44
+ include = ["sailrecon*"]
45
+
46
+ # Pixi configuration
47
+ [tool.pixi.workspace]
48
+ channels = ["conda-forge"]
49
+ platforms = ["linux-64"]
50
+
51
+ [tool.pixi.pypi-dependencies]
52
+ sailrecon = { path = ".", editable = true }
53
+
54
+ [tool.pixi.environments]
55
+ default = { solve-group = "default" }
56
+ demo = { features = ["demo"], solve-group = "default" }
57
+
58
+ [tool.pixi.tasks]
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.3.1
2
+ torchvision==0.18.1
3
+ numpy==1.26.1
4
+ Pillow
5
+ huggingface_hub
6
+ einops
7
+ safetensors
8
+ evo
9
+ plyfile
10
+ python-opencv
requirements_demo.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==5.17.1
2
+ viser==0.2.23
3
+ tqdm
4
+ hydra-core
5
+ omegaconf
6
+ opencv-python
7
+ scipy
8
+ onnxruntime
9
+ requests
10
+ trimesh
11
+ matplotlib
12
+ pydantic==2.10.6
13
+ # feel free to skip the dependencies below if you do not need demo_colmap.py
14
+ # pycolmap==3.10.0
15
+ # pyceres==2.3
16
+ # git+https://github.com/jytime/LightGlue.git#egg=lightglue
sailrecon/dependency/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .track_modules.base_track_predictor import BaseTrackerPredictor
2
+ from .track_modules.blocks import BasicEncoder, ShallowEncoder
3
+ from .track_modules.track_refine import refine_track
sailrecon/dependency/distortion.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Union
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ ArrayLike = Union[np.ndarray, torch.Tensor]
13
+
14
+
15
+ def _is_numpy(x: ArrayLike) -> bool:
16
+ return isinstance(x, np.ndarray)
17
+
18
+
19
+ def _is_torch(x: ArrayLike) -> bool:
20
+ return isinstance(x, torch.Tensor)
21
+
22
+
23
+ def _ensure_torch(x: ArrayLike) -> torch.Tensor:
24
+ """Convert input to torch tensor if it's not already one."""
25
+ if _is_numpy(x):
26
+ return torch.from_numpy(x)
27
+ elif _is_torch(x):
28
+ return x
29
+ else:
30
+ return torch.tensor(x)
31
+
32
+
33
+ def single_undistortion(params, tracks_normalized):
34
+ """
35
+ Apply undistortion to the normalized tracks using the given distortion parameters once.
36
+
37
+ Args:
38
+ params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN.
39
+ tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2].
40
+
41
+ Returns:
42
+ torch.Tensor: Undistorted normalized tracks tensor.
43
+ """
44
+ params = _ensure_torch(params)
45
+ tracks_normalized = _ensure_torch(tracks_normalized)
46
+
47
+ u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone()
48
+ u_undist, v_undist = apply_distortion(params, u, v)
49
+ return torch.stack([u_undist, v_undist], dim=-1)
50
+
51
+
52
+ def iterative_undistortion(
53
+ params,
54
+ tracks_normalized,
55
+ max_iterations=100,
56
+ max_step_norm=1e-10,
57
+ rel_step_size=1e-6,
58
+ ):
59
+ """
60
+ Iteratively undistort the normalized tracks using the given distortion parameters.
61
+
62
+ Args:
63
+ params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN.
64
+ tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2].
65
+ max_iterations (int): Maximum number of iterations for the undistortion process.
66
+ max_step_norm (float): Maximum step norm for convergence.
67
+ rel_step_size (float): Relative step size for numerical differentiation.
68
+
69
+ Returns:
70
+ torch.Tensor: Undistorted normalized tracks tensor.
71
+ """
72
+ params = _ensure_torch(params)
73
+ tracks_normalized = _ensure_torch(tracks_normalized)
74
+
75
+ B, N, _ = tracks_normalized.shape
76
+ u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone()
77
+ original_u, original_v = u.clone(), v.clone()
78
+
79
+ eps = torch.finfo(u.dtype).eps
80
+ for idx in range(max_iterations):
81
+ u_undist, v_undist = apply_distortion(params, u, v)
82
+ dx = original_u - u_undist
83
+ dy = original_v - v_undist
84
+
85
+ step_u = torch.clamp(torch.abs(u) * rel_step_size, min=eps)
86
+ step_v = torch.clamp(torch.abs(v) * rel_step_size, min=eps)
87
+
88
+ J_00 = (
89
+ apply_distortion(params, u + step_u, v)[0]
90
+ - apply_distortion(params, u - step_u, v)[0]
91
+ ) / (2 * step_u)
92
+ J_01 = (
93
+ apply_distortion(params, u, v + step_v)[0]
94
+ - apply_distortion(params, u, v - step_v)[0]
95
+ ) / (2 * step_v)
96
+ J_10 = (
97
+ apply_distortion(params, u + step_u, v)[1]
98
+ - apply_distortion(params, u - step_u, v)[1]
99
+ ) / (2 * step_u)
100
+ J_11 = (
101
+ apply_distortion(params, u, v + step_v)[1]
102
+ - apply_distortion(params, u, v - step_v)[1]
103
+ ) / (2 * step_v)
104
+
105
+ J = torch.stack(
106
+ [
107
+ torch.stack([J_00 + 1, J_01], dim=-1),
108
+ torch.stack([J_10, J_11 + 1], dim=-1),
109
+ ],
110
+ dim=-2,
111
+ )
112
+
113
+ delta = torch.linalg.solve(J, torch.stack([dx, dy], dim=-1))
114
+
115
+ u += delta[..., 0]
116
+ v += delta[..., 1]
117
+
118
+ if torch.max((delta**2).sum(dim=-1)) < max_step_norm:
119
+ break
120
+
121
+ return torch.stack([u, v], dim=-1)
122
+
123
+
124
+ def apply_distortion(extra_params, u, v):
125
+ """
126
+ Applies radial or OpenCV distortion to the given 2D points.
127
+
128
+ Args:
129
+ extra_params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
130
+ u (torch.Tensor or numpy.ndarray): Normalized x coordinates of shape Bxnum_tracks.
131
+ v (torch.Tensor or numpy.ndarray): Normalized y coordinates of shape Bxnum_tracks.
132
+
133
+ Returns:
134
+ points2D (torch.Tensor): Distorted 2D points of shape BxNx2.
135
+ """
136
+ extra_params = _ensure_torch(extra_params)
137
+ u = _ensure_torch(u)
138
+ v = _ensure_torch(v)
139
+
140
+ num_params = extra_params.shape[1]
141
+
142
+ if num_params == 1:
143
+ # Simple radial distortion
144
+ k = extra_params[:, 0]
145
+ u2 = u * u
146
+ v2 = v * v
147
+ r2 = u2 + v2
148
+ radial = k[:, None] * r2
149
+ du = u * radial
150
+ dv = v * radial
151
+
152
+ elif num_params == 2:
153
+ # RadialCameraModel distortion
154
+ k1, k2 = extra_params[:, 0], extra_params[:, 1]
155
+ u2 = u * u
156
+ v2 = v * v
157
+ r2 = u2 + v2
158
+ radial = k1[:, None] * r2 + k2[:, None] * r2 * r2
159
+ du = u * radial
160
+ dv = v * radial
161
+
162
+ elif num_params == 4:
163
+ # OpenCVCameraModel distortion
164
+ k1, k2, p1, p2 = (
165
+ extra_params[:, 0],
166
+ extra_params[:, 1],
167
+ extra_params[:, 2],
168
+ extra_params[:, 3],
169
+ )
170
+ u2 = u * u
171
+ v2 = v * v
172
+ uv = u * v
173
+ r2 = u2 + v2
174
+ radial = k1[:, None] * r2 + k2[:, None] * r2 * r2
175
+ du = u * radial + 2 * p1[:, None] * uv + p2[:, None] * (r2 + 2 * u2)
176
+ dv = v * radial + 2 * p2[:, None] * uv + p1[:, None] * (r2 + 2 * v2)
177
+ else:
178
+ raise ValueError("Unsupported number of distortion parameters")
179
+
180
+ u = u.clone() + du
181
+ v = v.clone() + dv
182
+
183
+ return u, v
184
+
185
+
186
+ if __name__ == "__main__":
187
+ import random
188
+
189
+ import pycolmap
190
+
191
+ max_diff = 0
192
+ for i in range(1000):
193
+ # Define distortion parameters (assuming 1 parameter for simplicity)
194
+ B = random.randint(1, 500)
195
+ track_num = random.randint(100, 1000)
196
+ params = torch.rand((B, 1), dtype=torch.float32) # Batch size 1, 4 parameters
197
+ tracks_normalized = torch.rand(
198
+ (B, track_num, 2), dtype=torch.float32
199
+ ) # Batch size 1, 5 points
200
+
201
+ # Undistort the tracks
202
+ undistorted_tracks = iterative_undistortion(params, tracks_normalized)
203
+
204
+ for b in range(B):
205
+ pycolmap_intri = np.array([1, 0, 0, params[b].item()])
206
+ pycam = pycolmap.Camera(
207
+ model="SIMPLE_RADIAL",
208
+ width=1,
209
+ height=1,
210
+ params=pycolmap_intri,
211
+ camera_id=0,
212
+ )
213
+
214
+ undistorted_tracks_pycolmap = pycam.cam_from_img(
215
+ tracks_normalized[b].numpy()
216
+ )
217
+ diff = (undistorted_tracks[b] - undistorted_tracks_pycolmap).abs().median()
218
+ max_diff = max(max_diff, diff)
219
+ print(f"diff: {diff}, max_diff: {max_diff}")
220
+
221
+ import pdb
222
+
223
+ pdb.set_trace()
sailrecon/dependency/np_to_pycolmap.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import pycolmap
9
+
10
+ from .projection import project_3D_points_np
11
+
12
+
13
+ def batch_np_matrix_to_pycolmap(
14
+ points3d,
15
+ extrinsics,
16
+ intrinsics,
17
+ tracks,
18
+ image_size,
19
+ masks=None,
20
+ max_reproj_error=None,
21
+ max_points3D_val=3000,
22
+ shared_camera=False,
23
+ camera_type="SIMPLE_PINHOLE",
24
+ extra_params=None,
25
+ min_inlier_per_frame=64,
26
+ points_rgb=None,
27
+ ):
28
+ """
29
+ Convert Batched NumPy Arrays to PyCOLMAP
30
+
31
+ Check https://github.com/colmap/pycolmap for more details about its format
32
+
33
+ NOTE that colmap expects images/cameras/points3D to be 1-indexed
34
+ so there is a +1 offset between colmap index and batch index
35
+
36
+
37
+ NOTE: different from VGGSfM, this function:
38
+ 1. Use np instead of torch
39
+ 2. Frame index and camera id starts from 1 rather than 0 (to fit the format of PyCOLMAP)
40
+ """
41
+ # points3d: Px3
42
+ # extrinsics: Nx3x4
43
+ # intrinsics: Nx3x3
44
+ # tracks: NxPx2
45
+ # masks: NxP
46
+ # image_size: 2, assume all the frames have been padded to the same size
47
+ # where N is the number of frames and P is the number of tracks
48
+
49
+ N, P, _ = tracks.shape
50
+ assert len(extrinsics) == N
51
+ assert len(intrinsics) == N
52
+ assert len(points3d) == P
53
+ assert image_size.shape[0] == 2
54
+
55
+ reproj_mask = None
56
+
57
+ if max_reproj_error is not None:
58
+ projected_points_2d, projected_points_cam = project_3D_points_np(
59
+ points3d, extrinsics, intrinsics
60
+ )
61
+ projected_diff = np.linalg.norm(projected_points_2d - tracks, axis=-1)
62
+ projected_points_2d[projected_points_cam[:, -1] <= 0] = 1e6
63
+ reproj_mask = projected_diff < max_reproj_error
64
+
65
+ if masks is not None and reproj_mask is not None:
66
+ masks = np.logical_and(masks, reproj_mask)
67
+ elif masks is not None:
68
+ masks = masks
69
+ else:
70
+ masks = reproj_mask
71
+
72
+ assert masks is not None
73
+
74
+ if masks.sum(1).min() < min_inlier_per_frame:
75
+ print(f"Not enough inliers per frame, skip BA.")
76
+ return None, None
77
+
78
+ # Reconstruction object, following the format of PyCOLMAP/COLMAP
79
+ reconstruction = pycolmap.Reconstruction()
80
+
81
+ inlier_num = masks.sum(0)
82
+ valid_mask = inlier_num >= 2 # a track is invalid if without two inliers
83
+ valid_idx = np.nonzero(valid_mask)[0]
84
+
85
+ # Only add 3D points that have sufficient 2D points
86
+ for vidx in valid_idx:
87
+ # Use RGB colors if provided, otherwise use zeros
88
+ rgb = points_rgb[vidx] if points_rgb is not None else np.zeros(3)
89
+ reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), rgb)
90
+
91
+ num_points3D = len(valid_idx)
92
+ camera = None
93
+ # frame idx
94
+ for fidx in range(N):
95
+ # set camera
96
+ if camera is None or (not shared_camera):
97
+ pycolmap_intri = _build_pycolmap_intri(
98
+ fidx, intrinsics, camera_type, extra_params
99
+ )
100
+
101
+ camera = pycolmap.Camera(
102
+ model=camera_type,
103
+ width=image_size[0],
104
+ height=image_size[1],
105
+ params=pycolmap_intri,
106
+ camera_id=fidx + 1,
107
+ )
108
+
109
+ # add camera
110
+ reconstruction.add_camera(camera)
111
+
112
+ # set image
113
+ cam_from_world = pycolmap.Rigid3d(
114
+ pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3]
115
+ ) # Rot and Trans
116
+
117
+ image = pycolmap.Image(
118
+ id=fidx + 1,
119
+ name=f"image_{fidx + 1}",
120
+ camera_id=camera.camera_id,
121
+ cam_from_world=cam_from_world,
122
+ )
123
+
124
+ points2D_list = []
125
+
126
+ point2D_idx = 0
127
+
128
+ # NOTE point3D_id start by 1
129
+ for point3D_id in range(1, num_points3D + 1):
130
+ original_track_idx = valid_idx[point3D_id - 1]
131
+
132
+ if (reconstruction.points3D[point3D_id].xyz < max_points3D_val).all():
133
+ if masks[fidx][original_track_idx]:
134
+ # It seems we don't need +0.5 for BA
135
+ point2D_xy = tracks[fidx][original_track_idx]
136
+ # Please note when adding the Point2D object
137
+ # It not only requires the 2D xy location, but also the id to 3D point
138
+ points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id))
139
+
140
+ # add element
141
+ track = reconstruction.points3D[point3D_id].track
142
+ track.add_element(fidx + 1, point2D_idx)
143
+ point2D_idx += 1
144
+
145
+ assert point2D_idx == len(points2D_list)
146
+
147
+ try:
148
+ image.points2D = pycolmap.ListPoint2D(points2D_list)
149
+ image.registered = True
150
+ except:
151
+ print(f"frame {fidx + 1} is out of BA")
152
+ image.registered = False
153
+
154
+ # add image
155
+ reconstruction.add_image(image)
156
+
157
+ return reconstruction, valid_mask
158
+
159
+
160
+ def pycolmap_to_batch_np_matrix(
161
+ reconstruction, device="cpu", camera_type="SIMPLE_PINHOLE"
162
+ ):
163
+ """
164
+ Convert a PyCOLMAP Reconstruction Object to batched NumPy arrays.
165
+
166
+ Args:
167
+ reconstruction (pycolmap.Reconstruction): The reconstruction object from PyCOLMAP.
168
+ device (str): Ignored in NumPy version (kept for API compatibility).
169
+ camera_type (str): The type of camera model used (default: "SIMPLE_PINHOLE").
170
+
171
+ Returns:
172
+ tuple: A tuple containing points3D, extrinsics, intrinsics, and optionally extra_params.
173
+ """
174
+
175
+ num_images = len(reconstruction.images)
176
+ max_points3D_id = max(reconstruction.point3D_ids())
177
+ points3D = np.zeros((max_points3D_id, 3))
178
+
179
+ for point3D_id in reconstruction.points3D:
180
+ points3D[point3D_id - 1] = reconstruction.points3D[point3D_id].xyz
181
+
182
+ extrinsics = []
183
+ intrinsics = []
184
+
185
+ extra_params = [] if camera_type == "SIMPLE_RADIAL" else None
186
+
187
+ for i in range(num_images):
188
+ # Extract and append extrinsics
189
+ pyimg = reconstruction.images[i + 1]
190
+ pycam = reconstruction.cameras[pyimg.camera_id]
191
+ matrix = pyimg.cam_from_world.matrix()
192
+ extrinsics.append(matrix)
193
+
194
+ # Extract and append intrinsics
195
+ calibration_matrix = pycam.calibration_matrix()
196
+ intrinsics.append(calibration_matrix)
197
+
198
+ if camera_type == "SIMPLE_RADIAL":
199
+ extra_params.append(pycam.params[-1])
200
+
201
+ # Convert lists to NumPy arrays instead of torch tensors
202
+ extrinsics = np.stack(extrinsics)
203
+ intrinsics = np.stack(intrinsics)
204
+
205
+ if camera_type == "SIMPLE_RADIAL":
206
+ extra_params = np.stack(extra_params)
207
+ extra_params = extra_params[:, None]
208
+
209
+ return points3D, extrinsics, intrinsics, extra_params
210
+
211
+
212
+ ########################################################
213
+
214
+
215
+ def batch_np_matrix_to_pycolmap_wo_track(
216
+ points3d,
217
+ points_xyf,
218
+ points_rgb,
219
+ extrinsics,
220
+ intrinsics,
221
+ image_size,
222
+ shared_camera=False,
223
+ camera_type="SIMPLE_PINHOLE",
224
+ ):
225
+ """
226
+ Convert Batched NumPy Arrays to PyCOLMAP
227
+
228
+ Different from batch_np_matrix_to_pycolmap, this function does not use tracks.
229
+
230
+ It saves points3d to colmap reconstruction format only to serve as init for Gaussians or other nvs methods.
231
+
232
+ Do NOT use this for BA.
233
+ """
234
+ # points3d: Px3
235
+ # points_xyf: Px3, with x, y coordinates and frame indices
236
+ # points_rgb: Px3, rgb colors
237
+ # extrinsics: Nx3x4
238
+ # intrinsics: Nx3x3
239
+ # image_size: 2, assume all the frames have been padded to the same size
240
+ # where N is the number of frames and P is the number of tracks
241
+
242
+ N = len(extrinsics)
243
+ P = len(points3d)
244
+
245
+ # Reconstruction object, following the format of PyCOLMAP/COLMAP
246
+ reconstruction = pycolmap.Reconstruction()
247
+
248
+ for vidx in range(P):
249
+ reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), points_rgb[vidx])
250
+
251
+ camera = None
252
+ # frame idx
253
+ for fidx in range(N):
254
+ # set camera
255
+ if camera is None or (not shared_camera):
256
+ pycolmap_intri = _build_pycolmap_intri(fidx, intrinsics, camera_type)
257
+
258
+ camera = pycolmap.Camera(
259
+ model=camera_type,
260
+ width=image_size[0],
261
+ height=image_size[1],
262
+ params=pycolmap_intri,
263
+ camera_id=fidx + 1,
264
+ )
265
+
266
+ # add camera
267
+ reconstruction.add_camera(camera)
268
+
269
+ # set image
270
+ cam_from_world = pycolmap.Rigid3d(
271
+ pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3]
272
+ ) # Rot and Trans
273
+
274
+ image = pycolmap.Image(
275
+ id=fidx + 1,
276
+ name=f"image_{fidx + 1}",
277
+ camera_id=camera.camera_id,
278
+ cam_from_world=cam_from_world,
279
+ )
280
+
281
+ points2D_list = []
282
+
283
+ point2D_idx = 0
284
+
285
+ points_belong_to_fidx = points_xyf[:, 2].astype(np.int32) == fidx
286
+ points_belong_to_fidx = np.nonzero(points_belong_to_fidx)[0]
287
+
288
+ for point3D_batch_idx in points_belong_to_fidx:
289
+ point3D_id = point3D_batch_idx + 1
290
+ point2D_xyf = points_xyf[point3D_batch_idx]
291
+ point2D_xy = point2D_xyf[:2]
292
+ points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id))
293
+
294
+ # add element
295
+ track = reconstruction.points3D[point3D_id].track
296
+ track.add_element(fidx + 1, point2D_idx)
297
+ point2D_idx += 1
298
+
299
+ assert point2D_idx == len(points2D_list)
300
+
301
+ try:
302
+ image.points2D = pycolmap.ListPoint2D(points2D_list)
303
+ image.registered = True
304
+ except:
305
+ print(f"frame {fidx + 1} does not have any points")
306
+ image.registered = False
307
+
308
+ # add image
309
+ reconstruction.add_image(image)
310
+
311
+ return reconstruction
312
+
313
+
314
+ def _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params=None):
315
+ """
316
+ Helper function to get camera parameters based on camera type.
317
+
318
+ Args:
319
+ fidx: Frame index
320
+ intrinsics: Camera intrinsic parameters
321
+ camera_type: Type of camera model
322
+ extra_params: Additional parameters for certain camera types
323
+
324
+ Returns:
325
+ pycolmap_intri: NumPy array of camera parameters
326
+ """
327
+ if camera_type == "PINHOLE":
328
+ pycolmap_intri = np.array(
329
+ [
330
+ intrinsics[fidx][0, 0],
331
+ intrinsics[fidx][1, 1],
332
+ intrinsics[fidx][0, 2],
333
+ intrinsics[fidx][1, 2],
334
+ ]
335
+ )
336
+ elif camera_type == "SIMPLE_PINHOLE":
337
+ focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2
338
+ pycolmap_intri = np.array(
339
+ [focal, intrinsics[fidx][0, 2], intrinsics[fidx][1, 2]]
340
+ )
341
+ elif camera_type == "SIMPLE_RADIAL":
342
+ raise NotImplementedError("SIMPLE_RADIAL is not supported yet")
343
+ focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2
344
+ pycolmap_intri = np.array(
345
+ [
346
+ focal,
347
+ intrinsics[fidx][0, 2],
348
+ intrinsics[fidx][1, 2],
349
+ extra_params[fidx][0],
350
+ ]
351
+ )
352
+ else:
353
+ raise ValueError(f"Camera type {camera_type} is not supported yet")
354
+
355
+ return pycolmap_intri
sailrecon/dependency/projection.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from .distortion import apply_distortion
11
+
12
+
13
+ def img_from_cam_np(
14
+ intrinsics: np.ndarray,
15
+ points_cam: np.ndarray,
16
+ extra_params: np.ndarray | None = None,
17
+ default: float = 0.0,
18
+ ) -> np.ndarray:
19
+ """
20
+ Apply intrinsics (and optional radial distortion) to camera-space points.
21
+
22
+ Args
23
+ ----
24
+ intrinsics : (B,3,3) camera matrix K.
25
+ points_cam : (B,3,N) homogeneous camera coords (x, y, z)ᵀ.
26
+ extra_params: (B, N) or (B, k) distortion params (k = 1,2,4) or None.
27
+ default : value used for np.nan replacement.
28
+
29
+ Returns
30
+ -------
31
+ points2D : (B,N,2) pixel coordinates.
32
+ """
33
+ # 1. perspective divide ───────────────────────────────────────
34
+ z = points_cam[:, 2:3, :] # (B,1,N)
35
+ points_cam_norm = points_cam / z # (B,3,N)
36
+ uv = points_cam_norm[:, :2, :] # (B,2,N)
37
+
38
+ # 2. optional distortion ──────────────────────────────────────
39
+ if extra_params is not None:
40
+ uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1])
41
+ uv = np.stack([uu, vv], axis=1) # (B,2,N)
42
+
43
+ # 3. homogeneous coords then K multiplication ─────────────────
44
+ ones = np.ones_like(uv[:, :1, :]) # (B,1,N)
45
+ points_cam_h = np.concatenate([uv, ones], axis=1) # (B,3,N)
46
+
47
+ # batched mat-mul: K · [u v 1]ᵀ
48
+ points2D_h = np.einsum("bij,bjk->bik", intrinsics, points_cam_h) # (B,3,N)
49
+ points2D = np.nan_to_num(points2D_h[:, :2, :], nan=default) # (B,2,N)
50
+
51
+ return points2D.transpose(0, 2, 1) # (B,N,2)
52
+
53
+
54
+ def project_3D_points_np(
55
+ points3D: np.ndarray,
56
+ extrinsics: np.ndarray,
57
+ intrinsics: np.ndarray | None = None,
58
+ extra_params: np.ndarray | None = None,
59
+ *,
60
+ default: float = 0.0,
61
+ only_points_cam: bool = False,
62
+ ):
63
+ """
64
+ NumPy clone of ``project_3D_points``.
65
+
66
+ Parameters
67
+ ----------
68
+ points3D : (N,3) world-space points.
69
+ extrinsics : (B,3,4) [R|t] matrix for each of B cameras.
70
+ intrinsics : (B,3,3) K matrix (optional if you only need cam-space).
71
+ extra_params : (B,k) or (B,N) distortion parameters (k ∈ {1,2,4}) or None.
72
+ default : value used to replace NaNs.
73
+ only_points_cam : if True, skip the projection and return points_cam with points2D as None.
74
+
75
+ Returns
76
+ -------
77
+ (points2D, points_cam) : A tuple where points2D is (B,N,2) pixel coords or None if only_points_cam=True,
78
+ and points_cam is (B,3,N) camera-space coordinates.
79
+ """
80
+ # ----- 0. prep sizes -----------------------------------------------------
81
+ N = points3D.shape[0] # #points
82
+ B = extrinsics.shape[0] # #cameras
83
+
84
+ # ----- 1. world → homogeneous -------------------------------------------
85
+ w_h = np.ones((N, 1), dtype=points3D.dtype)
86
+ points3D_h = np.concatenate([points3D, w_h], axis=1) # (N,4)
87
+
88
+ # broadcast to every camera (no actual copying with np.broadcast_to) ------
89
+ points3D_h_B = np.broadcast_to(points3D_h, (B, N, 4)) # (B,N,4)
90
+
91
+ # ----- 2. apply extrinsics (camera frame) ------------------------------
92
+ # X_cam = E · X_hom
93
+ # einsum: E_(b i j) · X_(b n j) → (b n i)
94
+ points_cam = np.einsum("bij,bnj->bni", extrinsics, points3D_h_B) # (B,N,3)
95
+ points_cam = points_cam.transpose(0, 2, 1) # (B,3,N)
96
+
97
+ if only_points_cam:
98
+ return None, points_cam
99
+
100
+ # ----- 3. intrinsics + distortion ---------------------------------------
101
+ if intrinsics is None:
102
+ raise ValueError("`intrinsics` must be provided unless only_points_cam=True")
103
+
104
+ points2D = img_from_cam_np(
105
+ intrinsics, points_cam, extra_params=extra_params, default=default
106
+ )
107
+
108
+ return points2D, points_cam
109
+
110
+
111
+ def project_3D_points(
112
+ points3D,
113
+ extrinsics,
114
+ intrinsics=None,
115
+ extra_params=None,
116
+ default=0,
117
+ only_points_cam=False,
118
+ ):
119
+ """
120
+ Transforms 3D points to 2D using extrinsic and intrinsic parameters.
121
+ Args:
122
+ points3D (torch.Tensor): 3D points of shape Px3.
123
+ extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4.
124
+ intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3.
125
+ extra_params (torch.Tensor): Extra parameters of shape BxN, used for radial distortion.
126
+ default (float): Default value to replace NaNs.
127
+ only_points_cam (bool): If True, skip the projection and return points2D as None.
128
+
129
+ Returns:
130
+ tuple: (points2D, points_cam) where points2D is of shape BxNx2 or None if only_points_cam=True,
131
+ and points_cam is of shape Bx3xN.
132
+ """
133
+ with torch.cuda.amp.autocast(dtype=torch.double):
134
+ N = points3D.shape[0] # Number of points
135
+ B = extrinsics.shape[0] # Batch size, i.e., number of cameras
136
+ points3D_homogeneous = torch.cat(
137
+ [points3D, torch.ones_like(points3D[..., 0:1])], dim=1
138
+ ) # Nx4
139
+ # Reshape for batch processing
140
+ points3D_homogeneous = points3D_homogeneous.unsqueeze(0).expand(
141
+ B, -1, -1
142
+ ) # BxNx4
143
+
144
+ # Step 1: Apply extrinsic parameters
145
+ # Transform 3D points to camera coordinate system for all cameras
146
+ points_cam = torch.bmm(extrinsics, points3D_homogeneous.transpose(-1, -2))
147
+
148
+ if only_points_cam:
149
+ return None, points_cam
150
+
151
+ # Step 2: Apply intrinsic parameters and (optional) distortion
152
+ points2D = img_from_cam(intrinsics, points_cam, extra_params, default)
153
+
154
+ return points2D, points_cam
155
+
156
+
157
+ def img_from_cam(intrinsics, points_cam, extra_params=None, default=0.0):
158
+ """
159
+ Applies intrinsic parameters and optional distortion to the given 3D points.
160
+
161
+ Args:
162
+ intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3.
163
+ points_cam (torch.Tensor): 3D points in camera coordinates of shape Bx3xN.
164
+ extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
165
+ default (float, optional): Default value to replace NaNs in the output.
166
+
167
+ Returns:
168
+ points2D (torch.Tensor): 2D points in pixel coordinates of shape BxNx2.
169
+ """
170
+
171
+ # Normalize by the third coordinate (homogeneous division)
172
+ points_cam = points_cam / points_cam[:, 2:3, :]
173
+ # Extract uv
174
+ uv = points_cam[:, :2, :]
175
+
176
+ # Apply distortion if extra_params are provided
177
+ if extra_params is not None:
178
+ uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1])
179
+ uv = torch.stack([uu, vv], dim=1)
180
+
181
+ # Prepare points_cam for batch matrix multiplication
182
+ points_cam_homo = torch.cat((uv, torch.ones_like(uv[:, :1, :])), dim=1) # Bx3xN
183
+ # Apply intrinsic parameters using batch matrix multiplication
184
+ points2D_homo = torch.bmm(intrinsics, points_cam_homo) # Bx3xN
185
+
186
+ # Extract x and y coordinates
187
+ points2D = points2D_homo[:, :2, :] # Bx2xN
188
+
189
+ # Replace NaNs with default value
190
+ points2D = torch.nan_to_num(points2D, nan=default)
191
+
192
+ return points2D.transpose(1, 2) # BxNx2
193
+
194
+
195
+ if __name__ == "__main__":
196
+ # Set up example input
197
+ B, N = 24, 10240
198
+
199
+ for _ in range(100):
200
+ points3D = np.random.rand(N, 3).astype(np.float64)
201
+ extrinsics = np.random.rand(B, 3, 4).astype(np.float64)
202
+ intrinsics = np.random.rand(B, 3, 3).astype(np.float64)
203
+
204
+ # Convert to torch tensors
205
+ points3D_torch = torch.tensor(points3D)
206
+ extrinsics_torch = torch.tensor(extrinsics)
207
+ intrinsics_torch = torch.tensor(intrinsics)
208
+
209
+ # Run NumPy implementation
210
+ points2D_np, points_cam_np = project_3D_points_np(
211
+ points3D, extrinsics, intrinsics
212
+ )
213
+
214
+ # Run torch implementation
215
+ points2D_torch, points_cam_torch = project_3D_points(
216
+ points3D_torch, extrinsics_torch, intrinsics_torch
217
+ )
218
+
219
+ # Convert torch output to numpy
220
+ points2D_torch_np = points2D_torch.detach().numpy()
221
+ points_cam_torch_np = points_cam_torch.detach().numpy()
222
+
223
+ # Compute difference
224
+ diff = np.abs(points2D_np - points2D_torch_np)
225
+ print("Difference between NumPy and PyTorch implementations:")
226
+ print(diff)
227
+
228
+ # Check max error
229
+ max_diff = np.max(diff)
230
+ print(f"Maximum difference: {max_diff}")
231
+
232
+ if np.allclose(points2D_np, points2D_torch_np, atol=1e-6):
233
+ print("Implementations match closely.")
234
+ else:
235
+ print("Significant differences detected.")
236
+
237
+ if points_cam_np is not None:
238
+ points_cam_diff = np.abs(points_cam_np - points_cam_torch_np)
239
+ print("Difference between NumPy and PyTorch camera-space coordinates:")
240
+ print(points_cam_diff)
241
+
242
+ # Check max error
243
+ max_cam_diff = np.max(points_cam_diff)
244
+ print(f"Maximum camera-space coordinate difference: {max_cam_diff}")
245
+
246
+ if np.allclose(points_cam_np, points_cam_torch_np, atol=1e-6):
247
+ print("Camera-space coordinates match closely.")
248
+ else:
249
+ print("Significant differences detected in camera-space coordinates.")
sailrecon/dependency/track_modules/__init__.py ADDED
File without changes
sailrecon/dependency/track_modules/base_track_predictor.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from einops import rearrange, repeat
10
+
11
+ from .blocks import CorrBlock, EfficientUpdateFormer
12
+ from .utils import get_2d_embedding, get_2d_sincos_pos_embed, sample_features4d
13
+
14
+
15
+ class BaseTrackerPredictor(nn.Module):
16
+ def __init__(
17
+ self,
18
+ stride=4,
19
+ corr_levels=5,
20
+ corr_radius=4,
21
+ latent_dim=128,
22
+ hidden_size=384,
23
+ use_spaceatt=True,
24
+ depth=6,
25
+ fine=False,
26
+ ):
27
+ super(BaseTrackerPredictor, self).__init__()
28
+ """
29
+ The base template to create a track predictor
30
+
31
+ Modified from https://github.com/facebookresearch/co-tracker/
32
+ """
33
+
34
+ self.stride = stride
35
+ self.latent_dim = latent_dim
36
+ self.corr_levels = corr_levels
37
+ self.corr_radius = corr_radius
38
+ self.hidden_size = hidden_size
39
+ self.fine = fine
40
+
41
+ self.flows_emb_dim = latent_dim // 2
42
+ self.transformer_dim = (
43
+ self.corr_levels * (self.corr_radius * 2 + 1) ** 2 + self.latent_dim * 2
44
+ )
45
+
46
+ if self.fine:
47
+ # TODO this is the old dummy code, will remove this when we train next model
48
+ self.transformer_dim += 4 if self.transformer_dim % 2 == 0 else 5
49
+ else:
50
+ self.transformer_dim += (4 - self.transformer_dim % 4) % 4
51
+
52
+ space_depth = depth if use_spaceatt else 0
53
+ time_depth = depth
54
+
55
+ self.updateformer = EfficientUpdateFormer(
56
+ space_depth=space_depth,
57
+ time_depth=time_depth,
58
+ input_dim=self.transformer_dim,
59
+ hidden_size=self.hidden_size,
60
+ output_dim=self.latent_dim + 2,
61
+ mlp_ratio=4.0,
62
+ add_space_attn=use_spaceatt,
63
+ )
64
+
65
+ self.norm = nn.GroupNorm(1, self.latent_dim)
66
+
67
+ # A linear layer to update track feats at each iteration
68
+ self.ffeat_updater = nn.Sequential(
69
+ nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()
70
+ )
71
+
72
+ if not self.fine:
73
+ self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
74
+
75
+ def forward(
76
+ self, query_points, fmaps=None, iters=4, return_feat=False, down_ratio=1
77
+ ):
78
+ """
79
+ query_points: B x N x 2, the number of batches, tracks, and xy
80
+ fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
81
+ note HH and WW is the size of feature maps instead of original images
82
+ """
83
+ B, N, D = query_points.shape
84
+ B, S, C, HH, WW = fmaps.shape
85
+
86
+ assert D == 2
87
+
88
+ # Scale the input query_points because we may downsample the images
89
+ # by down_ratio or self.stride
90
+ # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
91
+ # its query_points should be query_points/4
92
+ if down_ratio > 1:
93
+ query_points = query_points / float(down_ratio)
94
+ query_points = query_points / float(self.stride)
95
+
96
+ # Init with coords as the query points
97
+ # It means the search will start from the position of query points at the reference frames
98
+ coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
99
+
100
+ # Sample/extract the features of the query points in the query frame
101
+ query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
102
+
103
+ # init track feats by query feats
104
+ track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
105
+ # back up the init coords
106
+ coords_backup = coords.clone()
107
+
108
+ # Construct the correlation block
109
+
110
+ fcorr_fn = CorrBlock(
111
+ fmaps, num_levels=self.corr_levels, radius=self.corr_radius
112
+ )
113
+
114
+ coord_preds = []
115
+
116
+ # Iterative Refinement
117
+ for itr in range(iters):
118
+ # Detach the gradients from the last iteration
119
+ # (in my experience, not very important for performance)
120
+ coords = coords.detach()
121
+
122
+ # Compute the correlation (check the implementation of CorrBlock)
123
+
124
+ fcorr_fn.corr(track_feats)
125
+ fcorrs = fcorr_fn.sample(coords) # B, S, N, corrdim
126
+
127
+ corrdim = fcorrs.shape[3]
128
+
129
+ fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corrdim)
130
+
131
+ # Movement of current coords relative to query points
132
+ flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
133
+
134
+ flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
135
+
136
+ # (In my trials, it is also okay to just add the flows_emb instead of concat)
137
+ flows_emb = torch.cat([flows_emb, flows], dim=-1)
138
+
139
+ track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(
140
+ B * N, S, self.latent_dim
141
+ )
142
+
143
+ # Concatenate them as the input for the transformers
144
+ transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
145
+
146
+ if transformer_input.shape[2] < self.transformer_dim:
147
+ # pad the features to match the dimension
148
+ pad_dim = self.transformer_dim - transformer_input.shape[2]
149
+ pad = torch.zeros_like(flows_emb[..., 0:pad_dim])
150
+ transformer_input = torch.cat([transformer_input, pad], dim=2)
151
+
152
+ # 2D positional embed
153
+ # TODO: this can be much simplified
154
+ pos_embed = get_2d_sincos_pos_embed(
155
+ self.transformer_dim, grid_size=(HH, WW)
156
+ ).to(query_points.device)
157
+ sampled_pos_emb = sample_features4d(
158
+ pos_embed.expand(B, -1, -1, -1), coords[:, 0]
159
+ )
160
+ sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(
161
+ 1
162
+ )
163
+
164
+ x = transformer_input + sampled_pos_emb
165
+
166
+ # B, N, S, C
167
+ x = rearrange(x, "(b n) s d -> b n s d", b=B)
168
+
169
+ # Compute the delta coordinates and delta track features
170
+ delta = self.updateformer(x)
171
+ # BN, S, C
172
+ delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
173
+ delta_coords_ = delta[:, :, :2]
174
+ delta_feats_ = delta[:, :, 2:]
175
+
176
+ track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
177
+ delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
178
+
179
+ # Update the track features
180
+ track_feats_ = self.ffeat_updater(self.norm(delta_feats_)) + track_feats_
181
+ track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(
182
+ 0, 2, 1, 3
183
+ ) # BxSxNxC
184
+
185
+ # B x S x N x 2
186
+ coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
187
+
188
+ # Force coord0 as query
189
+ # because we assume the query points should not be changed
190
+ coords[:, 0] = coords_backup[:, 0]
191
+
192
+ # The predicted tracks are in the original image scale
193
+ if down_ratio > 1:
194
+ coord_preds.append(coords * self.stride * down_ratio)
195
+ else:
196
+ coord_preds.append(coords * self.stride)
197
+
198
+ # B, S, N
199
+ if not self.fine:
200
+ vis_e = self.vis_predictor(
201
+ track_feats.reshape(B * S * N, self.latent_dim)
202
+ ).reshape(B, S, N)
203
+ vis_e = torch.sigmoid(vis_e)
204
+ else:
205
+ vis_e = None
206
+
207
+ if return_feat:
208
+ return coord_preds, vis_e, track_feats, query_track_feat
209
+ else:
210
+ return coord_preds, vis_e
sailrecon/dependency/track_modules/blocks.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Modified from https://github.com/facebookresearch/co-tracker/
9
+
10
+
11
+ import collections
12
+ from functools import partial
13
+ from itertools import repeat
14
+ from typing import Callable
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torch import Tensor
20
+
21
+ from .modules import AttnBlock, CrossAttnBlock, Mlp, ResidualBlock
22
+ from .utils import bilinear_sampler
23
+
24
+
25
+ class BasicEncoder(nn.Module):
26
+ def __init__(self, input_dim=3, output_dim=128, stride=4):
27
+ super(BasicEncoder, self).__init__()
28
+
29
+ self.stride = stride
30
+ self.norm_fn = "instance"
31
+ self.in_planes = output_dim // 2
32
+
33
+ self.norm1 = nn.InstanceNorm2d(self.in_planes)
34
+ self.norm2 = nn.InstanceNorm2d(output_dim * 2)
35
+
36
+ self.conv1 = nn.Conv2d(
37
+ input_dim,
38
+ self.in_planes,
39
+ kernel_size=7,
40
+ stride=2,
41
+ padding=3,
42
+ padding_mode="zeros",
43
+ )
44
+ self.relu1 = nn.ReLU(inplace=True)
45
+ self.layer1 = self._make_layer(output_dim // 2, stride=1)
46
+ self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
47
+ self.layer3 = self._make_layer(output_dim, stride=2)
48
+ self.layer4 = self._make_layer(output_dim, stride=2)
49
+
50
+ self.conv2 = nn.Conv2d(
51
+ output_dim * 3 + output_dim // 4,
52
+ output_dim * 2,
53
+ kernel_size=3,
54
+ padding=1,
55
+ padding_mode="zeros",
56
+ )
57
+ self.relu2 = nn.ReLU(inplace=True)
58
+ self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
59
+
60
+ for m in self.modules():
61
+ if isinstance(m, nn.Conv2d):
62
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
63
+ elif isinstance(m, (nn.InstanceNorm2d)):
64
+ if m.weight is not None:
65
+ nn.init.constant_(m.weight, 1)
66
+ if m.bias is not None:
67
+ nn.init.constant_(m.bias, 0)
68
+
69
+ def _make_layer(self, dim, stride=1):
70
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
71
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
72
+ layers = (layer1, layer2)
73
+
74
+ self.in_planes = dim
75
+ return nn.Sequential(*layers)
76
+
77
+ def forward(self, x):
78
+ _, _, H, W = x.shape
79
+
80
+ x = self.conv1(x)
81
+ x = self.norm1(x)
82
+ x = self.relu1(x)
83
+
84
+ a = self.layer1(x)
85
+ b = self.layer2(a)
86
+ c = self.layer3(b)
87
+ d = self.layer4(c)
88
+
89
+ a = _bilinear_intepolate(a, self.stride, H, W)
90
+ b = _bilinear_intepolate(b, self.stride, H, W)
91
+ c = _bilinear_intepolate(c, self.stride, H, W)
92
+ d = _bilinear_intepolate(d, self.stride, H, W)
93
+
94
+ x = self.conv2(torch.cat([a, b, c, d], dim=1))
95
+ x = self.norm2(x)
96
+ x = self.relu2(x)
97
+ x = self.conv3(x)
98
+ return x
99
+
100
+
101
+ class ShallowEncoder(nn.Module):
102
+ def __init__(self, input_dim=3, output_dim=32, stride=1, norm_fn="instance"):
103
+ super(ShallowEncoder, self).__init__()
104
+ self.stride = stride
105
+ self.norm_fn = norm_fn
106
+ self.in_planes = output_dim
107
+
108
+ if self.norm_fn == "group":
109
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes)
110
+ self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2)
111
+ elif self.norm_fn == "batch":
112
+ self.norm1 = nn.BatchNorm2d(self.in_planes)
113
+ self.norm2 = nn.BatchNorm2d(output_dim * 2)
114
+ elif self.norm_fn == "instance":
115
+ self.norm1 = nn.InstanceNorm2d(self.in_planes)
116
+ self.norm2 = nn.InstanceNorm2d(output_dim * 2)
117
+ elif self.norm_fn == "none":
118
+ self.norm1 = nn.Sequential()
119
+
120
+ self.conv1 = nn.Conv2d(
121
+ input_dim,
122
+ self.in_planes,
123
+ kernel_size=3,
124
+ stride=2,
125
+ padding=1,
126
+ padding_mode="zeros",
127
+ )
128
+ self.relu1 = nn.ReLU(inplace=True)
129
+
130
+ self.layer1 = self._make_layer(output_dim, stride=2)
131
+
132
+ self.layer2 = self._make_layer(output_dim, stride=2)
133
+ self.conv2 = nn.Conv2d(output_dim, output_dim, kernel_size=1)
134
+
135
+ for m in self.modules():
136
+ if isinstance(m, nn.Conv2d):
137
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
138
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
139
+ if m.weight is not None:
140
+ nn.init.constant_(m.weight, 1)
141
+ if m.bias is not None:
142
+ nn.init.constant_(m.bias, 0)
143
+
144
+ def _make_layer(self, dim, stride=1):
145
+ self.in_planes = dim
146
+
147
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
148
+ return layer1
149
+
150
+ def forward(self, x):
151
+ _, _, H, W = x.shape
152
+
153
+ x = self.conv1(x)
154
+ x = self.norm1(x)
155
+ x = self.relu1(x)
156
+
157
+ tmp = self.layer1(x)
158
+ x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True)
159
+ tmp = self.layer2(tmp)
160
+ x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True)
161
+ tmp = None
162
+ x = self.conv2(x) + x
163
+
164
+ x = F.interpolate(
165
+ x, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True
166
+ )
167
+
168
+ return x
169
+
170
+
171
+ def _bilinear_intepolate(x, stride, H, W):
172
+ return F.interpolate(
173
+ x, (H // stride, W // stride), mode="bilinear", align_corners=True
174
+ )
175
+
176
+
177
+ class EfficientUpdateFormer(nn.Module):
178
+ """
179
+ Transformer model that updates track estimates.
180
+ """
181
+
182
+ def __init__(
183
+ self,
184
+ space_depth=6,
185
+ time_depth=6,
186
+ input_dim=320,
187
+ hidden_size=384,
188
+ num_heads=8,
189
+ output_dim=130,
190
+ mlp_ratio=4.0,
191
+ add_space_attn=True,
192
+ num_virtual_tracks=64,
193
+ ):
194
+ super().__init__()
195
+
196
+ self.out_channels = 2
197
+ self.num_heads = num_heads
198
+ self.hidden_size = hidden_size
199
+ self.add_space_attn = add_space_attn
200
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
201
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
202
+ self.num_virtual_tracks = num_virtual_tracks
203
+
204
+ if self.add_space_attn:
205
+ self.virual_tracks = nn.Parameter(
206
+ torch.randn(1, num_virtual_tracks, 1, hidden_size)
207
+ )
208
+ else:
209
+ self.virual_tracks = None
210
+
211
+ self.time_blocks = nn.ModuleList(
212
+ [
213
+ AttnBlock(
214
+ hidden_size,
215
+ num_heads,
216
+ mlp_ratio=mlp_ratio,
217
+ attn_class=nn.MultiheadAttention,
218
+ )
219
+ for _ in range(time_depth)
220
+ ]
221
+ )
222
+
223
+ if add_space_attn:
224
+ self.space_virtual_blocks = nn.ModuleList(
225
+ [
226
+ AttnBlock(
227
+ hidden_size,
228
+ num_heads,
229
+ mlp_ratio=mlp_ratio,
230
+ attn_class=nn.MultiheadAttention,
231
+ )
232
+ for _ in range(space_depth)
233
+ ]
234
+ )
235
+ self.space_point2virtual_blocks = nn.ModuleList(
236
+ [
237
+ CrossAttnBlock(
238
+ hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
239
+ )
240
+ for _ in range(space_depth)
241
+ ]
242
+ )
243
+ self.space_virtual2point_blocks = nn.ModuleList(
244
+ [
245
+ CrossAttnBlock(
246
+ hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
247
+ )
248
+ for _ in range(space_depth)
249
+ ]
250
+ )
251
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
252
+ self.initialize_weights()
253
+
254
+ def initialize_weights(self):
255
+ def _basic_init(module):
256
+ if isinstance(module, nn.Linear):
257
+ torch.nn.init.xavier_uniform_(module.weight)
258
+ if module.bias is not None:
259
+ nn.init.constant_(module.bias, 0)
260
+
261
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
262
+ """ViT weight initialization, original timm impl (for reproducibility)"""
263
+ if isinstance(module, nn.Linear):
264
+ trunc_normal_(module.weight, std=0.02)
265
+ if module.bias is not None:
266
+ nn.init.zeros_(module.bias)
267
+
268
+ def forward(self, input_tensor, mask=None):
269
+ tokens = self.input_transform(input_tensor)
270
+
271
+ init_tokens = tokens
272
+
273
+ B, _, T, _ = tokens.shape
274
+
275
+ if self.add_space_attn:
276
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
277
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
278
+
279
+ _, N, _, _ = tokens.shape
280
+
281
+ j = 0
282
+ for i in range(len(self.time_blocks)):
283
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
284
+ time_tokens = self.time_blocks[i](time_tokens)
285
+
286
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
287
+ if self.add_space_attn and (
288
+ i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0
289
+ ):
290
+ space_tokens = (
291
+ tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
292
+ ) # B N T C -> (B T) N C
293
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
294
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
295
+
296
+ virtual_tokens = self.space_virtual2point_blocks[j](
297
+ virtual_tokens, point_tokens, mask=mask
298
+ )
299
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
300
+ point_tokens = self.space_point2virtual_blocks[j](
301
+ point_tokens, virtual_tokens, mask=mask
302
+ )
303
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
304
+ tokens = space_tokens.view(B, T, N, -1).permute(
305
+ 0, 2, 1, 3
306
+ ) # (B T) N C -> B N T C
307
+ j += 1
308
+
309
+ if self.add_space_attn:
310
+ tokens = tokens[:, : N - self.num_virtual_tracks]
311
+
312
+ tokens = tokens + init_tokens
313
+
314
+ flow = self.flow_head(tokens)
315
+ return flow
316
+
317
+
318
+ class CorrBlock:
319
+ def __init__(
320
+ self,
321
+ fmaps,
322
+ num_levels=4,
323
+ radius=4,
324
+ multiple_track_feats=False,
325
+ padding_mode="zeros",
326
+ ):
327
+ B, S, C, H, W = fmaps.shape
328
+ self.S, self.C, self.H, self.W = S, C, H, W
329
+ self.padding_mode = padding_mode
330
+ self.num_levels = num_levels
331
+ self.radius = radius
332
+ self.fmaps_pyramid = []
333
+ self.multiple_track_feats = multiple_track_feats
334
+
335
+ self.fmaps_pyramid.append(fmaps)
336
+ for i in range(self.num_levels - 1):
337
+ fmaps_ = fmaps.reshape(B * S, C, H, W)
338
+ fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
339
+ _, _, H, W = fmaps_.shape
340
+ fmaps = fmaps_.reshape(B, S, C, H, W)
341
+ self.fmaps_pyramid.append(fmaps)
342
+
343
+ def sample(self, coords):
344
+ r = self.radius
345
+ B, S, N, D = coords.shape
346
+ assert D == 2
347
+
348
+ H, W = self.H, self.W
349
+ out_pyramid = []
350
+ for i in range(self.num_levels):
351
+ corrs = self.corrs_pyramid[i] # B, S, N, H, W
352
+ *_, H, W = corrs.shape
353
+
354
+ dx = torch.linspace(-r, r, 2 * r + 1)
355
+ dy = torch.linspace(-r, r, 2 * r + 1)
356
+ delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
357
+ coords.device
358
+ )
359
+
360
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
361
+ delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
362
+ coords_lvl = centroid_lvl + delta_lvl
363
+
364
+ corrs = bilinear_sampler(
365
+ corrs.reshape(B * S * N, 1, H, W),
366
+ coords_lvl,
367
+ padding_mode=self.padding_mode,
368
+ )
369
+ corrs = corrs.view(B, S, N, -1)
370
+
371
+ out_pyramid.append(corrs)
372
+
373
+ out = torch.cat(out_pyramid, dim=-1).contiguous() # B, S, N, LRR*2
374
+ return out
375
+
376
+ def corr(self, targets):
377
+ B, S, N, C = targets.shape
378
+ if self.multiple_track_feats:
379
+ targets_split = targets.split(C // self.num_levels, dim=-1)
380
+ B, S, N, C = targets_split[0].shape
381
+
382
+ assert C == self.C
383
+ assert S == self.S
384
+
385
+ fmap1 = targets
386
+
387
+ self.corrs_pyramid = []
388
+ for i, fmaps in enumerate(self.fmaps_pyramid):
389
+ *_, H, W = fmaps.shape
390
+ fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
391
+ if self.multiple_track_feats:
392
+ fmap1 = targets_split[i]
393
+ corrs = torch.matmul(fmap1, fmap2s)
394
+ corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
395
+ corrs = corrs / torch.sqrt(torch.tensor(C).float())
396
+ self.corrs_pyramid.append(corrs)
sailrecon/dependency/track_modules/modules.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import collections
9
+ from functools import partial
10
+ from itertools import repeat
11
+ from typing import Callable
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from torch import Tensor
17
+
18
+
19
+ # From PyTorch internals
20
+ def _ntuple(n):
21
+ def parse(x):
22
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
23
+ return tuple(x)
24
+ return tuple(repeat(x, n))
25
+
26
+ return parse
27
+
28
+
29
+ def exists(val):
30
+ return val is not None
31
+
32
+
33
+ def default(val, d):
34
+ return val if exists(val) else d
35
+
36
+
37
+ to_2tuple = _ntuple(2)
38
+
39
+
40
+ class ResidualBlock(nn.Module):
41
+ """
42
+ ResidualBlock: construct a block of two conv layers with residual connections
43
+ """
44
+
45
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
46
+ super(ResidualBlock, self).__init__()
47
+
48
+ self.conv1 = nn.Conv2d(
49
+ in_planes,
50
+ planes,
51
+ kernel_size=kernel_size,
52
+ padding=1,
53
+ stride=stride,
54
+ padding_mode="zeros",
55
+ )
56
+ self.conv2 = nn.Conv2d(
57
+ planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros"
58
+ )
59
+ self.relu = nn.ReLU(inplace=True)
60
+
61
+ num_groups = planes // 8
62
+
63
+ if norm_fn == "group":
64
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
65
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
66
+ if not stride == 1:
67
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
68
+
69
+ elif norm_fn == "batch":
70
+ self.norm1 = nn.BatchNorm2d(planes)
71
+ self.norm2 = nn.BatchNorm2d(planes)
72
+ if not stride == 1:
73
+ self.norm3 = nn.BatchNorm2d(planes)
74
+
75
+ elif norm_fn == "instance":
76
+ self.norm1 = nn.InstanceNorm2d(planes)
77
+ self.norm2 = nn.InstanceNorm2d(planes)
78
+ if not stride == 1:
79
+ self.norm3 = nn.InstanceNorm2d(planes)
80
+
81
+ elif norm_fn == "none":
82
+ self.norm1 = nn.Sequential()
83
+ self.norm2 = nn.Sequential()
84
+ if not stride == 1:
85
+ self.norm3 = nn.Sequential()
86
+ else:
87
+ raise NotImplementedError
88
+
89
+ if stride == 1:
90
+ self.downsample = None
91
+ else:
92
+ self.downsample = nn.Sequential(
93
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
94
+ )
95
+
96
+ def forward(self, x):
97
+ y = x
98
+ y = self.relu(self.norm1(self.conv1(y)))
99
+ y = self.relu(self.norm2(self.conv2(y)))
100
+
101
+ if self.downsample is not None:
102
+ x = self.downsample(x)
103
+
104
+ return self.relu(x + y)
105
+
106
+
107
+ class Mlp(nn.Module):
108
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
109
+
110
+ def __init__(
111
+ self,
112
+ in_features,
113
+ hidden_features=None,
114
+ out_features=None,
115
+ act_layer=nn.GELU,
116
+ norm_layer=None,
117
+ bias=True,
118
+ drop=0.0,
119
+ use_conv=False,
120
+ ):
121
+ super().__init__()
122
+ out_features = out_features or in_features
123
+ hidden_features = hidden_features or in_features
124
+ bias = to_2tuple(bias)
125
+ drop_probs = to_2tuple(drop)
126
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
127
+
128
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
129
+ self.act = act_layer()
130
+ self.drop1 = nn.Dropout(drop_probs[0])
131
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
132
+ self.drop2 = nn.Dropout(drop_probs[1])
133
+
134
+ def forward(self, x):
135
+ x = self.fc1(x)
136
+ x = self.act(x)
137
+ x = self.drop1(x)
138
+ x = self.fc2(x)
139
+ x = self.drop2(x)
140
+ return x
141
+
142
+
143
+ class AttnBlock(nn.Module):
144
+ def __init__(
145
+ self,
146
+ hidden_size,
147
+ num_heads,
148
+ attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
149
+ mlp_ratio=4.0,
150
+ **block_kwargs,
151
+ ):
152
+ """
153
+ Self attention block
154
+ """
155
+ super().__init__()
156
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
157
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
158
+
159
+ self.attn = attn_class(
160
+ embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
161
+ )
162
+
163
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
164
+
165
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
166
+
167
+ def forward(self, x, mask=None):
168
+ # Prepare the mask for PyTorch's attention (it expects a different format)
169
+ # attn_mask = mask if mask is not None else None
170
+ # Normalize before attention
171
+ x = self.norm1(x)
172
+
173
+ # PyTorch's MultiheadAttention returns attn_output, attn_output_weights
174
+ # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
175
+
176
+ attn_output, _ = self.attn(x, x, x)
177
+
178
+ # Add & Norm
179
+ x = x + attn_output
180
+ x = x + self.mlp(self.norm2(x))
181
+ return x
182
+
183
+
184
+ class CrossAttnBlock(nn.Module):
185
+ def __init__(
186
+ self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs
187
+ ):
188
+ """
189
+ Cross attention block
190
+ """
191
+ super().__init__()
192
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
193
+ self.norm_context = nn.LayerNorm(hidden_size)
194
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
195
+
196
+ self.cross_attn = nn.MultiheadAttention(
197
+ embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
198
+ )
199
+
200
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
201
+
202
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
203
+
204
+ def forward(self, x, context, mask=None):
205
+ # Normalize inputs
206
+ x = self.norm1(x)
207
+ context = self.norm_context(context)
208
+
209
+ # Apply cross attention
210
+ # Note: nn.MultiheadAttention returns attn_output, attn_output_weights
211
+ attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
212
+
213
+ # Add & Norm
214
+ x = x + attn_output
215
+ x = x + self.mlp(self.norm2(x))
216
+ return x
sailrecon/dependency/track_modules/track_refine.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import os
9
+ from functools import partial
10
+ from typing import Tuple, Union
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from einops import rearrange, repeat
17
+ from einops.layers.torch import Rearrange, Reduce
18
+ from PIL import Image
19
+ from torch import einsum, nn
20
+
21
+
22
+ def refine_track(
23
+ images,
24
+ fine_fnet,
25
+ fine_tracker,
26
+ coarse_pred,
27
+ compute_score=False,
28
+ pradius=15,
29
+ sradius=2,
30
+ fine_iters=6,
31
+ chunk=40960,
32
+ ):
33
+ """
34
+ Refines the tracking of images using a fine track predictor and a fine feature network.
35
+ Check https://arxiv.org/abs/2312.04563 for more details.
36
+
37
+ Args:
38
+ images (torch.Tensor): The images to be tracked.
39
+ fine_fnet (nn.Module): The fine feature network.
40
+ fine_tracker (nn.Module): The fine track predictor.
41
+ coarse_pred (torch.Tensor): The coarse predictions of tracks.
42
+ compute_score (bool, optional): Whether to compute the score. Defaults to False.
43
+ pradius (int, optional): The radius of a patch. Defaults to 15.
44
+ sradius (int, optional): The search radius. Defaults to 2.
45
+
46
+ Returns:
47
+ torch.Tensor: The refined tracks.
48
+ torch.Tensor, optional: The score.
49
+ """
50
+
51
+ # coarse_pred shape: BxSxNx2,
52
+ # where B is the batch, S is the video/images length, and N is the number of tracks
53
+ # now we are going to extract patches with the center at coarse_pred
54
+ # Please note that the last dimension indicates x and y, and hence has a dim number of 2
55
+ B, S, N, _ = coarse_pred.shape
56
+ _, _, _, H, W = images.shape
57
+
58
+ # Given the raidus of a patch, compute the patch size
59
+ psize = pradius * 2 + 1
60
+
61
+ # Note that we assume the first frame is the query frame
62
+ # so the 2D locations of the first frame are the query points
63
+ query_points = coarse_pred[:, 0]
64
+
65
+ # Given 2D positions, we can use grid_sample to extract patches
66
+ # but it takes too much memory.
67
+ # Instead, we use the floored track xy to sample patches.
68
+
69
+ # For example, if the query point xy is (128.16, 252.78),
70
+ # and the patch size is (31, 31),
71
+ # our goal is to extract the content of a rectangle
72
+ # with left top: (113.16, 237.78)
73
+ # and right bottom: (143.16, 267.78).
74
+ # However, we record the floored left top: (113, 237)
75
+ # and the offset (0.16, 0.78)
76
+ # Then what we need is just unfolding the images like in CNN,
77
+ # picking the content at [(113, 237), (143, 267)].
78
+ # Such operations are highly optimized at pytorch
79
+ # (well if you really want to use interpolation, check the function extract_glimpse() below)
80
+
81
+ with torch.no_grad():
82
+ content_to_extract = images.reshape(B * S, 3, H, W)
83
+ C_in = content_to_extract.shape[1]
84
+
85
+ # Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html
86
+ # for the detailed explanation of unfold()
87
+ # Here it runs sliding windows (psize x psize) to build patches
88
+ # The shape changes from
89
+ # (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize
90
+ # where Psize is the size of patch
91
+ content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1)
92
+
93
+ # Floor the coarse predictions to get integers and save the fractional/decimal
94
+ track_int = coarse_pred.floor().int()
95
+ track_frac = coarse_pred - track_int
96
+
97
+ # Note the points represent the center of patches
98
+ # now we get the location of the top left corner of patches
99
+ # because the ouput of pytorch unfold are indexed by top left corner
100
+ topleft = track_int - pradius
101
+ topleft_BSN = topleft.clone()
102
+
103
+ # clamp the values so that we will not go out of indexes
104
+ # NOTE: (VERY IMPORTANT: This operation ASSUMES H=W).
105
+ # You need to seperately clamp x and y if H!=W
106
+ topleft = topleft.clamp(0, H - psize)
107
+
108
+ # Reshape from BxSxNx2 -> (B*S)xNx2
109
+ topleft = topleft.reshape(B * S, N, 2)
110
+
111
+ # Prepare batches for indexing, shape: (B*S)xN
112
+ batch_indices = (
113
+ torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device)
114
+ )
115
+
116
+ # extracted_patches: (B*S) x N x C_in x Psize x Psize
117
+ extracted_patches = content_to_extract[
118
+ batch_indices, :, topleft[..., 1], topleft[..., 0]
119
+ ]
120
+
121
+ if chunk < 0:
122
+ # Extract image patches based on top left corners
123
+ # Feed patches to fine fent for features
124
+ patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize))
125
+ else:
126
+ patches = extracted_patches.reshape(B * S * N, C_in, psize, psize)
127
+
128
+ patch_feat_list = []
129
+ for p in torch.split(patches, chunk):
130
+ patch_feat_list += [fine_fnet(p)]
131
+ patch_feat = torch.cat(patch_feat_list, 0)
132
+
133
+ C_out = patch_feat.shape[1]
134
+
135
+ # Refine the coarse tracks by fine_tracker
136
+ # reshape back to B x S x N x C_out x Psize x Psize
137
+ patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize)
138
+ patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q")
139
+
140
+ # Prepare for the query points for fine tracker
141
+ # They are relative to the patch left top corner,
142
+ # instead of the image top left corner now
143
+ # patch_query_points: N x 1 x 2
144
+ # only 1 here because for each patch we only have 1 query point
145
+ patch_query_points = track_frac[:, 0] + pradius
146
+ patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1)
147
+
148
+ # Feed the PATCH query points and tracks into fine tracker
149
+ fine_pred_track_lists, _, _, query_point_feat = fine_tracker(
150
+ query_points=patch_query_points,
151
+ fmaps=patch_feat,
152
+ iters=fine_iters,
153
+ return_feat=True,
154
+ )
155
+
156
+ # relative the patch top left
157
+ fine_pred_track = fine_pred_track_lists[-1].clone()
158
+
159
+ # From (relative to the patch top left) to (relative to the image top left)
160
+ for idx in range(len(fine_pred_track_lists)):
161
+ fine_level = rearrange(
162
+ fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N
163
+ )
164
+ fine_level = fine_level.squeeze(-2)
165
+ fine_level = fine_level + topleft_BSN
166
+ fine_pred_track_lists[idx] = fine_level
167
+
168
+ # relative to the image top left
169
+ refined_tracks = fine_pred_track_lists[-1].clone()
170
+ refined_tracks[:, 0] = query_points
171
+
172
+ score = None
173
+
174
+ if compute_score:
175
+ score = compute_score_fn(
176
+ query_point_feat,
177
+ patch_feat,
178
+ fine_pred_track,
179
+ sradius,
180
+ psize,
181
+ B,
182
+ N,
183
+ S,
184
+ C_out,
185
+ )
186
+
187
+ return refined_tracks, score
188
+
189
+
190
+ def refine_track_v0(
191
+ images,
192
+ fine_fnet,
193
+ fine_tracker,
194
+ coarse_pred,
195
+ compute_score=False,
196
+ pradius=15,
197
+ sradius=2,
198
+ fine_iters=6,
199
+ ):
200
+ """
201
+ COPIED FROM VGGSfM
202
+
203
+ Refines the tracking of images using a fine track predictor and a fine feature network.
204
+ Check https://arxiv.org/abs/2312.04563 for more details.
205
+
206
+ Args:
207
+ images (torch.Tensor): The images to be tracked.
208
+ fine_fnet (nn.Module): The fine feature network.
209
+ fine_tracker (nn.Module): The fine track predictor.
210
+ coarse_pred (torch.Tensor): The coarse predictions of tracks.
211
+ compute_score (bool, optional): Whether to compute the score. Defaults to False.
212
+ pradius (int, optional): The radius of a patch. Defaults to 15.
213
+ sradius (int, optional): The search radius. Defaults to 2.
214
+
215
+ Returns:
216
+ torch.Tensor: The refined tracks.
217
+ torch.Tensor, optional: The score.
218
+ """
219
+
220
+ # coarse_pred shape: BxSxNx2,
221
+ # where B is the batch, S is the video/images length, and N is the number of tracks
222
+ # now we are going to extract patches with the center at coarse_pred
223
+ # Please note that the last dimension indicates x and y, and hence has a dim number of 2
224
+ B, S, N, _ = coarse_pred.shape
225
+ _, _, _, H, W = images.shape
226
+
227
+ # Given the raidus of a patch, compute the patch size
228
+ psize = pradius * 2 + 1
229
+
230
+ # Note that we assume the first frame is the query frame
231
+ # so the 2D locations of the first frame are the query points
232
+ query_points = coarse_pred[:, 0]
233
+
234
+ # Given 2D positions, we can use grid_sample to extract patches
235
+ # but it takes too much memory.
236
+ # Instead, we use the floored track xy to sample patches.
237
+
238
+ # For example, if the query point xy is (128.16, 252.78),
239
+ # and the patch size is (31, 31),
240
+ # our goal is to extract the content of a rectangle
241
+ # with left top: (113.16, 237.78)
242
+ # and right bottom: (143.16, 267.78).
243
+ # However, we record the floored left top: (113, 237)
244
+ # and the offset (0.16, 0.78)
245
+ # Then what we need is just unfolding the images like in CNN,
246
+ # picking the content at [(113, 237), (143, 267)].
247
+ # Such operations are highly optimized at pytorch
248
+ # (well if you really want to use interpolation, check the function extract_glimpse() below)
249
+
250
+ with torch.no_grad():
251
+ content_to_extract = images.reshape(B * S, 3, H, W)
252
+ C_in = content_to_extract.shape[1]
253
+
254
+ # Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html
255
+ # for the detailed explanation of unfold()
256
+ # Here it runs sliding windows (psize x psize) to build patches
257
+ # The shape changes from
258
+ # (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize
259
+ # where Psize is the size of patch
260
+ content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1)
261
+
262
+ # Floor the coarse predictions to get integers and save the fractional/decimal
263
+ track_int = coarse_pred.floor().int()
264
+ track_frac = coarse_pred - track_int
265
+
266
+ # Note the points represent the center of patches
267
+ # now we get the location of the top left corner of patches
268
+ # because the ouput of pytorch unfold are indexed by top left corner
269
+ topleft = track_int - pradius
270
+ topleft_BSN = topleft.clone()
271
+
272
+ # clamp the values so that we will not go out of indexes
273
+ # NOTE: (VERY IMPORTANT: This operation ASSUMES H=W).
274
+ # You need to seperately clamp x and y if H!=W
275
+ topleft = topleft.clamp(0, H - psize)
276
+
277
+ # Reshape from BxSxNx2 -> (B*S)xNx2
278
+ topleft = topleft.reshape(B * S, N, 2)
279
+
280
+ # Prepare batches for indexing, shape: (B*S)xN
281
+ batch_indices = (
282
+ torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device)
283
+ )
284
+
285
+ # Extract image patches based on top left corners
286
+ # extracted_patches: (B*S) x N x C_in x Psize x Psize
287
+ extracted_patches = content_to_extract[
288
+ batch_indices, :, topleft[..., 1], topleft[..., 0]
289
+ ]
290
+
291
+ # Feed patches to fine fent for features
292
+ patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize))
293
+
294
+ C_out = patch_feat.shape[1]
295
+
296
+ # Refine the coarse tracks by fine_tracker
297
+
298
+ # reshape back to B x S x N x C_out x Psize x Psize
299
+ patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize)
300
+ patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q")
301
+
302
+ # Prepare for the query points for fine tracker
303
+ # They are relative to the patch left top corner,
304
+ # instead of the image top left corner now
305
+ # patch_query_points: N x 1 x 2
306
+ # only 1 here because for each patch we only have 1 query point
307
+ patch_query_points = track_frac[:, 0] + pradius
308
+ patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1)
309
+
310
+ # Feed the PATCH query points and tracks into fine tracker
311
+ fine_pred_track_lists, _, _, query_point_feat = fine_tracker(
312
+ query_points=patch_query_points,
313
+ fmaps=patch_feat,
314
+ iters=fine_iters,
315
+ return_feat=True,
316
+ )
317
+
318
+ # relative the patch top left
319
+ fine_pred_track = fine_pred_track_lists[-1].clone()
320
+
321
+ # From (relative to the patch top left) to (relative to the image top left)
322
+ for idx in range(len(fine_pred_track_lists)):
323
+ fine_level = rearrange(
324
+ fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N
325
+ )
326
+ fine_level = fine_level.squeeze(-2)
327
+ fine_level = fine_level + topleft_BSN
328
+ fine_pred_track_lists[idx] = fine_level
329
+
330
+ # relative to the image top left
331
+ refined_tracks = fine_pred_track_lists[-1].clone()
332
+ refined_tracks[:, 0] = query_points
333
+
334
+ score = None
335
+
336
+ if compute_score:
337
+ score = compute_score_fn(
338
+ query_point_feat,
339
+ patch_feat,
340
+ fine_pred_track,
341
+ sradius,
342
+ psize,
343
+ B,
344
+ N,
345
+ S,
346
+ C_out,
347
+ )
348
+
349
+ return refined_tracks, score
350
+
351
+
352
+ ################################## NOTE: NOT USED ##################################
353
+
354
+
355
+ def compute_score_fn(
356
+ query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out
357
+ ):
358
+ """
359
+ Compute the scores, i.e., the standard deviation of the 2D similarity heatmaps,
360
+ given the query point features and reference frame feature maps
361
+ """
362
+
363
+ from kornia.geometry.subpix import dsnt
364
+ from kornia.utils.grid import create_meshgrid
365
+
366
+ # query_point_feat initial shape: B x N x C_out,
367
+ # query_point_feat indicates the feat at the coorponsing query points
368
+ # Therefore we don't have S dimension here
369
+ query_point_feat = query_point_feat.reshape(B, N, C_out)
370
+ # reshape and expand to B x (S-1) x N x C_out
371
+ query_point_feat = query_point_feat.unsqueeze(1).expand(-1, S - 1, -1, -1)
372
+ # and reshape to (B*(S-1)*N) x C_out
373
+ query_point_feat = query_point_feat.reshape(B * (S - 1) * N, C_out)
374
+
375
+ # Radius and size for computing the score
376
+ ssize = sradius * 2 + 1
377
+
378
+ # Reshape, you know it, so many reshaping operations
379
+ patch_feat = rearrange(patch_feat, "(b n) s c p q -> b s n c p q", b=B, n=N)
380
+
381
+ # Again, we unfold the patches to smaller patches
382
+ # so that we can then focus on smaller patches
383
+ # patch_feat_unfold shape:
384
+ # B x S x N x C_out x (psize - 2*sradius) x (psize - 2*sradius) x ssize x ssize
385
+ # well a bit scary, but actually not
386
+ patch_feat_unfold = patch_feat.unfold(4, ssize, 1).unfold(5, ssize, 1)
387
+
388
+ # Do the same stuffs above, i.e., the same as extracting patches
389
+ fine_prediction_floor = fine_pred_track.floor().int()
390
+ fine_level_floor_topleft = fine_prediction_floor - sradius
391
+
392
+ # Clamp to ensure the smaller patch is valid
393
+ fine_level_floor_topleft = fine_level_floor_topleft.clamp(0, psize - ssize)
394
+ fine_level_floor_topleft = fine_level_floor_topleft.squeeze(2)
395
+
396
+ # Prepare the batch indices and xy locations
397
+
398
+ batch_indices_score = torch.arange(B)[:, None, None].expand(-1, S, N) # BxSxN
399
+ batch_indices_score = batch_indices_score.reshape(-1).to(
400
+ patch_feat_unfold.device
401
+ ) # B*S*N
402
+ y_indices = fine_level_floor_topleft[..., 0].flatten() # Flatten H indices
403
+ x_indices = fine_level_floor_topleft[..., 1].flatten() # Flatten W indices
404
+
405
+ reference_frame_feat = patch_feat_unfold.reshape(
406
+ B * S * N, C_out, psize - sradius * 2, psize - sradius * 2, ssize, ssize
407
+ )
408
+
409
+ # Note again, according to pytorch convention
410
+ # x_indices cooresponds to [..., 1] and y_indices cooresponds to [..., 0]
411
+ reference_frame_feat = reference_frame_feat[
412
+ batch_indices_score, :, x_indices, y_indices
413
+ ]
414
+ reference_frame_feat = reference_frame_feat.reshape(B, S, N, C_out, ssize, ssize)
415
+ # pick the frames other than the first one, so we have S-1 frames here
416
+ reference_frame_feat = reference_frame_feat[:, 1:].reshape(
417
+ B * (S - 1) * N, C_out, ssize * ssize
418
+ )
419
+
420
+ # Compute similarity
421
+ sim_matrix = torch.einsum("mc,mcr->mr", query_point_feat, reference_frame_feat)
422
+ softmax_temp = 1.0 / C_out**0.5
423
+ heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1)
424
+ # 2D heatmaps
425
+ heatmap = heatmap.reshape(B * (S - 1) * N, ssize, ssize) # * x ssize x ssize
426
+
427
+ coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0]
428
+ grid_normalized = create_meshgrid(
429
+ ssize, ssize, normalized_coordinates=True, device=heatmap.device
430
+ ).reshape(1, -1, 2)
431
+
432
+ var = (
433
+ torch.sum(grid_normalized**2 * heatmap.view(-1, ssize * ssize, 1), dim=1)
434
+ - coords_normalized**2
435
+ )
436
+ std = torch.sum(
437
+ torch.sqrt(torch.clamp(var, min=1e-10)), -1
438
+ ) # clamp needed for numerical stability
439
+
440
+ score = std.reshape(B, S - 1, N)
441
+ # set score as 1 for the query frame
442
+ score = torch.cat([torch.ones_like(score[:, 0:1]), score], dim=1)
443
+
444
+ return score
445
+
446
+
447
+ def extract_glimpse(
448
+ tensor: torch.Tensor,
449
+ size: Tuple[int, int],
450
+ offsets,
451
+ mode="bilinear",
452
+ padding_mode="zeros",
453
+ debug=False,
454
+ orib=None,
455
+ ):
456
+ B, C, W, H = tensor.shape
457
+
458
+ h, w = size
459
+ xs = torch.arange(0, w, dtype=tensor.dtype, device=tensor.device) - (w - 1) / 2.0
460
+ ys = torch.arange(0, h, dtype=tensor.dtype, device=tensor.device) - (h - 1) / 2.0
461
+
462
+ vy, vx = torch.meshgrid(ys, xs)
463
+ grid = torch.stack([vx, vy], dim=-1) # h, w, 2
464
+ grid = grid[None]
465
+
466
+ B, N, _ = offsets.shape
467
+
468
+ offsets = offsets.reshape((B * N), 1, 1, 2)
469
+ offsets_grid = offsets + grid
470
+
471
+ # normalised grid to [-1, 1]
472
+ offsets_grid = (
473
+ offsets_grid - offsets_grid.new_tensor([W / 2, H / 2])
474
+ ) / offsets_grid.new_tensor([W / 2, H / 2])
475
+
476
+ # BxCxHxW -> Bx1xCxHxW
477
+ tensor = tensor[:, None]
478
+
479
+ # Bx1xCxHxW -> BxNxCxHxW
480
+ tensor = tensor.expand(-1, N, -1, -1, -1)
481
+
482
+ # BxNxCxHxW -> (B*N)xCxHxW
483
+ tensor = tensor.reshape((B * N), C, W, H)
484
+
485
+ sampled = torch.nn.functional.grid_sample(
486
+ tensor, offsets_grid, mode=mode, align_corners=False, padding_mode=padding_mode
487
+ )
488
+
489
+ # NOTE: I am not sure it should be h, w or w, h here
490
+ # but okay for sqaures
491
+ sampled = sampled.reshape(B, N, C, h, w)
492
+
493
+ return sampled
sailrecon/dependency/track_modules/utils.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from https://github.com/facebookresearch/PoseDiffusion
8
+ # and https://github.com/facebookresearch/co-tracker/tree/main
9
+
10
+
11
+ from typing import Optional, Tuple, Union
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from einops import rearrange, repeat
17
+
18
+
19
+ def get_2d_sincos_pos_embed(
20
+ embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False
21
+ ) -> torch.Tensor:
22
+ """
23
+ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
24
+ It is a wrapper of get_2d_sincos_pos_embed_from_grid.
25
+ Args:
26
+ - embed_dim: The embedding dimension.
27
+ - grid_size: The grid size.
28
+ Returns:
29
+ - pos_embed: The generated 2D positional embedding.
30
+ """
31
+ if isinstance(grid_size, tuple):
32
+ grid_size_h, grid_size_w = grid_size
33
+ else:
34
+ grid_size_h = grid_size_w = grid_size
35
+ grid_h = torch.arange(grid_size_h, dtype=torch.float)
36
+ grid_w = torch.arange(grid_size_w, dtype=torch.float)
37
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
38
+ grid = torch.stack(grid, dim=0)
39
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
40
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
41
+ if return_grid:
42
+ return (
43
+ pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2),
44
+ grid,
45
+ )
46
+ return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
47
+
48
+
49
+ def get_2d_sincos_pos_embed_from_grid(
50
+ embed_dim: int, grid: torch.Tensor
51
+ ) -> torch.Tensor:
52
+ """
53
+ This function generates a 2D positional embedding from a given grid using sine and cosine functions.
54
+
55
+ Args:
56
+ - embed_dim: The embedding dimension.
57
+ - grid: The grid to generate the embedding from.
58
+
59
+ Returns:
60
+ - emb: The generated 2D positional embedding.
61
+ """
62
+ assert embed_dim % 2 == 0
63
+
64
+ # use half of dimensions to encode grid_h
65
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
66
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
67
+
68
+ emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
69
+ return emb
70
+
71
+
72
+ def get_1d_sincos_pos_embed_from_grid(
73
+ embed_dim: int, pos: torch.Tensor
74
+ ) -> torch.Tensor:
75
+ """
76
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
77
+
78
+ Args:
79
+ - embed_dim: The embedding dimension.
80
+ - pos: The position to generate the embedding from.
81
+
82
+ Returns:
83
+ - emb: The generated 1D positional embedding.
84
+ """
85
+ assert embed_dim % 2 == 0
86
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
87
+ omega /= embed_dim / 2.0
88
+ omega = 1.0 / 10000**omega # (D/2,)
89
+
90
+ pos = pos.reshape(-1) # (M,)
91
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
92
+
93
+ emb_sin = torch.sin(out) # (M, D/2)
94
+ emb_cos = torch.cos(out) # (M, D/2)
95
+
96
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
97
+ return emb[None].float()
98
+
99
+
100
+ def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
101
+ """
102
+ This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
103
+
104
+ Args:
105
+ - xy: The coordinates to generate the embedding from.
106
+ - C: The size of the embedding.
107
+ - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
108
+
109
+ Returns:
110
+ - pe: The generated 2D positional embedding.
111
+ """
112
+ B, N, D = xy.shape
113
+ assert D == 2
114
+
115
+ x = xy[:, :, 0:1]
116
+ y = xy[:, :, 1:2]
117
+ div_term = (
118
+ torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
119
+ ).reshape(1, 1, int(C / 2))
120
+
121
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
122
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
123
+
124
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
125
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
126
+
127
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
128
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
129
+
130
+ pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
131
+ if cat_coords:
132
+ pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
133
+ return pe
134
+
135
+
136
+ def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
137
+ r"""Sample a tensor using bilinear interpolation
138
+
139
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
140
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
141
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
142
+ convention.
143
+
144
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
145
+ :math:`B` is the batch size, :math:`C` is the number of channels,
146
+ :math:`H` is the height of the image, and :math:`W` is the width of the
147
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
148
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
149
+
150
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
151
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
152
+ that in this case the order of the components is slightly different
153
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
154
+
155
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
156
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
157
+ left-most image pixel :math:`W-1` to the center of the right-most
158
+ pixel.
159
+
160
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
161
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
162
+ the left-most pixel :math:`W` to the right edge of the right-most
163
+ pixel.
164
+
165
+ Similar conventions apply to the :math:`y` for the range
166
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
167
+ :math:`[0,T-1]` and :math:`[0,T]`.
168
+
169
+ Args:
170
+ input (Tensor): batch of input images.
171
+ coords (Tensor): batch of coordinates.
172
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
173
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
174
+
175
+ Returns:
176
+ Tensor: sampled points.
177
+ """
178
+
179
+ sizes = input.shape[2:]
180
+
181
+ assert len(sizes) in [2, 3]
182
+
183
+ if len(sizes) == 3:
184
+ # t x y -> x y t to match dimensions T H W in grid_sample
185
+ coords = coords[..., [1, 2, 0]]
186
+
187
+ if align_corners:
188
+ coords = coords * torch.tensor(
189
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
190
+ )
191
+ else:
192
+ coords = coords * torch.tensor(
193
+ [2 / size for size in reversed(sizes)], device=coords.device
194
+ )
195
+
196
+ coords -= 1
197
+
198
+ return F.grid_sample(
199
+ input, coords, align_corners=align_corners, padding_mode=padding_mode
200
+ )
201
+
202
+
203
+ def sample_features4d(input, coords):
204
+ r"""Sample spatial features
205
+
206
+ `sample_features4d(input, coords)` samples the spatial features
207
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
208
+
209
+ The field is sampled at coordinates :attr:`coords` using bilinear
210
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
211
+ 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
212
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
213
+
214
+ The output tensor has one feature per point, and has shape :math:`(B,
215
+ R, C)`.
216
+
217
+ Args:
218
+ input (Tensor): spatial features.
219
+ coords (Tensor): points.
220
+
221
+ Returns:
222
+ Tensor: sampled features.
223
+ """
224
+
225
+ B, _, _, _ = input.shape
226
+
227
+ # B R 2 -> B R 1 2
228
+ coords = coords.unsqueeze(2)
229
+
230
+ # B C R 1
231
+ feats = bilinear_sampler(input, coords)
232
+
233
+ return feats.permute(0, 2, 1, 3).view(
234
+ B, -1, feats.shape[1] * feats.shape[3]
235
+ ) # B C R 1 -> B R C
sailrecon/dependency/track_predict.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ from .vggsfm_utils import *
11
+
12
+
13
+ def predict_tracks(
14
+ images,
15
+ conf=None,
16
+ points_3d=None,
17
+ masks=None,
18
+ max_query_pts=2048,
19
+ query_frame_num=5,
20
+ keypoint_extractor="aliked+sp",
21
+ max_points_num=163840,
22
+ fine_tracking=True,
23
+ complete_non_vis=True,
24
+ ):
25
+ """
26
+ Predict tracks for the given images and masks.
27
+
28
+ TODO: support non-square images
29
+ TODO: support masks
30
+
31
+
32
+ This function predicts the tracks for the given images and masks using the specified query method
33
+ and track predictor. It finds query points, and predicts the tracks, visibility, and scores for the query frames.
34
+
35
+ Args:
36
+ images: Tensor of shape [S, 3, H, W] containing the input images.
37
+ conf: Tensor of shape [S, 1, H, W] containing the confidence scores. Default is None.
38
+ points_3d: Tensor containing 3D points. Default is None.
39
+ masks: Optional tensor of shape [S, 1, H, W] containing masks. Default is None.
40
+ max_query_pts: Maximum number of query points. Default is 2048.
41
+ query_frame_num: Number of query frames to use. Default is 5.
42
+ keypoint_extractor: Method for keypoint extraction. Default is "aliked+sp".
43
+ max_points_num: Maximum number of points to process at once. Default is 163840.
44
+ fine_tracking: Whether to use fine tracking. Default is True.
45
+ complete_non_vis: Whether to augment non-visible frames. Default is True.
46
+
47
+ Returns:
48
+ pred_tracks: Numpy array containing the predicted tracks.
49
+ pred_vis_scores: Numpy array containing the visibility scores for the tracks.
50
+ pred_confs: Numpy array containing the confidence scores for the tracks.
51
+ pred_points_3d: Numpy array containing the 3D points for the tracks.
52
+ pred_colors: Numpy array containing the point colors for the tracks. (0, 255)
53
+ """
54
+
55
+ device = images.device
56
+ dtype = images.dtype
57
+ tracker = build_vggsfm_tracker().to(device, dtype)
58
+
59
+ # Find query frames
60
+ query_frame_indexes = generate_rank_by_dino(
61
+ images, query_frame_num=query_frame_num, device=device
62
+ )
63
+
64
+ # Add the first image to the front if not already present
65
+ if 0 in query_frame_indexes:
66
+ query_frame_indexes.remove(0)
67
+ query_frame_indexes = [0, *query_frame_indexes]
68
+
69
+ # TODO: add the functionality to handle the masks
70
+ keypoint_extractors = initialize_feature_extractors(
71
+ max_query_pts, extractor_method=keypoint_extractor, device=device
72
+ )
73
+
74
+ pred_tracks = []
75
+ pred_vis_scores = []
76
+ pred_confs = []
77
+ pred_points_3d = []
78
+ pred_colors = []
79
+
80
+ fmaps_for_tracker = tracker.process_images_to_fmaps(images)
81
+
82
+ if fine_tracking:
83
+ print("For faster inference, consider disabling fine_tracking")
84
+
85
+ for query_index in query_frame_indexes:
86
+ print(f"Predicting tracks for query frame {query_index}")
87
+ pred_track, pred_vis, pred_conf, pred_point_3d, pred_color = _forward_on_query(
88
+ query_index,
89
+ images,
90
+ conf,
91
+ points_3d,
92
+ fmaps_for_tracker,
93
+ keypoint_extractors,
94
+ tracker,
95
+ max_points_num,
96
+ fine_tracking,
97
+ device,
98
+ )
99
+
100
+ pred_tracks.append(pred_track)
101
+ pred_vis_scores.append(pred_vis)
102
+ pred_confs.append(pred_conf)
103
+ pred_points_3d.append(pred_point_3d)
104
+ pred_colors.append(pred_color)
105
+
106
+ if complete_non_vis:
107
+ (
108
+ pred_tracks,
109
+ pred_vis_scores,
110
+ pred_confs,
111
+ pred_points_3d,
112
+ pred_colors,
113
+ ) = _augment_non_visible_frames(
114
+ pred_tracks,
115
+ pred_vis_scores,
116
+ pred_confs,
117
+ pred_points_3d,
118
+ pred_colors,
119
+ images,
120
+ conf,
121
+ points_3d,
122
+ fmaps_for_tracker,
123
+ keypoint_extractors,
124
+ tracker,
125
+ max_points_num,
126
+ fine_tracking,
127
+ min_vis=500,
128
+ non_vis_thresh=0.1,
129
+ device=device,
130
+ )
131
+
132
+ pred_tracks = np.concatenate(pred_tracks, axis=1)
133
+ pred_vis_scores = np.concatenate(pred_vis_scores, axis=1)
134
+ pred_confs = np.concatenate(pred_confs, axis=0) if pred_confs else None
135
+ pred_points_3d = np.concatenate(pred_points_3d, axis=0) if pred_points_3d else None
136
+ pred_colors = np.concatenate(pred_colors, axis=0) if pred_colors else None
137
+
138
+ # from vggt.utils.visual_track import visualize_tracks_on_images
139
+ # visualize_tracks_on_images(images[None], torch.from_numpy(pred_tracks[None]), torch.from_numpy(pred_vis_scores[None])>0.2, out_dir="track_visuals")
140
+
141
+ return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors
142
+
143
+
144
+ def _forward_on_query(
145
+ query_index,
146
+ images,
147
+ conf,
148
+ points_3d,
149
+ fmaps_for_tracker,
150
+ keypoint_extractors,
151
+ tracker,
152
+ max_points_num,
153
+ fine_tracking,
154
+ device,
155
+ ):
156
+ """
157
+ Process a single query frame for track prediction.
158
+
159
+ Args:
160
+ query_index: Index of the query frame
161
+ images: Tensor of shape [S, 3, H, W] containing the input images
162
+ conf: Confidence tensor
163
+ points_3d: 3D points tensor
164
+ fmaps_for_tracker: Feature maps for the tracker
165
+ keypoint_extractors: Initialized feature extractors
166
+ tracker: VGG-SFM tracker
167
+ max_points_num: Maximum number of points to process at once
168
+ fine_tracking: Whether to use fine tracking
169
+ device: Device to use for computation
170
+
171
+ Returns:
172
+ pred_track: Predicted tracks
173
+ pred_vis: Visibility scores for the tracks
174
+ pred_conf: Confidence scores for the tracks
175
+ pred_point_3d: 3D points for the tracks
176
+ pred_color: Point colors for the tracks (0, 255)
177
+ """
178
+ frame_num, _, height, width = images.shape
179
+
180
+ query_image = images[query_index]
181
+ query_points = extract_keypoints(
182
+ query_image, keypoint_extractors, round_keypoints=False
183
+ )
184
+ query_points = query_points[:, torch.randperm(query_points.shape[1], device=device)]
185
+
186
+ # Extract the color at the keypoint locations
187
+ query_points_long = query_points.squeeze(0).round().long()
188
+ pred_color = images[query_index][
189
+ :, query_points_long[:, 1], query_points_long[:, 0]
190
+ ]
191
+ pred_color = (pred_color.permute(1, 0).cpu().numpy() * 255).astype(np.uint8)
192
+
193
+ # Query the confidence and points_3d at the keypoint locations
194
+ if (conf is not None) and (points_3d is not None):
195
+ assert height == width
196
+ assert conf.shape[-2] == conf.shape[-1]
197
+ assert conf.shape[:3] == points_3d.shape[:3]
198
+ scale = conf.shape[-1] / width
199
+
200
+ query_points_scaled = (query_points.squeeze(0) * scale).round().long()
201
+ query_points_scaled = query_points_scaled.cpu().numpy()
202
+
203
+ pred_conf = conf[query_index][
204
+ query_points_scaled[:, 1], query_points_scaled[:, 0]
205
+ ]
206
+ pred_point_3d = points_3d[query_index][
207
+ query_points_scaled[:, 1], query_points_scaled[:, 0]
208
+ ]
209
+
210
+ # heuristic to remove low confidence points
211
+ # should I export this as an input parameter?
212
+ valid_mask = pred_conf > 1.2
213
+ if valid_mask.sum() > 512:
214
+ query_points = query_points[:, valid_mask] # Make sure shape is compatible
215
+ pred_conf = pred_conf[valid_mask]
216
+ pred_point_3d = pred_point_3d[valid_mask]
217
+ pred_color = pred_color[valid_mask]
218
+ else:
219
+ pred_conf = None
220
+ pred_point_3d = None
221
+
222
+ reorder_index = calculate_index_mappings(query_index, frame_num, device=device)
223
+
224
+ images_feed, fmaps_feed = switch_tensor_order(
225
+ [images, fmaps_for_tracker], reorder_index, dim=0
226
+ )
227
+ images_feed = images_feed[None] # add batch dimension
228
+ fmaps_feed = fmaps_feed[None] # add batch dimension
229
+
230
+ all_points_num = images_feed.shape[1] * query_points.shape[1]
231
+
232
+ # Don't need to be scared, this is just chunking to make GPU happy
233
+ if all_points_num > max_points_num:
234
+ num_splits = (all_points_num + max_points_num - 1) // max_points_num
235
+ query_points = torch.chunk(query_points, num_splits, dim=1)
236
+ else:
237
+ query_points = [query_points]
238
+
239
+ pred_track, pred_vis, _ = predict_tracks_in_chunks(
240
+ tracker, images_feed, query_points, fmaps_feed, fine_tracking=fine_tracking
241
+ )
242
+
243
+ pred_track, pred_vis = switch_tensor_order(
244
+ [pred_track, pred_vis], reorder_index, dim=1
245
+ )
246
+
247
+ pred_track = pred_track.squeeze(0).float().cpu().numpy()
248
+ pred_vis = pred_vis.squeeze(0).float().cpu().numpy()
249
+
250
+ return pred_track, pred_vis, pred_conf, pred_point_3d, pred_color
251
+
252
+
253
+ def _augment_non_visible_frames(
254
+ pred_tracks: list, # ← running list of np.ndarrays
255
+ pred_vis_scores: list, # ← running list of np.ndarrays
256
+ pred_confs: list, # ← running list of np.ndarrays for confidence scores
257
+ pred_points_3d: list, # ← running list of np.ndarrays for 3D points
258
+ pred_colors: list, # ← running list of np.ndarrays for colors
259
+ images: torch.Tensor,
260
+ conf,
261
+ points_3d,
262
+ fmaps_for_tracker,
263
+ keypoint_extractors,
264
+ tracker,
265
+ max_points_num: int,
266
+ fine_tracking: bool,
267
+ *,
268
+ min_vis: int = 500,
269
+ non_vis_thresh: float = 0.1,
270
+ device: torch.device = None,
271
+ ):
272
+ """
273
+ Augment tracking for frames with insufficient visibility.
274
+
275
+ Args:
276
+ pred_tracks: List of numpy arrays containing predicted tracks.
277
+ pred_vis_scores: List of numpy arrays containing visibility scores.
278
+ pred_confs: List of numpy arrays containing confidence scores.
279
+ pred_points_3d: List of numpy arrays containing 3D points.
280
+ pred_colors: List of numpy arrays containing point colors.
281
+ images: Tensor of shape [S, 3, H, W] containing the input images.
282
+ conf: Tensor of shape [S, 1, H, W] containing confidence scores
283
+ points_3d: Tensor containing 3D points
284
+ fmaps_for_tracker: Feature maps for the tracker
285
+ keypoint_extractors: Initialized feature extractors
286
+ tracker: VGG-SFM tracker
287
+ max_points_num: Maximum number of points to process at once
288
+ fine_tracking: Whether to use fine tracking
289
+ min_vis: Minimum visibility threshold
290
+ non_vis_thresh: Non-visibility threshold
291
+ device: Device to use for computation
292
+
293
+ Returns:
294
+ Updated pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, and pred_colors lists.
295
+ """
296
+ last_query = -1
297
+ final_trial = False
298
+ cur_extractors = keypoint_extractors # may be replaced on the final trial
299
+
300
+ while True:
301
+ # Visibility per frame
302
+ vis_array = np.concatenate(pred_vis_scores, axis=1)
303
+
304
+ # Count frames with sufficient visibility using numpy
305
+ sufficient_vis_count = (vis_array > non_vis_thresh).sum(axis=-1)
306
+ non_vis_frames = np.where(sufficient_vis_count < min_vis)[0].tolist()
307
+
308
+ if len(non_vis_frames) == 0:
309
+ break
310
+
311
+ print("Processing non visible frames:", non_vis_frames)
312
+
313
+ # Decide the frames & extractor for this round
314
+ if non_vis_frames[0] == last_query:
315
+ # Same frame failed twice - final "all-in" attempt
316
+ final_trial = True
317
+ cur_extractors = initialize_feature_extractors(
318
+ 2048, extractor_method="sp+sift+aliked", device=device
319
+ )
320
+ query_frame_list = non_vis_frames # blast them all at once
321
+ else:
322
+ query_frame_list = [non_vis_frames[0]] # Process one at a time
323
+
324
+ last_query = non_vis_frames[0]
325
+
326
+ # Run the tracker for every selected frame
327
+ for query_index in query_frame_list:
328
+ new_track, new_vis, new_conf, new_point_3d, new_color = _forward_on_query(
329
+ query_index,
330
+ images,
331
+ conf,
332
+ points_3d,
333
+ fmaps_for_tracker,
334
+ cur_extractors,
335
+ tracker,
336
+ max_points_num,
337
+ fine_tracking,
338
+ device,
339
+ )
340
+ pred_tracks.append(new_track)
341
+ pred_vis_scores.append(new_vis)
342
+ pred_confs.append(new_conf)
343
+ pred_points_3d.append(new_point_3d)
344
+ pred_colors.append(new_color)
345
+
346
+ if final_trial:
347
+ break # Stop after final attempt
348
+
349
+ return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors
sailrecon/dependency/vggsfm_tracker.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ from functools import partial
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from einops import rearrange, repeat
15
+ from einops.layers.torch import Rearrange, Reduce
16
+ from hydra.utils import instantiate
17
+ from omegaconf import OmegaConf
18
+ from torch import einsum, nn
19
+
20
+ from .track_modules.base_track_predictor import BaseTrackerPredictor
21
+ from .track_modules.blocks import BasicEncoder, ShallowEncoder
22
+ from .track_modules.track_refine import refine_track
23
+
24
+
25
+ class TrackerPredictor(nn.Module):
26
+ def __init__(self, **extra_args):
27
+ super(TrackerPredictor, self).__init__()
28
+ """
29
+ Initializes the tracker predictor.
30
+
31
+ Both coarse_predictor and fine_predictor are constructed as a BaseTrackerPredictor,
32
+ check track_modules/base_track_predictor.py
33
+
34
+ Both coarse_fnet and fine_fnet are constructed as a 2D CNN network
35
+ check track_modules/blocks.py for BasicEncoder and ShallowEncoder
36
+ """
37
+ # Define coarse predictor configuration
38
+ coarse_stride = 4
39
+ self.coarse_down_ratio = 2
40
+
41
+ # Create networks directly instead of using instantiate
42
+ self.coarse_fnet = BasicEncoder(stride=coarse_stride)
43
+ self.coarse_predictor = BaseTrackerPredictor(stride=coarse_stride)
44
+
45
+ # Create fine predictor with stride = 1
46
+ self.fine_fnet = ShallowEncoder(stride=1)
47
+ self.fine_predictor = BaseTrackerPredictor(
48
+ stride=1,
49
+ depth=4,
50
+ corr_levels=3,
51
+ corr_radius=3,
52
+ latent_dim=32,
53
+ hidden_size=256,
54
+ fine=True,
55
+ use_spaceatt=False,
56
+ )
57
+
58
+ def forward(
59
+ self,
60
+ images,
61
+ query_points,
62
+ fmaps=None,
63
+ coarse_iters=6,
64
+ inference=True,
65
+ fine_tracking=True,
66
+ fine_chunk=40960,
67
+ ):
68
+ """
69
+ Args:
70
+ images (torch.Tensor): Images as RGB, in the range of [0, 1], with a shape of B x S x 3 x H x W.
71
+ query_points (torch.Tensor): 2D xy of query points, relative to top left, with a shape of B x N x 2.
72
+ fmaps (torch.Tensor, optional): Precomputed feature maps. Defaults to None.
73
+ coarse_iters (int, optional): Number of iterations for coarse prediction. Defaults to 6.
74
+ inference (bool, optional): Whether to perform inference. Defaults to True.
75
+ fine_tracking (bool, optional): Whether to perform fine tracking. Defaults to True.
76
+
77
+ Returns:
78
+ tuple: A tuple containing fine_pred_track, coarse_pred_track, pred_vis, and pred_score.
79
+ """
80
+
81
+ if fmaps is None:
82
+ batch_num, frame_num, image_dim, height, width = images.shape
83
+ reshaped_image = images.reshape(
84
+ batch_num * frame_num, image_dim, height, width
85
+ )
86
+ fmaps = self.process_images_to_fmaps(reshaped_image)
87
+ fmaps = fmaps.reshape(
88
+ batch_num, frame_num, -1, fmaps.shape[-2], fmaps.shape[-1]
89
+ )
90
+
91
+ if inference:
92
+ torch.cuda.empty_cache()
93
+
94
+ # Coarse prediction
95
+ coarse_pred_track_lists, pred_vis = self.coarse_predictor(
96
+ query_points=query_points,
97
+ fmaps=fmaps,
98
+ iters=coarse_iters,
99
+ down_ratio=self.coarse_down_ratio,
100
+ )
101
+ coarse_pred_track = coarse_pred_track_lists[-1]
102
+
103
+ if inference:
104
+ torch.cuda.empty_cache()
105
+
106
+ if fine_tracking:
107
+ # Refine the coarse prediction
108
+ fine_pred_track, pred_score = refine_track(
109
+ images,
110
+ self.fine_fnet,
111
+ self.fine_predictor,
112
+ coarse_pred_track,
113
+ compute_score=False,
114
+ chunk=fine_chunk,
115
+ )
116
+
117
+ if inference:
118
+ torch.cuda.empty_cache()
119
+ else:
120
+ fine_pred_track = coarse_pred_track
121
+ pred_score = torch.ones_like(pred_vis)
122
+
123
+ return fine_pred_track, coarse_pred_track, pred_vis, pred_score
124
+
125
+ def process_images_to_fmaps(self, images):
126
+ """
127
+ This function processes images for inference.
128
+
129
+ Args:
130
+ images (torch.Tensor): The images to be processed with shape S x 3 x H x W.
131
+
132
+ Returns:
133
+ torch.Tensor: The processed feature maps.
134
+ """
135
+ if self.coarse_down_ratio > 1:
136
+ # whether or not scale down the input images to save memory
137
+ fmaps = self.coarse_fnet(
138
+ F.interpolate(
139
+ images,
140
+ scale_factor=1 / self.coarse_down_ratio,
141
+ mode="bilinear",
142
+ align_corners=True,
143
+ )
144
+ )
145
+ else:
146
+ fmaps = self.coarse_fnet(images)
147
+
148
+ return fmaps
sailrecon/dependency/vggsfm_utils.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import warnings
9
+ from typing import Dict, List, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ import pycolmap
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from lightglue import ALIKED, SIFT, SuperPoint
16
+
17
+ from .vggsfm_tracker import TrackerPredictor
18
+
19
+ # Suppress verbose logging from dependencies
20
+ logging.getLogger("dinov2").setLevel(logging.WARNING)
21
+ warnings.filterwarnings("ignore", message="xFormers is available")
22
+ warnings.filterwarnings("ignore", message="dinov2")
23
+
24
+ # Constants
25
+ _RESNET_MEAN = [0.485, 0.456, 0.406]
26
+ _RESNET_STD = [0.229, 0.224, 0.225]
27
+
28
+
29
+ def build_vggsfm_tracker(model_path=None):
30
+ """
31
+ Build and initialize the VGGSfM tracker.
32
+
33
+ Args:
34
+ model_path: Path to the model weights file. If None, weights are downloaded from HuggingFace.
35
+
36
+ Returns:
37
+ Initialized tracker model in eval mode.
38
+ """
39
+ tracker = TrackerPredictor()
40
+
41
+ if model_path is None:
42
+ default_url = (
43
+ "https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_tracker.pt"
44
+ )
45
+ tracker.load_state_dict(torch.hub.load_state_dict_from_url(default_url))
46
+ else:
47
+ tracker.load_state_dict(torch.load(model_path))
48
+
49
+ tracker.eval()
50
+ return tracker
51
+
52
+
53
+ def generate_rank_by_dino(
54
+ images,
55
+ query_frame_num,
56
+ image_size=336,
57
+ model_name="dinov2_vitb14_reg",
58
+ device="cuda",
59
+ spatial_similarity=False,
60
+ ):
61
+ """
62
+ Generate a ranking of frames using DINO ViT features.
63
+
64
+ Args:
65
+ images: Tensor of shape (S, 3, H, W) with values in range [0, 1]
66
+ query_frame_num: Number of frames to select
67
+ image_size: Size to resize images to before processing
68
+ model_name: Name of the DINO model to use
69
+ device: Device to run the model on
70
+ spatial_similarity: Whether to use spatial token similarity or CLS token similarity
71
+
72
+ Returns:
73
+ List of frame indices ranked by their representativeness
74
+ """
75
+ # Resize images to the target size
76
+ images = F.interpolate(
77
+ images, (image_size, image_size), mode="bilinear", align_corners=False
78
+ )
79
+
80
+ # Load DINO model
81
+ dino_v2_model = torch.hub.load("facebookresearch/dinov2", model_name)
82
+ dino_v2_model.eval()
83
+ dino_v2_model = dino_v2_model.to(device)
84
+
85
+ # Normalize images using ResNet normalization
86
+ resnet_mean = torch.tensor(_RESNET_MEAN, device=device).view(1, 3, 1, 1)
87
+ resnet_std = torch.tensor(_RESNET_STD, device=device).view(1, 3, 1, 1)
88
+ images_resnet_norm = (images - resnet_mean) / resnet_std
89
+
90
+ with torch.no_grad():
91
+ frame_feat = dino_v2_model(images_resnet_norm, is_training=True)
92
+
93
+ # Process features based on similarity type
94
+ if spatial_similarity:
95
+ frame_feat = frame_feat["x_norm_patchtokens"]
96
+ frame_feat_norm = F.normalize(frame_feat, p=2, dim=1)
97
+
98
+ # Compute the similarity matrix
99
+ frame_feat_norm = frame_feat_norm.permute(1, 0, 2)
100
+ similarity_matrix = torch.bmm(
101
+ frame_feat_norm, frame_feat_norm.transpose(-1, -2)
102
+ )
103
+ similarity_matrix = similarity_matrix.mean(dim=0)
104
+ else:
105
+ frame_feat = frame_feat["x_norm_clstoken"]
106
+ frame_feat_norm = F.normalize(frame_feat, p=2, dim=1)
107
+ similarity_matrix = torch.mm(frame_feat_norm, frame_feat_norm.transpose(-1, -2))
108
+
109
+ distance_matrix = 100 - similarity_matrix.clone()
110
+
111
+ # Ignore self-pairing
112
+ similarity_matrix.fill_diagonal_(-100)
113
+ similarity_sum = similarity_matrix.sum(dim=1)
114
+
115
+ # Find the most common frame
116
+ most_common_frame_index = torch.argmax(similarity_sum).item()
117
+
118
+ # Conduct FPS sampling starting from the most common frame
119
+ fps_idx = farthest_point_sampling(
120
+ distance_matrix, query_frame_num, most_common_frame_index
121
+ )
122
+
123
+ # Clean up all tensors and models to free memory
124
+ del frame_feat, frame_feat_norm, similarity_matrix, distance_matrix
125
+ del dino_v2_model
126
+ torch.cuda.empty_cache()
127
+
128
+ return fps_idx
129
+
130
+
131
+ def farthest_point_sampling(distance_matrix, num_samples, most_common_frame_index=0):
132
+ """
133
+ Farthest point sampling algorithm to select diverse frames.
134
+
135
+ Args:
136
+ distance_matrix: Matrix of distances between frames
137
+ num_samples: Number of frames to select
138
+ most_common_frame_index: Index of the first frame to select
139
+
140
+ Returns:
141
+ List of selected frame indices
142
+ """
143
+ distance_matrix = distance_matrix.clamp(min=0)
144
+ N = distance_matrix.size(0)
145
+
146
+ # Initialize with the most common frame
147
+ selected_indices = [most_common_frame_index]
148
+ check_distances = distance_matrix[selected_indices]
149
+
150
+ while len(selected_indices) < num_samples:
151
+ # Find the farthest point from the current set of selected points
152
+ farthest_point = torch.argmax(check_distances)
153
+ selected_indices.append(farthest_point.item())
154
+
155
+ check_distances = distance_matrix[farthest_point]
156
+ # Mark already selected points to avoid selecting them again
157
+ check_distances[selected_indices] = 0
158
+
159
+ # Break if all points have been selected
160
+ if len(selected_indices) == N:
161
+ break
162
+
163
+ return selected_indices
164
+
165
+
166
+ def calculate_index_mappings(query_index, S, device=None):
167
+ """
168
+ Construct an order that switches [query_index] and [0]
169
+ so that the content of query_index would be placed at [0].
170
+
171
+ Args:
172
+ query_index: Index to swap with 0
173
+ S: Total number of elements
174
+ device: Device to place the tensor on
175
+
176
+ Returns:
177
+ Tensor of indices with the swapped order
178
+ """
179
+ new_order = torch.arange(S)
180
+ new_order[0] = query_index
181
+ new_order[query_index] = 0
182
+ if device is not None:
183
+ new_order = new_order.to(device)
184
+ return new_order
185
+
186
+
187
+ def switch_tensor_order(tensors, order, dim=1):
188
+ """
189
+ Reorder tensors along a specific dimension according to the given order.
190
+
191
+ Args:
192
+ tensors: List of tensors to reorder
193
+ order: Tensor of indices specifying the new order
194
+ dim: Dimension along which to reorder
195
+
196
+ Returns:
197
+ List of reordered tensors
198
+ """
199
+ return [
200
+ torch.index_select(tensor, dim, order) if tensor is not None else None
201
+ for tensor in tensors
202
+ ]
203
+
204
+
205
+ def initialize_feature_extractors(
206
+ max_query_num, det_thres=0.005, extractor_method="aliked", device="cuda"
207
+ ):
208
+ """
209
+ Initialize feature extractors that can be reused based on a method string.
210
+
211
+ Args:
212
+ max_query_num: Maximum number of keypoints to extract
213
+ det_thres: Detection threshold for keypoint extraction
214
+ extractor_method: String specifying which extractors to use (e.g., "aliked", "sp+sift", "aliked+sp+sift")
215
+ device: Device to run extraction on
216
+
217
+ Returns:
218
+ Dictionary of initialized extractors
219
+ """
220
+ extractors = {}
221
+ methods = extractor_method.lower().split("+")
222
+
223
+ for method in methods:
224
+ method = method.strip()
225
+ if method == "aliked":
226
+ aliked_extractor = ALIKED(
227
+ max_num_keypoints=max_query_num, detection_threshold=det_thres
228
+ )
229
+ extractors["aliked"] = aliked_extractor.to(device).eval()
230
+ elif method == "sp":
231
+ sp_extractor = SuperPoint(
232
+ max_num_keypoints=max_query_num, detection_threshold=det_thres
233
+ )
234
+ extractors["sp"] = sp_extractor.to(device).eval()
235
+ elif method == "sift":
236
+ sift_extractor = SIFT(max_num_keypoints=max_query_num)
237
+ extractors["sift"] = sift_extractor.to(device).eval()
238
+ else:
239
+ print(f"Warning: Unknown feature extractor '{method}', ignoring.")
240
+
241
+ if not extractors:
242
+ print(
243
+ f"Warning: No valid extractors found in '{extractor_method}'. Using ALIKED by default."
244
+ )
245
+ aliked_extractor = ALIKED(
246
+ max_num_keypoints=max_query_num, detection_threshold=det_thres
247
+ )
248
+ extractors["aliked"] = aliked_extractor.to(device).eval()
249
+
250
+ return extractors
251
+
252
+
253
+ def extract_keypoints(query_image, extractors, round_keypoints=True):
254
+ """
255
+ Extract keypoints using pre-initialized feature extractors.
256
+
257
+ Args:
258
+ query_image: Input image tensor (3xHxW, range [0, 1])
259
+ extractors: Dictionary of initialized extractors
260
+
261
+ Returns:
262
+ Tensor of keypoint coordinates (1xNx2)
263
+ """
264
+ query_points = None
265
+
266
+ with torch.no_grad():
267
+ for extractor_name, extractor in extractors.items():
268
+ query_points_data = extractor.extract(query_image, invalid_mask=None)
269
+ extractor_points = query_points_data["keypoints"]
270
+ if round_keypoints:
271
+ extractor_points = extractor_points.round()
272
+
273
+ if query_points is not None:
274
+ query_points = torch.cat([query_points, extractor_points], dim=1)
275
+ else:
276
+ query_points = extractor_points
277
+
278
+ return query_points
279
+
280
+
281
+ def predict_tracks_in_chunks(
282
+ track_predictor,
283
+ images_feed,
284
+ query_points_list,
285
+ fmaps_feed,
286
+ fine_tracking,
287
+ num_splits=None,
288
+ fine_chunk=40960,
289
+ ):
290
+ """
291
+ Process a list of query points to avoid memory issues.
292
+
293
+ Args:
294
+ track_predictor (object): The track predictor object used for predicting tracks.
295
+ images_feed (torch.Tensor): A tensor of shape (B, T, C, H, W) representing a batch of images.
296
+ query_points_list (list or tuple): A list/tuple of tensors, each of shape (B, Ni, 2) representing chunks of query points.
297
+ fmaps_feed (torch.Tensor): A tensor of feature maps for the tracker.
298
+ fine_tracking (bool): Whether to perform fine tracking.
299
+ num_splits (int, optional): Ignored when query_points_list is provided. Kept for backward compatibility.
300
+
301
+ Returns:
302
+ tuple: A tuple containing the concatenated predicted tracks, visibility, and scores.
303
+ """
304
+ # If query_points_list is not a list or tuple but a single tensor, handle it like the old version for backward compatibility
305
+ if not isinstance(query_points_list, (list, tuple)):
306
+ query_points = query_points_list
307
+ if num_splits is None:
308
+ num_splits = 1
309
+ query_points_list = torch.chunk(query_points, num_splits, dim=1)
310
+
311
+ # Ensure query_points_list is a list for iteration (as torch.chunk returns a tuple)
312
+ if isinstance(query_points_list, tuple):
313
+ query_points_list = list(query_points_list)
314
+
315
+ fine_pred_track_list = []
316
+ pred_vis_list = []
317
+ pred_score_list = []
318
+
319
+ for split_points in query_points_list:
320
+ # Feed into track predictor for each split
321
+ fine_pred_track, _, pred_vis, pred_score = track_predictor(
322
+ images_feed,
323
+ split_points,
324
+ fmaps=fmaps_feed,
325
+ fine_tracking=fine_tracking,
326
+ fine_chunk=fine_chunk,
327
+ )
328
+ fine_pred_track_list.append(fine_pred_track)
329
+ pred_vis_list.append(pred_vis)
330
+ pred_score_list.append(pred_score)
331
+
332
+ # Concatenate the results from all splits
333
+ fine_pred_track = torch.cat(fine_pred_track_list, dim=2)
334
+ pred_vis = torch.cat(pred_vis_list, dim=2)
335
+
336
+ if pred_score is not None:
337
+ pred_score = torch.cat(pred_score_list, dim=2)
338
+ else:
339
+ pred_score = None
340
+
341
+ return fine_pred_track, pred_vis, pred_score
sailrecon/heads/camera_head.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from sailrecon.heads.head_act import activate_pose
15
+ from sailrecon.layers import Mlp
16
+ from sailrecon.layers.block import Block
17
+
18
+
19
+ class CameraHead(nn.Module):
20
+ """
21
+ CameraHead predicts camera parameters from token representations using iterative refinement.
22
+
23
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ dim_in: int = 2048,
29
+ trunk_depth: int = 4,
30
+ pose_encoding_type: str = "absT_quaR_FoV",
31
+ num_heads: int = 16,
32
+ mlp_ratio: int = 4,
33
+ init_values: float = 0.01,
34
+ trans_act: str = "linear",
35
+ quat_act: str = "linear",
36
+ fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
37
+ ):
38
+ super().__init__()
39
+
40
+ if pose_encoding_type == "absT_quaR_FoV":
41
+ self.target_dim = 9
42
+ else:
43
+ raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
44
+
45
+ self.trans_act = trans_act
46
+ self.quat_act = quat_act
47
+ self.fl_act = fl_act
48
+ self.trunk_depth = trunk_depth
49
+
50
+ # Build the trunk using a sequence of transformer blocks.
51
+ self.trunk = nn.Sequential(
52
+ *[
53
+ Block(
54
+ dim=dim_in,
55
+ num_heads=num_heads,
56
+ mlp_ratio=mlp_ratio,
57
+ init_values=init_values,
58
+ )
59
+ for _ in range(trunk_depth)
60
+ ]
61
+ )
62
+
63
+ # Normalizations for camera token and trunk output.
64
+ self.token_norm = nn.LayerNorm(dim_in)
65
+ self.trunk_norm = nn.LayerNorm(dim_in)
66
+
67
+ # Learnable empty camera pose token.
68
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
69
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
70
+
71
+ # Module for producing modulation parameters: shift, scale, and a gate.
72
+ self.poseLN_modulation = nn.Sequential(
73
+ nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)
74
+ )
75
+
76
+ # Adaptive layer normalization without affine parameters.
77
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
78
+ self.pose_branch = Mlp(
79
+ in_features=dim_in,
80
+ hidden_features=dim_in // 2,
81
+ out_features=self.target_dim,
82
+ drop=0,
83
+ )
84
+
85
+ def forward(
86
+ self,
87
+ aggregated_tokens_list: list,
88
+ cam_token_last_layer: torch.Tensor | None,
89
+ num_iterations: int = 4,
90
+ ) -> list:
91
+ """
92
+ Forward pass to predict camera parameters.
93
+
94
+ Args:
95
+ aggregated_tokens_list (list): List of token tensors from the network;
96
+ the last tensor is used for prediction.
97
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
98
+
99
+ Returns:
100
+ list: A list of predicted camera encodings (post-activation) from each iteration.
101
+ """
102
+ # Use tokens from the last block for camera prediction.
103
+ tokens = aggregated_tokens_list[-1]
104
+
105
+ # Extract the camera tokens
106
+ pose_tokens = tokens[:, :, 0]
107
+ num_recon = cam_token_last_layer.shape[1]
108
+ num_reloc = pose_tokens.shape[1]
109
+ pose_tokens = torch.cat([cam_token_last_layer, pose_tokens], dim=1)
110
+ pose_tokens = self.token_norm(pose_tokens)
111
+
112
+ attention_mask = build_lr_mask(
113
+ S=num_recon + num_reloc,
114
+ no_reloc_list=[i for i in range(num_recon)],
115
+ device=pose_tokens.device,
116
+ )
117
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, attention_mask, num_iterations)
118
+ pred_pose_enc_list_returned = [[] for i in range(num_iterations)]
119
+ for i in range(len(pred_pose_enc_list)):
120
+ pred_pose_enc_list_returned[i] = pred_pose_enc_list[i][:, num_recon:]
121
+ return pred_pose_enc_list_returned
122
+
123
+ def trunk_fn(
124
+ self,
125
+ pose_tokens: torch.Tensor,
126
+ attention_mask: torch.Tensor,
127
+ num_iterations: int,
128
+ ) -> list:
129
+ """
130
+ Iteratively refine camera pose predictions.
131
+
132
+ Args:
133
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
134
+ num_iterations (int): Number of refinement iterations.
135
+
136
+ Returns:
137
+ list: List of activated camera encodings from each iteration.
138
+ """
139
+ B, S, C = pose_tokens.shape # S is expected to be 1.
140
+ pred_pose_enc = None
141
+ pred_pose_enc_list = []
142
+
143
+ for _ in range(num_iterations):
144
+ # Use a learned empty pose for the first iteration.
145
+ if pred_pose_enc is None:
146
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
147
+ else:
148
+ # Detach the previous prediction to avoid backprop through time.
149
+ pred_pose_enc = pred_pose_enc.detach()
150
+ module_input = self.embed_pose(pred_pose_enc)
151
+
152
+ # Generate modulation parameters and split them into shift, scale, and gate components.
153
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(
154
+ 3, dim=-1
155
+ )
156
+
157
+ # Adaptive layer normalization and modulation.
158
+ pose_tokens_modulated = gate_msa * modulate(
159
+ self.adaln_norm(pose_tokens), shift_msa, scale_msa
160
+ )
161
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
162
+
163
+ for idx_, blk in enumerate(self.trunk):
164
+ pose_tokens_modulated = blk(
165
+ pose_tokens_modulated, None, ~attention_mask
166
+ )
167
+ # Compute the delta update for the pose encoding.
168
+ pred_pose_enc_delta = self.pose_branch(
169
+ self.trunk_norm(pose_tokens_modulated)
170
+ )
171
+
172
+ if pred_pose_enc is None:
173
+ pred_pose_enc = pred_pose_enc_delta
174
+ else:
175
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
176
+
177
+ # Apply final activation functions for translation, quaternion, and field-of-view.
178
+ activated_pose = activate_pose(
179
+ pred_pose_enc,
180
+ trans_act=self.trans_act,
181
+ quat_act=self.quat_act,
182
+ fl_act=self.fl_act,
183
+ )
184
+ pred_pose_enc_list.append(activated_pose)
185
+
186
+ return pred_pose_enc_list
187
+
188
+
189
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
190
+ """
191
+ Modulate the input tensor using scaling and shifting parameters.
192
+ """
193
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
194
+ return x * (1 + scale) + shift
195
+
196
+
197
+ def build_lr_mask(S: int, no_reloc_list, device="cpu"):
198
+ """
199
+ Args:
200
+ S (int) : total number of tokens in the sequence.
201
+ r_idx (Sequence) : indices of r-tokens (unique, 0-based).
202
+ device (str) : target device for the mask tensor.
203
+
204
+ Returns:
205
+ attn_mask (torch.BoolTensor) of shape (1, 1, S, S)
206
+ — ready for F.scaled_dot_product_attention (True == masked).
207
+ """
208
+ # ----
209
+ r_idx = torch.tensor(
210
+ [i for i in range(S) if i not in no_reloc_list], dtype=torch.long, device=device
211
+ )
212
+ l_idx = torch.as_tensor(no_reloc_list, dtype=torch.long, device=device).unique(
213
+ sorted=True
214
+ )
215
+
216
+ mask = torch.zeros(S, S, dtype=torch.bool, device=device)
217
+
218
+ # ----
219
+ if l_idx.numel() and r_idx.numel():
220
+ mask[l_idx[:, None], r_idx[None, :]] = True
221
+
222
+ # ----
223
+ if r_idx.numel() > 1:
224
+ mask[r_idx[:, None], r_idx[None, :]] = True
225
+ mask[r_idx, r_idx] = False
226
+
227
+ # ---- 3. 补 batch & head 维度
228
+ return mask.unsqueeze(0).unsqueeze(0)
sailrecon/heads/dpt_head.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Inspired by https://github.com/DepthAnything/Depth-Anything-V2
9
+
10
+
11
+ import os
12
+ from typing import Dict, List, Tuple, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ from .head_act import activate_head
19
+ from .utils import create_uv_grid, position_grid_to_embed
20
+
21
+
22
+ class DPTHead(nn.Module):
23
+ """
24
+ DPT Head for dense prediction tasks.
25
+
26
+ This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
27
+ (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
28
+ backbone and produces dense predictions by fusing multi-scale features.
29
+
30
+ Args:
31
+ dim_in (int): Input dimension (channels).
32
+ patch_size (int, optional): Patch size. Default is 14.
33
+ output_dim (int, optional): Number of output channels. Default is 4.
34
+ activation (str, optional): Activation type. Default is "inv_log".
35
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
36
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
37
+ out_channels (List[int], optional): Output channels for each intermediate layer.
38
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
39
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
40
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
41
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ dim_in: int,
47
+ patch_size: int = 14,
48
+ output_dim: int = 4,
49
+ activation: str = "inv_log",
50
+ conf_activation: str = "expp1",
51
+ features: int = 256,
52
+ out_channels: List[int] = [256, 512, 1024, 1024],
53
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
54
+ pos_embed: bool = True,
55
+ feature_only: bool = False,
56
+ down_ratio: int = 1,
57
+ ) -> None:
58
+ super(DPTHead, self).__init__()
59
+ self.patch_size = patch_size
60
+ self.activation = activation
61
+ self.conf_activation = conf_activation
62
+ self.pos_embed = pos_embed
63
+ self.feature_only = feature_only
64
+ self.down_ratio = down_ratio
65
+ self.intermediate_layer_idx = intermediate_layer_idx
66
+
67
+ self.norm = nn.LayerNorm(dim_in)
68
+
69
+ # Projection layers for each output channel from tokens.
70
+ self.projects = nn.ModuleList(
71
+ [
72
+ nn.Conv2d(
73
+ in_channels=dim_in,
74
+ out_channels=oc,
75
+ kernel_size=1,
76
+ stride=1,
77
+ padding=0,
78
+ )
79
+ for oc in out_channels
80
+ ]
81
+ )
82
+
83
+ # Resize layers for upsampling feature maps.
84
+ self.resize_layers = nn.ModuleList(
85
+ [
86
+ nn.ConvTranspose2d(
87
+ in_channels=out_channels[0],
88
+ out_channels=out_channels[0],
89
+ kernel_size=4,
90
+ stride=4,
91
+ padding=0,
92
+ ),
93
+ nn.ConvTranspose2d(
94
+ in_channels=out_channels[1],
95
+ out_channels=out_channels[1],
96
+ kernel_size=2,
97
+ stride=2,
98
+ padding=0,
99
+ ),
100
+ nn.Identity(),
101
+ nn.Conv2d(
102
+ in_channels=out_channels[3],
103
+ out_channels=out_channels[3],
104
+ kernel_size=3,
105
+ stride=2,
106
+ padding=1,
107
+ ),
108
+ ]
109
+ )
110
+
111
+ self.scratch = _make_scratch(out_channels, features, expand=False)
112
+
113
+ # Attach additional modules to scratch.
114
+ self.scratch.stem_transpose = None
115
+ self.scratch.refinenet1 = _make_fusion_block(features)
116
+ self.scratch.refinenet2 = _make_fusion_block(features)
117
+ self.scratch.refinenet3 = _make_fusion_block(features)
118
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
119
+
120
+ head_features_1 = features
121
+ head_features_2 = 32
122
+
123
+ if feature_only:
124
+ self.scratch.output_conv1 = nn.Conv2d(
125
+ head_features_1, head_features_1, kernel_size=3, stride=1, padding=1
126
+ )
127
+ else:
128
+ self.scratch.output_conv1 = nn.Conv2d(
129
+ head_features_1,
130
+ head_features_1 // 2,
131
+ kernel_size=3,
132
+ stride=1,
133
+ padding=1,
134
+ )
135
+ conv2_in_channels = head_features_1 // 2
136
+
137
+ self.scratch.output_conv2 = nn.Sequential(
138
+ nn.Conv2d(
139
+ conv2_in_channels,
140
+ head_features_2,
141
+ kernel_size=3,
142
+ stride=1,
143
+ padding=1,
144
+ ),
145
+ nn.ReLU(inplace=True),
146
+ nn.Conv2d(
147
+ head_features_2, output_dim, kernel_size=1, stride=1, padding=0
148
+ ),
149
+ )
150
+
151
+ def forward(
152
+ self,
153
+ aggregated_tokens_list: List[torch.Tensor],
154
+ images: torch.Tensor,
155
+ patch_start_idx: int,
156
+ frames_chunk_size: int = 8,
157
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
158
+ """
159
+ Forward pass through the DPT head, supports processing by chunking frames.
160
+ Args:
161
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
162
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
163
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
164
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
165
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
166
+ If None or larger than S, all frames are processed at once. Default: 8.
167
+
168
+ Returns:
169
+ Tensor or Tuple[Tensor, Tensor]:
170
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
171
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
172
+ """
173
+ B, S, _, H, W = images.shape
174
+
175
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
176
+ if frames_chunk_size is None or frames_chunk_size >= S:
177
+ return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
178
+
179
+ # Otherwise, process frames in chunks to manage memory usage
180
+ assert frames_chunk_size > 0
181
+
182
+ # Process frames in batches
183
+ all_preds = []
184
+ all_conf = []
185
+
186
+ for frames_start_idx in range(0, S, frames_chunk_size):
187
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
188
+
189
+ # Process batch of frames
190
+ if self.feature_only:
191
+ chunk_output = self._forward_impl(
192
+ aggregated_tokens_list,
193
+ images,
194
+ patch_start_idx,
195
+ frames_start_idx,
196
+ frames_end_idx,
197
+ )
198
+ all_preds.append(chunk_output)
199
+ else:
200
+ chunk_preds, chunk_conf = self._forward_impl(
201
+ aggregated_tokens_list,
202
+ images,
203
+ patch_start_idx,
204
+ frames_start_idx,
205
+ frames_end_idx,
206
+ )
207
+ all_preds.append(chunk_preds)
208
+ all_conf.append(chunk_conf)
209
+
210
+ # Concatenate results along the sequence dimension
211
+ if self.feature_only:
212
+ return torch.cat(all_preds, dim=1)
213
+ else:
214
+ return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
215
+
216
+ def _forward_impl(
217
+ self,
218
+ aggregated_tokens_list: List[torch.Tensor],
219
+ images: torch.Tensor,
220
+ patch_start_idx: int,
221
+ frames_start_idx: int = None,
222
+ frames_end_idx: int = None,
223
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
224
+ """
225
+ Implementation of the forward pass through the DPT head.
226
+
227
+ This method processes a specific chunk of frames from the sequence.
228
+
229
+ Args:
230
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
231
+ images (Tensor): Input images with shape [B, S, 3, H, W].
232
+ patch_start_idx (int): Starting index for patch tokens.
233
+ frames_start_idx (int, optional): Starting index for frames to process.
234
+ frames_end_idx (int, optional): Ending index for frames to process.
235
+
236
+ Returns:
237
+ Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
238
+ """
239
+ if frames_start_idx is not None and frames_end_idx is not None:
240
+ images = images[:, frames_start_idx:frames_end_idx].contiguous()
241
+
242
+ B, S, _, H, W = images.shape
243
+
244
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
245
+
246
+ out = []
247
+ dpt_idx = 0
248
+
249
+ for layer_idx in self.intermediate_layer_idx:
250
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
251
+
252
+ # Select frames if processing a chunk
253
+ if frames_start_idx is not None and frames_end_idx is not None:
254
+ x = x[:, frames_start_idx:frames_end_idx]
255
+
256
+ x = x.reshape(B * S, -1, x.shape[-1])
257
+
258
+ x = self.norm(x)
259
+
260
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
261
+
262
+ x = self.projects[dpt_idx](x)
263
+ if self.pos_embed:
264
+ x = self._apply_pos_embed(x, W, H)
265
+ x = self.resize_layers[dpt_idx](x)
266
+
267
+ out.append(x)
268
+ dpt_idx += 1
269
+
270
+ # Fuse features from multiple layers.
271
+ out = self.scratch_forward(out)
272
+ # Interpolate fused output to match target image resolution.
273
+ out = custom_interpolate(
274
+ out,
275
+ (
276
+ int(patch_h * self.patch_size / self.down_ratio),
277
+ int(patch_w * self.patch_size / self.down_ratio),
278
+ ),
279
+ mode="bilinear",
280
+ align_corners=True,
281
+ )
282
+
283
+ if self.pos_embed:
284
+ out = self._apply_pos_embed(out, W, H)
285
+
286
+ if self.feature_only:
287
+ return out.view(B, S, *out.shape[1:])
288
+
289
+ out = self.scratch.output_conv2(out)
290
+ preds, conf = activate_head(
291
+ out, activation=self.activation, conf_activation=self.conf_activation
292
+ )
293
+
294
+ preds = preds.view(B, S, *preds.shape[1:])
295
+ conf = conf.view(B, S, *conf.shape[1:])
296
+ return preds, conf
297
+
298
+ def _apply_pos_embed(
299
+ self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1
300
+ ) -> torch.Tensor:
301
+ """
302
+ Apply positional embedding to tensor x.
303
+ """
304
+ patch_w = x.shape[-1]
305
+ patch_h = x.shape[-2]
306
+ pos_embed = create_uv_grid(
307
+ patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device
308
+ )
309
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
310
+ pos_embed = pos_embed * ratio
311
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
312
+ return x + pos_embed
313
+
314
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
315
+ """
316
+ Forward pass through the fusion blocks.
317
+
318
+ Args:
319
+ features (List[Tensor]): List of feature maps from different layers.
320
+
321
+ Returns:
322
+ Tensor: Fused feature map.
323
+ """
324
+ layer_1, layer_2, layer_3, layer_4 = features
325
+
326
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
327
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
328
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
329
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
330
+
331
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
332
+ del layer_4_rn, layer_4
333
+
334
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
335
+ del layer_3_rn, layer_3
336
+
337
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
338
+ del layer_2_rn, layer_2
339
+
340
+ out = self.scratch.refinenet1(out, layer_1_rn)
341
+ del layer_1_rn, layer_1
342
+
343
+ out = self.scratch.output_conv1(out)
344
+ return out
345
+
346
+
347
+ ################################################################################
348
+ # Modules
349
+ ################################################################################
350
+
351
+
352
+ def _make_fusion_block(
353
+ features: int, size: int = None, has_residual: bool = True, groups: int = 1
354
+ ) -> nn.Module:
355
+ return FeatureFusionBlock(
356
+ features,
357
+ nn.ReLU(inplace=True),
358
+ deconv=False,
359
+ bn=False,
360
+ expand=False,
361
+ align_corners=True,
362
+ size=size,
363
+ has_residual=has_residual,
364
+ groups=groups,
365
+ )
366
+
367
+
368
+ def _make_scratch(
369
+ in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False
370
+ ) -> nn.Module:
371
+ scratch = nn.Module()
372
+ out_shape1 = out_shape
373
+ out_shape2 = out_shape
374
+ out_shape3 = out_shape
375
+ if len(in_shape) >= 4:
376
+ out_shape4 = out_shape
377
+
378
+ if expand:
379
+ out_shape1 = out_shape
380
+ out_shape2 = out_shape * 2
381
+ out_shape3 = out_shape * 4
382
+ if len(in_shape) >= 4:
383
+ out_shape4 = out_shape * 8
384
+
385
+ scratch.layer1_rn = nn.Conv2d(
386
+ in_shape[0],
387
+ out_shape1,
388
+ kernel_size=3,
389
+ stride=1,
390
+ padding=1,
391
+ bias=False,
392
+ groups=groups,
393
+ )
394
+ scratch.layer2_rn = nn.Conv2d(
395
+ in_shape[1],
396
+ out_shape2,
397
+ kernel_size=3,
398
+ stride=1,
399
+ padding=1,
400
+ bias=False,
401
+ groups=groups,
402
+ )
403
+ scratch.layer3_rn = nn.Conv2d(
404
+ in_shape[2],
405
+ out_shape3,
406
+ kernel_size=3,
407
+ stride=1,
408
+ padding=1,
409
+ bias=False,
410
+ groups=groups,
411
+ )
412
+ if len(in_shape) >= 4:
413
+ scratch.layer4_rn = nn.Conv2d(
414
+ in_shape[3],
415
+ out_shape4,
416
+ kernel_size=3,
417
+ stride=1,
418
+ padding=1,
419
+ bias=False,
420
+ groups=groups,
421
+ )
422
+ return scratch
423
+
424
+
425
+ class ResidualConvUnit(nn.Module):
426
+ """Residual convolution module."""
427
+
428
+ def __init__(self, features, activation, bn, groups=1):
429
+ """Init.
430
+
431
+ Args:
432
+ features (int): number of features
433
+ """
434
+ super().__init__()
435
+
436
+ self.bn = bn
437
+ self.groups = groups
438
+ self.conv1 = nn.Conv2d(
439
+ features,
440
+ features,
441
+ kernel_size=3,
442
+ stride=1,
443
+ padding=1,
444
+ bias=True,
445
+ groups=self.groups,
446
+ )
447
+ self.conv2 = nn.Conv2d(
448
+ features,
449
+ features,
450
+ kernel_size=3,
451
+ stride=1,
452
+ padding=1,
453
+ bias=True,
454
+ groups=self.groups,
455
+ )
456
+
457
+ self.norm1 = None
458
+ self.norm2 = None
459
+
460
+ self.activation = activation
461
+ self.skip_add = nn.quantized.FloatFunctional()
462
+
463
+ def forward(self, x):
464
+ """Forward pass.
465
+
466
+ Args:
467
+ x (tensor): input
468
+
469
+ Returns:
470
+ tensor: output
471
+ """
472
+
473
+ out = self.activation(x)
474
+ out = self.conv1(out)
475
+ if self.norm1 is not None:
476
+ out = self.norm1(out)
477
+
478
+ out = self.activation(out)
479
+ out = self.conv2(out)
480
+ if self.norm2 is not None:
481
+ out = self.norm2(out)
482
+
483
+ return self.skip_add.add(out, x)
484
+
485
+
486
+ class FeatureFusionBlock(nn.Module):
487
+ """Feature fusion block."""
488
+
489
+ def __init__(
490
+ self,
491
+ features,
492
+ activation,
493
+ deconv=False,
494
+ bn=False,
495
+ expand=False,
496
+ align_corners=True,
497
+ size=None,
498
+ has_residual=True,
499
+ groups=1,
500
+ ):
501
+ """Init.
502
+
503
+ Args:
504
+ features (int): number of features
505
+ """
506
+ super(FeatureFusionBlock, self).__init__()
507
+
508
+ self.deconv = deconv
509
+ self.align_corners = align_corners
510
+ self.groups = groups
511
+ self.expand = expand
512
+ out_features = features
513
+ if self.expand == True:
514
+ out_features = features // 2
515
+
516
+ self.out_conv = nn.Conv2d(
517
+ features,
518
+ out_features,
519
+ kernel_size=1,
520
+ stride=1,
521
+ padding=0,
522
+ bias=True,
523
+ groups=self.groups,
524
+ )
525
+
526
+ if has_residual:
527
+ self.resConfUnit1 = ResidualConvUnit(
528
+ features, activation, bn, groups=self.groups
529
+ )
530
+
531
+ self.has_residual = has_residual
532
+ self.resConfUnit2 = ResidualConvUnit(
533
+ features, activation, bn, groups=self.groups
534
+ )
535
+
536
+ self.skip_add = nn.quantized.FloatFunctional()
537
+ self.size = size
538
+
539
+ def forward(self, *xs, size=None):
540
+ """Forward pass.
541
+
542
+ Returns:
543
+ tensor: output
544
+ """
545
+ output = xs[0]
546
+
547
+ if self.has_residual:
548
+ res = self.resConfUnit1(xs[1])
549
+ output = self.skip_add.add(output, res)
550
+
551
+ output = self.resConfUnit2(output)
552
+
553
+ if (size is None) and (self.size is None):
554
+ modifier = {"scale_factor": 2}
555
+ elif size is None:
556
+ modifier = {"size": self.size}
557
+ else:
558
+ modifier = {"size": size}
559
+
560
+ output = custom_interpolate(
561
+ output, **modifier, mode="bilinear", align_corners=self.align_corners
562
+ )
563
+ output = self.out_conv(output)
564
+
565
+ return output
566
+
567
+
568
+ def custom_interpolate(
569
+ x: torch.Tensor,
570
+ size: Tuple[int, int] = None,
571
+ scale_factor: float = None,
572
+ mode: str = "bilinear",
573
+ align_corners: bool = True,
574
+ ) -> torch.Tensor:
575
+ """
576
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
577
+ """
578
+ if size is None:
579
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
580
+
581
+ INT_MAX = 1610612736
582
+
583
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
584
+
585
+ if input_elements > INT_MAX:
586
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
587
+ interpolated_chunks = [
588
+ nn.functional.interpolate(
589
+ chunk, size=size, mode=mode, align_corners=align_corners
590
+ )
591
+ for chunk in chunks
592
+ ]
593
+ x = torch.cat(interpolated_chunks, dim=0)
594
+ return x.contiguous()
595
+ else:
596
+ return nn.functional.interpolate(
597
+ x, size=size, mode=mode, align_corners=align_corners
598
+ )
sailrecon/heads/head_act.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def activate_pose(
13
+ pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"
14
+ ):
15
+ """
16
+ Activate pose parameters with specified activation functions.
17
+
18
+ Args:
19
+ pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
20
+ trans_act: Activation type for translation component
21
+ quat_act: Activation type for quaternion component
22
+ fl_act: Activation type for focal length component
23
+
24
+ Returns:
25
+ Activated pose parameters tensor
26
+ """
27
+ T = pred_pose_enc[..., :3]
28
+ quat = pred_pose_enc[..., 3:7]
29
+ fl = pred_pose_enc[..., 7:] # or fov
30
+
31
+ T = base_pose_act(T, trans_act)
32
+ quat = base_pose_act(quat, quat_act)
33
+ fl = base_pose_act(fl, fl_act) # or fov
34
+
35
+ pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
36
+
37
+ return pred_pose_enc
38
+
39
+
40
+ def base_pose_act(pose_enc, act_type="linear"):
41
+ """
42
+ Apply basic activation function to pose parameters.
43
+
44
+ Args:
45
+ pose_enc: Tensor containing encoded pose parameters
46
+ act_type: Activation type ("linear", "inv_log", "exp", "relu")
47
+
48
+ Returns:
49
+ Activated pose parameters
50
+ """
51
+ if act_type == "linear":
52
+ return pose_enc
53
+ elif act_type == "inv_log":
54
+ return inverse_log_transform(pose_enc)
55
+ elif act_type == "exp":
56
+ return torch.exp(pose_enc)
57
+ elif act_type == "relu":
58
+ return F.relu(pose_enc)
59
+ else:
60
+ raise ValueError(f"Unknown act_type: {act_type}")
61
+
62
+
63
+ def activate_head(out, activation="norm_exp", conf_activation="expp1"):
64
+ """
65
+ Process network output to extract 3D points and confidence values.
66
+
67
+ Args:
68
+ out: Network output tensor (B, C, H, W)
69
+ activation: Activation type for 3D points
70
+ conf_activation: Activation type for confidence values
71
+
72
+ Returns:
73
+ Tuple of (3D points tensor, confidence tensor)
74
+ """
75
+ # Move channels from last dim to the 4th dimension => (B, H, W, C)
76
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
77
+
78
+ # Split into xyz (first C-1 channels) and confidence (last channel)
79
+ xyz = fmap[:, :, :, :-1]
80
+ conf = fmap[:, :, :, -1]
81
+
82
+ if activation == "norm_exp":
83
+ d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
84
+ xyz_normed = xyz / d
85
+ pts3d = xyz_normed * torch.expm1(d)
86
+ elif activation == "norm":
87
+ pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
88
+ elif activation == "exp":
89
+ pts3d = torch.exp(xyz)
90
+ elif activation == "relu":
91
+ pts3d = F.relu(xyz)
92
+ elif activation == "inv_log":
93
+ pts3d = inverse_log_transform(xyz)
94
+ elif activation == "xy_inv_log":
95
+ xy, z = xyz.split([2, 1], dim=-1)
96
+ z = inverse_log_transform(z)
97
+ pts3d = torch.cat([xy * z, z], dim=-1)
98
+ elif activation == "sigmoid":
99
+ pts3d = torch.sigmoid(xyz)
100
+ elif activation == "linear":
101
+ pts3d = xyz
102
+ else:
103
+ raise ValueError(f"Unknown activation: {activation}")
104
+
105
+ if conf_activation == "expp1":
106
+ conf_out = 1 + conf.exp()
107
+ elif conf_activation == "expp0":
108
+ conf_out = conf.exp()
109
+ elif conf_activation == "sigmoid":
110
+ conf_out = torch.sigmoid(conf)
111
+ else:
112
+ raise ValueError(f"Unknown conf_activation: {conf_activation}")
113
+
114
+ return pts3d, conf_out
115
+
116
+
117
+ def inverse_log_transform(y):
118
+ """
119
+ Apply inverse log transform: sign(y) * (exp(|y|) - 1)
120
+
121
+ Args:
122
+ y: Input tensor
123
+
124
+ Returns:
125
+ Transformed tensor
126
+ """
127
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
sailrecon/heads/track_head.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch.nn as nn
8
+
9
+ from .dpt_head import DPTHead
10
+ from .track_modules.base_track_predictor import BaseTrackerPredictor
11
+
12
+
13
+ class TrackHead(nn.Module):
14
+ """
15
+ Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
16
+ The tracking is performed iteratively, refining predictions over multiple iterations.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ dim_in,
22
+ patch_size=14,
23
+ features=128,
24
+ iters=4,
25
+ predict_conf=True,
26
+ stride=2,
27
+ corr_levels=7,
28
+ corr_radius=4,
29
+ hidden_size=384,
30
+ ):
31
+ """
32
+ Initialize the TrackHead module.
33
+
34
+ Args:
35
+ dim_in (int): Input dimension of tokens from the backbone.
36
+ patch_size (int): Size of image patches used in the vision transformer.
37
+ features (int): Number of feature channels in the feature extractor output.
38
+ iters (int): Number of refinement iterations for tracking predictions.
39
+ predict_conf (bool): Whether to predict confidence scores for tracked points.
40
+ stride (int): Stride value for the tracker predictor.
41
+ corr_levels (int): Number of correlation pyramid levels
42
+ corr_radius (int): Radius for correlation computation, controlling the search area.
43
+ hidden_size (int): Size of hidden layers in the tracker network.
44
+ """
45
+ super().__init__()
46
+
47
+ self.patch_size = patch_size
48
+
49
+ # Feature extractor based on DPT architecture
50
+ # Processes tokens into feature maps for tracking
51
+ self.feature_extractor = DPTHead(
52
+ dim_in=dim_in,
53
+ patch_size=patch_size,
54
+ features=features,
55
+ feature_only=True, # Only output features, no activation
56
+ down_ratio=2, # Reduces spatial dimensions by factor of 2
57
+ pos_embed=False,
58
+ )
59
+
60
+ # Tracker module that predicts point trajectories
61
+ # Takes feature maps and predicts coordinates and visibility
62
+ self.tracker = BaseTrackerPredictor(
63
+ latent_dim=features, # Match the output_dim of feature extractor
64
+ predict_conf=predict_conf,
65
+ stride=stride,
66
+ corr_levels=corr_levels,
67
+ corr_radius=corr_radius,
68
+ hidden_size=hidden_size,
69
+ )
70
+
71
+ self.iters = iters
72
+
73
+ def forward(
74
+ self,
75
+ aggregated_tokens_list,
76
+ images,
77
+ patch_start_idx,
78
+ query_points=None,
79
+ iters=None,
80
+ ):
81
+ """
82
+ Forward pass of the TrackHead.
83
+
84
+ Args:
85
+ aggregated_tokens_list (list): List of aggregated tokens from the backbone.
86
+ images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
87
+ B = batch size, S = sequence length.
88
+ patch_start_idx (int): Starting index for patch tokens.
89
+ query_points (torch.Tensor, optional): Initial query points to track.
90
+ If None, points are initialized by the tracker.
91
+ iters (int, optional): Number of refinement iterations. If None, uses self.iters.
92
+
93
+ Returns:
94
+ tuple:
95
+ - coord_preds (torch.Tensor): Predicted coordinates for tracked points.
96
+ - vis_scores (torch.Tensor): Visibility scores for tracked points.
97
+ - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
98
+ """
99
+ B, S, _, H, W = images.shape
100
+
101
+ # Extract features from tokens
102
+ # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
103
+ feature_maps = self.feature_extractor(
104
+ aggregated_tokens_list, images, patch_start_idx
105
+ )
106
+
107
+ # Use default iterations if not specified
108
+ if iters is None:
109
+ iters = self.iters
110
+
111
+ # Perform tracking using the extracted features
112
+ coord_preds, vis_scores, conf_scores = self.tracker(
113
+ query_points=query_points, fmaps=feature_maps, iters=iters
114
+ )
115
+
116
+ return coord_preds, vis_scores, conf_scores
sailrecon/heads/track_modules/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
sailrecon/heads/track_modules/base_track_predictor.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from einops import rearrange, repeat
10
+
11
+ from .blocks import CorrBlock, EfficientUpdateFormer
12
+ from .modules import Mlp
13
+ from .utils import get_2d_embedding, get_2d_sincos_pos_embed, sample_features4d
14
+
15
+
16
+ class BaseTrackerPredictor(nn.Module):
17
+ def __init__(
18
+ self,
19
+ stride=1,
20
+ corr_levels=5,
21
+ corr_radius=4,
22
+ latent_dim=128,
23
+ hidden_size=384,
24
+ use_spaceatt=True,
25
+ depth=6,
26
+ max_scale=518,
27
+ predict_conf=True,
28
+ ):
29
+ super(BaseTrackerPredictor, self).__init__()
30
+ """
31
+ The base template to create a track predictor
32
+
33
+ Modified from https://github.com/facebookresearch/co-tracker/
34
+ and https://github.com/facebookresearch/vggsfm
35
+ """
36
+
37
+ self.stride = stride
38
+ self.latent_dim = latent_dim
39
+ self.corr_levels = corr_levels
40
+ self.corr_radius = corr_radius
41
+ self.hidden_size = hidden_size
42
+ self.max_scale = max_scale
43
+ self.predict_conf = predict_conf
44
+
45
+ self.flows_emb_dim = latent_dim // 2
46
+
47
+ self.corr_mlp = Mlp(
48
+ in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
49
+ hidden_features=self.hidden_size,
50
+ out_features=self.latent_dim,
51
+ )
52
+
53
+ self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
54
+
55
+ self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
56
+
57
+ space_depth = depth if use_spaceatt else 0
58
+ time_depth = depth
59
+
60
+ self.updateformer = EfficientUpdateFormer(
61
+ space_depth=space_depth,
62
+ time_depth=time_depth,
63
+ input_dim=self.transformer_dim,
64
+ hidden_size=self.hidden_size,
65
+ output_dim=self.latent_dim + 2,
66
+ mlp_ratio=4.0,
67
+ add_space_attn=use_spaceatt,
68
+ )
69
+
70
+ self.fmap_norm = nn.LayerNorm(self.latent_dim)
71
+ self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
72
+
73
+ # A linear layer to update track feats at each iteration
74
+ self.ffeat_updater = nn.Sequential(
75
+ nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()
76
+ )
77
+
78
+ self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
79
+
80
+ if predict_conf:
81
+ self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
82
+
83
+ def forward(
84
+ self,
85
+ query_points,
86
+ fmaps=None,
87
+ iters=6,
88
+ return_feat=False,
89
+ down_ratio=1,
90
+ apply_sigmoid=True,
91
+ ):
92
+ """
93
+ query_points: B x N x 2, the number of batches, tracks, and xy
94
+ fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
95
+ note HH and WW is the size of feature maps instead of original images
96
+ """
97
+ B, N, D = query_points.shape
98
+ B, S, C, HH, WW = fmaps.shape
99
+
100
+ assert D == 2, "Input points must be 2D coordinates"
101
+
102
+ # apply a layernorm to fmaps here
103
+ fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
104
+ fmaps = fmaps.permute(0, 1, 4, 2, 3)
105
+
106
+ # Scale the input query_points because we may downsample the images
107
+ # by down_ratio or self.stride
108
+ # e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
109
+ # its query_points should be query_points/4
110
+ if down_ratio > 1:
111
+ query_points = query_points / float(down_ratio)
112
+
113
+ query_points = query_points / float(self.stride)
114
+
115
+ # Init with coords as the query points
116
+ # It means the search will start from the position of query points at the reference frames
117
+ coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
118
+
119
+ # Sample/extract the features of the query points in the query frame
120
+ query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
121
+
122
+ # init track feats by query feats
123
+ track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
124
+ # back up the init coords
125
+ coords_backup = coords.clone()
126
+
127
+ fcorr_fn = CorrBlock(
128
+ fmaps, num_levels=self.corr_levels, radius=self.corr_radius
129
+ )
130
+
131
+ coord_preds = []
132
+
133
+ # Iterative Refinement
134
+ for _ in range(iters):
135
+ # Detach the gradients from the last iteration
136
+ # (in my experience, not very important for performance)
137
+ coords = coords.detach()
138
+
139
+ fcorrs = fcorr_fn.corr_sample(track_feats, coords)
140
+
141
+ corr_dim = fcorrs.shape[3]
142
+ fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
143
+ fcorrs_ = self.corr_mlp(fcorrs_)
144
+
145
+ # Movement of current coords relative to query points
146
+ flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
147
+
148
+ flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
149
+
150
+ # (In my trials, it is also okay to just add the flows_emb instead of concat)
151
+ flows_emb = torch.cat(
152
+ [flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1
153
+ )
154
+
155
+ track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(
156
+ B * N, S, self.latent_dim
157
+ )
158
+
159
+ # Concatenate them as the input for the transformers
160
+ transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
161
+
162
+ # 2D positional embed
163
+ # TODO: this can be much simplified
164
+ pos_embed = get_2d_sincos_pos_embed(
165
+ self.transformer_dim, grid_size=(HH, WW)
166
+ ).to(query_points.device)
167
+ sampled_pos_emb = sample_features4d(
168
+ pos_embed.expand(B, -1, -1, -1), coords[:, 0]
169
+ )
170
+
171
+ sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(
172
+ 1
173
+ )
174
+
175
+ x = transformer_input + sampled_pos_emb
176
+
177
+ # Add the query ref token to the track feats
178
+ query_ref_token = torch.cat(
179
+ [
180
+ self.query_ref_token[:, 0:1],
181
+ self.query_ref_token[:, 1:2].expand(-1, S - 1, -1),
182
+ ],
183
+ dim=1,
184
+ )
185
+ x = x + query_ref_token.to(x.device).to(x.dtype)
186
+
187
+ # B, N, S, C
188
+ x = rearrange(x, "(b n) s d -> b n s d", b=B)
189
+
190
+ # Compute the delta coordinates and delta track features
191
+ delta, _ = self.updateformer(x)
192
+
193
+ # BN, S, C
194
+ delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
195
+ delta_coords_ = delta[:, :, :2]
196
+ delta_feats_ = delta[:, :, 2:]
197
+
198
+ track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
199
+ delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
200
+
201
+ # Update the track features
202
+ track_feats_ = (
203
+ self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
204
+ )
205
+
206
+ track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(
207
+ 0, 2, 1, 3
208
+ ) # BxSxNxC
209
+
210
+ # B x S x N x 2
211
+ coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
212
+
213
+ # Force coord0 as query
214
+ # because we assume the query points should not be changed
215
+ coords[:, 0] = coords_backup[:, 0]
216
+
217
+ # The predicted tracks are in the original image scale
218
+ if down_ratio > 1:
219
+ coord_preds.append(coords * self.stride * down_ratio)
220
+ else:
221
+ coord_preds.append(coords * self.stride)
222
+
223
+ # B, S, N
224
+ vis_e = self.vis_predictor(
225
+ track_feats.reshape(B * S * N, self.latent_dim)
226
+ ).reshape(B, S, N)
227
+ if apply_sigmoid:
228
+ vis_e = torch.sigmoid(vis_e)
229
+
230
+ if self.predict_conf:
231
+ conf_e = self.conf_predictor(
232
+ track_feats.reshape(B * S * N, self.latent_dim)
233
+ ).reshape(B, S, N)
234
+ if apply_sigmoid:
235
+ conf_e = torch.sigmoid(conf_e)
236
+ else:
237
+ conf_e = None
238
+
239
+ if return_feat:
240
+ return coord_preds, vis_e, track_feats, query_track_feat, conf_e
241
+ else:
242
+ return coord_preds, vis_e, conf_e