Spaces:
Running
on
Zero
Running
on
Zero
hengli
commited on
Commit
·
b7f83b0
1
Parent(s):
132427c
first
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .gitattributes copy +2 -0
- .gitignore +162 -0
- .pre-commit-config.yaml +27 -0
- CODE_OF_CONDUCT.md +80 -0
- CONTRIBUTING.md +31 -0
- LICENSE.txt +21 -0
- README copy.md +93 -0
- app.py +914 -0
- demo.py +131 -0
- demo_gradio.py +921 -0
- docs/traj_ply.png +3 -0
- eval/datasets/mip_360.py +115 -0
- eval/datasets/seven_scenes.py +58 -0
- eval/datasets/tnt.py +116 -0
- eval/datasets/tum.py +53 -0
- eval/readme.md +110 -0
- eval/utils/cropping.py +289 -0
- eval/utils/device.py +95 -0
- eval/utils/eval_pose_ransac.py +315 -0
- eval/utils/eval_utils.py +74 -0
- eval/utils/geometry.py +572 -0
- eval/utils/image.py +232 -0
- eval/utils/load_fn.py +155 -0
- eval/utils/misc.py +131 -0
- eval/utils/pose_enc.py +135 -0
- eval/utils/rotation.py +142 -0
- eval/utils/visual_track.py +244 -0
- pyproject.toml +58 -0
- requirements.txt +10 -0
- requirements_demo.txt +16 -0
- sailrecon/dependency/__init__.py +3 -0
- sailrecon/dependency/distortion.py +223 -0
- sailrecon/dependency/np_to_pycolmap.py +355 -0
- sailrecon/dependency/projection.py +249 -0
- sailrecon/dependency/track_modules/__init__.py +0 -0
- sailrecon/dependency/track_modules/base_track_predictor.py +210 -0
- sailrecon/dependency/track_modules/blocks.py +396 -0
- sailrecon/dependency/track_modules/modules.py +216 -0
- sailrecon/dependency/track_modules/track_refine.py +493 -0
- sailrecon/dependency/track_modules/utils.py +235 -0
- sailrecon/dependency/track_predict.py +349 -0
- sailrecon/dependency/vggsfm_tracker.py +148 -0
- sailrecon/dependency/vggsfm_utils.py +341 -0
- sailrecon/heads/camera_head.py +228 -0
- sailrecon/heads/dpt_head.py +598 -0
- sailrecon/heads/head_act.py +127 -0
- sailrecon/heads/track_head.py +116 -0
- sailrecon/heads/track_modules/__init__.py +5 -0
- sailrecon/heads/track_modules/base_track_predictor.py +242 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
docs/traj_ply.png filter=lfs diff=lfs merge=lfs -text
|
.gitattributes copy
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SCM syntax highlighting & preventing 3-way merges
|
| 2 |
+
pixi.lock merge=binary linguist-language=YAML linguist-generated=true
|
.gitignore
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.hydra/
|
| 2 |
+
output/
|
| 3 |
+
ckpt/
|
| 4 |
+
# Byte-compiled / optimized / DLL files
|
| 5 |
+
__pycache__/
|
| 6 |
+
**/__pycache__/
|
| 7 |
+
*.py[cod]
|
| 8 |
+
*$py.class
|
| 9 |
+
|
| 10 |
+
# C extensions
|
| 11 |
+
*.so
|
| 12 |
+
|
| 13 |
+
# Distribution / packaging
|
| 14 |
+
.Python
|
| 15 |
+
build/
|
| 16 |
+
develop-eggs/
|
| 17 |
+
dist/
|
| 18 |
+
downloads/
|
| 19 |
+
eggs/
|
| 20 |
+
.eggs/
|
| 21 |
+
lib/
|
| 22 |
+
lib64/
|
| 23 |
+
parts/
|
| 24 |
+
sdist/
|
| 25 |
+
var/
|
| 26 |
+
wheels/
|
| 27 |
+
pip-wheel-metadata/
|
| 28 |
+
share/python-wheels/
|
| 29 |
+
*.egg-info/
|
| 30 |
+
.installed.cfg
|
| 31 |
+
*.egg
|
| 32 |
+
MANIFEST
|
| 33 |
+
|
| 34 |
+
# PyInstaller
|
| 35 |
+
# Usually these files are written by a python script from a template
|
| 36 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 37 |
+
*.manifest
|
| 38 |
+
*.spec
|
| 39 |
+
|
| 40 |
+
# Installer logs
|
| 41 |
+
pip-log.txt
|
| 42 |
+
pip-delete-this-directory.txt
|
| 43 |
+
|
| 44 |
+
# Unit test / coverage reports
|
| 45 |
+
htmlcov/
|
| 46 |
+
.tox/
|
| 47 |
+
.nox/
|
| 48 |
+
.coverage
|
| 49 |
+
.coverage.*
|
| 50 |
+
.cache
|
| 51 |
+
nosetests.xml
|
| 52 |
+
coverage.xml
|
| 53 |
+
*.cover
|
| 54 |
+
*.py,cover
|
| 55 |
+
.hypothesis/
|
| 56 |
+
.pytest_cache/
|
| 57 |
+
cover/
|
| 58 |
+
|
| 59 |
+
# Translations
|
| 60 |
+
*.mo
|
| 61 |
+
*.pot
|
| 62 |
+
|
| 63 |
+
# Django stuff:
|
| 64 |
+
*.log
|
| 65 |
+
local_settings.py
|
| 66 |
+
db.sqlite3
|
| 67 |
+
db.sqlite3-journal
|
| 68 |
+
|
| 69 |
+
# Flask stuff:
|
| 70 |
+
instance/
|
| 71 |
+
.webassets-cache
|
| 72 |
+
|
| 73 |
+
# Scrapy stuff:
|
| 74 |
+
.scrapy
|
| 75 |
+
|
| 76 |
+
# Sphinx documentation
|
| 77 |
+
docs/_build/
|
| 78 |
+
|
| 79 |
+
# PyBuilder
|
| 80 |
+
target/
|
| 81 |
+
|
| 82 |
+
# Jupyter Notebook
|
| 83 |
+
.ipynb_checkpoints
|
| 84 |
+
|
| 85 |
+
# IPython
|
| 86 |
+
profile_default/
|
| 87 |
+
ipython_config.py
|
| 88 |
+
|
| 89 |
+
# pyenv
|
| 90 |
+
.python-version
|
| 91 |
+
|
| 92 |
+
# pipenv
|
| 93 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 94 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 95 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 96 |
+
# install all needed dependencies.
|
| 97 |
+
#Pipfile.lock
|
| 98 |
+
|
| 99 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 100 |
+
__pypackages__/
|
| 101 |
+
|
| 102 |
+
# Celery stuff
|
| 103 |
+
celerybeat-schedule
|
| 104 |
+
celerybeat.pid
|
| 105 |
+
|
| 106 |
+
# SageMath parsed files
|
| 107 |
+
*.sage.py
|
| 108 |
+
|
| 109 |
+
# Environments
|
| 110 |
+
.env
|
| 111 |
+
.venv
|
| 112 |
+
env/
|
| 113 |
+
venv/
|
| 114 |
+
ENV/
|
| 115 |
+
env.bak/
|
| 116 |
+
venv.bak/
|
| 117 |
+
|
| 118 |
+
# Spyder project settings
|
| 119 |
+
.spyderproject
|
| 120 |
+
.spyproject
|
| 121 |
+
|
| 122 |
+
# Rope project settings
|
| 123 |
+
.ropeproject
|
| 124 |
+
|
| 125 |
+
# mkdocs documentation
|
| 126 |
+
/site
|
| 127 |
+
|
| 128 |
+
# mypy
|
| 129 |
+
.mypy_cache/
|
| 130 |
+
.dmypy.json
|
| 131 |
+
dmypy.json
|
| 132 |
+
|
| 133 |
+
# Pyre type checker
|
| 134 |
+
.pyre/
|
| 135 |
+
|
| 136 |
+
# pytype static type analyzer
|
| 137 |
+
.pytype/
|
| 138 |
+
|
| 139 |
+
# Profiling data
|
| 140 |
+
.prof
|
| 141 |
+
|
| 142 |
+
# Folder specific to your needs
|
| 143 |
+
**/tmp/
|
| 144 |
+
**/outputs/skyseg.onnx
|
| 145 |
+
skyseg.onnx
|
| 146 |
+
|
| 147 |
+
# pixi environments
|
| 148 |
+
.pixi
|
| 149 |
+
*.egg-info
|
| 150 |
+
|
| 151 |
+
# demo images
|
| 152 |
+
/docs/demo_image
|
| 153 |
+
|
| 154 |
+
# vscode settings
|
| 155 |
+
.vscode/
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
outputs/
|
| 159 |
+
samples/
|
| 160 |
+
tmp_video/
|
| 161 |
+
.gradio/
|
| 162 |
+
input_images_*
|
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
default_language_version:
|
| 2 |
+
python: python3
|
| 3 |
+
|
| 4 |
+
repos:
|
| 5 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 6 |
+
rev: v4.4.0
|
| 7 |
+
hooks:
|
| 8 |
+
- id: trailing-whitespace
|
| 9 |
+
- id: check-ast
|
| 10 |
+
- id: check-merge-conflict
|
| 11 |
+
- id: check-yaml
|
| 12 |
+
- id: end-of-file-fixer
|
| 13 |
+
- id: trailing-whitespace
|
| 14 |
+
args: [--markdown-linebreak-ext=md]
|
| 15 |
+
|
| 16 |
+
- repo: https://github.com/psf/black
|
| 17 |
+
rev: 23.3.0
|
| 18 |
+
hooks:
|
| 19 |
+
- id: black
|
| 20 |
+
language_version: python3
|
| 21 |
+
|
| 22 |
+
- repo: https://github.com/pycqa/isort
|
| 23 |
+
rev: 5.12.0
|
| 24 |
+
hooks:
|
| 25 |
+
- id: isort
|
| 26 |
+
exclude: README.md
|
| 27 |
+
args: ["--profile", "black"]
|
CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
In the interest of fostering an open and welcoming environment, we as
|
| 6 |
+
contributors and maintainers pledge to make participation in our project and
|
| 7 |
+
our community a harassment-free experience for everyone, regardless of age, body
|
| 8 |
+
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
| 9 |
+
level of experience, education, socio-economic status, nationality, personal
|
| 10 |
+
appearance, race, religion, or sexual identity and orientation.
|
| 11 |
+
|
| 12 |
+
## Our Standards
|
| 13 |
+
|
| 14 |
+
Examples of behavior that contributes to creating a positive environment
|
| 15 |
+
include:
|
| 16 |
+
|
| 17 |
+
* Using welcoming and inclusive language
|
| 18 |
+
* Being respectful of differing viewpoints and experiences
|
| 19 |
+
* Gracefully accepting constructive criticism
|
| 20 |
+
* Focusing on what is best for the community
|
| 21 |
+
* Showing empathy towards other community members
|
| 22 |
+
|
| 23 |
+
Examples of unacceptable behavior by participants include:
|
| 24 |
+
|
| 25 |
+
* The use of sexualized language or imagery and unwelcome sexual attention or
|
| 26 |
+
advances
|
| 27 |
+
* Trolling, insulting/derogatory comments, and personal or political attacks
|
| 28 |
+
* Public or private harassment
|
| 29 |
+
* Publishing others' private information, such as a physical or electronic
|
| 30 |
+
address, without explicit permission
|
| 31 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 32 |
+
professional setting
|
| 33 |
+
|
| 34 |
+
## Our Responsibilities
|
| 35 |
+
|
| 36 |
+
Project maintainers are responsible for clarifying the standards of acceptable
|
| 37 |
+
behavior and are expected to take appropriate and fair corrective action in
|
| 38 |
+
response to any instances of unacceptable behavior.
|
| 39 |
+
|
| 40 |
+
Project maintainers have the right and responsibility to remove, edit, or
|
| 41 |
+
reject comments, commits, code, wiki edits, issues, and other contributions
|
| 42 |
+
that are not aligned to this Code of Conduct, or to ban temporarily or
|
| 43 |
+
permanently any contributor for other behaviors that they deem inappropriate,
|
| 44 |
+
threatening, offensive, or harmful.
|
| 45 |
+
|
| 46 |
+
## Scope
|
| 47 |
+
|
| 48 |
+
This Code of Conduct applies within all project spaces, and it also applies when
|
| 49 |
+
an individual is representing the project or its community in public spaces.
|
| 50 |
+
Examples of representing a project or community include using an official
|
| 51 |
+
project e-mail address, posting via an official social media account, or acting
|
| 52 |
+
as an appointed representative at an online or offline event. Representation of
|
| 53 |
+
a project may be further defined and clarified by project maintainers.
|
| 54 |
+
|
| 55 |
+
This Code of Conduct also applies outside the project spaces when there is a
|
| 56 |
+
reasonable belief that an individual's behavior may have a negative impact on
|
| 57 |
+
the project or its community.
|
| 58 |
+
|
| 59 |
+
## Enforcement
|
| 60 |
+
|
| 61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 62 |
+
reported by contacting the project team at <[email protected]>. All
|
| 63 |
+
complaints will be reviewed and investigated and will result in a response that
|
| 64 |
+
is deemed necessary and appropriate to the circumstances. The project team is
|
| 65 |
+
obligated to maintain confidentiality with regard to the reporter of an incident.
|
| 66 |
+
Further details of specific enforcement policies may be posted separately.
|
| 67 |
+
|
| 68 |
+
Project maintainers who do not follow or enforce the Code of Conduct in good
|
| 69 |
+
faith may face temporary or permanent repercussions as determined by other
|
| 70 |
+
members of the project's leadership.
|
| 71 |
+
|
| 72 |
+
## Attribution
|
| 73 |
+
|
| 74 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
| 75 |
+
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
| 76 |
+
|
| 77 |
+
[homepage]: https://www.contributor-covenant.org
|
| 78 |
+
|
| 79 |
+
For answers to common questions about this code of conduct, see
|
| 80 |
+
https://www.contributor-covenant.org/faq
|
CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to vggt
|
| 2 |
+
We want to make contributing to this project as easy and transparent as
|
| 3 |
+
possible.
|
| 4 |
+
|
| 5 |
+
## Pull Requests
|
| 6 |
+
We actively welcome your pull requests.
|
| 7 |
+
|
| 8 |
+
1. Fork the repo and create your branch from `main`.
|
| 9 |
+
2. If you've added code that should be tested, add tests.
|
| 10 |
+
3. If you've changed APIs, update the documentation.
|
| 11 |
+
4. Ensure the test suite passes.
|
| 12 |
+
5. Make sure your code lints.
|
| 13 |
+
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
| 14 |
+
|
| 15 |
+
## Contributor License Agreement ("CLA")
|
| 16 |
+
In order to accept your pull request, we need you to submit a CLA. You only need
|
| 17 |
+
to do this once to work on any of Facebook's open source projects.
|
| 18 |
+
|
| 19 |
+
Complete your CLA here: <https://code.facebook.com/cla>
|
| 20 |
+
|
| 21 |
+
## Issues
|
| 22 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
| 23 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
| 24 |
+
|
| 25 |
+
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
| 26 |
+
disclosure of security bugs. In those cases, please go through the process
|
| 27 |
+
outlined on that page and do not file a public issue.
|
| 28 |
+
|
| 29 |
+
## License
|
| 30 |
+
By contributing to vggt, you agree that your contributions will be licensed
|
| 31 |
+
under the LICENSE file in the root directory of this source tree.
|
LICENSE.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 HKUST SAIL-LAB and Horizon Robotics.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README copy.md
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
<h1>SAIL-Recon: Large SfM by Augmenting Scene Regression with Localization</h1>
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
<a href="https://arxiv.org/pdf/2508.17972"><img src="https://img.shields.io/badge/arXiv-2508.17972-b31b1b" alt="arXiv"></a>
|
| 6 |
+
<a href="https://hkust-sail.github.io/sail-recon/"><img src="https://img.shields.io/badge/Project_Page-green" alt="Project Page"></a>
|
| 7 |
+
<!--<a href='https://huggingface.co/'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-blue'></a>-->
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
**[HKUST Spatial Artificial Intelligence Lab](https://github.com/HKUST-SAIL)**; **[Horizon Robotics](https://en.horizon.auto/)**
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
[Junyuan Deng](https://scholar.google.com/citations?user=KTCPC5IAAAAJ&hl=en), [Heng Li](https://hengli.me/), [Tao Xie](https://github.com/xbillowy), [Weiqiang Ren](https://cn.linkedin.com/in/weiqiang-ren-b2798636), [Qian Zhang](https://cn.linkedin.com/in/qian-zhang-10234b73), [Ping Tan](https://facultyprofiles.hkust.edu.hk/profiles.php?profile=ping-tan-pingtan), [Xiaoyang Guo](https://xy-guo.github.io/)
|
| 14 |
+
</div>
|
| 15 |
+
|
| 16 |
+

|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
## Overview
|
| 21 |
+
|
| 22 |
+
Sail-Recon is a feed-forward Transformer that scales neural scene regression to large-scale Structure-from-Motion by augmenting it with visual localization. From a few anchor views, it constructs a global latent scene representation that encodes both geometry and appearance. Conditioned on this representation, the network directly regresses camera poses, intrinsics, depth maps, and scene coordinate maps for thousands of images in minutes, enabling precise and robust reconstruction without iterative optimization.
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
## TODO
|
| 26 |
+
- [x] Inference Code Release
|
| 27 |
+
- [x] Gradio Demo
|
| 28 |
+
- [ ] Evaluation Script
|
| 29 |
+
|
| 30 |
+
## Quick Start
|
| 31 |
+
|
| 32 |
+
First, clone this repository to your local machine, and install the dependencies (torch, torchvision, numpy, Pillow, and huggingface_hub) following VGGT.
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
git clone https://github.com/HKUST-SAIL/sail-recon.git
|
| 36 |
+
cd sail-recon
|
| 37 |
+
pip install -e .
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
You can download the demo image (e.g., [Barn](https://drive.google.com/file/d/0B-ePgl6HF260NzQySklGdXZyQzA/view?resourcekey=0-luQ7Jaym5BQL6IjxsgXY9A) from [Tanks & Temples](https://www.tanksandtemples.org/)) and put the images in `examples/demo_image`.
|
| 41 |
+
|
| 42 |
+
Now, you can try the model demo:
|
| 43 |
+
```bash
|
| 44 |
+
# Images
|
| 45 |
+
python demo.py --img_dir path/to/your/images --out_dir outputs
|
| 46 |
+
# Video
|
| 47 |
+
python demo.py --vid_dir path/to/your/images --out_dir outputs
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
You can find the ply file and camera pose under `outputs`.
|
| 51 |
+
|
| 52 |
+
We also provide a Gradio demo for easier usage. You can run the demo by:
|
| 53 |
+
```bash
|
| 54 |
+
python demo_gradio.py
|
| 55 |
+
```
|
| 56 |
+
Please note that the Gradio demo is slower than `demo.py` due to the visualization part.
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
## Evaluation
|
| 60 |
+
|
| 61 |
+
Please refer to [this](eval/readme.md) for more details.
|
| 62 |
+
|
| 63 |
+
## Acknowledgements
|
| 64 |
+
|
| 65 |
+
Thanks to these great repositories:
|
| 66 |
+
|
| 67 |
+
[ACE0](https://github.com/nianticlabs/acezero) for the PSNR evaluation;
|
| 68 |
+
|
| 69 |
+
[VGGT](https://github.com/facebookresearch/vggt) for the template of github, gradio and visualization;
|
| 70 |
+
|
| 71 |
+
[Fast3R](https://github.com/facebookresearch/fast3r) for the training data processing and some utility functions;
|
| 72 |
+
|
| 73 |
+
And many other inspiring works in the community.
|
| 74 |
+
|
| 75 |
+
If you find this project useful in your research, please consider citing:
|
| 76 |
+
```bibtex
|
| 77 |
+
@article{dengli2025sail,
|
| 78 |
+
title={SAIL-Recon: Large SfM by Augmenting Scene Regression with Localization},
|
| 79 |
+
author={Deng, Junyuan and Li, Heng and Xie, Tao and Ren, Weiqiang and Zhang, Qian and Tan, Ping and Guo, Xiaoyang},
|
| 80 |
+
journal={arXiv preprint arXiv:2508.17972},
|
| 81 |
+
year={2025}
|
| 82 |
+
}
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## License
|
| 86 |
+
|
| 87 |
+
See the [LICENSE](./LICENSE.txt) file for details about the license under which this code is made available.
|
| 88 |
+
|
| 89 |
+
Please see the license of [VGGT](https://github.com/facebookresearch/vggt) about the other code used in this project.
|
| 90 |
+
|
| 91 |
+
Please see the license of [ACE0](https://github.com/nianticlabs/acezero) about the evaluation used in this project.
|
| 92 |
+
|
| 93 |
+
Please see the license of [Fast3R](https://github.com/facebookresearch/fast3r) about the utility functions used in this project.
|
app.py
ADDED
|
@@ -0,0 +1,914 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import glob
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
import sys
|
| 6 |
+
import time
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
from eval.utils.device import to_cpu
|
| 16 |
+
from eval.utils.eval_utils import uniform_sample
|
| 17 |
+
from sailrecon.models.sail_recon import SailRecon
|
| 18 |
+
from sailrecon.utils.geometry import unproject_depth_map_to_point_map
|
| 19 |
+
from sailrecon.utils.load_fn import load_and_preprocess_images
|
| 20 |
+
from sailrecon.utils.pose_enc import (
|
| 21 |
+
extri_intri_to_pose_encoding,
|
| 22 |
+
pose_encoding_to_extri_intri,
|
| 23 |
+
)
|
| 24 |
+
from visual_util import predictions_to_glb
|
| 25 |
+
|
| 26 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 27 |
+
|
| 28 |
+
print("Initializing and loading SailRecon model...")
|
| 29 |
+
|
| 30 |
+
model = SailRecon(kv_cache=True)
|
| 31 |
+
|
| 32 |
+
model_dir = "ckpt/sailrecon.pt"
|
| 33 |
+
if os.path.exists(model_dir):
|
| 34 |
+
model.load_state_dict(torch.load(model_dir))
|
| 35 |
+
else:
|
| 36 |
+
_URL = "https://huggingface.co/HKUST-SAIL/SAIL-Recon/resolve/main/sailrecon.pt"
|
| 37 |
+
model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
model.eval()
|
| 41 |
+
model = model.to(device)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# -------------------------------------------------------------------------
|
| 45 |
+
# 1) Core model inference
|
| 46 |
+
# -------------------------------------------------------------------------
|
| 47 |
+
def run_model(target_dir, model, anchor_size=100) -> dict:
|
| 48 |
+
"""
|
| 49 |
+
Run the SAIL-Recon model on images in the 'target_dir/images' folder and return predictions.
|
| 50 |
+
"""
|
| 51 |
+
print(f"Processing images from {target_dir}")
|
| 52 |
+
|
| 53 |
+
# Device check
|
| 54 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 55 |
+
if not torch.cuda.is_available():
|
| 56 |
+
raise ValueError("CUDA is not available. Check your environment.")
|
| 57 |
+
|
| 58 |
+
# Move model to device
|
| 59 |
+
model = model.to(device)
|
| 60 |
+
model.eval()
|
| 61 |
+
|
| 62 |
+
# Load and preprocess images
|
| 63 |
+
image_names = glob.glob(os.path.join(target_dir, "images", "*"))
|
| 64 |
+
image_names = sorted(image_names)
|
| 65 |
+
print(f"Found {len(image_names)} images")
|
| 66 |
+
if len(image_names) == 0:
|
| 67 |
+
raise ValueError("No images found. Check your upload.")
|
| 68 |
+
|
| 69 |
+
images = load_and_preprocess_images(image_names).to(device)
|
| 70 |
+
print(f"Preprocessed images shape: {images.shape}")
|
| 71 |
+
# anchor image selection
|
| 72 |
+
select_indices = uniform_sample(len(image_names), min(100, len(image_names)))
|
| 73 |
+
anchor_images = images[select_indices]
|
| 74 |
+
|
| 75 |
+
# Run inference
|
| 76 |
+
print("Running inference...")
|
| 77 |
+
dtype = (
|
| 78 |
+
torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
with torch.cuda.amp.autocast(dtype=dtype):
|
| 83 |
+
print("Processing anchor images ...")
|
| 84 |
+
model.tmp_forward(anchor_images)
|
| 85 |
+
# del model.aggregator.global_blocks
|
| 86 |
+
# relocalization on all images
|
| 87 |
+
predictions_s = []
|
| 88 |
+
with tqdm(total=len(image_names), desc="Relocalizing") as pbar:
|
| 89 |
+
for img_split in images.split(10, dim=0):
|
| 90 |
+
pbar.update(10)
|
| 91 |
+
predictions_s += to_cpu(
|
| 92 |
+
model.reloc(img_split, ret_img=True, memory_save=False)
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
predictions = {}
|
| 96 |
+
predictions["extrinsic"] = torch.cat(
|
| 97 |
+
[s["extrinsic"] for s in predictions_s], dim=0
|
| 98 |
+
) # (S, 4, 4)
|
| 99 |
+
predictions["intrinsic"] = torch.cat(
|
| 100 |
+
[s["intrinsic"] for s in predictions_s], dim=0
|
| 101 |
+
) # (S, 4, 4)
|
| 102 |
+
predictions["depth"] = torch.cat(
|
| 103 |
+
[s["depth_map"] for s in predictions_s], dim=0
|
| 104 |
+
) # (S, H, W, 1)
|
| 105 |
+
predictions["depth_conf"] = torch.cat(
|
| 106 |
+
[s["dpt_cnf"] for s in predictions_s], dim=0
|
| 107 |
+
) # (S, H, W, 1)
|
| 108 |
+
predictions["images"] = torch.cat(
|
| 109 |
+
[s["images"] for s in predictions_s], dim=0
|
| 110 |
+
) # (S, H, W, 3)
|
| 111 |
+
predictions["world_points"] = torch.cat(
|
| 112 |
+
[s["point_map"] for s in predictions_s], dim=0
|
| 113 |
+
) # (S, H, W, 3)
|
| 114 |
+
predictions["world_points_conf"] = torch.cat(
|
| 115 |
+
[s["xyz_cnf"] for s in predictions_s], dim=0
|
| 116 |
+
) # (S, H, W, 3)
|
| 117 |
+
predictions["pose_enc"] = extri_intri_to_pose_encoding(
|
| 118 |
+
predictions["extrinsic"].unsqueeze(0),
|
| 119 |
+
predictions["intrinsic"].unsqueeze(0),
|
| 120 |
+
images.shape[-2:],
|
| 121 |
+
)[
|
| 122 |
+
0
|
| 123 |
+
] # a
|
| 124 |
+
del predictions_s
|
| 125 |
+
|
| 126 |
+
# Convert tensors to numpy
|
| 127 |
+
for key in predictions.keys():
|
| 128 |
+
if isinstance(predictions[key], torch.Tensor):
|
| 129 |
+
predictions[key] = predictions[key].cpu().numpy() # remove batch dimension
|
| 130 |
+
predictions["pose_enc_list"] = None # remove pose_enc_list
|
| 131 |
+
|
| 132 |
+
# Generate world points from depth map
|
| 133 |
+
print("Computing world points from depth map...")
|
| 134 |
+
depth_map = predictions["depth"] # (S, H, W, 1)
|
| 135 |
+
world_points = unproject_depth_map_to_point_map(
|
| 136 |
+
depth_map, predictions["extrinsic"], predictions["intrinsic"]
|
| 137 |
+
)
|
| 138 |
+
predictions["world_points_from_depth"] = world_points
|
| 139 |
+
|
| 140 |
+
# Clean up
|
| 141 |
+
torch.cuda.empty_cache()
|
| 142 |
+
return predictions
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# -------------------------------------------------------------------------
|
| 146 |
+
# 2) Handle uploaded video/images --> produce target_dir + images
|
| 147 |
+
# -------------------------------------------------------------------------
|
| 148 |
+
def handle_uploads(input_video, input_images):
|
| 149 |
+
"""
|
| 150 |
+
Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
|
| 151 |
+
images or extracted frames from video into it. Return (target_dir, image_paths).
|
| 152 |
+
"""
|
| 153 |
+
start_time = time.time()
|
| 154 |
+
gc.collect()
|
| 155 |
+
torch.cuda.empty_cache()
|
| 156 |
+
|
| 157 |
+
# Create a unique folder name
|
| 158 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 159 |
+
target_dir = f"input_images_{timestamp}"
|
| 160 |
+
target_dir_images = os.path.join(target_dir, "images")
|
| 161 |
+
|
| 162 |
+
# Clean up if somehow that folder already exists
|
| 163 |
+
if os.path.exists(target_dir):
|
| 164 |
+
shutil.rmtree(target_dir)
|
| 165 |
+
os.makedirs(target_dir)
|
| 166 |
+
os.makedirs(target_dir_images)
|
| 167 |
+
|
| 168 |
+
image_paths = []
|
| 169 |
+
|
| 170 |
+
# --- Handle images ---
|
| 171 |
+
if input_images is not None:
|
| 172 |
+
for file_data in input_images:
|
| 173 |
+
if isinstance(file_data, dict) and "name" in file_data:
|
| 174 |
+
file_path = file_data["name"]
|
| 175 |
+
else:
|
| 176 |
+
file_path = file_data
|
| 177 |
+
dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
|
| 178 |
+
shutil.copy(file_path, dst_path)
|
| 179 |
+
image_paths.append(dst_path)
|
| 180 |
+
|
| 181 |
+
# --- Handle video ---
|
| 182 |
+
if input_video is not None:
|
| 183 |
+
if isinstance(input_video, dict) and "name" in input_video:
|
| 184 |
+
video_path = input_video["name"]
|
| 185 |
+
else:
|
| 186 |
+
video_path = input_video
|
| 187 |
+
|
| 188 |
+
vs = cv2.VideoCapture(video_path)
|
| 189 |
+
fps = vs.get(cv2.CAP_PROP_FPS)
|
| 190 |
+
|
| 191 |
+
count = 0
|
| 192 |
+
video_frame_num = 0
|
| 193 |
+
while True:
|
| 194 |
+
gotit, frame = vs.read()
|
| 195 |
+
if not gotit:
|
| 196 |
+
break
|
| 197 |
+
count += 1
|
| 198 |
+
image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
|
| 199 |
+
cv2.imwrite(image_path, frame)
|
| 200 |
+
image_paths.append(image_path)
|
| 201 |
+
video_frame_num += 1
|
| 202 |
+
|
| 203 |
+
# Sort final images for gallery
|
| 204 |
+
image_paths = sorted(image_paths)
|
| 205 |
+
|
| 206 |
+
end_time = time.time()
|
| 207 |
+
print(
|
| 208 |
+
f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds"
|
| 209 |
+
)
|
| 210 |
+
return target_dir, image_paths
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# -------------------------------------------------------------------------
|
| 214 |
+
# 3) Update gallery on upload
|
| 215 |
+
# -------------------------------------------------------------------------
|
| 216 |
+
def update_gallery_on_upload(input_video, input_images):
|
| 217 |
+
"""
|
| 218 |
+
Whenever user uploads or changes files, immediately handle them
|
| 219 |
+
and show in the gallery. Return (target_dir, image_paths).
|
| 220 |
+
If nothing is uploaded, returns "None" and empty list.
|
| 221 |
+
"""
|
| 222 |
+
if not input_video and not input_images:
|
| 223 |
+
return None, None, None, None
|
| 224 |
+
target_dir, image_paths = handle_uploads(input_video, input_images)
|
| 225 |
+
return (
|
| 226 |
+
None,
|
| 227 |
+
target_dir,
|
| 228 |
+
image_paths,
|
| 229 |
+
"Upload complete. Click 'Reconstruct' to begin 3D processing.",
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# -------------------------------------------------------------------------
|
| 234 |
+
# 4) Reconstruction: uses the target_dir plus any viz parameters
|
| 235 |
+
# -------------------------------------------------------------------------
|
| 236 |
+
def gradio_demo(
|
| 237 |
+
target_dir,
|
| 238 |
+
conf_thres=3.0,
|
| 239 |
+
frame_filter="All",
|
| 240 |
+
mask_black_bg=False,
|
| 241 |
+
mask_white_bg=False,
|
| 242 |
+
show_cam=True,
|
| 243 |
+
mask_sky=False,
|
| 244 |
+
downsample_ratio=100.0,
|
| 245 |
+
prediction_mode="Pointmap Regression",
|
| 246 |
+
):
|
| 247 |
+
"""
|
| 248 |
+
Perform reconstruction using the already-created target_dir/images.
|
| 249 |
+
"""
|
| 250 |
+
if not os.path.isdir(target_dir) or target_dir == "None":
|
| 251 |
+
return None, "No valid target directory found. Please upload first.", None, None
|
| 252 |
+
|
| 253 |
+
start_time = time.time()
|
| 254 |
+
gc.collect()
|
| 255 |
+
torch.cuda.empty_cache()
|
| 256 |
+
|
| 257 |
+
# Prepare frame_filter dropdown
|
| 258 |
+
target_dir_images = os.path.join(target_dir, "images")
|
| 259 |
+
all_files = (
|
| 260 |
+
sorted(os.listdir(target_dir_images))
|
| 261 |
+
if os.path.isdir(target_dir_images)
|
| 262 |
+
else []
|
| 263 |
+
)
|
| 264 |
+
all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
|
| 265 |
+
frame_filter_choices = ["All"] + all_files
|
| 266 |
+
|
| 267 |
+
print("Running run_model...")
|
| 268 |
+
with torch.no_grad():
|
| 269 |
+
predictions = run_model(target_dir, model)
|
| 270 |
+
|
| 271 |
+
# Save predictions
|
| 272 |
+
prediction_save_path = os.path.join(target_dir, "predictions.npz")
|
| 273 |
+
np.savez(prediction_save_path, **predictions)
|
| 274 |
+
|
| 275 |
+
# Handle None frame_filter
|
| 276 |
+
if frame_filter is None:
|
| 277 |
+
frame_filter = "All"
|
| 278 |
+
|
| 279 |
+
# Build a GLB file name
|
| 280 |
+
glbfile = os.path.join(
|
| 281 |
+
target_dir,
|
| 282 |
+
f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# Convert predictions to GLB
|
| 286 |
+
glbscene = predictions_to_glb(
|
| 287 |
+
predictions,
|
| 288 |
+
conf_thres=conf_thres,
|
| 289 |
+
filter_by_frames=frame_filter,
|
| 290 |
+
mask_black_bg=mask_black_bg,
|
| 291 |
+
mask_white_bg=mask_white_bg,
|
| 292 |
+
show_cam=show_cam,
|
| 293 |
+
mask_sky=mask_sky,
|
| 294 |
+
target_dir=target_dir,
|
| 295 |
+
downsample_ratio=downsample_ratio / 100.0,
|
| 296 |
+
prediction_mode=prediction_mode,
|
| 297 |
+
)
|
| 298 |
+
glbscene.export(file_obj=glbfile)
|
| 299 |
+
|
| 300 |
+
# Cleanup
|
| 301 |
+
del predictions
|
| 302 |
+
gc.collect()
|
| 303 |
+
torch.cuda.empty_cache()
|
| 304 |
+
|
| 305 |
+
end_time = time.time()
|
| 306 |
+
print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
|
| 307 |
+
log_msg = (
|
| 308 |
+
f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
return (
|
| 312 |
+
glbfile,
|
| 313 |
+
log_msg,
|
| 314 |
+
gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
# -------------------------------------------------------------------------
|
| 319 |
+
# 5) Helper functions for UI resets + re-visualization
|
| 320 |
+
# -------------------------------------------------------------------------
|
| 321 |
+
def clear_fields():
|
| 322 |
+
"""
|
| 323 |
+
Clears the 3D viewer, the stored target_dir, and empties the gallery.
|
| 324 |
+
"""
|
| 325 |
+
return None
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def update_log():
|
| 329 |
+
"""
|
| 330 |
+
Display a quick log message while waiting.
|
| 331 |
+
"""
|
| 332 |
+
return "Loading and Reconstructing..."
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def update_visualization(
|
| 336 |
+
target_dir,
|
| 337 |
+
conf_thres,
|
| 338 |
+
frame_filter,
|
| 339 |
+
mask_black_bg,
|
| 340 |
+
mask_white_bg,
|
| 341 |
+
show_cam,
|
| 342 |
+
mask_sky,
|
| 343 |
+
downsample_ratio,
|
| 344 |
+
prediction_mode,
|
| 345 |
+
is_example,
|
| 346 |
+
):
|
| 347 |
+
"""
|
| 348 |
+
Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
|
| 349 |
+
and return it for the 3D viewer. If is_example == "True", skip.
|
| 350 |
+
"""
|
| 351 |
+
|
| 352 |
+
# If it's an example click, skip as requested
|
| 353 |
+
if is_example == "True":
|
| 354 |
+
return (
|
| 355 |
+
None,
|
| 356 |
+
"No reconstruction available. Please click the Reconstruct button first.",
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
|
| 360 |
+
return (
|
| 361 |
+
None,
|
| 362 |
+
"No reconstruction available. Please click the Reconstruct button first.",
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
predictions_path = os.path.join(target_dir, "predictions.npz")
|
| 366 |
+
if not os.path.exists(predictions_path):
|
| 367 |
+
return (
|
| 368 |
+
None,
|
| 369 |
+
f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.",
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
key_list = [
|
| 373 |
+
"pose_enc",
|
| 374 |
+
"depth",
|
| 375 |
+
"depth_conf",
|
| 376 |
+
"world_points",
|
| 377 |
+
"world_points_conf",
|
| 378 |
+
"images",
|
| 379 |
+
"extrinsic",
|
| 380 |
+
"intrinsic",
|
| 381 |
+
"world_points_from_depth",
|
| 382 |
+
]
|
| 383 |
+
|
| 384 |
+
loaded = np.load(predictions_path)
|
| 385 |
+
predictions = {key: np.array(loaded[key]) for key in key_list if key in loaded}
|
| 386 |
+
print(downsample_ratio)
|
| 387 |
+
glbfile = os.path.join(
|
| 388 |
+
target_dir,
|
| 389 |
+
f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_dr{downsample_ratio}_pred{prediction_mode.replace(' ', '_')}.glb",
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
if not os.path.exists(glbfile):
|
| 393 |
+
glbscene = predictions_to_glb(
|
| 394 |
+
predictions,
|
| 395 |
+
conf_thres=conf_thres,
|
| 396 |
+
filter_by_frames=frame_filter,
|
| 397 |
+
mask_black_bg=mask_black_bg,
|
| 398 |
+
mask_white_bg=mask_white_bg,
|
| 399 |
+
show_cam=show_cam,
|
| 400 |
+
mask_sky=mask_sky,
|
| 401 |
+
target_dir=target_dir,
|
| 402 |
+
downsample_ratio=downsample_ratio * 1.0 / 100.0,
|
| 403 |
+
prediction_mode=prediction_mode,
|
| 404 |
+
)
|
| 405 |
+
glbscene.export(file_obj=glbfile)
|
| 406 |
+
|
| 407 |
+
return glbfile, "Updating Visualization"
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
# -------------------------------------------------------------------------
|
| 411 |
+
# Example images
|
| 412 |
+
# -------------------------------------------------------------------------
|
| 413 |
+
|
| 414 |
+
great_wall_video = "examples/videos/great_wall.mp4"
|
| 415 |
+
colosseum_video = "examples/videos/Colosseum.mp4"
|
| 416 |
+
room_video = "examples/videos/room.mp4"
|
| 417 |
+
kitchen_video = "examples/videos/kitchen.mp4"
|
| 418 |
+
fern_video = "examples/videos/fern.mp4"
|
| 419 |
+
single_cartoon_video = "examples/videos/single_cartoon.mp4"
|
| 420 |
+
single_oil_painting_video = "examples/videos/single_oil_painting.mp4"
|
| 421 |
+
pyramid_video = "examples/videos/pyramid.mp4"
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
# -------------------------------------------------------------------------
|
| 425 |
+
# 6) Build Gradio UI
|
| 426 |
+
# -------------------------------------------------------------------------
|
| 427 |
+
theme = gr.themes.Ocean()
|
| 428 |
+
theme.set(
|
| 429 |
+
checkbox_label_background_fill_selected="*button_primary_background_fill",
|
| 430 |
+
checkbox_label_text_color_selected="*button_primary_text_color",
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
with gr.Blocks(
|
| 434 |
+
theme=theme,
|
| 435 |
+
css="""
|
| 436 |
+
.custom-log * {
|
| 437 |
+
font-style: italic;
|
| 438 |
+
font-size: 22px !important;
|
| 439 |
+
background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
|
| 440 |
+
-webkit-background-clip: text;
|
| 441 |
+
background-clip: text;
|
| 442 |
+
font-weight: bold !important;
|
| 443 |
+
color: transparent !important;
|
| 444 |
+
text-align: center !important;
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
.example-log * {
|
| 448 |
+
font-style: italic;
|
| 449 |
+
font-size: 16px !important;
|
| 450 |
+
background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
|
| 451 |
+
-webkit-background-clip: text;
|
| 452 |
+
background-clip: text;
|
| 453 |
+
color: transparent !important;
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
#my_radio .wrap {
|
| 457 |
+
display: flex;
|
| 458 |
+
flex-wrap: nowrap;
|
| 459 |
+
justify-content: center;
|
| 460 |
+
align-items: center;
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
#my_radio .wrap label {
|
| 464 |
+
display: flex;
|
| 465 |
+
width: 50%;
|
| 466 |
+
justify-content: center;
|
| 467 |
+
align-items: center;
|
| 468 |
+
margin: 0;
|
| 469 |
+
padding: 10px 0;
|
| 470 |
+
box-sizing: border-box;
|
| 471 |
+
}
|
| 472 |
+
""",
|
| 473 |
+
) as demo:
|
| 474 |
+
# Instead of gr.State, we use a hidden Textbox:
|
| 475 |
+
is_example = gr.Textbox(label="is_example", visible=False, value="None")
|
| 476 |
+
num_images = gr.Textbox(label="num_images", visible=False, value="None")
|
| 477 |
+
|
| 478 |
+
gr.HTML(
|
| 479 |
+
"""
|
| 480 |
+
<h1>🏛️ SAIL-Recon: Large SfM by Augmenting Scene Regression with Localization</h1>
|
| 481 |
+
<p>
|
| 482 |
+
<a href="https://github.com/HKUST-SAIL/sail-recon">🐙 GitHub Repository</a> |
|
| 483 |
+
<a href="https://hkust-sail.github.io/sail-recon/">Project Page</a>
|
| 484 |
+
</p>
|
| 485 |
+
|
| 486 |
+
<div style="font-size: 16px; line-height: 1.5;">
|
| 487 |
+
<p>Upload a video or a set of images to create a 3D reconstruction of a scene or object. SAIL-Recon takes these images and generates a 3D point cloud, along with estimated camera poses.</p>
|
| 488 |
+
|
| 489 |
+
<h3>Getting Started:</h3>
|
| 490 |
+
<ol>
|
| 491 |
+
<li><strong>Upload Your Data:</strong> Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).</li>
|
| 492 |
+
<li><strong>Preview:</strong> Your uploaded images will appear in the gallery on the left.</li>
|
| 493 |
+
<li><strong>Reconstruct:</strong> Click the "Reconstruct" button to start the 3D reconstruction process.</li>
|
| 494 |
+
<li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note that the visualization of 3D points may be slow for a large number of input images.</li>
|
| 495 |
+
<li>
|
| 496 |
+
<strong>Adjust Visualization (Optional):</strong>
|
| 497 |
+
After reconstruction, you can fine-tune the visualization using the options below
|
| 498 |
+
<details style="display:inline;">
|
| 499 |
+
<summary style="display:inline;">(<strong>click to expand</strong>):</summary>
|
| 500 |
+
<ul>
|
| 501 |
+
<li><em>Confidence Threshold:</em> Adjust the filtering of points based on confidence.</li>
|
| 502 |
+
<li><em>Show Points from Frame:</em> Select specific frames to display in the point cloud.</li>
|
| 503 |
+
<li><em>Show Camera:</em> Toggle the display of estimated camera positions.</li>
|
| 504 |
+
<li><em>Filter Sky / Filter Black Background:</em> Remove sky or black-background points.</li>
|
| 505 |
+
<li><em>Select a Prediction Mode:</em> Choose between "Depthmap and Camera Branch" or "Pointmap Branch."</li>
|
| 506 |
+
</ul>
|
| 507 |
+
</details>
|
| 508 |
+
</li>
|
| 509 |
+
</ol>
|
| 510 |
+
<p><strong style="color: #0ea5e9;">Please note:</strong> <span style="color: #0ea5e9; font-weight: bold;">SAIL-Recon typically reconstructs a scene at 5FPS with full 3D attributes. However, visualizing 3D points may take tens of seconds due to third-party rendering, which is independent of SAIL-Recon's processing time. Using the 'demo.py' can provide much faster processing.</span></p>
|
| 511 |
+
</div>
|
| 512 |
+
"""
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
|
| 516 |
+
|
| 517 |
+
with gr.Row():
|
| 518 |
+
with gr.Column(scale=2):
|
| 519 |
+
input_video = gr.Video(label="Upload Video", interactive=True)
|
| 520 |
+
input_images = gr.File(
|
| 521 |
+
file_count="multiple", label="Upload Images", interactive=True
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
image_gallery = gr.Gallery(
|
| 525 |
+
label="Preview",
|
| 526 |
+
columns=4,
|
| 527 |
+
height="300px",
|
| 528 |
+
show_download_button=True,
|
| 529 |
+
object_fit="contain",
|
| 530 |
+
preview=True,
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
with gr.Column(scale=4):
|
| 534 |
+
with gr.Column():
|
| 535 |
+
gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**")
|
| 536 |
+
log_output = gr.Markdown(
|
| 537 |
+
"Please upload a video or images, then click Reconstruct.",
|
| 538 |
+
elem_classes=["custom-log"],
|
| 539 |
+
)
|
| 540 |
+
reconstruction_output = gr.Model3D(
|
| 541 |
+
height=520, zoom_speed=0.5, pan_speed=0.5
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
with gr.Row():
|
| 545 |
+
submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
|
| 546 |
+
clear_btn = gr.ClearButton(
|
| 547 |
+
[
|
| 548 |
+
input_video,
|
| 549 |
+
input_images,
|
| 550 |
+
reconstruction_output,
|
| 551 |
+
log_output,
|
| 552 |
+
target_dir_output,
|
| 553 |
+
image_gallery,
|
| 554 |
+
],
|
| 555 |
+
scale=1,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
with gr.Row():
|
| 559 |
+
prediction_mode = gr.Radio(
|
| 560 |
+
["Depthmap and Camera Branch", "Pointmap Branch"],
|
| 561 |
+
label="Select a Prediction Mode",
|
| 562 |
+
value="Depthmap and Camera Branch",
|
| 563 |
+
scale=1,
|
| 564 |
+
elem_id="my_radio",
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
with gr.Row():
|
| 568 |
+
conf_thres = gr.Slider(
|
| 569 |
+
minimum=0,
|
| 570 |
+
maximum=100,
|
| 571 |
+
value=50,
|
| 572 |
+
step=0.1,
|
| 573 |
+
label="Confidence Threshold (%)",
|
| 574 |
+
)
|
| 575 |
+
downsample_ratio = gr.Slider(
|
| 576 |
+
minimum=1.0,
|
| 577 |
+
maximum=100,
|
| 578 |
+
value=100,
|
| 579 |
+
step=0.1,
|
| 580 |
+
label="Downsample Ratio(%)",
|
| 581 |
+
)
|
| 582 |
+
frame_filter = gr.Dropdown(
|
| 583 |
+
choices=["All"], value="All", label="Show Points from Frame"
|
| 584 |
+
)
|
| 585 |
+
with gr.Column():
|
| 586 |
+
show_cam = gr.Checkbox(label="Show Camera", value=True)
|
| 587 |
+
mask_sky = gr.Checkbox(label="Filter Sky", value=False)
|
| 588 |
+
mask_black_bg = gr.Checkbox(
|
| 589 |
+
label="Filter Black Background", value=False
|
| 590 |
+
)
|
| 591 |
+
mask_white_bg = gr.Checkbox(
|
| 592 |
+
label="Filter White Background", value=False
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
# ---------------------- Examples section ----------------------
|
| 596 |
+
examples = [
|
| 597 |
+
[
|
| 598 |
+
colosseum_video,
|
| 599 |
+
"22",
|
| 600 |
+
None,
|
| 601 |
+
20.0,
|
| 602 |
+
False,
|
| 603 |
+
False,
|
| 604 |
+
True,
|
| 605 |
+
False,
|
| 606 |
+
"Depthmap and Camera Branch",
|
| 607 |
+
"True",
|
| 608 |
+
],
|
| 609 |
+
[
|
| 610 |
+
pyramid_video,
|
| 611 |
+
"30",
|
| 612 |
+
None,
|
| 613 |
+
35.0,
|
| 614 |
+
False,
|
| 615 |
+
False,
|
| 616 |
+
True,
|
| 617 |
+
False,
|
| 618 |
+
"Depthmap and Camera Branch",
|
| 619 |
+
"True",
|
| 620 |
+
],
|
| 621 |
+
[
|
| 622 |
+
single_cartoon_video,
|
| 623 |
+
"1",
|
| 624 |
+
None,
|
| 625 |
+
15.0,
|
| 626 |
+
False,
|
| 627 |
+
False,
|
| 628 |
+
True,
|
| 629 |
+
False,
|
| 630 |
+
"Depthmap and Camera Branch",
|
| 631 |
+
"True",
|
| 632 |
+
],
|
| 633 |
+
[
|
| 634 |
+
single_oil_painting_video,
|
| 635 |
+
"1",
|
| 636 |
+
None,
|
| 637 |
+
20.0,
|
| 638 |
+
False,
|
| 639 |
+
False,
|
| 640 |
+
True,
|
| 641 |
+
True,
|
| 642 |
+
"Depthmap and Camera Branch",
|
| 643 |
+
"True",
|
| 644 |
+
],
|
| 645 |
+
[
|
| 646 |
+
room_video,
|
| 647 |
+
"8",
|
| 648 |
+
None,
|
| 649 |
+
5.0,
|
| 650 |
+
False,
|
| 651 |
+
False,
|
| 652 |
+
True,
|
| 653 |
+
False,
|
| 654 |
+
"Depthmap and Camera Branch",
|
| 655 |
+
"True",
|
| 656 |
+
],
|
| 657 |
+
[
|
| 658 |
+
kitchen_video,
|
| 659 |
+
"25",
|
| 660 |
+
None,
|
| 661 |
+
50.0,
|
| 662 |
+
False,
|
| 663 |
+
False,
|
| 664 |
+
True,
|
| 665 |
+
False,
|
| 666 |
+
"Depthmap and Camera Branch",
|
| 667 |
+
"True",
|
| 668 |
+
],
|
| 669 |
+
[
|
| 670 |
+
fern_video,
|
| 671 |
+
"20",
|
| 672 |
+
None,
|
| 673 |
+
45.0,
|
| 674 |
+
False,
|
| 675 |
+
False,
|
| 676 |
+
True,
|
| 677 |
+
False,
|
| 678 |
+
"Depthmap and Camera Branch",
|
| 679 |
+
"True",
|
| 680 |
+
],
|
| 681 |
+
]
|
| 682 |
+
|
| 683 |
+
def example_pipeline(
|
| 684 |
+
input_video,
|
| 685 |
+
num_images_str,
|
| 686 |
+
input_images,
|
| 687 |
+
conf_thres,
|
| 688 |
+
mask_black_bg,
|
| 689 |
+
mask_white_bg,
|
| 690 |
+
show_cam,
|
| 691 |
+
mask_sky,
|
| 692 |
+
downsample_ratio,
|
| 693 |
+
prediction_mode,
|
| 694 |
+
is_example_str,
|
| 695 |
+
):
|
| 696 |
+
"""
|
| 697 |
+
1) Copy example images to new target_dir
|
| 698 |
+
2) Reconstruct
|
| 699 |
+
3) Return model3D + logs + new_dir + updated dropdown + gallery
|
| 700 |
+
We do NOT return is_example. It's just an input.
|
| 701 |
+
"""
|
| 702 |
+
target_dir, image_paths = handle_uploads(input_video, input_images)
|
| 703 |
+
# Always use "All" for frame_filter in examples
|
| 704 |
+
frame_filter = "All"
|
| 705 |
+
glbfile, log_msg, dropdown = gradio_demo(
|
| 706 |
+
target_dir,
|
| 707 |
+
conf_thres,
|
| 708 |
+
frame_filter,
|
| 709 |
+
mask_black_bg,
|
| 710 |
+
mask_white_bg,
|
| 711 |
+
show_cam,
|
| 712 |
+
mask_sky,
|
| 713 |
+
downsample_ratio,
|
| 714 |
+
prediction_mode,
|
| 715 |
+
)
|
| 716 |
+
return glbfile, log_msg, target_dir, dropdown, image_paths
|
| 717 |
+
|
| 718 |
+
gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
|
| 719 |
+
|
| 720 |
+
# gr.Examples(
|
| 721 |
+
# examples=examples,
|
| 722 |
+
# inputs=[
|
| 723 |
+
# input_video,
|
| 724 |
+
# num_images,
|
| 725 |
+
# input_images,
|
| 726 |
+
# conf_thres,
|
| 727 |
+
# mask_black_bg,
|
| 728 |
+
# mask_white_bg,
|
| 729 |
+
# show_cam,
|
| 730 |
+
# mask_sky,
|
| 731 |
+
# downsample_ratio,
|
| 732 |
+
# prediction_mode,
|
| 733 |
+
# is_example,
|
| 734 |
+
# ],
|
| 735 |
+
# outputs=[reconstruction_output, log_output, target_dir_output, frame_filter, image_gallery],
|
| 736 |
+
# fn=example_pipeline,
|
| 737 |
+
# cache_examples=False,
|
| 738 |
+
# examples_per_page=50,
|
| 739 |
+
# )
|
| 740 |
+
|
| 741 |
+
# -------------------------------------------------------------------------
|
| 742 |
+
# "Reconstruct" button logic:
|
| 743 |
+
# - Clear fields
|
| 744 |
+
# - Update log
|
| 745 |
+
# - gradio_demo(...) with the existing target_dir
|
| 746 |
+
# - Then set is_example = "False"
|
| 747 |
+
# -------------------------------------------------------------------------
|
| 748 |
+
submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
|
| 749 |
+
fn=update_log, inputs=[], outputs=[log_output]
|
| 750 |
+
).then(
|
| 751 |
+
fn=gradio_demo,
|
| 752 |
+
inputs=[
|
| 753 |
+
target_dir_output,
|
| 754 |
+
conf_thres,
|
| 755 |
+
frame_filter,
|
| 756 |
+
mask_black_bg,
|
| 757 |
+
mask_white_bg,
|
| 758 |
+
show_cam,
|
| 759 |
+
mask_sky,
|
| 760 |
+
downsample_ratio,
|
| 761 |
+
prediction_mode,
|
| 762 |
+
],
|
| 763 |
+
outputs=[reconstruction_output, log_output, frame_filter],
|
| 764 |
+
).then(
|
| 765 |
+
fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
# -------------------------------------------------------------------------
|
| 769 |
+
# Real-time Visualization Updates
|
| 770 |
+
# -------------------------------------------------------------------------
|
| 771 |
+
conf_thres.change(
|
| 772 |
+
update_visualization,
|
| 773 |
+
[
|
| 774 |
+
target_dir_output,
|
| 775 |
+
conf_thres,
|
| 776 |
+
frame_filter,
|
| 777 |
+
mask_black_bg,
|
| 778 |
+
mask_white_bg,
|
| 779 |
+
show_cam,
|
| 780 |
+
mask_sky,
|
| 781 |
+
downsample_ratio,
|
| 782 |
+
prediction_mode,
|
| 783 |
+
is_example,
|
| 784 |
+
],
|
| 785 |
+
[reconstruction_output, log_output],
|
| 786 |
+
)
|
| 787 |
+
downsample_ratio.change(
|
| 788 |
+
update_visualization,
|
| 789 |
+
[
|
| 790 |
+
target_dir_output,
|
| 791 |
+
conf_thres,
|
| 792 |
+
frame_filter,
|
| 793 |
+
mask_black_bg,
|
| 794 |
+
mask_white_bg,
|
| 795 |
+
show_cam,
|
| 796 |
+
mask_sky,
|
| 797 |
+
downsample_ratio,
|
| 798 |
+
prediction_mode,
|
| 799 |
+
is_example,
|
| 800 |
+
],
|
| 801 |
+
[reconstruction_output, log_output],
|
| 802 |
+
)
|
| 803 |
+
frame_filter.change(
|
| 804 |
+
update_visualization,
|
| 805 |
+
[
|
| 806 |
+
target_dir_output,
|
| 807 |
+
conf_thres,
|
| 808 |
+
frame_filter,
|
| 809 |
+
mask_black_bg,
|
| 810 |
+
mask_white_bg,
|
| 811 |
+
show_cam,
|
| 812 |
+
mask_sky,
|
| 813 |
+
downsample_ratio,
|
| 814 |
+
prediction_mode,
|
| 815 |
+
is_example,
|
| 816 |
+
],
|
| 817 |
+
[reconstruction_output, log_output],
|
| 818 |
+
)
|
| 819 |
+
mask_black_bg.change(
|
| 820 |
+
update_visualization,
|
| 821 |
+
[
|
| 822 |
+
target_dir_output,
|
| 823 |
+
conf_thres,
|
| 824 |
+
frame_filter,
|
| 825 |
+
mask_black_bg,
|
| 826 |
+
mask_white_bg,
|
| 827 |
+
show_cam,
|
| 828 |
+
mask_sky,
|
| 829 |
+
downsample_ratio,
|
| 830 |
+
prediction_mode,
|
| 831 |
+
is_example,
|
| 832 |
+
],
|
| 833 |
+
[reconstruction_output, log_output],
|
| 834 |
+
)
|
| 835 |
+
mask_white_bg.change(
|
| 836 |
+
update_visualization,
|
| 837 |
+
[
|
| 838 |
+
target_dir_output,
|
| 839 |
+
conf_thres,
|
| 840 |
+
frame_filter,
|
| 841 |
+
mask_black_bg,
|
| 842 |
+
mask_white_bg,
|
| 843 |
+
show_cam,
|
| 844 |
+
mask_sky,
|
| 845 |
+
downsample_ratio,
|
| 846 |
+
prediction_mode,
|
| 847 |
+
is_example,
|
| 848 |
+
],
|
| 849 |
+
[reconstruction_output, log_output],
|
| 850 |
+
)
|
| 851 |
+
show_cam.change(
|
| 852 |
+
update_visualization,
|
| 853 |
+
[
|
| 854 |
+
target_dir_output,
|
| 855 |
+
conf_thres,
|
| 856 |
+
frame_filter,
|
| 857 |
+
mask_black_bg,
|
| 858 |
+
mask_white_bg,
|
| 859 |
+
show_cam,
|
| 860 |
+
mask_sky,
|
| 861 |
+
downsample_ratio,
|
| 862 |
+
prediction_mode,
|
| 863 |
+
is_example,
|
| 864 |
+
],
|
| 865 |
+
[reconstruction_output, log_output],
|
| 866 |
+
)
|
| 867 |
+
mask_sky.change(
|
| 868 |
+
update_visualization,
|
| 869 |
+
[
|
| 870 |
+
target_dir_output,
|
| 871 |
+
conf_thres,
|
| 872 |
+
frame_filter,
|
| 873 |
+
mask_black_bg,
|
| 874 |
+
mask_white_bg,
|
| 875 |
+
show_cam,
|
| 876 |
+
mask_sky,
|
| 877 |
+
downsample_ratio,
|
| 878 |
+
prediction_mode,
|
| 879 |
+
is_example,
|
| 880 |
+
],
|
| 881 |
+
[reconstruction_output, log_output],
|
| 882 |
+
)
|
| 883 |
+
prediction_mode.change(
|
| 884 |
+
update_visualization,
|
| 885 |
+
[
|
| 886 |
+
target_dir_output,
|
| 887 |
+
conf_thres,
|
| 888 |
+
frame_filter,
|
| 889 |
+
mask_black_bg,
|
| 890 |
+
mask_white_bg,
|
| 891 |
+
show_cam,
|
| 892 |
+
mask_sky,
|
| 893 |
+
downsample_ratio,
|
| 894 |
+
prediction_mode,
|
| 895 |
+
is_example,
|
| 896 |
+
],
|
| 897 |
+
[reconstruction_output, log_output],
|
| 898 |
+
)
|
| 899 |
+
|
| 900 |
+
# -------------------------------------------------------------------------
|
| 901 |
+
# Auto-update gallery whenever user uploads or changes their files
|
| 902 |
+
# -------------------------------------------------------------------------
|
| 903 |
+
input_video.change(
|
| 904 |
+
fn=update_gallery_on_upload,
|
| 905 |
+
inputs=[input_video, input_images],
|
| 906 |
+
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
|
| 907 |
+
)
|
| 908 |
+
input_images.change(
|
| 909 |
+
fn=update_gallery_on_upload,
|
| 910 |
+
inputs=[input_video, input_images],
|
| 911 |
+
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
|
| 912 |
+
)
|
| 913 |
+
|
| 914 |
+
demo.queue(max_size=20).launch(show_error=True, share=True)
|
demo.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) HKUST SAIL-Lab and Horizon Robotics.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 7 |
+
#
|
| 8 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 9 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
from eval.utils.device import to_cpu
|
| 18 |
+
from eval.utils.eval_utils import uniform_sample
|
| 19 |
+
from sailrecon.models.sail_recon import SailRecon
|
| 20 |
+
from sailrecon.utils.load_fn import load_and_preprocess_images
|
| 21 |
+
|
| 22 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 23 |
+
# bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+)
|
| 24 |
+
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def demo(args):
|
| 28 |
+
# Initialize the model and load the pretrained weights.
|
| 29 |
+
# This will automatically download the model weights the first time it's run, which may take a while.
|
| 30 |
+
_URL = "https://huggingface.co/HKUST-SAIL/SAIL-Recon/resolve/main/sailrecon.pt"
|
| 31 |
+
model_dir = args.ckpt
|
| 32 |
+
# model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
|
| 33 |
+
model = SailRecon(kv_cache=True)
|
| 34 |
+
if model_dir is not None:
|
| 35 |
+
model.load_state_dict(torch.load(model_dir))
|
| 36 |
+
else:
|
| 37 |
+
model.load_state_dict(
|
| 38 |
+
torch.hub.load_state_dict_from_url(_URL, model_dir=model_dir)
|
| 39 |
+
)
|
| 40 |
+
model = model.to(device=device)
|
| 41 |
+
model.eval()
|
| 42 |
+
|
| 43 |
+
# Load and preprocess example images
|
| 44 |
+
scene_name = "1"
|
| 45 |
+
if args.vid_dir is not None:
|
| 46 |
+
import cv2
|
| 47 |
+
|
| 48 |
+
image_names = []
|
| 49 |
+
video_path = args.vid_dir
|
| 50 |
+
vs = cv2.VideoCapture(video_path)
|
| 51 |
+
fps = vs.get(cv2.CAP_PROP_FPS)
|
| 52 |
+
tmp_file = os.path.join("tmp_video", os.path.basename(video_path).split(".")[0])
|
| 53 |
+
os.makedirs(tmp_file, exist_ok=True)
|
| 54 |
+
count = 0
|
| 55 |
+
video_frame_num = 0
|
| 56 |
+
while True:
|
| 57 |
+
gotit, frame = vs.read()
|
| 58 |
+
if not gotit:
|
| 59 |
+
break
|
| 60 |
+
count += 1
|
| 61 |
+
image_path = os.path.join(tmp_file, f"{video_frame_num:06}.png")
|
| 62 |
+
cv2.imwrite(image_path, frame)
|
| 63 |
+
image_names.append(image_path)
|
| 64 |
+
video_frame_num += 1
|
| 65 |
+
images = load_and_preprocess_images(image_names).to(device)
|
| 66 |
+
scene_name = os.path.basename(video_path).split(".")[0]
|
| 67 |
+
else:
|
| 68 |
+
image_names = os.listdir(args.img_dir)
|
| 69 |
+
image_names = [os.path.join(args.img_dir, f) for f in sorted(image_names)]
|
| 70 |
+
images = load_and_preprocess_images(image_names).to(device)
|
| 71 |
+
scene_name = os.path.basename(args.img_dir)
|
| 72 |
+
|
| 73 |
+
# anchor image selection
|
| 74 |
+
select_indices = uniform_sample(len(image_names), min(100, len(image_names)))
|
| 75 |
+
anchor_images = images[select_indices]
|
| 76 |
+
|
| 77 |
+
os.makedirs(os.path.join(args.out_dir, scene_name), exist_ok=True)
|
| 78 |
+
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
with torch.cuda.amp.autocast(dtype=dtype):
|
| 81 |
+
# processing anchor images to build scene representation (kv_cache)
|
| 82 |
+
print("Processing anchor images ...")
|
| 83 |
+
model.tmp_forward(anchor_images)
|
| 84 |
+
# remove the global transformer blocks to save memory during relocalization
|
| 85 |
+
del model.aggregator.global_blocks
|
| 86 |
+
# relocalization on all images
|
| 87 |
+
predictions = []
|
| 88 |
+
|
| 89 |
+
with tqdm(total=len(image_names), desc="Relocalizing") as pbar:
|
| 90 |
+
for img_split in images.split(20, dim=0):
|
| 91 |
+
pbar.update(20)
|
| 92 |
+
predictions += to_cpu(model.reloc(img_split))
|
| 93 |
+
|
| 94 |
+
# save the predicted point cloud and camera poses
|
| 95 |
+
|
| 96 |
+
from eval.utils.geometry import save_pointcloud_with_plyfile
|
| 97 |
+
|
| 98 |
+
save_pointcloud_with_plyfile(
|
| 99 |
+
predictions, os.path.join(args.out_dir, scene_name, "pred.ply")
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
import numpy as np
|
| 103 |
+
|
| 104 |
+
from eval.utils.eval_utils import save_kitti_poses
|
| 105 |
+
|
| 106 |
+
poses_w2c_estimated = [
|
| 107 |
+
one_result["extrinsic"][0].cpu().numpy() for one_result in predictions
|
| 108 |
+
]
|
| 109 |
+
poses_c2w_estimated = [
|
| 110 |
+
np.linalg.inv(np.vstack([pose, np.array([0, 0, 0, 1])]))
|
| 111 |
+
for pose in poses_w2c_estimated
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
save_kitti_poses(
|
| 115 |
+
poses_c2w_estimated,
|
| 116 |
+
os.path.join(args.out_dir, scene_name, "pred.txt"),
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
args = argparse.ArgumentParser()
|
| 122 |
+
args.add_argument(
|
| 123 |
+
"--img_dir", type=str, default="samples/kitchen", help="input image folder"
|
| 124 |
+
)
|
| 125 |
+
args.add_argument("--vid_dir", type=str, default=None, help="input video path")
|
| 126 |
+
args.add_argument("--out_dir", type=str, default="outputs", help="output folder")
|
| 127 |
+
args.add_argument(
|
| 128 |
+
"--ckpt", type=str, default=None, help="pretrained model checkpoint"
|
| 129 |
+
)
|
| 130 |
+
args = args.parse_args()
|
| 131 |
+
demo(args)
|
demo_gradio.py
ADDED
|
@@ -0,0 +1,921 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) HKUST SAIL-Lab and Horizon Robotics.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 7 |
+
#
|
| 8 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 9 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 10 |
+
|
| 11 |
+
import gc
|
| 12 |
+
import glob
|
| 13 |
+
import os
|
| 14 |
+
import shutil
|
| 15 |
+
import sys
|
| 16 |
+
import time
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
|
| 19 |
+
import cv2
|
| 20 |
+
import gradio as gr
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
|
| 25 |
+
from eval.utils.device import to_cpu
|
| 26 |
+
from eval.utils.eval_utils import uniform_sample
|
| 27 |
+
from sailrecon.models.sail_recon import SailRecon
|
| 28 |
+
from sailrecon.utils.geometry import unproject_depth_map_to_point_map
|
| 29 |
+
from sailrecon.utils.load_fn import load_and_preprocess_images
|
| 30 |
+
from sailrecon.utils.pose_enc import (
|
| 31 |
+
extri_intri_to_pose_encoding,
|
| 32 |
+
pose_encoding_to_extri_intri,
|
| 33 |
+
)
|
| 34 |
+
from visual_util import predictions_to_glb
|
| 35 |
+
|
| 36 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 37 |
+
|
| 38 |
+
print("Initializing and loading SailRecon model...")
|
| 39 |
+
|
| 40 |
+
model = SailRecon(kv_cache=True)
|
| 41 |
+
# _URL = "https://huggingface.co/HKUST-SAIL/SAIL-Recon/resolve/main/sailrecon.pt"
|
| 42 |
+
# model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
|
| 43 |
+
model_dir = "ckpt/sailrecon.pt"
|
| 44 |
+
model.load_state_dict(torch.load(model_dir))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
model.eval()
|
| 48 |
+
model = model.to(device)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# -------------------------------------------------------------------------
|
| 52 |
+
# 1) Core model inference
|
| 53 |
+
# -------------------------------------------------------------------------
|
| 54 |
+
def run_model(target_dir, model, anchor_size=100) -> dict:
|
| 55 |
+
"""
|
| 56 |
+
Run the SAIL-Recon model on images in the 'target_dir/images' folder and return predictions.
|
| 57 |
+
"""
|
| 58 |
+
print(f"Processing images from {target_dir}")
|
| 59 |
+
|
| 60 |
+
# Device check
|
| 61 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 62 |
+
if not torch.cuda.is_available():
|
| 63 |
+
raise ValueError("CUDA is not available. Check your environment.")
|
| 64 |
+
|
| 65 |
+
# Move model to device
|
| 66 |
+
model = model.to(device)
|
| 67 |
+
model.eval()
|
| 68 |
+
|
| 69 |
+
# Load and preprocess images
|
| 70 |
+
image_names = glob.glob(os.path.join(target_dir, "images", "*"))
|
| 71 |
+
image_names = sorted(image_names)
|
| 72 |
+
print(f"Found {len(image_names)} images")
|
| 73 |
+
if len(image_names) == 0:
|
| 74 |
+
raise ValueError("No images found. Check your upload.")
|
| 75 |
+
|
| 76 |
+
images = load_and_preprocess_images(image_names).to(device)
|
| 77 |
+
print(f"Preprocessed images shape: {images.shape}")
|
| 78 |
+
# anchor image selection
|
| 79 |
+
select_indices = uniform_sample(len(image_names), min(100, len(image_names)))
|
| 80 |
+
anchor_images = images[select_indices]
|
| 81 |
+
|
| 82 |
+
# Run inference
|
| 83 |
+
print("Running inference...")
|
| 84 |
+
dtype = (
|
| 85 |
+
torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
with torch.cuda.amp.autocast(dtype=dtype):
|
| 90 |
+
print("Processing anchor images ...")
|
| 91 |
+
model.tmp_forward(anchor_images)
|
| 92 |
+
# del model.aggregator.global_blocks
|
| 93 |
+
# relocalization on all images
|
| 94 |
+
predictions_s = []
|
| 95 |
+
with tqdm(total=len(image_names), desc="Relocalizing") as pbar:
|
| 96 |
+
for img_split in images.split(10, dim=0):
|
| 97 |
+
pbar.update(10)
|
| 98 |
+
predictions_s += to_cpu(
|
| 99 |
+
model.reloc(img_split, ret_img=True, memory_save=False)
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
predictions = {}
|
| 103 |
+
predictions["extrinsic"] = torch.cat(
|
| 104 |
+
[s["extrinsic"] for s in predictions_s], dim=0
|
| 105 |
+
) # (S, 4, 4)
|
| 106 |
+
predictions["intrinsic"] = torch.cat(
|
| 107 |
+
[s["intrinsic"] for s in predictions_s], dim=0
|
| 108 |
+
) # (S, 4, 4)
|
| 109 |
+
predictions["depth"] = torch.cat(
|
| 110 |
+
[s["depth_map"] for s in predictions_s], dim=0
|
| 111 |
+
) # (S, H, W, 1)
|
| 112 |
+
predictions["depth_conf"] = torch.cat(
|
| 113 |
+
[s["dpt_cnf"] for s in predictions_s], dim=0
|
| 114 |
+
) # (S, H, W, 1)
|
| 115 |
+
predictions["images"] = torch.cat(
|
| 116 |
+
[s["images"] for s in predictions_s], dim=0
|
| 117 |
+
) # (S, H, W, 3)
|
| 118 |
+
predictions["world_points"] = torch.cat(
|
| 119 |
+
[s["point_map"] for s in predictions_s], dim=0
|
| 120 |
+
) # (S, H, W, 3)
|
| 121 |
+
predictions["world_points_conf"] = torch.cat(
|
| 122 |
+
[s["xyz_cnf"] for s in predictions_s], dim=0
|
| 123 |
+
) # (S, H, W, 3)
|
| 124 |
+
predictions["pose_enc"] = extri_intri_to_pose_encoding(
|
| 125 |
+
predictions["extrinsic"].unsqueeze(0),
|
| 126 |
+
predictions["intrinsic"].unsqueeze(0),
|
| 127 |
+
images.shape[-2:],
|
| 128 |
+
)[
|
| 129 |
+
0
|
| 130 |
+
] # a
|
| 131 |
+
del predictions_s
|
| 132 |
+
|
| 133 |
+
# Convert tensors to numpy
|
| 134 |
+
for key in predictions.keys():
|
| 135 |
+
if isinstance(predictions[key], torch.Tensor):
|
| 136 |
+
predictions[key] = predictions[key].cpu().numpy() # remove batch dimension
|
| 137 |
+
predictions["pose_enc_list"] = None # remove pose_enc_list
|
| 138 |
+
|
| 139 |
+
# Generate world points from depth map
|
| 140 |
+
print("Computing world points from depth map...")
|
| 141 |
+
depth_map = predictions["depth"] # (S, H, W, 1)
|
| 142 |
+
world_points = unproject_depth_map_to_point_map(
|
| 143 |
+
depth_map, predictions["extrinsic"], predictions["intrinsic"]
|
| 144 |
+
)
|
| 145 |
+
predictions["world_points_from_depth"] = world_points
|
| 146 |
+
|
| 147 |
+
# Clean up
|
| 148 |
+
torch.cuda.empty_cache()
|
| 149 |
+
return predictions
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# -------------------------------------------------------------------------
|
| 153 |
+
# 2) Handle uploaded video/images --> produce target_dir + images
|
| 154 |
+
# -------------------------------------------------------------------------
|
| 155 |
+
def handle_uploads(input_video, input_images):
|
| 156 |
+
"""
|
| 157 |
+
Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
|
| 158 |
+
images or extracted frames from video into it. Return (target_dir, image_paths).
|
| 159 |
+
"""
|
| 160 |
+
start_time = time.time()
|
| 161 |
+
gc.collect()
|
| 162 |
+
torch.cuda.empty_cache()
|
| 163 |
+
|
| 164 |
+
# Create a unique folder name
|
| 165 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 166 |
+
target_dir = f"input_images_{timestamp}"
|
| 167 |
+
target_dir_images = os.path.join(target_dir, "images")
|
| 168 |
+
|
| 169 |
+
# Clean up if somehow that folder already exists
|
| 170 |
+
if os.path.exists(target_dir):
|
| 171 |
+
shutil.rmtree(target_dir)
|
| 172 |
+
os.makedirs(target_dir)
|
| 173 |
+
os.makedirs(target_dir_images)
|
| 174 |
+
|
| 175 |
+
image_paths = []
|
| 176 |
+
|
| 177 |
+
# --- Handle images ---
|
| 178 |
+
if input_images is not None:
|
| 179 |
+
for file_data in input_images:
|
| 180 |
+
if isinstance(file_data, dict) and "name" in file_data:
|
| 181 |
+
file_path = file_data["name"]
|
| 182 |
+
else:
|
| 183 |
+
file_path = file_data
|
| 184 |
+
dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
|
| 185 |
+
shutil.copy(file_path, dst_path)
|
| 186 |
+
image_paths.append(dst_path)
|
| 187 |
+
|
| 188 |
+
# --- Handle video ---
|
| 189 |
+
if input_video is not None:
|
| 190 |
+
if isinstance(input_video, dict) and "name" in input_video:
|
| 191 |
+
video_path = input_video["name"]
|
| 192 |
+
else:
|
| 193 |
+
video_path = input_video
|
| 194 |
+
|
| 195 |
+
vs = cv2.VideoCapture(video_path)
|
| 196 |
+
fps = vs.get(cv2.CAP_PROP_FPS)
|
| 197 |
+
|
| 198 |
+
count = 0
|
| 199 |
+
video_frame_num = 0
|
| 200 |
+
while True:
|
| 201 |
+
gotit, frame = vs.read()
|
| 202 |
+
if not gotit:
|
| 203 |
+
break
|
| 204 |
+
count += 1
|
| 205 |
+
image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
|
| 206 |
+
cv2.imwrite(image_path, frame)
|
| 207 |
+
image_paths.append(image_path)
|
| 208 |
+
video_frame_num += 1
|
| 209 |
+
|
| 210 |
+
# Sort final images for gallery
|
| 211 |
+
image_paths = sorted(image_paths)
|
| 212 |
+
|
| 213 |
+
end_time = time.time()
|
| 214 |
+
print(
|
| 215 |
+
f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds"
|
| 216 |
+
)
|
| 217 |
+
return target_dir, image_paths
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# -------------------------------------------------------------------------
|
| 221 |
+
# 3) Update gallery on upload
|
| 222 |
+
# -------------------------------------------------------------------------
|
| 223 |
+
def update_gallery_on_upload(input_video, input_images):
|
| 224 |
+
"""
|
| 225 |
+
Whenever user uploads or changes files, immediately handle them
|
| 226 |
+
and show in the gallery. Return (target_dir, image_paths).
|
| 227 |
+
If nothing is uploaded, returns "None" and empty list.
|
| 228 |
+
"""
|
| 229 |
+
if not input_video and not input_images:
|
| 230 |
+
return None, None, None, None
|
| 231 |
+
target_dir, image_paths = handle_uploads(input_video, input_images)
|
| 232 |
+
return (
|
| 233 |
+
None,
|
| 234 |
+
target_dir,
|
| 235 |
+
image_paths,
|
| 236 |
+
"Upload complete. Click 'Reconstruct' to begin 3D processing.",
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
# -------------------------------------------------------------------------
|
| 241 |
+
# 4) Reconstruction: uses the target_dir plus any viz parameters
|
| 242 |
+
# -------------------------------------------------------------------------
|
| 243 |
+
def gradio_demo(
|
| 244 |
+
target_dir,
|
| 245 |
+
conf_thres=3.0,
|
| 246 |
+
frame_filter="All",
|
| 247 |
+
mask_black_bg=False,
|
| 248 |
+
mask_white_bg=False,
|
| 249 |
+
show_cam=True,
|
| 250 |
+
mask_sky=False,
|
| 251 |
+
downsample_ratio=100.0,
|
| 252 |
+
prediction_mode="Pointmap Regression",
|
| 253 |
+
):
|
| 254 |
+
"""
|
| 255 |
+
Perform reconstruction using the already-created target_dir/images.
|
| 256 |
+
"""
|
| 257 |
+
if not os.path.isdir(target_dir) or target_dir == "None":
|
| 258 |
+
return None, "No valid target directory found. Please upload first.", None, None
|
| 259 |
+
|
| 260 |
+
start_time = time.time()
|
| 261 |
+
gc.collect()
|
| 262 |
+
torch.cuda.empty_cache()
|
| 263 |
+
|
| 264 |
+
# Prepare frame_filter dropdown
|
| 265 |
+
target_dir_images = os.path.join(target_dir, "images")
|
| 266 |
+
all_files = (
|
| 267 |
+
sorted(os.listdir(target_dir_images))
|
| 268 |
+
if os.path.isdir(target_dir_images)
|
| 269 |
+
else []
|
| 270 |
+
)
|
| 271 |
+
all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
|
| 272 |
+
frame_filter_choices = ["All"] + all_files
|
| 273 |
+
|
| 274 |
+
print("Running run_model...")
|
| 275 |
+
with torch.no_grad():
|
| 276 |
+
predictions = run_model(target_dir, model)
|
| 277 |
+
|
| 278 |
+
# Save predictions
|
| 279 |
+
prediction_save_path = os.path.join(target_dir, "predictions.npz")
|
| 280 |
+
np.savez(prediction_save_path, **predictions)
|
| 281 |
+
|
| 282 |
+
# Handle None frame_filter
|
| 283 |
+
if frame_filter is None:
|
| 284 |
+
frame_filter = "All"
|
| 285 |
+
|
| 286 |
+
# Build a GLB file name
|
| 287 |
+
glbfile = os.path.join(
|
| 288 |
+
target_dir,
|
| 289 |
+
f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# Convert predictions to GLB
|
| 293 |
+
glbscene = predictions_to_glb(
|
| 294 |
+
predictions,
|
| 295 |
+
conf_thres=conf_thres,
|
| 296 |
+
filter_by_frames=frame_filter,
|
| 297 |
+
mask_black_bg=mask_black_bg,
|
| 298 |
+
mask_white_bg=mask_white_bg,
|
| 299 |
+
show_cam=show_cam,
|
| 300 |
+
mask_sky=mask_sky,
|
| 301 |
+
target_dir=target_dir,
|
| 302 |
+
downsample_ratio=downsample_ratio / 100.0,
|
| 303 |
+
prediction_mode=prediction_mode,
|
| 304 |
+
)
|
| 305 |
+
glbscene.export(file_obj=glbfile)
|
| 306 |
+
|
| 307 |
+
# Cleanup
|
| 308 |
+
del predictions
|
| 309 |
+
gc.collect()
|
| 310 |
+
torch.cuda.empty_cache()
|
| 311 |
+
|
| 312 |
+
end_time = time.time()
|
| 313 |
+
print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
|
| 314 |
+
log_msg = (
|
| 315 |
+
f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
return (
|
| 319 |
+
glbfile,
|
| 320 |
+
log_msg,
|
| 321 |
+
gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# -------------------------------------------------------------------------
|
| 326 |
+
# 5) Helper functions for UI resets + re-visualization
|
| 327 |
+
# -------------------------------------------------------------------------
|
| 328 |
+
def clear_fields():
|
| 329 |
+
"""
|
| 330 |
+
Clears the 3D viewer, the stored target_dir, and empties the gallery.
|
| 331 |
+
"""
|
| 332 |
+
return None
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def update_log():
|
| 336 |
+
"""
|
| 337 |
+
Display a quick log message while waiting.
|
| 338 |
+
"""
|
| 339 |
+
return "Loading and Reconstructing..."
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def update_visualization(
|
| 343 |
+
target_dir,
|
| 344 |
+
conf_thres,
|
| 345 |
+
frame_filter,
|
| 346 |
+
mask_black_bg,
|
| 347 |
+
mask_white_bg,
|
| 348 |
+
show_cam,
|
| 349 |
+
mask_sky,
|
| 350 |
+
downsample_ratio,
|
| 351 |
+
prediction_mode,
|
| 352 |
+
is_example,
|
| 353 |
+
):
|
| 354 |
+
"""
|
| 355 |
+
Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
|
| 356 |
+
and return it for the 3D viewer. If is_example == "True", skip.
|
| 357 |
+
"""
|
| 358 |
+
|
| 359 |
+
# If it's an example click, skip as requested
|
| 360 |
+
if is_example == "True":
|
| 361 |
+
return (
|
| 362 |
+
None,
|
| 363 |
+
"No reconstruction available. Please click the Reconstruct button first.",
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
|
| 367 |
+
return (
|
| 368 |
+
None,
|
| 369 |
+
"No reconstruction available. Please click the Reconstruct button first.",
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
predictions_path = os.path.join(target_dir, "predictions.npz")
|
| 373 |
+
if not os.path.exists(predictions_path):
|
| 374 |
+
return (
|
| 375 |
+
None,
|
| 376 |
+
f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.",
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
key_list = [
|
| 380 |
+
"pose_enc",
|
| 381 |
+
"depth",
|
| 382 |
+
"depth_conf",
|
| 383 |
+
"world_points",
|
| 384 |
+
"world_points_conf",
|
| 385 |
+
"images",
|
| 386 |
+
"extrinsic",
|
| 387 |
+
"intrinsic",
|
| 388 |
+
"world_points_from_depth",
|
| 389 |
+
]
|
| 390 |
+
|
| 391 |
+
loaded = np.load(predictions_path)
|
| 392 |
+
predictions = {key: np.array(loaded[key]) for key in key_list if key in loaded}
|
| 393 |
+
print(downsample_ratio)
|
| 394 |
+
glbfile = os.path.join(
|
| 395 |
+
target_dir,
|
| 396 |
+
f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_dr{downsample_ratio}_pred{prediction_mode.replace(' ', '_')}.glb",
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
if not os.path.exists(glbfile):
|
| 400 |
+
glbscene = predictions_to_glb(
|
| 401 |
+
predictions,
|
| 402 |
+
conf_thres=conf_thres,
|
| 403 |
+
filter_by_frames=frame_filter,
|
| 404 |
+
mask_black_bg=mask_black_bg,
|
| 405 |
+
mask_white_bg=mask_white_bg,
|
| 406 |
+
show_cam=show_cam,
|
| 407 |
+
mask_sky=mask_sky,
|
| 408 |
+
target_dir=target_dir,
|
| 409 |
+
downsample_ratio=downsample_ratio * 1.0 / 100.0,
|
| 410 |
+
prediction_mode=prediction_mode,
|
| 411 |
+
)
|
| 412 |
+
glbscene.export(file_obj=glbfile)
|
| 413 |
+
|
| 414 |
+
return glbfile, "Updating Visualization"
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
# -------------------------------------------------------------------------
|
| 418 |
+
# Example images
|
| 419 |
+
# -------------------------------------------------------------------------
|
| 420 |
+
|
| 421 |
+
great_wall_video = "examples/videos/great_wall.mp4"
|
| 422 |
+
colosseum_video = "examples/videos/Colosseum.mp4"
|
| 423 |
+
room_video = "examples/videos/room.mp4"
|
| 424 |
+
kitchen_video = "examples/videos/kitchen.mp4"
|
| 425 |
+
fern_video = "examples/videos/fern.mp4"
|
| 426 |
+
single_cartoon_video = "examples/videos/single_cartoon.mp4"
|
| 427 |
+
single_oil_painting_video = "examples/videos/single_oil_painting.mp4"
|
| 428 |
+
pyramid_video = "examples/videos/pyramid.mp4"
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
# -------------------------------------------------------------------------
|
| 432 |
+
# 6) Build Gradio UI
|
| 433 |
+
# -------------------------------------------------------------------------
|
| 434 |
+
theme = gr.themes.Ocean()
|
| 435 |
+
theme.set(
|
| 436 |
+
checkbox_label_background_fill_selected="*button_primary_background_fill",
|
| 437 |
+
checkbox_label_text_color_selected="*button_primary_text_color",
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
with gr.Blocks(
|
| 441 |
+
theme=theme,
|
| 442 |
+
css="""
|
| 443 |
+
.custom-log * {
|
| 444 |
+
font-style: italic;
|
| 445 |
+
font-size: 22px !important;
|
| 446 |
+
background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
|
| 447 |
+
-webkit-background-clip: text;
|
| 448 |
+
background-clip: text;
|
| 449 |
+
font-weight: bold !important;
|
| 450 |
+
color: transparent !important;
|
| 451 |
+
text-align: center !important;
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
.example-log * {
|
| 455 |
+
font-style: italic;
|
| 456 |
+
font-size: 16px !important;
|
| 457 |
+
background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
|
| 458 |
+
-webkit-background-clip: text;
|
| 459 |
+
background-clip: text;
|
| 460 |
+
color: transparent !important;
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
#my_radio .wrap {
|
| 464 |
+
display: flex;
|
| 465 |
+
flex-wrap: nowrap;
|
| 466 |
+
justify-content: center;
|
| 467 |
+
align-items: center;
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
#my_radio .wrap label {
|
| 471 |
+
display: flex;
|
| 472 |
+
width: 50%;
|
| 473 |
+
justify-content: center;
|
| 474 |
+
align-items: center;
|
| 475 |
+
margin: 0;
|
| 476 |
+
padding: 10px 0;
|
| 477 |
+
box-sizing: border-box;
|
| 478 |
+
}
|
| 479 |
+
""",
|
| 480 |
+
) as demo:
|
| 481 |
+
# Instead of gr.State, we use a hidden Textbox:
|
| 482 |
+
is_example = gr.Textbox(label="is_example", visible=False, value="None")
|
| 483 |
+
num_images = gr.Textbox(label="num_images", visible=False, value="None")
|
| 484 |
+
|
| 485 |
+
gr.HTML(
|
| 486 |
+
"""
|
| 487 |
+
<h1>🏛️ SAIL-Recon: Large SfM by Augmenting Scene Regression with Localization</h1>
|
| 488 |
+
<p>
|
| 489 |
+
<a href="https://github.com/HKUST-SAIL/sail-recon">🐙 GitHub Repository</a> |
|
| 490 |
+
<a href="https://hkust-sail.github.io/sail-recon/">Project Page</a>
|
| 491 |
+
</p>
|
| 492 |
+
|
| 493 |
+
<div style="font-size: 16px; line-height: 1.5;">
|
| 494 |
+
<p>Upload a video or a set of images to create a 3D reconstruction of a scene or object. SAIL-Recon takes these images and generates a 3D point cloud, along with estimated camera poses.</p>
|
| 495 |
+
|
| 496 |
+
<h3>Getting Started:</h3>
|
| 497 |
+
<ol>
|
| 498 |
+
<li><strong>Upload Your Data:</strong> Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).</li>
|
| 499 |
+
<li><strong>Preview:</strong> Your uploaded images will appear in the gallery on the left.</li>
|
| 500 |
+
<li><strong>Reconstruct:</strong> Click the "Reconstruct" button to start the 3D reconstruction process.</li>
|
| 501 |
+
<li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note that the visualization of 3D points may be slow for a large number of input images.</li>
|
| 502 |
+
<li>
|
| 503 |
+
<strong>Adjust Visualization (Optional):</strong>
|
| 504 |
+
After reconstruction, you can fine-tune the visualization using the options below
|
| 505 |
+
<details style="display:inline;">
|
| 506 |
+
<summary style="display:inline;">(<strong>click to expand</strong>):</summary>
|
| 507 |
+
<ul>
|
| 508 |
+
<li><em>Confidence Threshold:</em> Adjust the filtering of points based on confidence.</li>
|
| 509 |
+
<li><em>Show Points from Frame:</em> Select specific frames to display in the point cloud.</li>
|
| 510 |
+
<li><em>Show Camera:</em> Toggle the display of estimated camera positions.</li>
|
| 511 |
+
<li><em>Filter Sky / Filter Black Background:</em> Remove sky or black-background points.</li>
|
| 512 |
+
<li><em>Select a Prediction Mode:</em> Choose between "Depthmap and Camera Branch" or "Pointmap Branch."</li>
|
| 513 |
+
</ul>
|
| 514 |
+
</details>
|
| 515 |
+
</li>
|
| 516 |
+
</ol>
|
| 517 |
+
<p><strong style="color: #0ea5e9;">Please note:</strong> <span style="color: #0ea5e9; font-weight: bold;">SAIL-Recon typically reconstructs a scene at 5FPS with full 3D attributes. However, visualizing 3D points may take tens of seconds due to third-party rendering, which is independent of SAIL-Recon's processing time. Using the 'demo.py' can provide much faster processing.</span></p>
|
| 518 |
+
</div>
|
| 519 |
+
"""
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
|
| 523 |
+
|
| 524 |
+
with gr.Row():
|
| 525 |
+
with gr.Column(scale=2):
|
| 526 |
+
input_video = gr.Video(label="Upload Video", interactive=True)
|
| 527 |
+
input_images = gr.File(
|
| 528 |
+
file_count="multiple", label="Upload Images", interactive=True
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
image_gallery = gr.Gallery(
|
| 532 |
+
label="Preview",
|
| 533 |
+
columns=4,
|
| 534 |
+
height="300px",
|
| 535 |
+
show_download_button=True,
|
| 536 |
+
object_fit="contain",
|
| 537 |
+
preview=True,
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
with gr.Column(scale=4):
|
| 541 |
+
with gr.Column():
|
| 542 |
+
gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**")
|
| 543 |
+
log_output = gr.Markdown(
|
| 544 |
+
"Please upload a video or images, then click Reconstruct.",
|
| 545 |
+
elem_classes=["custom-log"],
|
| 546 |
+
)
|
| 547 |
+
reconstruction_output = gr.Model3D(
|
| 548 |
+
height=520, zoom_speed=0.5, pan_speed=0.5
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
with gr.Row():
|
| 552 |
+
submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
|
| 553 |
+
clear_btn = gr.ClearButton(
|
| 554 |
+
[
|
| 555 |
+
input_video,
|
| 556 |
+
input_images,
|
| 557 |
+
reconstruction_output,
|
| 558 |
+
log_output,
|
| 559 |
+
target_dir_output,
|
| 560 |
+
image_gallery,
|
| 561 |
+
],
|
| 562 |
+
scale=1,
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
with gr.Row():
|
| 566 |
+
prediction_mode = gr.Radio(
|
| 567 |
+
["Depthmap and Camera Branch", "Pointmap Branch"],
|
| 568 |
+
label="Select a Prediction Mode",
|
| 569 |
+
value="Depthmap and Camera Branch",
|
| 570 |
+
scale=1,
|
| 571 |
+
elem_id="my_radio",
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
with gr.Row():
|
| 575 |
+
conf_thres = gr.Slider(
|
| 576 |
+
minimum=0,
|
| 577 |
+
maximum=100,
|
| 578 |
+
value=50,
|
| 579 |
+
step=0.1,
|
| 580 |
+
label="Confidence Threshold (%)",
|
| 581 |
+
)
|
| 582 |
+
downsample_ratio = gr.Slider(
|
| 583 |
+
minimum=1.0,
|
| 584 |
+
maximum=100,
|
| 585 |
+
value=100,
|
| 586 |
+
step=0.1,
|
| 587 |
+
label="Downsample Ratio(%)",
|
| 588 |
+
)
|
| 589 |
+
frame_filter = gr.Dropdown(
|
| 590 |
+
choices=["All"], value="All", label="Show Points from Frame"
|
| 591 |
+
)
|
| 592 |
+
with gr.Column():
|
| 593 |
+
show_cam = gr.Checkbox(label="Show Camera", value=True)
|
| 594 |
+
mask_sky = gr.Checkbox(label="Filter Sky", value=False)
|
| 595 |
+
mask_black_bg = gr.Checkbox(
|
| 596 |
+
label="Filter Black Background", value=False
|
| 597 |
+
)
|
| 598 |
+
mask_white_bg = gr.Checkbox(
|
| 599 |
+
label="Filter White Background", value=False
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
# ---------------------- Examples section ----------------------
|
| 603 |
+
examples = [
|
| 604 |
+
[
|
| 605 |
+
colosseum_video,
|
| 606 |
+
"22",
|
| 607 |
+
None,
|
| 608 |
+
20.0,
|
| 609 |
+
False,
|
| 610 |
+
False,
|
| 611 |
+
True,
|
| 612 |
+
False,
|
| 613 |
+
"Depthmap and Camera Branch",
|
| 614 |
+
"True",
|
| 615 |
+
],
|
| 616 |
+
[
|
| 617 |
+
pyramid_video,
|
| 618 |
+
"30",
|
| 619 |
+
None,
|
| 620 |
+
35.0,
|
| 621 |
+
False,
|
| 622 |
+
False,
|
| 623 |
+
True,
|
| 624 |
+
False,
|
| 625 |
+
"Depthmap and Camera Branch",
|
| 626 |
+
"True",
|
| 627 |
+
],
|
| 628 |
+
[
|
| 629 |
+
single_cartoon_video,
|
| 630 |
+
"1",
|
| 631 |
+
None,
|
| 632 |
+
15.0,
|
| 633 |
+
False,
|
| 634 |
+
False,
|
| 635 |
+
True,
|
| 636 |
+
False,
|
| 637 |
+
"Depthmap and Camera Branch",
|
| 638 |
+
"True",
|
| 639 |
+
],
|
| 640 |
+
[
|
| 641 |
+
single_oil_painting_video,
|
| 642 |
+
"1",
|
| 643 |
+
None,
|
| 644 |
+
20.0,
|
| 645 |
+
False,
|
| 646 |
+
False,
|
| 647 |
+
True,
|
| 648 |
+
True,
|
| 649 |
+
"Depthmap and Camera Branch",
|
| 650 |
+
"True",
|
| 651 |
+
],
|
| 652 |
+
[
|
| 653 |
+
room_video,
|
| 654 |
+
"8",
|
| 655 |
+
None,
|
| 656 |
+
5.0,
|
| 657 |
+
False,
|
| 658 |
+
False,
|
| 659 |
+
True,
|
| 660 |
+
False,
|
| 661 |
+
"Depthmap and Camera Branch",
|
| 662 |
+
"True",
|
| 663 |
+
],
|
| 664 |
+
[
|
| 665 |
+
kitchen_video,
|
| 666 |
+
"25",
|
| 667 |
+
None,
|
| 668 |
+
50.0,
|
| 669 |
+
False,
|
| 670 |
+
False,
|
| 671 |
+
True,
|
| 672 |
+
False,
|
| 673 |
+
"Depthmap and Camera Branch",
|
| 674 |
+
"True",
|
| 675 |
+
],
|
| 676 |
+
[
|
| 677 |
+
fern_video,
|
| 678 |
+
"20",
|
| 679 |
+
None,
|
| 680 |
+
45.0,
|
| 681 |
+
False,
|
| 682 |
+
False,
|
| 683 |
+
True,
|
| 684 |
+
False,
|
| 685 |
+
"Depthmap and Camera Branch",
|
| 686 |
+
"True",
|
| 687 |
+
],
|
| 688 |
+
]
|
| 689 |
+
|
| 690 |
+
def example_pipeline(
|
| 691 |
+
input_video,
|
| 692 |
+
num_images_str,
|
| 693 |
+
input_images,
|
| 694 |
+
conf_thres,
|
| 695 |
+
mask_black_bg,
|
| 696 |
+
mask_white_bg,
|
| 697 |
+
show_cam,
|
| 698 |
+
mask_sky,
|
| 699 |
+
downsample_ratio,
|
| 700 |
+
prediction_mode,
|
| 701 |
+
is_example_str,
|
| 702 |
+
):
|
| 703 |
+
"""
|
| 704 |
+
1) Copy example images to new target_dir
|
| 705 |
+
2) Reconstruct
|
| 706 |
+
3) Return model3D + logs + new_dir + updated dropdown + gallery
|
| 707 |
+
We do NOT return is_example. It's just an input.
|
| 708 |
+
"""
|
| 709 |
+
target_dir, image_paths = handle_uploads(input_video, input_images)
|
| 710 |
+
# Always use "All" for frame_filter in examples
|
| 711 |
+
frame_filter = "All"
|
| 712 |
+
glbfile, log_msg, dropdown = gradio_demo(
|
| 713 |
+
target_dir,
|
| 714 |
+
conf_thres,
|
| 715 |
+
frame_filter,
|
| 716 |
+
mask_black_bg,
|
| 717 |
+
mask_white_bg,
|
| 718 |
+
show_cam,
|
| 719 |
+
mask_sky,
|
| 720 |
+
downsample_ratio,
|
| 721 |
+
prediction_mode,
|
| 722 |
+
)
|
| 723 |
+
return glbfile, log_msg, target_dir, dropdown, image_paths
|
| 724 |
+
|
| 725 |
+
gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
|
| 726 |
+
|
| 727 |
+
# gr.Examples(
|
| 728 |
+
# examples=examples,
|
| 729 |
+
# inputs=[
|
| 730 |
+
# input_video,
|
| 731 |
+
# num_images,
|
| 732 |
+
# input_images,
|
| 733 |
+
# conf_thres,
|
| 734 |
+
# mask_black_bg,
|
| 735 |
+
# mask_white_bg,
|
| 736 |
+
# show_cam,
|
| 737 |
+
# mask_sky,
|
| 738 |
+
# downsample_ratio,
|
| 739 |
+
# prediction_mode,
|
| 740 |
+
# is_example,
|
| 741 |
+
# ],
|
| 742 |
+
# outputs=[reconstruction_output, log_output, target_dir_output, frame_filter, image_gallery],
|
| 743 |
+
# fn=example_pipeline,
|
| 744 |
+
# cache_examples=False,
|
| 745 |
+
# examples_per_page=50,
|
| 746 |
+
# )
|
| 747 |
+
|
| 748 |
+
# -------------------------------------------------------------------------
|
| 749 |
+
# "Reconstruct" button logic:
|
| 750 |
+
# - Clear fields
|
| 751 |
+
# - Update log
|
| 752 |
+
# - gradio_demo(...) with the existing target_dir
|
| 753 |
+
# - Then set is_example = "False"
|
| 754 |
+
# -------------------------------------------------------------------------
|
| 755 |
+
submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
|
| 756 |
+
fn=update_log, inputs=[], outputs=[log_output]
|
| 757 |
+
).then(
|
| 758 |
+
fn=gradio_demo,
|
| 759 |
+
inputs=[
|
| 760 |
+
target_dir_output,
|
| 761 |
+
conf_thres,
|
| 762 |
+
frame_filter,
|
| 763 |
+
mask_black_bg,
|
| 764 |
+
mask_white_bg,
|
| 765 |
+
show_cam,
|
| 766 |
+
mask_sky,
|
| 767 |
+
downsample_ratio,
|
| 768 |
+
prediction_mode,
|
| 769 |
+
],
|
| 770 |
+
outputs=[reconstruction_output, log_output, frame_filter],
|
| 771 |
+
).then(
|
| 772 |
+
fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
# -------------------------------------------------------------------------
|
| 776 |
+
# Real-time Visualization Updates
|
| 777 |
+
# -------------------------------------------------------------------------
|
| 778 |
+
conf_thres.change(
|
| 779 |
+
update_visualization,
|
| 780 |
+
[
|
| 781 |
+
target_dir_output,
|
| 782 |
+
conf_thres,
|
| 783 |
+
frame_filter,
|
| 784 |
+
mask_black_bg,
|
| 785 |
+
mask_white_bg,
|
| 786 |
+
show_cam,
|
| 787 |
+
mask_sky,
|
| 788 |
+
downsample_ratio,
|
| 789 |
+
prediction_mode,
|
| 790 |
+
is_example,
|
| 791 |
+
],
|
| 792 |
+
[reconstruction_output, log_output],
|
| 793 |
+
)
|
| 794 |
+
downsample_ratio.change(
|
| 795 |
+
update_visualization,
|
| 796 |
+
[
|
| 797 |
+
target_dir_output,
|
| 798 |
+
conf_thres,
|
| 799 |
+
frame_filter,
|
| 800 |
+
mask_black_bg,
|
| 801 |
+
mask_white_bg,
|
| 802 |
+
show_cam,
|
| 803 |
+
mask_sky,
|
| 804 |
+
downsample_ratio,
|
| 805 |
+
prediction_mode,
|
| 806 |
+
is_example,
|
| 807 |
+
],
|
| 808 |
+
[reconstruction_output, log_output],
|
| 809 |
+
)
|
| 810 |
+
frame_filter.change(
|
| 811 |
+
update_visualization,
|
| 812 |
+
[
|
| 813 |
+
target_dir_output,
|
| 814 |
+
conf_thres,
|
| 815 |
+
frame_filter,
|
| 816 |
+
mask_black_bg,
|
| 817 |
+
mask_white_bg,
|
| 818 |
+
show_cam,
|
| 819 |
+
mask_sky,
|
| 820 |
+
downsample_ratio,
|
| 821 |
+
prediction_mode,
|
| 822 |
+
is_example,
|
| 823 |
+
],
|
| 824 |
+
[reconstruction_output, log_output],
|
| 825 |
+
)
|
| 826 |
+
mask_black_bg.change(
|
| 827 |
+
update_visualization,
|
| 828 |
+
[
|
| 829 |
+
target_dir_output,
|
| 830 |
+
conf_thres,
|
| 831 |
+
frame_filter,
|
| 832 |
+
mask_black_bg,
|
| 833 |
+
mask_white_bg,
|
| 834 |
+
show_cam,
|
| 835 |
+
mask_sky,
|
| 836 |
+
downsample_ratio,
|
| 837 |
+
prediction_mode,
|
| 838 |
+
is_example,
|
| 839 |
+
],
|
| 840 |
+
[reconstruction_output, log_output],
|
| 841 |
+
)
|
| 842 |
+
mask_white_bg.change(
|
| 843 |
+
update_visualization,
|
| 844 |
+
[
|
| 845 |
+
target_dir_output,
|
| 846 |
+
conf_thres,
|
| 847 |
+
frame_filter,
|
| 848 |
+
mask_black_bg,
|
| 849 |
+
mask_white_bg,
|
| 850 |
+
show_cam,
|
| 851 |
+
mask_sky,
|
| 852 |
+
downsample_ratio,
|
| 853 |
+
prediction_mode,
|
| 854 |
+
is_example,
|
| 855 |
+
],
|
| 856 |
+
[reconstruction_output, log_output],
|
| 857 |
+
)
|
| 858 |
+
show_cam.change(
|
| 859 |
+
update_visualization,
|
| 860 |
+
[
|
| 861 |
+
target_dir_output,
|
| 862 |
+
conf_thres,
|
| 863 |
+
frame_filter,
|
| 864 |
+
mask_black_bg,
|
| 865 |
+
mask_white_bg,
|
| 866 |
+
show_cam,
|
| 867 |
+
mask_sky,
|
| 868 |
+
downsample_ratio,
|
| 869 |
+
prediction_mode,
|
| 870 |
+
is_example,
|
| 871 |
+
],
|
| 872 |
+
[reconstruction_output, log_output],
|
| 873 |
+
)
|
| 874 |
+
mask_sky.change(
|
| 875 |
+
update_visualization,
|
| 876 |
+
[
|
| 877 |
+
target_dir_output,
|
| 878 |
+
conf_thres,
|
| 879 |
+
frame_filter,
|
| 880 |
+
mask_black_bg,
|
| 881 |
+
mask_white_bg,
|
| 882 |
+
show_cam,
|
| 883 |
+
mask_sky,
|
| 884 |
+
downsample_ratio,
|
| 885 |
+
prediction_mode,
|
| 886 |
+
is_example,
|
| 887 |
+
],
|
| 888 |
+
[reconstruction_output, log_output],
|
| 889 |
+
)
|
| 890 |
+
prediction_mode.change(
|
| 891 |
+
update_visualization,
|
| 892 |
+
[
|
| 893 |
+
target_dir_output,
|
| 894 |
+
conf_thres,
|
| 895 |
+
frame_filter,
|
| 896 |
+
mask_black_bg,
|
| 897 |
+
mask_white_bg,
|
| 898 |
+
show_cam,
|
| 899 |
+
mask_sky,
|
| 900 |
+
downsample_ratio,
|
| 901 |
+
prediction_mode,
|
| 902 |
+
is_example,
|
| 903 |
+
],
|
| 904 |
+
[reconstruction_output, log_output],
|
| 905 |
+
)
|
| 906 |
+
|
| 907 |
+
# -------------------------------------------------------------------------
|
| 908 |
+
# Auto-update gallery whenever user uploads or changes their files
|
| 909 |
+
# -------------------------------------------------------------------------
|
| 910 |
+
input_video.change(
|
| 911 |
+
fn=update_gallery_on_upload,
|
| 912 |
+
inputs=[input_video, input_images],
|
| 913 |
+
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
|
| 914 |
+
)
|
| 915 |
+
input_images.change(
|
| 916 |
+
fn=update_gallery_on_upload,
|
| 917 |
+
inputs=[input_video, input_images],
|
| 918 |
+
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
+
demo.queue(max_size=20).launch(show_error=True, share=True)
|
docs/traj_ply.png
ADDED
|
Git LFS Details
|
eval/datasets/mip_360.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import struct
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils.data import Dataset
|
| 10 |
+
from vggt.utils.load_fn import load_and_preprocess_images
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Mip360Dataset(Dataset):
|
| 14 |
+
def __init__(self, root_dir, scene_name="bicycle"):
|
| 15 |
+
self.scene_dir = os.path.join(
|
| 16 |
+
root_dir,
|
| 17 |
+
f"{scene_name}",
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
self.test_samples = sorted(
|
| 21 |
+
glob.glob(os.path.join(self.scene_dir, "images_8", "*.JPG"))
|
| 22 |
+
)
|
| 23 |
+
# self.train_samples = sorted(glob.glob(os.path.join(self.train_seqs, "rgb", "*.png")))
|
| 24 |
+
self.all_samples = self.test_samples # + self.train_samples
|
| 25 |
+
bin_path = os.path.join(self.scene_dir, "sparse", "0", "images.bin")
|
| 26 |
+
self.poses = read_images_bin(bin_path)
|
| 27 |
+
|
| 28 |
+
def __len__(self):
|
| 29 |
+
return len(self.all_samples)
|
| 30 |
+
|
| 31 |
+
def __getitem__(self, idx):
|
| 32 |
+
return self._load_sample(self.all_samples[idx])
|
| 33 |
+
|
| 34 |
+
def get_train_sample(self, n=4):
|
| 35 |
+
# _rng = np.random.default_rng(seed=777)
|
| 36 |
+
gap = len(self.all_samples) // n
|
| 37 |
+
gap = max(gap, 1) # Ensure at least one sample is selected
|
| 38 |
+
gap = min(gap, len(self.all_samples)) # Ensure gap does not exceed length
|
| 39 |
+
if gap == 1:
|
| 40 |
+
selected = sorted(
|
| 41 |
+
random.sample(self.all_samples, min(n, len(self.all_samples)))
|
| 42 |
+
)
|
| 43 |
+
else:
|
| 44 |
+
selected = self.all_samples[::gap]
|
| 45 |
+
if len(selected) > n:
|
| 46 |
+
selected = sorted(random.sample(selected, n))
|
| 47 |
+
return [self._load_sample(s) for s in selected]
|
| 48 |
+
|
| 49 |
+
def _load_sample(self, rgb_path):
|
| 50 |
+
img_name = os.path.basename(rgb_path)
|
| 51 |
+
color = load_and_preprocess_images([rgb_path])[0]
|
| 52 |
+
pose = torch.from_numpy(self.poses[img_name]).float()
|
| 53 |
+
|
| 54 |
+
return dict(
|
| 55 |
+
img=color,
|
| 56 |
+
camera_pose=pose, # cam2world
|
| 57 |
+
dataset="7Scenes",
|
| 58 |
+
true_shape=torch.tensor([392, 518]),
|
| 59 |
+
label=img_name,
|
| 60 |
+
instance=img_name,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def read_images_bin(bin_path: str | Path):
|
| 65 |
+
bin_path = Path(bin_path)
|
| 66 |
+
poses = {}
|
| 67 |
+
|
| 68 |
+
with bin_path.open("rb") as f:
|
| 69 |
+
num_images = struct.unpack("<Q", f.read(8))[0] # uint64
|
| 70 |
+
for _ in range(num_images):
|
| 71 |
+
image_id = struct.unpack("<I", f.read(4))[0]
|
| 72 |
+
qvec = np.frombuffer(f.read(8 * 4), dtype=np.float64) # qw,qx,qy,qz
|
| 73 |
+
tvec = np.frombuffer(f.read(8 * 3), dtype=np.float64) # tx,ty,tz
|
| 74 |
+
cam_id = struct.unpack("<I", f.read(4))[0] # camera_id
|
| 75 |
+
|
| 76 |
+
name_bytes = bytearray()
|
| 77 |
+
while True:
|
| 78 |
+
c = f.read(1)
|
| 79 |
+
if c == b"\0":
|
| 80 |
+
break
|
| 81 |
+
name_bytes.extend(c)
|
| 82 |
+
name = name_bytes.decode("utf-8")
|
| 83 |
+
|
| 84 |
+
n_pts = struct.unpack("<Q", f.read(8))[0]
|
| 85 |
+
f.seek(n_pts * 24, 1)
|
| 86 |
+
|
| 87 |
+
# world→cam to cam→world
|
| 88 |
+
qw, qx, qy, qz = qvec
|
| 89 |
+
R_wc = np.array(
|
| 90 |
+
[
|
| 91 |
+
[
|
| 92 |
+
1 - 2 * qy * qy - 2 * qz * qz,
|
| 93 |
+
2 * qx * qy + 2 * qz * qw,
|
| 94 |
+
2 * qx * qz - 2 * qy * qw,
|
| 95 |
+
],
|
| 96 |
+
[
|
| 97 |
+
2 * qx * qy - 2 * qz * qw,
|
| 98 |
+
1 - 2 * qx * qx - 2 * qz * qz,
|
| 99 |
+
2 * qy * qz + 2 * qx * qw,
|
| 100 |
+
],
|
| 101 |
+
[
|
| 102 |
+
2 * qx * qz + 2 * qy * qw,
|
| 103 |
+
2 * qy * qz - 2 * qx * qw,
|
| 104 |
+
1 - 2 * qx * qx - 2 * qy * qy,
|
| 105 |
+
],
|
| 106 |
+
]
|
| 107 |
+
)
|
| 108 |
+
t_wc = -R_wc @ tvec
|
| 109 |
+
c2w = np.eye(4, dtype=np.float32)
|
| 110 |
+
c2w[:3, :3] = R_wc.astype(np.float32)
|
| 111 |
+
c2w[:3, 3] = t_wc.astype(np.float32)
|
| 112 |
+
|
| 113 |
+
poses[name] = c2w
|
| 114 |
+
|
| 115 |
+
return poses
|
eval/datasets/seven_scenes.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
from vggt.utils.load_fn import load_and_preprocess_images
|
| 8 |
+
|
| 9 |
+
from eval.utils.eval_utils import uniform_sample
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SevenScenesUnifiedDataset(Dataset):
|
| 13 |
+
def __init__(self, root_dir, scene_name="chess"):
|
| 14 |
+
self.scene_dir = os.path.join(root_dir, f"pgt_7scenes_{scene_name}")
|
| 15 |
+
|
| 16 |
+
self.train_seqs = os.path.join(self.scene_dir, "train")
|
| 17 |
+
self.test_seqs = os.path.join(self.scene_dir, "test")
|
| 18 |
+
|
| 19 |
+
self.test_samples = sorted(
|
| 20 |
+
glob.glob(os.path.join(self.test_seqs, "rgb", "*.png"))
|
| 21 |
+
)
|
| 22 |
+
self.train_samples = sorted(
|
| 23 |
+
glob.glob(os.path.join(self.train_seqs, "rgb", "*.png"))
|
| 24 |
+
)
|
| 25 |
+
self.all_samples = self.test_samples # + self.train_samples
|
| 26 |
+
# len_samples = len(self.all_samples)
|
| 27 |
+
# self.all_samples = self.all_samples[::len_samples//200]
|
| 28 |
+
|
| 29 |
+
def __len__(self):
|
| 30 |
+
return len(self.all_samples)
|
| 31 |
+
|
| 32 |
+
def __getitem__(self, idx):
|
| 33 |
+
return self._load_sample(self.all_samples[idx])
|
| 34 |
+
|
| 35 |
+
def get_train_sample(self, n=4):
|
| 36 |
+
uniform_sampled = uniform_sample(len(self.all_samples), n)
|
| 37 |
+
selected = [self.all_samples[i] for i in uniform_sampled]
|
| 38 |
+
return [self._load_sample(s) for s in selected]
|
| 39 |
+
|
| 40 |
+
def _load_sample(self, rgb_path):
|
| 41 |
+
img_name = os.path.basename(rgb_path)
|
| 42 |
+
color = load_and_preprocess_images([rgb_path])[0]
|
| 43 |
+
pose_path = (
|
| 44 |
+
rgb_path.replace("rgb", "poses")
|
| 45 |
+
.replace("color", "pose")
|
| 46 |
+
.replace(".png", ".txt")
|
| 47 |
+
)
|
| 48 |
+
pose = np.loadtxt(pose_path)
|
| 49 |
+
pose = torch.from_numpy(pose).float()
|
| 50 |
+
|
| 51 |
+
return dict(
|
| 52 |
+
img=color,
|
| 53 |
+
camera_pose=pose, # cam2world
|
| 54 |
+
dataset="7Scenes",
|
| 55 |
+
true_shape=torch.tensor([392, 518]),
|
| 56 |
+
label=img_name,
|
| 57 |
+
instance=img_name,
|
| 58 |
+
)
|
eval/datasets/tnt.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import struct
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import PIL
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from torch.utils.data import Dataset
|
| 10 |
+
from vggt.utils.load_fn import load_and_preprocess_images
|
| 11 |
+
|
| 12 |
+
from eval.utils.eval_utils import uniform_sample
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TnTDataset(Dataset):
|
| 16 |
+
def __init__(self, root_dir, colmap_dir, scene_name="advanced__Auditorium"):
|
| 17 |
+
scene_name_ori = scene_name
|
| 18 |
+
|
| 19 |
+
level, scene_name = scene_name.split("__")
|
| 20 |
+
|
| 21 |
+
self.scene_dir = os.path.join(root_dir, f"{level}", f"{scene_name}")
|
| 22 |
+
|
| 23 |
+
self.test_samples = []
|
| 24 |
+
|
| 25 |
+
bin_path = os.path.join(colmap_dir, scene_name_ori, "0", "images.bin")
|
| 26 |
+
self.poses = read_images_bin(bin_path)
|
| 27 |
+
for img in self.poses.keys():
|
| 28 |
+
self.test_samples.append(os.path.join(self.scene_dir, img))
|
| 29 |
+
self.all_samples = self.test_samples
|
| 30 |
+
|
| 31 |
+
def __len__(self):
|
| 32 |
+
return len(self.all_samples)
|
| 33 |
+
|
| 34 |
+
def __getitem__(self, idx):
|
| 35 |
+
return self._load_sample(self.all_samples[idx])
|
| 36 |
+
|
| 37 |
+
def get_train_sample(self, n=4):
|
| 38 |
+
gap = len(self.all_samples) // n
|
| 39 |
+
gap = max(gap, 1) # Ensure at least one sample is selected
|
| 40 |
+
gap = min(gap, len(self.all_samples)) # Ensure gap does not exceed length
|
| 41 |
+
if gap == 1:
|
| 42 |
+
uniform_sampled = uniform_sample(len(self.all_samples), n)
|
| 43 |
+
selected = [self.all_samples[i] for i in uniform_sampled]
|
| 44 |
+
else:
|
| 45 |
+
selected = self.all_samples[::gap]
|
| 46 |
+
if len(selected) > n:
|
| 47 |
+
uniform_sampled = uniform_sample(len(selected), n)
|
| 48 |
+
selected = [selected[i] for i in uniform_sampled]
|
| 49 |
+
return [self._load_sample(s) for s in selected]
|
| 50 |
+
|
| 51 |
+
def _load_sample(self, rgb_path):
|
| 52 |
+
img_name = os.path.basename(rgb_path)
|
| 53 |
+
color = load_and_preprocess_images([rgb_path])[0]
|
| 54 |
+
pose = torch.from_numpy(self.poses[img_name]).float()
|
| 55 |
+
|
| 56 |
+
return dict(
|
| 57 |
+
img=color,
|
| 58 |
+
camera_pose=pose, # cam2world
|
| 59 |
+
dataset="7Scenes",
|
| 60 |
+
true_shape=torch.tensor([392, 518]),
|
| 61 |
+
label=img_name,
|
| 62 |
+
instance=img_name,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def read_images_bin(bin_path: str | Path):
|
| 67 |
+
bin_path = Path(bin_path)
|
| 68 |
+
poses = {}
|
| 69 |
+
|
| 70 |
+
with bin_path.open("rb") as f:
|
| 71 |
+
num_images = struct.unpack("<Q", f.read(8))[0] # uint64
|
| 72 |
+
for _ in range(num_images):
|
| 73 |
+
image_id = struct.unpack("<I", f.read(4))[0]
|
| 74 |
+
qvec = np.frombuffer(f.read(8 * 4), dtype=np.float64) # qw,qx,qy,qz
|
| 75 |
+
tvec = np.frombuffer(f.read(8 * 3), dtype=np.float64) # tx,ty,tz
|
| 76 |
+
cam_id = struct.unpack("<I", f.read(4))[0] # camera_id
|
| 77 |
+
|
| 78 |
+
name_bytes = bytearray()
|
| 79 |
+
while True:
|
| 80 |
+
c = f.read(1)
|
| 81 |
+
if c == b"\0":
|
| 82 |
+
break
|
| 83 |
+
name_bytes.extend(c)
|
| 84 |
+
name = name_bytes.decode("utf-8").split("/")[-1] # 去掉后缀
|
| 85 |
+
|
| 86 |
+
n_pts = struct.unpack("<Q", f.read(8))[0]
|
| 87 |
+
f.seek(n_pts * 24, 1)
|
| 88 |
+
|
| 89 |
+
# world→cam to cam→world
|
| 90 |
+
qw, qx, qy, qz = qvec
|
| 91 |
+
R_wc = np.array(
|
| 92 |
+
[
|
| 93 |
+
[
|
| 94 |
+
1 - 2 * qy * qy - 2 * qz * qz,
|
| 95 |
+
2 * qx * qy + 2 * qz * qw,
|
| 96 |
+
2 * qx * qz - 2 * qy * qw,
|
| 97 |
+
],
|
| 98 |
+
[
|
| 99 |
+
2 * qx * qy - 2 * qz * qw,
|
| 100 |
+
1 - 2 * qx * qx - 2 * qz * qz,
|
| 101 |
+
2 * qy * qz + 2 * qx * qw,
|
| 102 |
+
],
|
| 103 |
+
[
|
| 104 |
+
2 * qx * qz + 2 * qy * qw,
|
| 105 |
+
2 * qy * qz - 2 * qx * qw,
|
| 106 |
+
1 - 2 * qx * qx - 2 * qy * qy,
|
| 107 |
+
],
|
| 108 |
+
]
|
| 109 |
+
)
|
| 110 |
+
t_wc = -R_wc @ tvec
|
| 111 |
+
c2w = np.eye(4, dtype=np.float32)
|
| 112 |
+
c2w[:3, :3] = R_wc.astype(np.float32)
|
| 113 |
+
c2w[:3, 3] = t_wc.astype(np.float32)
|
| 114 |
+
|
| 115 |
+
poses[name] = c2w
|
| 116 |
+
return poses
|
eval/datasets/tum.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
from vggt.utils.load_fn import load_and_preprocess_images
|
| 9 |
+
|
| 10 |
+
from eval.utils.eval_utils import uniform_sample
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TumDatasetAll(Dataset):
|
| 14 |
+
def __init__(self, root_dir, scene_name="rgbd_dataset_freiburg1_360"):
|
| 15 |
+
self.scene_name = scene_name
|
| 16 |
+
self.scene_all_dir = os.path.join(root_dir, f"{scene_name}", "rgb")
|
| 17 |
+
|
| 18 |
+
self.test_samples = sorted(glob.glob(os.path.join(self.scene_all_dir, "*.png")))
|
| 19 |
+
self.all_samples = self.test_samples
|
| 20 |
+
|
| 21 |
+
def __len__(self):
|
| 22 |
+
return len(self.all_samples)
|
| 23 |
+
|
| 24 |
+
def __getitem__(self, idx):
|
| 25 |
+
return self._load_sample(self.all_samples[idx])
|
| 26 |
+
|
| 27 |
+
def get_train_sample(self, n=4):
|
| 28 |
+
gap = len(self.all_samples) // n
|
| 29 |
+
gap = max(gap, 1) # Ensure at least one sample is selected
|
| 30 |
+
gap = min(gap, len(self.all_samples)) # Ensure gap does not exceed length
|
| 31 |
+
if gap == 1:
|
| 32 |
+
uniform_sampled = uniform_sample(len(self.all_samples), n)
|
| 33 |
+
selected = [self.all_samples[i] for i in uniform_sampled]
|
| 34 |
+
else:
|
| 35 |
+
selected = self.all_samples[::gap]
|
| 36 |
+
if len(selected) > n:
|
| 37 |
+
uniform_sampled = uniform_sample(len(selected), n)
|
| 38 |
+
selected = [selected[i] for i in uniform_sampled]
|
| 39 |
+
if self.scene_name == "rgbd_dataset_freiburg1_floor":
|
| 40 |
+
selected += self.all_samples[-20::5]
|
| 41 |
+
return [self._load_sample(s) for s in selected]
|
| 42 |
+
|
| 43 |
+
def _load_sample(self, rgb_path):
|
| 44 |
+
img_name = os.path.basename(rgb_path)
|
| 45 |
+
color = load_and_preprocess_images([rgb_path])[0]
|
| 46 |
+
|
| 47 |
+
return dict(
|
| 48 |
+
img=color,
|
| 49 |
+
dataset="tnt_all",
|
| 50 |
+
true_shape=torch.tensor([392, 518]),
|
| 51 |
+
label=img_name,
|
| 52 |
+
instance=img_name,
|
| 53 |
+
)
|
eval/readme.md
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Tanks and Temples.
|
| 2 |
+
|
| 3 |
+
### Image set
|
| 4 |
+
1. Data Preparation
|
| 5 |
+
|
| 6 |
+
Download the images data from [here](https://cvg.cit.tum.de/data/datasets/rgbd-dataset/download) (for intermidiated and advanced set, please download from [here](https://github.com/isl-org/TanksAndTemples/issues/35)), and COLMAP results from (here)[https://storage.googleapis.com/niantic-lon-static/research/acezero/colmap_raw.tar.gz]. We thank [ACE0](https://github.com/nianticlabs/acezero) again for providing the COLMAP results.
|
| 7 |
+
|
| 8 |
+
2. Adjust the parameter in `run_tnt.sh`
|
| 9 |
+
|
| 10 |
+
Specify the `dataset_root`, `colmap_dir`, `model_path` and `save_dir` in the file.
|
| 11 |
+
|
| 12 |
+
3. Get the inference results.
|
| 13 |
+
|
| 14 |
+
```
|
| 15 |
+
sh run_tnt.sh
|
| 16 |
+
```
|
| 17 |
+
### Video set
|
| 18 |
+
<details>
|
| 19 |
+
<summary>Click to expand</summary>
|
| 20 |
+
|
| 21 |
+
1. Data Preparation
|
| 22 |
+
|
| 23 |
+
Download the video sequence and from [here](https://www.tanksandtemples.org/download/) and get images from video via [this](https://www.tanksandtemples.org/tutorial/).
|
| 24 |
+
|
| 25 |
+
2. Run Inference
|
| 26 |
+
|
| 27 |
+
Replace `docs/demo_image` in `../demo.py` to the path storing images from videl.
|
| 28 |
+
</details>
|
| 29 |
+
|
| 30 |
+
## 7 scenes
|
| 31 |
+
|
| 32 |
+
1. Data Preparation
|
| 33 |
+
Download the corresponding sequence from [here](https://jonbarron.info/mipnerf360/).
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
## TUM-RGBD
|
| 37 |
+
|
| 38 |
+
1. Data Preparation
|
| 39 |
+
|
| 40 |
+
Download the corresponding sequence from [here](https://cvg.cit.tum.de/data/datasets/rgbd-dataset/download).
|
| 41 |
+
|
| 42 |
+
2. Adjust the parameter in `run_tum.sh`
|
| 43 |
+
|
| 44 |
+
Specify the `dataset_root`, `recon_img_num`, `model_path` and `save_dir` in the file.
|
| 45 |
+
|
| 46 |
+
3. Evaluate the results.
|
| 47 |
+
|
| 48 |
+
```
|
| 49 |
+
sh run_tum.sh
|
| 50 |
+
```
|
| 51 |
+
Noting that we set the `recon_img_num` to 50 or 100 according to the length of dataset. Please refer to the supplementary of paper for detail.
|
| 52 |
+
|
| 53 |
+
4. Using evo to evaluate The results
|
| 54 |
+
|
| 55 |
+
```
|
| 56 |
+
evo_ape tum gt_pose.txt pred_tum.txt -vas
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
## 7 scenes
|
| 61 |
+
|
| 62 |
+
1. Download the dataset from [here](https://www.microsoft.com/en-us/research/project/rgb-d-dataset-7-scenes/) and _Pseudo Ground Truth (PGT)_
|
| 63 |
+
(see
|
| 64 |
+
the [ICCV 2021 paper](https://openaccess.thecvf.com/content/ICCV2021/html/Brachmann_On_the_Limits_of_Pseudo_Ground_Truth_in_Visual_Camera_ICCV_2021_paper.html)
|
| 65 |
+
,
|
| 66 |
+
and [associated code](https://github.com/tsattler/visloc_pseudo_gt_limitations/) for details).
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
2. Adjust the parameter in `run_7scenes.sh`
|
| 70 |
+
|
| 71 |
+
Specify the `dataset_root`, `recon_img_num`, `model_path` and `save_dir` in the file.
|
| 72 |
+
|
| 73 |
+
3. Evaluate the results.
|
| 74 |
+
|
| 75 |
+
```
|
| 76 |
+
sh run_7scenes.sh
|
| 77 |
+
```
|
| 78 |
+
You will see a `result.txt` file reporting the evaluation results.
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
## Mip-NeRF 360
|
| 82 |
+
|
| 83 |
+
1. Data Preparation
|
| 84 |
+
|
| 85 |
+
Download the data from [here](https://jonbarron.info/mipnerf360/).
|
| 86 |
+
|
| 87 |
+
2. Adjust the parameter in `run_mip.sh`
|
| 88 |
+
|
| 89 |
+
Specify the `dataset_root`, `model_path` and `save_dir` in the file.
|
| 90 |
+
|
| 91 |
+
3. Get the inference results.
|
| 92 |
+
|
| 93 |
+
```
|
| 94 |
+
sh run_mip.sh
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
## Co3D-V2
|
| 98 |
+
|
| 99 |
+
1. We thank VGGT for providing evaluation code of CO3D-V2 dataset. Please see link [here](https://github.com/facebookresearch/vggt/tree/evaluation/evaluation#dataset-preparation) for data preparation and processing.
|
| 100 |
+
|
| 101 |
+
2. Adjust the parameterco3d_dir in `runco3d_anno_dir_7scenes.sh`
|
| 102 |
+
|
| 103 |
+
Specify the `dataset_root`, `recon_img_num`, `model_path`, `recon`, `reloc` and `fixed_rank` in the file.
|
| 104 |
+
|
| 105 |
+
3. Evaluate the results.
|
| 106 |
+
|
| 107 |
+
```
|
| 108 |
+
sh run_co3d.sh
|
| 109 |
+
```
|
| 110 |
+
You will see evaluation result in the terminal.
|
eval/utils/cropping.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 10 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 11 |
+
#
|
| 12 |
+
# --------------------------------------------------------
|
| 13 |
+
# croppping utilities
|
| 14 |
+
# --------------------------------------------------------
|
| 15 |
+
import PIL.Image
|
| 16 |
+
|
| 17 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
| 18 |
+
import cv2 # noqa
|
| 19 |
+
import numpy as np # noqa
|
| 20 |
+
from easyvolcap.reloc_eval.utils.device import to_numpy
|
| 21 |
+
from easyvolcap.reloc_eval.utils.geometry import ( # noqa
|
| 22 |
+
colmap_to_opencv_intrinsics,
|
| 23 |
+
geotrf,
|
| 24 |
+
inv,
|
| 25 |
+
opencv_to_colmap_intrinsics,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
lanczos = PIL.Image.Resampling.LANCZOS
|
| 30 |
+
bicubic = PIL.Image.Resampling.BICUBIC
|
| 31 |
+
except AttributeError:
|
| 32 |
+
lanczos = PIL.Image.LANCZOS
|
| 33 |
+
bicubic = PIL.Image.BICUBIC
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ImageList:
|
| 37 |
+
"""Convenience class to aply the same operation to a whole set of images."""
|
| 38 |
+
|
| 39 |
+
def __init__(self, images):
|
| 40 |
+
if not isinstance(images, (tuple, list, set)):
|
| 41 |
+
images = [images]
|
| 42 |
+
self.images = []
|
| 43 |
+
for image in images:
|
| 44 |
+
if not isinstance(image, PIL.Image.Image):
|
| 45 |
+
image = PIL.Image.fromarray(image)
|
| 46 |
+
self.images.append(image)
|
| 47 |
+
|
| 48 |
+
def __len__(self):
|
| 49 |
+
return len(self.images)
|
| 50 |
+
|
| 51 |
+
def to_pil(self):
|
| 52 |
+
return tuple(self.images) if len(self.images) > 1 else self.images[0]
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def size(self):
|
| 56 |
+
sizes = [im.size for im in self.images]
|
| 57 |
+
assert all(sizes[0] == s for s in sizes)
|
| 58 |
+
return sizes[0]
|
| 59 |
+
|
| 60 |
+
def resize(self, *args, **kwargs):
|
| 61 |
+
return ImageList(self._dispatch("resize", *args, **kwargs))
|
| 62 |
+
|
| 63 |
+
def crop(self, *args, **kwargs):
|
| 64 |
+
return ImageList(self._dispatch("crop", *args, **kwargs))
|
| 65 |
+
|
| 66 |
+
def _dispatch(self, func, *args, **kwargs):
|
| 67 |
+
return [getattr(im, func)(*args, **kwargs) for im in self.images]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def rescale_image_depthmap(
|
| 71 |
+
image, depthmap, camera_intrinsics, output_resolution, force=True
|
| 72 |
+
):
|
| 73 |
+
"""Jointly rescale a (image, depthmap)
|
| 74 |
+
so that (out_width, out_height) >= output_res
|
| 75 |
+
"""
|
| 76 |
+
image = ImageList(image)
|
| 77 |
+
input_resolution = np.array(image.size) # (W,H)
|
| 78 |
+
output_resolution = np.array(output_resolution)
|
| 79 |
+
if depthmap is not None:
|
| 80 |
+
# can also use this with masks instead of depthmaps
|
| 81 |
+
assert tuple(depthmap.shape[:2]) == image.size[::-1]
|
| 82 |
+
|
| 83 |
+
# define output resolution
|
| 84 |
+
assert output_resolution.shape == (2,)
|
| 85 |
+
scale_final = max(output_resolution / image.size) + 1e-8
|
| 86 |
+
if scale_final >= 1 and not force: # image is already smaller than what is asked
|
| 87 |
+
return (image.to_pil(), depthmap, camera_intrinsics)
|
| 88 |
+
output_resolution = np.floor(input_resolution * scale_final).astype(int)
|
| 89 |
+
|
| 90 |
+
# first rescale the image so that it contains the crop
|
| 91 |
+
image = image.resize(
|
| 92 |
+
tuple(output_resolution), resample=lanczos if scale_final < 1 else bicubic
|
| 93 |
+
)
|
| 94 |
+
if depthmap is not None:
|
| 95 |
+
depthmap = cv2.resize(
|
| 96 |
+
depthmap,
|
| 97 |
+
output_resolution,
|
| 98 |
+
fx=scale_final,
|
| 99 |
+
fy=scale_final,
|
| 100 |
+
interpolation=cv2.INTER_NEAREST,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# no offset here; simple rescaling
|
| 104 |
+
camera_intrinsics = camera_matrix_of_crop(
|
| 105 |
+
camera_intrinsics, input_resolution, output_resolution, scaling=scale_final
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
return image.to_pil(), depthmap, camera_intrinsics
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def camera_matrix_of_crop(
|
| 112 |
+
input_camera_matrix,
|
| 113 |
+
input_resolution,
|
| 114 |
+
output_resolution,
|
| 115 |
+
scaling=1,
|
| 116 |
+
offset_factor=0.5,
|
| 117 |
+
offset=None,
|
| 118 |
+
):
|
| 119 |
+
# Margins to offset the origin
|
| 120 |
+
margins = np.asarray(input_resolution) * scaling - output_resolution
|
| 121 |
+
assert np.all(margins >= 0.0)
|
| 122 |
+
if offset is None:
|
| 123 |
+
offset = offset_factor * margins
|
| 124 |
+
|
| 125 |
+
# Generate new camera parameters
|
| 126 |
+
output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
|
| 127 |
+
output_camera_matrix_colmap[:2, :] *= scaling
|
| 128 |
+
output_camera_matrix_colmap[:2, 2] -= offset
|
| 129 |
+
output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
|
| 130 |
+
|
| 131 |
+
return output_camera_matrix
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):
|
| 135 |
+
"""
|
| 136 |
+
Return a crop of the input view.
|
| 137 |
+
"""
|
| 138 |
+
image = ImageList(image)
|
| 139 |
+
l, t, r, b = crop_bbox
|
| 140 |
+
|
| 141 |
+
image = image.crop((l, t, r, b))
|
| 142 |
+
if depthmap is not None:
|
| 143 |
+
depthmap = depthmap[t:b, l:r]
|
| 144 |
+
|
| 145 |
+
camera_intrinsics = camera_intrinsics.copy()
|
| 146 |
+
camera_intrinsics[0, 2] -= l
|
| 147 |
+
camera_intrinsics[1, 2] -= t
|
| 148 |
+
|
| 149 |
+
return image.to_pil(), depthmap, camera_intrinsics
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def bbox_from_intrinsics_in_out(
|
| 153 |
+
input_camera_matrix, output_camera_matrix, output_resolution
|
| 154 |
+
):
|
| 155 |
+
out_width, out_height = output_resolution
|
| 156 |
+
l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
|
| 157 |
+
crop_bbox = (l, t, l + out_width, t + out_height)
|
| 158 |
+
return crop_bbox
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def reciprocal_1d(corres_1_to_2, corres_2_to_1, ret_recip=False):
|
| 162 |
+
is_reciprocal1 = corres_2_to_1[corres_1_to_2] == np.arange(len(corres_1_to_2))
|
| 163 |
+
pos1 = is_reciprocal1.nonzero()[0]
|
| 164 |
+
pos2 = corres_1_to_2[pos1]
|
| 165 |
+
if ret_recip:
|
| 166 |
+
return is_reciprocal1, pos1, pos2
|
| 167 |
+
return pos1, pos2
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def generate_non_self_pairs(n):
|
| 171 |
+
i, j = np.meshgrid(np.arange(n), np.arange(n), indexing="ij")
|
| 172 |
+
|
| 173 |
+
pairs = np.stack([i.ravel(), j.ravel()], axis=1)
|
| 174 |
+
|
| 175 |
+
mask = pairs[:, 0] != pairs[:, 1]
|
| 176 |
+
filtered_pairs = pairs[mask]
|
| 177 |
+
|
| 178 |
+
return filtered_pairs
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def unravel_xy(pos, shape):
|
| 182 |
+
# convert (x+W*y) back to 2d (x,y) coordinates
|
| 183 |
+
return np.unravel_index(pos, shape)[0].base[:, ::-1].copy()
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def ravel_xy(pos, shape):
|
| 187 |
+
H, W = shape
|
| 188 |
+
with np.errstate(invalid="ignore"):
|
| 189 |
+
qx, qy = pos.reshape(-1, 2).round().astype(np.int32).T
|
| 190 |
+
quantized_pos = qx.clip(min=0, max=W - 1, out=qx) + W * qy.clip(
|
| 191 |
+
min=0, max=H - 1, out=qy
|
| 192 |
+
)
|
| 193 |
+
return quantized_pos
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def extract_correspondences_from_pts3d(
|
| 197 |
+
view1, view2, target_n_corres, rng=np.random, ret_xy=True, nneg=0
|
| 198 |
+
):
|
| 199 |
+
view1, view2 = to_numpy((view1, view2))
|
| 200 |
+
# project pixels from image1 --> 3d points --> image2 pixels
|
| 201 |
+
shape1, corres1_to_2 = reproject_view(view1["pts3d"], view2)
|
| 202 |
+
shape2, corres2_to_1 = reproject_view(view2["pts3d"], view1)
|
| 203 |
+
|
| 204 |
+
# compute reciprocal correspondences:
|
| 205 |
+
# pos1 == valid pixels (correspondences) in image1
|
| 206 |
+
is_reciprocal1, pos1, pos2 = reciprocal_1d(
|
| 207 |
+
corres1_to_2, corres2_to_1, ret_recip=True
|
| 208 |
+
)
|
| 209 |
+
is_reciprocal2 = corres1_to_2[corres2_to_1] == np.arange(len(corres2_to_1))
|
| 210 |
+
|
| 211 |
+
if target_n_corres is None:
|
| 212 |
+
if ret_xy:
|
| 213 |
+
pos1 = unravel_xy(pos1, shape1)
|
| 214 |
+
pos2 = unravel_xy(pos2, shape2)
|
| 215 |
+
return pos1, pos2
|
| 216 |
+
|
| 217 |
+
available_negatives = min((~is_reciprocal1).sum(), (~is_reciprocal2).sum())
|
| 218 |
+
target_n_positives = int(target_n_corres * (1 - nneg))
|
| 219 |
+
n_positives = min(len(pos1), target_n_positives)
|
| 220 |
+
n_negatives = min(target_n_corres - n_positives, available_negatives)
|
| 221 |
+
|
| 222 |
+
if n_negatives + n_positives != target_n_corres:
|
| 223 |
+
# should be really rare => when there are not enough negatives
|
| 224 |
+
# in that case, break nneg and add a few more positives ?
|
| 225 |
+
n_positives = target_n_corres - n_negatives
|
| 226 |
+
assert n_positives <= len(pos1)
|
| 227 |
+
|
| 228 |
+
assert n_positives <= len(pos1)
|
| 229 |
+
assert n_positives <= len(pos2)
|
| 230 |
+
assert n_negatives <= (~is_reciprocal1).sum()
|
| 231 |
+
assert n_negatives <= (~is_reciprocal2).sum()
|
| 232 |
+
assert n_positives + n_negatives == target_n_corres
|
| 233 |
+
|
| 234 |
+
valid = np.ones(n_positives, dtype=bool)
|
| 235 |
+
if n_positives < len(pos1):
|
| 236 |
+
# random sub-sampling of valid correspondences
|
| 237 |
+
perm = rng.permutation(len(pos1))[:n_positives]
|
| 238 |
+
pos1 = pos1[perm]
|
| 239 |
+
pos2 = pos2[perm]
|
| 240 |
+
|
| 241 |
+
if n_negatives > 0:
|
| 242 |
+
# add false correspondences if not enough
|
| 243 |
+
def norm(p):
|
| 244 |
+
return p / p.sum()
|
| 245 |
+
|
| 246 |
+
pos1 = np.r_[
|
| 247 |
+
pos1,
|
| 248 |
+
rng.choice(
|
| 249 |
+
shape1[0] * shape1[1],
|
| 250 |
+
size=n_negatives,
|
| 251 |
+
replace=False,
|
| 252 |
+
p=norm(~is_reciprocal1),
|
| 253 |
+
),
|
| 254 |
+
]
|
| 255 |
+
pos2 = np.r_[
|
| 256 |
+
pos2,
|
| 257 |
+
rng.choice(
|
| 258 |
+
shape2[0] * shape2[1],
|
| 259 |
+
size=n_negatives,
|
| 260 |
+
replace=False,
|
| 261 |
+
p=norm(~is_reciprocal2),
|
| 262 |
+
),
|
| 263 |
+
]
|
| 264 |
+
valid = np.r_[valid, np.zeros(n_negatives, dtype=bool)]
|
| 265 |
+
|
| 266 |
+
# convert (x+W*y) back to 2d (x,y) coordinates
|
| 267 |
+
if ret_xy:
|
| 268 |
+
pos1 = unravel_xy(pos1, shape1)
|
| 269 |
+
pos2 = unravel_xy(pos2, shape2)
|
| 270 |
+
return pos1, pos2, valid
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def reproject_view(pts3d, view2):
|
| 274 |
+
shape = view2["pts3d"].shape[:2]
|
| 275 |
+
return reproject(
|
| 276 |
+
pts3d, view2["camera_intrinsics"], inv(view2["camera_pose"]), shape
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def reproject(pts3d, K, world2cam, shape):
|
| 281 |
+
H, W, THREE = pts3d.shape
|
| 282 |
+
assert THREE == 3
|
| 283 |
+
|
| 284 |
+
# reproject in camera2 space
|
| 285 |
+
with np.errstate(divide="ignore", invalid="ignore"):
|
| 286 |
+
pos = geotrf(K @ world2cam[:3], pts3d, norm=1, ncol=2)
|
| 287 |
+
|
| 288 |
+
# quantize to pixel positions
|
| 289 |
+
return (H, W), ravel_xy(pos, shape)
|
eval/utils/device.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 8 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 9 |
+
#
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
# utilitary functions for DUSt3R
|
| 12 |
+
# --------------------------------------------------------
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def todevice(batch, device, callback=None, non_blocking=False):
|
| 18 |
+
"""Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
|
| 19 |
+
|
| 20 |
+
batch: list, tuple, dict of tensors or other things
|
| 21 |
+
device: pytorch device or 'numpy'
|
| 22 |
+
callback: function that would be called on every sub-elements.
|
| 23 |
+
"""
|
| 24 |
+
if callback:
|
| 25 |
+
batch = callback(batch)
|
| 26 |
+
|
| 27 |
+
if isinstance(batch, dict):
|
| 28 |
+
return {k: todevice(v, device) for k, v in batch.items()}
|
| 29 |
+
|
| 30 |
+
if isinstance(batch, (tuple, list)):
|
| 31 |
+
return type(batch)(todevice(x, device) for x in batch)
|
| 32 |
+
|
| 33 |
+
x = batch
|
| 34 |
+
if device == "numpy":
|
| 35 |
+
if isinstance(x, torch.Tensor):
|
| 36 |
+
x = x.detach().cpu().numpy()
|
| 37 |
+
elif x is not None:
|
| 38 |
+
if isinstance(x, np.ndarray):
|
| 39 |
+
x = torch.from_numpy(x)
|
| 40 |
+
if torch.is_tensor(x):
|
| 41 |
+
x = x.to(device, non_blocking=non_blocking)
|
| 42 |
+
return x
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
to_device = todevice # alias
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def to_numpy(x):
|
| 49 |
+
return todevice(x, "numpy")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def to_cpu(x):
|
| 53 |
+
return todevice(x, "cpu")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def to_cuda(x):
|
| 57 |
+
return todevice(x, "cuda")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def collate_with_cat(whatever, lists=False):
|
| 61 |
+
if isinstance(whatever, dict):
|
| 62 |
+
return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()}
|
| 63 |
+
|
| 64 |
+
elif isinstance(whatever, (tuple, list)):
|
| 65 |
+
if len(whatever) == 0:
|
| 66 |
+
return whatever
|
| 67 |
+
elem = whatever[0]
|
| 68 |
+
T = type(whatever)
|
| 69 |
+
|
| 70 |
+
if elem is None:
|
| 71 |
+
return None
|
| 72 |
+
if isinstance(elem, (bool, float, int, str)):
|
| 73 |
+
return whatever
|
| 74 |
+
if isinstance(elem, tuple):
|
| 75 |
+
return T(collate_with_cat(x, lists=lists) for x in zip(*whatever))
|
| 76 |
+
if isinstance(elem, dict):
|
| 77 |
+
return {
|
| 78 |
+
k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
if isinstance(elem, torch.Tensor):
|
| 82 |
+
return listify(whatever) if lists else torch.cat(whatever)
|
| 83 |
+
if isinstance(elem, np.ndarray):
|
| 84 |
+
return (
|
| 85 |
+
listify(whatever)
|
| 86 |
+
if lists
|
| 87 |
+
else torch.cat([torch.from_numpy(x) for x in whatever])
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# otherwise, we just chain lists
|
| 91 |
+
return sum(whatever, T())
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def listify(elems):
|
| 95 |
+
return [x for e in elems for x in e]
|
eval/utils/eval_pose_ransac.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
import random
|
| 4 |
+
from collections import namedtuple
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
from scipy.spatial.transform import Rotation
|
| 9 |
+
|
| 10 |
+
_logger = logging.getLogger(__name__)
|
| 11 |
+
_logger.setLevel(logging.DEBUG)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def kabsch(pts1, pts2, estimate_scale=False):
|
| 15 |
+
c_pts1 = pts1 - pts1.mean(axis=0)
|
| 16 |
+
c_pts2 = pts2 - pts2.mean(axis=0)
|
| 17 |
+
|
| 18 |
+
covariance = np.matmul(c_pts1.T, c_pts2) / c_pts1.shape[0]
|
| 19 |
+
|
| 20 |
+
U, S, VT = np.linalg.svd(covariance)
|
| 21 |
+
|
| 22 |
+
d = np.sign(np.linalg.det(np.matmul(VT.T, U.T)))
|
| 23 |
+
correction = np.eye(3)
|
| 24 |
+
correction[2, 2] = d
|
| 25 |
+
|
| 26 |
+
if estimate_scale:
|
| 27 |
+
pts_var = np.mean(np.linalg.norm(c_pts2, axis=1) ** 2)
|
| 28 |
+
scale_factor = pts_var / np.trace(S * correction)
|
| 29 |
+
else:
|
| 30 |
+
scale_factor = 1.0
|
| 31 |
+
|
| 32 |
+
R = scale_factor * np.matmul(np.matmul(VT.T, correction), U.T)
|
| 33 |
+
t = pts2.mean(axis=0) - np.matmul(R, pts1.mean(axis=0))
|
| 34 |
+
|
| 35 |
+
T = np.eye(4)
|
| 36 |
+
T[:3, :3] = R
|
| 37 |
+
T[:3, 3] = t
|
| 38 |
+
|
| 39 |
+
return T, scale_factor
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_inliers(h_T, poses_gt, poses_est, inlier_threshold_t, inlier_threshold_r):
|
| 43 |
+
# h_T aligns ground truth poses with estimates poses
|
| 44 |
+
poses_gt_transformed = h_T @ poses_gt
|
| 45 |
+
|
| 46 |
+
# calculate differences in position and rotations
|
| 47 |
+
translations_delta = poses_gt_transformed[:, :3, 3] - poses_est[:, :3, 3]
|
| 48 |
+
rotations_delta = poses_gt_transformed[:, :3, :3] @ poses_est[:, :3, :3].transpose(
|
| 49 |
+
[0, 2, 1]
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# translation inliers
|
| 53 |
+
inliers_t = np.linalg.norm(translations_delta, axis=1) < inlier_threshold_t
|
| 54 |
+
# rotation inliers
|
| 55 |
+
inliers_r = Rotation.from_matrix(rotations_delta).magnitude() < (
|
| 56 |
+
inlier_threshold_r / 180 * math.pi
|
| 57 |
+
)
|
| 58 |
+
# intersection of both
|
| 59 |
+
return np.logical_and(inliers_r, inliers_t)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def print_hyp(hypothesis, hyp_name):
|
| 63 |
+
h_translation = np.linalg.norm(hypothesis["transformation"][:3, 3])
|
| 64 |
+
h_angle = (
|
| 65 |
+
np.linalg.norm(
|
| 66 |
+
Rotation.from_matrix(hypothesis["transformation"][:3, :3]).as_rotvec()
|
| 67 |
+
)
|
| 68 |
+
* 180
|
| 69 |
+
/ math.pi
|
| 70 |
+
)
|
| 71 |
+
print(
|
| 72 |
+
f"{hyp_name}: score={hypothesis['score']}, translation={h_translation:.2f}m, "
|
| 73 |
+
f"rotation={h_angle:.1f}deg."
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def estimated_alignment(
|
| 78 |
+
pose_est,
|
| 79 |
+
pose_gt,
|
| 80 |
+
inlier_threshold_t=0.05,
|
| 81 |
+
inlier_threshold_r=5,
|
| 82 |
+
ransac_iterations=1000,
|
| 83 |
+
refinement_max_hyp=12,
|
| 84 |
+
refinement_max_it=8,
|
| 85 |
+
estimate_scale=False,
|
| 86 |
+
):
|
| 87 |
+
n_pose = len(pose_est)
|
| 88 |
+
ransac_hypotheses = []
|
| 89 |
+
for i in range(ransac_iterations):
|
| 90 |
+
min_sample_size = 3
|
| 91 |
+
samples = random.sample(range(n_pose), min_sample_size)
|
| 92 |
+
h_pts1 = pose_gt[samples, :3, 3]
|
| 93 |
+
h_pts2 = pose_est[samples, :3, 3]
|
| 94 |
+
|
| 95 |
+
h_T, h_scale = kabsch(h_pts1, h_pts2, estimate_scale=estimate_scale)
|
| 96 |
+
|
| 97 |
+
inliers = get_inliers(
|
| 98 |
+
h_T, pose_gt, pose_est, inlier_threshold_t, inlier_threshold_r
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
if inliers[samples].sum() >= 3:
|
| 102 |
+
# only keep hypotheses if minimal sample is all inliers
|
| 103 |
+
ransac_hypotheses.append(
|
| 104 |
+
{
|
| 105 |
+
"transformation": h_T,
|
| 106 |
+
"inliers": inliers,
|
| 107 |
+
"score": inliers.sum(),
|
| 108 |
+
"scale": h_scale,
|
| 109 |
+
}
|
| 110 |
+
)
|
| 111 |
+
if len(ransac_hypotheses) == 0:
|
| 112 |
+
print(
|
| 113 |
+
f"Did not fine a single valid RANSAC hypothesis, abort alignment estimation."
|
| 114 |
+
)
|
| 115 |
+
return None, 1
|
| 116 |
+
|
| 117 |
+
# sort according to score
|
| 118 |
+
ransac_hypotheses = sorted(
|
| 119 |
+
ransac_hypotheses, key=lambda x: x["score"], reverse=True
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# for hyp_idx, hyp in enumerate(ransac_hypotheses):
|
| 123 |
+
# print_hyp(hyp, f"Hypothesis {hyp_idx}")
|
| 124 |
+
|
| 125 |
+
# create shortlist of best hypotheses for refinement
|
| 126 |
+
# print(f"Starting refinement of {refinement_max_hyp} best hypotheses.")
|
| 127 |
+
ransac_hypotheses = ransac_hypotheses[:refinement_max_hyp]
|
| 128 |
+
|
| 129 |
+
# refine all hypotheses in the short list
|
| 130 |
+
for ref_hyp in ransac_hypotheses:
|
| 131 |
+
# print_hyp(ref_hyp, "Pre-Refinement")
|
| 132 |
+
|
| 133 |
+
# refinement loop
|
| 134 |
+
for ref_it in range(refinement_max_it):
|
| 135 |
+
# re-solve alignment on all inliers
|
| 136 |
+
h_pts1 = pose_gt[ref_hyp["inliers"], :3, 3]
|
| 137 |
+
h_pts2 = pose_est[ref_hyp["inliers"], :3, 3]
|
| 138 |
+
|
| 139 |
+
h_T, h_scale = kabsch(h_pts1, h_pts2, estimate_scale)
|
| 140 |
+
|
| 141 |
+
# calculate new inliers
|
| 142 |
+
inliers = get_inliers(
|
| 143 |
+
h_T, pose_gt, pose_est, inlier_threshold_t, inlier_threshold_r
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# check whether hypothesis score improved
|
| 147 |
+
refined_score = inliers.sum()
|
| 148 |
+
|
| 149 |
+
if refined_score > ref_hyp["score"]:
|
| 150 |
+
ref_hyp["transformation"] = h_T
|
| 151 |
+
ref_hyp["inliers"] = inliers
|
| 152 |
+
ref_hyp["score"] = refined_score
|
| 153 |
+
ref_hyp["scale"] = h_scale
|
| 154 |
+
|
| 155 |
+
# print_hyp(ref_hyp, f"Refinement interation {ref_it}")
|
| 156 |
+
|
| 157 |
+
else:
|
| 158 |
+
# print(f"Stopping refinement. Score did not improve: New score={refined_score}, "
|
| 159 |
+
# f"Old score={ref_hyp['score']}")
|
| 160 |
+
break
|
| 161 |
+
|
| 162 |
+
# re-sort refined hyotheses
|
| 163 |
+
ransac_hypotheses = sorted(
|
| 164 |
+
ransac_hypotheses, key=lambda x: x["score"], reverse=True
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# for hyp_idx, hyp in enumerate(ransac_hypotheses):
|
| 168 |
+
# print_hyp(hyp, f"Hypothesis {hyp_idx}")
|
| 169 |
+
|
| 170 |
+
return ransac_hypotheses[0]["transformation"], ransac_hypotheses[0]["scale"]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def eval_pose_ransac(gt, est, t_thres=0.05, r_thres=5, aligned=True, save_dir=None):
|
| 174 |
+
if aligned:
|
| 175 |
+
alignment_transformation, alignment_scale = estimated_alignment(
|
| 176 |
+
est,
|
| 177 |
+
gt,
|
| 178 |
+
inlier_threshold_t=0.05,
|
| 179 |
+
inlier_threshold_r=5,
|
| 180 |
+
ransac_iterations=1000,
|
| 181 |
+
refinement_max_hyp=12,
|
| 182 |
+
refinement_max_it=8,
|
| 183 |
+
estimate_scale=True,
|
| 184 |
+
)
|
| 185 |
+
if alignment_transformation is None:
|
| 186 |
+
_logger.info(
|
| 187 |
+
f"Alignment requested but failed. Setting all pose errors to {math.inf}."
|
| 188 |
+
)
|
| 189 |
+
else:
|
| 190 |
+
alignment_transformation = np.eye(4)
|
| 191 |
+
alignment_scale = 1.0
|
| 192 |
+
# Evaluation Loop
|
| 193 |
+
|
| 194 |
+
rErrs = []
|
| 195 |
+
tErrs = []
|
| 196 |
+
accuracy = 0
|
| 197 |
+
r_acc_5 = 0
|
| 198 |
+
r_acc_2 = 0
|
| 199 |
+
r_acc_1 = 0
|
| 200 |
+
t_acc_15 = 0
|
| 201 |
+
t_acc_10 = 0
|
| 202 |
+
t_acc_5 = 0
|
| 203 |
+
t_acc_2 = 0
|
| 204 |
+
t_acc_1 = 0
|
| 205 |
+
acc_10 = 0
|
| 206 |
+
acc_5 = 0
|
| 207 |
+
acc_2 = 0
|
| 208 |
+
acc_1 = 0
|
| 209 |
+
|
| 210 |
+
for pred_pose, gt_pose in zip(est, gt):
|
| 211 |
+
if alignment_transformation is not None:
|
| 212 |
+
# Apply alignment transformation to GT pose
|
| 213 |
+
gt_pose = alignment_transformation @ gt_pose
|
| 214 |
+
|
| 215 |
+
# Calculate translation error.
|
| 216 |
+
t_err = float(np.linalg.norm(gt_pose[0:3, 3] - pred_pose[0:3, 3]))
|
| 217 |
+
|
| 218 |
+
# Correct translation scale with the inverse alignment scale (since we align GT with estimates)
|
| 219 |
+
t_err = t_err / alignment_scale
|
| 220 |
+
|
| 221 |
+
# Rotation error.
|
| 222 |
+
gt_R = gt_pose[0:3, 0:3]
|
| 223 |
+
out_R = pred_pose[0:3, 0:3]
|
| 224 |
+
|
| 225 |
+
r_err = np.matmul(out_R, np.transpose(gt_R))
|
| 226 |
+
# Compute angle-axis representation.
|
| 227 |
+
r_err = cv2.Rodrigues(r_err)[0]
|
| 228 |
+
# Extract the angle.
|
| 229 |
+
r_err = np.linalg.norm(r_err) * 180 / math.pi
|
| 230 |
+
else:
|
| 231 |
+
pose_gt = None
|
| 232 |
+
t_err, r_err = math.inf, math.inf
|
| 233 |
+
|
| 234 |
+
# _logger.info(f"Rotation Error: {r_err:.2f}deg, Translation Error: {t_err * 100:.1f}cm")
|
| 235 |
+
|
| 236 |
+
# Save the errors.
|
| 237 |
+
rErrs.append(r_err)
|
| 238 |
+
tErrs.append(t_err * 100) # in cm
|
| 239 |
+
|
| 240 |
+
# Check various thresholds.
|
| 241 |
+
if r_err < r_thres and t_err < t_thres:
|
| 242 |
+
accuracy += 1
|
| 243 |
+
if r_err < 5:
|
| 244 |
+
r_acc_5 += 1
|
| 245 |
+
if r_err < 2:
|
| 246 |
+
r_acc_2 += 1
|
| 247 |
+
if r_err < 1:
|
| 248 |
+
r_acc_1 += 1
|
| 249 |
+
if t_err < 0.15:
|
| 250 |
+
t_acc_15 += 1
|
| 251 |
+
if t_err < 0.10:
|
| 252 |
+
t_acc_10 += 1
|
| 253 |
+
if t_err < 0.05:
|
| 254 |
+
t_acc_5 += 1
|
| 255 |
+
if t_err < 0.02:
|
| 256 |
+
t_acc_2 += 1
|
| 257 |
+
if t_err < 0.01:
|
| 258 |
+
t_acc_1 += 1
|
| 259 |
+
if r_err < 10 and t_err < 0.10:
|
| 260 |
+
acc_10 += 1
|
| 261 |
+
if r_err < 5 and t_err < 0.05:
|
| 262 |
+
acc_5 += 1
|
| 263 |
+
if r_err < 2 and t_err < 0.02:
|
| 264 |
+
acc_2 += 1
|
| 265 |
+
if r_err < 1 and t_err < 0.01:
|
| 266 |
+
acc_1 += 1
|
| 267 |
+
|
| 268 |
+
total_frames = len(rErrs)
|
| 269 |
+
assert total_frames == len(est)
|
| 270 |
+
|
| 271 |
+
# Compute median errors.
|
| 272 |
+
tErrs.sort()
|
| 273 |
+
rErrs.sort()
|
| 274 |
+
median_idx = total_frames // 2
|
| 275 |
+
median_rErr = rErrs[median_idx]
|
| 276 |
+
median_tErr = tErrs[median_idx]
|
| 277 |
+
|
| 278 |
+
# Compute final precision.
|
| 279 |
+
accuracy = accuracy / total_frames * 100
|
| 280 |
+
r_acc_5 = r_acc_5 / total_frames * 100
|
| 281 |
+
r_acc_2 = r_acc_2 / total_frames * 100
|
| 282 |
+
r_acc_1 = r_acc_1 / total_frames * 100
|
| 283 |
+
t_acc_15 = t_acc_15 / total_frames * 100
|
| 284 |
+
t_acc_10 = t_acc_10 / total_frames * 100
|
| 285 |
+
t_acc_5 = t_acc_5 / total_frames * 100
|
| 286 |
+
t_acc_2 = t_acc_2 / total_frames * 100
|
| 287 |
+
t_acc_1 = t_acc_1 / total_frames * 100
|
| 288 |
+
acc_10 = acc_10 / total_frames * 100
|
| 289 |
+
acc_5 = acc_5 / total_frames * 100
|
| 290 |
+
acc_2 = acc_2 / total_frames * 100
|
| 291 |
+
acc_1 = acc_1 / total_frames * 100
|
| 292 |
+
|
| 293 |
+
# _logger.info("===================================================")
|
| 294 |
+
# _logger.info("Test complete.")
|
| 295 |
+
|
| 296 |
+
# _logger.info(f'Accuracy: {accuracy:.1f}%')
|
| 297 |
+
# _logger.info(f"Median Error: {median_rErr:.1f}deg, {median_tErr:.1f}cm")
|
| 298 |
+
# print("===================================================")
|
| 299 |
+
# print("Test complete.")
|
| 300 |
+
|
| 301 |
+
with open(save_dir, "w") as f:
|
| 302 |
+
f.write(f"Accuracy: {accuracy:.1f}%\n\n")
|
| 303 |
+
f.write(f"Median Error: {median_rErr:.1f}deg, {median_tErr:.1f}cm\n")
|
| 304 |
+
f.write(f"R acc 5: {r_acc_5:.1f}%\n")
|
| 305 |
+
f.write(f"R acc 2: {r_acc_2:.1f}%\n")
|
| 306 |
+
f.write(f"R acc 1: {r_acc_1:.1f}%\n")
|
| 307 |
+
f.write(f"T acc 15: {t_acc_15:.1f}%\n")
|
| 308 |
+
f.write(f"T acc 10: {t_acc_10:.1f}%\n")
|
| 309 |
+
f.write(f"T acc 5: {t_acc_5:.1f}%\n")
|
| 310 |
+
f.write(f"T acc 2: {t_acc_2:.1f}%\n")
|
| 311 |
+
f.write(f"T acc 1: {t_acc_1:.1f}%\n")
|
| 312 |
+
f.write(f"Acc 10: {acc_10:.1f}%\n")
|
| 313 |
+
f.write(f"Acc 5: {acc_5:.1f}%\n")
|
| 314 |
+
f.write(f"Acc 2: {acc_2:.1f}%\n")
|
| 315 |
+
f.write(f"Acc 1: {acc_1:.1f}%\n")
|
eval/utils/eval_utils.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from evo.core.trajectory import PosePath3D, PoseTrajectory3D
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def save_kitti_poses(poses, save_path):
|
| 9 |
+
with open(save_path, "w") as f:
|
| 10 |
+
for pose in poses: # pose: 4x4 numpy array
|
| 11 |
+
pose_line = pose[:3].reshape(-1) # flatten first 3 rows
|
| 12 |
+
f.write(" ".join(map(str, pose_line)) + "\n")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def save_tum_poses(poses, timestamps, save_path):
|
| 16 |
+
"""
|
| 17 |
+
Save poses in TUM RGB-D format.
|
| 18 |
+
Args:
|
| 19 |
+
poses: list or array of 4x4 numpy arrays (T_w_c)
|
| 20 |
+
timestamps: list or array of float timestamps (same length as poses)
|
| 21 |
+
save_path: output file path
|
| 22 |
+
"""
|
| 23 |
+
assert len(poses) == len(timestamps), "poses and timestamps length mismatch"
|
| 24 |
+
|
| 25 |
+
with open(save_path, "w") as f:
|
| 26 |
+
for ts, pose in zip(timestamps, poses):
|
| 27 |
+
tx, ty, tz = pose[0, 3], pose[1, 3], pose[2, 3]
|
| 28 |
+
|
| 29 |
+
R = pose[:3, :3]
|
| 30 |
+
qw = np.sqrt(max(0, 1 + R[0, 0] + R[1, 1] + R[2, 2])) / 2
|
| 31 |
+
qx = np.sqrt(max(0, 1 + R[0, 0] - R[1, 1] - R[2, 2])) / 2
|
| 32 |
+
qy = np.sqrt(max(0, 1 - R[0, 0] + R[1, 1] - R[2, 2])) / 2
|
| 33 |
+
qz = np.sqrt(max(0, 1 - R[0, 0] - R[1, 1] + R[2, 2])) / 2
|
| 34 |
+
qx = math.copysign(qx, R[2, 1] - R[1, 2])
|
| 35 |
+
qy = math.copysign(qy, R[0, 2] - R[2, 0])
|
| 36 |
+
qz = math.copysign(qz, R[1, 0] - R[0, 1])
|
| 37 |
+
|
| 38 |
+
f.write(
|
| 39 |
+
f"{ts:.6f} {tx:.6f} {ty:.6f} {tz:.6f} {qx:.6f} {qy:.6f} {qz:.6f} {qw:.6f}\n"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def align_gt_pred(gt_views, poses_c2w_estimated):
|
| 44 |
+
poses_c2w_gt = [view["camera_pose"][0] for view in gt_views]
|
| 45 |
+
gt = PosePath3D(poses_se3=poses_c2w_gt)
|
| 46 |
+
pred = PosePath3D(poses_se3=poses_c2w_estimated[0])
|
| 47 |
+
r_a, t_a, s = pred.align(gt, correct_scale=True)
|
| 48 |
+
|
| 49 |
+
return pred.poses_se3, gt.poses_se3
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def align_gt_pred_2(poses_c2w_gt, poses_c2w_estimated):
|
| 53 |
+
# poses_c2w_gt = [view["camera_pose"][0] for view in gt_views]
|
| 54 |
+
gt = PosePath3D(poses_se3=poses_c2w_gt)
|
| 55 |
+
pred = PosePath3D(poses_se3=poses_c2w_estimated)
|
| 56 |
+
r_a, t_a, s = pred.align(gt, correct_scale=True)
|
| 57 |
+
|
| 58 |
+
return pred.poses_se3, gt.poses_se3
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def save_all_intrinsics_to_txt(result, filename="all_intrinsics.txt"):
|
| 62 |
+
with open(filename, "w") as f:
|
| 63 |
+
for i, res in enumerate(result):
|
| 64 |
+
intrinsic = res["intrinsic"].squeeze(0).reshape(-1).cpu().numpy() # (9,)
|
| 65 |
+
line = "\t".join([f"{v:.6f}" for v in intrinsic])
|
| 66 |
+
f.write(line + "\n")
|
| 67 |
+
print(f"[TXT] Saved {len(result)} intrinsics to {filename}")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def uniform_sample(total: int, select: int) -> list:
|
| 71 |
+
if select > total:
|
| 72 |
+
raise ValueError("select cannot be greater than total")
|
| 73 |
+
step = total / select
|
| 74 |
+
return [int(i * step) for i in range(select)]
|
eval/utils/geometry.py
ADDED
|
@@ -0,0 +1,572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 10 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 11 |
+
#
|
| 12 |
+
# --------------------------------------------------------
|
| 13 |
+
# geometry utilitary functions
|
| 14 |
+
# --------------------------------------------------------
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import torchvision.utils as vutils
|
| 19 |
+
from plyfile import PlyData, PlyElement
|
| 20 |
+
from scipy.spatial import cKDTree as KDTree
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
from eval.utils.device import to_numpy
|
| 24 |
+
from eval.utils.misc import invalid_to_nans, invalid_to_zeros
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def xy_grid(
|
| 28 |
+
W,
|
| 29 |
+
H,
|
| 30 |
+
device=None,
|
| 31 |
+
origin=(0, 0),
|
| 32 |
+
unsqueeze=None,
|
| 33 |
+
cat_dim=-1,
|
| 34 |
+
homogeneous=False,
|
| 35 |
+
**arange_kw,
|
| 36 |
+
):
|
| 37 |
+
"""Output a (H,W,2) array of int32
|
| 38 |
+
with output[j,i,0] = i + origin[0]
|
| 39 |
+
output[j,i,1] = j + origin[1]
|
| 40 |
+
"""
|
| 41 |
+
if device is None:
|
| 42 |
+
# numpy
|
| 43 |
+
arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
|
| 44 |
+
else:
|
| 45 |
+
# torch
|
| 46 |
+
arange = lambda *a, **kw: torch.arange(*a, device=device, **kw)
|
| 47 |
+
meshgrid, stack = torch.meshgrid, torch.stack
|
| 48 |
+
ones = lambda *a: torch.ones(*a, device=device)
|
| 49 |
+
|
| 50 |
+
tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)]
|
| 51 |
+
grid = meshgrid(tw, th, indexing="xy")
|
| 52 |
+
if homogeneous:
|
| 53 |
+
grid = grid + (ones((H, W)),)
|
| 54 |
+
if unsqueeze is not None:
|
| 55 |
+
grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
|
| 56 |
+
if cat_dim is not None:
|
| 57 |
+
grid = stack(grid, cat_dim)
|
| 58 |
+
return grid
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def geotrf(Trf, pts, ncol=None, norm=False):
|
| 62 |
+
"""Apply a geometric transformation to a list of 3-D points.
|
| 63 |
+
|
| 64 |
+
H: 3x3 or 4x4 projection matrix (typically a Homography)
|
| 65 |
+
p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
|
| 66 |
+
|
| 67 |
+
ncol: int. number of columns of the result (2 or 3)
|
| 68 |
+
norm: float. if != 0, the resut is projected on the z=norm plane.
|
| 69 |
+
|
| 70 |
+
Returns an array of projected 2d points.
|
| 71 |
+
"""
|
| 72 |
+
assert Trf.ndim >= 2
|
| 73 |
+
if isinstance(Trf, np.ndarray):
|
| 74 |
+
pts = np.asarray(pts)
|
| 75 |
+
elif isinstance(Trf, torch.Tensor):
|
| 76 |
+
pts = torch.as_tensor(pts, dtype=Trf.dtype)
|
| 77 |
+
|
| 78 |
+
# adapt shape if necessary
|
| 79 |
+
output_reshape = pts.shape[:-1]
|
| 80 |
+
ncol = ncol or pts.shape[-1]
|
| 81 |
+
|
| 82 |
+
# optimized code
|
| 83 |
+
if (
|
| 84 |
+
isinstance(Trf, torch.Tensor)
|
| 85 |
+
and isinstance(pts, torch.Tensor)
|
| 86 |
+
and Trf.ndim == 3
|
| 87 |
+
and pts.ndim == 4
|
| 88 |
+
):
|
| 89 |
+
d = pts.shape[3]
|
| 90 |
+
if Trf.shape[-1] == d:
|
| 91 |
+
pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
|
| 92 |
+
elif Trf.shape[-1] == d + 1:
|
| 93 |
+
pts = (
|
| 94 |
+
torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts)
|
| 95 |
+
+ Trf[:, None, None, :d, d]
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}")
|
| 99 |
+
else:
|
| 100 |
+
if Trf.ndim >= 3:
|
| 101 |
+
n = Trf.ndim - 2
|
| 102 |
+
assert Trf.shape[:n] == pts.shape[:n], "batch size does not match"
|
| 103 |
+
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
|
| 104 |
+
|
| 105 |
+
if pts.ndim > Trf.ndim:
|
| 106 |
+
# Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
|
| 107 |
+
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
|
| 108 |
+
elif pts.ndim == 2:
|
| 109 |
+
# Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
|
| 110 |
+
pts = pts[:, None, :]
|
| 111 |
+
|
| 112 |
+
if pts.shape[-1] + 1 == Trf.shape[-1]:
|
| 113 |
+
Trf = Trf.swapaxes(-1, -2) # transpose Trf
|
| 114 |
+
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
|
| 115 |
+
elif pts.shape[-1] == Trf.shape[-1]:
|
| 116 |
+
Trf = Trf.swapaxes(-1, -2) # transpose Trf
|
| 117 |
+
pts = pts @ Trf
|
| 118 |
+
else:
|
| 119 |
+
pts = Trf @ pts.T
|
| 120 |
+
if pts.ndim >= 2:
|
| 121 |
+
pts = pts.swapaxes(-1, -2)
|
| 122 |
+
|
| 123 |
+
if norm:
|
| 124 |
+
pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
|
| 125 |
+
if norm != 1:
|
| 126 |
+
pts *= norm
|
| 127 |
+
|
| 128 |
+
res = pts[..., :ncol].reshape(*output_reshape, ncol)
|
| 129 |
+
return res
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def inv(mat):
|
| 133 |
+
"""Invert a torch or numpy matrix"""
|
| 134 |
+
if isinstance(mat, torch.Tensor):
|
| 135 |
+
return torch.linalg.inv(mat)
|
| 136 |
+
if isinstance(mat, np.ndarray):
|
| 137 |
+
return np.linalg.inv(mat)
|
| 138 |
+
raise ValueError(f"bad matrix type = {type(mat)}")
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_):
|
| 142 |
+
"""
|
| 143 |
+
Args:
|
| 144 |
+
- depthmap (BxHxW array):
|
| 145 |
+
- pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W]
|
| 146 |
+
Returns:
|
| 147 |
+
pointmap of absolute coordinates (BxHxWx3 array)
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
if len(depth.shape) == 4:
|
| 151 |
+
B, H, W, n = depth.shape
|
| 152 |
+
else:
|
| 153 |
+
B, H, W = depth.shape
|
| 154 |
+
n = None
|
| 155 |
+
|
| 156 |
+
if len(pseudo_focal.shape) == 3: # [B,H,W]
|
| 157 |
+
pseudo_focalx = pseudo_focaly = pseudo_focal
|
| 158 |
+
elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W]
|
| 159 |
+
pseudo_focalx = pseudo_focal[:, 0]
|
| 160 |
+
if pseudo_focal.shape[1] == 2:
|
| 161 |
+
pseudo_focaly = pseudo_focal[:, 1]
|
| 162 |
+
else:
|
| 163 |
+
pseudo_focaly = pseudo_focalx
|
| 164 |
+
else:
|
| 165 |
+
raise NotImplementedError("Error, unknown input focal shape format.")
|
| 166 |
+
|
| 167 |
+
assert pseudo_focalx.shape == depth.shape[:3]
|
| 168 |
+
assert pseudo_focaly.shape == depth.shape[:3]
|
| 169 |
+
grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None]
|
| 170 |
+
|
| 171 |
+
# set principal point
|
| 172 |
+
if pp is None:
|
| 173 |
+
grid_x = grid_x - (W - 1) / 2
|
| 174 |
+
grid_y = grid_y - (H - 1) / 2
|
| 175 |
+
else:
|
| 176 |
+
grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None]
|
| 177 |
+
grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None]
|
| 178 |
+
|
| 179 |
+
if n is None:
|
| 180 |
+
pts3d = torch.empty((B, H, W, 3), device=depth.device)
|
| 181 |
+
pts3d[..., 0] = depth * grid_x / pseudo_focalx
|
| 182 |
+
pts3d[..., 1] = depth * grid_y / pseudo_focaly
|
| 183 |
+
pts3d[..., 2] = depth
|
| 184 |
+
else:
|
| 185 |
+
pts3d = torch.empty((B, H, W, 3, n), device=depth.device)
|
| 186 |
+
pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None]
|
| 187 |
+
pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None]
|
| 188 |
+
pts3d[..., 2, :] = depth
|
| 189 |
+
return pts3d
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
|
| 193 |
+
"""
|
| 194 |
+
Args:
|
| 195 |
+
- depthmap (HxW array):
|
| 196 |
+
- camera_intrinsics: a 3x3 matrix
|
| 197 |
+
Returns:
|
| 198 |
+
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
|
| 199 |
+
"""
|
| 200 |
+
camera_intrinsics = np.float32(camera_intrinsics)
|
| 201 |
+
H, W = depthmap.shape
|
| 202 |
+
|
| 203 |
+
# Compute 3D ray associated with each pixel
|
| 204 |
+
# Strong assumption: there are no skew terms
|
| 205 |
+
assert camera_intrinsics[0, 1] == 0.0
|
| 206 |
+
assert camera_intrinsics[1, 0] == 0.0
|
| 207 |
+
if pseudo_focal is None:
|
| 208 |
+
fu = camera_intrinsics[0, 0]
|
| 209 |
+
fv = camera_intrinsics[1, 1]
|
| 210 |
+
else:
|
| 211 |
+
assert pseudo_focal.shape == (H, W)
|
| 212 |
+
fu = fv = pseudo_focal
|
| 213 |
+
cu = camera_intrinsics[0, 2]
|
| 214 |
+
cv = camera_intrinsics[1, 2]
|
| 215 |
+
|
| 216 |
+
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
| 217 |
+
z_cam = depthmap
|
| 218 |
+
x_cam = (u - cu) * z_cam / fu
|
| 219 |
+
y_cam = (v - cv) * z_cam / fv
|
| 220 |
+
X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
|
| 221 |
+
|
| 222 |
+
# Mask for valid coordinates
|
| 223 |
+
valid_mask = depthmap > 0.0
|
| 224 |
+
return X_cam, valid_mask
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def depthmap_to_absolute_camera_coordinates(
|
| 228 |
+
depthmap, camera_intrinsics, camera_pose, **kw
|
| 229 |
+
):
|
| 230 |
+
"""
|
| 231 |
+
Args:
|
| 232 |
+
- depthmap (HxW array):
|
| 233 |
+
- camera_intrinsics: a 3x3 matrix
|
| 234 |
+
- camera_pose: a 4x3 or 4x4 cam2world matrix
|
| 235 |
+
Returns:
|
| 236 |
+
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
|
| 237 |
+
"""
|
| 238 |
+
X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
|
| 239 |
+
|
| 240 |
+
# R_cam2world = np.float32(camera_params["R_cam2world"])
|
| 241 |
+
# t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze()
|
| 242 |
+
R_cam2world = camera_pose[:3, :3]
|
| 243 |
+
t_cam2world = camera_pose[:3, 3]
|
| 244 |
+
|
| 245 |
+
# Express in absolute coordinates (invalid depth values)
|
| 246 |
+
X_world = (
|
| 247 |
+
np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
|
| 248 |
+
)
|
| 249 |
+
return X_world, valid_mask
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def colmap_to_opencv_intrinsics(K):
|
| 253 |
+
"""
|
| 254 |
+
Modify camera intrinsics to follow a different convention.
|
| 255 |
+
Coordinates of the center of the top-left pixels are by default:
|
| 256 |
+
- (0.5, 0.5) in Colmap
|
| 257 |
+
- (0,0) in OpenCV
|
| 258 |
+
"""
|
| 259 |
+
K = K.copy()
|
| 260 |
+
K[0, 2] -= 0.5
|
| 261 |
+
K[1, 2] -= 0.5
|
| 262 |
+
return K
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def opencv_to_colmap_intrinsics(K):
|
| 266 |
+
"""
|
| 267 |
+
Modify camera intrinsics to follow a different convention.
|
| 268 |
+
Coordinates of the center of the top-left pixels are by default:
|
| 269 |
+
- (0.5, 0.5) in Colmap
|
| 270 |
+
- (0,0) in OpenCV
|
| 271 |
+
"""
|
| 272 |
+
K = K.copy()
|
| 273 |
+
K[0, 2] += 0.5
|
| 274 |
+
K[1, 2] += 0.5
|
| 275 |
+
return K
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def normalize_pointcloud(pts1, pts2, norm_mode="avg_dis", valid1=None, valid2=None):
|
| 279 |
+
"""renorm pointmaps pts1, pts2 with norm_mode"""
|
| 280 |
+
assert pts1.ndim >= 3 and pts1.shape[-1] == 3
|
| 281 |
+
assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3)
|
| 282 |
+
norm_mode, dis_mode = norm_mode.split("_")
|
| 283 |
+
|
| 284 |
+
if norm_mode == "avg":
|
| 285 |
+
# gather all points together (joint normalization)
|
| 286 |
+
nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3)
|
| 287 |
+
nan_pts2, nnz2 = (
|
| 288 |
+
invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0)
|
| 289 |
+
)
|
| 290 |
+
all_pts = (
|
| 291 |
+
torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# compute distance to origin
|
| 295 |
+
all_dis = all_pts.norm(dim=-1)
|
| 296 |
+
if dis_mode == "dis":
|
| 297 |
+
pass # do nothing
|
| 298 |
+
elif dis_mode == "log1p":
|
| 299 |
+
all_dis = torch.log1p(all_dis)
|
| 300 |
+
elif dis_mode == "warp-log1p":
|
| 301 |
+
# actually warp input points before normalizing them
|
| 302 |
+
log_dis = torch.log1p(all_dis)
|
| 303 |
+
warp_factor = log_dis / all_dis.clip(min=1e-8)
|
| 304 |
+
H1, W1 = pts1.shape[1:-1]
|
| 305 |
+
pts1 = pts1 * warp_factor[:, : W1 * H1].view(-1, H1, W1, 1)
|
| 306 |
+
if pts2 is not None:
|
| 307 |
+
H2, W2 = pts2.shape[1:-1]
|
| 308 |
+
pts2 = pts2 * warp_factor[:, W1 * H1 :].view(-1, H2, W2, 1)
|
| 309 |
+
all_dis = log_dis # this is their true distance afterwards
|
| 310 |
+
else:
|
| 311 |
+
raise ValueError(f"bad {dis_mode=}")
|
| 312 |
+
|
| 313 |
+
norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8)
|
| 314 |
+
else:
|
| 315 |
+
# gather all points together (joint normalization)
|
| 316 |
+
nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3)
|
| 317 |
+
nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None
|
| 318 |
+
all_pts = (
|
| 319 |
+
torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# compute distance to origin
|
| 323 |
+
all_dis = all_pts.norm(dim=-1)
|
| 324 |
+
|
| 325 |
+
if norm_mode == "avg":
|
| 326 |
+
norm_factor = all_dis.nanmean(dim=1)
|
| 327 |
+
elif norm_mode == "median":
|
| 328 |
+
norm_factor = all_dis.nanmedian(dim=1).values.detach()
|
| 329 |
+
elif norm_mode == "sqrt":
|
| 330 |
+
norm_factor = all_dis.sqrt().nanmean(dim=1) ** 2
|
| 331 |
+
else:
|
| 332 |
+
raise ValueError(f"bad {norm_mode=}")
|
| 333 |
+
|
| 334 |
+
norm_factor = norm_factor.clip(min=1e-8)
|
| 335 |
+
while norm_factor.ndim < pts1.ndim:
|
| 336 |
+
norm_factor.unsqueeze_(-1)
|
| 337 |
+
|
| 338 |
+
res = pts1 / norm_factor
|
| 339 |
+
if pts2 is not None:
|
| 340 |
+
res = (res, pts2 / norm_factor)
|
| 341 |
+
return res
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
@torch.no_grad()
|
| 345 |
+
def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5):
|
| 346 |
+
# set invalid points to NaN
|
| 347 |
+
_z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1)
|
| 348 |
+
_z2 = (
|
| 349 |
+
invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1)
|
| 350 |
+
if z2 is not None
|
| 351 |
+
else None
|
| 352 |
+
)
|
| 353 |
+
_z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1
|
| 354 |
+
|
| 355 |
+
# compute median depth overall (ignoring nans)
|
| 356 |
+
if quantile == 0.5:
|
| 357 |
+
shift_z = torch.nanmedian(_z, dim=-1).values
|
| 358 |
+
else:
|
| 359 |
+
shift_z = torch.nanquantile(_z, quantile, dim=-1)
|
| 360 |
+
return shift_z # (B,)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
@torch.no_grad()
|
| 364 |
+
def get_joint_pointcloud_center_scale(
|
| 365 |
+
pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True
|
| 366 |
+
):
|
| 367 |
+
# set invalid points to NaN
|
| 368 |
+
_pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3)
|
| 369 |
+
_pts2 = (
|
| 370 |
+
invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3)
|
| 371 |
+
if pts2 is not None
|
| 372 |
+
else None
|
| 373 |
+
)
|
| 374 |
+
_pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1
|
| 375 |
+
|
| 376 |
+
# compute median center
|
| 377 |
+
_center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
|
| 378 |
+
if z_only:
|
| 379 |
+
_center[..., :2] = 0 # do not center X and Y
|
| 380 |
+
|
| 381 |
+
# compute median norm
|
| 382 |
+
_norm = ((_pts - _center) if center else _pts).norm(dim=-1)
|
| 383 |
+
scale = torch.nanmedian(_norm, dim=1).values
|
| 384 |
+
return _center[:, None, :, :], scale[:, None, None, None]
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def find_reciprocal_matches(P1, P2):
|
| 388 |
+
"""
|
| 389 |
+
returns 3 values:
|
| 390 |
+
1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match
|
| 391 |
+
2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1
|
| 392 |
+
3 - reciprocal_in_P2.sum(): the number of matches
|
| 393 |
+
"""
|
| 394 |
+
tree1 = KDTree(P1)
|
| 395 |
+
tree2 = KDTree(P2)
|
| 396 |
+
|
| 397 |
+
_, nn1_in_P2 = tree2.query(P1, workers=8)
|
| 398 |
+
_, nn2_in_P1 = tree1.query(P2, workers=8)
|
| 399 |
+
|
| 400 |
+
reciprocal_in_P1 = nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2))
|
| 401 |
+
reciprocal_in_P2 = nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1))
|
| 402 |
+
assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum()
|
| 403 |
+
return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum()
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def get_med_dist_between_poses(poses):
|
| 407 |
+
from scipy.spatial.distance import pdist
|
| 408 |
+
|
| 409 |
+
return np.median(pdist([to_numpy(p[:3, 3]) for p in poses]))
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def save_pointcloud_with_plyfile(result, filename="output.ply", downsample_ratio=10):
|
| 413 |
+
all_points = []
|
| 414 |
+
all_colors = []
|
| 415 |
+
|
| 416 |
+
for view in result:
|
| 417 |
+
pts = view["point_map_by_unprojection"] # (1, H, W, 3)
|
| 418 |
+
rgbs = view["rgbs"] # (1, 3, H, W)
|
| 419 |
+
dpt_cnf = view["dpt_cnf"]
|
| 420 |
+
|
| 421 |
+
# Remove batch dimension
|
| 422 |
+
pts = pts.squeeze(0) # (H, W, 3)
|
| 423 |
+
rgbs = rgbs.squeeze(0).permute(1, 2, 0) # (3, H, W) -> (H, W, 3)
|
| 424 |
+
|
| 425 |
+
# Flatten
|
| 426 |
+
pts = pts.reshape(-1, 3) # (N, 3)
|
| 427 |
+
rgbs = rgbs.reshape(-1, 3) # (N, 3)
|
| 428 |
+
|
| 429 |
+
# Remove invalid points
|
| 430 |
+
valid = torch.isfinite(pts).all(dim=1) & (pts.norm(dim=1) > 0)
|
| 431 |
+
valid = valid & (dpt_cnf > torch.quantile(view["dpt_cnf"], 0.5)).flatten()
|
| 432 |
+
pts = pts[valid]
|
| 433 |
+
rgbs = rgbs[valid]
|
| 434 |
+
|
| 435 |
+
# Downsample this view
|
| 436 |
+
N = pts.shape[0]
|
| 437 |
+
if downsample_ratio > 1 and N >= downsample_ratio:
|
| 438 |
+
idx = torch.randperm(N)[: N // downsample_ratio]
|
| 439 |
+
pts = pts[idx]
|
| 440 |
+
rgbs = rgbs[idx]
|
| 441 |
+
|
| 442 |
+
all_points.append(pts)
|
| 443 |
+
all_colors.append(rgbs)
|
| 444 |
+
|
| 445 |
+
# Merge all views
|
| 446 |
+
all_points = torch.cat(all_points, dim=0).cpu().numpy()
|
| 447 |
+
all_colors = torch.cat(all_colors, dim=0).cpu().numpy()
|
| 448 |
+
|
| 449 |
+
# Normalize color
|
| 450 |
+
if all_colors.max() <= 1.0:
|
| 451 |
+
all_colors = (all_colors * 255).astype(np.uint8)
|
| 452 |
+
else:
|
| 453 |
+
all_colors = all_colors.astype(np.uint8)
|
| 454 |
+
|
| 455 |
+
# Build structured array
|
| 456 |
+
vertex_data = np.empty(
|
| 457 |
+
len(all_points),
|
| 458 |
+
dtype=[
|
| 459 |
+
("x", "f4"),
|
| 460 |
+
("y", "f4"),
|
| 461 |
+
("z", "f4"),
|
| 462 |
+
("red", "u1"),
|
| 463 |
+
("green", "u1"),
|
| 464 |
+
("blue", "u1"),
|
| 465 |
+
],
|
| 466 |
+
)
|
| 467 |
+
vertex_data["x"] = all_points[:, 0]
|
| 468 |
+
vertex_data["y"] = all_points[:, 1]
|
| 469 |
+
vertex_data["z"] = all_points[:, 2]
|
| 470 |
+
vertex_data["red"] = all_colors[:, 0]
|
| 471 |
+
vertex_data["green"] = all_colors[:, 1]
|
| 472 |
+
vertex_data["blue"] = all_colors[:, 2]
|
| 473 |
+
|
| 474 |
+
# Save with plyfile
|
| 475 |
+
el = PlyElement.describe(vertex_data, "vertex")
|
| 476 |
+
PlyData([el], text=False).write(filename)
|
| 477 |
+
|
| 478 |
+
print(f"[PLY] Saved {len(all_points)} points to {filename}")
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def save_pointcloud_with_plyfile_each_frame(
|
| 482 |
+
result, filename="output.ply", downsample_ratio=10
|
| 483 |
+
):
|
| 484 |
+
for frame_number, view in enumerate(result):
|
| 485 |
+
all_points = []
|
| 486 |
+
all_colors = []
|
| 487 |
+
pts = view["point_map_by_unprojection"] # (1, H, W, 3)
|
| 488 |
+
rgbs = view["rgbs"] # (1, 3, H, W)
|
| 489 |
+
dpt_cnf = view["dpt_cnf"]
|
| 490 |
+
|
| 491 |
+
# Remove batch dimension
|
| 492 |
+
pts = pts.squeeze(0) # (H, W, 3)
|
| 493 |
+
rgbs = rgbs.squeeze(0).permute(1, 2, 0) # (3, H, W) -> (H, W, 3)
|
| 494 |
+
|
| 495 |
+
# Flatten
|
| 496 |
+
pts = pts.reshape(-1, 3) # (N, 3)
|
| 497 |
+
rgbs = rgbs.reshape(-1, 3) # (N, 3)
|
| 498 |
+
|
| 499 |
+
# Remove invalid points
|
| 500 |
+
valid = torch.isfinite(pts).all(dim=1) & (pts.norm(dim=1) > 0)
|
| 501 |
+
valid = valid & (dpt_cnf > torch.quantile(view["dpt_cnf"], 0.5)).flatten()
|
| 502 |
+
pts = pts[valid]
|
| 503 |
+
rgbs = rgbs[valid]
|
| 504 |
+
|
| 505 |
+
# Downsample this view
|
| 506 |
+
N = pts.shape[0]
|
| 507 |
+
if downsample_ratio > 1 and N >= downsample_ratio:
|
| 508 |
+
idx = torch.randperm(N)[: N // downsample_ratio]
|
| 509 |
+
pts = pts[idx]
|
| 510 |
+
rgbs = rgbs[idx]
|
| 511 |
+
|
| 512 |
+
all_points.append(pts)
|
| 513 |
+
all_colors.append(rgbs)
|
| 514 |
+
|
| 515 |
+
# Merge all views
|
| 516 |
+
all_points = torch.cat(all_points, dim=0).cpu().numpy()
|
| 517 |
+
all_colors = torch.cat(all_colors, dim=0).cpu().numpy()
|
| 518 |
+
|
| 519 |
+
# Normalize color
|
| 520 |
+
if all_colors.max() <= 1.0:
|
| 521 |
+
all_colors = (all_colors * 255).astype(np.uint8)
|
| 522 |
+
else:
|
| 523 |
+
all_colors = all_colors.astype(np.uint8)
|
| 524 |
+
|
| 525 |
+
# Build structured array
|
| 526 |
+
vertex_data = np.empty(
|
| 527 |
+
len(all_points),
|
| 528 |
+
dtype=[
|
| 529 |
+
("x", "f4"),
|
| 530 |
+
("y", "f4"),
|
| 531 |
+
("z", "f4"),
|
| 532 |
+
("red", "u1"),
|
| 533 |
+
("green", "u1"),
|
| 534 |
+
("blue", "u1"),
|
| 535 |
+
],
|
| 536 |
+
)
|
| 537 |
+
vertex_data["x"] = all_points[:, 0]
|
| 538 |
+
vertex_data["y"] = all_points[:, 1]
|
| 539 |
+
vertex_data["z"] = all_points[:, 2]
|
| 540 |
+
vertex_data["red"] = all_colors[:, 0]
|
| 541 |
+
vertex_data["green"] = all_colors[:, 1]
|
| 542 |
+
vertex_data["blue"] = all_colors[:, 2]
|
| 543 |
+
|
| 544 |
+
# Save with plyfile
|
| 545 |
+
el = PlyElement.describe(vertex_data, "vertex")
|
| 546 |
+
PlyData([el], text=False).write(
|
| 547 |
+
filename.split(".ply")[0] + f"_{frame_number:05d}.ply"
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
print(
|
| 551 |
+
f"[PLY] Saved {len(all_points)} points to {filename} idx {frame_number:05d}"
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def save_concatenated_images(samples, save_path):
|
| 556 |
+
imgs = []
|
| 557 |
+
|
| 558 |
+
for sample in samples:
|
| 559 |
+
img = sample["img"] # (1, C, H, W)
|
| 560 |
+
img = F.interpolate(
|
| 561 |
+
img, scale_factor=0.25, mode="bilinear", align_corners=False
|
| 562 |
+
)
|
| 563 |
+
imgs.append(img)
|
| 564 |
+
|
| 565 |
+
imgs = torch.cat(imgs, dim=0).cpu() # (N, C, H, W)
|
| 566 |
+
|
| 567 |
+
save_path = Path(save_path)
|
| 568 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 569 |
+
|
| 570 |
+
vutils.save_image(imgs, save_path, normalize=True)
|
| 571 |
+
|
| 572 |
+
print(f"[Image] Saved concatenated image to {save_path}")
|
eval/utils/image.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 8 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 9 |
+
#
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
# utilitary functions about images (loading/converting...)
|
| 12 |
+
# --------------------------------------------------------
|
| 13 |
+
import os
|
| 14 |
+
from typing import Dict, Optional
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import PIL.Image
|
| 18 |
+
import torch
|
| 19 |
+
import torchvision.transforms as tvf
|
| 20 |
+
from PIL.ImageOps import exif_transpose
|
| 21 |
+
|
| 22 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
| 23 |
+
import cv2
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from pillow_heif import register_heif_opener
|
| 27 |
+
|
| 28 |
+
register_heif_opener()
|
| 29 |
+
heif_support_enabled = True
|
| 30 |
+
except ImportError:
|
| 31 |
+
heif_support_enabled = False
|
| 32 |
+
|
| 33 |
+
ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def imread_cv2(path, options=cv2.IMREAD_COLOR):
|
| 37 |
+
"""Open an image or a depthmap with opencv-python."""
|
| 38 |
+
if path.endswith((".exr", "EXR")):
|
| 39 |
+
options = cv2.IMREAD_ANYDEPTH
|
| 40 |
+
img = cv2.imread(path, options)
|
| 41 |
+
if img is None:
|
| 42 |
+
raise IOError(f"Could not load image={path} with {options=}")
|
| 43 |
+
if img.ndim == 3:
|
| 44 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 45 |
+
return img
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def rgb(ftensor, true_shape=None):
|
| 49 |
+
if isinstance(ftensor, list):
|
| 50 |
+
return [rgb(x, true_shape=true_shape) for x in ftensor]
|
| 51 |
+
if isinstance(ftensor, torch.Tensor):
|
| 52 |
+
ftensor = ftensor.detach().cpu().numpy() # H,W,3
|
| 53 |
+
if ftensor.ndim == 3 and ftensor.shape[0] == 3:
|
| 54 |
+
ftensor = ftensor.transpose(1, 2, 0)
|
| 55 |
+
elif ftensor.ndim == 4 and ftensor.shape[1] == 3:
|
| 56 |
+
ftensor = ftensor.transpose(0, 2, 3, 1)
|
| 57 |
+
if true_shape is not None:
|
| 58 |
+
H, W = true_shape
|
| 59 |
+
ftensor = ftensor[:H, :W]
|
| 60 |
+
if ftensor.dtype == np.uint8:
|
| 61 |
+
img = np.float32(ftensor) / 255
|
| 62 |
+
else:
|
| 63 |
+
img = (ftensor * 0.5) + 0.5
|
| 64 |
+
return img.clip(min=0, max=1)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _resize_pil_image(img, long_edge_size):
|
| 68 |
+
S = max(img.size)
|
| 69 |
+
if S > long_edge_size:
|
| 70 |
+
interp = PIL.Image.LANCZOS
|
| 71 |
+
elif S <= long_edge_size:
|
| 72 |
+
interp = PIL.Image.BICUBIC
|
| 73 |
+
new_size = tuple(int(round(x * long_edge_size / S)) for x in img.size)
|
| 74 |
+
return img.resize(new_size, interp)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def load_images(
|
| 78 |
+
folder_or_list,
|
| 79 |
+
size,
|
| 80 |
+
square_ok=False,
|
| 81 |
+
verbose=True,
|
| 82 |
+
rotate_clockwise_90=False,
|
| 83 |
+
crop_to_landscape=False,
|
| 84 |
+
):
|
| 85 |
+
"""open and convert all images in a list or folder to proper input format for DUSt3R"""
|
| 86 |
+
if isinstance(folder_or_list, str):
|
| 87 |
+
if verbose:
|
| 88 |
+
print(f">> Loading images from {folder_or_list}")
|
| 89 |
+
root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
|
| 90 |
+
|
| 91 |
+
elif isinstance(folder_or_list, list):
|
| 92 |
+
if verbose:
|
| 93 |
+
print(f">> Loading a list of {len(folder_or_list)} images")
|
| 94 |
+
root, folder_content = "", folder_or_list
|
| 95 |
+
|
| 96 |
+
else:
|
| 97 |
+
raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})")
|
| 98 |
+
|
| 99 |
+
supported_images_extensions = [".jpg", ".jpeg", ".png"]
|
| 100 |
+
if heif_support_enabled:
|
| 101 |
+
supported_images_extensions += [".heic", ".heif"]
|
| 102 |
+
supported_images_extensions = tuple(supported_images_extensions)
|
| 103 |
+
|
| 104 |
+
imgs = []
|
| 105 |
+
for path in folder_content:
|
| 106 |
+
if not path.lower().endswith(supported_images_extensions):
|
| 107 |
+
continue
|
| 108 |
+
img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB")
|
| 109 |
+
if rotate_clockwise_90:
|
| 110 |
+
img = img.rotate(-90, expand=True)
|
| 111 |
+
if crop_to_landscape:
|
| 112 |
+
# Crop to a landscape aspect ratio (e.g., 16:9)
|
| 113 |
+
desired_aspect_ratio = 4 / 3
|
| 114 |
+
width, height = img.size
|
| 115 |
+
current_aspect_ratio = width / height
|
| 116 |
+
|
| 117 |
+
if current_aspect_ratio > desired_aspect_ratio:
|
| 118 |
+
# Wider than landscape: crop width
|
| 119 |
+
new_width = int(height * desired_aspect_ratio)
|
| 120 |
+
left = (width - new_width) // 2
|
| 121 |
+
right = left + new_width
|
| 122 |
+
top = 0
|
| 123 |
+
bottom = height
|
| 124 |
+
else:
|
| 125 |
+
# Taller than landscape: crop height
|
| 126 |
+
new_height = int(width / desired_aspect_ratio)
|
| 127 |
+
top = (height - new_height) // 2
|
| 128 |
+
bottom = top + new_height
|
| 129 |
+
left = 0
|
| 130 |
+
right = width
|
| 131 |
+
|
| 132 |
+
img = img.crop((left, top, right, bottom))
|
| 133 |
+
|
| 134 |
+
W1, H1 = img.size
|
| 135 |
+
if size == 224:
|
| 136 |
+
# resize short side to 224 (then crop)
|
| 137 |
+
img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1)))
|
| 138 |
+
else:
|
| 139 |
+
# resize long side to 512
|
| 140 |
+
img = _resize_pil_image(img, size)
|
| 141 |
+
W, H = img.size
|
| 142 |
+
cx, cy = W // 2, H // 2
|
| 143 |
+
if size == 224:
|
| 144 |
+
half = min(cx, cy)
|
| 145 |
+
img = img.crop((cx - half, cy - half, cx + half, cy + half))
|
| 146 |
+
else:
|
| 147 |
+
halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8
|
| 148 |
+
if not (square_ok) and W == H:
|
| 149 |
+
halfh = 3 * halfw / 4
|
| 150 |
+
img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh))
|
| 151 |
+
|
| 152 |
+
W2, H2 = img.size
|
| 153 |
+
if verbose:
|
| 154 |
+
print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}")
|
| 155 |
+
imgs.append(
|
| 156 |
+
dict(
|
| 157 |
+
img=ImgNorm(img)[None],
|
| 158 |
+
true_shape=np.int32([img.size[::-1]]),
|
| 159 |
+
idx=len(imgs),
|
| 160 |
+
instance=str(len(imgs)),
|
| 161 |
+
)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
assert imgs, "no images foud at " + root
|
| 165 |
+
if verbose:
|
| 166 |
+
print(f" (Found {len(imgs)} images)")
|
| 167 |
+
return imgs
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def get_image_vggt_augmentation(
|
| 171 |
+
color_jitter: Optional[Dict[str, float]] = None,
|
| 172 |
+
gray_scale: bool = True,
|
| 173 |
+
gau_blur: bool = False,
|
| 174 |
+
) -> Optional[tvf.Compose]:
|
| 175 |
+
"""Create a composition of image augmentations.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
color_jitter: Dictionary containing color jitter parameters:
|
| 179 |
+
- brightness: float (default: 0.5)
|
| 180 |
+
- contrast: float (default: 0.5)
|
| 181 |
+
- saturation: float (default: 0.5)
|
| 182 |
+
- hue: float (default: 0.1)
|
| 183 |
+
- p: probability of applying (default: 0.9)
|
| 184 |
+
If None, uses default values
|
| 185 |
+
gray_scale: Whether to apply random grayscale (default: True)
|
| 186 |
+
gau_blur: Whether to apply gaussian blur (default: False)
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
A Compose object of transforms or None if no transforms are added
|
| 190 |
+
"""
|
| 191 |
+
transform_list = []
|
| 192 |
+
default_jitter = {
|
| 193 |
+
"brightness": 0.5,
|
| 194 |
+
"contrast": 0.5,
|
| 195 |
+
"saturation": 0.5,
|
| 196 |
+
"hue": 0.1,
|
| 197 |
+
"p": 0.9,
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
# Handle color jitter
|
| 201 |
+
if color_jitter is not None:
|
| 202 |
+
if not isinstance(color_jitter, dict):
|
| 203 |
+
raise ValueError("color_jitter must be a dictionary or None")
|
| 204 |
+
# Merge with defaults for missing keys
|
| 205 |
+
effective_jitter = {**default_jitter, **color_jitter}
|
| 206 |
+
else:
|
| 207 |
+
effective_jitter = default_jitter
|
| 208 |
+
|
| 209 |
+
transform_list.append(
|
| 210 |
+
tvf.RandomApply(
|
| 211 |
+
[
|
| 212 |
+
tvf.ColorJitter(
|
| 213 |
+
brightness=effective_jitter["brightness"],
|
| 214 |
+
contrast=effective_jitter["contrast"],
|
| 215 |
+
saturation=effective_jitter["saturation"],
|
| 216 |
+
hue=effective_jitter["hue"],
|
| 217 |
+
)
|
| 218 |
+
],
|
| 219 |
+
p=effective_jitter["p"],
|
| 220 |
+
)
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
if gray_scale:
|
| 224 |
+
transform_list.append(tvf.RandomGrayscale(p=0.05))
|
| 225 |
+
|
| 226 |
+
if gau_blur:
|
| 227 |
+
transform_list.append(
|
| 228 |
+
tvf.RandomApply([tvf.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05)
|
| 229 |
+
)
|
| 230 |
+
# transform_list.append(tvf.ToTensor())
|
| 231 |
+
|
| 232 |
+
return tvf.Compose(transform_list) if transform_list else None
|
eval/utils/load_fn.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from torchvision import transforms as TF
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_and_preprocess_images(image_path_list, mode="crop"):
|
| 13 |
+
"""
|
| 14 |
+
A quick start function to load and preprocess images for model input.
|
| 15 |
+
This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
image_path_list (list): List of paths to image files
|
| 19 |
+
mode (str, optional): Preprocessing mode, either "crop" or "pad".
|
| 20 |
+
- "crop" (default): Sets width to 518px and center crops height if needed.
|
| 21 |
+
- "pad": Preserves all pixels by making the largest dimension 518px
|
| 22 |
+
and padding the smaller dimension to reach a square shape.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
|
| 26 |
+
|
| 27 |
+
Raises:
|
| 28 |
+
ValueError: If the input list is empty or if mode is invalid
|
| 29 |
+
|
| 30 |
+
Notes:
|
| 31 |
+
- Images with different dimensions will be padded with white (value=1.0)
|
| 32 |
+
- A warning is printed when images have different shapes
|
| 33 |
+
- When mode="crop": The function ensures width=518px while maintaining aspect ratio
|
| 34 |
+
and height is center-cropped if larger than 518px
|
| 35 |
+
- When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
|
| 36 |
+
and the smaller dimension is padded to reach a square shape (518x518)
|
| 37 |
+
- Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
|
| 38 |
+
"""
|
| 39 |
+
# Check for empty list
|
| 40 |
+
if len(image_path_list) == 0:
|
| 41 |
+
raise ValueError("At least 1 image is required")
|
| 42 |
+
|
| 43 |
+
# Validate mode
|
| 44 |
+
if mode not in ["crop", "pad"]:
|
| 45 |
+
raise ValueError("Mode must be either 'crop' or 'pad'")
|
| 46 |
+
|
| 47 |
+
images = []
|
| 48 |
+
shapes = set()
|
| 49 |
+
to_tensor = TF.ToTensor()
|
| 50 |
+
target_size = 518
|
| 51 |
+
|
| 52 |
+
# First process all images and collect their shapes
|
| 53 |
+
for image_path in image_path_list:
|
| 54 |
+
# Open image
|
| 55 |
+
img = Image.open(image_path)
|
| 56 |
+
|
| 57 |
+
# If there's an alpha channel, blend onto white background:
|
| 58 |
+
if img.mode == "RGBA":
|
| 59 |
+
# Create white background
|
| 60 |
+
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
|
| 61 |
+
# Alpha composite onto the white background
|
| 62 |
+
img = Image.alpha_composite(background, img)
|
| 63 |
+
|
| 64 |
+
# Now convert to "RGB" (this step assigns white for transparent areas)
|
| 65 |
+
img = img.convert("RGB")
|
| 66 |
+
|
| 67 |
+
width, height = img.size
|
| 68 |
+
|
| 69 |
+
if mode == "pad":
|
| 70 |
+
# Make the largest dimension 518px while maintaining aspect ratio
|
| 71 |
+
if width >= height:
|
| 72 |
+
new_width = target_size
|
| 73 |
+
new_height = (
|
| 74 |
+
round(height * (new_width / width) / 14) * 14
|
| 75 |
+
) # Make divisible by 14
|
| 76 |
+
else:
|
| 77 |
+
new_height = target_size
|
| 78 |
+
new_width = (
|
| 79 |
+
round(width * (new_height / height) / 14) * 14
|
| 80 |
+
) # Make divisible by 14
|
| 81 |
+
else: # mode == "crop"
|
| 82 |
+
# Original behavior: set width to 518px
|
| 83 |
+
new_width = target_size
|
| 84 |
+
# Calculate height maintaining aspect ratio, divisible by 14
|
| 85 |
+
new_height = round(height * (new_width / width) / 14) * 14
|
| 86 |
+
|
| 87 |
+
# Resize with new dimensions (width, height)
|
| 88 |
+
img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
|
| 89 |
+
img = to_tensor(img) # Convert to tensor (0, 1)
|
| 90 |
+
|
| 91 |
+
# Center crop height if it's larger than 518 (only in crop mode)
|
| 92 |
+
if mode == "crop" and new_height > target_size:
|
| 93 |
+
start_y = (new_height - target_size) // 2
|
| 94 |
+
img = img[:, start_y : start_y + target_size, :]
|
| 95 |
+
|
| 96 |
+
# For pad mode, pad to make a square of target_size x target_size
|
| 97 |
+
if mode == "pad":
|
| 98 |
+
h_padding = target_size - img.shape[1]
|
| 99 |
+
w_padding = target_size - img.shape[2]
|
| 100 |
+
|
| 101 |
+
if h_padding > 0 or w_padding > 0:
|
| 102 |
+
pad_top = h_padding // 2
|
| 103 |
+
pad_bottom = h_padding - pad_top
|
| 104 |
+
pad_left = w_padding // 2
|
| 105 |
+
pad_right = w_padding - pad_left
|
| 106 |
+
|
| 107 |
+
# Pad with white (value=1.0)
|
| 108 |
+
img = torch.nn.functional.pad(
|
| 109 |
+
img,
|
| 110 |
+
(pad_left, pad_right, pad_top, pad_bottom),
|
| 111 |
+
mode="constant",
|
| 112 |
+
value=1.0,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
shapes.add((img.shape[1], img.shape[2]))
|
| 116 |
+
images.append(img)
|
| 117 |
+
|
| 118 |
+
# Check if we have different shapes
|
| 119 |
+
# In theory our model can also work well with different shapes
|
| 120 |
+
if len(shapes) > 1:
|
| 121 |
+
print(f"Warning: Found images with different shapes: {shapes}")
|
| 122 |
+
# Find maximum dimensions
|
| 123 |
+
max_height = max(shape[0] for shape in shapes)
|
| 124 |
+
max_width = max(shape[1] for shape in shapes)
|
| 125 |
+
|
| 126 |
+
# Pad images if necessary
|
| 127 |
+
padded_images = []
|
| 128 |
+
for img in images:
|
| 129 |
+
h_padding = max_height - img.shape[1]
|
| 130 |
+
w_padding = max_width - img.shape[2]
|
| 131 |
+
|
| 132 |
+
if h_padding > 0 or w_padding > 0:
|
| 133 |
+
pad_top = h_padding // 2
|
| 134 |
+
pad_bottom = h_padding - pad_top
|
| 135 |
+
pad_left = w_padding // 2
|
| 136 |
+
pad_right = w_padding - pad_left
|
| 137 |
+
|
| 138 |
+
img = torch.nn.functional.pad(
|
| 139 |
+
img,
|
| 140 |
+
(pad_left, pad_right, pad_top, pad_bottom),
|
| 141 |
+
mode="constant",
|
| 142 |
+
value=1.0,
|
| 143 |
+
)
|
| 144 |
+
padded_images.append(img)
|
| 145 |
+
images = padded_images
|
| 146 |
+
|
| 147 |
+
images = torch.stack(images) # concatenate images
|
| 148 |
+
|
| 149 |
+
# Ensure correct shape when single image
|
| 150 |
+
if len(image_path_list) == 1:
|
| 151 |
+
# Verify shape is (1, C, H, W)
|
| 152 |
+
if images.dim() == 3:
|
| 153 |
+
images = images.unsqueeze(0)
|
| 154 |
+
|
| 155 |
+
return images
|
eval/utils/misc.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 8 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 9 |
+
#
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
# utilitary functions for DUSt3R
|
| 12 |
+
# --------------------------------------------------------
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def fill_default_args(kwargs, func):
|
| 17 |
+
import inspect # a bit hacky but it works reliably
|
| 18 |
+
|
| 19 |
+
signature = inspect.signature(func)
|
| 20 |
+
|
| 21 |
+
for k, v in signature.parameters.items():
|
| 22 |
+
if v.default is inspect.Parameter.empty:
|
| 23 |
+
continue
|
| 24 |
+
kwargs.setdefault(k, v.default)
|
| 25 |
+
|
| 26 |
+
return kwargs
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def freeze_all_params(modules):
|
| 30 |
+
for module in modules:
|
| 31 |
+
try:
|
| 32 |
+
for n, param in module.named_parameters():
|
| 33 |
+
param.requires_grad = False
|
| 34 |
+
except AttributeError:
|
| 35 |
+
# module is directly a parameter
|
| 36 |
+
module.requires_grad = False
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def is_symmetrized(gt1, gt2):
|
| 40 |
+
x = gt1["instance"]
|
| 41 |
+
y = gt2["instance"]
|
| 42 |
+
if len(x) == len(y) and len(x) == 1:
|
| 43 |
+
return False # special case of batchsize 1
|
| 44 |
+
ok = True
|
| 45 |
+
for i in range(0, len(x), 2):
|
| 46 |
+
ok = ok and (x[i] == y[i + 1]) and (x[i + 1] == y[i])
|
| 47 |
+
return ok
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def flip(tensor):
|
| 51 |
+
"""flip so that tensor[0::2] <=> tensor[1::2]"""
|
| 52 |
+
return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def interleave(tensor1, tensor2):
|
| 56 |
+
res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1)
|
| 57 |
+
res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1)
|
| 58 |
+
return res1, res2
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def transpose_to_landscape(head, activate=True):
|
| 62 |
+
"""Predict in the correct aspect-ratio,
|
| 63 |
+
then transpose the result in landscape
|
| 64 |
+
and stack everything back together.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def wrapper_no(decout, true_shape):
|
| 68 |
+
B = len(true_shape)
|
| 69 |
+
assert true_shape[0:1].allclose(true_shape), "true_shape must be all identical"
|
| 70 |
+
H, W = true_shape[0].cpu().tolist()
|
| 71 |
+
res = head(decout, (H, W))
|
| 72 |
+
return res
|
| 73 |
+
|
| 74 |
+
def wrapper_yes(decout, true_shape):
|
| 75 |
+
B = len(true_shape)
|
| 76 |
+
# by definition, the batch is in landscape mode so W >= H
|
| 77 |
+
H, W = int(true_shape.min()), int(true_shape.max())
|
| 78 |
+
|
| 79 |
+
height, width = true_shape.T
|
| 80 |
+
is_landscape = width >= height
|
| 81 |
+
is_portrait = ~is_landscape
|
| 82 |
+
|
| 83 |
+
# true_shape = true_shape.cpu()
|
| 84 |
+
if is_landscape.all():
|
| 85 |
+
return head(decout, (H, W))
|
| 86 |
+
if is_portrait.all():
|
| 87 |
+
return transposed(head(decout, (W, H)))
|
| 88 |
+
|
| 89 |
+
# batch is a mix of both portraint & landscape
|
| 90 |
+
def selout(ar):
|
| 91 |
+
return [d[ar] for d in decout]
|
| 92 |
+
|
| 93 |
+
l_result = head(selout(is_landscape), (H, W))
|
| 94 |
+
p_result = transposed(head(selout(is_portrait), (W, H)))
|
| 95 |
+
|
| 96 |
+
# allocate full result
|
| 97 |
+
result = {}
|
| 98 |
+
for k in l_result | p_result:
|
| 99 |
+
x = l_result[k].new(B, *l_result[k].shape[1:])
|
| 100 |
+
x[is_landscape] = l_result[k]
|
| 101 |
+
x[is_portrait] = p_result[k]
|
| 102 |
+
result[k] = x
|
| 103 |
+
|
| 104 |
+
return result
|
| 105 |
+
|
| 106 |
+
return wrapper_yes if activate else wrapper_no
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def transposed(dic):
|
| 110 |
+
return {k: v.swapaxes(1, 2) for k, v in dic.items()}
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def invalid_to_nans(arr, valid_mask, ndim=999):
|
| 114 |
+
if valid_mask is not None:
|
| 115 |
+
arr = arr.clone()
|
| 116 |
+
arr[~valid_mask] = float("nan")
|
| 117 |
+
if arr.ndim > ndim:
|
| 118 |
+
arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
|
| 119 |
+
return arr
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def invalid_to_zeros(arr, valid_mask, ndim=999):
|
| 123 |
+
if valid_mask is not None:
|
| 124 |
+
arr = arr.clone()
|
| 125 |
+
arr[~valid_mask] = 0
|
| 126 |
+
nnz = valid_mask.view(len(valid_mask), -1).sum(1)
|
| 127 |
+
else:
|
| 128 |
+
nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
|
| 129 |
+
if arr.ndim > ndim:
|
| 130 |
+
arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
|
| 131 |
+
return arr, nnz
|
eval/utils/pose_enc.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from .rotation import mat_to_quat, quat_to_mat
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def extri_intri_to_pose_encoding(
|
| 13 |
+
extrinsics,
|
| 14 |
+
intrinsics,
|
| 15 |
+
image_size_hw=None, # e.g., (256, 512)
|
| 16 |
+
pose_encoding_type="absT_quaR_FoV",
|
| 17 |
+
):
|
| 18 |
+
"""Convert camera extrinsics and intrinsics to a compact pose encoding.
|
| 19 |
+
|
| 20 |
+
This function transforms camera parameters into a unified pose encoding format,
|
| 21 |
+
which can be used for various downstream tasks like pose prediction or representation.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
|
| 25 |
+
where B is batch size and S is sequence length.
|
| 26 |
+
In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
|
| 27 |
+
The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
|
| 28 |
+
intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
|
| 29 |
+
Defined in pixels, with format:
|
| 30 |
+
[[fx, 0, cx],
|
| 31 |
+
[0, fy, cy],
|
| 32 |
+
[0, 0, 1]]
|
| 33 |
+
where fx, fy are focal lengths and (cx, cy) is the principal point
|
| 34 |
+
image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
|
| 35 |
+
Required for computing field of view values. For example: (256, 512).
|
| 36 |
+
pose_encoding_type (str): Type of pose encoding to use. Currently only
|
| 37 |
+
supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
torch.Tensor: Encoded camera pose parameters with shape BxSx9.
|
| 41 |
+
For "absT_quaR_FoV" type, the 9 dimensions are:
|
| 42 |
+
- [:3] = absolute translation vector T (3D)
|
| 43 |
+
- [3:7] = rotation as quaternion quat (4D)
|
| 44 |
+
- [7:] = field of view (2D)
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
# extrinsics: BxSx3x4
|
| 48 |
+
# intrinsics: BxSx3x3
|
| 49 |
+
|
| 50 |
+
if pose_encoding_type == "absT_quaR_FoV":
|
| 51 |
+
R = extrinsics[:, :, :3, :3] # BxSx3x3
|
| 52 |
+
T = extrinsics[:, :, :3, 3] # BxSx3
|
| 53 |
+
|
| 54 |
+
quat = mat_to_quat(R)
|
| 55 |
+
# Note the order of h and w here
|
| 56 |
+
H, W = image_size_hw
|
| 57 |
+
fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
|
| 58 |
+
fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
|
| 59 |
+
pose_encoding = torch.cat(
|
| 60 |
+
[T, quat, fov_h[..., None], fov_w[..., None]], dim=-1
|
| 61 |
+
).float()
|
| 62 |
+
else:
|
| 63 |
+
raise NotImplementedError
|
| 64 |
+
|
| 65 |
+
return pose_encoding
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def pose_encoding_to_extri_intri(
|
| 69 |
+
pose_encoding,
|
| 70 |
+
image_size_hw=None, # e.g., (256, 512)
|
| 71 |
+
pose_encoding_type="absT_quaR_FoV",
|
| 72 |
+
build_intrinsics=True,
|
| 73 |
+
):
|
| 74 |
+
"""Convert a pose encoding back to camera extrinsics and intrinsics.
|
| 75 |
+
|
| 76 |
+
This function performs the inverse operation of extri_intri_to_pose_encoding,
|
| 77 |
+
reconstructing the full camera parameters from the compact encoding.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,
|
| 81 |
+
where B is batch size and S is sequence length.
|
| 82 |
+
For "absT_quaR_FoV" type, the 9 dimensions are:
|
| 83 |
+
- [:3] = absolute translation vector T (3D)
|
| 84 |
+
- [3:7] = rotation as quaternion quat (4D)
|
| 85 |
+
- [7:] = field of view (2D)
|
| 86 |
+
image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
|
| 87 |
+
Required for reconstructing intrinsics from field of view values.
|
| 88 |
+
For example: (256, 512).
|
| 89 |
+
pose_encoding_type (str): Type of pose encoding used. Currently only
|
| 90 |
+
supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
|
| 91 |
+
build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.
|
| 92 |
+
If False, only extrinsics are returned and intrinsics will be None.
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
tuple: (extrinsics, intrinsics)
|
| 96 |
+
- extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.
|
| 97 |
+
In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world
|
| 98 |
+
transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is
|
| 99 |
+
a 3x1 translation vector.
|
| 100 |
+
- intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,
|
| 101 |
+
or None if build_intrinsics is False. Defined in pixels, with format:
|
| 102 |
+
[[fx, 0, cx],
|
| 103 |
+
[0, fy, cy],
|
| 104 |
+
[0, 0, 1]]
|
| 105 |
+
where fx, fy are focal lengths and (cx, cy) is the principal point,
|
| 106 |
+
assumed to be at the center of the image (W/2, H/2).
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
intrinsics = None
|
| 110 |
+
|
| 111 |
+
if pose_encoding_type == "absT_quaR_FoV":
|
| 112 |
+
T = pose_encoding[..., :3]
|
| 113 |
+
quat = pose_encoding[..., 3:7]
|
| 114 |
+
fov_h = pose_encoding[..., 7]
|
| 115 |
+
fov_w = pose_encoding[..., 8]
|
| 116 |
+
|
| 117 |
+
R = quat_to_mat(quat)
|
| 118 |
+
extrinsics = torch.cat([R, T[..., None]], dim=-1)
|
| 119 |
+
|
| 120 |
+
if build_intrinsics:
|
| 121 |
+
H, W = image_size_hw
|
| 122 |
+
fy = (H / 2.0) / torch.tan(fov_h / 2.0)
|
| 123 |
+
fx = (W / 2.0) / torch.tan(fov_w / 2.0)
|
| 124 |
+
intrinsics = torch.zeros(
|
| 125 |
+
pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device
|
| 126 |
+
)
|
| 127 |
+
intrinsics[..., 0, 0] = fx
|
| 128 |
+
intrinsics[..., 1, 1] = fy
|
| 129 |
+
intrinsics[..., 0, 2] = W / 2
|
| 130 |
+
intrinsics[..., 1, 2] = H / 2
|
| 131 |
+
intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
|
| 132 |
+
else:
|
| 133 |
+
raise NotImplementedError
|
| 134 |
+
|
| 135 |
+
return extrinsics, intrinsics
|
eval/utils/rotation.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
"""
|
| 16 |
+
Quaternion Order: XYZW or say ijkr, scalar-last
|
| 17 |
+
|
| 18 |
+
Convert rotations given as quaternions to rotation matrices.
|
| 19 |
+
Args:
|
| 20 |
+
quaternions: quaternions with real part last,
|
| 21 |
+
as tensor of shape (..., 4).
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 25 |
+
"""
|
| 26 |
+
i, j, k, r = torch.unbind(quaternions, -1)
|
| 27 |
+
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
| 28 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
| 29 |
+
|
| 30 |
+
o = torch.stack(
|
| 31 |
+
(
|
| 32 |
+
1 - two_s * (j * j + k * k),
|
| 33 |
+
two_s * (i * j - k * r),
|
| 34 |
+
two_s * (i * k + j * r),
|
| 35 |
+
two_s * (i * j + k * r),
|
| 36 |
+
1 - two_s * (i * i + k * k),
|
| 37 |
+
two_s * (j * k - i * r),
|
| 38 |
+
two_s * (i * k - j * r),
|
| 39 |
+
two_s * (j * k + i * r),
|
| 40 |
+
1 - two_s * (i * i + j * j),
|
| 41 |
+
),
|
| 42 |
+
-1,
|
| 43 |
+
)
|
| 44 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
"""
|
| 49 |
+
Convert rotations given as rotation matrices to quaternions.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
quaternions with real part last, as tensor of shape (..., 4).
|
| 56 |
+
Quaternion Order: XYZW or say ijkr, scalar-last
|
| 57 |
+
"""
|
| 58 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
| 59 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
| 60 |
+
|
| 61 |
+
batch_dim = matrix.shape[:-2]
|
| 62 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
| 63 |
+
matrix.reshape(batch_dim + (9,)), dim=-1
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
q_abs = _sqrt_positive_part(
|
| 67 |
+
torch.stack(
|
| 68 |
+
[
|
| 69 |
+
1.0 + m00 + m11 + m22,
|
| 70 |
+
1.0 + m00 - m11 - m22,
|
| 71 |
+
1.0 - m00 + m11 - m22,
|
| 72 |
+
1.0 - m00 - m11 + m22,
|
| 73 |
+
],
|
| 74 |
+
dim=-1,
|
| 75 |
+
)
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# we produce the desired quaternion multiplied by each of r, i, j, k
|
| 79 |
+
quat_by_rijk = torch.stack(
|
| 80 |
+
[
|
| 81 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 82 |
+
# `int`.
|
| 83 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
| 84 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 85 |
+
# `int`.
|
| 86 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
| 87 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 88 |
+
# `int`.
|
| 89 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
| 90 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 91 |
+
# `int`.
|
| 92 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
| 93 |
+
],
|
| 94 |
+
dim=-2,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
| 98 |
+
# the candidate won't be picked.
|
| 99 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
| 100 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
| 101 |
+
|
| 102 |
+
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
| 103 |
+
# forall i; we pick the best-conditioned one (with the largest denominator)
|
| 104 |
+
out = quat_candidates[
|
| 105 |
+
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
|
| 106 |
+
].reshape(batch_dim + (4,))
|
| 107 |
+
|
| 108 |
+
# Convert from rijk to ijkr
|
| 109 |
+
out = out[..., [1, 2, 3, 0]]
|
| 110 |
+
|
| 111 |
+
out = standardize_quaternion(out)
|
| 112 |
+
|
| 113 |
+
return out
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
| 117 |
+
"""
|
| 118 |
+
Returns torch.sqrt(torch.max(0, x))
|
| 119 |
+
but with a zero subgradient where x is 0.
|
| 120 |
+
"""
|
| 121 |
+
ret = torch.zeros_like(x)
|
| 122 |
+
positive_mask = x > 0
|
| 123 |
+
if torch.is_grad_enabled():
|
| 124 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
| 125 |
+
else:
|
| 126 |
+
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
| 127 |
+
return ret
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
| 131 |
+
"""
|
| 132 |
+
Convert a unit quaternion to a standard form: one in which the real
|
| 133 |
+
part is non negative.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
quaternions: Quaternions with real part last,
|
| 137 |
+
as tensor of shape (..., 4).
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
Standardized quaternions as tensor of shape (..., 4).
|
| 141 |
+
"""
|
| 142 |
+
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
|
eval/utils/visual_track.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
import cv2
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def color_from_xy(x, y, W, H, cmap_name="hsv"):
|
| 15 |
+
"""
|
| 16 |
+
Map (x, y) -> color in (R, G, B).
|
| 17 |
+
1) Normalize x,y to [0,1].
|
| 18 |
+
2) Combine them into a single scalar c in [0,1].
|
| 19 |
+
3) Use matplotlib's colormap to convert c -> (R,G,B).
|
| 20 |
+
|
| 21 |
+
You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y).
|
| 22 |
+
"""
|
| 23 |
+
import matplotlib.cm
|
| 24 |
+
import matplotlib.colors
|
| 25 |
+
|
| 26 |
+
x_norm = x / max(W - 1, 1)
|
| 27 |
+
y_norm = y / max(H - 1, 1)
|
| 28 |
+
# Simple combination:
|
| 29 |
+
c = (x_norm + y_norm) / 2.0
|
| 30 |
+
|
| 31 |
+
cmap = matplotlib.cm.get_cmap(cmap_name)
|
| 32 |
+
# cmap(c) -> (r,g,b,a) in [0,1]
|
| 33 |
+
rgba = cmap(c)
|
| 34 |
+
r, g, b = rgba[0], rgba[1], rgba[2]
|
| 35 |
+
return (r, g, b) # in [0,1], RGB order
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_track_colors_by_position(
|
| 39 |
+
tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
Given all tracks in one sample (b), compute a (N,3) array of RGB color values
|
| 43 |
+
in [0,255]. The color is determined by the (x,y) position in the first
|
| 44 |
+
visible frame for each track.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame.
|
| 48 |
+
vis_mask_b: (S, N) boolean mask; if None, assume all are visible.
|
| 49 |
+
image_width, image_height: used for normalizing (x, y).
|
| 50 |
+
cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet').
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255].
|
| 54 |
+
"""
|
| 55 |
+
S, N, _ = tracks_b.shape
|
| 56 |
+
track_colors = np.zeros((N, 3), dtype=np.uint8)
|
| 57 |
+
|
| 58 |
+
if vis_mask_b is None:
|
| 59 |
+
# treat all as visible
|
| 60 |
+
vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device)
|
| 61 |
+
|
| 62 |
+
for i in range(N):
|
| 63 |
+
# Find first visible frame for track i
|
| 64 |
+
visible_frames = torch.where(vis_mask_b[:, i])[0]
|
| 65 |
+
if len(visible_frames) == 0:
|
| 66 |
+
# track is never visible; just assign black or something
|
| 67 |
+
track_colors[i] = (0, 0, 0)
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
first_s = int(visible_frames[0].item())
|
| 71 |
+
# use that frame's (x,y)
|
| 72 |
+
x, y = tracks_b[first_s, i].tolist()
|
| 73 |
+
|
| 74 |
+
# map (x,y) -> (R,G,B) in [0,1]
|
| 75 |
+
r, g, b = color_from_xy(
|
| 76 |
+
x, y, W=image_width, H=image_height, cmap_name=cmap_name
|
| 77 |
+
)
|
| 78 |
+
# scale to [0,255]
|
| 79 |
+
r, g, b = int(r * 255), int(g * 255), int(b * 255)
|
| 80 |
+
track_colors[i] = (r, g, b)
|
| 81 |
+
|
| 82 |
+
return track_colors
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def visualize_tracks_on_images(
|
| 86 |
+
images,
|
| 87 |
+
tracks,
|
| 88 |
+
track_vis_mask=None,
|
| 89 |
+
out_dir="track_visuals_concat_by_xy",
|
| 90 |
+
image_format="CHW", # "CHW" or "HWC"
|
| 91 |
+
normalize_mode="[0,1]",
|
| 92 |
+
cmap_name="hsv", # e.g. "hsv", "rainbow", "jet"
|
| 93 |
+
frames_per_row=4, # New parameter for grid layout
|
| 94 |
+
save_grid=True, # Flag to control whether to save the grid image
|
| 95 |
+
):
|
| 96 |
+
"""
|
| 97 |
+
Visualizes frames in a grid layout with specified frames per row.
|
| 98 |
+
Each track's color is determined by its (x,y) position
|
| 99 |
+
in the first visible frame (or frame 0 if always visible).
|
| 100 |
+
Finally convert the BGR result to RGB before saving.
|
| 101 |
+
Also saves each individual frame as a separate PNG file.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC.
|
| 105 |
+
tracks: torch.Tensor (S, N, 2), last dim = (x, y).
|
| 106 |
+
track_vis_mask: torch.Tensor (S, N) or None.
|
| 107 |
+
out_dir: folder to save visualizations.
|
| 108 |
+
image_format: "CHW" or "HWC".
|
| 109 |
+
normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255
|
| 110 |
+
cmap_name: a matplotlib colormap name for color_from_xy.
|
| 111 |
+
frames_per_row: number of frames to display in each row of the grid.
|
| 112 |
+
save_grid: whether to save all frames in one grid image.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
None (saves images in out_dir).
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
if len(tracks.shape) == 4:
|
| 119 |
+
tracks = tracks.squeeze(0)
|
| 120 |
+
images = images.squeeze(0)
|
| 121 |
+
if track_vis_mask is not None:
|
| 122 |
+
track_vis_mask = track_vis_mask.squeeze(0)
|
| 123 |
+
|
| 124 |
+
import matplotlib
|
| 125 |
+
|
| 126 |
+
matplotlib.use("Agg") # for non-interactive (optional)
|
| 127 |
+
|
| 128 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 129 |
+
|
| 130 |
+
S = images.shape[0]
|
| 131 |
+
_, N, _ = tracks.shape # (S, N, 2)
|
| 132 |
+
|
| 133 |
+
# Move to CPU
|
| 134 |
+
images = images.cpu().clone()
|
| 135 |
+
tracks = tracks.cpu().clone()
|
| 136 |
+
if track_vis_mask is not None:
|
| 137 |
+
track_vis_mask = track_vis_mask.cpu().clone()
|
| 138 |
+
|
| 139 |
+
# Infer H, W from images shape
|
| 140 |
+
if image_format == "CHW":
|
| 141 |
+
# e.g. images[s].shape = (3, H, W)
|
| 142 |
+
H, W = images.shape[2], images.shape[3]
|
| 143 |
+
else:
|
| 144 |
+
# e.g. images[s].shape = (H, W, 3)
|
| 145 |
+
H, W = images.shape[1], images.shape[2]
|
| 146 |
+
|
| 147 |
+
# Pre-compute the color for each track i based on first visible position
|
| 148 |
+
track_colors_rgb = get_track_colors_by_position(
|
| 149 |
+
tracks, # shape (S, N, 2)
|
| 150 |
+
vis_mask_b=track_vis_mask if track_vis_mask is not None else None,
|
| 151 |
+
image_width=W,
|
| 152 |
+
image_height=H,
|
| 153 |
+
cmap_name=cmap_name,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# We'll accumulate each frame's drawn image in a list
|
| 157 |
+
frame_images = []
|
| 158 |
+
|
| 159 |
+
for s in range(S):
|
| 160 |
+
# shape => either (3, H, W) or (H, W, 3)
|
| 161 |
+
img = images[s]
|
| 162 |
+
|
| 163 |
+
# Convert to (H, W, 3)
|
| 164 |
+
if image_format == "CHW":
|
| 165 |
+
img = img.permute(1, 2, 0) # (H, W, 3)
|
| 166 |
+
# else "HWC", do nothing
|
| 167 |
+
|
| 168 |
+
img = img.numpy().astype(np.float32)
|
| 169 |
+
|
| 170 |
+
# Scale to [0,255] if needed
|
| 171 |
+
if normalize_mode == "[0,1]":
|
| 172 |
+
img = np.clip(img, 0, 1) * 255.0
|
| 173 |
+
elif normalize_mode == "[-1,1]":
|
| 174 |
+
img = (img + 1.0) * 0.5 * 255.0
|
| 175 |
+
img = np.clip(img, 0, 255.0)
|
| 176 |
+
# else no normalization
|
| 177 |
+
|
| 178 |
+
# Convert to uint8
|
| 179 |
+
img = img.astype(np.uint8)
|
| 180 |
+
|
| 181 |
+
# For drawing in OpenCV, convert to BGR
|
| 182 |
+
img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
| 183 |
+
|
| 184 |
+
# Draw each visible track
|
| 185 |
+
cur_tracks = tracks[s] # shape (N, 2)
|
| 186 |
+
if track_vis_mask is not None:
|
| 187 |
+
valid_indices = torch.where(track_vis_mask[s])[0]
|
| 188 |
+
else:
|
| 189 |
+
valid_indices = range(N)
|
| 190 |
+
|
| 191 |
+
cur_tracks_np = cur_tracks.numpy()
|
| 192 |
+
for i in valid_indices:
|
| 193 |
+
x, y = cur_tracks_np[i]
|
| 194 |
+
pt = (int(round(x)), int(round(y)))
|
| 195 |
+
|
| 196 |
+
# track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR
|
| 197 |
+
R, G, B = track_colors_rgb[i]
|
| 198 |
+
color_bgr = (int(B), int(G), int(R))
|
| 199 |
+
cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1)
|
| 200 |
+
|
| 201 |
+
# Convert back to RGB for consistent final saving:
|
| 202 |
+
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
| 203 |
+
|
| 204 |
+
# Save individual frame
|
| 205 |
+
frame_path = os.path.join(out_dir, f"frame_{s:04d}.png")
|
| 206 |
+
# Convert to BGR for OpenCV imwrite
|
| 207 |
+
frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
|
| 208 |
+
cv2.imwrite(frame_path, frame_bgr)
|
| 209 |
+
|
| 210 |
+
frame_images.append(img_rgb)
|
| 211 |
+
|
| 212 |
+
# Only create and save the grid image if save_grid is True
|
| 213 |
+
if save_grid:
|
| 214 |
+
# Calculate grid dimensions
|
| 215 |
+
num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division
|
| 216 |
+
|
| 217 |
+
# Create a grid of images
|
| 218 |
+
grid_img = None
|
| 219 |
+
for row in range(num_rows):
|
| 220 |
+
start_idx = row * frames_per_row
|
| 221 |
+
end_idx = min(start_idx + frames_per_row, S)
|
| 222 |
+
|
| 223 |
+
# Concatenate this row horizontally
|
| 224 |
+
row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1)
|
| 225 |
+
|
| 226 |
+
# If this row has fewer than frames_per_row images, pad with black
|
| 227 |
+
if end_idx - start_idx < frames_per_row:
|
| 228 |
+
padding_width = (frames_per_row - (end_idx - start_idx)) * W
|
| 229 |
+
padding = np.zeros((H, padding_width, 3), dtype=np.uint8)
|
| 230 |
+
row_img = np.concatenate([row_img, padding], axis=1)
|
| 231 |
+
|
| 232 |
+
# Add this row to the grid
|
| 233 |
+
if grid_img is None:
|
| 234 |
+
grid_img = row_img
|
| 235 |
+
else:
|
| 236 |
+
grid_img = np.concatenate([grid_img, row_img], axis=0)
|
| 237 |
+
|
| 238 |
+
out_path = os.path.join(out_dir, "tracks_grid.png")
|
| 239 |
+
# Convert back to BGR for OpenCV imwrite
|
| 240 |
+
grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR)
|
| 241 |
+
cv2.imwrite(out_path, grid_img_bgr)
|
| 242 |
+
print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}")
|
| 243 |
+
|
| 244 |
+
print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")
|
pyproject.toml
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
authors = [{name = "Junyuan DENG"},{name="Heng LI"}]
|
| 3 |
+
dependencies = [
|
| 4 |
+
"numpy<2",
|
| 5 |
+
"Pillow",
|
| 6 |
+
"huggingface_hub",
|
| 7 |
+
"einops",
|
| 8 |
+
"safetensors",
|
| 9 |
+
"opencv-python",
|
| 10 |
+
"torch>=2.3.1",
|
| 11 |
+
"torchvision>=0.18.1",
|
| 12 |
+
"numpy==1.26.1",
|
| 13 |
+
"evo",
|
| 14 |
+
"plyfile",
|
| 15 |
+
#"python-opencv",
|
| 16 |
+
]
|
| 17 |
+
name = "sailrecon"
|
| 18 |
+
requires-python = ">= 3.10"
|
| 19 |
+
version = "0.0.1"
|
| 20 |
+
|
| 21 |
+
[project.optional-dependencies]
|
| 22 |
+
demo = [
|
| 23 |
+
"gradio>=5.17.1",
|
| 24 |
+
"viser>=0.2.23",
|
| 25 |
+
"tqdm",
|
| 26 |
+
"hydra-core",
|
| 27 |
+
"omegaconf",
|
| 28 |
+
"opencv-python",
|
| 29 |
+
"scipy",
|
| 30 |
+
"onnxruntime",
|
| 31 |
+
"requests",
|
| 32 |
+
"trimesh",
|
| 33 |
+
"matplotlib",
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
# Using setuptools as the build backend
|
| 37 |
+
[build-system]
|
| 38 |
+
requires = ["setuptools>=61.0", "wheel"]
|
| 39 |
+
build-backend = "setuptools.build_meta"
|
| 40 |
+
|
| 41 |
+
# setuptools configuration
|
| 42 |
+
[tool.setuptools.packages.find]
|
| 43 |
+
where = ["."]
|
| 44 |
+
include = ["sailrecon*"]
|
| 45 |
+
|
| 46 |
+
# Pixi configuration
|
| 47 |
+
[tool.pixi.workspace]
|
| 48 |
+
channels = ["conda-forge"]
|
| 49 |
+
platforms = ["linux-64"]
|
| 50 |
+
|
| 51 |
+
[tool.pixi.pypi-dependencies]
|
| 52 |
+
sailrecon = { path = ".", editable = true }
|
| 53 |
+
|
| 54 |
+
[tool.pixi.environments]
|
| 55 |
+
default = { solve-group = "default" }
|
| 56 |
+
demo = { features = ["demo"], solve-group = "default" }
|
| 57 |
+
|
| 58 |
+
[tool.pixi.tasks]
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.3.1
|
| 2 |
+
torchvision==0.18.1
|
| 3 |
+
numpy==1.26.1
|
| 4 |
+
Pillow
|
| 5 |
+
huggingface_hub
|
| 6 |
+
einops
|
| 7 |
+
safetensors
|
| 8 |
+
evo
|
| 9 |
+
plyfile
|
| 10 |
+
python-opencv
|
requirements_demo.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==5.17.1
|
| 2 |
+
viser==0.2.23
|
| 3 |
+
tqdm
|
| 4 |
+
hydra-core
|
| 5 |
+
omegaconf
|
| 6 |
+
opencv-python
|
| 7 |
+
scipy
|
| 8 |
+
onnxruntime
|
| 9 |
+
requests
|
| 10 |
+
trimesh
|
| 11 |
+
matplotlib
|
| 12 |
+
pydantic==2.10.6
|
| 13 |
+
# feel free to skip the dependencies below if you do not need demo_colmap.py
|
| 14 |
+
# pycolmap==3.10.0
|
| 15 |
+
# pyceres==2.3
|
| 16 |
+
# git+https://github.com/jytime/LightGlue.git#egg=lightglue
|
sailrecon/dependency/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .track_modules.base_track_predictor import BaseTrackerPredictor
|
| 2 |
+
from .track_modules.blocks import BasicEncoder, ShallowEncoder
|
| 3 |
+
from .track_modules.track_refine import refine_track
|
sailrecon/dependency/distortion.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
ArrayLike = Union[np.ndarray, torch.Tensor]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _is_numpy(x: ArrayLike) -> bool:
|
| 16 |
+
return isinstance(x, np.ndarray)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _is_torch(x: ArrayLike) -> bool:
|
| 20 |
+
return isinstance(x, torch.Tensor)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _ensure_torch(x: ArrayLike) -> torch.Tensor:
|
| 24 |
+
"""Convert input to torch tensor if it's not already one."""
|
| 25 |
+
if _is_numpy(x):
|
| 26 |
+
return torch.from_numpy(x)
|
| 27 |
+
elif _is_torch(x):
|
| 28 |
+
return x
|
| 29 |
+
else:
|
| 30 |
+
return torch.tensor(x)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def single_undistortion(params, tracks_normalized):
|
| 34 |
+
"""
|
| 35 |
+
Apply undistortion to the normalized tracks using the given distortion parameters once.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN.
|
| 39 |
+
tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2].
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
torch.Tensor: Undistorted normalized tracks tensor.
|
| 43 |
+
"""
|
| 44 |
+
params = _ensure_torch(params)
|
| 45 |
+
tracks_normalized = _ensure_torch(tracks_normalized)
|
| 46 |
+
|
| 47 |
+
u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone()
|
| 48 |
+
u_undist, v_undist = apply_distortion(params, u, v)
|
| 49 |
+
return torch.stack([u_undist, v_undist], dim=-1)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def iterative_undistortion(
|
| 53 |
+
params,
|
| 54 |
+
tracks_normalized,
|
| 55 |
+
max_iterations=100,
|
| 56 |
+
max_step_norm=1e-10,
|
| 57 |
+
rel_step_size=1e-6,
|
| 58 |
+
):
|
| 59 |
+
"""
|
| 60 |
+
Iteratively undistort the normalized tracks using the given distortion parameters.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN.
|
| 64 |
+
tracks_normalized (torch.Tensor or numpy.ndarray): Normalized tracks tensor of shape [batch_size, num_tracks, 2].
|
| 65 |
+
max_iterations (int): Maximum number of iterations for the undistortion process.
|
| 66 |
+
max_step_norm (float): Maximum step norm for convergence.
|
| 67 |
+
rel_step_size (float): Relative step size for numerical differentiation.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
torch.Tensor: Undistorted normalized tracks tensor.
|
| 71 |
+
"""
|
| 72 |
+
params = _ensure_torch(params)
|
| 73 |
+
tracks_normalized = _ensure_torch(tracks_normalized)
|
| 74 |
+
|
| 75 |
+
B, N, _ = tracks_normalized.shape
|
| 76 |
+
u, v = tracks_normalized[..., 0].clone(), tracks_normalized[..., 1].clone()
|
| 77 |
+
original_u, original_v = u.clone(), v.clone()
|
| 78 |
+
|
| 79 |
+
eps = torch.finfo(u.dtype).eps
|
| 80 |
+
for idx in range(max_iterations):
|
| 81 |
+
u_undist, v_undist = apply_distortion(params, u, v)
|
| 82 |
+
dx = original_u - u_undist
|
| 83 |
+
dy = original_v - v_undist
|
| 84 |
+
|
| 85 |
+
step_u = torch.clamp(torch.abs(u) * rel_step_size, min=eps)
|
| 86 |
+
step_v = torch.clamp(torch.abs(v) * rel_step_size, min=eps)
|
| 87 |
+
|
| 88 |
+
J_00 = (
|
| 89 |
+
apply_distortion(params, u + step_u, v)[0]
|
| 90 |
+
- apply_distortion(params, u - step_u, v)[0]
|
| 91 |
+
) / (2 * step_u)
|
| 92 |
+
J_01 = (
|
| 93 |
+
apply_distortion(params, u, v + step_v)[0]
|
| 94 |
+
- apply_distortion(params, u, v - step_v)[0]
|
| 95 |
+
) / (2 * step_v)
|
| 96 |
+
J_10 = (
|
| 97 |
+
apply_distortion(params, u + step_u, v)[1]
|
| 98 |
+
- apply_distortion(params, u - step_u, v)[1]
|
| 99 |
+
) / (2 * step_u)
|
| 100 |
+
J_11 = (
|
| 101 |
+
apply_distortion(params, u, v + step_v)[1]
|
| 102 |
+
- apply_distortion(params, u, v - step_v)[1]
|
| 103 |
+
) / (2 * step_v)
|
| 104 |
+
|
| 105 |
+
J = torch.stack(
|
| 106 |
+
[
|
| 107 |
+
torch.stack([J_00 + 1, J_01], dim=-1),
|
| 108 |
+
torch.stack([J_10, J_11 + 1], dim=-1),
|
| 109 |
+
],
|
| 110 |
+
dim=-2,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
delta = torch.linalg.solve(J, torch.stack([dx, dy], dim=-1))
|
| 114 |
+
|
| 115 |
+
u += delta[..., 0]
|
| 116 |
+
v += delta[..., 1]
|
| 117 |
+
|
| 118 |
+
if torch.max((delta**2).sum(dim=-1)) < max_step_norm:
|
| 119 |
+
break
|
| 120 |
+
|
| 121 |
+
return torch.stack([u, v], dim=-1)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def apply_distortion(extra_params, u, v):
|
| 125 |
+
"""
|
| 126 |
+
Applies radial or OpenCV distortion to the given 2D points.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
extra_params (torch.Tensor or numpy.ndarray): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
|
| 130 |
+
u (torch.Tensor or numpy.ndarray): Normalized x coordinates of shape Bxnum_tracks.
|
| 131 |
+
v (torch.Tensor or numpy.ndarray): Normalized y coordinates of shape Bxnum_tracks.
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
points2D (torch.Tensor): Distorted 2D points of shape BxNx2.
|
| 135 |
+
"""
|
| 136 |
+
extra_params = _ensure_torch(extra_params)
|
| 137 |
+
u = _ensure_torch(u)
|
| 138 |
+
v = _ensure_torch(v)
|
| 139 |
+
|
| 140 |
+
num_params = extra_params.shape[1]
|
| 141 |
+
|
| 142 |
+
if num_params == 1:
|
| 143 |
+
# Simple radial distortion
|
| 144 |
+
k = extra_params[:, 0]
|
| 145 |
+
u2 = u * u
|
| 146 |
+
v2 = v * v
|
| 147 |
+
r2 = u2 + v2
|
| 148 |
+
radial = k[:, None] * r2
|
| 149 |
+
du = u * radial
|
| 150 |
+
dv = v * radial
|
| 151 |
+
|
| 152 |
+
elif num_params == 2:
|
| 153 |
+
# RadialCameraModel distortion
|
| 154 |
+
k1, k2 = extra_params[:, 0], extra_params[:, 1]
|
| 155 |
+
u2 = u * u
|
| 156 |
+
v2 = v * v
|
| 157 |
+
r2 = u2 + v2
|
| 158 |
+
radial = k1[:, None] * r2 + k2[:, None] * r2 * r2
|
| 159 |
+
du = u * radial
|
| 160 |
+
dv = v * radial
|
| 161 |
+
|
| 162 |
+
elif num_params == 4:
|
| 163 |
+
# OpenCVCameraModel distortion
|
| 164 |
+
k1, k2, p1, p2 = (
|
| 165 |
+
extra_params[:, 0],
|
| 166 |
+
extra_params[:, 1],
|
| 167 |
+
extra_params[:, 2],
|
| 168 |
+
extra_params[:, 3],
|
| 169 |
+
)
|
| 170 |
+
u2 = u * u
|
| 171 |
+
v2 = v * v
|
| 172 |
+
uv = u * v
|
| 173 |
+
r2 = u2 + v2
|
| 174 |
+
radial = k1[:, None] * r2 + k2[:, None] * r2 * r2
|
| 175 |
+
du = u * radial + 2 * p1[:, None] * uv + p2[:, None] * (r2 + 2 * u2)
|
| 176 |
+
dv = v * radial + 2 * p2[:, None] * uv + p1[:, None] * (r2 + 2 * v2)
|
| 177 |
+
else:
|
| 178 |
+
raise ValueError("Unsupported number of distortion parameters")
|
| 179 |
+
|
| 180 |
+
u = u.clone() + du
|
| 181 |
+
v = v.clone() + dv
|
| 182 |
+
|
| 183 |
+
return u, v
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
if __name__ == "__main__":
|
| 187 |
+
import random
|
| 188 |
+
|
| 189 |
+
import pycolmap
|
| 190 |
+
|
| 191 |
+
max_diff = 0
|
| 192 |
+
for i in range(1000):
|
| 193 |
+
# Define distortion parameters (assuming 1 parameter for simplicity)
|
| 194 |
+
B = random.randint(1, 500)
|
| 195 |
+
track_num = random.randint(100, 1000)
|
| 196 |
+
params = torch.rand((B, 1), dtype=torch.float32) # Batch size 1, 4 parameters
|
| 197 |
+
tracks_normalized = torch.rand(
|
| 198 |
+
(B, track_num, 2), dtype=torch.float32
|
| 199 |
+
) # Batch size 1, 5 points
|
| 200 |
+
|
| 201 |
+
# Undistort the tracks
|
| 202 |
+
undistorted_tracks = iterative_undistortion(params, tracks_normalized)
|
| 203 |
+
|
| 204 |
+
for b in range(B):
|
| 205 |
+
pycolmap_intri = np.array([1, 0, 0, params[b].item()])
|
| 206 |
+
pycam = pycolmap.Camera(
|
| 207 |
+
model="SIMPLE_RADIAL",
|
| 208 |
+
width=1,
|
| 209 |
+
height=1,
|
| 210 |
+
params=pycolmap_intri,
|
| 211 |
+
camera_id=0,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
undistorted_tracks_pycolmap = pycam.cam_from_img(
|
| 215 |
+
tracks_normalized[b].numpy()
|
| 216 |
+
)
|
| 217 |
+
diff = (undistorted_tracks[b] - undistorted_tracks_pycolmap).abs().median()
|
| 218 |
+
max_diff = max(max_diff, diff)
|
| 219 |
+
print(f"diff: {diff}, max_diff: {max_diff}")
|
| 220 |
+
|
| 221 |
+
import pdb
|
| 222 |
+
|
| 223 |
+
pdb.set_trace()
|
sailrecon/dependency/np_to_pycolmap.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pycolmap
|
| 9 |
+
|
| 10 |
+
from .projection import project_3D_points_np
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def batch_np_matrix_to_pycolmap(
|
| 14 |
+
points3d,
|
| 15 |
+
extrinsics,
|
| 16 |
+
intrinsics,
|
| 17 |
+
tracks,
|
| 18 |
+
image_size,
|
| 19 |
+
masks=None,
|
| 20 |
+
max_reproj_error=None,
|
| 21 |
+
max_points3D_val=3000,
|
| 22 |
+
shared_camera=False,
|
| 23 |
+
camera_type="SIMPLE_PINHOLE",
|
| 24 |
+
extra_params=None,
|
| 25 |
+
min_inlier_per_frame=64,
|
| 26 |
+
points_rgb=None,
|
| 27 |
+
):
|
| 28 |
+
"""
|
| 29 |
+
Convert Batched NumPy Arrays to PyCOLMAP
|
| 30 |
+
|
| 31 |
+
Check https://github.com/colmap/pycolmap for more details about its format
|
| 32 |
+
|
| 33 |
+
NOTE that colmap expects images/cameras/points3D to be 1-indexed
|
| 34 |
+
so there is a +1 offset between colmap index and batch index
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
NOTE: different from VGGSfM, this function:
|
| 38 |
+
1. Use np instead of torch
|
| 39 |
+
2. Frame index and camera id starts from 1 rather than 0 (to fit the format of PyCOLMAP)
|
| 40 |
+
"""
|
| 41 |
+
# points3d: Px3
|
| 42 |
+
# extrinsics: Nx3x4
|
| 43 |
+
# intrinsics: Nx3x3
|
| 44 |
+
# tracks: NxPx2
|
| 45 |
+
# masks: NxP
|
| 46 |
+
# image_size: 2, assume all the frames have been padded to the same size
|
| 47 |
+
# where N is the number of frames and P is the number of tracks
|
| 48 |
+
|
| 49 |
+
N, P, _ = tracks.shape
|
| 50 |
+
assert len(extrinsics) == N
|
| 51 |
+
assert len(intrinsics) == N
|
| 52 |
+
assert len(points3d) == P
|
| 53 |
+
assert image_size.shape[0] == 2
|
| 54 |
+
|
| 55 |
+
reproj_mask = None
|
| 56 |
+
|
| 57 |
+
if max_reproj_error is not None:
|
| 58 |
+
projected_points_2d, projected_points_cam = project_3D_points_np(
|
| 59 |
+
points3d, extrinsics, intrinsics
|
| 60 |
+
)
|
| 61 |
+
projected_diff = np.linalg.norm(projected_points_2d - tracks, axis=-1)
|
| 62 |
+
projected_points_2d[projected_points_cam[:, -1] <= 0] = 1e6
|
| 63 |
+
reproj_mask = projected_diff < max_reproj_error
|
| 64 |
+
|
| 65 |
+
if masks is not None and reproj_mask is not None:
|
| 66 |
+
masks = np.logical_and(masks, reproj_mask)
|
| 67 |
+
elif masks is not None:
|
| 68 |
+
masks = masks
|
| 69 |
+
else:
|
| 70 |
+
masks = reproj_mask
|
| 71 |
+
|
| 72 |
+
assert masks is not None
|
| 73 |
+
|
| 74 |
+
if masks.sum(1).min() < min_inlier_per_frame:
|
| 75 |
+
print(f"Not enough inliers per frame, skip BA.")
|
| 76 |
+
return None, None
|
| 77 |
+
|
| 78 |
+
# Reconstruction object, following the format of PyCOLMAP/COLMAP
|
| 79 |
+
reconstruction = pycolmap.Reconstruction()
|
| 80 |
+
|
| 81 |
+
inlier_num = masks.sum(0)
|
| 82 |
+
valid_mask = inlier_num >= 2 # a track is invalid if without two inliers
|
| 83 |
+
valid_idx = np.nonzero(valid_mask)[0]
|
| 84 |
+
|
| 85 |
+
# Only add 3D points that have sufficient 2D points
|
| 86 |
+
for vidx in valid_idx:
|
| 87 |
+
# Use RGB colors if provided, otherwise use zeros
|
| 88 |
+
rgb = points_rgb[vidx] if points_rgb is not None else np.zeros(3)
|
| 89 |
+
reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), rgb)
|
| 90 |
+
|
| 91 |
+
num_points3D = len(valid_idx)
|
| 92 |
+
camera = None
|
| 93 |
+
# frame idx
|
| 94 |
+
for fidx in range(N):
|
| 95 |
+
# set camera
|
| 96 |
+
if camera is None or (not shared_camera):
|
| 97 |
+
pycolmap_intri = _build_pycolmap_intri(
|
| 98 |
+
fidx, intrinsics, camera_type, extra_params
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
camera = pycolmap.Camera(
|
| 102 |
+
model=camera_type,
|
| 103 |
+
width=image_size[0],
|
| 104 |
+
height=image_size[1],
|
| 105 |
+
params=pycolmap_intri,
|
| 106 |
+
camera_id=fidx + 1,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# add camera
|
| 110 |
+
reconstruction.add_camera(camera)
|
| 111 |
+
|
| 112 |
+
# set image
|
| 113 |
+
cam_from_world = pycolmap.Rigid3d(
|
| 114 |
+
pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3]
|
| 115 |
+
) # Rot and Trans
|
| 116 |
+
|
| 117 |
+
image = pycolmap.Image(
|
| 118 |
+
id=fidx + 1,
|
| 119 |
+
name=f"image_{fidx + 1}",
|
| 120 |
+
camera_id=camera.camera_id,
|
| 121 |
+
cam_from_world=cam_from_world,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
points2D_list = []
|
| 125 |
+
|
| 126 |
+
point2D_idx = 0
|
| 127 |
+
|
| 128 |
+
# NOTE point3D_id start by 1
|
| 129 |
+
for point3D_id in range(1, num_points3D + 1):
|
| 130 |
+
original_track_idx = valid_idx[point3D_id - 1]
|
| 131 |
+
|
| 132 |
+
if (reconstruction.points3D[point3D_id].xyz < max_points3D_val).all():
|
| 133 |
+
if masks[fidx][original_track_idx]:
|
| 134 |
+
# It seems we don't need +0.5 for BA
|
| 135 |
+
point2D_xy = tracks[fidx][original_track_idx]
|
| 136 |
+
# Please note when adding the Point2D object
|
| 137 |
+
# It not only requires the 2D xy location, but also the id to 3D point
|
| 138 |
+
points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id))
|
| 139 |
+
|
| 140 |
+
# add element
|
| 141 |
+
track = reconstruction.points3D[point3D_id].track
|
| 142 |
+
track.add_element(fidx + 1, point2D_idx)
|
| 143 |
+
point2D_idx += 1
|
| 144 |
+
|
| 145 |
+
assert point2D_idx == len(points2D_list)
|
| 146 |
+
|
| 147 |
+
try:
|
| 148 |
+
image.points2D = pycolmap.ListPoint2D(points2D_list)
|
| 149 |
+
image.registered = True
|
| 150 |
+
except:
|
| 151 |
+
print(f"frame {fidx + 1} is out of BA")
|
| 152 |
+
image.registered = False
|
| 153 |
+
|
| 154 |
+
# add image
|
| 155 |
+
reconstruction.add_image(image)
|
| 156 |
+
|
| 157 |
+
return reconstruction, valid_mask
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def pycolmap_to_batch_np_matrix(
|
| 161 |
+
reconstruction, device="cpu", camera_type="SIMPLE_PINHOLE"
|
| 162 |
+
):
|
| 163 |
+
"""
|
| 164 |
+
Convert a PyCOLMAP Reconstruction Object to batched NumPy arrays.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
reconstruction (pycolmap.Reconstruction): The reconstruction object from PyCOLMAP.
|
| 168 |
+
device (str): Ignored in NumPy version (kept for API compatibility).
|
| 169 |
+
camera_type (str): The type of camera model used (default: "SIMPLE_PINHOLE").
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
tuple: A tuple containing points3D, extrinsics, intrinsics, and optionally extra_params.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
num_images = len(reconstruction.images)
|
| 176 |
+
max_points3D_id = max(reconstruction.point3D_ids())
|
| 177 |
+
points3D = np.zeros((max_points3D_id, 3))
|
| 178 |
+
|
| 179 |
+
for point3D_id in reconstruction.points3D:
|
| 180 |
+
points3D[point3D_id - 1] = reconstruction.points3D[point3D_id].xyz
|
| 181 |
+
|
| 182 |
+
extrinsics = []
|
| 183 |
+
intrinsics = []
|
| 184 |
+
|
| 185 |
+
extra_params = [] if camera_type == "SIMPLE_RADIAL" else None
|
| 186 |
+
|
| 187 |
+
for i in range(num_images):
|
| 188 |
+
# Extract and append extrinsics
|
| 189 |
+
pyimg = reconstruction.images[i + 1]
|
| 190 |
+
pycam = reconstruction.cameras[pyimg.camera_id]
|
| 191 |
+
matrix = pyimg.cam_from_world.matrix()
|
| 192 |
+
extrinsics.append(matrix)
|
| 193 |
+
|
| 194 |
+
# Extract and append intrinsics
|
| 195 |
+
calibration_matrix = pycam.calibration_matrix()
|
| 196 |
+
intrinsics.append(calibration_matrix)
|
| 197 |
+
|
| 198 |
+
if camera_type == "SIMPLE_RADIAL":
|
| 199 |
+
extra_params.append(pycam.params[-1])
|
| 200 |
+
|
| 201 |
+
# Convert lists to NumPy arrays instead of torch tensors
|
| 202 |
+
extrinsics = np.stack(extrinsics)
|
| 203 |
+
intrinsics = np.stack(intrinsics)
|
| 204 |
+
|
| 205 |
+
if camera_type == "SIMPLE_RADIAL":
|
| 206 |
+
extra_params = np.stack(extra_params)
|
| 207 |
+
extra_params = extra_params[:, None]
|
| 208 |
+
|
| 209 |
+
return points3D, extrinsics, intrinsics, extra_params
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
########################################################
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def batch_np_matrix_to_pycolmap_wo_track(
|
| 216 |
+
points3d,
|
| 217 |
+
points_xyf,
|
| 218 |
+
points_rgb,
|
| 219 |
+
extrinsics,
|
| 220 |
+
intrinsics,
|
| 221 |
+
image_size,
|
| 222 |
+
shared_camera=False,
|
| 223 |
+
camera_type="SIMPLE_PINHOLE",
|
| 224 |
+
):
|
| 225 |
+
"""
|
| 226 |
+
Convert Batched NumPy Arrays to PyCOLMAP
|
| 227 |
+
|
| 228 |
+
Different from batch_np_matrix_to_pycolmap, this function does not use tracks.
|
| 229 |
+
|
| 230 |
+
It saves points3d to colmap reconstruction format only to serve as init for Gaussians or other nvs methods.
|
| 231 |
+
|
| 232 |
+
Do NOT use this for BA.
|
| 233 |
+
"""
|
| 234 |
+
# points3d: Px3
|
| 235 |
+
# points_xyf: Px3, with x, y coordinates and frame indices
|
| 236 |
+
# points_rgb: Px3, rgb colors
|
| 237 |
+
# extrinsics: Nx3x4
|
| 238 |
+
# intrinsics: Nx3x3
|
| 239 |
+
# image_size: 2, assume all the frames have been padded to the same size
|
| 240 |
+
# where N is the number of frames and P is the number of tracks
|
| 241 |
+
|
| 242 |
+
N = len(extrinsics)
|
| 243 |
+
P = len(points3d)
|
| 244 |
+
|
| 245 |
+
# Reconstruction object, following the format of PyCOLMAP/COLMAP
|
| 246 |
+
reconstruction = pycolmap.Reconstruction()
|
| 247 |
+
|
| 248 |
+
for vidx in range(P):
|
| 249 |
+
reconstruction.add_point3D(points3d[vidx], pycolmap.Track(), points_rgb[vidx])
|
| 250 |
+
|
| 251 |
+
camera = None
|
| 252 |
+
# frame idx
|
| 253 |
+
for fidx in range(N):
|
| 254 |
+
# set camera
|
| 255 |
+
if camera is None or (not shared_camera):
|
| 256 |
+
pycolmap_intri = _build_pycolmap_intri(fidx, intrinsics, camera_type)
|
| 257 |
+
|
| 258 |
+
camera = pycolmap.Camera(
|
| 259 |
+
model=camera_type,
|
| 260 |
+
width=image_size[0],
|
| 261 |
+
height=image_size[1],
|
| 262 |
+
params=pycolmap_intri,
|
| 263 |
+
camera_id=fidx + 1,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# add camera
|
| 267 |
+
reconstruction.add_camera(camera)
|
| 268 |
+
|
| 269 |
+
# set image
|
| 270 |
+
cam_from_world = pycolmap.Rigid3d(
|
| 271 |
+
pycolmap.Rotation3d(extrinsics[fidx][:3, :3]), extrinsics[fidx][:3, 3]
|
| 272 |
+
) # Rot and Trans
|
| 273 |
+
|
| 274 |
+
image = pycolmap.Image(
|
| 275 |
+
id=fidx + 1,
|
| 276 |
+
name=f"image_{fidx + 1}",
|
| 277 |
+
camera_id=camera.camera_id,
|
| 278 |
+
cam_from_world=cam_from_world,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
points2D_list = []
|
| 282 |
+
|
| 283 |
+
point2D_idx = 0
|
| 284 |
+
|
| 285 |
+
points_belong_to_fidx = points_xyf[:, 2].astype(np.int32) == fidx
|
| 286 |
+
points_belong_to_fidx = np.nonzero(points_belong_to_fidx)[0]
|
| 287 |
+
|
| 288 |
+
for point3D_batch_idx in points_belong_to_fidx:
|
| 289 |
+
point3D_id = point3D_batch_idx + 1
|
| 290 |
+
point2D_xyf = points_xyf[point3D_batch_idx]
|
| 291 |
+
point2D_xy = point2D_xyf[:2]
|
| 292 |
+
points2D_list.append(pycolmap.Point2D(point2D_xy, point3D_id))
|
| 293 |
+
|
| 294 |
+
# add element
|
| 295 |
+
track = reconstruction.points3D[point3D_id].track
|
| 296 |
+
track.add_element(fidx + 1, point2D_idx)
|
| 297 |
+
point2D_idx += 1
|
| 298 |
+
|
| 299 |
+
assert point2D_idx == len(points2D_list)
|
| 300 |
+
|
| 301 |
+
try:
|
| 302 |
+
image.points2D = pycolmap.ListPoint2D(points2D_list)
|
| 303 |
+
image.registered = True
|
| 304 |
+
except:
|
| 305 |
+
print(f"frame {fidx + 1} does not have any points")
|
| 306 |
+
image.registered = False
|
| 307 |
+
|
| 308 |
+
# add image
|
| 309 |
+
reconstruction.add_image(image)
|
| 310 |
+
|
| 311 |
+
return reconstruction
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def _build_pycolmap_intri(fidx, intrinsics, camera_type, extra_params=None):
|
| 315 |
+
"""
|
| 316 |
+
Helper function to get camera parameters based on camera type.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
fidx: Frame index
|
| 320 |
+
intrinsics: Camera intrinsic parameters
|
| 321 |
+
camera_type: Type of camera model
|
| 322 |
+
extra_params: Additional parameters for certain camera types
|
| 323 |
+
|
| 324 |
+
Returns:
|
| 325 |
+
pycolmap_intri: NumPy array of camera parameters
|
| 326 |
+
"""
|
| 327 |
+
if camera_type == "PINHOLE":
|
| 328 |
+
pycolmap_intri = np.array(
|
| 329 |
+
[
|
| 330 |
+
intrinsics[fidx][0, 0],
|
| 331 |
+
intrinsics[fidx][1, 1],
|
| 332 |
+
intrinsics[fidx][0, 2],
|
| 333 |
+
intrinsics[fidx][1, 2],
|
| 334 |
+
]
|
| 335 |
+
)
|
| 336 |
+
elif camera_type == "SIMPLE_PINHOLE":
|
| 337 |
+
focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2
|
| 338 |
+
pycolmap_intri = np.array(
|
| 339 |
+
[focal, intrinsics[fidx][0, 2], intrinsics[fidx][1, 2]]
|
| 340 |
+
)
|
| 341 |
+
elif camera_type == "SIMPLE_RADIAL":
|
| 342 |
+
raise NotImplementedError("SIMPLE_RADIAL is not supported yet")
|
| 343 |
+
focal = (intrinsics[fidx][0, 0] + intrinsics[fidx][1, 1]) / 2
|
| 344 |
+
pycolmap_intri = np.array(
|
| 345 |
+
[
|
| 346 |
+
focal,
|
| 347 |
+
intrinsics[fidx][0, 2],
|
| 348 |
+
intrinsics[fidx][1, 2],
|
| 349 |
+
extra_params[fidx][0],
|
| 350 |
+
]
|
| 351 |
+
)
|
| 352 |
+
else:
|
| 353 |
+
raise ValueError(f"Camera type {camera_type} is not supported yet")
|
| 354 |
+
|
| 355 |
+
return pycolmap_intri
|
sailrecon/dependency/projection.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from .distortion import apply_distortion
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def img_from_cam_np(
|
| 14 |
+
intrinsics: np.ndarray,
|
| 15 |
+
points_cam: np.ndarray,
|
| 16 |
+
extra_params: np.ndarray | None = None,
|
| 17 |
+
default: float = 0.0,
|
| 18 |
+
) -> np.ndarray:
|
| 19 |
+
"""
|
| 20 |
+
Apply intrinsics (and optional radial distortion) to camera-space points.
|
| 21 |
+
|
| 22 |
+
Args
|
| 23 |
+
----
|
| 24 |
+
intrinsics : (B,3,3) camera matrix K.
|
| 25 |
+
points_cam : (B,3,N) homogeneous camera coords (x, y, z)ᵀ.
|
| 26 |
+
extra_params: (B, N) or (B, k) distortion params (k = 1,2,4) or None.
|
| 27 |
+
default : value used for np.nan replacement.
|
| 28 |
+
|
| 29 |
+
Returns
|
| 30 |
+
-------
|
| 31 |
+
points2D : (B,N,2) pixel coordinates.
|
| 32 |
+
"""
|
| 33 |
+
# 1. perspective divide ───────────────────────────────────────
|
| 34 |
+
z = points_cam[:, 2:3, :] # (B,1,N)
|
| 35 |
+
points_cam_norm = points_cam / z # (B,3,N)
|
| 36 |
+
uv = points_cam_norm[:, :2, :] # (B,2,N)
|
| 37 |
+
|
| 38 |
+
# 2. optional distortion ──────────────────────────────────────
|
| 39 |
+
if extra_params is not None:
|
| 40 |
+
uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1])
|
| 41 |
+
uv = np.stack([uu, vv], axis=1) # (B,2,N)
|
| 42 |
+
|
| 43 |
+
# 3. homogeneous coords then K multiplication ─────────────────
|
| 44 |
+
ones = np.ones_like(uv[:, :1, :]) # (B,1,N)
|
| 45 |
+
points_cam_h = np.concatenate([uv, ones], axis=1) # (B,3,N)
|
| 46 |
+
|
| 47 |
+
# batched mat-mul: K · [u v 1]ᵀ
|
| 48 |
+
points2D_h = np.einsum("bij,bjk->bik", intrinsics, points_cam_h) # (B,3,N)
|
| 49 |
+
points2D = np.nan_to_num(points2D_h[:, :2, :], nan=default) # (B,2,N)
|
| 50 |
+
|
| 51 |
+
return points2D.transpose(0, 2, 1) # (B,N,2)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def project_3D_points_np(
|
| 55 |
+
points3D: np.ndarray,
|
| 56 |
+
extrinsics: np.ndarray,
|
| 57 |
+
intrinsics: np.ndarray | None = None,
|
| 58 |
+
extra_params: np.ndarray | None = None,
|
| 59 |
+
*,
|
| 60 |
+
default: float = 0.0,
|
| 61 |
+
only_points_cam: bool = False,
|
| 62 |
+
):
|
| 63 |
+
"""
|
| 64 |
+
NumPy clone of ``project_3D_points``.
|
| 65 |
+
|
| 66 |
+
Parameters
|
| 67 |
+
----------
|
| 68 |
+
points3D : (N,3) world-space points.
|
| 69 |
+
extrinsics : (B,3,4) [R|t] matrix for each of B cameras.
|
| 70 |
+
intrinsics : (B,3,3) K matrix (optional if you only need cam-space).
|
| 71 |
+
extra_params : (B,k) or (B,N) distortion parameters (k ∈ {1,2,4}) or None.
|
| 72 |
+
default : value used to replace NaNs.
|
| 73 |
+
only_points_cam : if True, skip the projection and return points_cam with points2D as None.
|
| 74 |
+
|
| 75 |
+
Returns
|
| 76 |
+
-------
|
| 77 |
+
(points2D, points_cam) : A tuple where points2D is (B,N,2) pixel coords or None if only_points_cam=True,
|
| 78 |
+
and points_cam is (B,3,N) camera-space coordinates.
|
| 79 |
+
"""
|
| 80 |
+
# ----- 0. prep sizes -----------------------------------------------------
|
| 81 |
+
N = points3D.shape[0] # #points
|
| 82 |
+
B = extrinsics.shape[0] # #cameras
|
| 83 |
+
|
| 84 |
+
# ----- 1. world → homogeneous -------------------------------------------
|
| 85 |
+
w_h = np.ones((N, 1), dtype=points3D.dtype)
|
| 86 |
+
points3D_h = np.concatenate([points3D, w_h], axis=1) # (N,4)
|
| 87 |
+
|
| 88 |
+
# broadcast to every camera (no actual copying with np.broadcast_to) ------
|
| 89 |
+
points3D_h_B = np.broadcast_to(points3D_h, (B, N, 4)) # (B,N,4)
|
| 90 |
+
|
| 91 |
+
# ----- 2. apply extrinsics (camera frame) ------------------------------
|
| 92 |
+
# X_cam = E · X_hom
|
| 93 |
+
# einsum: E_(b i j) · X_(b n j) → (b n i)
|
| 94 |
+
points_cam = np.einsum("bij,bnj->bni", extrinsics, points3D_h_B) # (B,N,3)
|
| 95 |
+
points_cam = points_cam.transpose(0, 2, 1) # (B,3,N)
|
| 96 |
+
|
| 97 |
+
if only_points_cam:
|
| 98 |
+
return None, points_cam
|
| 99 |
+
|
| 100 |
+
# ----- 3. intrinsics + distortion ---------------------------------------
|
| 101 |
+
if intrinsics is None:
|
| 102 |
+
raise ValueError("`intrinsics` must be provided unless only_points_cam=True")
|
| 103 |
+
|
| 104 |
+
points2D = img_from_cam_np(
|
| 105 |
+
intrinsics, points_cam, extra_params=extra_params, default=default
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
return points2D, points_cam
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def project_3D_points(
|
| 112 |
+
points3D,
|
| 113 |
+
extrinsics,
|
| 114 |
+
intrinsics=None,
|
| 115 |
+
extra_params=None,
|
| 116 |
+
default=0,
|
| 117 |
+
only_points_cam=False,
|
| 118 |
+
):
|
| 119 |
+
"""
|
| 120 |
+
Transforms 3D points to 2D using extrinsic and intrinsic parameters.
|
| 121 |
+
Args:
|
| 122 |
+
points3D (torch.Tensor): 3D points of shape Px3.
|
| 123 |
+
extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4.
|
| 124 |
+
intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3.
|
| 125 |
+
extra_params (torch.Tensor): Extra parameters of shape BxN, used for radial distortion.
|
| 126 |
+
default (float): Default value to replace NaNs.
|
| 127 |
+
only_points_cam (bool): If True, skip the projection and return points2D as None.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
tuple: (points2D, points_cam) where points2D is of shape BxNx2 or None if only_points_cam=True,
|
| 131 |
+
and points_cam is of shape Bx3xN.
|
| 132 |
+
"""
|
| 133 |
+
with torch.cuda.amp.autocast(dtype=torch.double):
|
| 134 |
+
N = points3D.shape[0] # Number of points
|
| 135 |
+
B = extrinsics.shape[0] # Batch size, i.e., number of cameras
|
| 136 |
+
points3D_homogeneous = torch.cat(
|
| 137 |
+
[points3D, torch.ones_like(points3D[..., 0:1])], dim=1
|
| 138 |
+
) # Nx4
|
| 139 |
+
# Reshape for batch processing
|
| 140 |
+
points3D_homogeneous = points3D_homogeneous.unsqueeze(0).expand(
|
| 141 |
+
B, -1, -1
|
| 142 |
+
) # BxNx4
|
| 143 |
+
|
| 144 |
+
# Step 1: Apply extrinsic parameters
|
| 145 |
+
# Transform 3D points to camera coordinate system for all cameras
|
| 146 |
+
points_cam = torch.bmm(extrinsics, points3D_homogeneous.transpose(-1, -2))
|
| 147 |
+
|
| 148 |
+
if only_points_cam:
|
| 149 |
+
return None, points_cam
|
| 150 |
+
|
| 151 |
+
# Step 2: Apply intrinsic parameters and (optional) distortion
|
| 152 |
+
points2D = img_from_cam(intrinsics, points_cam, extra_params, default)
|
| 153 |
+
|
| 154 |
+
return points2D, points_cam
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def img_from_cam(intrinsics, points_cam, extra_params=None, default=0.0):
|
| 158 |
+
"""
|
| 159 |
+
Applies intrinsic parameters and optional distortion to the given 3D points.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3.
|
| 163 |
+
points_cam (torch.Tensor): 3D points in camera coordinates of shape Bx3xN.
|
| 164 |
+
extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
|
| 165 |
+
default (float, optional): Default value to replace NaNs in the output.
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
points2D (torch.Tensor): 2D points in pixel coordinates of shape BxNx2.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
# Normalize by the third coordinate (homogeneous division)
|
| 172 |
+
points_cam = points_cam / points_cam[:, 2:3, :]
|
| 173 |
+
# Extract uv
|
| 174 |
+
uv = points_cam[:, :2, :]
|
| 175 |
+
|
| 176 |
+
# Apply distortion if extra_params are provided
|
| 177 |
+
if extra_params is not None:
|
| 178 |
+
uu, vv = apply_distortion(extra_params, uv[:, 0], uv[:, 1])
|
| 179 |
+
uv = torch.stack([uu, vv], dim=1)
|
| 180 |
+
|
| 181 |
+
# Prepare points_cam for batch matrix multiplication
|
| 182 |
+
points_cam_homo = torch.cat((uv, torch.ones_like(uv[:, :1, :])), dim=1) # Bx3xN
|
| 183 |
+
# Apply intrinsic parameters using batch matrix multiplication
|
| 184 |
+
points2D_homo = torch.bmm(intrinsics, points_cam_homo) # Bx3xN
|
| 185 |
+
|
| 186 |
+
# Extract x and y coordinates
|
| 187 |
+
points2D = points2D_homo[:, :2, :] # Bx2xN
|
| 188 |
+
|
| 189 |
+
# Replace NaNs with default value
|
| 190 |
+
points2D = torch.nan_to_num(points2D, nan=default)
|
| 191 |
+
|
| 192 |
+
return points2D.transpose(1, 2) # BxNx2
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
if __name__ == "__main__":
|
| 196 |
+
# Set up example input
|
| 197 |
+
B, N = 24, 10240
|
| 198 |
+
|
| 199 |
+
for _ in range(100):
|
| 200 |
+
points3D = np.random.rand(N, 3).astype(np.float64)
|
| 201 |
+
extrinsics = np.random.rand(B, 3, 4).astype(np.float64)
|
| 202 |
+
intrinsics = np.random.rand(B, 3, 3).astype(np.float64)
|
| 203 |
+
|
| 204 |
+
# Convert to torch tensors
|
| 205 |
+
points3D_torch = torch.tensor(points3D)
|
| 206 |
+
extrinsics_torch = torch.tensor(extrinsics)
|
| 207 |
+
intrinsics_torch = torch.tensor(intrinsics)
|
| 208 |
+
|
| 209 |
+
# Run NumPy implementation
|
| 210 |
+
points2D_np, points_cam_np = project_3D_points_np(
|
| 211 |
+
points3D, extrinsics, intrinsics
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# Run torch implementation
|
| 215 |
+
points2D_torch, points_cam_torch = project_3D_points(
|
| 216 |
+
points3D_torch, extrinsics_torch, intrinsics_torch
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# Convert torch output to numpy
|
| 220 |
+
points2D_torch_np = points2D_torch.detach().numpy()
|
| 221 |
+
points_cam_torch_np = points_cam_torch.detach().numpy()
|
| 222 |
+
|
| 223 |
+
# Compute difference
|
| 224 |
+
diff = np.abs(points2D_np - points2D_torch_np)
|
| 225 |
+
print("Difference between NumPy and PyTorch implementations:")
|
| 226 |
+
print(diff)
|
| 227 |
+
|
| 228 |
+
# Check max error
|
| 229 |
+
max_diff = np.max(diff)
|
| 230 |
+
print(f"Maximum difference: {max_diff}")
|
| 231 |
+
|
| 232 |
+
if np.allclose(points2D_np, points2D_torch_np, atol=1e-6):
|
| 233 |
+
print("Implementations match closely.")
|
| 234 |
+
else:
|
| 235 |
+
print("Significant differences detected.")
|
| 236 |
+
|
| 237 |
+
if points_cam_np is not None:
|
| 238 |
+
points_cam_diff = np.abs(points_cam_np - points_cam_torch_np)
|
| 239 |
+
print("Difference between NumPy and PyTorch camera-space coordinates:")
|
| 240 |
+
print(points_cam_diff)
|
| 241 |
+
|
| 242 |
+
# Check max error
|
| 243 |
+
max_cam_diff = np.max(points_cam_diff)
|
| 244 |
+
print(f"Maximum camera-space coordinate difference: {max_cam_diff}")
|
| 245 |
+
|
| 246 |
+
if np.allclose(points_cam_np, points_cam_torch_np, atol=1e-6):
|
| 247 |
+
print("Camera-space coordinates match closely.")
|
| 248 |
+
else:
|
| 249 |
+
print("Significant differences detected in camera-space coordinates.")
|
sailrecon/dependency/track_modules/__init__.py
ADDED
|
File without changes
|
sailrecon/dependency/track_modules/base_track_predictor.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from einops import rearrange, repeat
|
| 10 |
+
|
| 11 |
+
from .blocks import CorrBlock, EfficientUpdateFormer
|
| 12 |
+
from .utils import get_2d_embedding, get_2d_sincos_pos_embed, sample_features4d
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BaseTrackerPredictor(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
stride=4,
|
| 19 |
+
corr_levels=5,
|
| 20 |
+
corr_radius=4,
|
| 21 |
+
latent_dim=128,
|
| 22 |
+
hidden_size=384,
|
| 23 |
+
use_spaceatt=True,
|
| 24 |
+
depth=6,
|
| 25 |
+
fine=False,
|
| 26 |
+
):
|
| 27 |
+
super(BaseTrackerPredictor, self).__init__()
|
| 28 |
+
"""
|
| 29 |
+
The base template to create a track predictor
|
| 30 |
+
|
| 31 |
+
Modified from https://github.com/facebookresearch/co-tracker/
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
self.stride = stride
|
| 35 |
+
self.latent_dim = latent_dim
|
| 36 |
+
self.corr_levels = corr_levels
|
| 37 |
+
self.corr_radius = corr_radius
|
| 38 |
+
self.hidden_size = hidden_size
|
| 39 |
+
self.fine = fine
|
| 40 |
+
|
| 41 |
+
self.flows_emb_dim = latent_dim // 2
|
| 42 |
+
self.transformer_dim = (
|
| 43 |
+
self.corr_levels * (self.corr_radius * 2 + 1) ** 2 + self.latent_dim * 2
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
if self.fine:
|
| 47 |
+
# TODO this is the old dummy code, will remove this when we train next model
|
| 48 |
+
self.transformer_dim += 4 if self.transformer_dim % 2 == 0 else 5
|
| 49 |
+
else:
|
| 50 |
+
self.transformer_dim += (4 - self.transformer_dim % 4) % 4
|
| 51 |
+
|
| 52 |
+
space_depth = depth if use_spaceatt else 0
|
| 53 |
+
time_depth = depth
|
| 54 |
+
|
| 55 |
+
self.updateformer = EfficientUpdateFormer(
|
| 56 |
+
space_depth=space_depth,
|
| 57 |
+
time_depth=time_depth,
|
| 58 |
+
input_dim=self.transformer_dim,
|
| 59 |
+
hidden_size=self.hidden_size,
|
| 60 |
+
output_dim=self.latent_dim + 2,
|
| 61 |
+
mlp_ratio=4.0,
|
| 62 |
+
add_space_attn=use_spaceatt,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
self.norm = nn.GroupNorm(1, self.latent_dim)
|
| 66 |
+
|
| 67 |
+
# A linear layer to update track feats at each iteration
|
| 68 |
+
self.ffeat_updater = nn.Sequential(
|
| 69 |
+
nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if not self.fine:
|
| 73 |
+
self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
|
| 74 |
+
|
| 75 |
+
def forward(
|
| 76 |
+
self, query_points, fmaps=None, iters=4, return_feat=False, down_ratio=1
|
| 77 |
+
):
|
| 78 |
+
"""
|
| 79 |
+
query_points: B x N x 2, the number of batches, tracks, and xy
|
| 80 |
+
fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
|
| 81 |
+
note HH and WW is the size of feature maps instead of original images
|
| 82 |
+
"""
|
| 83 |
+
B, N, D = query_points.shape
|
| 84 |
+
B, S, C, HH, WW = fmaps.shape
|
| 85 |
+
|
| 86 |
+
assert D == 2
|
| 87 |
+
|
| 88 |
+
# Scale the input query_points because we may downsample the images
|
| 89 |
+
# by down_ratio or self.stride
|
| 90 |
+
# e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
|
| 91 |
+
# its query_points should be query_points/4
|
| 92 |
+
if down_ratio > 1:
|
| 93 |
+
query_points = query_points / float(down_ratio)
|
| 94 |
+
query_points = query_points / float(self.stride)
|
| 95 |
+
|
| 96 |
+
# Init with coords as the query points
|
| 97 |
+
# It means the search will start from the position of query points at the reference frames
|
| 98 |
+
coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
|
| 99 |
+
|
| 100 |
+
# Sample/extract the features of the query points in the query frame
|
| 101 |
+
query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
|
| 102 |
+
|
| 103 |
+
# init track feats by query feats
|
| 104 |
+
track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
|
| 105 |
+
# back up the init coords
|
| 106 |
+
coords_backup = coords.clone()
|
| 107 |
+
|
| 108 |
+
# Construct the correlation block
|
| 109 |
+
|
| 110 |
+
fcorr_fn = CorrBlock(
|
| 111 |
+
fmaps, num_levels=self.corr_levels, radius=self.corr_radius
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
coord_preds = []
|
| 115 |
+
|
| 116 |
+
# Iterative Refinement
|
| 117 |
+
for itr in range(iters):
|
| 118 |
+
# Detach the gradients from the last iteration
|
| 119 |
+
# (in my experience, not very important for performance)
|
| 120 |
+
coords = coords.detach()
|
| 121 |
+
|
| 122 |
+
# Compute the correlation (check the implementation of CorrBlock)
|
| 123 |
+
|
| 124 |
+
fcorr_fn.corr(track_feats)
|
| 125 |
+
fcorrs = fcorr_fn.sample(coords) # B, S, N, corrdim
|
| 126 |
+
|
| 127 |
+
corrdim = fcorrs.shape[3]
|
| 128 |
+
|
| 129 |
+
fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corrdim)
|
| 130 |
+
|
| 131 |
+
# Movement of current coords relative to query points
|
| 132 |
+
flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
|
| 133 |
+
|
| 134 |
+
flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
|
| 135 |
+
|
| 136 |
+
# (In my trials, it is also okay to just add the flows_emb instead of concat)
|
| 137 |
+
flows_emb = torch.cat([flows_emb, flows], dim=-1)
|
| 138 |
+
|
| 139 |
+
track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(
|
| 140 |
+
B * N, S, self.latent_dim
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Concatenate them as the input for the transformers
|
| 144 |
+
transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
|
| 145 |
+
|
| 146 |
+
if transformer_input.shape[2] < self.transformer_dim:
|
| 147 |
+
# pad the features to match the dimension
|
| 148 |
+
pad_dim = self.transformer_dim - transformer_input.shape[2]
|
| 149 |
+
pad = torch.zeros_like(flows_emb[..., 0:pad_dim])
|
| 150 |
+
transformer_input = torch.cat([transformer_input, pad], dim=2)
|
| 151 |
+
|
| 152 |
+
# 2D positional embed
|
| 153 |
+
# TODO: this can be much simplified
|
| 154 |
+
pos_embed = get_2d_sincos_pos_embed(
|
| 155 |
+
self.transformer_dim, grid_size=(HH, WW)
|
| 156 |
+
).to(query_points.device)
|
| 157 |
+
sampled_pos_emb = sample_features4d(
|
| 158 |
+
pos_embed.expand(B, -1, -1, -1), coords[:, 0]
|
| 159 |
+
)
|
| 160 |
+
sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(
|
| 161 |
+
1
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
x = transformer_input + sampled_pos_emb
|
| 165 |
+
|
| 166 |
+
# B, N, S, C
|
| 167 |
+
x = rearrange(x, "(b n) s d -> b n s d", b=B)
|
| 168 |
+
|
| 169 |
+
# Compute the delta coordinates and delta track features
|
| 170 |
+
delta = self.updateformer(x)
|
| 171 |
+
# BN, S, C
|
| 172 |
+
delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
|
| 173 |
+
delta_coords_ = delta[:, :, :2]
|
| 174 |
+
delta_feats_ = delta[:, :, 2:]
|
| 175 |
+
|
| 176 |
+
track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
|
| 177 |
+
delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
|
| 178 |
+
|
| 179 |
+
# Update the track features
|
| 180 |
+
track_feats_ = self.ffeat_updater(self.norm(delta_feats_)) + track_feats_
|
| 181 |
+
track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(
|
| 182 |
+
0, 2, 1, 3
|
| 183 |
+
) # BxSxNxC
|
| 184 |
+
|
| 185 |
+
# B x S x N x 2
|
| 186 |
+
coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
|
| 187 |
+
|
| 188 |
+
# Force coord0 as query
|
| 189 |
+
# because we assume the query points should not be changed
|
| 190 |
+
coords[:, 0] = coords_backup[:, 0]
|
| 191 |
+
|
| 192 |
+
# The predicted tracks are in the original image scale
|
| 193 |
+
if down_ratio > 1:
|
| 194 |
+
coord_preds.append(coords * self.stride * down_ratio)
|
| 195 |
+
else:
|
| 196 |
+
coord_preds.append(coords * self.stride)
|
| 197 |
+
|
| 198 |
+
# B, S, N
|
| 199 |
+
if not self.fine:
|
| 200 |
+
vis_e = self.vis_predictor(
|
| 201 |
+
track_feats.reshape(B * S * N, self.latent_dim)
|
| 202 |
+
).reshape(B, S, N)
|
| 203 |
+
vis_e = torch.sigmoid(vis_e)
|
| 204 |
+
else:
|
| 205 |
+
vis_e = None
|
| 206 |
+
|
| 207 |
+
if return_feat:
|
| 208 |
+
return coord_preds, vis_e, track_feats, query_track_feat
|
| 209 |
+
else:
|
| 210 |
+
return coord_preds, vis_e
|
sailrecon/dependency/track_modules/blocks.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Modified from https://github.com/facebookresearch/co-tracker/
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import collections
|
| 12 |
+
from functools import partial
|
| 13 |
+
from itertools import repeat
|
| 14 |
+
from typing import Callable
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from torch import Tensor
|
| 20 |
+
|
| 21 |
+
from .modules import AttnBlock, CrossAttnBlock, Mlp, ResidualBlock
|
| 22 |
+
from .utils import bilinear_sampler
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class BasicEncoder(nn.Module):
|
| 26 |
+
def __init__(self, input_dim=3, output_dim=128, stride=4):
|
| 27 |
+
super(BasicEncoder, self).__init__()
|
| 28 |
+
|
| 29 |
+
self.stride = stride
|
| 30 |
+
self.norm_fn = "instance"
|
| 31 |
+
self.in_planes = output_dim // 2
|
| 32 |
+
|
| 33 |
+
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
| 34 |
+
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
| 35 |
+
|
| 36 |
+
self.conv1 = nn.Conv2d(
|
| 37 |
+
input_dim,
|
| 38 |
+
self.in_planes,
|
| 39 |
+
kernel_size=7,
|
| 40 |
+
stride=2,
|
| 41 |
+
padding=3,
|
| 42 |
+
padding_mode="zeros",
|
| 43 |
+
)
|
| 44 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 45 |
+
self.layer1 = self._make_layer(output_dim // 2, stride=1)
|
| 46 |
+
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
|
| 47 |
+
self.layer3 = self._make_layer(output_dim, stride=2)
|
| 48 |
+
self.layer4 = self._make_layer(output_dim, stride=2)
|
| 49 |
+
|
| 50 |
+
self.conv2 = nn.Conv2d(
|
| 51 |
+
output_dim * 3 + output_dim // 4,
|
| 52 |
+
output_dim * 2,
|
| 53 |
+
kernel_size=3,
|
| 54 |
+
padding=1,
|
| 55 |
+
padding_mode="zeros",
|
| 56 |
+
)
|
| 57 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 58 |
+
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
| 59 |
+
|
| 60 |
+
for m in self.modules():
|
| 61 |
+
if isinstance(m, nn.Conv2d):
|
| 62 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 63 |
+
elif isinstance(m, (nn.InstanceNorm2d)):
|
| 64 |
+
if m.weight is not None:
|
| 65 |
+
nn.init.constant_(m.weight, 1)
|
| 66 |
+
if m.bias is not None:
|
| 67 |
+
nn.init.constant_(m.bias, 0)
|
| 68 |
+
|
| 69 |
+
def _make_layer(self, dim, stride=1):
|
| 70 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
| 71 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
| 72 |
+
layers = (layer1, layer2)
|
| 73 |
+
|
| 74 |
+
self.in_planes = dim
|
| 75 |
+
return nn.Sequential(*layers)
|
| 76 |
+
|
| 77 |
+
def forward(self, x):
|
| 78 |
+
_, _, H, W = x.shape
|
| 79 |
+
|
| 80 |
+
x = self.conv1(x)
|
| 81 |
+
x = self.norm1(x)
|
| 82 |
+
x = self.relu1(x)
|
| 83 |
+
|
| 84 |
+
a = self.layer1(x)
|
| 85 |
+
b = self.layer2(a)
|
| 86 |
+
c = self.layer3(b)
|
| 87 |
+
d = self.layer4(c)
|
| 88 |
+
|
| 89 |
+
a = _bilinear_intepolate(a, self.stride, H, W)
|
| 90 |
+
b = _bilinear_intepolate(b, self.stride, H, W)
|
| 91 |
+
c = _bilinear_intepolate(c, self.stride, H, W)
|
| 92 |
+
d = _bilinear_intepolate(d, self.stride, H, W)
|
| 93 |
+
|
| 94 |
+
x = self.conv2(torch.cat([a, b, c, d], dim=1))
|
| 95 |
+
x = self.norm2(x)
|
| 96 |
+
x = self.relu2(x)
|
| 97 |
+
x = self.conv3(x)
|
| 98 |
+
return x
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class ShallowEncoder(nn.Module):
|
| 102 |
+
def __init__(self, input_dim=3, output_dim=32, stride=1, norm_fn="instance"):
|
| 103 |
+
super(ShallowEncoder, self).__init__()
|
| 104 |
+
self.stride = stride
|
| 105 |
+
self.norm_fn = norm_fn
|
| 106 |
+
self.in_planes = output_dim
|
| 107 |
+
|
| 108 |
+
if self.norm_fn == "group":
|
| 109 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes)
|
| 110 |
+
self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2)
|
| 111 |
+
elif self.norm_fn == "batch":
|
| 112 |
+
self.norm1 = nn.BatchNorm2d(self.in_planes)
|
| 113 |
+
self.norm2 = nn.BatchNorm2d(output_dim * 2)
|
| 114 |
+
elif self.norm_fn == "instance":
|
| 115 |
+
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
| 116 |
+
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
| 117 |
+
elif self.norm_fn == "none":
|
| 118 |
+
self.norm1 = nn.Sequential()
|
| 119 |
+
|
| 120 |
+
self.conv1 = nn.Conv2d(
|
| 121 |
+
input_dim,
|
| 122 |
+
self.in_planes,
|
| 123 |
+
kernel_size=3,
|
| 124 |
+
stride=2,
|
| 125 |
+
padding=1,
|
| 126 |
+
padding_mode="zeros",
|
| 127 |
+
)
|
| 128 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 129 |
+
|
| 130 |
+
self.layer1 = self._make_layer(output_dim, stride=2)
|
| 131 |
+
|
| 132 |
+
self.layer2 = self._make_layer(output_dim, stride=2)
|
| 133 |
+
self.conv2 = nn.Conv2d(output_dim, output_dim, kernel_size=1)
|
| 134 |
+
|
| 135 |
+
for m in self.modules():
|
| 136 |
+
if isinstance(m, nn.Conv2d):
|
| 137 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 138 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
| 139 |
+
if m.weight is not None:
|
| 140 |
+
nn.init.constant_(m.weight, 1)
|
| 141 |
+
if m.bias is not None:
|
| 142 |
+
nn.init.constant_(m.bias, 0)
|
| 143 |
+
|
| 144 |
+
def _make_layer(self, dim, stride=1):
|
| 145 |
+
self.in_planes = dim
|
| 146 |
+
|
| 147 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
| 148 |
+
return layer1
|
| 149 |
+
|
| 150 |
+
def forward(self, x):
|
| 151 |
+
_, _, H, W = x.shape
|
| 152 |
+
|
| 153 |
+
x = self.conv1(x)
|
| 154 |
+
x = self.norm1(x)
|
| 155 |
+
x = self.relu1(x)
|
| 156 |
+
|
| 157 |
+
tmp = self.layer1(x)
|
| 158 |
+
x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True)
|
| 159 |
+
tmp = self.layer2(tmp)
|
| 160 |
+
x = x + F.interpolate(tmp, (x.shape[-2:]), mode="bilinear", align_corners=True)
|
| 161 |
+
tmp = None
|
| 162 |
+
x = self.conv2(x) + x
|
| 163 |
+
|
| 164 |
+
x = F.interpolate(
|
| 165 |
+
x, (H // self.stride, W // self.stride), mode="bilinear", align_corners=True
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
return x
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _bilinear_intepolate(x, stride, H, W):
|
| 172 |
+
return F.interpolate(
|
| 173 |
+
x, (H // stride, W // stride), mode="bilinear", align_corners=True
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class EfficientUpdateFormer(nn.Module):
|
| 178 |
+
"""
|
| 179 |
+
Transformer model that updates track estimates.
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
def __init__(
|
| 183 |
+
self,
|
| 184 |
+
space_depth=6,
|
| 185 |
+
time_depth=6,
|
| 186 |
+
input_dim=320,
|
| 187 |
+
hidden_size=384,
|
| 188 |
+
num_heads=8,
|
| 189 |
+
output_dim=130,
|
| 190 |
+
mlp_ratio=4.0,
|
| 191 |
+
add_space_attn=True,
|
| 192 |
+
num_virtual_tracks=64,
|
| 193 |
+
):
|
| 194 |
+
super().__init__()
|
| 195 |
+
|
| 196 |
+
self.out_channels = 2
|
| 197 |
+
self.num_heads = num_heads
|
| 198 |
+
self.hidden_size = hidden_size
|
| 199 |
+
self.add_space_attn = add_space_attn
|
| 200 |
+
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
| 201 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
|
| 202 |
+
self.num_virtual_tracks = num_virtual_tracks
|
| 203 |
+
|
| 204 |
+
if self.add_space_attn:
|
| 205 |
+
self.virual_tracks = nn.Parameter(
|
| 206 |
+
torch.randn(1, num_virtual_tracks, 1, hidden_size)
|
| 207 |
+
)
|
| 208 |
+
else:
|
| 209 |
+
self.virual_tracks = None
|
| 210 |
+
|
| 211 |
+
self.time_blocks = nn.ModuleList(
|
| 212 |
+
[
|
| 213 |
+
AttnBlock(
|
| 214 |
+
hidden_size,
|
| 215 |
+
num_heads,
|
| 216 |
+
mlp_ratio=mlp_ratio,
|
| 217 |
+
attn_class=nn.MultiheadAttention,
|
| 218 |
+
)
|
| 219 |
+
for _ in range(time_depth)
|
| 220 |
+
]
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
if add_space_attn:
|
| 224 |
+
self.space_virtual_blocks = nn.ModuleList(
|
| 225 |
+
[
|
| 226 |
+
AttnBlock(
|
| 227 |
+
hidden_size,
|
| 228 |
+
num_heads,
|
| 229 |
+
mlp_ratio=mlp_ratio,
|
| 230 |
+
attn_class=nn.MultiheadAttention,
|
| 231 |
+
)
|
| 232 |
+
for _ in range(space_depth)
|
| 233 |
+
]
|
| 234 |
+
)
|
| 235 |
+
self.space_point2virtual_blocks = nn.ModuleList(
|
| 236 |
+
[
|
| 237 |
+
CrossAttnBlock(
|
| 238 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
| 239 |
+
)
|
| 240 |
+
for _ in range(space_depth)
|
| 241 |
+
]
|
| 242 |
+
)
|
| 243 |
+
self.space_virtual2point_blocks = nn.ModuleList(
|
| 244 |
+
[
|
| 245 |
+
CrossAttnBlock(
|
| 246 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
| 247 |
+
)
|
| 248 |
+
for _ in range(space_depth)
|
| 249 |
+
]
|
| 250 |
+
)
|
| 251 |
+
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
|
| 252 |
+
self.initialize_weights()
|
| 253 |
+
|
| 254 |
+
def initialize_weights(self):
|
| 255 |
+
def _basic_init(module):
|
| 256 |
+
if isinstance(module, nn.Linear):
|
| 257 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 258 |
+
if module.bias is not None:
|
| 259 |
+
nn.init.constant_(module.bias, 0)
|
| 260 |
+
|
| 261 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
| 262 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 263 |
+
if isinstance(module, nn.Linear):
|
| 264 |
+
trunc_normal_(module.weight, std=0.02)
|
| 265 |
+
if module.bias is not None:
|
| 266 |
+
nn.init.zeros_(module.bias)
|
| 267 |
+
|
| 268 |
+
def forward(self, input_tensor, mask=None):
|
| 269 |
+
tokens = self.input_transform(input_tensor)
|
| 270 |
+
|
| 271 |
+
init_tokens = tokens
|
| 272 |
+
|
| 273 |
+
B, _, T, _ = tokens.shape
|
| 274 |
+
|
| 275 |
+
if self.add_space_attn:
|
| 276 |
+
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
|
| 277 |
+
tokens = torch.cat([tokens, virtual_tokens], dim=1)
|
| 278 |
+
|
| 279 |
+
_, N, _, _ = tokens.shape
|
| 280 |
+
|
| 281 |
+
j = 0
|
| 282 |
+
for i in range(len(self.time_blocks)):
|
| 283 |
+
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
|
| 284 |
+
time_tokens = self.time_blocks[i](time_tokens)
|
| 285 |
+
|
| 286 |
+
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
|
| 287 |
+
if self.add_space_attn and (
|
| 288 |
+
i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0
|
| 289 |
+
):
|
| 290 |
+
space_tokens = (
|
| 291 |
+
tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
|
| 292 |
+
) # B N T C -> (B T) N C
|
| 293 |
+
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
|
| 294 |
+
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
|
| 295 |
+
|
| 296 |
+
virtual_tokens = self.space_virtual2point_blocks[j](
|
| 297 |
+
virtual_tokens, point_tokens, mask=mask
|
| 298 |
+
)
|
| 299 |
+
virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
|
| 300 |
+
point_tokens = self.space_point2virtual_blocks[j](
|
| 301 |
+
point_tokens, virtual_tokens, mask=mask
|
| 302 |
+
)
|
| 303 |
+
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
|
| 304 |
+
tokens = space_tokens.view(B, T, N, -1).permute(
|
| 305 |
+
0, 2, 1, 3
|
| 306 |
+
) # (B T) N C -> B N T C
|
| 307 |
+
j += 1
|
| 308 |
+
|
| 309 |
+
if self.add_space_attn:
|
| 310 |
+
tokens = tokens[:, : N - self.num_virtual_tracks]
|
| 311 |
+
|
| 312 |
+
tokens = tokens + init_tokens
|
| 313 |
+
|
| 314 |
+
flow = self.flow_head(tokens)
|
| 315 |
+
return flow
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class CorrBlock:
|
| 319 |
+
def __init__(
|
| 320 |
+
self,
|
| 321 |
+
fmaps,
|
| 322 |
+
num_levels=4,
|
| 323 |
+
radius=4,
|
| 324 |
+
multiple_track_feats=False,
|
| 325 |
+
padding_mode="zeros",
|
| 326 |
+
):
|
| 327 |
+
B, S, C, H, W = fmaps.shape
|
| 328 |
+
self.S, self.C, self.H, self.W = S, C, H, W
|
| 329 |
+
self.padding_mode = padding_mode
|
| 330 |
+
self.num_levels = num_levels
|
| 331 |
+
self.radius = radius
|
| 332 |
+
self.fmaps_pyramid = []
|
| 333 |
+
self.multiple_track_feats = multiple_track_feats
|
| 334 |
+
|
| 335 |
+
self.fmaps_pyramid.append(fmaps)
|
| 336 |
+
for i in range(self.num_levels - 1):
|
| 337 |
+
fmaps_ = fmaps.reshape(B * S, C, H, W)
|
| 338 |
+
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
| 339 |
+
_, _, H, W = fmaps_.shape
|
| 340 |
+
fmaps = fmaps_.reshape(B, S, C, H, W)
|
| 341 |
+
self.fmaps_pyramid.append(fmaps)
|
| 342 |
+
|
| 343 |
+
def sample(self, coords):
|
| 344 |
+
r = self.radius
|
| 345 |
+
B, S, N, D = coords.shape
|
| 346 |
+
assert D == 2
|
| 347 |
+
|
| 348 |
+
H, W = self.H, self.W
|
| 349 |
+
out_pyramid = []
|
| 350 |
+
for i in range(self.num_levels):
|
| 351 |
+
corrs = self.corrs_pyramid[i] # B, S, N, H, W
|
| 352 |
+
*_, H, W = corrs.shape
|
| 353 |
+
|
| 354 |
+
dx = torch.linspace(-r, r, 2 * r + 1)
|
| 355 |
+
dy = torch.linspace(-r, r, 2 * r + 1)
|
| 356 |
+
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
|
| 357 |
+
coords.device
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
|
| 361 |
+
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
| 362 |
+
coords_lvl = centroid_lvl + delta_lvl
|
| 363 |
+
|
| 364 |
+
corrs = bilinear_sampler(
|
| 365 |
+
corrs.reshape(B * S * N, 1, H, W),
|
| 366 |
+
coords_lvl,
|
| 367 |
+
padding_mode=self.padding_mode,
|
| 368 |
+
)
|
| 369 |
+
corrs = corrs.view(B, S, N, -1)
|
| 370 |
+
|
| 371 |
+
out_pyramid.append(corrs)
|
| 372 |
+
|
| 373 |
+
out = torch.cat(out_pyramid, dim=-1).contiguous() # B, S, N, LRR*2
|
| 374 |
+
return out
|
| 375 |
+
|
| 376 |
+
def corr(self, targets):
|
| 377 |
+
B, S, N, C = targets.shape
|
| 378 |
+
if self.multiple_track_feats:
|
| 379 |
+
targets_split = targets.split(C // self.num_levels, dim=-1)
|
| 380 |
+
B, S, N, C = targets_split[0].shape
|
| 381 |
+
|
| 382 |
+
assert C == self.C
|
| 383 |
+
assert S == self.S
|
| 384 |
+
|
| 385 |
+
fmap1 = targets
|
| 386 |
+
|
| 387 |
+
self.corrs_pyramid = []
|
| 388 |
+
for i, fmaps in enumerate(self.fmaps_pyramid):
|
| 389 |
+
*_, H, W = fmaps.shape
|
| 390 |
+
fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
|
| 391 |
+
if self.multiple_track_feats:
|
| 392 |
+
fmap1 = targets_split[i]
|
| 393 |
+
corrs = torch.matmul(fmap1, fmap2s)
|
| 394 |
+
corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
|
| 395 |
+
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
| 396 |
+
self.corrs_pyramid.append(corrs)
|
sailrecon/dependency/track_modules/modules.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import collections
|
| 9 |
+
from functools import partial
|
| 10 |
+
from itertools import repeat
|
| 11 |
+
from typing import Callable
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from torch import Tensor
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# From PyTorch internals
|
| 20 |
+
def _ntuple(n):
|
| 21 |
+
def parse(x):
|
| 22 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
| 23 |
+
return tuple(x)
|
| 24 |
+
return tuple(repeat(x, n))
|
| 25 |
+
|
| 26 |
+
return parse
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def exists(val):
|
| 30 |
+
return val is not None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def default(val, d):
|
| 34 |
+
return val if exists(val) else d
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
to_2tuple = _ntuple(2)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ResidualBlock(nn.Module):
|
| 41 |
+
"""
|
| 42 |
+
ResidualBlock: construct a block of two conv layers with residual connections
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
|
| 46 |
+
super(ResidualBlock, self).__init__()
|
| 47 |
+
|
| 48 |
+
self.conv1 = nn.Conv2d(
|
| 49 |
+
in_planes,
|
| 50 |
+
planes,
|
| 51 |
+
kernel_size=kernel_size,
|
| 52 |
+
padding=1,
|
| 53 |
+
stride=stride,
|
| 54 |
+
padding_mode="zeros",
|
| 55 |
+
)
|
| 56 |
+
self.conv2 = nn.Conv2d(
|
| 57 |
+
planes, planes, kernel_size=kernel_size, padding=1, padding_mode="zeros"
|
| 58 |
+
)
|
| 59 |
+
self.relu = nn.ReLU(inplace=True)
|
| 60 |
+
|
| 61 |
+
num_groups = planes // 8
|
| 62 |
+
|
| 63 |
+
if norm_fn == "group":
|
| 64 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 65 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 66 |
+
if not stride == 1:
|
| 67 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
| 68 |
+
|
| 69 |
+
elif norm_fn == "batch":
|
| 70 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
| 71 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
| 72 |
+
if not stride == 1:
|
| 73 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
| 74 |
+
|
| 75 |
+
elif norm_fn == "instance":
|
| 76 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
| 77 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
| 78 |
+
if not stride == 1:
|
| 79 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
| 80 |
+
|
| 81 |
+
elif norm_fn == "none":
|
| 82 |
+
self.norm1 = nn.Sequential()
|
| 83 |
+
self.norm2 = nn.Sequential()
|
| 84 |
+
if not stride == 1:
|
| 85 |
+
self.norm3 = nn.Sequential()
|
| 86 |
+
else:
|
| 87 |
+
raise NotImplementedError
|
| 88 |
+
|
| 89 |
+
if stride == 1:
|
| 90 |
+
self.downsample = None
|
| 91 |
+
else:
|
| 92 |
+
self.downsample = nn.Sequential(
|
| 93 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def forward(self, x):
|
| 97 |
+
y = x
|
| 98 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
| 99 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
| 100 |
+
|
| 101 |
+
if self.downsample is not None:
|
| 102 |
+
x = self.downsample(x)
|
| 103 |
+
|
| 104 |
+
return self.relu(x + y)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class Mlp(nn.Module):
|
| 108 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
| 109 |
+
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
in_features,
|
| 113 |
+
hidden_features=None,
|
| 114 |
+
out_features=None,
|
| 115 |
+
act_layer=nn.GELU,
|
| 116 |
+
norm_layer=None,
|
| 117 |
+
bias=True,
|
| 118 |
+
drop=0.0,
|
| 119 |
+
use_conv=False,
|
| 120 |
+
):
|
| 121 |
+
super().__init__()
|
| 122 |
+
out_features = out_features or in_features
|
| 123 |
+
hidden_features = hidden_features or in_features
|
| 124 |
+
bias = to_2tuple(bias)
|
| 125 |
+
drop_probs = to_2tuple(drop)
|
| 126 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
| 127 |
+
|
| 128 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
| 129 |
+
self.act = act_layer()
|
| 130 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
| 131 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
| 132 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
x = self.fc1(x)
|
| 136 |
+
x = self.act(x)
|
| 137 |
+
x = self.drop1(x)
|
| 138 |
+
x = self.fc2(x)
|
| 139 |
+
x = self.drop2(x)
|
| 140 |
+
return x
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class AttnBlock(nn.Module):
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
hidden_size,
|
| 147 |
+
num_heads,
|
| 148 |
+
attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
|
| 149 |
+
mlp_ratio=4.0,
|
| 150 |
+
**block_kwargs,
|
| 151 |
+
):
|
| 152 |
+
"""
|
| 153 |
+
Self attention block
|
| 154 |
+
"""
|
| 155 |
+
super().__init__()
|
| 156 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 157 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 158 |
+
|
| 159 |
+
self.attn = attn_class(
|
| 160 |
+
embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 164 |
+
|
| 165 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
|
| 166 |
+
|
| 167 |
+
def forward(self, x, mask=None):
|
| 168 |
+
# Prepare the mask for PyTorch's attention (it expects a different format)
|
| 169 |
+
# attn_mask = mask if mask is not None else None
|
| 170 |
+
# Normalize before attention
|
| 171 |
+
x = self.norm1(x)
|
| 172 |
+
|
| 173 |
+
# PyTorch's MultiheadAttention returns attn_output, attn_output_weights
|
| 174 |
+
# attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
|
| 175 |
+
|
| 176 |
+
attn_output, _ = self.attn(x, x, x)
|
| 177 |
+
|
| 178 |
+
# Add & Norm
|
| 179 |
+
x = x + attn_output
|
| 180 |
+
x = x + self.mlp(self.norm2(x))
|
| 181 |
+
return x
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class CrossAttnBlock(nn.Module):
|
| 185 |
+
def __init__(
|
| 186 |
+
self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs
|
| 187 |
+
):
|
| 188 |
+
"""
|
| 189 |
+
Cross attention block
|
| 190 |
+
"""
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 193 |
+
self.norm_context = nn.LayerNorm(hidden_size)
|
| 194 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 195 |
+
|
| 196 |
+
self.cross_attn = nn.MultiheadAttention(
|
| 197 |
+
embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 201 |
+
|
| 202 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
|
| 203 |
+
|
| 204 |
+
def forward(self, x, context, mask=None):
|
| 205 |
+
# Normalize inputs
|
| 206 |
+
x = self.norm1(x)
|
| 207 |
+
context = self.norm_context(context)
|
| 208 |
+
|
| 209 |
+
# Apply cross attention
|
| 210 |
+
# Note: nn.MultiheadAttention returns attn_output, attn_output_weights
|
| 211 |
+
attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
|
| 212 |
+
|
| 213 |
+
# Add & Norm
|
| 214 |
+
x = x + attn_output
|
| 215 |
+
x = x + self.mlp(self.norm2(x))
|
| 216 |
+
return x
|
sailrecon/dependency/track_modules/track_refine.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
from functools import partial
|
| 10 |
+
from typing import Tuple, Union
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from einops import rearrange, repeat
|
| 17 |
+
from einops.layers.torch import Rearrange, Reduce
|
| 18 |
+
from PIL import Image
|
| 19 |
+
from torch import einsum, nn
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def refine_track(
|
| 23 |
+
images,
|
| 24 |
+
fine_fnet,
|
| 25 |
+
fine_tracker,
|
| 26 |
+
coarse_pred,
|
| 27 |
+
compute_score=False,
|
| 28 |
+
pradius=15,
|
| 29 |
+
sradius=2,
|
| 30 |
+
fine_iters=6,
|
| 31 |
+
chunk=40960,
|
| 32 |
+
):
|
| 33 |
+
"""
|
| 34 |
+
Refines the tracking of images using a fine track predictor and a fine feature network.
|
| 35 |
+
Check https://arxiv.org/abs/2312.04563 for more details.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
images (torch.Tensor): The images to be tracked.
|
| 39 |
+
fine_fnet (nn.Module): The fine feature network.
|
| 40 |
+
fine_tracker (nn.Module): The fine track predictor.
|
| 41 |
+
coarse_pred (torch.Tensor): The coarse predictions of tracks.
|
| 42 |
+
compute_score (bool, optional): Whether to compute the score. Defaults to False.
|
| 43 |
+
pradius (int, optional): The radius of a patch. Defaults to 15.
|
| 44 |
+
sradius (int, optional): The search radius. Defaults to 2.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
torch.Tensor: The refined tracks.
|
| 48 |
+
torch.Tensor, optional: The score.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
# coarse_pred shape: BxSxNx2,
|
| 52 |
+
# where B is the batch, S is the video/images length, and N is the number of tracks
|
| 53 |
+
# now we are going to extract patches with the center at coarse_pred
|
| 54 |
+
# Please note that the last dimension indicates x and y, and hence has a dim number of 2
|
| 55 |
+
B, S, N, _ = coarse_pred.shape
|
| 56 |
+
_, _, _, H, W = images.shape
|
| 57 |
+
|
| 58 |
+
# Given the raidus of a patch, compute the patch size
|
| 59 |
+
psize = pradius * 2 + 1
|
| 60 |
+
|
| 61 |
+
# Note that we assume the first frame is the query frame
|
| 62 |
+
# so the 2D locations of the first frame are the query points
|
| 63 |
+
query_points = coarse_pred[:, 0]
|
| 64 |
+
|
| 65 |
+
# Given 2D positions, we can use grid_sample to extract patches
|
| 66 |
+
# but it takes too much memory.
|
| 67 |
+
# Instead, we use the floored track xy to sample patches.
|
| 68 |
+
|
| 69 |
+
# For example, if the query point xy is (128.16, 252.78),
|
| 70 |
+
# and the patch size is (31, 31),
|
| 71 |
+
# our goal is to extract the content of a rectangle
|
| 72 |
+
# with left top: (113.16, 237.78)
|
| 73 |
+
# and right bottom: (143.16, 267.78).
|
| 74 |
+
# However, we record the floored left top: (113, 237)
|
| 75 |
+
# and the offset (0.16, 0.78)
|
| 76 |
+
# Then what we need is just unfolding the images like in CNN,
|
| 77 |
+
# picking the content at [(113, 237), (143, 267)].
|
| 78 |
+
# Such operations are highly optimized at pytorch
|
| 79 |
+
# (well if you really want to use interpolation, check the function extract_glimpse() below)
|
| 80 |
+
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
content_to_extract = images.reshape(B * S, 3, H, W)
|
| 83 |
+
C_in = content_to_extract.shape[1]
|
| 84 |
+
|
| 85 |
+
# Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html
|
| 86 |
+
# for the detailed explanation of unfold()
|
| 87 |
+
# Here it runs sliding windows (psize x psize) to build patches
|
| 88 |
+
# The shape changes from
|
| 89 |
+
# (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize
|
| 90 |
+
# where Psize is the size of patch
|
| 91 |
+
content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1)
|
| 92 |
+
|
| 93 |
+
# Floor the coarse predictions to get integers and save the fractional/decimal
|
| 94 |
+
track_int = coarse_pred.floor().int()
|
| 95 |
+
track_frac = coarse_pred - track_int
|
| 96 |
+
|
| 97 |
+
# Note the points represent the center of patches
|
| 98 |
+
# now we get the location of the top left corner of patches
|
| 99 |
+
# because the ouput of pytorch unfold are indexed by top left corner
|
| 100 |
+
topleft = track_int - pradius
|
| 101 |
+
topleft_BSN = topleft.clone()
|
| 102 |
+
|
| 103 |
+
# clamp the values so that we will not go out of indexes
|
| 104 |
+
# NOTE: (VERY IMPORTANT: This operation ASSUMES H=W).
|
| 105 |
+
# You need to seperately clamp x and y if H!=W
|
| 106 |
+
topleft = topleft.clamp(0, H - psize)
|
| 107 |
+
|
| 108 |
+
# Reshape from BxSxNx2 -> (B*S)xNx2
|
| 109 |
+
topleft = topleft.reshape(B * S, N, 2)
|
| 110 |
+
|
| 111 |
+
# Prepare batches for indexing, shape: (B*S)xN
|
| 112 |
+
batch_indices = (
|
| 113 |
+
torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device)
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# extracted_patches: (B*S) x N x C_in x Psize x Psize
|
| 117 |
+
extracted_patches = content_to_extract[
|
| 118 |
+
batch_indices, :, topleft[..., 1], topleft[..., 0]
|
| 119 |
+
]
|
| 120 |
+
|
| 121 |
+
if chunk < 0:
|
| 122 |
+
# Extract image patches based on top left corners
|
| 123 |
+
# Feed patches to fine fent for features
|
| 124 |
+
patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize))
|
| 125 |
+
else:
|
| 126 |
+
patches = extracted_patches.reshape(B * S * N, C_in, psize, psize)
|
| 127 |
+
|
| 128 |
+
patch_feat_list = []
|
| 129 |
+
for p in torch.split(patches, chunk):
|
| 130 |
+
patch_feat_list += [fine_fnet(p)]
|
| 131 |
+
patch_feat = torch.cat(patch_feat_list, 0)
|
| 132 |
+
|
| 133 |
+
C_out = patch_feat.shape[1]
|
| 134 |
+
|
| 135 |
+
# Refine the coarse tracks by fine_tracker
|
| 136 |
+
# reshape back to B x S x N x C_out x Psize x Psize
|
| 137 |
+
patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize)
|
| 138 |
+
patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q")
|
| 139 |
+
|
| 140 |
+
# Prepare for the query points for fine tracker
|
| 141 |
+
# They are relative to the patch left top corner,
|
| 142 |
+
# instead of the image top left corner now
|
| 143 |
+
# patch_query_points: N x 1 x 2
|
| 144 |
+
# only 1 here because for each patch we only have 1 query point
|
| 145 |
+
patch_query_points = track_frac[:, 0] + pradius
|
| 146 |
+
patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1)
|
| 147 |
+
|
| 148 |
+
# Feed the PATCH query points and tracks into fine tracker
|
| 149 |
+
fine_pred_track_lists, _, _, query_point_feat = fine_tracker(
|
| 150 |
+
query_points=patch_query_points,
|
| 151 |
+
fmaps=patch_feat,
|
| 152 |
+
iters=fine_iters,
|
| 153 |
+
return_feat=True,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# relative the patch top left
|
| 157 |
+
fine_pred_track = fine_pred_track_lists[-1].clone()
|
| 158 |
+
|
| 159 |
+
# From (relative to the patch top left) to (relative to the image top left)
|
| 160 |
+
for idx in range(len(fine_pred_track_lists)):
|
| 161 |
+
fine_level = rearrange(
|
| 162 |
+
fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N
|
| 163 |
+
)
|
| 164 |
+
fine_level = fine_level.squeeze(-2)
|
| 165 |
+
fine_level = fine_level + topleft_BSN
|
| 166 |
+
fine_pred_track_lists[idx] = fine_level
|
| 167 |
+
|
| 168 |
+
# relative to the image top left
|
| 169 |
+
refined_tracks = fine_pred_track_lists[-1].clone()
|
| 170 |
+
refined_tracks[:, 0] = query_points
|
| 171 |
+
|
| 172 |
+
score = None
|
| 173 |
+
|
| 174 |
+
if compute_score:
|
| 175 |
+
score = compute_score_fn(
|
| 176 |
+
query_point_feat,
|
| 177 |
+
patch_feat,
|
| 178 |
+
fine_pred_track,
|
| 179 |
+
sradius,
|
| 180 |
+
psize,
|
| 181 |
+
B,
|
| 182 |
+
N,
|
| 183 |
+
S,
|
| 184 |
+
C_out,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
return refined_tracks, score
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def refine_track_v0(
|
| 191 |
+
images,
|
| 192 |
+
fine_fnet,
|
| 193 |
+
fine_tracker,
|
| 194 |
+
coarse_pred,
|
| 195 |
+
compute_score=False,
|
| 196 |
+
pradius=15,
|
| 197 |
+
sradius=2,
|
| 198 |
+
fine_iters=6,
|
| 199 |
+
):
|
| 200 |
+
"""
|
| 201 |
+
COPIED FROM VGGSfM
|
| 202 |
+
|
| 203 |
+
Refines the tracking of images using a fine track predictor and a fine feature network.
|
| 204 |
+
Check https://arxiv.org/abs/2312.04563 for more details.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
images (torch.Tensor): The images to be tracked.
|
| 208 |
+
fine_fnet (nn.Module): The fine feature network.
|
| 209 |
+
fine_tracker (nn.Module): The fine track predictor.
|
| 210 |
+
coarse_pred (torch.Tensor): The coarse predictions of tracks.
|
| 211 |
+
compute_score (bool, optional): Whether to compute the score. Defaults to False.
|
| 212 |
+
pradius (int, optional): The radius of a patch. Defaults to 15.
|
| 213 |
+
sradius (int, optional): The search radius. Defaults to 2.
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
torch.Tensor: The refined tracks.
|
| 217 |
+
torch.Tensor, optional: The score.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
# coarse_pred shape: BxSxNx2,
|
| 221 |
+
# where B is the batch, S is the video/images length, and N is the number of tracks
|
| 222 |
+
# now we are going to extract patches with the center at coarse_pred
|
| 223 |
+
# Please note that the last dimension indicates x and y, and hence has a dim number of 2
|
| 224 |
+
B, S, N, _ = coarse_pred.shape
|
| 225 |
+
_, _, _, H, W = images.shape
|
| 226 |
+
|
| 227 |
+
# Given the raidus of a patch, compute the patch size
|
| 228 |
+
psize = pradius * 2 + 1
|
| 229 |
+
|
| 230 |
+
# Note that we assume the first frame is the query frame
|
| 231 |
+
# so the 2D locations of the first frame are the query points
|
| 232 |
+
query_points = coarse_pred[:, 0]
|
| 233 |
+
|
| 234 |
+
# Given 2D positions, we can use grid_sample to extract patches
|
| 235 |
+
# but it takes too much memory.
|
| 236 |
+
# Instead, we use the floored track xy to sample patches.
|
| 237 |
+
|
| 238 |
+
# For example, if the query point xy is (128.16, 252.78),
|
| 239 |
+
# and the patch size is (31, 31),
|
| 240 |
+
# our goal is to extract the content of a rectangle
|
| 241 |
+
# with left top: (113.16, 237.78)
|
| 242 |
+
# and right bottom: (143.16, 267.78).
|
| 243 |
+
# However, we record the floored left top: (113, 237)
|
| 244 |
+
# and the offset (0.16, 0.78)
|
| 245 |
+
# Then what we need is just unfolding the images like in CNN,
|
| 246 |
+
# picking the content at [(113, 237), (143, 267)].
|
| 247 |
+
# Such operations are highly optimized at pytorch
|
| 248 |
+
# (well if you really want to use interpolation, check the function extract_glimpse() below)
|
| 249 |
+
|
| 250 |
+
with torch.no_grad():
|
| 251 |
+
content_to_extract = images.reshape(B * S, 3, H, W)
|
| 252 |
+
C_in = content_to_extract.shape[1]
|
| 253 |
+
|
| 254 |
+
# Please refer to https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html
|
| 255 |
+
# for the detailed explanation of unfold()
|
| 256 |
+
# Here it runs sliding windows (psize x psize) to build patches
|
| 257 |
+
# The shape changes from
|
| 258 |
+
# (B*S)x C_in x H x W to (B*S)x C_in x H_new x W_new x Psize x Psize
|
| 259 |
+
# where Psize is the size of patch
|
| 260 |
+
content_to_extract = content_to_extract.unfold(2, psize, 1).unfold(3, psize, 1)
|
| 261 |
+
|
| 262 |
+
# Floor the coarse predictions to get integers and save the fractional/decimal
|
| 263 |
+
track_int = coarse_pred.floor().int()
|
| 264 |
+
track_frac = coarse_pred - track_int
|
| 265 |
+
|
| 266 |
+
# Note the points represent the center of patches
|
| 267 |
+
# now we get the location of the top left corner of patches
|
| 268 |
+
# because the ouput of pytorch unfold are indexed by top left corner
|
| 269 |
+
topleft = track_int - pradius
|
| 270 |
+
topleft_BSN = topleft.clone()
|
| 271 |
+
|
| 272 |
+
# clamp the values so that we will not go out of indexes
|
| 273 |
+
# NOTE: (VERY IMPORTANT: This operation ASSUMES H=W).
|
| 274 |
+
# You need to seperately clamp x and y if H!=W
|
| 275 |
+
topleft = topleft.clamp(0, H - psize)
|
| 276 |
+
|
| 277 |
+
# Reshape from BxSxNx2 -> (B*S)xNx2
|
| 278 |
+
topleft = topleft.reshape(B * S, N, 2)
|
| 279 |
+
|
| 280 |
+
# Prepare batches for indexing, shape: (B*S)xN
|
| 281 |
+
batch_indices = (
|
| 282 |
+
torch.arange(B * S)[:, None].expand(-1, N).to(content_to_extract.device)
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# Extract image patches based on top left corners
|
| 286 |
+
# extracted_patches: (B*S) x N x C_in x Psize x Psize
|
| 287 |
+
extracted_patches = content_to_extract[
|
| 288 |
+
batch_indices, :, topleft[..., 1], topleft[..., 0]
|
| 289 |
+
]
|
| 290 |
+
|
| 291 |
+
# Feed patches to fine fent for features
|
| 292 |
+
patch_feat = fine_fnet(extracted_patches.reshape(B * S * N, C_in, psize, psize))
|
| 293 |
+
|
| 294 |
+
C_out = patch_feat.shape[1]
|
| 295 |
+
|
| 296 |
+
# Refine the coarse tracks by fine_tracker
|
| 297 |
+
|
| 298 |
+
# reshape back to B x S x N x C_out x Psize x Psize
|
| 299 |
+
patch_feat = patch_feat.reshape(B, S, N, C_out, psize, psize)
|
| 300 |
+
patch_feat = rearrange(patch_feat, "b s n c p q -> (b n) s c p q")
|
| 301 |
+
|
| 302 |
+
# Prepare for the query points for fine tracker
|
| 303 |
+
# They are relative to the patch left top corner,
|
| 304 |
+
# instead of the image top left corner now
|
| 305 |
+
# patch_query_points: N x 1 x 2
|
| 306 |
+
# only 1 here because for each patch we only have 1 query point
|
| 307 |
+
patch_query_points = track_frac[:, 0] + pradius
|
| 308 |
+
patch_query_points = patch_query_points.reshape(B * N, 2).unsqueeze(1)
|
| 309 |
+
|
| 310 |
+
# Feed the PATCH query points and tracks into fine tracker
|
| 311 |
+
fine_pred_track_lists, _, _, query_point_feat = fine_tracker(
|
| 312 |
+
query_points=patch_query_points,
|
| 313 |
+
fmaps=patch_feat,
|
| 314 |
+
iters=fine_iters,
|
| 315 |
+
return_feat=True,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# relative the patch top left
|
| 319 |
+
fine_pred_track = fine_pred_track_lists[-1].clone()
|
| 320 |
+
|
| 321 |
+
# From (relative to the patch top left) to (relative to the image top left)
|
| 322 |
+
for idx in range(len(fine_pred_track_lists)):
|
| 323 |
+
fine_level = rearrange(
|
| 324 |
+
fine_pred_track_lists[idx], "(b n) s u v -> b s n u v", b=B, n=N
|
| 325 |
+
)
|
| 326 |
+
fine_level = fine_level.squeeze(-2)
|
| 327 |
+
fine_level = fine_level + topleft_BSN
|
| 328 |
+
fine_pred_track_lists[idx] = fine_level
|
| 329 |
+
|
| 330 |
+
# relative to the image top left
|
| 331 |
+
refined_tracks = fine_pred_track_lists[-1].clone()
|
| 332 |
+
refined_tracks[:, 0] = query_points
|
| 333 |
+
|
| 334 |
+
score = None
|
| 335 |
+
|
| 336 |
+
if compute_score:
|
| 337 |
+
score = compute_score_fn(
|
| 338 |
+
query_point_feat,
|
| 339 |
+
patch_feat,
|
| 340 |
+
fine_pred_track,
|
| 341 |
+
sradius,
|
| 342 |
+
psize,
|
| 343 |
+
B,
|
| 344 |
+
N,
|
| 345 |
+
S,
|
| 346 |
+
C_out,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
return refined_tracks, score
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
################################## NOTE: NOT USED ##################################
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def compute_score_fn(
|
| 356 |
+
query_point_feat, patch_feat, fine_pred_track, sradius, psize, B, N, S, C_out
|
| 357 |
+
):
|
| 358 |
+
"""
|
| 359 |
+
Compute the scores, i.e., the standard deviation of the 2D similarity heatmaps,
|
| 360 |
+
given the query point features and reference frame feature maps
|
| 361 |
+
"""
|
| 362 |
+
|
| 363 |
+
from kornia.geometry.subpix import dsnt
|
| 364 |
+
from kornia.utils.grid import create_meshgrid
|
| 365 |
+
|
| 366 |
+
# query_point_feat initial shape: B x N x C_out,
|
| 367 |
+
# query_point_feat indicates the feat at the coorponsing query points
|
| 368 |
+
# Therefore we don't have S dimension here
|
| 369 |
+
query_point_feat = query_point_feat.reshape(B, N, C_out)
|
| 370 |
+
# reshape and expand to B x (S-1) x N x C_out
|
| 371 |
+
query_point_feat = query_point_feat.unsqueeze(1).expand(-1, S - 1, -1, -1)
|
| 372 |
+
# and reshape to (B*(S-1)*N) x C_out
|
| 373 |
+
query_point_feat = query_point_feat.reshape(B * (S - 1) * N, C_out)
|
| 374 |
+
|
| 375 |
+
# Radius and size for computing the score
|
| 376 |
+
ssize = sradius * 2 + 1
|
| 377 |
+
|
| 378 |
+
# Reshape, you know it, so many reshaping operations
|
| 379 |
+
patch_feat = rearrange(patch_feat, "(b n) s c p q -> b s n c p q", b=B, n=N)
|
| 380 |
+
|
| 381 |
+
# Again, we unfold the patches to smaller patches
|
| 382 |
+
# so that we can then focus on smaller patches
|
| 383 |
+
# patch_feat_unfold shape:
|
| 384 |
+
# B x S x N x C_out x (psize - 2*sradius) x (psize - 2*sradius) x ssize x ssize
|
| 385 |
+
# well a bit scary, but actually not
|
| 386 |
+
patch_feat_unfold = patch_feat.unfold(4, ssize, 1).unfold(5, ssize, 1)
|
| 387 |
+
|
| 388 |
+
# Do the same stuffs above, i.e., the same as extracting patches
|
| 389 |
+
fine_prediction_floor = fine_pred_track.floor().int()
|
| 390 |
+
fine_level_floor_topleft = fine_prediction_floor - sradius
|
| 391 |
+
|
| 392 |
+
# Clamp to ensure the smaller patch is valid
|
| 393 |
+
fine_level_floor_topleft = fine_level_floor_topleft.clamp(0, psize - ssize)
|
| 394 |
+
fine_level_floor_topleft = fine_level_floor_topleft.squeeze(2)
|
| 395 |
+
|
| 396 |
+
# Prepare the batch indices and xy locations
|
| 397 |
+
|
| 398 |
+
batch_indices_score = torch.arange(B)[:, None, None].expand(-1, S, N) # BxSxN
|
| 399 |
+
batch_indices_score = batch_indices_score.reshape(-1).to(
|
| 400 |
+
patch_feat_unfold.device
|
| 401 |
+
) # B*S*N
|
| 402 |
+
y_indices = fine_level_floor_topleft[..., 0].flatten() # Flatten H indices
|
| 403 |
+
x_indices = fine_level_floor_topleft[..., 1].flatten() # Flatten W indices
|
| 404 |
+
|
| 405 |
+
reference_frame_feat = patch_feat_unfold.reshape(
|
| 406 |
+
B * S * N, C_out, psize - sradius * 2, psize - sradius * 2, ssize, ssize
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
# Note again, according to pytorch convention
|
| 410 |
+
# x_indices cooresponds to [..., 1] and y_indices cooresponds to [..., 0]
|
| 411 |
+
reference_frame_feat = reference_frame_feat[
|
| 412 |
+
batch_indices_score, :, x_indices, y_indices
|
| 413 |
+
]
|
| 414 |
+
reference_frame_feat = reference_frame_feat.reshape(B, S, N, C_out, ssize, ssize)
|
| 415 |
+
# pick the frames other than the first one, so we have S-1 frames here
|
| 416 |
+
reference_frame_feat = reference_frame_feat[:, 1:].reshape(
|
| 417 |
+
B * (S - 1) * N, C_out, ssize * ssize
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
# Compute similarity
|
| 421 |
+
sim_matrix = torch.einsum("mc,mcr->mr", query_point_feat, reference_frame_feat)
|
| 422 |
+
softmax_temp = 1.0 / C_out**0.5
|
| 423 |
+
heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1)
|
| 424 |
+
# 2D heatmaps
|
| 425 |
+
heatmap = heatmap.reshape(B * (S - 1) * N, ssize, ssize) # * x ssize x ssize
|
| 426 |
+
|
| 427 |
+
coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0]
|
| 428 |
+
grid_normalized = create_meshgrid(
|
| 429 |
+
ssize, ssize, normalized_coordinates=True, device=heatmap.device
|
| 430 |
+
).reshape(1, -1, 2)
|
| 431 |
+
|
| 432 |
+
var = (
|
| 433 |
+
torch.sum(grid_normalized**2 * heatmap.view(-1, ssize * ssize, 1), dim=1)
|
| 434 |
+
- coords_normalized**2
|
| 435 |
+
)
|
| 436 |
+
std = torch.sum(
|
| 437 |
+
torch.sqrt(torch.clamp(var, min=1e-10)), -1
|
| 438 |
+
) # clamp needed for numerical stability
|
| 439 |
+
|
| 440 |
+
score = std.reshape(B, S - 1, N)
|
| 441 |
+
# set score as 1 for the query frame
|
| 442 |
+
score = torch.cat([torch.ones_like(score[:, 0:1]), score], dim=1)
|
| 443 |
+
|
| 444 |
+
return score
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def extract_glimpse(
|
| 448 |
+
tensor: torch.Tensor,
|
| 449 |
+
size: Tuple[int, int],
|
| 450 |
+
offsets,
|
| 451 |
+
mode="bilinear",
|
| 452 |
+
padding_mode="zeros",
|
| 453 |
+
debug=False,
|
| 454 |
+
orib=None,
|
| 455 |
+
):
|
| 456 |
+
B, C, W, H = tensor.shape
|
| 457 |
+
|
| 458 |
+
h, w = size
|
| 459 |
+
xs = torch.arange(0, w, dtype=tensor.dtype, device=tensor.device) - (w - 1) / 2.0
|
| 460 |
+
ys = torch.arange(0, h, dtype=tensor.dtype, device=tensor.device) - (h - 1) / 2.0
|
| 461 |
+
|
| 462 |
+
vy, vx = torch.meshgrid(ys, xs)
|
| 463 |
+
grid = torch.stack([vx, vy], dim=-1) # h, w, 2
|
| 464 |
+
grid = grid[None]
|
| 465 |
+
|
| 466 |
+
B, N, _ = offsets.shape
|
| 467 |
+
|
| 468 |
+
offsets = offsets.reshape((B * N), 1, 1, 2)
|
| 469 |
+
offsets_grid = offsets + grid
|
| 470 |
+
|
| 471 |
+
# normalised grid to [-1, 1]
|
| 472 |
+
offsets_grid = (
|
| 473 |
+
offsets_grid - offsets_grid.new_tensor([W / 2, H / 2])
|
| 474 |
+
) / offsets_grid.new_tensor([W / 2, H / 2])
|
| 475 |
+
|
| 476 |
+
# BxCxHxW -> Bx1xCxHxW
|
| 477 |
+
tensor = tensor[:, None]
|
| 478 |
+
|
| 479 |
+
# Bx1xCxHxW -> BxNxCxHxW
|
| 480 |
+
tensor = tensor.expand(-1, N, -1, -1, -1)
|
| 481 |
+
|
| 482 |
+
# BxNxCxHxW -> (B*N)xCxHxW
|
| 483 |
+
tensor = tensor.reshape((B * N), C, W, H)
|
| 484 |
+
|
| 485 |
+
sampled = torch.nn.functional.grid_sample(
|
| 486 |
+
tensor, offsets_grid, mode=mode, align_corners=False, padding_mode=padding_mode
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
# NOTE: I am not sure it should be h, w or w, h here
|
| 490 |
+
# but okay for sqaures
|
| 491 |
+
sampled = sampled.reshape(B, N, C, h, w)
|
| 492 |
+
|
| 493 |
+
return sampled
|
sailrecon/dependency/track_modules/utils.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Modified from https://github.com/facebookresearch/PoseDiffusion
|
| 8 |
+
# and https://github.com/facebookresearch/co-tracker/tree/main
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from typing import Optional, Tuple, Union
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from einops import rearrange, repeat
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_2d_sincos_pos_embed(
|
| 20 |
+
embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False
|
| 21 |
+
) -> torch.Tensor:
|
| 22 |
+
"""
|
| 23 |
+
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
|
| 24 |
+
It is a wrapper of get_2d_sincos_pos_embed_from_grid.
|
| 25 |
+
Args:
|
| 26 |
+
- embed_dim: The embedding dimension.
|
| 27 |
+
- grid_size: The grid size.
|
| 28 |
+
Returns:
|
| 29 |
+
- pos_embed: The generated 2D positional embedding.
|
| 30 |
+
"""
|
| 31 |
+
if isinstance(grid_size, tuple):
|
| 32 |
+
grid_size_h, grid_size_w = grid_size
|
| 33 |
+
else:
|
| 34 |
+
grid_size_h = grid_size_w = grid_size
|
| 35 |
+
grid_h = torch.arange(grid_size_h, dtype=torch.float)
|
| 36 |
+
grid_w = torch.arange(grid_size_w, dtype=torch.float)
|
| 37 |
+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
| 38 |
+
grid = torch.stack(grid, dim=0)
|
| 39 |
+
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
| 40 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 41 |
+
if return_grid:
|
| 42 |
+
return (
|
| 43 |
+
pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2),
|
| 44 |
+
grid,
|
| 45 |
+
)
|
| 46 |
+
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_2d_sincos_pos_embed_from_grid(
|
| 50 |
+
embed_dim: int, grid: torch.Tensor
|
| 51 |
+
) -> torch.Tensor:
|
| 52 |
+
"""
|
| 53 |
+
This function generates a 2D positional embedding from a given grid using sine and cosine functions.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
- embed_dim: The embedding dimension.
|
| 57 |
+
- grid: The grid to generate the embedding from.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
- emb: The generated 2D positional embedding.
|
| 61 |
+
"""
|
| 62 |
+
assert embed_dim % 2 == 0
|
| 63 |
+
|
| 64 |
+
# use half of dimensions to encode grid_h
|
| 65 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 66 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 67 |
+
|
| 68 |
+
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
|
| 69 |
+
return emb
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_1d_sincos_pos_embed_from_grid(
|
| 73 |
+
embed_dim: int, pos: torch.Tensor
|
| 74 |
+
) -> torch.Tensor:
|
| 75 |
+
"""
|
| 76 |
+
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
- embed_dim: The embedding dimension.
|
| 80 |
+
- pos: The position to generate the embedding from.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
- emb: The generated 1D positional embedding.
|
| 84 |
+
"""
|
| 85 |
+
assert embed_dim % 2 == 0
|
| 86 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.double)
|
| 87 |
+
omega /= embed_dim / 2.0
|
| 88 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 89 |
+
|
| 90 |
+
pos = pos.reshape(-1) # (M,)
|
| 91 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 92 |
+
|
| 93 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
| 94 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
| 95 |
+
|
| 96 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
| 97 |
+
return emb[None].float()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
|
| 101 |
+
"""
|
| 102 |
+
This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
- xy: The coordinates to generate the embedding from.
|
| 106 |
+
- C: The size of the embedding.
|
| 107 |
+
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
- pe: The generated 2D positional embedding.
|
| 111 |
+
"""
|
| 112 |
+
B, N, D = xy.shape
|
| 113 |
+
assert D == 2
|
| 114 |
+
|
| 115 |
+
x = xy[:, :, 0:1]
|
| 116 |
+
y = xy[:, :, 1:2]
|
| 117 |
+
div_term = (
|
| 118 |
+
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
|
| 119 |
+
).reshape(1, 1, int(C / 2))
|
| 120 |
+
|
| 121 |
+
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
| 122 |
+
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
| 123 |
+
|
| 124 |
+
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
| 125 |
+
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
| 126 |
+
|
| 127 |
+
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
| 128 |
+
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
| 129 |
+
|
| 130 |
+
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
|
| 131 |
+
if cat_coords:
|
| 132 |
+
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
|
| 133 |
+
return pe
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
|
| 137 |
+
r"""Sample a tensor using bilinear interpolation
|
| 138 |
+
|
| 139 |
+
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
| 140 |
+
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
| 141 |
+
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
| 142 |
+
convention.
|
| 143 |
+
|
| 144 |
+
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
| 145 |
+
:math:`B` is the batch size, :math:`C` is the number of channels,
|
| 146 |
+
:math:`H` is the height of the image, and :math:`W` is the width of the
|
| 147 |
+
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
| 148 |
+
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
| 149 |
+
|
| 150 |
+
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
| 151 |
+
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
| 152 |
+
that in this case the order of the components is slightly different
|
| 153 |
+
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
| 154 |
+
|
| 155 |
+
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
| 156 |
+
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
| 157 |
+
left-most image pixel :math:`W-1` to the center of the right-most
|
| 158 |
+
pixel.
|
| 159 |
+
|
| 160 |
+
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
| 161 |
+
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
| 162 |
+
the left-most pixel :math:`W` to the right edge of the right-most
|
| 163 |
+
pixel.
|
| 164 |
+
|
| 165 |
+
Similar conventions apply to the :math:`y` for the range
|
| 166 |
+
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
| 167 |
+
:math:`[0,T-1]` and :math:`[0,T]`.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
input (Tensor): batch of input images.
|
| 171 |
+
coords (Tensor): batch of coordinates.
|
| 172 |
+
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
| 173 |
+
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
Tensor: sampled points.
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
sizes = input.shape[2:]
|
| 180 |
+
|
| 181 |
+
assert len(sizes) in [2, 3]
|
| 182 |
+
|
| 183 |
+
if len(sizes) == 3:
|
| 184 |
+
# t x y -> x y t to match dimensions T H W in grid_sample
|
| 185 |
+
coords = coords[..., [1, 2, 0]]
|
| 186 |
+
|
| 187 |
+
if align_corners:
|
| 188 |
+
coords = coords * torch.tensor(
|
| 189 |
+
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
|
| 190 |
+
)
|
| 191 |
+
else:
|
| 192 |
+
coords = coords * torch.tensor(
|
| 193 |
+
[2 / size for size in reversed(sizes)], device=coords.device
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
coords -= 1
|
| 197 |
+
|
| 198 |
+
return F.grid_sample(
|
| 199 |
+
input, coords, align_corners=align_corners, padding_mode=padding_mode
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def sample_features4d(input, coords):
|
| 204 |
+
r"""Sample spatial features
|
| 205 |
+
|
| 206 |
+
`sample_features4d(input, coords)` samples the spatial features
|
| 207 |
+
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
|
| 208 |
+
|
| 209 |
+
The field is sampled at coordinates :attr:`coords` using bilinear
|
| 210 |
+
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
|
| 211 |
+
2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
|
| 212 |
+
same convention as :func:`bilinear_sampler` with `align_corners=True`.
|
| 213 |
+
|
| 214 |
+
The output tensor has one feature per point, and has shape :math:`(B,
|
| 215 |
+
R, C)`.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
input (Tensor): spatial features.
|
| 219 |
+
coords (Tensor): points.
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
Tensor: sampled features.
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
B, _, _, _ = input.shape
|
| 226 |
+
|
| 227 |
+
# B R 2 -> B R 1 2
|
| 228 |
+
coords = coords.unsqueeze(2)
|
| 229 |
+
|
| 230 |
+
# B C R 1
|
| 231 |
+
feats = bilinear_sampler(input, coords)
|
| 232 |
+
|
| 233 |
+
return feats.permute(0, 2, 1, 3).view(
|
| 234 |
+
B, -1, feats.shape[1] * feats.shape[3]
|
| 235 |
+
) # B C R 1 -> B R C
|
sailrecon/dependency/track_predict.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from .vggsfm_utils import *
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def predict_tracks(
|
| 14 |
+
images,
|
| 15 |
+
conf=None,
|
| 16 |
+
points_3d=None,
|
| 17 |
+
masks=None,
|
| 18 |
+
max_query_pts=2048,
|
| 19 |
+
query_frame_num=5,
|
| 20 |
+
keypoint_extractor="aliked+sp",
|
| 21 |
+
max_points_num=163840,
|
| 22 |
+
fine_tracking=True,
|
| 23 |
+
complete_non_vis=True,
|
| 24 |
+
):
|
| 25 |
+
"""
|
| 26 |
+
Predict tracks for the given images and masks.
|
| 27 |
+
|
| 28 |
+
TODO: support non-square images
|
| 29 |
+
TODO: support masks
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
This function predicts the tracks for the given images and masks using the specified query method
|
| 33 |
+
and track predictor. It finds query points, and predicts the tracks, visibility, and scores for the query frames.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
images: Tensor of shape [S, 3, H, W] containing the input images.
|
| 37 |
+
conf: Tensor of shape [S, 1, H, W] containing the confidence scores. Default is None.
|
| 38 |
+
points_3d: Tensor containing 3D points. Default is None.
|
| 39 |
+
masks: Optional tensor of shape [S, 1, H, W] containing masks. Default is None.
|
| 40 |
+
max_query_pts: Maximum number of query points. Default is 2048.
|
| 41 |
+
query_frame_num: Number of query frames to use. Default is 5.
|
| 42 |
+
keypoint_extractor: Method for keypoint extraction. Default is "aliked+sp".
|
| 43 |
+
max_points_num: Maximum number of points to process at once. Default is 163840.
|
| 44 |
+
fine_tracking: Whether to use fine tracking. Default is True.
|
| 45 |
+
complete_non_vis: Whether to augment non-visible frames. Default is True.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
pred_tracks: Numpy array containing the predicted tracks.
|
| 49 |
+
pred_vis_scores: Numpy array containing the visibility scores for the tracks.
|
| 50 |
+
pred_confs: Numpy array containing the confidence scores for the tracks.
|
| 51 |
+
pred_points_3d: Numpy array containing the 3D points for the tracks.
|
| 52 |
+
pred_colors: Numpy array containing the point colors for the tracks. (0, 255)
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
device = images.device
|
| 56 |
+
dtype = images.dtype
|
| 57 |
+
tracker = build_vggsfm_tracker().to(device, dtype)
|
| 58 |
+
|
| 59 |
+
# Find query frames
|
| 60 |
+
query_frame_indexes = generate_rank_by_dino(
|
| 61 |
+
images, query_frame_num=query_frame_num, device=device
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Add the first image to the front if not already present
|
| 65 |
+
if 0 in query_frame_indexes:
|
| 66 |
+
query_frame_indexes.remove(0)
|
| 67 |
+
query_frame_indexes = [0, *query_frame_indexes]
|
| 68 |
+
|
| 69 |
+
# TODO: add the functionality to handle the masks
|
| 70 |
+
keypoint_extractors = initialize_feature_extractors(
|
| 71 |
+
max_query_pts, extractor_method=keypoint_extractor, device=device
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
pred_tracks = []
|
| 75 |
+
pred_vis_scores = []
|
| 76 |
+
pred_confs = []
|
| 77 |
+
pred_points_3d = []
|
| 78 |
+
pred_colors = []
|
| 79 |
+
|
| 80 |
+
fmaps_for_tracker = tracker.process_images_to_fmaps(images)
|
| 81 |
+
|
| 82 |
+
if fine_tracking:
|
| 83 |
+
print("For faster inference, consider disabling fine_tracking")
|
| 84 |
+
|
| 85 |
+
for query_index in query_frame_indexes:
|
| 86 |
+
print(f"Predicting tracks for query frame {query_index}")
|
| 87 |
+
pred_track, pred_vis, pred_conf, pred_point_3d, pred_color = _forward_on_query(
|
| 88 |
+
query_index,
|
| 89 |
+
images,
|
| 90 |
+
conf,
|
| 91 |
+
points_3d,
|
| 92 |
+
fmaps_for_tracker,
|
| 93 |
+
keypoint_extractors,
|
| 94 |
+
tracker,
|
| 95 |
+
max_points_num,
|
| 96 |
+
fine_tracking,
|
| 97 |
+
device,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
pred_tracks.append(pred_track)
|
| 101 |
+
pred_vis_scores.append(pred_vis)
|
| 102 |
+
pred_confs.append(pred_conf)
|
| 103 |
+
pred_points_3d.append(pred_point_3d)
|
| 104 |
+
pred_colors.append(pred_color)
|
| 105 |
+
|
| 106 |
+
if complete_non_vis:
|
| 107 |
+
(
|
| 108 |
+
pred_tracks,
|
| 109 |
+
pred_vis_scores,
|
| 110 |
+
pred_confs,
|
| 111 |
+
pred_points_3d,
|
| 112 |
+
pred_colors,
|
| 113 |
+
) = _augment_non_visible_frames(
|
| 114 |
+
pred_tracks,
|
| 115 |
+
pred_vis_scores,
|
| 116 |
+
pred_confs,
|
| 117 |
+
pred_points_3d,
|
| 118 |
+
pred_colors,
|
| 119 |
+
images,
|
| 120 |
+
conf,
|
| 121 |
+
points_3d,
|
| 122 |
+
fmaps_for_tracker,
|
| 123 |
+
keypoint_extractors,
|
| 124 |
+
tracker,
|
| 125 |
+
max_points_num,
|
| 126 |
+
fine_tracking,
|
| 127 |
+
min_vis=500,
|
| 128 |
+
non_vis_thresh=0.1,
|
| 129 |
+
device=device,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
pred_tracks = np.concatenate(pred_tracks, axis=1)
|
| 133 |
+
pred_vis_scores = np.concatenate(pred_vis_scores, axis=1)
|
| 134 |
+
pred_confs = np.concatenate(pred_confs, axis=0) if pred_confs else None
|
| 135 |
+
pred_points_3d = np.concatenate(pred_points_3d, axis=0) if pred_points_3d else None
|
| 136 |
+
pred_colors = np.concatenate(pred_colors, axis=0) if pred_colors else None
|
| 137 |
+
|
| 138 |
+
# from vggt.utils.visual_track import visualize_tracks_on_images
|
| 139 |
+
# visualize_tracks_on_images(images[None], torch.from_numpy(pred_tracks[None]), torch.from_numpy(pred_vis_scores[None])>0.2, out_dir="track_visuals")
|
| 140 |
+
|
| 141 |
+
return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _forward_on_query(
|
| 145 |
+
query_index,
|
| 146 |
+
images,
|
| 147 |
+
conf,
|
| 148 |
+
points_3d,
|
| 149 |
+
fmaps_for_tracker,
|
| 150 |
+
keypoint_extractors,
|
| 151 |
+
tracker,
|
| 152 |
+
max_points_num,
|
| 153 |
+
fine_tracking,
|
| 154 |
+
device,
|
| 155 |
+
):
|
| 156 |
+
"""
|
| 157 |
+
Process a single query frame for track prediction.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
query_index: Index of the query frame
|
| 161 |
+
images: Tensor of shape [S, 3, H, W] containing the input images
|
| 162 |
+
conf: Confidence tensor
|
| 163 |
+
points_3d: 3D points tensor
|
| 164 |
+
fmaps_for_tracker: Feature maps for the tracker
|
| 165 |
+
keypoint_extractors: Initialized feature extractors
|
| 166 |
+
tracker: VGG-SFM tracker
|
| 167 |
+
max_points_num: Maximum number of points to process at once
|
| 168 |
+
fine_tracking: Whether to use fine tracking
|
| 169 |
+
device: Device to use for computation
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
pred_track: Predicted tracks
|
| 173 |
+
pred_vis: Visibility scores for the tracks
|
| 174 |
+
pred_conf: Confidence scores for the tracks
|
| 175 |
+
pred_point_3d: 3D points for the tracks
|
| 176 |
+
pred_color: Point colors for the tracks (0, 255)
|
| 177 |
+
"""
|
| 178 |
+
frame_num, _, height, width = images.shape
|
| 179 |
+
|
| 180 |
+
query_image = images[query_index]
|
| 181 |
+
query_points = extract_keypoints(
|
| 182 |
+
query_image, keypoint_extractors, round_keypoints=False
|
| 183 |
+
)
|
| 184 |
+
query_points = query_points[:, torch.randperm(query_points.shape[1], device=device)]
|
| 185 |
+
|
| 186 |
+
# Extract the color at the keypoint locations
|
| 187 |
+
query_points_long = query_points.squeeze(0).round().long()
|
| 188 |
+
pred_color = images[query_index][
|
| 189 |
+
:, query_points_long[:, 1], query_points_long[:, 0]
|
| 190 |
+
]
|
| 191 |
+
pred_color = (pred_color.permute(1, 0).cpu().numpy() * 255).astype(np.uint8)
|
| 192 |
+
|
| 193 |
+
# Query the confidence and points_3d at the keypoint locations
|
| 194 |
+
if (conf is not None) and (points_3d is not None):
|
| 195 |
+
assert height == width
|
| 196 |
+
assert conf.shape[-2] == conf.shape[-1]
|
| 197 |
+
assert conf.shape[:3] == points_3d.shape[:3]
|
| 198 |
+
scale = conf.shape[-1] / width
|
| 199 |
+
|
| 200 |
+
query_points_scaled = (query_points.squeeze(0) * scale).round().long()
|
| 201 |
+
query_points_scaled = query_points_scaled.cpu().numpy()
|
| 202 |
+
|
| 203 |
+
pred_conf = conf[query_index][
|
| 204 |
+
query_points_scaled[:, 1], query_points_scaled[:, 0]
|
| 205 |
+
]
|
| 206 |
+
pred_point_3d = points_3d[query_index][
|
| 207 |
+
query_points_scaled[:, 1], query_points_scaled[:, 0]
|
| 208 |
+
]
|
| 209 |
+
|
| 210 |
+
# heuristic to remove low confidence points
|
| 211 |
+
# should I export this as an input parameter?
|
| 212 |
+
valid_mask = pred_conf > 1.2
|
| 213 |
+
if valid_mask.sum() > 512:
|
| 214 |
+
query_points = query_points[:, valid_mask] # Make sure shape is compatible
|
| 215 |
+
pred_conf = pred_conf[valid_mask]
|
| 216 |
+
pred_point_3d = pred_point_3d[valid_mask]
|
| 217 |
+
pred_color = pred_color[valid_mask]
|
| 218 |
+
else:
|
| 219 |
+
pred_conf = None
|
| 220 |
+
pred_point_3d = None
|
| 221 |
+
|
| 222 |
+
reorder_index = calculate_index_mappings(query_index, frame_num, device=device)
|
| 223 |
+
|
| 224 |
+
images_feed, fmaps_feed = switch_tensor_order(
|
| 225 |
+
[images, fmaps_for_tracker], reorder_index, dim=0
|
| 226 |
+
)
|
| 227 |
+
images_feed = images_feed[None] # add batch dimension
|
| 228 |
+
fmaps_feed = fmaps_feed[None] # add batch dimension
|
| 229 |
+
|
| 230 |
+
all_points_num = images_feed.shape[1] * query_points.shape[1]
|
| 231 |
+
|
| 232 |
+
# Don't need to be scared, this is just chunking to make GPU happy
|
| 233 |
+
if all_points_num > max_points_num:
|
| 234 |
+
num_splits = (all_points_num + max_points_num - 1) // max_points_num
|
| 235 |
+
query_points = torch.chunk(query_points, num_splits, dim=1)
|
| 236 |
+
else:
|
| 237 |
+
query_points = [query_points]
|
| 238 |
+
|
| 239 |
+
pred_track, pred_vis, _ = predict_tracks_in_chunks(
|
| 240 |
+
tracker, images_feed, query_points, fmaps_feed, fine_tracking=fine_tracking
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
pred_track, pred_vis = switch_tensor_order(
|
| 244 |
+
[pred_track, pred_vis], reorder_index, dim=1
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
pred_track = pred_track.squeeze(0).float().cpu().numpy()
|
| 248 |
+
pred_vis = pred_vis.squeeze(0).float().cpu().numpy()
|
| 249 |
+
|
| 250 |
+
return pred_track, pred_vis, pred_conf, pred_point_3d, pred_color
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def _augment_non_visible_frames(
|
| 254 |
+
pred_tracks: list, # ← running list of np.ndarrays
|
| 255 |
+
pred_vis_scores: list, # ← running list of np.ndarrays
|
| 256 |
+
pred_confs: list, # ← running list of np.ndarrays for confidence scores
|
| 257 |
+
pred_points_3d: list, # ← running list of np.ndarrays for 3D points
|
| 258 |
+
pred_colors: list, # ← running list of np.ndarrays for colors
|
| 259 |
+
images: torch.Tensor,
|
| 260 |
+
conf,
|
| 261 |
+
points_3d,
|
| 262 |
+
fmaps_for_tracker,
|
| 263 |
+
keypoint_extractors,
|
| 264 |
+
tracker,
|
| 265 |
+
max_points_num: int,
|
| 266 |
+
fine_tracking: bool,
|
| 267 |
+
*,
|
| 268 |
+
min_vis: int = 500,
|
| 269 |
+
non_vis_thresh: float = 0.1,
|
| 270 |
+
device: torch.device = None,
|
| 271 |
+
):
|
| 272 |
+
"""
|
| 273 |
+
Augment tracking for frames with insufficient visibility.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
pred_tracks: List of numpy arrays containing predicted tracks.
|
| 277 |
+
pred_vis_scores: List of numpy arrays containing visibility scores.
|
| 278 |
+
pred_confs: List of numpy arrays containing confidence scores.
|
| 279 |
+
pred_points_3d: List of numpy arrays containing 3D points.
|
| 280 |
+
pred_colors: List of numpy arrays containing point colors.
|
| 281 |
+
images: Tensor of shape [S, 3, H, W] containing the input images.
|
| 282 |
+
conf: Tensor of shape [S, 1, H, W] containing confidence scores
|
| 283 |
+
points_3d: Tensor containing 3D points
|
| 284 |
+
fmaps_for_tracker: Feature maps for the tracker
|
| 285 |
+
keypoint_extractors: Initialized feature extractors
|
| 286 |
+
tracker: VGG-SFM tracker
|
| 287 |
+
max_points_num: Maximum number of points to process at once
|
| 288 |
+
fine_tracking: Whether to use fine tracking
|
| 289 |
+
min_vis: Minimum visibility threshold
|
| 290 |
+
non_vis_thresh: Non-visibility threshold
|
| 291 |
+
device: Device to use for computation
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
Updated pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, and pred_colors lists.
|
| 295 |
+
"""
|
| 296 |
+
last_query = -1
|
| 297 |
+
final_trial = False
|
| 298 |
+
cur_extractors = keypoint_extractors # may be replaced on the final trial
|
| 299 |
+
|
| 300 |
+
while True:
|
| 301 |
+
# Visibility per frame
|
| 302 |
+
vis_array = np.concatenate(pred_vis_scores, axis=1)
|
| 303 |
+
|
| 304 |
+
# Count frames with sufficient visibility using numpy
|
| 305 |
+
sufficient_vis_count = (vis_array > non_vis_thresh).sum(axis=-1)
|
| 306 |
+
non_vis_frames = np.where(sufficient_vis_count < min_vis)[0].tolist()
|
| 307 |
+
|
| 308 |
+
if len(non_vis_frames) == 0:
|
| 309 |
+
break
|
| 310 |
+
|
| 311 |
+
print("Processing non visible frames:", non_vis_frames)
|
| 312 |
+
|
| 313 |
+
# Decide the frames & extractor for this round
|
| 314 |
+
if non_vis_frames[0] == last_query:
|
| 315 |
+
# Same frame failed twice - final "all-in" attempt
|
| 316 |
+
final_trial = True
|
| 317 |
+
cur_extractors = initialize_feature_extractors(
|
| 318 |
+
2048, extractor_method="sp+sift+aliked", device=device
|
| 319 |
+
)
|
| 320 |
+
query_frame_list = non_vis_frames # blast them all at once
|
| 321 |
+
else:
|
| 322 |
+
query_frame_list = [non_vis_frames[0]] # Process one at a time
|
| 323 |
+
|
| 324 |
+
last_query = non_vis_frames[0]
|
| 325 |
+
|
| 326 |
+
# Run the tracker for every selected frame
|
| 327 |
+
for query_index in query_frame_list:
|
| 328 |
+
new_track, new_vis, new_conf, new_point_3d, new_color = _forward_on_query(
|
| 329 |
+
query_index,
|
| 330 |
+
images,
|
| 331 |
+
conf,
|
| 332 |
+
points_3d,
|
| 333 |
+
fmaps_for_tracker,
|
| 334 |
+
cur_extractors,
|
| 335 |
+
tracker,
|
| 336 |
+
max_points_num,
|
| 337 |
+
fine_tracking,
|
| 338 |
+
device,
|
| 339 |
+
)
|
| 340 |
+
pred_tracks.append(new_track)
|
| 341 |
+
pred_vis_scores.append(new_vis)
|
| 342 |
+
pred_confs.append(new_conf)
|
| 343 |
+
pred_points_3d.append(new_point_3d)
|
| 344 |
+
pred_colors.append(new_color)
|
| 345 |
+
|
| 346 |
+
if final_trial:
|
| 347 |
+
break # Stop after final attempt
|
| 348 |
+
|
| 349 |
+
return pred_tracks, pred_vis_scores, pred_confs, pred_points_3d, pred_colors
|
sailrecon/dependency/vggsfm_tracker.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from functools import partial
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from einops import rearrange, repeat
|
| 15 |
+
from einops.layers.torch import Rearrange, Reduce
|
| 16 |
+
from hydra.utils import instantiate
|
| 17 |
+
from omegaconf import OmegaConf
|
| 18 |
+
from torch import einsum, nn
|
| 19 |
+
|
| 20 |
+
from .track_modules.base_track_predictor import BaseTrackerPredictor
|
| 21 |
+
from .track_modules.blocks import BasicEncoder, ShallowEncoder
|
| 22 |
+
from .track_modules.track_refine import refine_track
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TrackerPredictor(nn.Module):
|
| 26 |
+
def __init__(self, **extra_args):
|
| 27 |
+
super(TrackerPredictor, self).__init__()
|
| 28 |
+
"""
|
| 29 |
+
Initializes the tracker predictor.
|
| 30 |
+
|
| 31 |
+
Both coarse_predictor and fine_predictor are constructed as a BaseTrackerPredictor,
|
| 32 |
+
check track_modules/base_track_predictor.py
|
| 33 |
+
|
| 34 |
+
Both coarse_fnet and fine_fnet are constructed as a 2D CNN network
|
| 35 |
+
check track_modules/blocks.py for BasicEncoder and ShallowEncoder
|
| 36 |
+
"""
|
| 37 |
+
# Define coarse predictor configuration
|
| 38 |
+
coarse_stride = 4
|
| 39 |
+
self.coarse_down_ratio = 2
|
| 40 |
+
|
| 41 |
+
# Create networks directly instead of using instantiate
|
| 42 |
+
self.coarse_fnet = BasicEncoder(stride=coarse_stride)
|
| 43 |
+
self.coarse_predictor = BaseTrackerPredictor(stride=coarse_stride)
|
| 44 |
+
|
| 45 |
+
# Create fine predictor with stride = 1
|
| 46 |
+
self.fine_fnet = ShallowEncoder(stride=1)
|
| 47 |
+
self.fine_predictor = BaseTrackerPredictor(
|
| 48 |
+
stride=1,
|
| 49 |
+
depth=4,
|
| 50 |
+
corr_levels=3,
|
| 51 |
+
corr_radius=3,
|
| 52 |
+
latent_dim=32,
|
| 53 |
+
hidden_size=256,
|
| 54 |
+
fine=True,
|
| 55 |
+
use_spaceatt=False,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def forward(
|
| 59 |
+
self,
|
| 60 |
+
images,
|
| 61 |
+
query_points,
|
| 62 |
+
fmaps=None,
|
| 63 |
+
coarse_iters=6,
|
| 64 |
+
inference=True,
|
| 65 |
+
fine_tracking=True,
|
| 66 |
+
fine_chunk=40960,
|
| 67 |
+
):
|
| 68 |
+
"""
|
| 69 |
+
Args:
|
| 70 |
+
images (torch.Tensor): Images as RGB, in the range of [0, 1], with a shape of B x S x 3 x H x W.
|
| 71 |
+
query_points (torch.Tensor): 2D xy of query points, relative to top left, with a shape of B x N x 2.
|
| 72 |
+
fmaps (torch.Tensor, optional): Precomputed feature maps. Defaults to None.
|
| 73 |
+
coarse_iters (int, optional): Number of iterations for coarse prediction. Defaults to 6.
|
| 74 |
+
inference (bool, optional): Whether to perform inference. Defaults to True.
|
| 75 |
+
fine_tracking (bool, optional): Whether to perform fine tracking. Defaults to True.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
tuple: A tuple containing fine_pred_track, coarse_pred_track, pred_vis, and pred_score.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
if fmaps is None:
|
| 82 |
+
batch_num, frame_num, image_dim, height, width = images.shape
|
| 83 |
+
reshaped_image = images.reshape(
|
| 84 |
+
batch_num * frame_num, image_dim, height, width
|
| 85 |
+
)
|
| 86 |
+
fmaps = self.process_images_to_fmaps(reshaped_image)
|
| 87 |
+
fmaps = fmaps.reshape(
|
| 88 |
+
batch_num, frame_num, -1, fmaps.shape[-2], fmaps.shape[-1]
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
if inference:
|
| 92 |
+
torch.cuda.empty_cache()
|
| 93 |
+
|
| 94 |
+
# Coarse prediction
|
| 95 |
+
coarse_pred_track_lists, pred_vis = self.coarse_predictor(
|
| 96 |
+
query_points=query_points,
|
| 97 |
+
fmaps=fmaps,
|
| 98 |
+
iters=coarse_iters,
|
| 99 |
+
down_ratio=self.coarse_down_ratio,
|
| 100 |
+
)
|
| 101 |
+
coarse_pred_track = coarse_pred_track_lists[-1]
|
| 102 |
+
|
| 103 |
+
if inference:
|
| 104 |
+
torch.cuda.empty_cache()
|
| 105 |
+
|
| 106 |
+
if fine_tracking:
|
| 107 |
+
# Refine the coarse prediction
|
| 108 |
+
fine_pred_track, pred_score = refine_track(
|
| 109 |
+
images,
|
| 110 |
+
self.fine_fnet,
|
| 111 |
+
self.fine_predictor,
|
| 112 |
+
coarse_pred_track,
|
| 113 |
+
compute_score=False,
|
| 114 |
+
chunk=fine_chunk,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
if inference:
|
| 118 |
+
torch.cuda.empty_cache()
|
| 119 |
+
else:
|
| 120 |
+
fine_pred_track = coarse_pred_track
|
| 121 |
+
pred_score = torch.ones_like(pred_vis)
|
| 122 |
+
|
| 123 |
+
return fine_pred_track, coarse_pred_track, pred_vis, pred_score
|
| 124 |
+
|
| 125 |
+
def process_images_to_fmaps(self, images):
|
| 126 |
+
"""
|
| 127 |
+
This function processes images for inference.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
images (torch.Tensor): The images to be processed with shape S x 3 x H x W.
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
torch.Tensor: The processed feature maps.
|
| 134 |
+
"""
|
| 135 |
+
if self.coarse_down_ratio > 1:
|
| 136 |
+
# whether or not scale down the input images to save memory
|
| 137 |
+
fmaps = self.coarse_fnet(
|
| 138 |
+
F.interpolate(
|
| 139 |
+
images,
|
| 140 |
+
scale_factor=1 / self.coarse_down_ratio,
|
| 141 |
+
mode="bilinear",
|
| 142 |
+
align_corners=True,
|
| 143 |
+
)
|
| 144 |
+
)
|
| 145 |
+
else:
|
| 146 |
+
fmaps = self.coarse_fnet(images)
|
| 147 |
+
|
| 148 |
+
return fmaps
|
sailrecon/dependency/vggsfm_utils.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import warnings
|
| 9 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pycolmap
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from lightglue import ALIKED, SIFT, SuperPoint
|
| 16 |
+
|
| 17 |
+
from .vggsfm_tracker import TrackerPredictor
|
| 18 |
+
|
| 19 |
+
# Suppress verbose logging from dependencies
|
| 20 |
+
logging.getLogger("dinov2").setLevel(logging.WARNING)
|
| 21 |
+
warnings.filterwarnings("ignore", message="xFormers is available")
|
| 22 |
+
warnings.filterwarnings("ignore", message="dinov2")
|
| 23 |
+
|
| 24 |
+
# Constants
|
| 25 |
+
_RESNET_MEAN = [0.485, 0.456, 0.406]
|
| 26 |
+
_RESNET_STD = [0.229, 0.224, 0.225]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def build_vggsfm_tracker(model_path=None):
|
| 30 |
+
"""
|
| 31 |
+
Build and initialize the VGGSfM tracker.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
model_path: Path to the model weights file. If None, weights are downloaded from HuggingFace.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Initialized tracker model in eval mode.
|
| 38 |
+
"""
|
| 39 |
+
tracker = TrackerPredictor()
|
| 40 |
+
|
| 41 |
+
if model_path is None:
|
| 42 |
+
default_url = (
|
| 43 |
+
"https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_tracker.pt"
|
| 44 |
+
)
|
| 45 |
+
tracker.load_state_dict(torch.hub.load_state_dict_from_url(default_url))
|
| 46 |
+
else:
|
| 47 |
+
tracker.load_state_dict(torch.load(model_path))
|
| 48 |
+
|
| 49 |
+
tracker.eval()
|
| 50 |
+
return tracker
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def generate_rank_by_dino(
|
| 54 |
+
images,
|
| 55 |
+
query_frame_num,
|
| 56 |
+
image_size=336,
|
| 57 |
+
model_name="dinov2_vitb14_reg",
|
| 58 |
+
device="cuda",
|
| 59 |
+
spatial_similarity=False,
|
| 60 |
+
):
|
| 61 |
+
"""
|
| 62 |
+
Generate a ranking of frames using DINO ViT features.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
images: Tensor of shape (S, 3, H, W) with values in range [0, 1]
|
| 66 |
+
query_frame_num: Number of frames to select
|
| 67 |
+
image_size: Size to resize images to before processing
|
| 68 |
+
model_name: Name of the DINO model to use
|
| 69 |
+
device: Device to run the model on
|
| 70 |
+
spatial_similarity: Whether to use spatial token similarity or CLS token similarity
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
List of frame indices ranked by their representativeness
|
| 74 |
+
"""
|
| 75 |
+
# Resize images to the target size
|
| 76 |
+
images = F.interpolate(
|
| 77 |
+
images, (image_size, image_size), mode="bilinear", align_corners=False
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Load DINO model
|
| 81 |
+
dino_v2_model = torch.hub.load("facebookresearch/dinov2", model_name)
|
| 82 |
+
dino_v2_model.eval()
|
| 83 |
+
dino_v2_model = dino_v2_model.to(device)
|
| 84 |
+
|
| 85 |
+
# Normalize images using ResNet normalization
|
| 86 |
+
resnet_mean = torch.tensor(_RESNET_MEAN, device=device).view(1, 3, 1, 1)
|
| 87 |
+
resnet_std = torch.tensor(_RESNET_STD, device=device).view(1, 3, 1, 1)
|
| 88 |
+
images_resnet_norm = (images - resnet_mean) / resnet_std
|
| 89 |
+
|
| 90 |
+
with torch.no_grad():
|
| 91 |
+
frame_feat = dino_v2_model(images_resnet_norm, is_training=True)
|
| 92 |
+
|
| 93 |
+
# Process features based on similarity type
|
| 94 |
+
if spatial_similarity:
|
| 95 |
+
frame_feat = frame_feat["x_norm_patchtokens"]
|
| 96 |
+
frame_feat_norm = F.normalize(frame_feat, p=2, dim=1)
|
| 97 |
+
|
| 98 |
+
# Compute the similarity matrix
|
| 99 |
+
frame_feat_norm = frame_feat_norm.permute(1, 0, 2)
|
| 100 |
+
similarity_matrix = torch.bmm(
|
| 101 |
+
frame_feat_norm, frame_feat_norm.transpose(-1, -2)
|
| 102 |
+
)
|
| 103 |
+
similarity_matrix = similarity_matrix.mean(dim=0)
|
| 104 |
+
else:
|
| 105 |
+
frame_feat = frame_feat["x_norm_clstoken"]
|
| 106 |
+
frame_feat_norm = F.normalize(frame_feat, p=2, dim=1)
|
| 107 |
+
similarity_matrix = torch.mm(frame_feat_norm, frame_feat_norm.transpose(-1, -2))
|
| 108 |
+
|
| 109 |
+
distance_matrix = 100 - similarity_matrix.clone()
|
| 110 |
+
|
| 111 |
+
# Ignore self-pairing
|
| 112 |
+
similarity_matrix.fill_diagonal_(-100)
|
| 113 |
+
similarity_sum = similarity_matrix.sum(dim=1)
|
| 114 |
+
|
| 115 |
+
# Find the most common frame
|
| 116 |
+
most_common_frame_index = torch.argmax(similarity_sum).item()
|
| 117 |
+
|
| 118 |
+
# Conduct FPS sampling starting from the most common frame
|
| 119 |
+
fps_idx = farthest_point_sampling(
|
| 120 |
+
distance_matrix, query_frame_num, most_common_frame_index
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Clean up all tensors and models to free memory
|
| 124 |
+
del frame_feat, frame_feat_norm, similarity_matrix, distance_matrix
|
| 125 |
+
del dino_v2_model
|
| 126 |
+
torch.cuda.empty_cache()
|
| 127 |
+
|
| 128 |
+
return fps_idx
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def farthest_point_sampling(distance_matrix, num_samples, most_common_frame_index=0):
|
| 132 |
+
"""
|
| 133 |
+
Farthest point sampling algorithm to select diverse frames.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
distance_matrix: Matrix of distances between frames
|
| 137 |
+
num_samples: Number of frames to select
|
| 138 |
+
most_common_frame_index: Index of the first frame to select
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
List of selected frame indices
|
| 142 |
+
"""
|
| 143 |
+
distance_matrix = distance_matrix.clamp(min=0)
|
| 144 |
+
N = distance_matrix.size(0)
|
| 145 |
+
|
| 146 |
+
# Initialize with the most common frame
|
| 147 |
+
selected_indices = [most_common_frame_index]
|
| 148 |
+
check_distances = distance_matrix[selected_indices]
|
| 149 |
+
|
| 150 |
+
while len(selected_indices) < num_samples:
|
| 151 |
+
# Find the farthest point from the current set of selected points
|
| 152 |
+
farthest_point = torch.argmax(check_distances)
|
| 153 |
+
selected_indices.append(farthest_point.item())
|
| 154 |
+
|
| 155 |
+
check_distances = distance_matrix[farthest_point]
|
| 156 |
+
# Mark already selected points to avoid selecting them again
|
| 157 |
+
check_distances[selected_indices] = 0
|
| 158 |
+
|
| 159 |
+
# Break if all points have been selected
|
| 160 |
+
if len(selected_indices) == N:
|
| 161 |
+
break
|
| 162 |
+
|
| 163 |
+
return selected_indices
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def calculate_index_mappings(query_index, S, device=None):
|
| 167 |
+
"""
|
| 168 |
+
Construct an order that switches [query_index] and [0]
|
| 169 |
+
so that the content of query_index would be placed at [0].
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
query_index: Index to swap with 0
|
| 173 |
+
S: Total number of elements
|
| 174 |
+
device: Device to place the tensor on
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
Tensor of indices with the swapped order
|
| 178 |
+
"""
|
| 179 |
+
new_order = torch.arange(S)
|
| 180 |
+
new_order[0] = query_index
|
| 181 |
+
new_order[query_index] = 0
|
| 182 |
+
if device is not None:
|
| 183 |
+
new_order = new_order.to(device)
|
| 184 |
+
return new_order
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def switch_tensor_order(tensors, order, dim=1):
|
| 188 |
+
"""
|
| 189 |
+
Reorder tensors along a specific dimension according to the given order.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
tensors: List of tensors to reorder
|
| 193 |
+
order: Tensor of indices specifying the new order
|
| 194 |
+
dim: Dimension along which to reorder
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
List of reordered tensors
|
| 198 |
+
"""
|
| 199 |
+
return [
|
| 200 |
+
torch.index_select(tensor, dim, order) if tensor is not None else None
|
| 201 |
+
for tensor in tensors
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def initialize_feature_extractors(
|
| 206 |
+
max_query_num, det_thres=0.005, extractor_method="aliked", device="cuda"
|
| 207 |
+
):
|
| 208 |
+
"""
|
| 209 |
+
Initialize feature extractors that can be reused based on a method string.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
max_query_num: Maximum number of keypoints to extract
|
| 213 |
+
det_thres: Detection threshold for keypoint extraction
|
| 214 |
+
extractor_method: String specifying which extractors to use (e.g., "aliked", "sp+sift", "aliked+sp+sift")
|
| 215 |
+
device: Device to run extraction on
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
Dictionary of initialized extractors
|
| 219 |
+
"""
|
| 220 |
+
extractors = {}
|
| 221 |
+
methods = extractor_method.lower().split("+")
|
| 222 |
+
|
| 223 |
+
for method in methods:
|
| 224 |
+
method = method.strip()
|
| 225 |
+
if method == "aliked":
|
| 226 |
+
aliked_extractor = ALIKED(
|
| 227 |
+
max_num_keypoints=max_query_num, detection_threshold=det_thres
|
| 228 |
+
)
|
| 229 |
+
extractors["aliked"] = aliked_extractor.to(device).eval()
|
| 230 |
+
elif method == "sp":
|
| 231 |
+
sp_extractor = SuperPoint(
|
| 232 |
+
max_num_keypoints=max_query_num, detection_threshold=det_thres
|
| 233 |
+
)
|
| 234 |
+
extractors["sp"] = sp_extractor.to(device).eval()
|
| 235 |
+
elif method == "sift":
|
| 236 |
+
sift_extractor = SIFT(max_num_keypoints=max_query_num)
|
| 237 |
+
extractors["sift"] = sift_extractor.to(device).eval()
|
| 238 |
+
else:
|
| 239 |
+
print(f"Warning: Unknown feature extractor '{method}', ignoring.")
|
| 240 |
+
|
| 241 |
+
if not extractors:
|
| 242 |
+
print(
|
| 243 |
+
f"Warning: No valid extractors found in '{extractor_method}'. Using ALIKED by default."
|
| 244 |
+
)
|
| 245 |
+
aliked_extractor = ALIKED(
|
| 246 |
+
max_num_keypoints=max_query_num, detection_threshold=det_thres
|
| 247 |
+
)
|
| 248 |
+
extractors["aliked"] = aliked_extractor.to(device).eval()
|
| 249 |
+
|
| 250 |
+
return extractors
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def extract_keypoints(query_image, extractors, round_keypoints=True):
|
| 254 |
+
"""
|
| 255 |
+
Extract keypoints using pre-initialized feature extractors.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
query_image: Input image tensor (3xHxW, range [0, 1])
|
| 259 |
+
extractors: Dictionary of initialized extractors
|
| 260 |
+
|
| 261 |
+
Returns:
|
| 262 |
+
Tensor of keypoint coordinates (1xNx2)
|
| 263 |
+
"""
|
| 264 |
+
query_points = None
|
| 265 |
+
|
| 266 |
+
with torch.no_grad():
|
| 267 |
+
for extractor_name, extractor in extractors.items():
|
| 268 |
+
query_points_data = extractor.extract(query_image, invalid_mask=None)
|
| 269 |
+
extractor_points = query_points_data["keypoints"]
|
| 270 |
+
if round_keypoints:
|
| 271 |
+
extractor_points = extractor_points.round()
|
| 272 |
+
|
| 273 |
+
if query_points is not None:
|
| 274 |
+
query_points = torch.cat([query_points, extractor_points], dim=1)
|
| 275 |
+
else:
|
| 276 |
+
query_points = extractor_points
|
| 277 |
+
|
| 278 |
+
return query_points
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def predict_tracks_in_chunks(
|
| 282 |
+
track_predictor,
|
| 283 |
+
images_feed,
|
| 284 |
+
query_points_list,
|
| 285 |
+
fmaps_feed,
|
| 286 |
+
fine_tracking,
|
| 287 |
+
num_splits=None,
|
| 288 |
+
fine_chunk=40960,
|
| 289 |
+
):
|
| 290 |
+
"""
|
| 291 |
+
Process a list of query points to avoid memory issues.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
track_predictor (object): The track predictor object used for predicting tracks.
|
| 295 |
+
images_feed (torch.Tensor): A tensor of shape (B, T, C, H, W) representing a batch of images.
|
| 296 |
+
query_points_list (list or tuple): A list/tuple of tensors, each of shape (B, Ni, 2) representing chunks of query points.
|
| 297 |
+
fmaps_feed (torch.Tensor): A tensor of feature maps for the tracker.
|
| 298 |
+
fine_tracking (bool): Whether to perform fine tracking.
|
| 299 |
+
num_splits (int, optional): Ignored when query_points_list is provided. Kept for backward compatibility.
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
tuple: A tuple containing the concatenated predicted tracks, visibility, and scores.
|
| 303 |
+
"""
|
| 304 |
+
# If query_points_list is not a list or tuple but a single tensor, handle it like the old version for backward compatibility
|
| 305 |
+
if not isinstance(query_points_list, (list, tuple)):
|
| 306 |
+
query_points = query_points_list
|
| 307 |
+
if num_splits is None:
|
| 308 |
+
num_splits = 1
|
| 309 |
+
query_points_list = torch.chunk(query_points, num_splits, dim=1)
|
| 310 |
+
|
| 311 |
+
# Ensure query_points_list is a list for iteration (as torch.chunk returns a tuple)
|
| 312 |
+
if isinstance(query_points_list, tuple):
|
| 313 |
+
query_points_list = list(query_points_list)
|
| 314 |
+
|
| 315 |
+
fine_pred_track_list = []
|
| 316 |
+
pred_vis_list = []
|
| 317 |
+
pred_score_list = []
|
| 318 |
+
|
| 319 |
+
for split_points in query_points_list:
|
| 320 |
+
# Feed into track predictor for each split
|
| 321 |
+
fine_pred_track, _, pred_vis, pred_score = track_predictor(
|
| 322 |
+
images_feed,
|
| 323 |
+
split_points,
|
| 324 |
+
fmaps=fmaps_feed,
|
| 325 |
+
fine_tracking=fine_tracking,
|
| 326 |
+
fine_chunk=fine_chunk,
|
| 327 |
+
)
|
| 328 |
+
fine_pred_track_list.append(fine_pred_track)
|
| 329 |
+
pred_vis_list.append(pred_vis)
|
| 330 |
+
pred_score_list.append(pred_score)
|
| 331 |
+
|
| 332 |
+
# Concatenate the results from all splits
|
| 333 |
+
fine_pred_track = torch.cat(fine_pred_track_list, dim=2)
|
| 334 |
+
pred_vis = torch.cat(pred_vis_list, dim=2)
|
| 335 |
+
|
| 336 |
+
if pred_score is not None:
|
| 337 |
+
pred_score = torch.cat(pred_score_list, dim=2)
|
| 338 |
+
else:
|
| 339 |
+
pred_score = None
|
| 340 |
+
|
| 341 |
+
return fine_pred_track, pred_vis, pred_score
|
sailrecon/heads/camera_head.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
from sailrecon.heads.head_act import activate_pose
|
| 15 |
+
from sailrecon.layers import Mlp
|
| 16 |
+
from sailrecon.layers.block import Block
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class CameraHead(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
CameraHead predicts camera parameters from token representations using iterative refinement.
|
| 22 |
+
|
| 23 |
+
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
dim_in: int = 2048,
|
| 29 |
+
trunk_depth: int = 4,
|
| 30 |
+
pose_encoding_type: str = "absT_quaR_FoV",
|
| 31 |
+
num_heads: int = 16,
|
| 32 |
+
mlp_ratio: int = 4,
|
| 33 |
+
init_values: float = 0.01,
|
| 34 |
+
trans_act: str = "linear",
|
| 35 |
+
quat_act: str = "linear",
|
| 36 |
+
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
|
| 37 |
+
):
|
| 38 |
+
super().__init__()
|
| 39 |
+
|
| 40 |
+
if pose_encoding_type == "absT_quaR_FoV":
|
| 41 |
+
self.target_dim = 9
|
| 42 |
+
else:
|
| 43 |
+
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
|
| 44 |
+
|
| 45 |
+
self.trans_act = trans_act
|
| 46 |
+
self.quat_act = quat_act
|
| 47 |
+
self.fl_act = fl_act
|
| 48 |
+
self.trunk_depth = trunk_depth
|
| 49 |
+
|
| 50 |
+
# Build the trunk using a sequence of transformer blocks.
|
| 51 |
+
self.trunk = nn.Sequential(
|
| 52 |
+
*[
|
| 53 |
+
Block(
|
| 54 |
+
dim=dim_in,
|
| 55 |
+
num_heads=num_heads,
|
| 56 |
+
mlp_ratio=mlp_ratio,
|
| 57 |
+
init_values=init_values,
|
| 58 |
+
)
|
| 59 |
+
for _ in range(trunk_depth)
|
| 60 |
+
]
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Normalizations for camera token and trunk output.
|
| 64 |
+
self.token_norm = nn.LayerNorm(dim_in)
|
| 65 |
+
self.trunk_norm = nn.LayerNorm(dim_in)
|
| 66 |
+
|
| 67 |
+
# Learnable empty camera pose token.
|
| 68 |
+
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
|
| 69 |
+
self.embed_pose = nn.Linear(self.target_dim, dim_in)
|
| 70 |
+
|
| 71 |
+
# Module for producing modulation parameters: shift, scale, and a gate.
|
| 72 |
+
self.poseLN_modulation = nn.Sequential(
|
| 73 |
+
nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Adaptive layer normalization without affine parameters.
|
| 77 |
+
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
| 78 |
+
self.pose_branch = Mlp(
|
| 79 |
+
in_features=dim_in,
|
| 80 |
+
hidden_features=dim_in // 2,
|
| 81 |
+
out_features=self.target_dim,
|
| 82 |
+
drop=0,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def forward(
|
| 86 |
+
self,
|
| 87 |
+
aggregated_tokens_list: list,
|
| 88 |
+
cam_token_last_layer: torch.Tensor | None,
|
| 89 |
+
num_iterations: int = 4,
|
| 90 |
+
) -> list:
|
| 91 |
+
"""
|
| 92 |
+
Forward pass to predict camera parameters.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
aggregated_tokens_list (list): List of token tensors from the network;
|
| 96 |
+
the last tensor is used for prediction.
|
| 97 |
+
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
list: A list of predicted camera encodings (post-activation) from each iteration.
|
| 101 |
+
"""
|
| 102 |
+
# Use tokens from the last block for camera prediction.
|
| 103 |
+
tokens = aggregated_tokens_list[-1]
|
| 104 |
+
|
| 105 |
+
# Extract the camera tokens
|
| 106 |
+
pose_tokens = tokens[:, :, 0]
|
| 107 |
+
num_recon = cam_token_last_layer.shape[1]
|
| 108 |
+
num_reloc = pose_tokens.shape[1]
|
| 109 |
+
pose_tokens = torch.cat([cam_token_last_layer, pose_tokens], dim=1)
|
| 110 |
+
pose_tokens = self.token_norm(pose_tokens)
|
| 111 |
+
|
| 112 |
+
attention_mask = build_lr_mask(
|
| 113 |
+
S=num_recon + num_reloc,
|
| 114 |
+
no_reloc_list=[i for i in range(num_recon)],
|
| 115 |
+
device=pose_tokens.device,
|
| 116 |
+
)
|
| 117 |
+
pred_pose_enc_list = self.trunk_fn(pose_tokens, attention_mask, num_iterations)
|
| 118 |
+
pred_pose_enc_list_returned = [[] for i in range(num_iterations)]
|
| 119 |
+
for i in range(len(pred_pose_enc_list)):
|
| 120 |
+
pred_pose_enc_list_returned[i] = pred_pose_enc_list[i][:, num_recon:]
|
| 121 |
+
return pred_pose_enc_list_returned
|
| 122 |
+
|
| 123 |
+
def trunk_fn(
|
| 124 |
+
self,
|
| 125 |
+
pose_tokens: torch.Tensor,
|
| 126 |
+
attention_mask: torch.Tensor,
|
| 127 |
+
num_iterations: int,
|
| 128 |
+
) -> list:
|
| 129 |
+
"""
|
| 130 |
+
Iteratively refine camera pose predictions.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
|
| 134 |
+
num_iterations (int): Number of refinement iterations.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
list: List of activated camera encodings from each iteration.
|
| 138 |
+
"""
|
| 139 |
+
B, S, C = pose_tokens.shape # S is expected to be 1.
|
| 140 |
+
pred_pose_enc = None
|
| 141 |
+
pred_pose_enc_list = []
|
| 142 |
+
|
| 143 |
+
for _ in range(num_iterations):
|
| 144 |
+
# Use a learned empty pose for the first iteration.
|
| 145 |
+
if pred_pose_enc is None:
|
| 146 |
+
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
|
| 147 |
+
else:
|
| 148 |
+
# Detach the previous prediction to avoid backprop through time.
|
| 149 |
+
pred_pose_enc = pred_pose_enc.detach()
|
| 150 |
+
module_input = self.embed_pose(pred_pose_enc)
|
| 151 |
+
|
| 152 |
+
# Generate modulation parameters and split them into shift, scale, and gate components.
|
| 153 |
+
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(
|
| 154 |
+
3, dim=-1
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Adaptive layer normalization and modulation.
|
| 158 |
+
pose_tokens_modulated = gate_msa * modulate(
|
| 159 |
+
self.adaln_norm(pose_tokens), shift_msa, scale_msa
|
| 160 |
+
)
|
| 161 |
+
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
|
| 162 |
+
|
| 163 |
+
for idx_, blk in enumerate(self.trunk):
|
| 164 |
+
pose_tokens_modulated = blk(
|
| 165 |
+
pose_tokens_modulated, None, ~attention_mask
|
| 166 |
+
)
|
| 167 |
+
# Compute the delta update for the pose encoding.
|
| 168 |
+
pred_pose_enc_delta = self.pose_branch(
|
| 169 |
+
self.trunk_norm(pose_tokens_modulated)
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
if pred_pose_enc is None:
|
| 173 |
+
pred_pose_enc = pred_pose_enc_delta
|
| 174 |
+
else:
|
| 175 |
+
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
|
| 176 |
+
|
| 177 |
+
# Apply final activation functions for translation, quaternion, and field-of-view.
|
| 178 |
+
activated_pose = activate_pose(
|
| 179 |
+
pred_pose_enc,
|
| 180 |
+
trans_act=self.trans_act,
|
| 181 |
+
quat_act=self.quat_act,
|
| 182 |
+
fl_act=self.fl_act,
|
| 183 |
+
)
|
| 184 |
+
pred_pose_enc_list.append(activated_pose)
|
| 185 |
+
|
| 186 |
+
return pred_pose_enc_list
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
| 190 |
+
"""
|
| 191 |
+
Modulate the input tensor using scaling and shifting parameters.
|
| 192 |
+
"""
|
| 193 |
+
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
| 194 |
+
return x * (1 + scale) + shift
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def build_lr_mask(S: int, no_reloc_list, device="cpu"):
|
| 198 |
+
"""
|
| 199 |
+
Args:
|
| 200 |
+
S (int) : total number of tokens in the sequence.
|
| 201 |
+
r_idx (Sequence) : indices of r-tokens (unique, 0-based).
|
| 202 |
+
device (str) : target device for the mask tensor.
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
attn_mask (torch.BoolTensor) of shape (1, 1, S, S)
|
| 206 |
+
— ready for F.scaled_dot_product_attention (True == masked).
|
| 207 |
+
"""
|
| 208 |
+
# ----
|
| 209 |
+
r_idx = torch.tensor(
|
| 210 |
+
[i for i in range(S) if i not in no_reloc_list], dtype=torch.long, device=device
|
| 211 |
+
)
|
| 212 |
+
l_idx = torch.as_tensor(no_reloc_list, dtype=torch.long, device=device).unique(
|
| 213 |
+
sorted=True
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
mask = torch.zeros(S, S, dtype=torch.bool, device=device)
|
| 217 |
+
|
| 218 |
+
# ----
|
| 219 |
+
if l_idx.numel() and r_idx.numel():
|
| 220 |
+
mask[l_idx[:, None], r_idx[None, :]] = True
|
| 221 |
+
|
| 222 |
+
# ----
|
| 223 |
+
if r_idx.numel() > 1:
|
| 224 |
+
mask[r_idx[:, None], r_idx[None, :]] = True
|
| 225 |
+
mask[r_idx, r_idx] = False
|
| 226 |
+
|
| 227 |
+
# ---- 3. 补 batch & head 维度
|
| 228 |
+
return mask.unsqueeze(0).unsqueeze(0)
|
sailrecon/heads/dpt_head.py
ADDED
|
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
from typing import Dict, List, Tuple, Union
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
from .head_act import activate_head
|
| 19 |
+
from .utils import create_uv_grid, position_grid_to_embed
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DPTHead(nn.Module):
|
| 23 |
+
"""
|
| 24 |
+
DPT Head for dense prediction tasks.
|
| 25 |
+
|
| 26 |
+
This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
|
| 27 |
+
(https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
|
| 28 |
+
backbone and produces dense predictions by fusing multi-scale features.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
dim_in (int): Input dimension (channels).
|
| 32 |
+
patch_size (int, optional): Patch size. Default is 14.
|
| 33 |
+
output_dim (int, optional): Number of output channels. Default is 4.
|
| 34 |
+
activation (str, optional): Activation type. Default is "inv_log".
|
| 35 |
+
conf_activation (str, optional): Confidence activation type. Default is "expp1".
|
| 36 |
+
features (int, optional): Feature channels for intermediate representations. Default is 256.
|
| 37 |
+
out_channels (List[int], optional): Output channels for each intermediate layer.
|
| 38 |
+
intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
|
| 39 |
+
pos_embed (bool, optional): Whether to use positional embedding. Default is True.
|
| 40 |
+
feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
|
| 41 |
+
down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
dim_in: int,
|
| 47 |
+
patch_size: int = 14,
|
| 48 |
+
output_dim: int = 4,
|
| 49 |
+
activation: str = "inv_log",
|
| 50 |
+
conf_activation: str = "expp1",
|
| 51 |
+
features: int = 256,
|
| 52 |
+
out_channels: List[int] = [256, 512, 1024, 1024],
|
| 53 |
+
intermediate_layer_idx: List[int] = [4, 11, 17, 23],
|
| 54 |
+
pos_embed: bool = True,
|
| 55 |
+
feature_only: bool = False,
|
| 56 |
+
down_ratio: int = 1,
|
| 57 |
+
) -> None:
|
| 58 |
+
super(DPTHead, self).__init__()
|
| 59 |
+
self.patch_size = patch_size
|
| 60 |
+
self.activation = activation
|
| 61 |
+
self.conf_activation = conf_activation
|
| 62 |
+
self.pos_embed = pos_embed
|
| 63 |
+
self.feature_only = feature_only
|
| 64 |
+
self.down_ratio = down_ratio
|
| 65 |
+
self.intermediate_layer_idx = intermediate_layer_idx
|
| 66 |
+
|
| 67 |
+
self.norm = nn.LayerNorm(dim_in)
|
| 68 |
+
|
| 69 |
+
# Projection layers for each output channel from tokens.
|
| 70 |
+
self.projects = nn.ModuleList(
|
| 71 |
+
[
|
| 72 |
+
nn.Conv2d(
|
| 73 |
+
in_channels=dim_in,
|
| 74 |
+
out_channels=oc,
|
| 75 |
+
kernel_size=1,
|
| 76 |
+
stride=1,
|
| 77 |
+
padding=0,
|
| 78 |
+
)
|
| 79 |
+
for oc in out_channels
|
| 80 |
+
]
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Resize layers for upsampling feature maps.
|
| 84 |
+
self.resize_layers = nn.ModuleList(
|
| 85 |
+
[
|
| 86 |
+
nn.ConvTranspose2d(
|
| 87 |
+
in_channels=out_channels[0],
|
| 88 |
+
out_channels=out_channels[0],
|
| 89 |
+
kernel_size=4,
|
| 90 |
+
stride=4,
|
| 91 |
+
padding=0,
|
| 92 |
+
),
|
| 93 |
+
nn.ConvTranspose2d(
|
| 94 |
+
in_channels=out_channels[1],
|
| 95 |
+
out_channels=out_channels[1],
|
| 96 |
+
kernel_size=2,
|
| 97 |
+
stride=2,
|
| 98 |
+
padding=0,
|
| 99 |
+
),
|
| 100 |
+
nn.Identity(),
|
| 101 |
+
nn.Conv2d(
|
| 102 |
+
in_channels=out_channels[3],
|
| 103 |
+
out_channels=out_channels[3],
|
| 104 |
+
kernel_size=3,
|
| 105 |
+
stride=2,
|
| 106 |
+
padding=1,
|
| 107 |
+
),
|
| 108 |
+
]
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
self.scratch = _make_scratch(out_channels, features, expand=False)
|
| 112 |
+
|
| 113 |
+
# Attach additional modules to scratch.
|
| 114 |
+
self.scratch.stem_transpose = None
|
| 115 |
+
self.scratch.refinenet1 = _make_fusion_block(features)
|
| 116 |
+
self.scratch.refinenet2 = _make_fusion_block(features)
|
| 117 |
+
self.scratch.refinenet3 = _make_fusion_block(features)
|
| 118 |
+
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
|
| 119 |
+
|
| 120 |
+
head_features_1 = features
|
| 121 |
+
head_features_2 = 32
|
| 122 |
+
|
| 123 |
+
if feature_only:
|
| 124 |
+
self.scratch.output_conv1 = nn.Conv2d(
|
| 125 |
+
head_features_1, head_features_1, kernel_size=3, stride=1, padding=1
|
| 126 |
+
)
|
| 127 |
+
else:
|
| 128 |
+
self.scratch.output_conv1 = nn.Conv2d(
|
| 129 |
+
head_features_1,
|
| 130 |
+
head_features_1 // 2,
|
| 131 |
+
kernel_size=3,
|
| 132 |
+
stride=1,
|
| 133 |
+
padding=1,
|
| 134 |
+
)
|
| 135 |
+
conv2_in_channels = head_features_1 // 2
|
| 136 |
+
|
| 137 |
+
self.scratch.output_conv2 = nn.Sequential(
|
| 138 |
+
nn.Conv2d(
|
| 139 |
+
conv2_in_channels,
|
| 140 |
+
head_features_2,
|
| 141 |
+
kernel_size=3,
|
| 142 |
+
stride=1,
|
| 143 |
+
padding=1,
|
| 144 |
+
),
|
| 145 |
+
nn.ReLU(inplace=True),
|
| 146 |
+
nn.Conv2d(
|
| 147 |
+
head_features_2, output_dim, kernel_size=1, stride=1, padding=0
|
| 148 |
+
),
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def forward(
|
| 152 |
+
self,
|
| 153 |
+
aggregated_tokens_list: List[torch.Tensor],
|
| 154 |
+
images: torch.Tensor,
|
| 155 |
+
patch_start_idx: int,
|
| 156 |
+
frames_chunk_size: int = 8,
|
| 157 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 158 |
+
"""
|
| 159 |
+
Forward pass through the DPT head, supports processing by chunking frames.
|
| 160 |
+
Args:
|
| 161 |
+
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
| 162 |
+
images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
| 163 |
+
patch_start_idx (int): Starting index for patch tokens in the token sequence.
|
| 164 |
+
Used to separate patch tokens from other tokens (e.g., camera or register tokens).
|
| 165 |
+
frames_chunk_size (int, optional): Number of frames to process in each chunk.
|
| 166 |
+
If None or larger than S, all frames are processed at once. Default: 8.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Tensor or Tuple[Tensor, Tensor]:
|
| 170 |
+
- If feature_only=True: Feature maps with shape [B, S, C, H, W]
|
| 171 |
+
- Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
|
| 172 |
+
"""
|
| 173 |
+
B, S, _, H, W = images.shape
|
| 174 |
+
|
| 175 |
+
# If frames_chunk_size is not specified or greater than S, process all frames at once
|
| 176 |
+
if frames_chunk_size is None or frames_chunk_size >= S:
|
| 177 |
+
return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
|
| 178 |
+
|
| 179 |
+
# Otherwise, process frames in chunks to manage memory usage
|
| 180 |
+
assert frames_chunk_size > 0
|
| 181 |
+
|
| 182 |
+
# Process frames in batches
|
| 183 |
+
all_preds = []
|
| 184 |
+
all_conf = []
|
| 185 |
+
|
| 186 |
+
for frames_start_idx in range(0, S, frames_chunk_size):
|
| 187 |
+
frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
|
| 188 |
+
|
| 189 |
+
# Process batch of frames
|
| 190 |
+
if self.feature_only:
|
| 191 |
+
chunk_output = self._forward_impl(
|
| 192 |
+
aggregated_tokens_list,
|
| 193 |
+
images,
|
| 194 |
+
patch_start_idx,
|
| 195 |
+
frames_start_idx,
|
| 196 |
+
frames_end_idx,
|
| 197 |
+
)
|
| 198 |
+
all_preds.append(chunk_output)
|
| 199 |
+
else:
|
| 200 |
+
chunk_preds, chunk_conf = self._forward_impl(
|
| 201 |
+
aggregated_tokens_list,
|
| 202 |
+
images,
|
| 203 |
+
patch_start_idx,
|
| 204 |
+
frames_start_idx,
|
| 205 |
+
frames_end_idx,
|
| 206 |
+
)
|
| 207 |
+
all_preds.append(chunk_preds)
|
| 208 |
+
all_conf.append(chunk_conf)
|
| 209 |
+
|
| 210 |
+
# Concatenate results along the sequence dimension
|
| 211 |
+
if self.feature_only:
|
| 212 |
+
return torch.cat(all_preds, dim=1)
|
| 213 |
+
else:
|
| 214 |
+
return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
|
| 215 |
+
|
| 216 |
+
def _forward_impl(
|
| 217 |
+
self,
|
| 218 |
+
aggregated_tokens_list: List[torch.Tensor],
|
| 219 |
+
images: torch.Tensor,
|
| 220 |
+
patch_start_idx: int,
|
| 221 |
+
frames_start_idx: int = None,
|
| 222 |
+
frames_end_idx: int = None,
|
| 223 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 224 |
+
"""
|
| 225 |
+
Implementation of the forward pass through the DPT head.
|
| 226 |
+
|
| 227 |
+
This method processes a specific chunk of frames from the sequence.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
| 231 |
+
images (Tensor): Input images with shape [B, S, 3, H, W].
|
| 232 |
+
patch_start_idx (int): Starting index for patch tokens.
|
| 233 |
+
frames_start_idx (int, optional): Starting index for frames to process.
|
| 234 |
+
frames_end_idx (int, optional): Ending index for frames to process.
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
|
| 238 |
+
"""
|
| 239 |
+
if frames_start_idx is not None and frames_end_idx is not None:
|
| 240 |
+
images = images[:, frames_start_idx:frames_end_idx].contiguous()
|
| 241 |
+
|
| 242 |
+
B, S, _, H, W = images.shape
|
| 243 |
+
|
| 244 |
+
patch_h, patch_w = H // self.patch_size, W // self.patch_size
|
| 245 |
+
|
| 246 |
+
out = []
|
| 247 |
+
dpt_idx = 0
|
| 248 |
+
|
| 249 |
+
for layer_idx in self.intermediate_layer_idx:
|
| 250 |
+
x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
|
| 251 |
+
|
| 252 |
+
# Select frames if processing a chunk
|
| 253 |
+
if frames_start_idx is not None and frames_end_idx is not None:
|
| 254 |
+
x = x[:, frames_start_idx:frames_end_idx]
|
| 255 |
+
|
| 256 |
+
x = x.reshape(B * S, -1, x.shape[-1])
|
| 257 |
+
|
| 258 |
+
x = self.norm(x)
|
| 259 |
+
|
| 260 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
| 261 |
+
|
| 262 |
+
x = self.projects[dpt_idx](x)
|
| 263 |
+
if self.pos_embed:
|
| 264 |
+
x = self._apply_pos_embed(x, W, H)
|
| 265 |
+
x = self.resize_layers[dpt_idx](x)
|
| 266 |
+
|
| 267 |
+
out.append(x)
|
| 268 |
+
dpt_idx += 1
|
| 269 |
+
|
| 270 |
+
# Fuse features from multiple layers.
|
| 271 |
+
out = self.scratch_forward(out)
|
| 272 |
+
# Interpolate fused output to match target image resolution.
|
| 273 |
+
out = custom_interpolate(
|
| 274 |
+
out,
|
| 275 |
+
(
|
| 276 |
+
int(patch_h * self.patch_size / self.down_ratio),
|
| 277 |
+
int(patch_w * self.patch_size / self.down_ratio),
|
| 278 |
+
),
|
| 279 |
+
mode="bilinear",
|
| 280 |
+
align_corners=True,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
if self.pos_embed:
|
| 284 |
+
out = self._apply_pos_embed(out, W, H)
|
| 285 |
+
|
| 286 |
+
if self.feature_only:
|
| 287 |
+
return out.view(B, S, *out.shape[1:])
|
| 288 |
+
|
| 289 |
+
out = self.scratch.output_conv2(out)
|
| 290 |
+
preds, conf = activate_head(
|
| 291 |
+
out, activation=self.activation, conf_activation=self.conf_activation
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
preds = preds.view(B, S, *preds.shape[1:])
|
| 295 |
+
conf = conf.view(B, S, *conf.shape[1:])
|
| 296 |
+
return preds, conf
|
| 297 |
+
|
| 298 |
+
def _apply_pos_embed(
|
| 299 |
+
self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1
|
| 300 |
+
) -> torch.Tensor:
|
| 301 |
+
"""
|
| 302 |
+
Apply positional embedding to tensor x.
|
| 303 |
+
"""
|
| 304 |
+
patch_w = x.shape[-1]
|
| 305 |
+
patch_h = x.shape[-2]
|
| 306 |
+
pos_embed = create_uv_grid(
|
| 307 |
+
patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device
|
| 308 |
+
)
|
| 309 |
+
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
|
| 310 |
+
pos_embed = pos_embed * ratio
|
| 311 |
+
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
|
| 312 |
+
return x + pos_embed
|
| 313 |
+
|
| 314 |
+
def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
|
| 315 |
+
"""
|
| 316 |
+
Forward pass through the fusion blocks.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
features (List[Tensor]): List of feature maps from different layers.
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
Tensor: Fused feature map.
|
| 323 |
+
"""
|
| 324 |
+
layer_1, layer_2, layer_3, layer_4 = features
|
| 325 |
+
|
| 326 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
| 327 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
| 328 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
| 329 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
| 330 |
+
|
| 331 |
+
out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
| 332 |
+
del layer_4_rn, layer_4
|
| 333 |
+
|
| 334 |
+
out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
|
| 335 |
+
del layer_3_rn, layer_3
|
| 336 |
+
|
| 337 |
+
out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
|
| 338 |
+
del layer_2_rn, layer_2
|
| 339 |
+
|
| 340 |
+
out = self.scratch.refinenet1(out, layer_1_rn)
|
| 341 |
+
del layer_1_rn, layer_1
|
| 342 |
+
|
| 343 |
+
out = self.scratch.output_conv1(out)
|
| 344 |
+
return out
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
################################################################################
|
| 348 |
+
# Modules
|
| 349 |
+
################################################################################
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def _make_fusion_block(
|
| 353 |
+
features: int, size: int = None, has_residual: bool = True, groups: int = 1
|
| 354 |
+
) -> nn.Module:
|
| 355 |
+
return FeatureFusionBlock(
|
| 356 |
+
features,
|
| 357 |
+
nn.ReLU(inplace=True),
|
| 358 |
+
deconv=False,
|
| 359 |
+
bn=False,
|
| 360 |
+
expand=False,
|
| 361 |
+
align_corners=True,
|
| 362 |
+
size=size,
|
| 363 |
+
has_residual=has_residual,
|
| 364 |
+
groups=groups,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def _make_scratch(
|
| 369 |
+
in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False
|
| 370 |
+
) -> nn.Module:
|
| 371 |
+
scratch = nn.Module()
|
| 372 |
+
out_shape1 = out_shape
|
| 373 |
+
out_shape2 = out_shape
|
| 374 |
+
out_shape3 = out_shape
|
| 375 |
+
if len(in_shape) >= 4:
|
| 376 |
+
out_shape4 = out_shape
|
| 377 |
+
|
| 378 |
+
if expand:
|
| 379 |
+
out_shape1 = out_shape
|
| 380 |
+
out_shape2 = out_shape * 2
|
| 381 |
+
out_shape3 = out_shape * 4
|
| 382 |
+
if len(in_shape) >= 4:
|
| 383 |
+
out_shape4 = out_shape * 8
|
| 384 |
+
|
| 385 |
+
scratch.layer1_rn = nn.Conv2d(
|
| 386 |
+
in_shape[0],
|
| 387 |
+
out_shape1,
|
| 388 |
+
kernel_size=3,
|
| 389 |
+
stride=1,
|
| 390 |
+
padding=1,
|
| 391 |
+
bias=False,
|
| 392 |
+
groups=groups,
|
| 393 |
+
)
|
| 394 |
+
scratch.layer2_rn = nn.Conv2d(
|
| 395 |
+
in_shape[1],
|
| 396 |
+
out_shape2,
|
| 397 |
+
kernel_size=3,
|
| 398 |
+
stride=1,
|
| 399 |
+
padding=1,
|
| 400 |
+
bias=False,
|
| 401 |
+
groups=groups,
|
| 402 |
+
)
|
| 403 |
+
scratch.layer3_rn = nn.Conv2d(
|
| 404 |
+
in_shape[2],
|
| 405 |
+
out_shape3,
|
| 406 |
+
kernel_size=3,
|
| 407 |
+
stride=1,
|
| 408 |
+
padding=1,
|
| 409 |
+
bias=False,
|
| 410 |
+
groups=groups,
|
| 411 |
+
)
|
| 412 |
+
if len(in_shape) >= 4:
|
| 413 |
+
scratch.layer4_rn = nn.Conv2d(
|
| 414 |
+
in_shape[3],
|
| 415 |
+
out_shape4,
|
| 416 |
+
kernel_size=3,
|
| 417 |
+
stride=1,
|
| 418 |
+
padding=1,
|
| 419 |
+
bias=False,
|
| 420 |
+
groups=groups,
|
| 421 |
+
)
|
| 422 |
+
return scratch
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
class ResidualConvUnit(nn.Module):
|
| 426 |
+
"""Residual convolution module."""
|
| 427 |
+
|
| 428 |
+
def __init__(self, features, activation, bn, groups=1):
|
| 429 |
+
"""Init.
|
| 430 |
+
|
| 431 |
+
Args:
|
| 432 |
+
features (int): number of features
|
| 433 |
+
"""
|
| 434 |
+
super().__init__()
|
| 435 |
+
|
| 436 |
+
self.bn = bn
|
| 437 |
+
self.groups = groups
|
| 438 |
+
self.conv1 = nn.Conv2d(
|
| 439 |
+
features,
|
| 440 |
+
features,
|
| 441 |
+
kernel_size=3,
|
| 442 |
+
stride=1,
|
| 443 |
+
padding=1,
|
| 444 |
+
bias=True,
|
| 445 |
+
groups=self.groups,
|
| 446 |
+
)
|
| 447 |
+
self.conv2 = nn.Conv2d(
|
| 448 |
+
features,
|
| 449 |
+
features,
|
| 450 |
+
kernel_size=3,
|
| 451 |
+
stride=1,
|
| 452 |
+
padding=1,
|
| 453 |
+
bias=True,
|
| 454 |
+
groups=self.groups,
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
self.norm1 = None
|
| 458 |
+
self.norm2 = None
|
| 459 |
+
|
| 460 |
+
self.activation = activation
|
| 461 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 462 |
+
|
| 463 |
+
def forward(self, x):
|
| 464 |
+
"""Forward pass.
|
| 465 |
+
|
| 466 |
+
Args:
|
| 467 |
+
x (tensor): input
|
| 468 |
+
|
| 469 |
+
Returns:
|
| 470 |
+
tensor: output
|
| 471 |
+
"""
|
| 472 |
+
|
| 473 |
+
out = self.activation(x)
|
| 474 |
+
out = self.conv1(out)
|
| 475 |
+
if self.norm1 is not None:
|
| 476 |
+
out = self.norm1(out)
|
| 477 |
+
|
| 478 |
+
out = self.activation(out)
|
| 479 |
+
out = self.conv2(out)
|
| 480 |
+
if self.norm2 is not None:
|
| 481 |
+
out = self.norm2(out)
|
| 482 |
+
|
| 483 |
+
return self.skip_add.add(out, x)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
class FeatureFusionBlock(nn.Module):
|
| 487 |
+
"""Feature fusion block."""
|
| 488 |
+
|
| 489 |
+
def __init__(
|
| 490 |
+
self,
|
| 491 |
+
features,
|
| 492 |
+
activation,
|
| 493 |
+
deconv=False,
|
| 494 |
+
bn=False,
|
| 495 |
+
expand=False,
|
| 496 |
+
align_corners=True,
|
| 497 |
+
size=None,
|
| 498 |
+
has_residual=True,
|
| 499 |
+
groups=1,
|
| 500 |
+
):
|
| 501 |
+
"""Init.
|
| 502 |
+
|
| 503 |
+
Args:
|
| 504 |
+
features (int): number of features
|
| 505 |
+
"""
|
| 506 |
+
super(FeatureFusionBlock, self).__init__()
|
| 507 |
+
|
| 508 |
+
self.deconv = deconv
|
| 509 |
+
self.align_corners = align_corners
|
| 510 |
+
self.groups = groups
|
| 511 |
+
self.expand = expand
|
| 512 |
+
out_features = features
|
| 513 |
+
if self.expand == True:
|
| 514 |
+
out_features = features // 2
|
| 515 |
+
|
| 516 |
+
self.out_conv = nn.Conv2d(
|
| 517 |
+
features,
|
| 518 |
+
out_features,
|
| 519 |
+
kernel_size=1,
|
| 520 |
+
stride=1,
|
| 521 |
+
padding=0,
|
| 522 |
+
bias=True,
|
| 523 |
+
groups=self.groups,
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
if has_residual:
|
| 527 |
+
self.resConfUnit1 = ResidualConvUnit(
|
| 528 |
+
features, activation, bn, groups=self.groups
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
self.has_residual = has_residual
|
| 532 |
+
self.resConfUnit2 = ResidualConvUnit(
|
| 533 |
+
features, activation, bn, groups=self.groups
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 537 |
+
self.size = size
|
| 538 |
+
|
| 539 |
+
def forward(self, *xs, size=None):
|
| 540 |
+
"""Forward pass.
|
| 541 |
+
|
| 542 |
+
Returns:
|
| 543 |
+
tensor: output
|
| 544 |
+
"""
|
| 545 |
+
output = xs[0]
|
| 546 |
+
|
| 547 |
+
if self.has_residual:
|
| 548 |
+
res = self.resConfUnit1(xs[1])
|
| 549 |
+
output = self.skip_add.add(output, res)
|
| 550 |
+
|
| 551 |
+
output = self.resConfUnit2(output)
|
| 552 |
+
|
| 553 |
+
if (size is None) and (self.size is None):
|
| 554 |
+
modifier = {"scale_factor": 2}
|
| 555 |
+
elif size is None:
|
| 556 |
+
modifier = {"size": self.size}
|
| 557 |
+
else:
|
| 558 |
+
modifier = {"size": size}
|
| 559 |
+
|
| 560 |
+
output = custom_interpolate(
|
| 561 |
+
output, **modifier, mode="bilinear", align_corners=self.align_corners
|
| 562 |
+
)
|
| 563 |
+
output = self.out_conv(output)
|
| 564 |
+
|
| 565 |
+
return output
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def custom_interpolate(
|
| 569 |
+
x: torch.Tensor,
|
| 570 |
+
size: Tuple[int, int] = None,
|
| 571 |
+
scale_factor: float = None,
|
| 572 |
+
mode: str = "bilinear",
|
| 573 |
+
align_corners: bool = True,
|
| 574 |
+
) -> torch.Tensor:
|
| 575 |
+
"""
|
| 576 |
+
Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
|
| 577 |
+
"""
|
| 578 |
+
if size is None:
|
| 579 |
+
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
| 580 |
+
|
| 581 |
+
INT_MAX = 1610612736
|
| 582 |
+
|
| 583 |
+
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
|
| 584 |
+
|
| 585 |
+
if input_elements > INT_MAX:
|
| 586 |
+
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
|
| 587 |
+
interpolated_chunks = [
|
| 588 |
+
nn.functional.interpolate(
|
| 589 |
+
chunk, size=size, mode=mode, align_corners=align_corners
|
| 590 |
+
)
|
| 591 |
+
for chunk in chunks
|
| 592 |
+
]
|
| 593 |
+
x = torch.cat(interpolated_chunks, dim=0)
|
| 594 |
+
return x.contiguous()
|
| 595 |
+
else:
|
| 596 |
+
return nn.functional.interpolate(
|
| 597 |
+
x, size=size, mode=mode, align_corners=align_corners
|
| 598 |
+
)
|
sailrecon/heads/head_act.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def activate_pose(
|
| 13 |
+
pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"
|
| 14 |
+
):
|
| 15 |
+
"""
|
| 16 |
+
Activate pose parameters with specified activation functions.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
|
| 20 |
+
trans_act: Activation type for translation component
|
| 21 |
+
quat_act: Activation type for quaternion component
|
| 22 |
+
fl_act: Activation type for focal length component
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
Activated pose parameters tensor
|
| 26 |
+
"""
|
| 27 |
+
T = pred_pose_enc[..., :3]
|
| 28 |
+
quat = pred_pose_enc[..., 3:7]
|
| 29 |
+
fl = pred_pose_enc[..., 7:] # or fov
|
| 30 |
+
|
| 31 |
+
T = base_pose_act(T, trans_act)
|
| 32 |
+
quat = base_pose_act(quat, quat_act)
|
| 33 |
+
fl = base_pose_act(fl, fl_act) # or fov
|
| 34 |
+
|
| 35 |
+
pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
|
| 36 |
+
|
| 37 |
+
return pred_pose_enc
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def base_pose_act(pose_enc, act_type="linear"):
|
| 41 |
+
"""
|
| 42 |
+
Apply basic activation function to pose parameters.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
pose_enc: Tensor containing encoded pose parameters
|
| 46 |
+
act_type: Activation type ("linear", "inv_log", "exp", "relu")
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Activated pose parameters
|
| 50 |
+
"""
|
| 51 |
+
if act_type == "linear":
|
| 52 |
+
return pose_enc
|
| 53 |
+
elif act_type == "inv_log":
|
| 54 |
+
return inverse_log_transform(pose_enc)
|
| 55 |
+
elif act_type == "exp":
|
| 56 |
+
return torch.exp(pose_enc)
|
| 57 |
+
elif act_type == "relu":
|
| 58 |
+
return F.relu(pose_enc)
|
| 59 |
+
else:
|
| 60 |
+
raise ValueError(f"Unknown act_type: {act_type}")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def activate_head(out, activation="norm_exp", conf_activation="expp1"):
|
| 64 |
+
"""
|
| 65 |
+
Process network output to extract 3D points and confidence values.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
out: Network output tensor (B, C, H, W)
|
| 69 |
+
activation: Activation type for 3D points
|
| 70 |
+
conf_activation: Activation type for confidence values
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
Tuple of (3D points tensor, confidence tensor)
|
| 74 |
+
"""
|
| 75 |
+
# Move channels from last dim to the 4th dimension => (B, H, W, C)
|
| 76 |
+
fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
|
| 77 |
+
|
| 78 |
+
# Split into xyz (first C-1 channels) and confidence (last channel)
|
| 79 |
+
xyz = fmap[:, :, :, :-1]
|
| 80 |
+
conf = fmap[:, :, :, -1]
|
| 81 |
+
|
| 82 |
+
if activation == "norm_exp":
|
| 83 |
+
d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
| 84 |
+
xyz_normed = xyz / d
|
| 85 |
+
pts3d = xyz_normed * torch.expm1(d)
|
| 86 |
+
elif activation == "norm":
|
| 87 |
+
pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
|
| 88 |
+
elif activation == "exp":
|
| 89 |
+
pts3d = torch.exp(xyz)
|
| 90 |
+
elif activation == "relu":
|
| 91 |
+
pts3d = F.relu(xyz)
|
| 92 |
+
elif activation == "inv_log":
|
| 93 |
+
pts3d = inverse_log_transform(xyz)
|
| 94 |
+
elif activation == "xy_inv_log":
|
| 95 |
+
xy, z = xyz.split([2, 1], dim=-1)
|
| 96 |
+
z = inverse_log_transform(z)
|
| 97 |
+
pts3d = torch.cat([xy * z, z], dim=-1)
|
| 98 |
+
elif activation == "sigmoid":
|
| 99 |
+
pts3d = torch.sigmoid(xyz)
|
| 100 |
+
elif activation == "linear":
|
| 101 |
+
pts3d = xyz
|
| 102 |
+
else:
|
| 103 |
+
raise ValueError(f"Unknown activation: {activation}")
|
| 104 |
+
|
| 105 |
+
if conf_activation == "expp1":
|
| 106 |
+
conf_out = 1 + conf.exp()
|
| 107 |
+
elif conf_activation == "expp0":
|
| 108 |
+
conf_out = conf.exp()
|
| 109 |
+
elif conf_activation == "sigmoid":
|
| 110 |
+
conf_out = torch.sigmoid(conf)
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError(f"Unknown conf_activation: {conf_activation}")
|
| 113 |
+
|
| 114 |
+
return pts3d, conf_out
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def inverse_log_transform(y):
|
| 118 |
+
"""
|
| 119 |
+
Apply inverse log transform: sign(y) * (exp(|y|) - 1)
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
y: Input tensor
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
Transformed tensor
|
| 126 |
+
"""
|
| 127 |
+
return torch.sign(y) * (torch.expm1(torch.abs(y)))
|
sailrecon/heads/track_head.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from .dpt_head import DPTHead
|
| 10 |
+
from .track_modules.base_track_predictor import BaseTrackerPredictor
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TrackHead(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
|
| 16 |
+
The tracking is performed iteratively, refining predictions over multiple iterations.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
dim_in,
|
| 22 |
+
patch_size=14,
|
| 23 |
+
features=128,
|
| 24 |
+
iters=4,
|
| 25 |
+
predict_conf=True,
|
| 26 |
+
stride=2,
|
| 27 |
+
corr_levels=7,
|
| 28 |
+
corr_radius=4,
|
| 29 |
+
hidden_size=384,
|
| 30 |
+
):
|
| 31 |
+
"""
|
| 32 |
+
Initialize the TrackHead module.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
dim_in (int): Input dimension of tokens from the backbone.
|
| 36 |
+
patch_size (int): Size of image patches used in the vision transformer.
|
| 37 |
+
features (int): Number of feature channels in the feature extractor output.
|
| 38 |
+
iters (int): Number of refinement iterations for tracking predictions.
|
| 39 |
+
predict_conf (bool): Whether to predict confidence scores for tracked points.
|
| 40 |
+
stride (int): Stride value for the tracker predictor.
|
| 41 |
+
corr_levels (int): Number of correlation pyramid levels
|
| 42 |
+
corr_radius (int): Radius for correlation computation, controlling the search area.
|
| 43 |
+
hidden_size (int): Size of hidden layers in the tracker network.
|
| 44 |
+
"""
|
| 45 |
+
super().__init__()
|
| 46 |
+
|
| 47 |
+
self.patch_size = patch_size
|
| 48 |
+
|
| 49 |
+
# Feature extractor based on DPT architecture
|
| 50 |
+
# Processes tokens into feature maps for tracking
|
| 51 |
+
self.feature_extractor = DPTHead(
|
| 52 |
+
dim_in=dim_in,
|
| 53 |
+
patch_size=patch_size,
|
| 54 |
+
features=features,
|
| 55 |
+
feature_only=True, # Only output features, no activation
|
| 56 |
+
down_ratio=2, # Reduces spatial dimensions by factor of 2
|
| 57 |
+
pos_embed=False,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Tracker module that predicts point trajectories
|
| 61 |
+
# Takes feature maps and predicts coordinates and visibility
|
| 62 |
+
self.tracker = BaseTrackerPredictor(
|
| 63 |
+
latent_dim=features, # Match the output_dim of feature extractor
|
| 64 |
+
predict_conf=predict_conf,
|
| 65 |
+
stride=stride,
|
| 66 |
+
corr_levels=corr_levels,
|
| 67 |
+
corr_radius=corr_radius,
|
| 68 |
+
hidden_size=hidden_size,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.iters = iters
|
| 72 |
+
|
| 73 |
+
def forward(
|
| 74 |
+
self,
|
| 75 |
+
aggregated_tokens_list,
|
| 76 |
+
images,
|
| 77 |
+
patch_start_idx,
|
| 78 |
+
query_points=None,
|
| 79 |
+
iters=None,
|
| 80 |
+
):
|
| 81 |
+
"""
|
| 82 |
+
Forward pass of the TrackHead.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
aggregated_tokens_list (list): List of aggregated tokens from the backbone.
|
| 86 |
+
images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
|
| 87 |
+
B = batch size, S = sequence length.
|
| 88 |
+
patch_start_idx (int): Starting index for patch tokens.
|
| 89 |
+
query_points (torch.Tensor, optional): Initial query points to track.
|
| 90 |
+
If None, points are initialized by the tracker.
|
| 91 |
+
iters (int, optional): Number of refinement iterations. If None, uses self.iters.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
tuple:
|
| 95 |
+
- coord_preds (torch.Tensor): Predicted coordinates for tracked points.
|
| 96 |
+
- vis_scores (torch.Tensor): Visibility scores for tracked points.
|
| 97 |
+
- conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
|
| 98 |
+
"""
|
| 99 |
+
B, S, _, H, W = images.shape
|
| 100 |
+
|
| 101 |
+
# Extract features from tokens
|
| 102 |
+
# feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
|
| 103 |
+
feature_maps = self.feature_extractor(
|
| 104 |
+
aggregated_tokens_list, images, patch_start_idx
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Use default iterations if not specified
|
| 108 |
+
if iters is None:
|
| 109 |
+
iters = self.iters
|
| 110 |
+
|
| 111 |
+
# Perform tracking using the extracted features
|
| 112 |
+
coord_preds, vis_scores, conf_scores = self.tracker(
|
| 113 |
+
query_points=query_points, fmaps=feature_maps, iters=iters
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
return coord_preds, vis_scores, conf_scores
|
sailrecon/heads/track_modules/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
sailrecon/heads/track_modules/base_track_predictor.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from einops import rearrange, repeat
|
| 10 |
+
|
| 11 |
+
from .blocks import CorrBlock, EfficientUpdateFormer
|
| 12 |
+
from .modules import Mlp
|
| 13 |
+
from .utils import get_2d_embedding, get_2d_sincos_pos_embed, sample_features4d
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BaseTrackerPredictor(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
stride=1,
|
| 20 |
+
corr_levels=5,
|
| 21 |
+
corr_radius=4,
|
| 22 |
+
latent_dim=128,
|
| 23 |
+
hidden_size=384,
|
| 24 |
+
use_spaceatt=True,
|
| 25 |
+
depth=6,
|
| 26 |
+
max_scale=518,
|
| 27 |
+
predict_conf=True,
|
| 28 |
+
):
|
| 29 |
+
super(BaseTrackerPredictor, self).__init__()
|
| 30 |
+
"""
|
| 31 |
+
The base template to create a track predictor
|
| 32 |
+
|
| 33 |
+
Modified from https://github.com/facebookresearch/co-tracker/
|
| 34 |
+
and https://github.com/facebookresearch/vggsfm
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
self.stride = stride
|
| 38 |
+
self.latent_dim = latent_dim
|
| 39 |
+
self.corr_levels = corr_levels
|
| 40 |
+
self.corr_radius = corr_radius
|
| 41 |
+
self.hidden_size = hidden_size
|
| 42 |
+
self.max_scale = max_scale
|
| 43 |
+
self.predict_conf = predict_conf
|
| 44 |
+
|
| 45 |
+
self.flows_emb_dim = latent_dim // 2
|
| 46 |
+
|
| 47 |
+
self.corr_mlp = Mlp(
|
| 48 |
+
in_features=self.corr_levels * (self.corr_radius * 2 + 1) ** 2,
|
| 49 |
+
hidden_features=self.hidden_size,
|
| 50 |
+
out_features=self.latent_dim,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
self.transformer_dim = self.latent_dim + self.latent_dim + self.latent_dim + 4
|
| 54 |
+
|
| 55 |
+
self.query_ref_token = nn.Parameter(torch.randn(1, 2, self.transformer_dim))
|
| 56 |
+
|
| 57 |
+
space_depth = depth if use_spaceatt else 0
|
| 58 |
+
time_depth = depth
|
| 59 |
+
|
| 60 |
+
self.updateformer = EfficientUpdateFormer(
|
| 61 |
+
space_depth=space_depth,
|
| 62 |
+
time_depth=time_depth,
|
| 63 |
+
input_dim=self.transformer_dim,
|
| 64 |
+
hidden_size=self.hidden_size,
|
| 65 |
+
output_dim=self.latent_dim + 2,
|
| 66 |
+
mlp_ratio=4.0,
|
| 67 |
+
add_space_attn=use_spaceatt,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
self.fmap_norm = nn.LayerNorm(self.latent_dim)
|
| 71 |
+
self.ffeat_norm = nn.GroupNorm(1, self.latent_dim)
|
| 72 |
+
|
| 73 |
+
# A linear layer to update track feats at each iteration
|
| 74 |
+
self.ffeat_updater = nn.Sequential(
|
| 75 |
+
nn.Linear(self.latent_dim, self.latent_dim), nn.GELU()
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
self.vis_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
|
| 79 |
+
|
| 80 |
+
if predict_conf:
|
| 81 |
+
self.conf_predictor = nn.Sequential(nn.Linear(self.latent_dim, 1))
|
| 82 |
+
|
| 83 |
+
def forward(
|
| 84 |
+
self,
|
| 85 |
+
query_points,
|
| 86 |
+
fmaps=None,
|
| 87 |
+
iters=6,
|
| 88 |
+
return_feat=False,
|
| 89 |
+
down_ratio=1,
|
| 90 |
+
apply_sigmoid=True,
|
| 91 |
+
):
|
| 92 |
+
"""
|
| 93 |
+
query_points: B x N x 2, the number of batches, tracks, and xy
|
| 94 |
+
fmaps: B x S x C x HH x WW, the number of batches, frames, and feature dimension.
|
| 95 |
+
note HH and WW is the size of feature maps instead of original images
|
| 96 |
+
"""
|
| 97 |
+
B, N, D = query_points.shape
|
| 98 |
+
B, S, C, HH, WW = fmaps.shape
|
| 99 |
+
|
| 100 |
+
assert D == 2, "Input points must be 2D coordinates"
|
| 101 |
+
|
| 102 |
+
# apply a layernorm to fmaps here
|
| 103 |
+
fmaps = self.fmap_norm(fmaps.permute(0, 1, 3, 4, 2))
|
| 104 |
+
fmaps = fmaps.permute(0, 1, 4, 2, 3)
|
| 105 |
+
|
| 106 |
+
# Scale the input query_points because we may downsample the images
|
| 107 |
+
# by down_ratio or self.stride
|
| 108 |
+
# e.g., if a 3x1024x1024 image is processed to a 128x256x256 feature map
|
| 109 |
+
# its query_points should be query_points/4
|
| 110 |
+
if down_ratio > 1:
|
| 111 |
+
query_points = query_points / float(down_ratio)
|
| 112 |
+
|
| 113 |
+
query_points = query_points / float(self.stride)
|
| 114 |
+
|
| 115 |
+
# Init with coords as the query points
|
| 116 |
+
# It means the search will start from the position of query points at the reference frames
|
| 117 |
+
coords = query_points.clone().reshape(B, 1, N, 2).repeat(1, S, 1, 1)
|
| 118 |
+
|
| 119 |
+
# Sample/extract the features of the query points in the query frame
|
| 120 |
+
query_track_feat = sample_features4d(fmaps[:, 0], coords[:, 0])
|
| 121 |
+
|
| 122 |
+
# init track feats by query feats
|
| 123 |
+
track_feats = query_track_feat.unsqueeze(1).repeat(1, S, 1, 1) # B, S, N, C
|
| 124 |
+
# back up the init coords
|
| 125 |
+
coords_backup = coords.clone()
|
| 126 |
+
|
| 127 |
+
fcorr_fn = CorrBlock(
|
| 128 |
+
fmaps, num_levels=self.corr_levels, radius=self.corr_radius
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
coord_preds = []
|
| 132 |
+
|
| 133 |
+
# Iterative Refinement
|
| 134 |
+
for _ in range(iters):
|
| 135 |
+
# Detach the gradients from the last iteration
|
| 136 |
+
# (in my experience, not very important for performance)
|
| 137 |
+
coords = coords.detach()
|
| 138 |
+
|
| 139 |
+
fcorrs = fcorr_fn.corr_sample(track_feats, coords)
|
| 140 |
+
|
| 141 |
+
corr_dim = fcorrs.shape[3]
|
| 142 |
+
fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, corr_dim)
|
| 143 |
+
fcorrs_ = self.corr_mlp(fcorrs_)
|
| 144 |
+
|
| 145 |
+
# Movement of current coords relative to query points
|
| 146 |
+
flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
|
| 147 |
+
|
| 148 |
+
flows_emb = get_2d_embedding(flows, self.flows_emb_dim, cat_coords=False)
|
| 149 |
+
|
| 150 |
+
# (In my trials, it is also okay to just add the flows_emb instead of concat)
|
| 151 |
+
flows_emb = torch.cat(
|
| 152 |
+
[flows_emb, flows / self.max_scale, flows / self.max_scale], dim=-1
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
track_feats_ = track_feats.permute(0, 2, 1, 3).reshape(
|
| 156 |
+
B * N, S, self.latent_dim
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Concatenate them as the input for the transformers
|
| 160 |
+
transformer_input = torch.cat([flows_emb, fcorrs_, track_feats_], dim=2)
|
| 161 |
+
|
| 162 |
+
# 2D positional embed
|
| 163 |
+
# TODO: this can be much simplified
|
| 164 |
+
pos_embed = get_2d_sincos_pos_embed(
|
| 165 |
+
self.transformer_dim, grid_size=(HH, WW)
|
| 166 |
+
).to(query_points.device)
|
| 167 |
+
sampled_pos_emb = sample_features4d(
|
| 168 |
+
pos_embed.expand(B, -1, -1, -1), coords[:, 0]
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
sampled_pos_emb = rearrange(sampled_pos_emb, "b n c -> (b n) c").unsqueeze(
|
| 172 |
+
1
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
x = transformer_input + sampled_pos_emb
|
| 176 |
+
|
| 177 |
+
# Add the query ref token to the track feats
|
| 178 |
+
query_ref_token = torch.cat(
|
| 179 |
+
[
|
| 180 |
+
self.query_ref_token[:, 0:1],
|
| 181 |
+
self.query_ref_token[:, 1:2].expand(-1, S - 1, -1),
|
| 182 |
+
],
|
| 183 |
+
dim=1,
|
| 184 |
+
)
|
| 185 |
+
x = x + query_ref_token.to(x.device).to(x.dtype)
|
| 186 |
+
|
| 187 |
+
# B, N, S, C
|
| 188 |
+
x = rearrange(x, "(b n) s d -> b n s d", b=B)
|
| 189 |
+
|
| 190 |
+
# Compute the delta coordinates and delta track features
|
| 191 |
+
delta, _ = self.updateformer(x)
|
| 192 |
+
|
| 193 |
+
# BN, S, C
|
| 194 |
+
delta = rearrange(delta, " b n s d -> (b n) s d", b=B)
|
| 195 |
+
delta_coords_ = delta[:, :, :2]
|
| 196 |
+
delta_feats_ = delta[:, :, 2:]
|
| 197 |
+
|
| 198 |
+
track_feats_ = track_feats_.reshape(B * N * S, self.latent_dim)
|
| 199 |
+
delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
|
| 200 |
+
|
| 201 |
+
# Update the track features
|
| 202 |
+
track_feats_ = (
|
| 203 |
+
self.ffeat_updater(self.ffeat_norm(delta_feats_)) + track_feats_
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
track_feats = track_feats_.reshape(B, N, S, self.latent_dim).permute(
|
| 207 |
+
0, 2, 1, 3
|
| 208 |
+
) # BxSxNxC
|
| 209 |
+
|
| 210 |
+
# B x S x N x 2
|
| 211 |
+
coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
|
| 212 |
+
|
| 213 |
+
# Force coord0 as query
|
| 214 |
+
# because we assume the query points should not be changed
|
| 215 |
+
coords[:, 0] = coords_backup[:, 0]
|
| 216 |
+
|
| 217 |
+
# The predicted tracks are in the original image scale
|
| 218 |
+
if down_ratio > 1:
|
| 219 |
+
coord_preds.append(coords * self.stride * down_ratio)
|
| 220 |
+
else:
|
| 221 |
+
coord_preds.append(coords * self.stride)
|
| 222 |
+
|
| 223 |
+
# B, S, N
|
| 224 |
+
vis_e = self.vis_predictor(
|
| 225 |
+
track_feats.reshape(B * S * N, self.latent_dim)
|
| 226 |
+
).reshape(B, S, N)
|
| 227 |
+
if apply_sigmoid:
|
| 228 |
+
vis_e = torch.sigmoid(vis_e)
|
| 229 |
+
|
| 230 |
+
if self.predict_conf:
|
| 231 |
+
conf_e = self.conf_predictor(
|
| 232 |
+
track_feats.reshape(B * S * N, self.latent_dim)
|
| 233 |
+
).reshape(B, S, N)
|
| 234 |
+
if apply_sigmoid:
|
| 235 |
+
conf_e = torch.sigmoid(conf_e)
|
| 236 |
+
else:
|
| 237 |
+
conf_e = None
|
| 238 |
+
|
| 239 |
+
if return_feat:
|
| 240 |
+
return coord_preds, vis_e, track_feats, query_track_feat, conf_e
|
| 241 |
+
else:
|
| 242 |
+
return coord_preds, vis_e, conf_e
|