Spaces:
Running
Running
init LaRI demo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +195 -0
- app.py +282 -0
- demo.py +136 -0
- requirements.txt +19 -0
- src/lari/model/__init__.py +2 -0
- src/lari/model/blocks.py +209 -0
- src/lari/model/dinoseg_model.py +153 -0
- src/lari/model/dinov2/__init__.py +6 -0
- src/lari/model/dinov2/hub/__init__.py +4 -0
- src/lari/model/dinov2/hub/backbones.py +156 -0
- src/lari/model/dinov2/hub/utils.py +39 -0
- src/lari/model/dinov2/layers/__init__.py +11 -0
- src/lari/model/dinov2/layers/attention.py +89 -0
- src/lari/model/dinov2/layers/block.py +259 -0
- src/lari/model/dinov2/layers/dino_head.py +58 -0
- src/lari/model/dinov2/layers/drop_path.py +34 -0
- src/lari/model/dinov2/layers/layer_scale.py +27 -0
- src/lari/model/dinov2/layers/mlp.py +40 -0
- src/lari/model/dinov2/layers/patch_embed.py +88 -0
- src/lari/model/dinov2/layers/swiglu_ffn.py +72 -0
- src/lari/model/dinov2/models/__init__.py +43 -0
- src/lari/model/dinov2/models/vision_transformer.py +396 -0
- src/lari/model/dinov2/utils/__init__.py +4 -0
- src/lari/model/dinov2/utils/cluster.py +95 -0
- src/lari/model/dinov2/utils/config.py +72 -0
- src/lari/model/dinov2/utils/dtype.py +37 -0
- src/lari/model/dinov2/utils/param_groups.py +103 -0
- src/lari/model/dinov2/utils/utils.py +95 -0
- src/lari/model/dpt_seg_head.py +158 -0
- src/lari/model/heads.py +104 -0
- src/lari/model/lari_model.py +177 -0
- src/lari/model/utils.py +38 -0
- src/lari/utils/__init__.py +0 -0
- src/lari/utils/geometry_numpy.py +187 -0
- src/lari/utils/geometry_torch.py +221 -0
- src/utils/__init__.py +2 -0
- src/utils/vis.py +105 -0
- src/utils3d/README.md +3 -0
- src/utils3d/__init__.py +20 -0
- src/utils3d/_helpers.py +35 -0
- src/utils3d/_unified/__init__.py +934 -0
- src/utils3d/_unified/__init__.pyi +0 -0
- src/utils3d/io/__init__.py +3 -0
- src/utils3d/io/colmap.py +139 -0
- src/utils3d/io/obj.py +146 -0
- src/utils3d/io/ply.py +104 -0
- src/utils3d/numpy/__init__.py +142 -0
- src/utils3d/numpy/_helpers.py +93 -0
- src/utils3d/numpy/mesh.py +355 -0
- src/utils3d/numpy/quadmesh.py +472 -0
.gitignore
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
scripts/rendering/blender-*
|
2 |
+
|
3 |
+
# Byte-compiled / optimized / DLL files
|
4 |
+
__pycache__/
|
5 |
+
*.py[cod]
|
6 |
+
*$py.class
|
7 |
+
.vscode/
|
8 |
+
|
9 |
+
# C extensions
|
10 |
+
*.so
|
11 |
+
|
12 |
+
# Distribution / packaging
|
13 |
+
.Python
|
14 |
+
build/
|
15 |
+
develop-eggs/
|
16 |
+
dist/
|
17 |
+
downloads/
|
18 |
+
eggs/
|
19 |
+
.eggs/
|
20 |
+
lib/
|
21 |
+
lib64/
|
22 |
+
parts/
|
23 |
+
sdist/
|
24 |
+
var/
|
25 |
+
wheels/
|
26 |
+
share/python-wheels/
|
27 |
+
*.egg-info/
|
28 |
+
.installed.cfg
|
29 |
+
*.egg
|
30 |
+
MANIFEST
|
31 |
+
|
32 |
+
# PyInstaller
|
33 |
+
# Usually these files are written by a python script from a template
|
34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
35 |
+
*.manifest
|
36 |
+
*.spec
|
37 |
+
|
38 |
+
# Installer logs
|
39 |
+
pip-log.txt
|
40 |
+
pip-delete-this-directory.txt
|
41 |
+
|
42 |
+
# Unit test / coverage reports
|
43 |
+
htmlcov/
|
44 |
+
.tox/
|
45 |
+
.nox/
|
46 |
+
.coverage
|
47 |
+
.coverage.*
|
48 |
+
.cache
|
49 |
+
nosetests.xml
|
50 |
+
coverage.xml
|
51 |
+
*.cover
|
52 |
+
*.py,cover
|
53 |
+
.hypothesis/
|
54 |
+
.pytest_cache/
|
55 |
+
cover/
|
56 |
+
|
57 |
+
# Translations
|
58 |
+
*.mo
|
59 |
+
*.pot
|
60 |
+
|
61 |
+
# Django stuff:
|
62 |
+
*.log
|
63 |
+
local_settings.py
|
64 |
+
db.sqlite3
|
65 |
+
db.sqlite3-journal
|
66 |
+
|
67 |
+
# Flask stuff:
|
68 |
+
instance/
|
69 |
+
.webassets-cache
|
70 |
+
|
71 |
+
# Scrapy stuff:
|
72 |
+
.scrapy
|
73 |
+
|
74 |
+
# Sphinx documentation
|
75 |
+
docs/_build/
|
76 |
+
|
77 |
+
# PyBuilder
|
78 |
+
.pybuilder/
|
79 |
+
target/
|
80 |
+
|
81 |
+
# Jupyter Notebook
|
82 |
+
.ipynb_checkpoints
|
83 |
+
|
84 |
+
# IPython
|
85 |
+
profile_default/
|
86 |
+
ipython_config.py
|
87 |
+
|
88 |
+
# pyenv
|
89 |
+
# For a library or package, you might want to ignore these files since the code is
|
90 |
+
# intended to run in multiple environments; otherwise, check them in:
|
91 |
+
# .python-version
|
92 |
+
|
93 |
+
# pipenv
|
94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
97 |
+
# install all needed dependencies.
|
98 |
+
#Pipfile.lock
|
99 |
+
|
100 |
+
# poetry
|
101 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
102 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
103 |
+
# commonly ignored for libraries.
|
104 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
105 |
+
#poetry.lock
|
106 |
+
|
107 |
+
# pdm
|
108 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
109 |
+
#pdm.lock
|
110 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
111 |
+
# in version control.
|
112 |
+
# https://pdm.fming.dev/#use-with-ide
|
113 |
+
.pdm.toml
|
114 |
+
|
115 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
116 |
+
__pypackages__/
|
117 |
+
|
118 |
+
# Celery stuff
|
119 |
+
celerybeat-schedule
|
120 |
+
celerybeat.pid
|
121 |
+
|
122 |
+
# SageMath parsed files
|
123 |
+
*.sage.py
|
124 |
+
|
125 |
+
# Environments
|
126 |
+
.env
|
127 |
+
.venv
|
128 |
+
env/
|
129 |
+
venv/
|
130 |
+
ENV/
|
131 |
+
env.bak/
|
132 |
+
venv.bak/
|
133 |
+
|
134 |
+
# Spyder project settings
|
135 |
+
.spyderproject
|
136 |
+
.spyproject
|
137 |
+
|
138 |
+
# Rope project settings
|
139 |
+
.ropeproject
|
140 |
+
|
141 |
+
# mkdocs documentation
|
142 |
+
/site
|
143 |
+
|
144 |
+
# mypy
|
145 |
+
.mypy_cache/
|
146 |
+
.dmypy.json
|
147 |
+
dmypy.json
|
148 |
+
|
149 |
+
# Pyre type checker
|
150 |
+
.pyre/
|
151 |
+
|
152 |
+
# pytype static type analyzer
|
153 |
+
.pytype/
|
154 |
+
|
155 |
+
# Cython debug symbols
|
156 |
+
cython_debug/
|
157 |
+
|
158 |
+
# PyCharm
|
159 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
160 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
161 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
162 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
163 |
+
#.idea/
|
164 |
+
*.sif
|
165 |
+
blender-4.2.5-linux-x64*/
|
166 |
+
*.zip
|
167 |
+
*.png
|
168 |
+
*.jpg
|
169 |
+
*.log
|
170 |
+
intermediate/
|
171 |
+
__pycache__
|
172 |
+
*.ply
|
173 |
+
*.npy
|
174 |
+
*.npz
|
175 |
+
*.obj
|
176 |
+
*.mtl
|
177 |
+
*.json.gz
|
178 |
+
dcgm/
|
179 |
+
wandb/
|
180 |
+
# *.json
|
181 |
+
|
182 |
+
|
183 |
+
|
184 |
+
# Exception to add training list
|
185 |
+
!lgm_leq20Kpts_simtopo25_train.json.gz
|
186 |
+
!lgm_leq20Kpts_simtopo25_test.json.gz
|
187 |
+
!lgm_leq20Kpts_train.json.gz
|
188 |
+
!lgm_leq20Kpts_train_same_size_wrt_simtopo.json.gz
|
189 |
+
*.mp4
|
190 |
+
*.gif
|
191 |
+
*.glb
|
192 |
+
!lgm_leq20Kpts_plus_3Kremain_train.json.gz
|
193 |
+
!lgm_leq20Kpts_test_cleaned.json.gz
|
194 |
+
!lgm_leq20Kpts_train_cleaned.json.gz
|
195 |
+
test_metrics.json
|
app.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import gradio
|
3 |
+
import torch
|
4 |
+
import torch.backends.cudnn as cudnn
|
5 |
+
from src.utils.vis import prob_to_mask
|
6 |
+
from src.lari.model import LaRIModel, DinoSegModel
|
7 |
+
from tools import load_model, process_image, post_process_output, get_masked_depth, save_to_glb, get_point_cloud, removebg_crop
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
|
10 |
+
parser = argparse.ArgumentParser("Arguments for deploying a LaRI Demo")
|
11 |
+
|
12 |
+
parser.add_argument(
|
13 |
+
"--model_info_pm",
|
14 |
+
type=str,
|
15 |
+
default="LaRIModel(use_pretrained = 'moge_full', num_output_layer = 5, head_type = 'point')",
|
16 |
+
help="Network parameters to load the model",
|
17 |
+
)
|
18 |
+
|
19 |
+
parser.add_argument(
|
20 |
+
"--model_info_mask",
|
21 |
+
type=str,
|
22 |
+
default="DinoSegModel(use_pretrained = 'dinov2', dim_proj = 256, pretrained_path = '', num_output_layer = 4, output_type = 'ray_stop')",
|
23 |
+
help="Network parameters to load the model",
|
24 |
+
)
|
25 |
+
|
26 |
+
parser.add_argument(
|
27 |
+
"--ckpt_path_pm",
|
28 |
+
type=str,
|
29 |
+
default="lari_obj_16k_pointmap.pth",
|
30 |
+
help="Path to pre-trained weights",
|
31 |
+
)
|
32 |
+
|
33 |
+
parser.add_argument(
|
34 |
+
"--ckpt_path_mask",
|
35 |
+
type=str,
|
36 |
+
default="lari_obj_16k_seg.pth",
|
37 |
+
help="Path to pre-trained weights",
|
38 |
+
)
|
39 |
+
|
40 |
+
parser.add_argument(
|
41 |
+
"--resolution", type=int, default=512, help="Default model resolution"
|
42 |
+
)
|
43 |
+
args = parser.parse_args()
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
def model_forward(pil_input, layered_id, rembg_checkbox):
|
48 |
+
"""
|
49 |
+
Perform LaRI estimation by:
|
50 |
+
1. image processing
|
51 |
+
2. network forward
|
52 |
+
3. save masked layered depth image
|
53 |
+
4. save point cloud
|
54 |
+
"""
|
55 |
+
if pil_input is None:
|
56 |
+
return (None, None, None, None, None, None)
|
57 |
+
|
58 |
+
if rembg_checkbox:
|
59 |
+
pil_input = removebg_crop(pil_input)
|
60 |
+
|
61 |
+
# Process the input image.
|
62 |
+
input_tensor, ori_img_tensor, crop_coords, original_size = process_image(
|
63 |
+
pil_input, resolution=512
|
64 |
+
)
|
65 |
+
input_tensor = input_tensor.to(device)
|
66 |
+
|
67 |
+
# Run inference.
|
68 |
+
with torch.no_grad():
|
69 |
+
# lari map
|
70 |
+
pred_dict = model_pm(input_tensor)
|
71 |
+
lari_map = -pred_dict["pts3d"].squeeze(
|
72 |
+
0
|
73 |
+
) # Expected output shape: (H_reso, W_reso, L, 3)
|
74 |
+
# mask
|
75 |
+
if model_mask:
|
76 |
+
pred_dict = model_mask(input_tensor)
|
77 |
+
assert "seg_prob" in pred_dict
|
78 |
+
valid_mask = prob_to_mask(pred_dict["seg_prob"].squeeze(0)) # H W L 1
|
79 |
+
else:
|
80 |
+
h, w, l, _ = lari_map.shape
|
81 |
+
valid_mask = torch.new_ones((h, w, l, 1), device=lari_map.device)
|
82 |
+
|
83 |
+
# crop & resize the output to the original resolution.
|
84 |
+
if original_size[0] != args.resolution or original_size[1] != args.resolution:
|
85 |
+
lari_map = post_process_output(lari_map, crop_coords, original_size) # H W L 3
|
86 |
+
valid_mask = post_process_output(
|
87 |
+
valid_mask.float(), crop_coords, original_size
|
88 |
+
).bool() # H W L 1
|
89 |
+
|
90 |
+
max_n_layer = min(valid_mask.shape[-2], lari_map.shape[-2])
|
91 |
+
valid_mask = valid_mask[:, :, :max_n_layer, :]
|
92 |
+
lari_map = lari_map[:, :, :max_n_layer, :]
|
93 |
+
|
94 |
+
curr_layer_id = min(max_n_layer - 1, layered_id - 1)
|
95 |
+
|
96 |
+
# masked depth list
|
97 |
+
depth_image = get_masked_depth(
|
98 |
+
lari_map=lari_map, valid_mask=valid_mask, layer_id=curr_layer_id
|
99 |
+
)
|
100 |
+
# point cloud
|
101 |
+
glb_path, ply_path = get_point_cloud(
|
102 |
+
lari_map, ori_img_tensor, valid_mask, first_layer_color="pseudo"
|
103 |
+
)
|
104 |
+
|
105 |
+
return (
|
106 |
+
depth_image,
|
107 |
+
glb_path,
|
108 |
+
lari_map,
|
109 |
+
valid_mask,
|
110 |
+
0,
|
111 |
+
max_n_layer - 1,
|
112 |
+
glb_path,
|
113 |
+
ply_path,
|
114 |
+
pil_input,
|
115 |
+
)
|
116 |
+
|
117 |
+
|
118 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
119 |
+
cudnn.benchmark = True
|
120 |
+
|
121 |
+
|
122 |
+
# Download the file
|
123 |
+
model_path_pm = hf_hub_download(repo_id="ruili3/LaRI", filename=args.ckpt_path_pm, repo_type="model")
|
124 |
+
model_path_mask = hf_hub_download(repo_id="ruili3/LaRI", filename=args.ckpt_path_mask, repo_type="model")
|
125 |
+
|
126 |
+
|
127 |
+
# Load the model with pretrained weights.
|
128 |
+
model_pm = load_model(args.model_info_pm, model_path_pm, device)
|
129 |
+
model_mask = (
|
130 |
+
load_model(args.model_info_mask, model_path_mask, device)
|
131 |
+
if args.model_info_mask is not None
|
132 |
+
else None
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
def change_layer(slider_layer_id, lari_map, valid_mask, min_layer_id, max_layer_id):
|
137 |
+
|
138 |
+
if lari_map is None:
|
139 |
+
return
|
140 |
+
|
141 |
+
slider_layer_id = slider_layer_id - 1
|
142 |
+
curr_layer_id = min(slider_layer_id, max_layer_id)
|
143 |
+
curr_layer_id = max(curr_layer_id, min_layer_id)
|
144 |
+
|
145 |
+
# masked depth list
|
146 |
+
depth_image = get_masked_depth(
|
147 |
+
lari_map=lari_map, valid_mask=valid_mask, layer_id=curr_layer_id
|
148 |
+
)
|
149 |
+
|
150 |
+
return depth_image
|
151 |
+
|
152 |
+
|
153 |
+
def clear_everything():
|
154 |
+
return (
|
155 |
+
gradio.update(value=None),
|
156 |
+
gradio.update(value=None),
|
157 |
+
gradio.update(value=None),
|
158 |
+
gradio.update(value=None),
|
159 |
+
gradio.update(value=None),
|
160 |
+
gradio.update(value=None),
|
161 |
+
gradio.update(value=None),
|
162 |
+
)
|
163 |
+
|
164 |
+
|
165 |
+
with gradio.Blocks(
|
166 |
+
css=""".gradio-container {margin: 0 !important; min-width: 100%};""",
|
167 |
+
title="LaRI Demo",
|
168 |
+
) as demo:
|
169 |
+
|
170 |
+
gradio.Markdown(
|
171 |
+
"<h1 style='text-align: center;'>LaRI: Layered Ray Intersections for Single-view 3D Geometric Reasoning</h1>",
|
172 |
+
elem_id="title",
|
173 |
+
)
|
174 |
+
|
175 |
+
gradio.Markdown(
|
176 |
+
"""
|
177 |
+
This is the official demo of Layered Ray Intersection (<a href="https://ruili3.github.io/lari/index.html" target="_blank" style="color: #2a9d8f;">LaRI</a>). For a quick start, click the images in 'Examples' and then click the 'Process' Button.
|
178 |
+
|
179 |
+
You can try with your own images with following steps:
|
180 |
+
- Load an image;
|
181 |
+
- Click the 'Process' button;
|
182 |
+
- Browse layered depth maps (z-channel of the resulting LaRI point map) by tunning 'Layer ID';
|
183 |
+
|
184 |
+
Note that in '3D Point Cloud', different color denotes diffrent intersection layers, i.e., <b style="color: #FFBD1C;">layer 1</b>, <b style="color: #FB5607;">layer 2</b>, <b style="color: #F15BB5;">layer 3</b>, <b style="color: #8338EC;">layer 4</b>.
|
185 |
+
"""
|
186 |
+
)
|
187 |
+
|
188 |
+
# , <b style="color: #3A86FF;">layer 5</b>.
|
189 |
+
lari_map = gradio.State(None)
|
190 |
+
valid_mask = gradio.State(None)
|
191 |
+
min_layer_id = gradio.State(None)
|
192 |
+
max_layer_id = gradio.State(None)
|
193 |
+
|
194 |
+
with gradio.Column():
|
195 |
+
with gradio.Row(equal_height=True):
|
196 |
+
with gradio.Column(scale=1):
|
197 |
+
image_input = gradio.Image(
|
198 |
+
label="Upload an Image", type="pil", height=350
|
199 |
+
)
|
200 |
+
with gradio.Row():
|
201 |
+
rembg_checkbox = gradio.Checkbox(label="Remove background")
|
202 |
+
clear_button = gradio.Button("Clear")
|
203 |
+
submit_btn = gradio.Button("Process")
|
204 |
+
with gradio.Column(scale=1):
|
205 |
+
depth_output = gradio.Image(
|
206 |
+
label="LaRI Map at Z-axis (depth)",
|
207 |
+
type="pil",
|
208 |
+
interactive=False,
|
209 |
+
height=300,
|
210 |
+
)
|
211 |
+
slider_layer_id = gradio.Slider(
|
212 |
+
minimum=1,
|
213 |
+
maximum=4,
|
214 |
+
step=1,
|
215 |
+
value=1,
|
216 |
+
label="Layer ID",
|
217 |
+
interactive=True,
|
218 |
+
)
|
219 |
+
|
220 |
+
with gradio.Row(scale=1):
|
221 |
+
outmodel = gradio.Model3D(
|
222 |
+
label="3D Point Cloud (Color denotes different layers)",
|
223 |
+
interactive=False,
|
224 |
+
zoom_speed=0.5,
|
225 |
+
pan_speed=0.5,
|
226 |
+
height=450,
|
227 |
+
)
|
228 |
+
|
229 |
+
with gradio.Row():
|
230 |
+
ply_file_output = gradio.File(label="ply output", elem_classes="small-file")
|
231 |
+
glb_file_output = gradio.File(label="glb output", elem_classes="small-file")
|
232 |
+
|
233 |
+
submit_btn.click(
|
234 |
+
fn=model_forward,
|
235 |
+
inputs=[image_input, slider_layer_id, rembg_checkbox],
|
236 |
+
outputs=[
|
237 |
+
depth_output,
|
238 |
+
outmodel,
|
239 |
+
lari_map,
|
240 |
+
valid_mask,
|
241 |
+
min_layer_id,
|
242 |
+
max_layer_id,
|
243 |
+
glb_file_output,
|
244 |
+
ply_file_output,
|
245 |
+
image_input,
|
246 |
+
],
|
247 |
+
)
|
248 |
+
|
249 |
+
clear_button.click(
|
250 |
+
fn=clear_everything,
|
251 |
+
outputs=[
|
252 |
+
lari_map,
|
253 |
+
valid_mask,
|
254 |
+
min_layer_id,
|
255 |
+
max_layer_id,
|
256 |
+
image_input,
|
257 |
+
depth_output,
|
258 |
+
outmodel,
|
259 |
+
],
|
260 |
+
)
|
261 |
+
|
262 |
+
slider_layer_id.change(
|
263 |
+
fn=change_layer,
|
264 |
+
inputs=[slider_layer_id, lari_map, valid_mask, min_layer_id, max_layer_id],
|
265 |
+
outputs=depth_output,
|
266 |
+
)
|
267 |
+
|
268 |
+
gradio.Examples(examples=["assets/cole_hardware.png",
|
269 |
+
"assets/3m_tape.png",
|
270 |
+
"assets/horse.png",
|
271 |
+
"assets/rhino.png",
|
272 |
+
"assets/alphabet.png",
|
273 |
+
"assets/martin_wedge.png",
|
274 |
+
"assets/d_rose.png",
|
275 |
+
"assets/ace.png",
|
276 |
+
"assets/bifidus.png",
|
277 |
+
"assets/fem.png",
|
278 |
+
],
|
279 |
+
inputs=image_input)
|
280 |
+
|
281 |
+
|
282 |
+
demo.launch(share=False)
|
demo.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import torch.backends.cudnn as cudnn
|
5 |
+
from PIL import Image
|
6 |
+
from src.utils.vis import prob_to_mask
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
from tools import load_model, process_image, post_process_output, get_masked_depth, get_point_cloud, removebg_crop
|
9 |
+
|
10 |
+
parser = argparse.ArgumentParser("Arguments for deploying a LaRI Demo")
|
11 |
+
parser.add_argument(
|
12 |
+
"--image_path",
|
13 |
+
type=str,
|
14 |
+
default="assets/cole_hardware.png",
|
15 |
+
help="input image name",
|
16 |
+
)
|
17 |
+
|
18 |
+
parser.add_argument(
|
19 |
+
"--output_path",
|
20 |
+
type=str,
|
21 |
+
default="./results",
|
22 |
+
help="path to save the image",
|
23 |
+
)
|
24 |
+
|
25 |
+
parser.add_argument(
|
26 |
+
"--model_info_pm",
|
27 |
+
type=str,
|
28 |
+
default="LaRIModel(use_pretrained = 'moge_full', num_output_layer = 5, head_type = 'point')",
|
29 |
+
help="Network parameters to load the model",
|
30 |
+
)
|
31 |
+
|
32 |
+
parser.add_argument(
|
33 |
+
"--model_info_mask",
|
34 |
+
type=str,
|
35 |
+
default="DinoSegModel(use_pretrained = 'dinov2', dim_proj = 256, pretrained_path = '', num_output_layer = 4, output_type = 'ray_stop')",
|
36 |
+
help="Network parameters to load the model",
|
37 |
+
)
|
38 |
+
|
39 |
+
parser.add_argument(
|
40 |
+
"--ckpt_path_pm",
|
41 |
+
type=str,
|
42 |
+
default="lari_obj_16k_pointmap.pth",
|
43 |
+
help="Path to pre-trained weights",
|
44 |
+
)
|
45 |
+
|
46 |
+
parser.add_argument(
|
47 |
+
"--ckpt_path_mask",
|
48 |
+
type=str,
|
49 |
+
default="lari_obj_16k_seg.pth",
|
50 |
+
help="Path to pre-trained weights",
|
51 |
+
)
|
52 |
+
|
53 |
+
parser.add_argument(
|
54 |
+
"--resolution", type=int, default=512, help="Default model resolution"
|
55 |
+
)
|
56 |
+
|
57 |
+
parser.add_argument(
|
58 |
+
"--is_remove_background", action="store_true", help="Automatically remove the background."
|
59 |
+
)
|
60 |
+
|
61 |
+
args = parser.parse_args()
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
69 |
+
cudnn.benchmark = True
|
70 |
+
|
71 |
+
# === Load the model
|
72 |
+
|
73 |
+
model_path_pm = hf_hub_download(repo_id="ruili3/LaRI", filename=args.ckpt_path_pm, repo_type="model")
|
74 |
+
model_path_mask = hf_hub_download(repo_id="ruili3/LaRI", filename=args.ckpt_path_mask, repo_type="model")
|
75 |
+
# Load the model with pretrained weights.
|
76 |
+
model_pm = load_model(args.model_info_pm, model_path_pm, device)
|
77 |
+
model_mask = (
|
78 |
+
load_model(args.model_info_mask, model_path_mask, device)
|
79 |
+
if args.model_info_mask is not None
|
80 |
+
else None
|
81 |
+
)
|
82 |
+
|
83 |
+
# === Image pre-processing
|
84 |
+
pil_input = Image.open(args.image_path)
|
85 |
+
if args.is_remove_background:
|
86 |
+
pil_input = removebg_crop(pil_input) # remove background
|
87 |
+
input_tensor, ori_img_tensor, crop_coords, original_size = process_image(
|
88 |
+
pil_input, resolution=512) # crop & resize to fit the model input size
|
89 |
+
input_tensor = input_tensor.to(device)
|
90 |
+
|
91 |
+
|
92 |
+
# === Run inference
|
93 |
+
with torch.no_grad():
|
94 |
+
# lari map
|
95 |
+
pred_dict = model_pm(input_tensor)
|
96 |
+
lari_map = -pred_dict["pts3d"].squeeze(
|
97 |
+
0
|
98 |
+
)
|
99 |
+
# mask
|
100 |
+
if model_mask:
|
101 |
+
pred_dict = model_mask(input_tensor)
|
102 |
+
assert "seg_prob" in pred_dict
|
103 |
+
valid_mask = prob_to_mask(pred_dict["seg_prob"].squeeze(0)) # H W L 1
|
104 |
+
else:
|
105 |
+
h, w, l, _ = lari_map.shape
|
106 |
+
valid_mask = torch.new_ones((h, w, l, 1), device=lari_map.device)
|
107 |
+
|
108 |
+
# === crop & resize back to the original resolution
|
109 |
+
if original_size[0] != args.resolution or original_size[1] != args.resolution:
|
110 |
+
lari_map = post_process_output(lari_map, crop_coords, original_size) # H W L 3
|
111 |
+
valid_mask = post_process_output(
|
112 |
+
valid_mask.float(), crop_coords, original_size
|
113 |
+
).bool() # H W L 1
|
114 |
+
|
115 |
+
max_n_layer = min(valid_mask.shape[-2], lari_map.shape[-2])
|
116 |
+
valid_mask = valid_mask[:, :, :max_n_layer, :]
|
117 |
+
lari_map = lari_map[:, :, :max_n_layer, :]
|
118 |
+
|
119 |
+
|
120 |
+
# === save output
|
121 |
+
os.makedirs(args.output_path, exist_ok=True)
|
122 |
+
|
123 |
+
for layer_id in range(max_n_layer):
|
124 |
+
depth_pil = get_masked_depth(
|
125 |
+
lari_map=lari_map, valid_mask=valid_mask, layer_id=layer_id
|
126 |
+
)
|
127 |
+
depth_pil.save(os.path.join(args.output_path, f"layered_depth_{layer_id}.jpg"))
|
128 |
+
|
129 |
+
|
130 |
+
# point cloud
|
131 |
+
glb_path, ply_path = get_point_cloud(
|
132 |
+
lari_map, ori_img_tensor, valid_mask, first_layer_color="pseudo",
|
133 |
+
target_folder=args.output_path
|
134 |
+
)
|
135 |
+
|
136 |
+
print("All results saved to `{}`.".format(args.output_path))
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==5.23.3
|
2 |
+
huggingface_hub==0.30.1
|
3 |
+
imageio==2.37.0
|
4 |
+
matplotlib==3.10.1
|
5 |
+
moderngl==5.12.0
|
6 |
+
omegaconf==2.3.0
|
7 |
+
opencv_python==4.11.0.86
|
8 |
+
opencv_python_headless==4.11.0.86
|
9 |
+
Pillow==11.1.0
|
10 |
+
piqp==0.5.0
|
11 |
+
plyfile==1.1
|
12 |
+
rembg==2.0.65
|
13 |
+
scipy==1.15.2
|
14 |
+
torchvision==0.21.0
|
15 |
+
trimesh==4.6.4
|
16 |
+
xformers==0.0.29.post3
|
17 |
+
numpy==1.26.4
|
18 |
+
torch==2.6.0
|
19 |
+
opencv-python==4.11.0
|
src/lari/model/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .lari_model import LaRIModel
|
2 |
+
from .dinoseg_model import DinoSegModel
|
src/lari/model/blocks.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class ResidualConvBlock(nn.Module):
|
5 |
+
def __init__(self, in_channels: int, out_channels: int = None, hidden_channels: int = None, padding_mode: str = 'replicate', activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', norm: Literal['group_norm', 'layer_norm'] = 'group_norm'):
|
6 |
+
super(ResidualConvBlock, self).__init__()
|
7 |
+
if out_channels is None:
|
8 |
+
out_channels = in_channels
|
9 |
+
if hidden_channels is None:
|
10 |
+
hidden_channels = in_channels
|
11 |
+
|
12 |
+
if activation =='relu':
|
13 |
+
activation_cls = lambda: nn.ReLU(inplace=True)
|
14 |
+
elif activation == 'leaky_relu':
|
15 |
+
activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
16 |
+
elif activation =='silu':
|
17 |
+
activation_cls = lambda: nn.SiLU(inplace=True)
|
18 |
+
elif activation == 'elu':
|
19 |
+
activation_cls = lambda: nn.ELU(inplace=True)
|
20 |
+
else:
|
21 |
+
raise ValueError(f'Unsupported activation function: {activation}')
|
22 |
+
|
23 |
+
self.layers = nn.Sequential(
|
24 |
+
nn.GroupNorm(1, in_channels),
|
25 |
+
activation_cls(),
|
26 |
+
nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode=padding_mode),
|
27 |
+
nn.GroupNorm(hidden_channels // 32 if norm == 'group_norm' else 1, hidden_channels),
|
28 |
+
activation_cls(),
|
29 |
+
nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode)
|
30 |
+
)
|
31 |
+
|
32 |
+
self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity()
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
skip = self.skip_connection(x)
|
36 |
+
x = self.layers(x)
|
37 |
+
x = x + skip
|
38 |
+
return x
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
def make_upsampler(in_channels: int, out_channels: int):
|
44 |
+
upsampler = nn.Sequential(
|
45 |
+
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
|
46 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
|
47 |
+
)
|
48 |
+
upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1]
|
49 |
+
return upsampler
|
50 |
+
|
51 |
+
def make_output_block(dim_in: int, dim_out: int, dim_times_res_block_hidden: int, last_res_blocks: int, last_conv_channels: int, last_conv_size: int, res_block_norm: Literal['group_norm', 'layer_norm']):
|
52 |
+
return nn.Sequential(
|
53 |
+
nn.Conv2d(dim_in, last_conv_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
|
54 |
+
*(ResidualConvBlock(last_conv_channels, last_conv_channels, dim_times_res_block_hidden * last_conv_channels, activation='relu', norm=res_block_norm) for _ in range(last_res_blocks)),
|
55 |
+
nn.ReLU(inplace=True),
|
56 |
+
nn.Conv2d(last_conv_channels, dim_out, kernel_size=last_conv_size, stride=1, padding=last_conv_size // 2, padding_mode='replicate'),
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
# ---- the following are from Depth Anything ----
|
62 |
+
import torch.nn as nn
|
63 |
+
|
64 |
+
|
65 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
66 |
+
scratch = nn.Module()
|
67 |
+
|
68 |
+
out_shape1 = out_shape
|
69 |
+
out_shape2 = out_shape
|
70 |
+
out_shape3 = out_shape
|
71 |
+
if len(in_shape) >= 4:
|
72 |
+
out_shape4 = out_shape
|
73 |
+
|
74 |
+
if expand:
|
75 |
+
out_shape1 = out_shape
|
76 |
+
out_shape2 = out_shape * 2
|
77 |
+
out_shape3 = out_shape * 4
|
78 |
+
if len(in_shape) >= 4:
|
79 |
+
out_shape4 = out_shape * 8
|
80 |
+
|
81 |
+
scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
82 |
+
scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
83 |
+
scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
84 |
+
if len(in_shape) >= 4:
|
85 |
+
scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
86 |
+
|
87 |
+
return scratch
|
88 |
+
|
89 |
+
|
90 |
+
class ResidualConvUnit(nn.Module):
|
91 |
+
"""Residual convolution module.
|
92 |
+
"""
|
93 |
+
|
94 |
+
def __init__(self, features, activation, bn):
|
95 |
+
"""Init.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
features (int): number of features
|
99 |
+
"""
|
100 |
+
super().__init__()
|
101 |
+
|
102 |
+
self.bn = bn
|
103 |
+
|
104 |
+
self.groups=1
|
105 |
+
|
106 |
+
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
107 |
+
|
108 |
+
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
109 |
+
|
110 |
+
if self.bn == True:
|
111 |
+
self.bn1 = nn.BatchNorm2d(features)
|
112 |
+
self.bn2 = nn.BatchNorm2d(features)
|
113 |
+
|
114 |
+
self.activation = activation
|
115 |
+
|
116 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
117 |
+
|
118 |
+
def forward(self, x):
|
119 |
+
"""Forward pass.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
x (tensor): input
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
tensor: output
|
126 |
+
"""
|
127 |
+
|
128 |
+
out = self.activation(x)
|
129 |
+
out = self.conv1(out)
|
130 |
+
if self.bn == True:
|
131 |
+
out = self.bn1(out)
|
132 |
+
|
133 |
+
out = self.activation(out)
|
134 |
+
out = self.conv2(out)
|
135 |
+
if self.bn == True:
|
136 |
+
out = self.bn2(out)
|
137 |
+
|
138 |
+
if self.groups > 1:
|
139 |
+
out = self.conv_merge(out)
|
140 |
+
|
141 |
+
return self.skip_add.add(out, x)
|
142 |
+
|
143 |
+
|
144 |
+
class FeatureFusionBlock(nn.Module):
|
145 |
+
"""Feature fusion block.
|
146 |
+
"""
|
147 |
+
|
148 |
+
def __init__(
|
149 |
+
self,
|
150 |
+
features,
|
151 |
+
activation,
|
152 |
+
deconv=False,
|
153 |
+
bn=False,
|
154 |
+
expand=False,
|
155 |
+
align_corners=True,
|
156 |
+
size=None
|
157 |
+
):
|
158 |
+
"""Init.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
features (int): number of features
|
162 |
+
"""
|
163 |
+
super(FeatureFusionBlock, self).__init__()
|
164 |
+
|
165 |
+
self.deconv = deconv
|
166 |
+
self.align_corners = align_corners
|
167 |
+
|
168 |
+
self.groups=1
|
169 |
+
|
170 |
+
self.expand = expand
|
171 |
+
out_features = features
|
172 |
+
if self.expand == True:
|
173 |
+
out_features = features // 2
|
174 |
+
|
175 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
176 |
+
|
177 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
|
178 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
|
179 |
+
|
180 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
181 |
+
|
182 |
+
self.size=size
|
183 |
+
|
184 |
+
def forward(self, *xs, size=None):
|
185 |
+
"""Forward pass.
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
tensor: output
|
189 |
+
"""
|
190 |
+
output = xs[0]
|
191 |
+
|
192 |
+
if len(xs) == 2:
|
193 |
+
res = self.resConfUnit1(xs[1])
|
194 |
+
output = self.skip_add.add(output, res)
|
195 |
+
|
196 |
+
output = self.resConfUnit2(output)
|
197 |
+
|
198 |
+
if (size is None) and (self.size is None):
|
199 |
+
modifier = {"scale_factor": 2}
|
200 |
+
elif size is None:
|
201 |
+
modifier = {"size": self.size}
|
202 |
+
else:
|
203 |
+
modifier = {"size": size}
|
204 |
+
|
205 |
+
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
206 |
+
|
207 |
+
output = self.out_conv(output)
|
208 |
+
|
209 |
+
return output
|
src/lari/model/dinoseg_model.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
from numbers import Number
|
3 |
+
from functools import partial
|
4 |
+
from pathlib import Path
|
5 |
+
import importlib
|
6 |
+
import warnings
|
7 |
+
import json
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torch.utils
|
13 |
+
import torch.utils.checkpoint
|
14 |
+
import torch.version
|
15 |
+
from huggingface_hub import hf_hub_download
|
16 |
+
|
17 |
+
|
18 |
+
from src.lari.model.utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
|
19 |
+
from src.lari.model.dpt_seg_head import DPTSegHead
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
class DinoSegModel(nn.Module):
|
24 |
+
|
25 |
+
def __init__(self,
|
26 |
+
encoder: str = 'dinov2_vitl14',
|
27 |
+
intermediate_layers: Union[int, List[int]] = 4,
|
28 |
+
dim_proj: int = 512,
|
29 |
+
use_pretrained: Literal["dinov2", "moge_full", "moge_backbone", None] = None,
|
30 |
+
pretrained_path: str = None,
|
31 |
+
num_output_layer: str = None,
|
32 |
+
output_type: str = "ray_stop", # "seg_sep"
|
33 |
+
**deprecated_kwargs
|
34 |
+
):
|
35 |
+
super(DinoSegModel, self).__init__()
|
36 |
+
if deprecated_kwargs:
|
37 |
+
warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
|
38 |
+
|
39 |
+
self.encoder = encoder
|
40 |
+
self.intermediate_layers = intermediate_layers
|
41 |
+
self.use_pretrained = use_pretrained
|
42 |
+
self.pretrained_path = pretrained_path
|
43 |
+
self.num_output_layer = num_output_layer
|
44 |
+
self.output_type = output_type
|
45 |
+
assert self.output_type in ["seg_sep", "ray_stop"]
|
46 |
+
|
47 |
+
hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder)
|
48 |
+
|
49 |
+
self.backbone = hub_loader(pretrained=True if self.use_pretrained == "dinov2" else False)
|
50 |
+
dim_feature = self.backbone.blocks[0].attn.qkv.in_features
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
self.head = DPTSegHead(in_channels=dim_feature,
|
56 |
+
features=dim_proj,
|
57 |
+
use_bn=True,
|
58 |
+
out_channels=[256, 512, 1024, 1024],
|
59 |
+
use_clstoken=False,
|
60 |
+
num_classes = num_output_layer,
|
61 |
+
output_type = self.output_type
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
if torch.__version__ >= '2.0':
|
66 |
+
self.enable_pytorch_native_sdpa()
|
67 |
+
|
68 |
+
self._load_pretrained()
|
69 |
+
|
70 |
+
|
71 |
+
def _load_pretrained(self):
|
72 |
+
'''
|
73 |
+
Load data from MoGe model
|
74 |
+
'''
|
75 |
+
return
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
@classmethod
|
81 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'DinoSegModel':
|
82 |
+
"""
|
83 |
+
Load a model from a checkpoint file.
|
84 |
+
|
85 |
+
### Parameters:
|
86 |
+
- `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
|
87 |
+
- `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
|
88 |
+
- `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
|
89 |
+
|
90 |
+
### Returns:
|
91 |
+
- A new instance of `MoGe` with the parameters loaded from the checkpoint.
|
92 |
+
"""
|
93 |
+
if Path(pretrained_model_name_or_path).exists():
|
94 |
+
checkpoint = torch.load(pretrained_model_name_or_path, map_location='cpu', weights_only=True)
|
95 |
+
else:
|
96 |
+
cached_checkpoint_path = hf_hub_download(
|
97 |
+
repo_id=pretrained_model_name_or_path,
|
98 |
+
repo_type="model",
|
99 |
+
filename="model.pt",
|
100 |
+
**hf_kwargs
|
101 |
+
)
|
102 |
+
checkpoint = torch.load(cached_checkpoint_path, map_location='cpu', weights_only=True)
|
103 |
+
model_config = checkpoint['model_config']
|
104 |
+
if model_kwargs is not None:
|
105 |
+
model_config.update(model_kwargs)
|
106 |
+
model = cls(**model_config)
|
107 |
+
model.load_state_dict(checkpoint['model'])
|
108 |
+
return model
|
109 |
+
|
110 |
+
@staticmethod
|
111 |
+
def cache_pretrained_backbone(encoder: str, pretrained: bool):
|
112 |
+
_ = torch.hub.load('facebookresearch/dinov2', encoder, pretrained=pretrained)
|
113 |
+
|
114 |
+
def load_pretrained_backbone(self):
|
115 |
+
"Load the backbone with pretrained dinov2 weights from torch hub"
|
116 |
+
state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict()
|
117 |
+
self.backbone.load_state_dict(state_dict)
|
118 |
+
|
119 |
+
def enable_backbone_gradient_checkpointing(self):
|
120 |
+
for i in range(len(self.backbone.blocks)):
|
121 |
+
self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
|
122 |
+
|
123 |
+
def enable_pytorch_native_sdpa(self):
|
124 |
+
for i in range(len(self.backbone.blocks)):
|
125 |
+
self.backbone.blocks[i].attn = wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn)
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
def forward(self, image: torch.Tensor, mixed_precision: bool = False) -> Dict[str, torch.Tensor]:
|
130 |
+
raw_img_h, raw_img_w = image.shape[-2:]
|
131 |
+
patch_h, patch_w = raw_img_h // 14, raw_img_w // 14
|
132 |
+
# Apply image transformation for DINOv2
|
133 |
+
image_14 = F.interpolate(image, (patch_h * 14, patch_w * 14), mode="bilinear", align_corners=False, antialias=True)
|
134 |
+
|
135 |
+
# Get intermediate layers from the backbone
|
136 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=mixed_precision):
|
137 |
+
features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True)
|
138 |
+
|
139 |
+
# Predict points and mask (mask scores)
|
140 |
+
mask = self.head(features, patch_h, patch_w)
|
141 |
+
|
142 |
+
# b c h w
|
143 |
+
mask = F.interpolate(mask, (raw_img_h, raw_img_w), mode="bilinear", align_corners=False)
|
144 |
+
|
145 |
+
out_dict = {}
|
146 |
+
|
147 |
+
if self.output_type == "seg_sep":
|
148 |
+
# mask = torch.nn.functional.sigmoid(mask) # for binary segmentation
|
149 |
+
out_dict["mask"] = mask.permute(0, 2, 3, 1).unsqueeze(-1) # B H W L 1
|
150 |
+
elif self.output_type == "ray_stop":
|
151 |
+
out_dict["seg_prob"] = mask # B L+1 H W
|
152 |
+
|
153 |
+
return out_dict
|
src/lari/model/dinov2/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
__version__ = "0.0.1"
|
src/lari/model/dinov2/hub/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
src/lari/model/dinov2/hub/backbones.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from enum import Enum
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
|
12 |
+
|
13 |
+
|
14 |
+
class Weights(Enum):
|
15 |
+
LVD142M = "LVD142M"
|
16 |
+
|
17 |
+
|
18 |
+
def _make_dinov2_model(
|
19 |
+
*,
|
20 |
+
arch_name: str = "vit_large",
|
21 |
+
img_size: int = 518,
|
22 |
+
patch_size: int = 14,
|
23 |
+
init_values: float = 1.0,
|
24 |
+
ffn_layer: str = "mlp",
|
25 |
+
block_chunks: int = 0,
|
26 |
+
num_register_tokens: int = 0,
|
27 |
+
interpolate_antialias: bool = False,
|
28 |
+
interpolate_offset: float = 0.1,
|
29 |
+
pretrained: bool = True,
|
30 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
31 |
+
**kwargs,
|
32 |
+
):
|
33 |
+
from ..models import vision_transformer as vits
|
34 |
+
|
35 |
+
if isinstance(weights, str):
|
36 |
+
try:
|
37 |
+
weights = Weights[weights]
|
38 |
+
except KeyError:
|
39 |
+
raise AssertionError(f"Unsupported weights: {weights}")
|
40 |
+
|
41 |
+
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
42 |
+
vit_kwargs = dict(
|
43 |
+
img_size=img_size,
|
44 |
+
patch_size=patch_size,
|
45 |
+
init_values=init_values,
|
46 |
+
ffn_layer=ffn_layer,
|
47 |
+
block_chunks=block_chunks,
|
48 |
+
num_register_tokens=num_register_tokens,
|
49 |
+
interpolate_antialias=interpolate_antialias,
|
50 |
+
interpolate_offset=interpolate_offset,
|
51 |
+
)
|
52 |
+
vit_kwargs.update(**kwargs)
|
53 |
+
model = vits.__dict__[arch_name](**vit_kwargs)
|
54 |
+
|
55 |
+
if pretrained:
|
56 |
+
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
|
57 |
+
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
|
58 |
+
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
59 |
+
model.load_state_dict(state_dict, strict=True)
|
60 |
+
|
61 |
+
return model
|
62 |
+
|
63 |
+
|
64 |
+
def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
65 |
+
"""
|
66 |
+
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
|
67 |
+
"""
|
68 |
+
return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
|
69 |
+
|
70 |
+
|
71 |
+
def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
72 |
+
"""
|
73 |
+
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
|
74 |
+
"""
|
75 |
+
return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
|
76 |
+
|
77 |
+
|
78 |
+
def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
79 |
+
"""
|
80 |
+
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
|
81 |
+
"""
|
82 |
+
return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
|
83 |
+
|
84 |
+
|
85 |
+
def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
86 |
+
"""
|
87 |
+
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
|
88 |
+
"""
|
89 |
+
return _make_dinov2_model(
|
90 |
+
arch_name="vit_giant2",
|
91 |
+
ffn_layer="swiglufused",
|
92 |
+
weights=weights,
|
93 |
+
pretrained=pretrained,
|
94 |
+
**kwargs,
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
99 |
+
"""
|
100 |
+
DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
101 |
+
"""
|
102 |
+
return _make_dinov2_model(
|
103 |
+
arch_name="vit_small",
|
104 |
+
pretrained=pretrained,
|
105 |
+
weights=weights,
|
106 |
+
num_register_tokens=4,
|
107 |
+
interpolate_antialias=True,
|
108 |
+
interpolate_offset=0.0,
|
109 |
+
**kwargs,
|
110 |
+
)
|
111 |
+
|
112 |
+
|
113 |
+
def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
114 |
+
"""
|
115 |
+
DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
116 |
+
"""
|
117 |
+
return _make_dinov2_model(
|
118 |
+
arch_name="vit_base",
|
119 |
+
pretrained=pretrained,
|
120 |
+
weights=weights,
|
121 |
+
num_register_tokens=4,
|
122 |
+
interpolate_antialias=True,
|
123 |
+
interpolate_offset=0.0,
|
124 |
+
**kwargs,
|
125 |
+
)
|
126 |
+
|
127 |
+
|
128 |
+
def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
129 |
+
"""
|
130 |
+
DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
131 |
+
"""
|
132 |
+
return _make_dinov2_model(
|
133 |
+
arch_name="vit_large",
|
134 |
+
pretrained=pretrained,
|
135 |
+
weights=weights,
|
136 |
+
num_register_tokens=4,
|
137 |
+
interpolate_antialias=True,
|
138 |
+
interpolate_offset=0.0,
|
139 |
+
**kwargs,
|
140 |
+
)
|
141 |
+
|
142 |
+
|
143 |
+
def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
144 |
+
"""
|
145 |
+
DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
146 |
+
"""
|
147 |
+
return _make_dinov2_model(
|
148 |
+
arch_name="vit_giant2",
|
149 |
+
ffn_layer="swiglufused",
|
150 |
+
weights=weights,
|
151 |
+
pretrained=pretrained,
|
152 |
+
num_register_tokens=4,
|
153 |
+
interpolate_antialias=True,
|
154 |
+
interpolate_offset=0.0,
|
155 |
+
**kwargs,
|
156 |
+
)
|
src/lari/model/dinov2/hub/utils.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import itertools
|
7 |
+
import math
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
|
15 |
+
|
16 |
+
|
17 |
+
def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
|
18 |
+
compact_arch_name = arch_name.replace("_", "")[:4]
|
19 |
+
registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
|
20 |
+
return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
|
21 |
+
|
22 |
+
|
23 |
+
class CenterPadding(nn.Module):
|
24 |
+
def __init__(self, multiple):
|
25 |
+
super().__init__()
|
26 |
+
self.multiple = multiple
|
27 |
+
|
28 |
+
def _get_pad(self, size):
|
29 |
+
new_size = math.ceil(size / self.multiple) * self.multiple
|
30 |
+
pad_size = new_size - size
|
31 |
+
pad_size_left = pad_size // 2
|
32 |
+
pad_size_right = pad_size - pad_size_left
|
33 |
+
return pad_size_left, pad_size_right
|
34 |
+
|
35 |
+
@torch.inference_mode()
|
36 |
+
def forward(self, x):
|
37 |
+
pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
|
38 |
+
output = F.pad(x, pads)
|
39 |
+
return output
|
src/lari/model/dinov2/layers/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .dino_head import DINOHead
|
7 |
+
from .mlp import Mlp
|
8 |
+
from .patch_embed import PatchEmbed
|
9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
10 |
+
from .block import NestedTensorBlock
|
11 |
+
from .attention import MemEffAttention
|
src/lari/model/dinov2/layers/attention.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import warnings
|
13 |
+
|
14 |
+
from torch import Tensor
|
15 |
+
from torch import nn
|
16 |
+
|
17 |
+
|
18 |
+
logger = logging.getLogger("dinov2")
|
19 |
+
|
20 |
+
|
21 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
22 |
+
try:
|
23 |
+
if XFORMERS_ENABLED:
|
24 |
+
from xformers.ops import memory_efficient_attention, unbind
|
25 |
+
|
26 |
+
XFORMERS_AVAILABLE = True
|
27 |
+
# warnings.warn("xFormers is available (Attention)")
|
28 |
+
else:
|
29 |
+
# warnings.warn("xFormers is disabled (Attention)")
|
30 |
+
raise ImportError
|
31 |
+
except ImportError:
|
32 |
+
XFORMERS_AVAILABLE = False
|
33 |
+
# warnings.warn("xFormers is not available (Attention)")
|
34 |
+
|
35 |
+
|
36 |
+
class Attention(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
dim: int,
|
40 |
+
num_heads: int = 8,
|
41 |
+
qkv_bias: bool = False,
|
42 |
+
proj_bias: bool = True,
|
43 |
+
attn_drop: float = 0.0,
|
44 |
+
proj_drop: float = 0.0,
|
45 |
+
) -> None:
|
46 |
+
super().__init__()
|
47 |
+
self.num_heads = num_heads
|
48 |
+
head_dim = dim // num_heads
|
49 |
+
self.scale = head_dim**-0.5
|
50 |
+
|
51 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
52 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
53 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
54 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
55 |
+
|
56 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
57 |
+
B, N, C = x.shape
|
58 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
59 |
+
|
60 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
61 |
+
attn = q @ k.transpose(-2, -1)
|
62 |
+
|
63 |
+
attn = attn.softmax(dim=-1)
|
64 |
+
attn = self.attn_drop(attn)
|
65 |
+
|
66 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
67 |
+
x = self.proj(x)
|
68 |
+
x = self.proj_drop(x)
|
69 |
+
return x
|
70 |
+
|
71 |
+
|
72 |
+
class MemEffAttention(Attention):
|
73 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
74 |
+
if not XFORMERS_AVAILABLE:
|
75 |
+
if attn_bias is not None:
|
76 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
77 |
+
return super().forward(x)
|
78 |
+
|
79 |
+
B, N, C = x.shape
|
80 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
81 |
+
|
82 |
+
q, k, v = unbind(qkv, 2)
|
83 |
+
|
84 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
85 |
+
x = x.reshape([B, N, C])
|
86 |
+
|
87 |
+
x = self.proj(x)
|
88 |
+
x = self.proj_drop(x)
|
89 |
+
return x
|
src/lari/model/dinov2/layers/block.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
13 |
+
import warnings
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import nn, Tensor
|
17 |
+
|
18 |
+
from .attention import Attention, MemEffAttention
|
19 |
+
from .drop_path import DropPath
|
20 |
+
from .layer_scale import LayerScale
|
21 |
+
from .mlp import Mlp
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.getLogger("dinov2")
|
25 |
+
|
26 |
+
|
27 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
28 |
+
try:
|
29 |
+
if XFORMERS_ENABLED:
|
30 |
+
from xformers.ops import fmha, scaled_index_add, index_select_cat
|
31 |
+
|
32 |
+
XFORMERS_AVAILABLE = True
|
33 |
+
# warnings.warn("xFormers is available (Block)")
|
34 |
+
else:
|
35 |
+
# warnings.warn("xFormers is disabled (Block)")
|
36 |
+
raise ImportError
|
37 |
+
except ImportError:
|
38 |
+
XFORMERS_AVAILABLE = False
|
39 |
+
# warnings.warn("xFormers is not available (Block)")
|
40 |
+
|
41 |
+
|
42 |
+
class Block(nn.Module):
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
dim: int,
|
46 |
+
num_heads: int,
|
47 |
+
mlp_ratio: float = 4.0,
|
48 |
+
qkv_bias: bool = False,
|
49 |
+
proj_bias: bool = True,
|
50 |
+
ffn_bias: bool = True,
|
51 |
+
drop: float = 0.0,
|
52 |
+
attn_drop: float = 0.0,
|
53 |
+
init_values=None,
|
54 |
+
drop_path: float = 0.0,
|
55 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
56 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
57 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
58 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
59 |
+
) -> None:
|
60 |
+
super().__init__()
|
61 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
62 |
+
self.norm1 = norm_layer(dim)
|
63 |
+
self.attn = attn_class(
|
64 |
+
dim,
|
65 |
+
num_heads=num_heads,
|
66 |
+
qkv_bias=qkv_bias,
|
67 |
+
proj_bias=proj_bias,
|
68 |
+
attn_drop=attn_drop,
|
69 |
+
proj_drop=drop,
|
70 |
+
)
|
71 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
72 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
73 |
+
|
74 |
+
self.norm2 = norm_layer(dim)
|
75 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
76 |
+
self.mlp = ffn_layer(
|
77 |
+
in_features=dim,
|
78 |
+
hidden_features=mlp_hidden_dim,
|
79 |
+
act_layer=act_layer,
|
80 |
+
drop=drop,
|
81 |
+
bias=ffn_bias,
|
82 |
+
)
|
83 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
84 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
85 |
+
|
86 |
+
self.sample_drop_ratio = drop_path
|
87 |
+
|
88 |
+
def forward(self, x: Tensor) -> Tensor:
|
89 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
90 |
+
return self.ls1(self.attn(self.norm1(x)))
|
91 |
+
|
92 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
93 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
94 |
+
|
95 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
96 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
97 |
+
x = drop_add_residual_stochastic_depth(
|
98 |
+
x,
|
99 |
+
residual_func=attn_residual_func,
|
100 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
101 |
+
)
|
102 |
+
x = drop_add_residual_stochastic_depth(
|
103 |
+
x,
|
104 |
+
residual_func=ffn_residual_func,
|
105 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
106 |
+
)
|
107 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
108 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
109 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
110 |
+
else:
|
111 |
+
x = x + attn_residual_func(x)
|
112 |
+
x = x + ffn_residual_func(x)
|
113 |
+
return x
|
114 |
+
|
115 |
+
|
116 |
+
def drop_add_residual_stochastic_depth(
|
117 |
+
x: Tensor,
|
118 |
+
residual_func: Callable[[Tensor], Tensor],
|
119 |
+
sample_drop_ratio: float = 0.0,
|
120 |
+
) -> Tensor:
|
121 |
+
# 1) extract subset using permutation
|
122 |
+
b, n, d = x.shape
|
123 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
124 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
125 |
+
x_subset = x[brange]
|
126 |
+
|
127 |
+
# 2) apply residual_func to get residual
|
128 |
+
residual = residual_func(x_subset)
|
129 |
+
|
130 |
+
x_flat = x.flatten(1)
|
131 |
+
residual = residual.flatten(1)
|
132 |
+
|
133 |
+
residual_scale_factor = b / sample_subset_size
|
134 |
+
|
135 |
+
# 3) add the residual
|
136 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
137 |
+
return x_plus_residual.view_as(x)
|
138 |
+
|
139 |
+
|
140 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
141 |
+
b, n, d = x.shape
|
142 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
143 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
144 |
+
residual_scale_factor = b / sample_subset_size
|
145 |
+
return brange, residual_scale_factor
|
146 |
+
|
147 |
+
|
148 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
149 |
+
if scaling_vector is None:
|
150 |
+
x_flat = x.flatten(1)
|
151 |
+
residual = residual.flatten(1)
|
152 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
153 |
+
else:
|
154 |
+
x_plus_residual = scaled_index_add(
|
155 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
156 |
+
)
|
157 |
+
return x_plus_residual
|
158 |
+
|
159 |
+
|
160 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
161 |
+
|
162 |
+
|
163 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
164 |
+
"""
|
165 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
166 |
+
"""
|
167 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
168 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
169 |
+
if all_shapes not in attn_bias_cache.keys():
|
170 |
+
seqlens = []
|
171 |
+
for b, x in zip(batch_sizes, x_list):
|
172 |
+
for _ in range(b):
|
173 |
+
seqlens.append(x.shape[1])
|
174 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
175 |
+
attn_bias._batch_sizes = batch_sizes
|
176 |
+
attn_bias_cache[all_shapes] = attn_bias
|
177 |
+
|
178 |
+
if branges is not None:
|
179 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
180 |
+
else:
|
181 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
182 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
183 |
+
|
184 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
185 |
+
|
186 |
+
|
187 |
+
def drop_add_residual_stochastic_depth_list(
|
188 |
+
x_list: List[Tensor],
|
189 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
190 |
+
sample_drop_ratio: float = 0.0,
|
191 |
+
scaling_vector=None,
|
192 |
+
) -> Tensor:
|
193 |
+
# 1) generate random set of indices for dropping samples in the batch
|
194 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
195 |
+
branges = [s[0] for s in branges_scales]
|
196 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
197 |
+
|
198 |
+
# 2) get attention bias and index+concat the tensors
|
199 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
200 |
+
|
201 |
+
# 3) apply residual_func to get residual, and split the result
|
202 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
203 |
+
|
204 |
+
outputs = []
|
205 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
206 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
207 |
+
return outputs
|
208 |
+
|
209 |
+
|
210 |
+
class NestedTensorBlock(Block):
|
211 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
212 |
+
"""
|
213 |
+
x_list contains a list of tensors to nest together and run
|
214 |
+
"""
|
215 |
+
assert isinstance(self.attn, MemEffAttention)
|
216 |
+
|
217 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
218 |
+
|
219 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
220 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
221 |
+
|
222 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
223 |
+
return self.mlp(self.norm2(x))
|
224 |
+
|
225 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
226 |
+
x_list,
|
227 |
+
residual_func=attn_residual_func,
|
228 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
229 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
230 |
+
)
|
231 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
232 |
+
x_list,
|
233 |
+
residual_func=ffn_residual_func,
|
234 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
235 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
236 |
+
)
|
237 |
+
return x_list
|
238 |
+
else:
|
239 |
+
|
240 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
241 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
242 |
+
|
243 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
244 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
245 |
+
|
246 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
247 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
248 |
+
x = x + ffn_residual_func(x)
|
249 |
+
return attn_bias.split(x)
|
250 |
+
|
251 |
+
def forward(self, x_or_x_list):
|
252 |
+
if isinstance(x_or_x_list, Tensor):
|
253 |
+
return super().forward(x_or_x_list)
|
254 |
+
elif isinstance(x_or_x_list, list):
|
255 |
+
if not XFORMERS_AVAILABLE:
|
256 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
257 |
+
return self.forward_nested(x_or_x_list)
|
258 |
+
else:
|
259 |
+
raise AssertionError
|
src/lari/model/dinov2/layers/dino_head.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.nn.init import trunc_normal_
|
9 |
+
from torch.nn.utils import weight_norm
|
10 |
+
|
11 |
+
|
12 |
+
class DINOHead(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
in_dim,
|
16 |
+
out_dim,
|
17 |
+
use_bn=False,
|
18 |
+
nlayers=3,
|
19 |
+
hidden_dim=2048,
|
20 |
+
bottleneck_dim=256,
|
21 |
+
mlp_bias=True,
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
nlayers = max(nlayers, 1)
|
25 |
+
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
|
26 |
+
self.apply(self._init_weights)
|
27 |
+
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
28 |
+
self.last_layer.weight_g.data.fill_(1)
|
29 |
+
|
30 |
+
def _init_weights(self, m):
|
31 |
+
if isinstance(m, nn.Linear):
|
32 |
+
trunc_normal_(m.weight, std=0.02)
|
33 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
34 |
+
nn.init.constant_(m.bias, 0)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = self.mlp(x)
|
38 |
+
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
39 |
+
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
40 |
+
x = self.last_layer(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
|
45 |
+
if nlayers == 1:
|
46 |
+
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
47 |
+
else:
|
48 |
+
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
49 |
+
if use_bn:
|
50 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
51 |
+
layers.append(nn.GELU())
|
52 |
+
for _ in range(nlayers - 2):
|
53 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
54 |
+
if use_bn:
|
55 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
56 |
+
layers.append(nn.GELU())
|
57 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
58 |
+
return nn.Sequential(*layers)
|
src/lari/model/dinov2/layers/drop_path.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
9 |
+
|
10 |
+
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
15 |
+
if drop_prob == 0.0 or not training:
|
16 |
+
return x
|
17 |
+
keep_prob = 1 - drop_prob
|
18 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
19 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
20 |
+
if keep_prob > 0.0:
|
21 |
+
random_tensor.div_(keep_prob)
|
22 |
+
output = x * random_tensor
|
23 |
+
return output
|
24 |
+
|
25 |
+
|
26 |
+
class DropPath(nn.Module):
|
27 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
28 |
+
|
29 |
+
def __init__(self, drop_prob=None):
|
30 |
+
super(DropPath, self).__init__()
|
31 |
+
self.drop_prob = drop_prob
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
return drop_path(x, self.drop_prob, self.training)
|
src/lari/model/dinov2/layers/layer_scale.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
7 |
+
|
8 |
+
from typing import Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import Tensor
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
|
15 |
+
class LayerScale(nn.Module):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
dim: int,
|
19 |
+
init_values: Union[float, Tensor] = 1e-5,
|
20 |
+
inplace: bool = False,
|
21 |
+
) -> None:
|
22 |
+
super().__init__()
|
23 |
+
self.inplace = inplace
|
24 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
25 |
+
|
26 |
+
def forward(self, x: Tensor) -> Tensor:
|
27 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
src/lari/model/dinov2/layers/mlp.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
9 |
+
|
10 |
+
|
11 |
+
from typing import Callable, Optional
|
12 |
+
|
13 |
+
from torch import Tensor, nn
|
14 |
+
|
15 |
+
|
16 |
+
class Mlp(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
in_features: int,
|
20 |
+
hidden_features: Optional[int] = None,
|
21 |
+
out_features: Optional[int] = None,
|
22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
23 |
+
drop: float = 0.0,
|
24 |
+
bias: bool = True,
|
25 |
+
) -> None:
|
26 |
+
super().__init__()
|
27 |
+
out_features = out_features or in_features
|
28 |
+
hidden_features = hidden_features or in_features
|
29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
30 |
+
self.act = act_layer()
|
31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
32 |
+
self.drop = nn.Dropout(drop)
|
33 |
+
|
34 |
+
def forward(self, x: Tensor) -> Tensor:
|
35 |
+
x = self.fc1(x)
|
36 |
+
x = self.act(x)
|
37 |
+
x = self.drop(x)
|
38 |
+
x = self.fc2(x)
|
39 |
+
x = self.drop(x)
|
40 |
+
return x
|
src/lari/model/dinov2/layers/patch_embed.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
9 |
+
|
10 |
+
from typing import Callable, Optional, Tuple, Union
|
11 |
+
|
12 |
+
from torch import Tensor
|
13 |
+
import torch.nn as nn
|
14 |
+
|
15 |
+
|
16 |
+
def make_2tuple(x):
|
17 |
+
if isinstance(x, tuple):
|
18 |
+
assert len(x) == 2
|
19 |
+
return x
|
20 |
+
|
21 |
+
assert isinstance(x, int)
|
22 |
+
return (x, x)
|
23 |
+
|
24 |
+
|
25 |
+
class PatchEmbed(nn.Module):
|
26 |
+
"""
|
27 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
28 |
+
|
29 |
+
Args:
|
30 |
+
img_size: Image size.
|
31 |
+
patch_size: Patch token size.
|
32 |
+
in_chans: Number of input image channels.
|
33 |
+
embed_dim: Number of linear projection output channels.
|
34 |
+
norm_layer: Normalization layer.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
40 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
41 |
+
in_chans: int = 3,
|
42 |
+
embed_dim: int = 768,
|
43 |
+
norm_layer: Optional[Callable] = None,
|
44 |
+
flatten_embedding: bool = True,
|
45 |
+
) -> None:
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
image_HW = make_2tuple(img_size)
|
49 |
+
patch_HW = make_2tuple(patch_size)
|
50 |
+
patch_grid_size = (
|
51 |
+
image_HW[0] // patch_HW[0],
|
52 |
+
image_HW[1] // patch_HW[1],
|
53 |
+
)
|
54 |
+
|
55 |
+
self.img_size = image_HW
|
56 |
+
self.patch_size = patch_HW
|
57 |
+
self.patches_resolution = patch_grid_size
|
58 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
59 |
+
|
60 |
+
self.in_chans = in_chans
|
61 |
+
self.embed_dim = embed_dim
|
62 |
+
|
63 |
+
self.flatten_embedding = flatten_embedding
|
64 |
+
|
65 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
66 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
67 |
+
|
68 |
+
def forward(self, x: Tensor) -> Tensor:
|
69 |
+
_, _, H, W = x.shape
|
70 |
+
patch_H, patch_W = self.patch_size
|
71 |
+
|
72 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
73 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
74 |
+
|
75 |
+
x = self.proj(x) # B C H W
|
76 |
+
H, W = x.size(2), x.size(3)
|
77 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
78 |
+
x = self.norm(x)
|
79 |
+
if not self.flatten_embedding:
|
80 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
81 |
+
return x
|
82 |
+
|
83 |
+
def flops(self) -> float:
|
84 |
+
Ho, Wo = self.patches_resolution
|
85 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
86 |
+
if self.norm is not None:
|
87 |
+
flops += Ho * Wo * self.embed_dim
|
88 |
+
return flops
|
src/lari/model/dinov2/layers/swiglu_ffn.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
from typing import Callable, Optional
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
from torch import Tensor, nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
class SwiGLUFFN(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
in_features: int,
|
18 |
+
hidden_features: Optional[int] = None,
|
19 |
+
out_features: Optional[int] = None,
|
20 |
+
act_layer: Callable[..., nn.Module] = None,
|
21 |
+
drop: float = 0.0,
|
22 |
+
bias: bool = True,
|
23 |
+
) -> None:
|
24 |
+
super().__init__()
|
25 |
+
out_features = out_features or in_features
|
26 |
+
hidden_features = hidden_features or in_features
|
27 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
28 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
29 |
+
|
30 |
+
def forward(self, x: Tensor) -> Tensor:
|
31 |
+
x12 = self.w12(x)
|
32 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
33 |
+
hidden = F.silu(x1) * x2
|
34 |
+
return self.w3(hidden)
|
35 |
+
|
36 |
+
|
37 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
38 |
+
try:
|
39 |
+
if XFORMERS_ENABLED:
|
40 |
+
from xformers.ops import SwiGLU
|
41 |
+
|
42 |
+
XFORMERS_AVAILABLE = True
|
43 |
+
# warnings.warn("xFormers is available (SwiGLU)")
|
44 |
+
else:
|
45 |
+
# warnings.warn("xFormers is disabled (SwiGLU)")
|
46 |
+
raise ImportError
|
47 |
+
except ImportError:
|
48 |
+
SwiGLU = SwiGLUFFN
|
49 |
+
XFORMERS_AVAILABLE = False
|
50 |
+
|
51 |
+
# warnings.warn("xFormers is not available (SwiGLU)")
|
52 |
+
|
53 |
+
|
54 |
+
class SwiGLUFFNFused(SwiGLU):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
in_features: int,
|
58 |
+
hidden_features: Optional[int] = None,
|
59 |
+
out_features: Optional[int] = None,
|
60 |
+
act_layer: Callable[..., nn.Module] = None,
|
61 |
+
drop: float = 0.0,
|
62 |
+
bias: bool = True,
|
63 |
+
) -> None:
|
64 |
+
out_features = out_features or in_features
|
65 |
+
hidden_features = hidden_features or in_features
|
66 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
67 |
+
super().__init__(
|
68 |
+
in_features=in_features,
|
69 |
+
hidden_features=hidden_features,
|
70 |
+
out_features=out_features,
|
71 |
+
bias=bias,
|
72 |
+
)
|
src/lari/model/dinov2/models/__init__.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
|
8 |
+
from . import vision_transformer as vits
|
9 |
+
|
10 |
+
|
11 |
+
logger = logging.getLogger("dinov2")
|
12 |
+
|
13 |
+
|
14 |
+
def build_model(args, only_teacher=False, img_size=224):
|
15 |
+
args.arch = args.arch.removesuffix("_memeff")
|
16 |
+
if "vit" in args.arch:
|
17 |
+
vit_kwargs = dict(
|
18 |
+
img_size=img_size,
|
19 |
+
patch_size=args.patch_size,
|
20 |
+
init_values=args.layerscale,
|
21 |
+
ffn_layer=args.ffn_layer,
|
22 |
+
block_chunks=args.block_chunks,
|
23 |
+
qkv_bias=args.qkv_bias,
|
24 |
+
proj_bias=args.proj_bias,
|
25 |
+
ffn_bias=args.ffn_bias,
|
26 |
+
num_register_tokens=args.num_register_tokens,
|
27 |
+
interpolate_offset=args.interpolate_offset,
|
28 |
+
interpolate_antialias=args.interpolate_antialias,
|
29 |
+
)
|
30 |
+
teacher = vits.__dict__[args.arch](**vit_kwargs)
|
31 |
+
if only_teacher:
|
32 |
+
return teacher, teacher.embed_dim
|
33 |
+
student = vits.__dict__[args.arch](
|
34 |
+
**vit_kwargs,
|
35 |
+
drop_path_rate=args.drop_path_rate,
|
36 |
+
drop_path_uniform=args.drop_path_uniform,
|
37 |
+
)
|
38 |
+
embed_dim = student.embed_dim
|
39 |
+
return student, teacher, embed_dim
|
40 |
+
|
41 |
+
|
42 |
+
def build_model_from_cfg(cfg, only_teacher=False):
|
43 |
+
return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
|
src/lari/model/dinov2/models/vision_transformer.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
from functools import partial
|
11 |
+
import math
|
12 |
+
import logging
|
13 |
+
from typing import Sequence, Tuple, Union, Callable
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.utils.checkpoint
|
18 |
+
from torch.nn.init import trunc_normal_
|
19 |
+
|
20 |
+
from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger("dinov2")
|
24 |
+
|
25 |
+
|
26 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
27 |
+
if not depth_first and include_root:
|
28 |
+
fn(module=module, name=name)
|
29 |
+
for child_name, child_module in module.named_children():
|
30 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
31 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
32 |
+
if depth_first and include_root:
|
33 |
+
fn(module=module, name=name)
|
34 |
+
return module
|
35 |
+
|
36 |
+
|
37 |
+
class BlockChunk(nn.ModuleList):
|
38 |
+
def forward(self, x):
|
39 |
+
for b in self:
|
40 |
+
x = b(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
class DinoVisionTransformer(nn.Module):
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
img_size=224,
|
48 |
+
patch_size=16,
|
49 |
+
in_chans=3,
|
50 |
+
embed_dim=768,
|
51 |
+
depth=12,
|
52 |
+
num_heads=12,
|
53 |
+
mlp_ratio=4.0,
|
54 |
+
qkv_bias=True,
|
55 |
+
ffn_bias=True,
|
56 |
+
proj_bias=True,
|
57 |
+
drop_path_rate=0.0,
|
58 |
+
drop_path_uniform=False,
|
59 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
60 |
+
embed_layer=PatchEmbed,
|
61 |
+
act_layer=nn.GELU,
|
62 |
+
block_fn=Block,
|
63 |
+
ffn_layer="mlp",
|
64 |
+
block_chunks=1,
|
65 |
+
num_register_tokens=0,
|
66 |
+
interpolate_antialias=False,
|
67 |
+
interpolate_offset=0.1,
|
68 |
+
):
|
69 |
+
"""
|
70 |
+
Args:
|
71 |
+
img_size (int, tuple): input image size
|
72 |
+
patch_size (int, tuple): patch size
|
73 |
+
in_chans (int): number of input channels
|
74 |
+
embed_dim (int): embedding dimension
|
75 |
+
depth (int): depth of transformer
|
76 |
+
num_heads (int): number of attention heads
|
77 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
78 |
+
qkv_bias (bool): enable bias for qkv if True
|
79 |
+
proj_bias (bool): enable bias for proj in attn if True
|
80 |
+
ffn_bias (bool): enable bias for ffn if True
|
81 |
+
drop_path_rate (float): stochastic depth rate
|
82 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
83 |
+
weight_init (str): weight init scheme
|
84 |
+
init_values (float): layer-scale init values
|
85 |
+
embed_layer (nn.Module): patch embedding layer
|
86 |
+
act_layer (nn.Module): MLP activation layer
|
87 |
+
block_fn (nn.Module): transformer block class
|
88 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
89 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
90 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
91 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
92 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
93 |
+
"""
|
94 |
+
super().__init__()
|
95 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
96 |
+
|
97 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
98 |
+
self.num_tokens = 1
|
99 |
+
self.n_blocks = depth
|
100 |
+
self.num_heads = num_heads
|
101 |
+
self.patch_size = patch_size
|
102 |
+
self.num_register_tokens = num_register_tokens
|
103 |
+
self.interpolate_antialias = interpolate_antialias
|
104 |
+
self.interpolate_offset = interpolate_offset
|
105 |
+
|
106 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
107 |
+
num_patches = self.patch_embed.num_patches
|
108 |
+
|
109 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
110 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
111 |
+
assert num_register_tokens >= 0
|
112 |
+
self.register_tokens = (
|
113 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
114 |
+
)
|
115 |
+
|
116 |
+
if drop_path_uniform is True:
|
117 |
+
dpr = [drop_path_rate] * depth
|
118 |
+
else:
|
119 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
120 |
+
|
121 |
+
if ffn_layer == "mlp":
|
122 |
+
logger.info("using MLP layer as FFN")
|
123 |
+
ffn_layer = Mlp
|
124 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
125 |
+
logger.info("using SwiGLU layer as FFN")
|
126 |
+
ffn_layer = SwiGLUFFNFused
|
127 |
+
elif ffn_layer == "identity":
|
128 |
+
logger.info("using Identity layer as FFN")
|
129 |
+
|
130 |
+
def f(*args, **kwargs):
|
131 |
+
return nn.Identity()
|
132 |
+
|
133 |
+
ffn_layer = f
|
134 |
+
else:
|
135 |
+
raise NotImplementedError
|
136 |
+
|
137 |
+
blocks_list = [
|
138 |
+
block_fn(
|
139 |
+
dim=embed_dim,
|
140 |
+
num_heads=num_heads,
|
141 |
+
mlp_ratio=mlp_ratio,
|
142 |
+
qkv_bias=qkv_bias,
|
143 |
+
proj_bias=proj_bias,
|
144 |
+
ffn_bias=ffn_bias,
|
145 |
+
drop_path=dpr[i],
|
146 |
+
norm_layer=norm_layer,
|
147 |
+
act_layer=act_layer,
|
148 |
+
ffn_layer=ffn_layer,
|
149 |
+
init_values=init_values,
|
150 |
+
)
|
151 |
+
for i in range(depth)
|
152 |
+
]
|
153 |
+
if block_chunks > 0:
|
154 |
+
self.chunked_blocks = True
|
155 |
+
chunked_blocks = []
|
156 |
+
chunksize = depth // block_chunks
|
157 |
+
for i in range(0, depth, chunksize):
|
158 |
+
# this is to keep the block index consistent if we chunk the block list
|
159 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
160 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
161 |
+
else:
|
162 |
+
self.chunked_blocks = False
|
163 |
+
self.blocks = nn.ModuleList(blocks_list)
|
164 |
+
|
165 |
+
self.norm = norm_layer(embed_dim)
|
166 |
+
self.head = nn.Identity()
|
167 |
+
|
168 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
169 |
+
|
170 |
+
self.init_weights()
|
171 |
+
|
172 |
+
def init_weights(self):
|
173 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
174 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
175 |
+
if self.register_tokens is not None:
|
176 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
177 |
+
named_apply(init_weights_vit_timm, self)
|
178 |
+
|
179 |
+
def interpolate_pos_encoding(self, x, w, h):
|
180 |
+
previous_dtype = x.dtype
|
181 |
+
npatch = x.shape[1] - 1
|
182 |
+
N = self.pos_embed.shape[1] - 1
|
183 |
+
if npatch == N and w == h:
|
184 |
+
return self.pos_embed
|
185 |
+
pos_embed = self.pos_embed.float()
|
186 |
+
class_pos_embed = pos_embed[:, 0]
|
187 |
+
patch_pos_embed = pos_embed[:, 1:]
|
188 |
+
dim = x.shape[-1]
|
189 |
+
w0 = w // self.patch_size
|
190 |
+
h0 = h // self.patch_size
|
191 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
192 |
+
assert N == M * M
|
193 |
+
kwargs = {}
|
194 |
+
if self.interpolate_offset:
|
195 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
196 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
197 |
+
sx = float(w0 + self.interpolate_offset) / M
|
198 |
+
sy = float(h0 + self.interpolate_offset) / M
|
199 |
+
kwargs["scale_factor"] = (sx, sy)
|
200 |
+
else:
|
201 |
+
# Simply specify an output size instead of a scale factor
|
202 |
+
kwargs["size"] = (w0, h0)
|
203 |
+
patch_pos_embed = nn.functional.interpolate(
|
204 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
205 |
+
mode="bicubic",
|
206 |
+
antialias=self.interpolate_antialias,
|
207 |
+
**kwargs,
|
208 |
+
)
|
209 |
+
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
210 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
211 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
212 |
+
|
213 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
214 |
+
B, nc, w, h = x.shape
|
215 |
+
x = self.patch_embed(x)
|
216 |
+
if masks is not None:
|
217 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
218 |
+
|
219 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
220 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
221 |
+
|
222 |
+
if self.register_tokens is not None:
|
223 |
+
x = torch.cat(
|
224 |
+
(
|
225 |
+
x[:, :1],
|
226 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
227 |
+
x[:, 1:],
|
228 |
+
),
|
229 |
+
dim=1,
|
230 |
+
)
|
231 |
+
|
232 |
+
return x
|
233 |
+
|
234 |
+
def forward_features_list(self, x_list, masks_list):
|
235 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
236 |
+
for blk in self.blocks:
|
237 |
+
x = blk(x)
|
238 |
+
|
239 |
+
all_x = x
|
240 |
+
output = []
|
241 |
+
for x, masks in zip(all_x, masks_list):
|
242 |
+
x_norm = self.norm(x)
|
243 |
+
output.append(
|
244 |
+
{
|
245 |
+
"x_norm_clstoken": x_norm[:, 0],
|
246 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
247 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
248 |
+
"x_prenorm": x,
|
249 |
+
"masks": masks,
|
250 |
+
}
|
251 |
+
)
|
252 |
+
return output
|
253 |
+
|
254 |
+
def forward_features(self, x, masks=None):
|
255 |
+
if isinstance(x, list):
|
256 |
+
return self.forward_features_list(x, masks)
|
257 |
+
|
258 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
259 |
+
|
260 |
+
for blk in self.blocks:
|
261 |
+
x = blk(x)
|
262 |
+
|
263 |
+
x_norm = self.norm(x)
|
264 |
+
return {
|
265 |
+
"x_norm_clstoken": x_norm[:, 0],
|
266 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
267 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
268 |
+
"x_prenorm": x,
|
269 |
+
"masks": masks,
|
270 |
+
}
|
271 |
+
|
272 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
273 |
+
x = self.prepare_tokens_with_masks(x)
|
274 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
275 |
+
output, total_block_len = [], len(self.blocks)
|
276 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
277 |
+
for i, blk in enumerate(self.blocks):
|
278 |
+
x = blk(x)
|
279 |
+
if i in blocks_to_take:
|
280 |
+
output.append(x)
|
281 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
282 |
+
return output
|
283 |
+
|
284 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
285 |
+
x = self.prepare_tokens_with_masks(x)
|
286 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
287 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
288 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
289 |
+
for block_chunk in self.blocks:
|
290 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
291 |
+
x = blk(x)
|
292 |
+
if i in blocks_to_take:
|
293 |
+
output.append(x)
|
294 |
+
i += 1
|
295 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
296 |
+
return output
|
297 |
+
|
298 |
+
def get_intermediate_layers(
|
299 |
+
self,
|
300 |
+
x: torch.Tensor,
|
301 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
302 |
+
reshape: bool = False,
|
303 |
+
return_class_token: bool = False,
|
304 |
+
norm=True,
|
305 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
306 |
+
if self.chunked_blocks:
|
307 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
308 |
+
else:
|
309 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
310 |
+
if norm:
|
311 |
+
outputs = [self.norm(out) for out in outputs]
|
312 |
+
class_tokens = [out[:, 0] for out in outputs]
|
313 |
+
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
|
314 |
+
if reshape:
|
315 |
+
B, _, w, h = x.shape
|
316 |
+
outputs = [
|
317 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
318 |
+
for out in outputs
|
319 |
+
]
|
320 |
+
if return_class_token:
|
321 |
+
return tuple(zip(outputs, class_tokens))
|
322 |
+
return tuple(outputs)
|
323 |
+
|
324 |
+
def forward(self, *args, is_training=False, **kwargs):
|
325 |
+
ret = self.forward_features(*args, **kwargs)
|
326 |
+
if is_training:
|
327 |
+
return ret
|
328 |
+
else:
|
329 |
+
return self.head(ret["x_norm_clstoken"])
|
330 |
+
|
331 |
+
|
332 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
333 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
334 |
+
if isinstance(module, nn.Linear):
|
335 |
+
trunc_normal_(module.weight, std=0.02)
|
336 |
+
if module.bias is not None:
|
337 |
+
nn.init.zeros_(module.bias)
|
338 |
+
|
339 |
+
|
340 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
341 |
+
model = DinoVisionTransformer(
|
342 |
+
patch_size=patch_size,
|
343 |
+
embed_dim=384,
|
344 |
+
depth=12,
|
345 |
+
num_heads=6,
|
346 |
+
mlp_ratio=4,
|
347 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
348 |
+
num_register_tokens=num_register_tokens,
|
349 |
+
**kwargs,
|
350 |
+
)
|
351 |
+
return model
|
352 |
+
|
353 |
+
|
354 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
355 |
+
model = DinoVisionTransformer(
|
356 |
+
patch_size=patch_size,
|
357 |
+
embed_dim=768,
|
358 |
+
depth=12,
|
359 |
+
num_heads=12,
|
360 |
+
mlp_ratio=4,
|
361 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
362 |
+
num_register_tokens=num_register_tokens,
|
363 |
+
**kwargs,
|
364 |
+
)
|
365 |
+
return model
|
366 |
+
|
367 |
+
|
368 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
369 |
+
model = DinoVisionTransformer(
|
370 |
+
patch_size=patch_size,
|
371 |
+
embed_dim=1024,
|
372 |
+
depth=24,
|
373 |
+
num_heads=16,
|
374 |
+
mlp_ratio=4,
|
375 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
376 |
+
num_register_tokens=num_register_tokens,
|
377 |
+
**kwargs,
|
378 |
+
)
|
379 |
+
return model
|
380 |
+
|
381 |
+
|
382 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
383 |
+
"""
|
384 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
385 |
+
"""
|
386 |
+
model = DinoVisionTransformer(
|
387 |
+
patch_size=patch_size,
|
388 |
+
embed_dim=1536,
|
389 |
+
depth=40,
|
390 |
+
num_heads=24,
|
391 |
+
mlp_ratio=4,
|
392 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
393 |
+
num_register_tokens=num_register_tokens,
|
394 |
+
**kwargs,
|
395 |
+
)
|
396 |
+
return model
|
src/lari/model/dinov2/utils/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
src/lari/model/dinov2/utils/cluster.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from enum import Enum
|
7 |
+
import os
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Any, Dict, Optional
|
10 |
+
|
11 |
+
|
12 |
+
class ClusterType(Enum):
|
13 |
+
AWS = "aws"
|
14 |
+
FAIR = "fair"
|
15 |
+
RSC = "rsc"
|
16 |
+
|
17 |
+
|
18 |
+
def _guess_cluster_type() -> ClusterType:
|
19 |
+
uname = os.uname()
|
20 |
+
if uname.sysname == "Linux":
|
21 |
+
if uname.release.endswith("-aws"):
|
22 |
+
# Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
|
23 |
+
return ClusterType.AWS
|
24 |
+
elif uname.nodename.startswith("rsc"):
|
25 |
+
# Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
|
26 |
+
return ClusterType.RSC
|
27 |
+
|
28 |
+
return ClusterType.FAIR
|
29 |
+
|
30 |
+
|
31 |
+
def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
|
32 |
+
if cluster_type is None:
|
33 |
+
return _guess_cluster_type()
|
34 |
+
|
35 |
+
return cluster_type
|
36 |
+
|
37 |
+
|
38 |
+
def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
39 |
+
cluster_type = get_cluster_type(cluster_type)
|
40 |
+
if cluster_type is None:
|
41 |
+
return None
|
42 |
+
|
43 |
+
CHECKPOINT_DIRNAMES = {
|
44 |
+
ClusterType.AWS: "checkpoints",
|
45 |
+
ClusterType.FAIR: "checkpoint",
|
46 |
+
ClusterType.RSC: "checkpoint/dino",
|
47 |
+
}
|
48 |
+
return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
|
49 |
+
|
50 |
+
|
51 |
+
def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
52 |
+
checkpoint_path = get_checkpoint_path(cluster_type)
|
53 |
+
if checkpoint_path is None:
|
54 |
+
return None
|
55 |
+
|
56 |
+
username = os.environ.get("USER")
|
57 |
+
assert username is not None
|
58 |
+
return checkpoint_path / username
|
59 |
+
|
60 |
+
|
61 |
+
def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
|
62 |
+
cluster_type = get_cluster_type(cluster_type)
|
63 |
+
if cluster_type is None:
|
64 |
+
return None
|
65 |
+
|
66 |
+
SLURM_PARTITIONS = {
|
67 |
+
ClusterType.AWS: "learnlab",
|
68 |
+
ClusterType.FAIR: "learnlab",
|
69 |
+
ClusterType.RSC: "learn",
|
70 |
+
}
|
71 |
+
return SLURM_PARTITIONS[cluster_type]
|
72 |
+
|
73 |
+
|
74 |
+
def get_slurm_executor_parameters(
|
75 |
+
nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
|
76 |
+
) -> Dict[str, Any]:
|
77 |
+
# create default parameters
|
78 |
+
params = {
|
79 |
+
"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
|
80 |
+
"gpus_per_node": num_gpus_per_node,
|
81 |
+
"tasks_per_node": num_gpus_per_node, # one task per GPU
|
82 |
+
"cpus_per_task": 10,
|
83 |
+
"nodes": nodes,
|
84 |
+
"slurm_partition": get_slurm_partition(cluster_type),
|
85 |
+
}
|
86 |
+
# apply cluster-specific adjustments
|
87 |
+
cluster_type = get_cluster_type(cluster_type)
|
88 |
+
if cluster_type == ClusterType.AWS:
|
89 |
+
params["cpus_per_task"] = 12
|
90 |
+
del params["mem_gb"]
|
91 |
+
elif cluster_type == ClusterType.RSC:
|
92 |
+
params["cpus_per_task"] = 12
|
93 |
+
# set additional parameters / apply overrides
|
94 |
+
params.update(kwargs)
|
95 |
+
return params
|
src/lari/model/dinov2/utils/config.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
|
12 |
+
import dinov2.distributed as distributed
|
13 |
+
from dinov2.logging import setup_logging
|
14 |
+
from dinov2.utils import utils
|
15 |
+
from dinov2.configs import dinov2_default_config
|
16 |
+
|
17 |
+
|
18 |
+
logger = logging.getLogger("dinov2")
|
19 |
+
|
20 |
+
|
21 |
+
def apply_scaling_rules_to_cfg(cfg): # to fix
|
22 |
+
if cfg.optim.scaling_rule == "sqrt_wrt_1024":
|
23 |
+
base_lr = cfg.optim.base_lr
|
24 |
+
cfg.optim.lr = base_lr
|
25 |
+
cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
|
26 |
+
logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
|
27 |
+
else:
|
28 |
+
raise NotImplementedError
|
29 |
+
return cfg
|
30 |
+
|
31 |
+
|
32 |
+
def write_config(cfg, output_dir, name="config.yaml"):
|
33 |
+
logger.info(OmegaConf.to_yaml(cfg))
|
34 |
+
saved_cfg_path = os.path.join(output_dir, name)
|
35 |
+
with open(saved_cfg_path, "w") as f:
|
36 |
+
OmegaConf.save(config=cfg, f=f)
|
37 |
+
return saved_cfg_path
|
38 |
+
|
39 |
+
|
40 |
+
def get_cfg_from_args(args):
|
41 |
+
args.output_dir = os.path.abspath(args.output_dir)
|
42 |
+
args.opts += [f"train.output_dir={args.output_dir}"]
|
43 |
+
default_cfg = OmegaConf.create(dinov2_default_config)
|
44 |
+
cfg = OmegaConf.load(args.config_file)
|
45 |
+
cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
|
46 |
+
return cfg
|
47 |
+
|
48 |
+
|
49 |
+
def default_setup(args):
|
50 |
+
distributed.enable(overwrite=True)
|
51 |
+
seed = getattr(args, "seed", 0)
|
52 |
+
rank = distributed.get_global_rank()
|
53 |
+
|
54 |
+
global logger
|
55 |
+
setup_logging(output=args.output_dir, level=logging.INFO)
|
56 |
+
logger = logging.getLogger("dinov2")
|
57 |
+
|
58 |
+
utils.fix_random_seeds(seed + rank)
|
59 |
+
logger.info("git:\n {}\n".format(utils.get_sha()))
|
60 |
+
logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
|
61 |
+
|
62 |
+
|
63 |
+
def setup(args):
|
64 |
+
"""
|
65 |
+
Create configs and perform basic setups.
|
66 |
+
"""
|
67 |
+
cfg = get_cfg_from_args(args)
|
68 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
69 |
+
default_setup(args)
|
70 |
+
apply_scaling_rules_to_cfg(cfg)
|
71 |
+
write_config(cfg, args.output_dir)
|
72 |
+
return cfg
|
src/lari/model/dinov2/utils/dtype.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
from typing import Dict, Union
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
|
13 |
+
TypeSpec = Union[str, np.dtype, torch.dtype]
|
14 |
+
|
15 |
+
|
16 |
+
_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
|
17 |
+
np.dtype("bool"): torch.bool,
|
18 |
+
np.dtype("uint8"): torch.uint8,
|
19 |
+
np.dtype("int8"): torch.int8,
|
20 |
+
np.dtype("int16"): torch.int16,
|
21 |
+
np.dtype("int32"): torch.int32,
|
22 |
+
np.dtype("int64"): torch.int64,
|
23 |
+
np.dtype("float16"): torch.float16,
|
24 |
+
np.dtype("float32"): torch.float32,
|
25 |
+
np.dtype("float64"): torch.float64,
|
26 |
+
np.dtype("complex64"): torch.complex64,
|
27 |
+
np.dtype("complex128"): torch.complex128,
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
|
32 |
+
if isinstance(dtype, torch.dtype):
|
33 |
+
return dtype
|
34 |
+
if isinstance(dtype, str):
|
35 |
+
dtype = np.dtype(dtype)
|
36 |
+
assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
|
37 |
+
return _NUMPY_TO_TORCH_DTYPE[dtype]
|
src/lari/model/dinov2/utils/param_groups.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from collections import defaultdict
|
7 |
+
import logging
|
8 |
+
|
9 |
+
|
10 |
+
logger = logging.getLogger("dinov2")
|
11 |
+
|
12 |
+
|
13 |
+
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
|
14 |
+
"""
|
15 |
+
Calculate lr decay rate for different ViT blocks.
|
16 |
+
Args:
|
17 |
+
name (string): parameter name.
|
18 |
+
lr_decay_rate (float): base lr decay rate.
|
19 |
+
num_layers (int): number of ViT blocks.
|
20 |
+
Returns:
|
21 |
+
lr decay rate for the given parameter.
|
22 |
+
"""
|
23 |
+
layer_id = num_layers + 1
|
24 |
+
if name.startswith("backbone") or force_is_backbone:
|
25 |
+
if (
|
26 |
+
".pos_embed" in name
|
27 |
+
or ".patch_embed" in name
|
28 |
+
or ".mask_token" in name
|
29 |
+
or ".cls_token" in name
|
30 |
+
or ".register_tokens" in name
|
31 |
+
):
|
32 |
+
layer_id = 0
|
33 |
+
elif force_is_backbone and (
|
34 |
+
"pos_embed" in name
|
35 |
+
or "patch_embed" in name
|
36 |
+
or "mask_token" in name
|
37 |
+
or "cls_token" in name
|
38 |
+
or "register_tokens" in name
|
39 |
+
):
|
40 |
+
layer_id = 0
|
41 |
+
elif ".blocks." in name and ".residual." not in name:
|
42 |
+
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
|
43 |
+
elif chunked_blocks and "blocks." in name and "residual." not in name:
|
44 |
+
layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
|
45 |
+
elif "blocks." in name and "residual." not in name:
|
46 |
+
layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
|
47 |
+
|
48 |
+
return lr_decay_rate ** (num_layers + 1 - layer_id)
|
49 |
+
|
50 |
+
|
51 |
+
def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
|
52 |
+
chunked_blocks = False
|
53 |
+
if hasattr(model, "n_blocks"):
|
54 |
+
logger.info("chunked fsdp")
|
55 |
+
n_blocks = model.n_blocks
|
56 |
+
chunked_blocks = model.chunked_blocks
|
57 |
+
elif hasattr(model, "blocks"):
|
58 |
+
logger.info("first code branch")
|
59 |
+
n_blocks = len(model.blocks)
|
60 |
+
elif hasattr(model, "backbone"):
|
61 |
+
logger.info("second code branch")
|
62 |
+
n_blocks = len(model.backbone.blocks)
|
63 |
+
else:
|
64 |
+
logger.info("else code branch")
|
65 |
+
n_blocks = 0
|
66 |
+
all_param_groups = []
|
67 |
+
|
68 |
+
for name, param in model.named_parameters():
|
69 |
+
name = name.replace("_fsdp_wrapped_module.", "")
|
70 |
+
if not param.requires_grad:
|
71 |
+
continue
|
72 |
+
decay_rate = get_vit_lr_decay_rate(
|
73 |
+
name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
|
74 |
+
)
|
75 |
+
d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
|
76 |
+
|
77 |
+
if "last_layer" in name:
|
78 |
+
d.update({"is_last_layer": True})
|
79 |
+
|
80 |
+
if name.endswith(".bias") or "norm" in name or "gamma" in name:
|
81 |
+
d.update({"wd_multiplier": 0.0})
|
82 |
+
|
83 |
+
if "patch_embed" in name:
|
84 |
+
d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
|
85 |
+
|
86 |
+
all_param_groups.append(d)
|
87 |
+
logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
|
88 |
+
|
89 |
+
return all_param_groups
|
90 |
+
|
91 |
+
|
92 |
+
def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
|
93 |
+
fused_params_groups = defaultdict(lambda: {"params": []})
|
94 |
+
for d in all_params_groups:
|
95 |
+
identifier = ""
|
96 |
+
for k in keys:
|
97 |
+
identifier += k + str(d[k]) + "_"
|
98 |
+
|
99 |
+
for k in keys:
|
100 |
+
fused_params_groups[identifier][k] = d[k]
|
101 |
+
fused_params_groups[identifier]["params"].append(d["params"])
|
102 |
+
|
103 |
+
return fused_params_groups.values()
|
src/lari/model/dinov2/utils/utils.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import os
|
8 |
+
import random
|
9 |
+
import subprocess
|
10 |
+
from urllib.parse import urlparse
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
|
17 |
+
logger = logging.getLogger("dinov2")
|
18 |
+
|
19 |
+
|
20 |
+
def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
|
21 |
+
if urlparse(pretrained_weights).scheme: # If it looks like an URL
|
22 |
+
state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
|
23 |
+
else:
|
24 |
+
state_dict = torch.load(pretrained_weights, map_location="cpu")
|
25 |
+
if checkpoint_key is not None and checkpoint_key in state_dict:
|
26 |
+
logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
|
27 |
+
state_dict = state_dict[checkpoint_key]
|
28 |
+
# remove `module.` prefix
|
29 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
30 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
31 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
32 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
33 |
+
logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
|
34 |
+
|
35 |
+
|
36 |
+
def fix_random_seeds(seed=31):
|
37 |
+
"""
|
38 |
+
Fix random seeds.
|
39 |
+
"""
|
40 |
+
torch.manual_seed(seed)
|
41 |
+
torch.cuda.manual_seed_all(seed)
|
42 |
+
np.random.seed(seed)
|
43 |
+
random.seed(seed)
|
44 |
+
|
45 |
+
|
46 |
+
def get_sha():
|
47 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
48 |
+
|
49 |
+
def _run(command):
|
50 |
+
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
|
51 |
+
|
52 |
+
sha = "N/A"
|
53 |
+
diff = "clean"
|
54 |
+
branch = "N/A"
|
55 |
+
try:
|
56 |
+
sha = _run(["git", "rev-parse", "HEAD"])
|
57 |
+
subprocess.check_output(["git", "diff"], cwd=cwd)
|
58 |
+
diff = _run(["git", "diff-index", "HEAD"])
|
59 |
+
diff = "has uncommitted changes" if diff else "clean"
|
60 |
+
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
61 |
+
except Exception:
|
62 |
+
pass
|
63 |
+
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
64 |
+
return message
|
65 |
+
|
66 |
+
|
67 |
+
class CosineScheduler(object):
|
68 |
+
def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
|
69 |
+
super().__init__()
|
70 |
+
self.final_value = final_value
|
71 |
+
self.total_iters = total_iters
|
72 |
+
|
73 |
+
freeze_schedule = np.zeros((freeze_iters))
|
74 |
+
|
75 |
+
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
76 |
+
|
77 |
+
iters = np.arange(total_iters - warmup_iters - freeze_iters)
|
78 |
+
schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
|
79 |
+
self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
|
80 |
+
|
81 |
+
assert len(self.schedule) == self.total_iters
|
82 |
+
|
83 |
+
def __getitem__(self, it):
|
84 |
+
if it >= self.total_iters:
|
85 |
+
return self.final_value
|
86 |
+
else:
|
87 |
+
return self.schedule[it]
|
88 |
+
|
89 |
+
|
90 |
+
def has_batchnorms(model):
|
91 |
+
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
|
92 |
+
for name, module in model.named_modules():
|
93 |
+
if isinstance(module, bn_types):
|
94 |
+
return True
|
95 |
+
return False
|
src/lari/model/dpt_seg_head.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
The code is modified based on Depth Anything and DPT
|
3 |
+
'''
|
4 |
+
from src.lari.model.blocks import FeatureFusionBlock, _make_scratch
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torchvision.transforms import Compose
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
def _make_fusion_block(features, use_bn, size=None):
|
15 |
+
return FeatureFusionBlock(
|
16 |
+
features,
|
17 |
+
nn.ReLU(False),
|
18 |
+
deconv=False,
|
19 |
+
bn=use_bn,
|
20 |
+
expand=False,
|
21 |
+
align_corners=True,
|
22 |
+
size=size,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class ConvBlock(nn.Module):
|
27 |
+
def __init__(self, in_feature, out_feature):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.conv_block = nn.Sequential(
|
31 |
+
nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
|
32 |
+
nn.BatchNorm2d(out_feature),
|
33 |
+
nn.ReLU(True)
|
34 |
+
)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
return self.conv_block(x)
|
38 |
+
|
39 |
+
|
40 |
+
class DPTSegHead(nn.Module):
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
in_channels,
|
44 |
+
features=256,
|
45 |
+
use_bn=False,
|
46 |
+
out_channels=[256, 512, 1024, 1024],
|
47 |
+
use_clstoken=False,
|
48 |
+
num_classes = 5,
|
49 |
+
output_type = "ray_stop" # "seg_sep"
|
50 |
+
):
|
51 |
+
super(DPTSegHead, self).__init__()
|
52 |
+
|
53 |
+
self.use_clstoken = use_clstoken
|
54 |
+
self.output_type = output_type
|
55 |
+
|
56 |
+
# output one more layer to indicate the invalid ray-stopping point using index 0
|
57 |
+
self.num_classes = num_classes + 1 if self.output_type == "ray_stop" else num_classes
|
58 |
+
|
59 |
+
|
60 |
+
self.projects = nn.ModuleList([
|
61 |
+
nn.Conv2d(
|
62 |
+
in_channels=in_channels,
|
63 |
+
out_channels=out_channel,
|
64 |
+
kernel_size=1,
|
65 |
+
stride=1,
|
66 |
+
padding=0,
|
67 |
+
) for out_channel in out_channels
|
68 |
+
])
|
69 |
+
|
70 |
+
self.resize_layers = nn.ModuleList([
|
71 |
+
nn.ConvTranspose2d(
|
72 |
+
in_channels=out_channels[0],
|
73 |
+
out_channels=out_channels[0],
|
74 |
+
kernel_size=4,
|
75 |
+
stride=4,
|
76 |
+
padding=0),
|
77 |
+
nn.ConvTranspose2d(
|
78 |
+
in_channels=out_channels[1],
|
79 |
+
out_channels=out_channels[1],
|
80 |
+
kernel_size=2,
|
81 |
+
stride=2,
|
82 |
+
padding=0),
|
83 |
+
nn.Identity(),
|
84 |
+
nn.Conv2d(
|
85 |
+
in_channels=out_channels[3],
|
86 |
+
out_channels=out_channels[3],
|
87 |
+
kernel_size=3,
|
88 |
+
stride=2,
|
89 |
+
padding=1)
|
90 |
+
])
|
91 |
+
|
92 |
+
if use_clstoken:
|
93 |
+
self.readout_projects = nn.ModuleList()
|
94 |
+
for _ in range(len(self.projects)):
|
95 |
+
self.readout_projects.append(
|
96 |
+
nn.Sequential(
|
97 |
+
nn.Linear(2 * in_channels, in_channels),
|
98 |
+
nn.GELU()))
|
99 |
+
|
100 |
+
self.scratch = _make_scratch(
|
101 |
+
out_channels,
|
102 |
+
features,
|
103 |
+
groups=1,
|
104 |
+
expand=False,
|
105 |
+
)
|
106 |
+
|
107 |
+
self.scratch.stem_transpose = None
|
108 |
+
|
109 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
110 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
111 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
112 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
113 |
+
|
114 |
+
self.scratch.output_conv1 = nn.Sequential(
|
115 |
+
nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
|
116 |
+
nn.BatchNorm2d(features),
|
117 |
+
nn.ReLU(True),
|
118 |
+
nn.Dropout(0.1, False),
|
119 |
+
nn.Conv2d(features, self.num_classes, kernel_size=1),
|
120 |
+
)
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
def forward(self, out_features, patch_h, patch_w):
|
125 |
+
out = []
|
126 |
+
for i, x in enumerate(out_features):
|
127 |
+
if self.use_clstoken:
|
128 |
+
x, cls_token = x[0], x[1]
|
129 |
+
readout = cls_token.unsqueeze(1).expand_as(x)
|
130 |
+
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
131 |
+
else:
|
132 |
+
x = x[0]
|
133 |
+
|
134 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
135 |
+
|
136 |
+
x = self.projects[i](x)
|
137 |
+
x = self.resize_layers[i](x)
|
138 |
+
|
139 |
+
out.append(x)
|
140 |
+
|
141 |
+
layer_1, layer_2, layer_3, layer_4 = out
|
142 |
+
|
143 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
144 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
145 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
146 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
147 |
+
|
148 |
+
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
149 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
150 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
151 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
152 |
+
|
153 |
+
out = self.scratch.output_conv1(path_1)
|
154 |
+
|
155 |
+
# B C H W - segmentaton logits
|
156 |
+
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
|
157 |
+
|
158 |
+
return out
|
src/lari/model/heads.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.utils
|
6 |
+
import torch.utils.checkpoint
|
7 |
+
import torch.version
|
8 |
+
from typing import *
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
|
12 |
+
from src.lari.model.blocks import ResidualConvBlock, make_upsampler, make_output_block
|
13 |
+
from src.lari.utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, gaussian_blur_2d
|
14 |
+
|
15 |
+
|
16 |
+
class PointHead(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
num_features: int,
|
20 |
+
dim_in: int,
|
21 |
+
dim_out: int,
|
22 |
+
dim_proj: int = 512,
|
23 |
+
dim_upsample: List[int] = [256, 128, 128],
|
24 |
+
dim_times_res_block_hidden: int = 1,
|
25 |
+
num_res_blocks: int = 1,
|
26 |
+
res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm',
|
27 |
+
last_res_blocks: int = 0,
|
28 |
+
last_conv_channels: int = 32,
|
29 |
+
last_conv_size: int = 1,
|
30 |
+
num_output_layer: int = 5
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
|
34 |
+
self.num_output_layer = num_output_layer
|
35 |
+
|
36 |
+
self.projects = nn.ModuleList([
|
37 |
+
nn.Conv2d(in_channels=dim_in, out_channels=dim_proj, kernel_size=1, stride=1, padding=0,) for _ in range(num_features)
|
38 |
+
])
|
39 |
+
|
40 |
+
self.upsample_blocks = nn.ModuleList([
|
41 |
+
nn.Sequential(
|
42 |
+
make_upsampler(in_ch + 2, out_ch),
|
43 |
+
*(ResidualConvBlock(out_ch, out_ch, dim_times_res_block_hidden * out_ch, activation="relu", norm=res_block_norm) for _ in range(num_res_blocks))
|
44 |
+
) for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample)
|
45 |
+
])
|
46 |
+
|
47 |
+
# layer iterations
|
48 |
+
self.first_layer_block = make_output_block(dim_upsample[-1] + 2, dim_out,
|
49 |
+
dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm,)
|
50 |
+
|
51 |
+
self.remaining_layer_block = nn.ModuleList([make_output_block(dim_upsample[-1] + 2, dim_out,
|
52 |
+
dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm,)
|
53 |
+
for _ in range(self.num_output_layer - 1)])
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
def forward(self, hidden_states: torch.Tensor, image: torch.Tensor):
|
58 |
+
img_h, img_w = image.shape[-2:]
|
59 |
+
patch_h, patch_w = img_h // 14, img_w // 14
|
60 |
+
|
61 |
+
# Process the hidden states
|
62 |
+
x = torch.stack([
|
63 |
+
proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous())
|
64 |
+
for proj, (feat, clstoken) in zip(self.projects, hidden_states)
|
65 |
+
], dim=1).sum(dim=1)
|
66 |
+
|
67 |
+
# Upsample stage
|
68 |
+
# (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8)
|
69 |
+
for i, block in enumerate(self.upsample_blocks):
|
70 |
+
# UV coordinates is for awareness of image aspect ratio
|
71 |
+
uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
|
72 |
+
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
|
73 |
+
x = torch.cat([x, uv], dim=1)
|
74 |
+
for layer in block:
|
75 |
+
x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False)
|
76 |
+
|
77 |
+
# (patch_h * 8, patch_w * 8) -> (img_h, img_w)
|
78 |
+
x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False)
|
79 |
+
uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
|
80 |
+
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
|
81 |
+
x = torch.cat([x, uv], dim=1)
|
82 |
+
|
83 |
+
|
84 |
+
pts_list = []
|
85 |
+
for layer_id in range(self.num_output_layer):
|
86 |
+
if layer_id == 0:
|
87 |
+
blocks = self.first_layer_block
|
88 |
+
else:
|
89 |
+
blocks = self.remaining_layer_block[layer_id-1]
|
90 |
+
|
91 |
+
# for each block
|
92 |
+
if isinstance(blocks, nn.ModuleList):
|
93 |
+
raise NotImplementedError()
|
94 |
+
else:
|
95 |
+
res = torch.utils.checkpoint.checkpoint(blocks, x, use_reentrant=False)[:,:3, :,:]
|
96 |
+
pts_list.append(res[:, :3, :,:])
|
97 |
+
|
98 |
+
pts = torch.stack(pts_list, dim=-1)
|
99 |
+
seg = pts.new_zeros(pts.shape)[:, :1, ...]
|
100 |
+
|
101 |
+
# <b 3 h w l>, <b 1 h w l>
|
102 |
+
output = [pts, seg]
|
103 |
+
|
104 |
+
return output
|
src/lari/model/lari_model.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
from numbers import Number
|
3 |
+
from functools import partial
|
4 |
+
from pathlib import Path
|
5 |
+
import importlib
|
6 |
+
import warnings
|
7 |
+
import json
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torch.utils
|
13 |
+
import torch.utils.checkpoint
|
14 |
+
import torch.version
|
15 |
+
from huggingface_hub import hf_hub_download
|
16 |
+
from src.lari.model.utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
|
17 |
+
from src.lari.model.heads import PointHead
|
18 |
+
|
19 |
+
|
20 |
+
class LaRIModel(nn.Module):
|
21 |
+
image_mean: torch.Tensor
|
22 |
+
image_std: torch.Tensor
|
23 |
+
|
24 |
+
def __init__(self,
|
25 |
+
encoder: str = 'dinov2_vitl14',
|
26 |
+
intermediate_layers: Union[int, List[int]] = 4,
|
27 |
+
dim_proj: int = 512,
|
28 |
+
dim_upsample: List[int] = [256, 128, 64],
|
29 |
+
dim_times_res_block_hidden: int = 2,
|
30 |
+
num_res_blocks: int = 2,
|
31 |
+
output_mask: bool = True,
|
32 |
+
split_head: bool = True,
|
33 |
+
remap_output: Literal[False, True, 'linear', 'sinh', 'exp', 'sinh_exp'] = 'exp',
|
34 |
+
res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm',
|
35 |
+
last_res_blocks: int = 0,
|
36 |
+
last_conv_channels: int = 32,
|
37 |
+
last_conv_size: int = 1,
|
38 |
+
use_pretrained: Literal["dinov2", "moge_full", "moge_backbone", None] = None,
|
39 |
+
pretrained_path: str = "",
|
40 |
+
num_output_layer: str = None,
|
41 |
+
head_type = None,
|
42 |
+
**deprecated_kwargs
|
43 |
+
):
|
44 |
+
super(LaRIModel, self).__init__()
|
45 |
+
if deprecated_kwargs:
|
46 |
+
warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
|
47 |
+
|
48 |
+
self.encoder = encoder
|
49 |
+
self.remap_output = remap_output
|
50 |
+
self.intermediate_layers = intermediate_layers
|
51 |
+
self.head_type = head_type
|
52 |
+
self.output_mask = output_mask
|
53 |
+
self.split_head = split_head
|
54 |
+
self.use_pretrained = use_pretrained
|
55 |
+
self.pretrained_path = pretrained_path
|
56 |
+
self.num_output_layer = num_output_layer
|
57 |
+
|
58 |
+
hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder)
|
59 |
+
# hub_loader = getattr(importlib.import_module("dinov2.hub.backbones", __package__), encoder)
|
60 |
+
|
61 |
+
self.backbone = hub_loader(pretrained=True if self.use_pretrained == "dinov2" else False)
|
62 |
+
dim_feature = self.backbone.blocks[0].attn.qkv.in_features
|
63 |
+
|
64 |
+
if self.head_type == "point":
|
65 |
+
self.head = PointHead(
|
66 |
+
num_features=intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers),
|
67 |
+
dim_in=dim_feature,
|
68 |
+
dim_out=3,
|
69 |
+
dim_proj=dim_proj,
|
70 |
+
dim_upsample=dim_upsample,
|
71 |
+
dim_times_res_block_hidden=dim_times_res_block_hidden,
|
72 |
+
num_res_blocks=num_res_blocks,
|
73 |
+
res_block_norm=res_block_norm,
|
74 |
+
last_res_blocks=last_res_blocks,
|
75 |
+
last_conv_channels=last_conv_channels,
|
76 |
+
last_conv_size=last_conv_size,
|
77 |
+
num_output_layer = num_output_layer
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
raise NotImplementedError()
|
81 |
+
|
82 |
+
|
83 |
+
if torch.__version__ >= '2.0':
|
84 |
+
self.enable_pytorch_native_sdpa()
|
85 |
+
|
86 |
+
self._load_pretrained()
|
87 |
+
|
88 |
+
|
89 |
+
def _load_pretrained(self):
|
90 |
+
'''
|
91 |
+
Load pre-trained weights
|
92 |
+
'''
|
93 |
+
if self.use_pretrained == "dinov2" or self.use_pretrained is None: return
|
94 |
+
|
95 |
+
if self.use_pretrained == "moge_full" and self.pretrained_path != "":
|
96 |
+
checkpoint = torch.load(self.pretrained_path, map_location='cpu', weights_only=True)
|
97 |
+
if self.head_type == "point":
|
98 |
+
key_transition_map = {"output_block": "first_layer_block"}
|
99 |
+
model_state_dict = {}
|
100 |
+
|
101 |
+
# change the key name of the dict
|
102 |
+
for key, val in checkpoint['model'].items():
|
103 |
+
for trans_src, trans_target in key_transition_map.items():
|
104 |
+
if trans_src in key:
|
105 |
+
model_state_dict[key.replace(trans_src, trans_target)] = val
|
106 |
+
else:
|
107 |
+
model_state_dict[key] = val
|
108 |
+
|
109 |
+
self.load_state_dict(model_state_dict, strict=False)
|
110 |
+
del model_state_dict
|
111 |
+
|
112 |
+
|
113 |
+
else:
|
114 |
+
return
|
115 |
+
|
116 |
+
|
117 |
+
@staticmethod
|
118 |
+
def cache_pretrained_backbone(encoder: str, pretrained: bool):
|
119 |
+
_ = torch.hub.load('facebookresearch/dinov2', encoder, pretrained=pretrained)
|
120 |
+
|
121 |
+
def load_pretrained_backbone(self):
|
122 |
+
"Load the backbone with pretrained dinov2 weights from torch hub"
|
123 |
+
state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict()
|
124 |
+
self.backbone.load_state_dict(state_dict)
|
125 |
+
|
126 |
+
def enable_backbone_gradient_checkpointing(self):
|
127 |
+
for i in range(len(self.backbone.blocks)):
|
128 |
+
self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
|
129 |
+
|
130 |
+
def enable_pytorch_native_sdpa(self):
|
131 |
+
for i in range(len(self.backbone.blocks)):
|
132 |
+
self.backbone.blocks[i].attn = wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn)
|
133 |
+
|
134 |
+
def forward(self, image: torch.Tensor, mixed_precision: bool = False) -> Dict[str, torch.Tensor]:
|
135 |
+
raw_img_h, raw_img_w = image.shape[-2:]
|
136 |
+
patch_h, patch_w = raw_img_h // 14, raw_img_w // 14
|
137 |
+
|
138 |
+
# Apply image transformation for DINOv2
|
139 |
+
image_14 = F.interpolate(image, (patch_h * 14, patch_w * 14), mode="bilinear", align_corners=False, antialias=True)
|
140 |
+
|
141 |
+
# Get intermediate layers from the backbone
|
142 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=mixed_precision):
|
143 |
+
features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True)
|
144 |
+
|
145 |
+
# Predict points and mask (mask scores)
|
146 |
+
points, mask = self.head(features, image)
|
147 |
+
|
148 |
+
is_output_prob = False
|
149 |
+
if mask.ndim == 5:
|
150 |
+
# <b, h, w, layer, 3>, <b, h, w, layer, 1>
|
151 |
+
points, mask = points.permute(0, 2, 3, 4, 1), mask.permute(0,2,3,4,1)
|
152 |
+
elif mask.ndim == 4: # <b, h, w, layer, 3>, <b, layer, h, w>
|
153 |
+
points = points.permute(0, 2, 3, 4, 1)
|
154 |
+
is_output_prob = True
|
155 |
+
|
156 |
+
if self.remap_output == 'linear' or self.remap_output == False:
|
157 |
+
pass
|
158 |
+
elif self.remap_output =='sinh' or self.remap_output == True:
|
159 |
+
points = torch.sinh(points)
|
160 |
+
elif self.remap_output == 'exp':
|
161 |
+
xy, z = points.split([2, 1], dim=-1)
|
162 |
+
z = torch.exp(z)
|
163 |
+
points = torch.cat([xy * z, z], dim=-1)
|
164 |
+
elif self.remap_output =='sinh_exp':
|
165 |
+
xy, z = points.split([2, 1], dim=-1)
|
166 |
+
points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
|
167 |
+
else:
|
168 |
+
raise ValueError(f"Invalid remap output type: {self.remap_output}")
|
169 |
+
|
170 |
+
return_dict = {'pts3d': points}
|
171 |
+
|
172 |
+
if not is_output_prob:
|
173 |
+
return_dict['mask'] = mask
|
174 |
+
else:
|
175 |
+
return_dict["seg_prob"] = mask
|
176 |
+
|
177 |
+
return return_dict
|
src/lari/model/utils.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
def wrap_module_with_gradient_checkpointing(module: nn.Module):
|
8 |
+
from torch.utils.checkpoint import checkpoint
|
9 |
+
class _CheckpointingWrapper(module.__class__):
|
10 |
+
_restore_cls = module.__class__
|
11 |
+
def forward(self, *args, **kwargs):
|
12 |
+
return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
|
13 |
+
|
14 |
+
module.__class__ = _CheckpointingWrapper
|
15 |
+
return module
|
16 |
+
|
17 |
+
|
18 |
+
def unwrap_module_with_gradient_checkpointing(module: nn.Module):
|
19 |
+
module.__class__ = module.__class__._restore_cls
|
20 |
+
|
21 |
+
|
22 |
+
def wrap_dinov2_attention_with_sdpa(module: nn.Module):
|
23 |
+
assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later"
|
24 |
+
class _AttentionWrapper(module.__class__):
|
25 |
+
def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
26 |
+
B, N, C = x.shape
|
27 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
|
28 |
+
|
29 |
+
q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
|
30 |
+
|
31 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_bias)
|
32 |
+
x = x.permute(0, 2, 1, 3).reshape(B, N, C)
|
33 |
+
|
34 |
+
x = self.proj(x)
|
35 |
+
x = self.proj_drop(x)
|
36 |
+
return x
|
37 |
+
module.__class__ = _AttentionWrapper
|
38 |
+
return module
|
src/lari/utils/__init__.py
ADDED
File without changes
|
src/lari/utils/geometry_numpy.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
from functools import partial
|
3 |
+
import math
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import utils3d
|
7 |
+
|
8 |
+
def weighted_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
|
9 |
+
if w is None:
|
10 |
+
return np.mean(x, axis=axis)
|
11 |
+
else:
|
12 |
+
w = w.astype(x.dtype)
|
13 |
+
return (x * w).mean(axis=axis) / np.clip(w.mean(axis=axis), eps, None)
|
14 |
+
|
15 |
+
|
16 |
+
def harmonic_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
|
17 |
+
if w is None:
|
18 |
+
return 1 / (1 / np.clip(x, eps, None)).mean(axis=axis)
|
19 |
+
else:
|
20 |
+
w = w.astype(x.dtype)
|
21 |
+
return 1 / (weighted_mean_numpy(1 / (x + eps), w, axis=axis, keepdims=keepdims, eps=eps) + eps)
|
22 |
+
|
23 |
+
|
24 |
+
def normalized_view_plane_uv_numpy(width: int, height: int, aspect_ratio: float = None, dtype: np.dtype = np.float32) -> np.ndarray:
|
25 |
+
"UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
|
26 |
+
if aspect_ratio is None:
|
27 |
+
aspect_ratio = width / height
|
28 |
+
|
29 |
+
span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
|
30 |
+
span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
|
31 |
+
|
32 |
+
u = np.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype)
|
33 |
+
v = np.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype)
|
34 |
+
u, v = np.meshgrid(u, v, indexing='xy')
|
35 |
+
uv = np.stack([u, v], axis=-1)
|
36 |
+
return uv
|
37 |
+
|
38 |
+
|
39 |
+
def focal_to_fov_numpy(focal: np.ndarray):
|
40 |
+
return 2 * np.arctan(0.5 / focal)
|
41 |
+
|
42 |
+
|
43 |
+
def fov_to_focal_numpy(fov: np.ndarray):
|
44 |
+
return 0.5 / np.tan(fov / 2)
|
45 |
+
|
46 |
+
|
47 |
+
def intrinsics_to_fov_numpy(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
48 |
+
fov_x = focal_to_fov_numpy(intrinsics[..., 0, 0])
|
49 |
+
fov_y = focal_to_fov_numpy(intrinsics[..., 1, 1])
|
50 |
+
return fov_x, fov_y
|
51 |
+
|
52 |
+
|
53 |
+
def point_map_to_depth_legacy_numpy(points: np.ndarray):
|
54 |
+
height, width = points.shape[-3:-1]
|
55 |
+
diagonal = (height ** 2 + width ** 2) ** 0.5
|
56 |
+
uv = normalized_view_plane_uv_numpy(width, height, dtype=points.dtype) # (H, W, 2)
|
57 |
+
_, uv = np.broadcast_arrays(points[..., :2], uv)
|
58 |
+
|
59 |
+
# Solve least squares problem
|
60 |
+
b = (uv * points[..., 2:]).reshape(*points.shape[:-3], -1) # (..., H * W * 2)
|
61 |
+
A = np.stack([points[..., :2], -uv], axis=-1).reshape(*points.shape[:-3], -1, 2) # (..., H * W * 2, 2)
|
62 |
+
|
63 |
+
M = A.swapaxes(-2, -1) @ A
|
64 |
+
solution = (np.linalg.inv(M + 1e-6 * np.eye(2)) @ (A.swapaxes(-2, -1) @ b[..., None])).squeeze(-1)
|
65 |
+
focal, shift = solution
|
66 |
+
|
67 |
+
depth = points[..., 2] + shift[..., None, None]
|
68 |
+
fov_x = np.arctan(width / diagonal / focal) * 2
|
69 |
+
fov_y = np.arctan(height / diagonal / focal) * 2
|
70 |
+
return depth, fov_x, fov_y, shift
|
71 |
+
|
72 |
+
|
73 |
+
def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray):
|
74 |
+
"Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal"
|
75 |
+
from scipy.optimize import least_squares
|
76 |
+
uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
|
77 |
+
|
78 |
+
def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
|
79 |
+
xy_proj = xy / (z + shift)[: , None]
|
80 |
+
f = (xy_proj * uv).sum() / np.square(xy_proj).sum()
|
81 |
+
err = (f * xy_proj - uv).ravel()
|
82 |
+
return err
|
83 |
+
|
84 |
+
solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
|
85 |
+
optim_shift = solution['x'].squeeze().astype(np.float32)
|
86 |
+
|
87 |
+
xy_proj = xy / (z + optim_shift)[: , None]
|
88 |
+
optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum()
|
89 |
+
|
90 |
+
return optim_shift, optim_focal
|
91 |
+
|
92 |
+
|
93 |
+
def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float):
|
94 |
+
"Solve `min |focal * xy / (z + shift) - uv|` with respect to shift"
|
95 |
+
from scipy.optimize import least_squares
|
96 |
+
uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
|
97 |
+
|
98 |
+
def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
|
99 |
+
xy_proj = xy/ (z + shift)[: , None]
|
100 |
+
err = (focal * xy_proj - uv).ravel()
|
101 |
+
return err
|
102 |
+
|
103 |
+
solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
|
104 |
+
optim_shift = solution['x'].squeeze().astype(np.float32)
|
105 |
+
|
106 |
+
return optim_shift
|
107 |
+
|
108 |
+
|
109 |
+
def recover_focal_shift_numpy(points: np.ndarray, mask: np.ndarray = None, focal: float = None, downsample_size: Tuple[int, int] = (64, 64)):
|
110 |
+
import cv2
|
111 |
+
assert points.shape[-1] == 3, "Points should (H, W, 3)"
|
112 |
+
|
113 |
+
height, width = points.shape[-3], points.shape[-2]
|
114 |
+
diagonal = (height ** 2 + width ** 2) ** 0.5
|
115 |
+
|
116 |
+
uv = normalized_view_plane_uv_numpy(width=width, height=height)
|
117 |
+
|
118 |
+
if mask is None:
|
119 |
+
points_lr = cv2.resize(points, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 3)
|
120 |
+
uv_lr = cv2.resize(uv, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 2)
|
121 |
+
else:
|
122 |
+
index, mask_lr = mask_aware_nearest_resize_numpy(mask, *downsample_size)
|
123 |
+
points_lr, uv_lr = points[index][mask_lr], uv[index][mask_lr]
|
124 |
+
|
125 |
+
if points_lr.size == 0:
|
126 |
+
return np.zeros((height, width)), 0, 0, 0
|
127 |
+
|
128 |
+
if focal is None:
|
129 |
+
focal, shift = solve_optimal_focal_shift(uv_lr, points_lr)
|
130 |
+
else:
|
131 |
+
shift = solve_optimal_shift(uv_lr, points_lr, focal)
|
132 |
+
|
133 |
+
return focal, shift
|
134 |
+
|
135 |
+
|
136 |
+
def mask_aware_nearest_resize_numpy(mask: np.ndarray, target_width: int, target_height: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
137 |
+
"""
|
138 |
+
Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
|
139 |
+
|
140 |
+
### Parameters
|
141 |
+
- `mask`: Input 2D mask of shape (..., H, W)
|
142 |
+
- `target_width`: target width of the resized map
|
143 |
+
- `target_height`: target height of the resized map
|
144 |
+
|
145 |
+
### Returns
|
146 |
+
- `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width). Indices are like j + i * W, where j is the row index and i is the column index.
|
147 |
+
- `target_mask`: Mask of the resized map of shape (..., target_height, target_width)
|
148 |
+
"""
|
149 |
+
height, width = mask.shape[-2:]
|
150 |
+
filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
|
151 |
+
filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
|
152 |
+
filter_size = filter_h_i * filter_w_i
|
153 |
+
padding_h, padding_w = round(filter_h_f / 2), round(filter_w_f / 2)
|
154 |
+
|
155 |
+
# Window the original mask and uv
|
156 |
+
uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32)
|
157 |
+
indices = np.arange(height * width, dtype=np.int32).reshape(height, width)
|
158 |
+
padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32)
|
159 |
+
padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
|
160 |
+
padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool)
|
161 |
+
padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
|
162 |
+
padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32)
|
163 |
+
padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
|
164 |
+
windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1))
|
165 |
+
windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1))
|
166 |
+
windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1))
|
167 |
+
|
168 |
+
# Gather the target pixels's local window
|
169 |
+
target_uv = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32)
|
170 |
+
target_corner = target_uv - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
|
171 |
+
target_corner = np.round(target_corner - 0.5).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32)
|
172 |
+
|
173 |
+
target_window_uv = windowed_uv[target_corner[..., 1], target_corner[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
|
174 |
+
target_window_mask = windowed_mask[..., target_corner[..., 1], target_corner[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
|
175 |
+
target_window_indices = windowed_indices[target_corner[..., 1], target_corner[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
|
176 |
+
|
177 |
+
# Compute nearest neighbor in the local window for each pixel
|
178 |
+
dist = np.square(target_window_uv - target_uv[..., None])
|
179 |
+
dist = dist[..., 0, :] + dist[..., 1, :]
|
180 |
+
dist = np.where(target_window_mask, dist, np.inf) # (..., target_height, tgt_width, filter_size)
|
181 |
+
nearest_in_window = np.argmin(dist, axis=-1, keepdims=True) # (..., target_height, tgt_width, 1)
|
182 |
+
nearest_idx = np.take_along_axis(target_window_indices, nearest_in_window, axis=-1).squeeze(-1) # (..., target_height, tgt_width)
|
183 |
+
nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
|
184 |
+
target_mask = np.any(target_window_mask, axis=-1)
|
185 |
+
batch_indices = [np.arange(n).reshape([1] * i + [n] + [1] * (mask.ndim - i - 1)) for i, n in enumerate(mask.shape[:-2])]
|
186 |
+
|
187 |
+
return (*batch_indices, nearest_i, nearest_j), target_mask
|
src/lari/utils/geometry_torch.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import math
|
3 |
+
from collections import namedtuple
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch.types
|
10 |
+
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
14 |
+
import utils3d
|
15 |
+
from .geometry_numpy import solve_optimal_focal_shift, solve_optimal_shift
|
16 |
+
|
17 |
+
|
18 |
+
def weighted_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
|
19 |
+
if w is None:
|
20 |
+
return x.mean(dim=dim, keepdim=keepdim)
|
21 |
+
else:
|
22 |
+
w = w.to(x.dtype)
|
23 |
+
return (x * w).mean(dim=dim, keepdim=keepdim) / w.mean(dim=dim, keepdim=keepdim).add(eps)
|
24 |
+
|
25 |
+
|
26 |
+
def harmonic_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
|
27 |
+
if w is None:
|
28 |
+
return x.add(eps).reciprocal().mean(dim=dim, keepdim=keepdim).reciprocal()
|
29 |
+
else:
|
30 |
+
w = w.to(x.dtype)
|
31 |
+
return weighted_mean(x.add(eps).reciprocal(), w, dim=dim, keepdim=keepdim, eps=eps).add(eps).reciprocal()
|
32 |
+
|
33 |
+
|
34 |
+
def geometric_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
|
35 |
+
if w is None:
|
36 |
+
return x.add(eps).log().mean(dim=dim).exp()
|
37 |
+
else:
|
38 |
+
w = w.to(x.dtype)
|
39 |
+
return weighted_mean(x.add(eps).log(), w, dim=dim, keepdim=keepdim, eps=eps).exp()
|
40 |
+
|
41 |
+
|
42 |
+
def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor:
|
43 |
+
"UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
|
44 |
+
if aspect_ratio is None:
|
45 |
+
aspect_ratio = width / height
|
46 |
+
|
47 |
+
span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
|
48 |
+
span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
|
49 |
+
|
50 |
+
u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
|
51 |
+
v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
|
52 |
+
u, v = torch.meshgrid(u, v, indexing='xy')
|
53 |
+
uv = torch.stack([u, v], dim=-1)
|
54 |
+
return uv
|
55 |
+
|
56 |
+
|
57 |
+
def gaussian_blur_2d(input: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
|
58 |
+
kernel = torch.exp(-(torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=input.dtype, device=input.device) ** 2) / (2 * sigma ** 2))
|
59 |
+
kernel = kernel / kernel.sum()
|
60 |
+
kernel = (kernel[:, None] * kernel[None, :]).reshape(1, 1, kernel_size, kernel_size)
|
61 |
+
input = F.pad(input, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), mode='replicate')
|
62 |
+
input = F.conv2d(input, kernel, groups=input.shape[1])
|
63 |
+
return input
|
64 |
+
|
65 |
+
|
66 |
+
def focal_to_fov(focal: torch.Tensor):
|
67 |
+
return 2 * torch.atan(0.5 / focal)
|
68 |
+
|
69 |
+
|
70 |
+
def fov_to_focal(fov: torch.Tensor):
|
71 |
+
return 0.5 / torch.tan(fov / 2)
|
72 |
+
|
73 |
+
|
74 |
+
def intrinsics_to_fov(intrinsics: torch.Tensor):
|
75 |
+
"""
|
76 |
+
Returns field of view in radians from normalized intrinsics matrix.
|
77 |
+
### Parameters:
|
78 |
+
- intrinsics: torch.Tensor of shape (..., 3, 3)
|
79 |
+
|
80 |
+
### Returns:
|
81 |
+
- fov_x: torch.Tensor of shape (...)
|
82 |
+
- fov_y: torch.Tensor of shape (...)
|
83 |
+
"""
|
84 |
+
focal_x = intrinsics[..., 0, 0]
|
85 |
+
focal_y = intrinsics[..., 1, 1]
|
86 |
+
return 2 * torch.atan(0.5 / focal_x), 2 * torch.atan(0.5 / focal_y)
|
87 |
+
|
88 |
+
|
89 |
+
def point_map_to_depth_legacy(points: torch.Tensor):
|
90 |
+
height, width = points.shape[-3:-1]
|
91 |
+
diagonal = (height ** 2 + width ** 2) ** 0.5
|
92 |
+
uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
|
93 |
+
|
94 |
+
# Solve least squares problem
|
95 |
+
b = (uv * points[..., 2:]).flatten(-3, -1) # (..., H * W * 2)
|
96 |
+
A = torch.stack([points[..., :2], -uv.expand_as(points[..., :2])], dim=-1).flatten(-4, -2) # (..., H * W * 2, 2)
|
97 |
+
|
98 |
+
M = A.transpose(-2, -1) @ A
|
99 |
+
solution = (torch.inverse(M + 1e-6 * torch.eye(2).to(A)) @ (A.transpose(-2, -1) @ b[..., None])).squeeze(-1)
|
100 |
+
focal, shift = solution.unbind(-1)
|
101 |
+
|
102 |
+
depth = points[..., 2] + shift[..., None, None]
|
103 |
+
fov_x = torch.atan(width / diagonal / focal) * 2
|
104 |
+
fov_y = torch.atan(height / diagonal / focal) * 2
|
105 |
+
return depth, fov_x, fov_y, shift
|
106 |
+
|
107 |
+
|
108 |
+
def view_plane_uv_to_focal(uv: torch.Tensor):
|
109 |
+
normed_uv = normalized_view_plane_uv(width=uv.shape[-2], height=uv.shape[-3], device=uv.device, dtype=uv.dtype)
|
110 |
+
focal = (uv * normed_uv).sum() / uv.square().sum().add(1e-12)
|
111 |
+
return focal
|
112 |
+
|
113 |
+
|
114 |
+
def recover_focal_shift(points: torch.Tensor, mask: torch.Tensor = None, focal: torch.Tensor = None, downsample_size: Tuple[int, int] = (64, 64)):
|
115 |
+
"""
|
116 |
+
Recover the depth map and FoV from a point map with unknown z shift and focal.
|
117 |
+
|
118 |
+
Note that it assumes:
|
119 |
+
- the optical center is at the center of the map
|
120 |
+
- the map is undistorted
|
121 |
+
- the map is isometric in the x and y directions
|
122 |
+
|
123 |
+
### Parameters:
|
124 |
+
- `points: torch.Tensor` of shape (..., H, W, 3)
|
125 |
+
- `mask: torch.Tensor` of shape (..., H, W). Optional.
|
126 |
+
- `focal: torch.Tensor` of shape (...). Optional.
|
127 |
+
- `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps.
|
128 |
+
|
129 |
+
### Returns:
|
130 |
+
- `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map
|
131 |
+
- `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space
|
132 |
+
"""
|
133 |
+
shape = points.shape
|
134 |
+
height, width = points.shape[-3], points.shape[-2]
|
135 |
+
diagonal = (height ** 2 + width ** 2) ** 0.5
|
136 |
+
|
137 |
+
points = points.reshape(-1, *shape[-3:])
|
138 |
+
mask = None if mask is None else mask.reshape(-1, *shape[-3:-1])
|
139 |
+
focal = focal.reshape(-1) if focal is not None else None
|
140 |
+
uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
|
141 |
+
|
142 |
+
points_lr = F.interpolate(points.permute(0, 3, 1, 2), downsample_size, mode='nearest').permute(0, 2, 3, 1)
|
143 |
+
uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode='nearest').squeeze(0).permute(1, 2, 0)
|
144 |
+
mask_lr = None if mask is None else F.interpolate(mask.to(torch.float32).unsqueeze(1), downsample_size, mode='nearest').squeeze(1) > 0
|
145 |
+
|
146 |
+
uv_lr_np = uv_lr.cpu().numpy()
|
147 |
+
points_lr_np = points_lr.detach().cpu().numpy()
|
148 |
+
focal_np = focal.cpu().numpy() if focal is not None else None
|
149 |
+
mask_lr_np = None if mask is None else mask_lr.cpu().numpy()
|
150 |
+
optim_shift, optim_focal = [], []
|
151 |
+
for i in range(points.shape[0]):
|
152 |
+
points_lr_i_np = points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]]
|
153 |
+
uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]]
|
154 |
+
if focal is None:
|
155 |
+
optim_shift_i, optim_focal_i = solve_optimal_focal_shift(uv_lr_i_np, points_lr_i_np)
|
156 |
+
optim_focal.append(float(optim_focal_i))
|
157 |
+
else:
|
158 |
+
optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i])
|
159 |
+
optim_shift.append(float(optim_shift_i))
|
160 |
+
optim_shift = torch.tensor(optim_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3])
|
161 |
+
|
162 |
+
if focal is None:
|
163 |
+
optim_focal = torch.tensor(optim_focal, device=points.device, dtype=points.dtype).reshape(shape[:-3])
|
164 |
+
else:
|
165 |
+
optim_focal = focal.reshape(shape[:-3])
|
166 |
+
|
167 |
+
return optim_focal, optim_shift
|
168 |
+
|
169 |
+
|
170 |
+
def mask_aware_nearest_resize(mask: torch.BoolTensor, target_width: int, target_height: int) -> Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]:
|
171 |
+
"""
|
172 |
+
Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
|
173 |
+
|
174 |
+
### Parameters
|
175 |
+
- `mask`: Input 2D mask of shape (..., H, W)
|
176 |
+
- `target_width`: target width of the resized map
|
177 |
+
- `target_height`: target height of the resized map
|
178 |
+
|
179 |
+
### Returns
|
180 |
+
- `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension
|
181 |
+
- `target_mask`: Mask of the resized map of shape (..., target_height, target_width)
|
182 |
+
"""
|
183 |
+
height, width = mask.shape[-2:]
|
184 |
+
device = mask.device
|
185 |
+
filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
|
186 |
+
filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
|
187 |
+
filter_size = filter_h_i * filter_w_i
|
188 |
+
padding_h, padding_w = round(filter_h_f / 2), round(filter_w_f / 2)
|
189 |
+
|
190 |
+
# Window the original mask and uv
|
191 |
+
uv = utils3d.torch.image_pixel_center(width=width, height=height, dtype=torch.float32, device=device)
|
192 |
+
indices = torch.arange(height * width, dtype=torch.long, device=device).reshape(height, width)
|
193 |
+
padded_uv = torch.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=torch.float32, device=device)
|
194 |
+
padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
|
195 |
+
padded_mask = torch.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=torch.bool, device=device)
|
196 |
+
padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
|
197 |
+
padded_indices = torch.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=torch.long, device=device)
|
198 |
+
padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
|
199 |
+
windowed_uv = utils3d.torch.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, dim=(0, 1))
|
200 |
+
windowed_mask = utils3d.torch.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, dim=(-2, -1))
|
201 |
+
windowed_indices = utils3d.torch.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, dim=(0, 1))
|
202 |
+
|
203 |
+
# Gather the target pixels's local window
|
204 |
+
target_uv = utils3d.torch.image_uv(width=target_width, height=target_height, dtype=torch.float32, device=device) * torch.tensor([width, height], dtype=torch.float32, device=device)
|
205 |
+
target_corner = target_uv - torch.tensor((filter_w_f / 2, filter_h_f / 2), dtype=torch.float32, device=device)
|
206 |
+
target_corner = torch.round(target_corner - 0.5).long() + torch.tensor((padding_w, padding_h), dtype=torch.long, device=device)
|
207 |
+
|
208 |
+
target_window_uv = windowed_uv[target_corner[..., 1], target_corner[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
|
209 |
+
target_window_mask = windowed_mask[..., target_corner[..., 1], target_corner[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
|
210 |
+
target_window_indices = windowed_indices[target_corner[..., 1], target_corner[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
|
211 |
+
target_window_indices = target_window_indices.expand_as(target_window_mask)
|
212 |
+
|
213 |
+
# Compute nearest neighbor in the local window for each pixel
|
214 |
+
dist = torch.where(target_window_mask, torch.norm(target_window_uv - target_uv[..., None], dim=-2), torch.inf) # (..., target_height, tgt_width, filter_size)
|
215 |
+
nearest = torch.argmin(dist, dim=-1, keepdim=True) # (..., target_height, tgt_width, 1)
|
216 |
+
nearest_idx = torch.gather(target_window_indices, index=nearest, dim=-1).squeeze(-1) # (..., target_height, tgt_width)
|
217 |
+
target_mask = torch.any(target_window_mask, dim=-1)
|
218 |
+
nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
|
219 |
+
batch_indices = [torch.arange(n, device=device).reshape([1] * i + [n] + [1] * (mask.dim() - i - 1)) for i, n in enumerate(mask.shape[:-2])]
|
220 |
+
|
221 |
+
return (*batch_indices, nearest_i, nearest_j), target_mask
|
src/utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
src/utils/vis.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import torchvision.transforms as transforms
|
2 |
+
# import torch.nn.functional as F
|
3 |
+
# import cv2
|
4 |
+
# import os
|
5 |
+
# import logging
|
6 |
+
# from pathlib import Path
|
7 |
+
import numpy as np
|
8 |
+
# import os
|
9 |
+
import torch
|
10 |
+
import matplotlib
|
11 |
+
# import cv2
|
12 |
+
# import random
|
13 |
+
# from PIL import Image
|
14 |
+
# import imageio
|
15 |
+
|
16 |
+
def prob_to_mask(prob):
|
17 |
+
"""
|
18 |
+
Transforms a probability map of stopping points (shape: (n_layer+1, H, W))
|
19 |
+
into a binary mask (shape: (H, W, n_layer, 1)) where for each pixel, layers
|
20 |
+
with index ≤ stopping index (as given by argmax) are marked valid.
|
21 |
+
"""
|
22 |
+
num_layer_plus1, H, W = prob.shape
|
23 |
+
# Get stopping index for each pixel; values are in {0, 1, ..., n_layer}
|
24 |
+
stopping_indices = torch.argmax(prob, dim=0) # (H, W)
|
25 |
+
|
26 |
+
# Create a tensor with layer indices [1, 2, ..., n_layer]
|
27 |
+
layer_indices = torch.arange(1, num_layer_plus1, device=prob.device).view(-1, 1, 1)
|
28 |
+
|
29 |
+
# Compare: a layer is valid if its index is <= the stopping index.
|
30 |
+
pred_mask = (layer_indices <= stopping_indices.unsqueeze(0))
|
31 |
+
|
32 |
+
# Permute and unsqueeze to get shape (H, W, n_layer, 1)
|
33 |
+
pred_mask = pred_mask.permute(1, 2, 0).unsqueeze(-1)
|
34 |
+
return pred_mask
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
def colorize(value, vmin=None, vmax=None, cmap='rainbow', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None):
|
40 |
+
"""Converts a depth map to a color image.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed
|
44 |
+
vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None.
|
45 |
+
vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None.
|
46 |
+
cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'.
|
47 |
+
invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99.
|
48 |
+
invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None.
|
49 |
+
background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255).
|
50 |
+
gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False.
|
51 |
+
value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4)
|
55 |
+
"""
|
56 |
+
if isinstance(value, torch.Tensor):
|
57 |
+
value = value.detach().cpu().numpy()
|
58 |
+
|
59 |
+
value = value.squeeze()
|
60 |
+
if invalid_mask is None:
|
61 |
+
invalid_mask = value == invalid_val
|
62 |
+
mask = np.logical_not(invalid_mask)
|
63 |
+
|
64 |
+
# normalize
|
65 |
+
vmin = np.percentile(value[mask],2) if vmin is None else vmin
|
66 |
+
vmax = np.percentile(value[mask],85) if vmax is None else vmax
|
67 |
+
if vmin != vmax:
|
68 |
+
value = (value - vmin) / (vmax - vmin) # vmin..vmax
|
69 |
+
else:
|
70 |
+
# Avoid 0-division
|
71 |
+
value = value * 0.
|
72 |
+
|
73 |
+
value[invalid_mask] = np.nan
|
74 |
+
cmapper = matplotlib.cm.get_cmap(cmap)
|
75 |
+
if value_transform:
|
76 |
+
value = value_transform(value)
|
77 |
+
# value = value / value.max()
|
78 |
+
value = cmapper(value, bytes=True) # (nxmx4)
|
79 |
+
|
80 |
+
# img = value[:, :, :]
|
81 |
+
img = value[...]
|
82 |
+
img[invalid_mask] = background_color
|
83 |
+
|
84 |
+
if gamma_corrected:
|
85 |
+
# gamma correction
|
86 |
+
img = img / 255
|
87 |
+
img = np.power(img, 2.2)
|
88 |
+
img = img * 255
|
89 |
+
img = img.astype(np.uint8)
|
90 |
+
return img
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
def denormalize(x):
|
95 |
+
"""Reverses the imagenet normalization applied to the input.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
x (torch.Tensor - shape(N,3,H,W)): input tensor
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
torch.Tensor - shape(N,3,H,W): Denormalized input
|
102 |
+
"""
|
103 |
+
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
|
104 |
+
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
|
105 |
+
return x * std + mean
|
src/utils3d/README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# utils3d
|
2 |
+
|
3 |
+
This is a collection of utility functions for 3D computer vision tasks copied from https://github.com/EasternJournalist/utils3d.
|
src/utils3d/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A package for common utility functions in 3D computer graphics and vision. Providing NumPy utilities in `utils3d.numpy`, PyTorch utilities in `utils3d.torch`, and IO utilities in `utils3d.io`.
|
3 |
+
"""
|
4 |
+
import importlib
|
5 |
+
from typing import TYPE_CHECKING
|
6 |
+
|
7 |
+
try:
|
8 |
+
from ._unified import *
|
9 |
+
except ImportError:
|
10 |
+
pass
|
11 |
+
|
12 |
+
__all__ = ['numpy', 'torch', 'io']
|
13 |
+
|
14 |
+
def __getattr__(name: str):
|
15 |
+
return globals().get(name, importlib.import_module(f'.{name}', __package__))
|
16 |
+
|
17 |
+
if TYPE_CHECKING:
|
18 |
+
from . import torch
|
19 |
+
from . import numpy
|
20 |
+
from . import io
|
src/utils3d/_helpers.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import wraps
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
|
5 |
+
def suppress_traceback(fn):
|
6 |
+
@wraps(fn)
|
7 |
+
def wrapper(*args, **kwargs):
|
8 |
+
try:
|
9 |
+
return fn(*args, **kwargs)
|
10 |
+
except Exception as e:
|
11 |
+
e.__traceback__ = e.__traceback__.tb_next.tb_next
|
12 |
+
raise
|
13 |
+
return wrapper
|
14 |
+
|
15 |
+
|
16 |
+
class no_warnings:
|
17 |
+
def __init__(self, action: str = 'ignore', **kwargs):
|
18 |
+
self.action = action
|
19 |
+
self.filter_kwargs = kwargs
|
20 |
+
|
21 |
+
def __call__(self, fn):
|
22 |
+
@wraps(fn)
|
23 |
+
def wrapper(*args, **kwargs):
|
24 |
+
with warnings.catch_warnings():
|
25 |
+
warnings.simplefilter(self.action, **self.filter_kwargs)
|
26 |
+
return fn(*args, **kwargs)
|
27 |
+
return wrapper
|
28 |
+
|
29 |
+
def __enter__(self):
|
30 |
+
self.warnings_manager = warnings.catch_warnings()
|
31 |
+
self.warnings_manager.__enter__()
|
32 |
+
warnings.simplefilter(self.action, **self.filter_kwargs)
|
33 |
+
|
34 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
35 |
+
self.warnings_manager.__exit__(exc_type, exc_val, exc_tb)
|
src/utils3d/_unified/__init__.py
ADDED
@@ -0,0 +1,934 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Auto-generated implementation redirecting to numpy/torch implementations
|
2 |
+
import sys
|
3 |
+
from typing import TYPE_CHECKING
|
4 |
+
import utils3d
|
5 |
+
from .._helpers import suppress_traceback
|
6 |
+
|
7 |
+
__all__ = ["triangulate",
|
8 |
+
"compute_face_normal",
|
9 |
+
"compute_face_angle",
|
10 |
+
"compute_vertex_normal",
|
11 |
+
"compute_vertex_normal_weighted",
|
12 |
+
"remove_corrupted_faces",
|
13 |
+
"merge_duplicate_vertices",
|
14 |
+
"remove_unreferenced_vertices",
|
15 |
+
"subdivide_mesh_simple",
|
16 |
+
"mesh_relations",
|
17 |
+
"flatten_mesh_indices",
|
18 |
+
"calc_quad_candidates",
|
19 |
+
"calc_quad_distortion",
|
20 |
+
"calc_quad_direction",
|
21 |
+
"calc_quad_smoothness",
|
22 |
+
"sovle_quad",
|
23 |
+
"sovle_quad_qp",
|
24 |
+
"tri_to_quad",
|
25 |
+
"sliding_window_1d",
|
26 |
+
"sliding_window_nd",
|
27 |
+
"sliding_window_2d",
|
28 |
+
"max_pool_1d",
|
29 |
+
"max_pool_2d",
|
30 |
+
"max_pool_nd",
|
31 |
+
"depth_edge",
|
32 |
+
"normals_edge",
|
33 |
+
"depth_aliasing",
|
34 |
+
"interpolate",
|
35 |
+
"image_scrcoord",
|
36 |
+
"image_uv",
|
37 |
+
"image_pixel_center",
|
38 |
+
"image_pixel",
|
39 |
+
"image_mesh",
|
40 |
+
"image_mesh_from_depth",
|
41 |
+
"depth_to_normals",
|
42 |
+
"points_to_normals",
|
43 |
+
"chessboard",
|
44 |
+
"cube",
|
45 |
+
"icosahedron",
|
46 |
+
"square",
|
47 |
+
"camera_frustum",
|
48 |
+
"perspective",
|
49 |
+
"perspective_from_fov",
|
50 |
+
"perspective_from_fov_xy",
|
51 |
+
"intrinsics_from_focal_center",
|
52 |
+
"intrinsics_from_fov",
|
53 |
+
"fov_to_focal",
|
54 |
+
"focal_to_fov",
|
55 |
+
"intrinsics_to_fov",
|
56 |
+
"view_look_at",
|
57 |
+
"extrinsics_look_at",
|
58 |
+
"perspective_to_intrinsics",
|
59 |
+
"perspective_to_near_far",
|
60 |
+
"intrinsics_to_perspective",
|
61 |
+
"extrinsics_to_view",
|
62 |
+
"view_to_extrinsics",
|
63 |
+
"normalize_intrinsics",
|
64 |
+
"crop_intrinsics",
|
65 |
+
"pixel_to_uv",
|
66 |
+
"pixel_to_ndc",
|
67 |
+
"uv_to_pixel",
|
68 |
+
"project_depth",
|
69 |
+
"depth_buffer_to_linear",
|
70 |
+
"unproject_cv",
|
71 |
+
"unproject_gl",
|
72 |
+
"project_cv",
|
73 |
+
"project_gl",
|
74 |
+
"quaternion_to_matrix",
|
75 |
+
"axis_angle_to_matrix",
|
76 |
+
"matrix_to_quaternion",
|
77 |
+
"extrinsics_to_essential",
|
78 |
+
"euler_axis_angle_rotation",
|
79 |
+
"euler_angles_to_matrix",
|
80 |
+
"skew_symmetric",
|
81 |
+
"rotation_matrix_from_vectors",
|
82 |
+
"ray_intersection",
|
83 |
+
"se3_matrix",
|
84 |
+
"slerp_quaternion",
|
85 |
+
"slerp_vector",
|
86 |
+
"lerp",
|
87 |
+
"lerp_se3_matrix",
|
88 |
+
"piecewise_lerp",
|
89 |
+
"piecewise_lerp_se3_matrix",
|
90 |
+
"apply_transform",
|
91 |
+
"linear_spline_interpolate",
|
92 |
+
"RastContext",
|
93 |
+
"rasterize_triangle_faces",
|
94 |
+
"rasterize_edges",
|
95 |
+
"texture",
|
96 |
+
"warp_image_by_depth",
|
97 |
+
"test_rasterization",
|
98 |
+
"compute_face_angles",
|
99 |
+
"compute_face_tbn",
|
100 |
+
"compute_vertex_tbn",
|
101 |
+
"laplacian",
|
102 |
+
"laplacian_smooth_mesh",
|
103 |
+
"taubin_smooth_mesh",
|
104 |
+
"laplacian_hc_smooth_mesh",
|
105 |
+
"get_rays",
|
106 |
+
"get_image_rays",
|
107 |
+
"get_mipnerf_cones",
|
108 |
+
"volume_rendering",
|
109 |
+
"bin_sample",
|
110 |
+
"importance_sample",
|
111 |
+
"nerf_render_rays",
|
112 |
+
"mipnerf_render_rays",
|
113 |
+
"nerf_render_view",
|
114 |
+
"mipnerf_render_view",
|
115 |
+
"InstantNGP",
|
116 |
+
"point_to_normal",
|
117 |
+
"depth_to_normal",
|
118 |
+
"masked_min",
|
119 |
+
"masked_max",
|
120 |
+
"bounding_rect",
|
121 |
+
"intrinsics_from_fov_xy",
|
122 |
+
"matrix_to_euler_angles",
|
123 |
+
"matrix_to_axis_angle",
|
124 |
+
"axis_angle_to_quaternion",
|
125 |
+
"quaternion_to_axis_angle",
|
126 |
+
"slerp",
|
127 |
+
"interpolate_extrinsics",
|
128 |
+
"interpolate_view",
|
129 |
+
"to4x4",
|
130 |
+
"rotation_matrix_2d",
|
131 |
+
"rotate_2d",
|
132 |
+
"translate_2d",
|
133 |
+
"scale_2d",
|
134 |
+
"apply_2d",
|
135 |
+
"warp_image_by_forward_flow"]
|
136 |
+
|
137 |
+
def _contains_tensor(obj):
|
138 |
+
if isinstance(obj, (list, tuple)):
|
139 |
+
return any(_contains_tensor(item) for item in obj)
|
140 |
+
elif isinstance(obj, dict):
|
141 |
+
return any(_contains_tensor(value) for value in obj.values())
|
142 |
+
else:
|
143 |
+
import torch
|
144 |
+
return isinstance(obj, torch.Tensor)
|
145 |
+
|
146 |
+
|
147 |
+
@suppress_traceback
|
148 |
+
def _call_based_on_args(fname, args, kwargs):
|
149 |
+
if 'torch' in sys.modules:
|
150 |
+
if any(_contains_tensor(arg) for arg in args) or any(_contains_tensor(v) for v in kwargs.values()):
|
151 |
+
fn = getattr(utils3d.torch, fname, None)
|
152 |
+
if fn is None:
|
153 |
+
raise NotImplementedError(f"Function {fname} has no torch implementation.")
|
154 |
+
return fn(*args, **kwargs)
|
155 |
+
fn = getattr(utils3d.numpy, fname, None)
|
156 |
+
if fn is None:
|
157 |
+
raise NotImplementedError(f"Function {fname} has no numpy implementation.")
|
158 |
+
return fn(*args, **kwargs)
|
159 |
+
|
160 |
+
|
161 |
+
@suppress_traceback
|
162 |
+
def triangulate(*args, **kwargs):
|
163 |
+
if TYPE_CHECKING: # redirected to:
|
164 |
+
utils3d.numpy.triangulate, utils3d.torch.triangulate
|
165 |
+
return _call_based_on_args('triangulate', args, kwargs)
|
166 |
+
|
167 |
+
@suppress_traceback
|
168 |
+
def compute_face_normal(*args, **kwargs):
|
169 |
+
if TYPE_CHECKING: # redirected to:
|
170 |
+
utils3d.numpy.compute_face_normal, utils3d.torch.compute_face_normal
|
171 |
+
return _call_based_on_args('compute_face_normal', args, kwargs)
|
172 |
+
|
173 |
+
@suppress_traceback
|
174 |
+
def compute_face_angle(*args, **kwargs):
|
175 |
+
if TYPE_CHECKING: # redirected to:
|
176 |
+
utils3d.numpy.compute_face_angle, None
|
177 |
+
return _call_based_on_args('compute_face_angle', args, kwargs)
|
178 |
+
|
179 |
+
@suppress_traceback
|
180 |
+
def compute_vertex_normal(*args, **kwargs):
|
181 |
+
if TYPE_CHECKING: # redirected to:
|
182 |
+
utils3d.numpy.compute_vertex_normal, utils3d.torch.compute_vertex_normal
|
183 |
+
return _call_based_on_args('compute_vertex_normal', args, kwargs)
|
184 |
+
|
185 |
+
@suppress_traceback
|
186 |
+
def compute_vertex_normal_weighted(*args, **kwargs):
|
187 |
+
if TYPE_CHECKING: # redirected to:
|
188 |
+
utils3d.numpy.compute_vertex_normal_weighted, utils3d.torch.compute_vertex_normal_weighted
|
189 |
+
return _call_based_on_args('compute_vertex_normal_weighted', args, kwargs)
|
190 |
+
|
191 |
+
@suppress_traceback
|
192 |
+
def remove_corrupted_faces(*args, **kwargs):
|
193 |
+
if TYPE_CHECKING: # redirected to:
|
194 |
+
utils3d.numpy.remove_corrupted_faces, utils3d.torch.remove_corrupted_faces
|
195 |
+
return _call_based_on_args('remove_corrupted_faces', args, kwargs)
|
196 |
+
|
197 |
+
@suppress_traceback
|
198 |
+
def merge_duplicate_vertices(*args, **kwargs):
|
199 |
+
if TYPE_CHECKING: # redirected to:
|
200 |
+
utils3d.numpy.merge_duplicate_vertices, utils3d.torch.merge_duplicate_vertices
|
201 |
+
return _call_based_on_args('merge_duplicate_vertices', args, kwargs)
|
202 |
+
|
203 |
+
@suppress_traceback
|
204 |
+
def remove_unreferenced_vertices(*args, **kwargs):
|
205 |
+
if TYPE_CHECKING: # redirected to:
|
206 |
+
utils3d.numpy.remove_unreferenced_vertices, utils3d.torch.remove_unreferenced_vertices
|
207 |
+
return _call_based_on_args('remove_unreferenced_vertices', args, kwargs)
|
208 |
+
|
209 |
+
@suppress_traceback
|
210 |
+
def subdivide_mesh_simple(*args, **kwargs):
|
211 |
+
if TYPE_CHECKING: # redirected to:
|
212 |
+
utils3d.numpy.subdivide_mesh_simple, utils3d.torch.subdivide_mesh_simple
|
213 |
+
return _call_based_on_args('subdivide_mesh_simple', args, kwargs)
|
214 |
+
|
215 |
+
@suppress_traceback
|
216 |
+
def mesh_relations(*args, **kwargs):
|
217 |
+
if TYPE_CHECKING: # redirected to:
|
218 |
+
utils3d.numpy.mesh_relations, None
|
219 |
+
return _call_based_on_args('mesh_relations', args, kwargs)
|
220 |
+
|
221 |
+
@suppress_traceback
|
222 |
+
def flatten_mesh_indices(*args, **kwargs):
|
223 |
+
if TYPE_CHECKING: # redirected to:
|
224 |
+
utils3d.numpy.flatten_mesh_indices, None
|
225 |
+
return _call_based_on_args('flatten_mesh_indices', args, kwargs)
|
226 |
+
|
227 |
+
@suppress_traceback
|
228 |
+
def calc_quad_candidates(*args, **kwargs):
|
229 |
+
if TYPE_CHECKING: # redirected to:
|
230 |
+
utils3d.numpy.calc_quad_candidates, None
|
231 |
+
return _call_based_on_args('calc_quad_candidates', args, kwargs)
|
232 |
+
|
233 |
+
@suppress_traceback
|
234 |
+
def calc_quad_distortion(*args, **kwargs):
|
235 |
+
if TYPE_CHECKING: # redirected to:
|
236 |
+
utils3d.numpy.calc_quad_distortion, None
|
237 |
+
return _call_based_on_args('calc_quad_distortion', args, kwargs)
|
238 |
+
|
239 |
+
@suppress_traceback
|
240 |
+
def calc_quad_direction(*args, **kwargs):
|
241 |
+
if TYPE_CHECKING: # redirected to:
|
242 |
+
utils3d.numpy.calc_quad_direction, None
|
243 |
+
return _call_based_on_args('calc_quad_direction', args, kwargs)
|
244 |
+
|
245 |
+
@suppress_traceback
|
246 |
+
def calc_quad_smoothness(*args, **kwargs):
|
247 |
+
if TYPE_CHECKING: # redirected to:
|
248 |
+
utils3d.numpy.calc_quad_smoothness, None
|
249 |
+
return _call_based_on_args('calc_quad_smoothness', args, kwargs)
|
250 |
+
|
251 |
+
@suppress_traceback
|
252 |
+
def sovle_quad(*args, **kwargs):
|
253 |
+
if TYPE_CHECKING: # redirected to:
|
254 |
+
utils3d.numpy.sovle_quad, None
|
255 |
+
return _call_based_on_args('sovle_quad', args, kwargs)
|
256 |
+
|
257 |
+
@suppress_traceback
|
258 |
+
def sovle_quad_qp(*args, **kwargs):
|
259 |
+
if TYPE_CHECKING: # redirected to:
|
260 |
+
utils3d.numpy.sovle_quad_qp, None
|
261 |
+
return _call_based_on_args('sovle_quad_qp', args, kwargs)
|
262 |
+
|
263 |
+
@suppress_traceback
|
264 |
+
def tri_to_quad(*args, **kwargs):
|
265 |
+
if TYPE_CHECKING: # redirected to:
|
266 |
+
utils3d.numpy.tri_to_quad, None
|
267 |
+
return _call_based_on_args('tri_to_quad', args, kwargs)
|
268 |
+
|
269 |
+
@suppress_traceback
|
270 |
+
def sliding_window_1d(*args, **kwargs):
|
271 |
+
if TYPE_CHECKING: # redirected to:
|
272 |
+
utils3d.numpy.sliding_window_1d, utils3d.torch.sliding_window_1d
|
273 |
+
return _call_based_on_args('sliding_window_1d', args, kwargs)
|
274 |
+
|
275 |
+
@suppress_traceback
|
276 |
+
def sliding_window_nd(*args, **kwargs):
|
277 |
+
if TYPE_CHECKING: # redirected to:
|
278 |
+
utils3d.numpy.sliding_window_nd, utils3d.torch.sliding_window_nd
|
279 |
+
return _call_based_on_args('sliding_window_nd', args, kwargs)
|
280 |
+
|
281 |
+
@suppress_traceback
|
282 |
+
def sliding_window_2d(*args, **kwargs):
|
283 |
+
if TYPE_CHECKING: # redirected to:
|
284 |
+
utils3d.numpy.sliding_window_2d, utils3d.torch.sliding_window_2d
|
285 |
+
return _call_based_on_args('sliding_window_2d', args, kwargs)
|
286 |
+
|
287 |
+
@suppress_traceback
|
288 |
+
def max_pool_1d(*args, **kwargs):
|
289 |
+
if TYPE_CHECKING: # redirected to:
|
290 |
+
utils3d.numpy.max_pool_1d, None
|
291 |
+
return _call_based_on_args('max_pool_1d', args, kwargs)
|
292 |
+
|
293 |
+
@suppress_traceback
|
294 |
+
def max_pool_2d(*args, **kwargs):
|
295 |
+
if TYPE_CHECKING: # redirected to:
|
296 |
+
utils3d.numpy.max_pool_2d, None
|
297 |
+
return _call_based_on_args('max_pool_2d', args, kwargs)
|
298 |
+
|
299 |
+
@suppress_traceback
|
300 |
+
def max_pool_nd(*args, **kwargs):
|
301 |
+
if TYPE_CHECKING: # redirected to:
|
302 |
+
utils3d.numpy.max_pool_nd, None
|
303 |
+
return _call_based_on_args('max_pool_nd', args, kwargs)
|
304 |
+
|
305 |
+
@suppress_traceback
|
306 |
+
def depth_edge(*args, **kwargs):
|
307 |
+
if TYPE_CHECKING: # redirected to:
|
308 |
+
utils3d.numpy.depth_edge, utils3d.torch.depth_edge
|
309 |
+
return _call_based_on_args('depth_edge', args, kwargs)
|
310 |
+
|
311 |
+
@suppress_traceback
|
312 |
+
def normals_edge(*args, **kwargs):
|
313 |
+
if TYPE_CHECKING: # redirected to:
|
314 |
+
utils3d.numpy.normals_edge, None
|
315 |
+
return _call_based_on_args('normals_edge', args, kwargs)
|
316 |
+
|
317 |
+
@suppress_traceback
|
318 |
+
def depth_aliasing(*args, **kwargs):
|
319 |
+
if TYPE_CHECKING: # redirected to:
|
320 |
+
utils3d.numpy.depth_aliasing, utils3d.torch.depth_aliasing
|
321 |
+
return _call_based_on_args('depth_aliasing', args, kwargs)
|
322 |
+
|
323 |
+
@suppress_traceback
|
324 |
+
def interpolate(*args, **kwargs):
|
325 |
+
if TYPE_CHECKING: # redirected to:
|
326 |
+
utils3d.numpy.interpolate, None
|
327 |
+
return _call_based_on_args('interpolate', args, kwargs)
|
328 |
+
|
329 |
+
@suppress_traceback
|
330 |
+
def image_scrcoord(*args, **kwargs):
|
331 |
+
if TYPE_CHECKING: # redirected to:
|
332 |
+
utils3d.numpy.image_scrcoord, None
|
333 |
+
return _call_based_on_args('image_scrcoord', args, kwargs)
|
334 |
+
|
335 |
+
@suppress_traceback
|
336 |
+
def image_uv(*args, **kwargs):
|
337 |
+
if TYPE_CHECKING: # redirected to:
|
338 |
+
utils3d.numpy.image_uv, utils3d.torch.image_uv
|
339 |
+
return _call_based_on_args('image_uv', args, kwargs)
|
340 |
+
|
341 |
+
@suppress_traceback
|
342 |
+
def image_pixel_center(*args, **kwargs):
|
343 |
+
if TYPE_CHECKING: # redirected to:
|
344 |
+
utils3d.numpy.image_pixel_center, utils3d.torch.image_pixel_center
|
345 |
+
return _call_based_on_args('image_pixel_center', args, kwargs)
|
346 |
+
|
347 |
+
@suppress_traceback
|
348 |
+
def image_pixel(*args, **kwargs):
|
349 |
+
if TYPE_CHECKING: # redirected to:
|
350 |
+
utils3d.numpy.image_pixel, None
|
351 |
+
return _call_based_on_args('image_pixel', args, kwargs)
|
352 |
+
|
353 |
+
@suppress_traceback
|
354 |
+
def image_mesh(*args, **kwargs):
|
355 |
+
if TYPE_CHECKING: # redirected to:
|
356 |
+
utils3d.numpy.image_mesh, utils3d.torch.image_mesh
|
357 |
+
return _call_based_on_args('image_mesh', args, kwargs)
|
358 |
+
|
359 |
+
@suppress_traceback
|
360 |
+
def image_mesh_from_depth(*args, **kwargs):
|
361 |
+
if TYPE_CHECKING: # redirected to:
|
362 |
+
utils3d.numpy.image_mesh_from_depth, utils3d.torch.image_mesh_from_depth
|
363 |
+
return _call_based_on_args('image_mesh_from_depth', args, kwargs)
|
364 |
+
|
365 |
+
@suppress_traceback
|
366 |
+
def depth_to_normals(*args, **kwargs):
|
367 |
+
if TYPE_CHECKING: # redirected to:
|
368 |
+
utils3d.numpy.depth_to_normals, None
|
369 |
+
return _call_based_on_args('depth_to_normals', args, kwargs)
|
370 |
+
|
371 |
+
@suppress_traceback
|
372 |
+
def points_to_normals(*args, **kwargs):
|
373 |
+
if TYPE_CHECKING: # redirected to:
|
374 |
+
utils3d.numpy.points_to_normals, None
|
375 |
+
return _call_based_on_args('points_to_normals', args, kwargs)
|
376 |
+
|
377 |
+
@suppress_traceback
|
378 |
+
def chessboard(*args, **kwargs):
|
379 |
+
if TYPE_CHECKING: # redirected to:
|
380 |
+
utils3d.numpy.chessboard, utils3d.torch.chessboard
|
381 |
+
return _call_based_on_args('chessboard', args, kwargs)
|
382 |
+
|
383 |
+
@suppress_traceback
|
384 |
+
def cube(*args, **kwargs):
|
385 |
+
if TYPE_CHECKING: # redirected to:
|
386 |
+
utils3d.numpy.cube, None
|
387 |
+
return _call_based_on_args('cube', args, kwargs)
|
388 |
+
|
389 |
+
@suppress_traceback
|
390 |
+
def icosahedron(*args, **kwargs):
|
391 |
+
if TYPE_CHECKING: # redirected to:
|
392 |
+
utils3d.numpy.icosahedron, None
|
393 |
+
return _call_based_on_args('icosahedron', args, kwargs)
|
394 |
+
|
395 |
+
@suppress_traceback
|
396 |
+
def square(*args, **kwargs):
|
397 |
+
if TYPE_CHECKING: # redirected to:
|
398 |
+
utils3d.numpy.square, None
|
399 |
+
return _call_based_on_args('square', args, kwargs)
|
400 |
+
|
401 |
+
@suppress_traceback
|
402 |
+
def camera_frustum(*args, **kwargs):
|
403 |
+
if TYPE_CHECKING: # redirected to:
|
404 |
+
utils3d.numpy.camera_frustum, None
|
405 |
+
return _call_based_on_args('camera_frustum', args, kwargs)
|
406 |
+
|
407 |
+
@suppress_traceback
|
408 |
+
def perspective(*args, **kwargs):
|
409 |
+
if TYPE_CHECKING: # redirected to:
|
410 |
+
utils3d.numpy.perspective, utils3d.torch.perspective
|
411 |
+
return _call_based_on_args('perspective', args, kwargs)
|
412 |
+
|
413 |
+
@suppress_traceback
|
414 |
+
def perspective_from_fov(*args, **kwargs):
|
415 |
+
if TYPE_CHECKING: # redirected to:
|
416 |
+
utils3d.numpy.perspective_from_fov, utils3d.torch.perspective_from_fov
|
417 |
+
return _call_based_on_args('perspective_from_fov', args, kwargs)
|
418 |
+
|
419 |
+
@suppress_traceback
|
420 |
+
def perspective_from_fov_xy(*args, **kwargs):
|
421 |
+
if TYPE_CHECKING: # redirected to:
|
422 |
+
utils3d.numpy.perspective_from_fov_xy, utils3d.torch.perspective_from_fov_xy
|
423 |
+
return _call_based_on_args('perspective_from_fov_xy', args, kwargs)
|
424 |
+
|
425 |
+
@suppress_traceback
|
426 |
+
def intrinsics_from_focal_center(*args, **kwargs):
|
427 |
+
if TYPE_CHECKING: # redirected to:
|
428 |
+
utils3d.numpy.intrinsics_from_focal_center, utils3d.torch.intrinsics_from_focal_center
|
429 |
+
return _call_based_on_args('intrinsics_from_focal_center', args, kwargs)
|
430 |
+
|
431 |
+
@suppress_traceback
|
432 |
+
def intrinsics_from_fov(*args, **kwargs):
|
433 |
+
if TYPE_CHECKING: # redirected to:
|
434 |
+
utils3d.numpy.intrinsics_from_fov, utils3d.torch.intrinsics_from_fov
|
435 |
+
return _call_based_on_args('intrinsics_from_fov', args, kwargs)
|
436 |
+
|
437 |
+
@suppress_traceback
|
438 |
+
def fov_to_focal(*args, **kwargs):
|
439 |
+
if TYPE_CHECKING: # redirected to:
|
440 |
+
utils3d.numpy.fov_to_focal, None
|
441 |
+
return _call_based_on_args('fov_to_focal', args, kwargs)
|
442 |
+
|
443 |
+
@suppress_traceback
|
444 |
+
def focal_to_fov(*args, **kwargs):
|
445 |
+
if TYPE_CHECKING: # redirected to:
|
446 |
+
utils3d.numpy.focal_to_fov, None
|
447 |
+
return _call_based_on_args('focal_to_fov', args, kwargs)
|
448 |
+
|
449 |
+
@suppress_traceback
|
450 |
+
def intrinsics_to_fov(*args, **kwargs):
|
451 |
+
if TYPE_CHECKING: # redirected to:
|
452 |
+
utils3d.numpy.intrinsics_to_fov, None
|
453 |
+
return _call_based_on_args('intrinsics_to_fov', args, kwargs)
|
454 |
+
|
455 |
+
@suppress_traceback
|
456 |
+
def view_look_at(*args, **kwargs):
|
457 |
+
if TYPE_CHECKING: # redirected to:
|
458 |
+
utils3d.numpy.view_look_at, utils3d.torch.view_look_at
|
459 |
+
return _call_based_on_args('view_look_at', args, kwargs)
|
460 |
+
|
461 |
+
@suppress_traceback
|
462 |
+
def extrinsics_look_at(*args, **kwargs):
|
463 |
+
if TYPE_CHECKING: # redirected to:
|
464 |
+
utils3d.numpy.extrinsics_look_at, utils3d.torch.extrinsics_look_at
|
465 |
+
return _call_based_on_args('extrinsics_look_at', args, kwargs)
|
466 |
+
|
467 |
+
@suppress_traceback
|
468 |
+
def perspective_to_intrinsics(*args, **kwargs):
|
469 |
+
if TYPE_CHECKING: # redirected to:
|
470 |
+
utils3d.numpy.perspective_to_intrinsics, utils3d.torch.perspective_to_intrinsics
|
471 |
+
return _call_based_on_args('perspective_to_intrinsics', args, kwargs)
|
472 |
+
|
473 |
+
@suppress_traceback
|
474 |
+
def perspective_to_near_far(*args, **kwargs):
|
475 |
+
if TYPE_CHECKING: # redirected to:
|
476 |
+
utils3d.numpy.perspective_to_near_far, None
|
477 |
+
return _call_based_on_args('perspective_to_near_far', args, kwargs)
|
478 |
+
|
479 |
+
@suppress_traceback
|
480 |
+
def intrinsics_to_perspective(*args, **kwargs):
|
481 |
+
if TYPE_CHECKING: # redirected to:
|
482 |
+
utils3d.numpy.intrinsics_to_perspective, utils3d.torch.intrinsics_to_perspective
|
483 |
+
return _call_based_on_args('intrinsics_to_perspective', args, kwargs)
|
484 |
+
|
485 |
+
@suppress_traceback
|
486 |
+
def extrinsics_to_view(*args, **kwargs):
|
487 |
+
if TYPE_CHECKING: # redirected to:
|
488 |
+
utils3d.numpy.extrinsics_to_view, utils3d.torch.extrinsics_to_view
|
489 |
+
return _call_based_on_args('extrinsics_to_view', args, kwargs)
|
490 |
+
|
491 |
+
@suppress_traceback
|
492 |
+
def view_to_extrinsics(*args, **kwargs):
|
493 |
+
if TYPE_CHECKING: # redirected to:
|
494 |
+
utils3d.numpy.view_to_extrinsics, utils3d.torch.view_to_extrinsics
|
495 |
+
return _call_based_on_args('view_to_extrinsics', args, kwargs)
|
496 |
+
|
497 |
+
@suppress_traceback
|
498 |
+
def normalize_intrinsics(*args, **kwargs):
|
499 |
+
if TYPE_CHECKING: # redirected to:
|
500 |
+
utils3d.numpy.normalize_intrinsics, utils3d.torch.normalize_intrinsics
|
501 |
+
return _call_based_on_args('normalize_intrinsics', args, kwargs)
|
502 |
+
|
503 |
+
@suppress_traceback
|
504 |
+
def crop_intrinsics(*args, **kwargs):
|
505 |
+
if TYPE_CHECKING: # redirected to:
|
506 |
+
utils3d.numpy.crop_intrinsics, utils3d.torch.crop_intrinsics
|
507 |
+
return _call_based_on_args('crop_intrinsics', args, kwargs)
|
508 |
+
|
509 |
+
@suppress_traceback
|
510 |
+
def pixel_to_uv(*args, **kwargs):
|
511 |
+
if TYPE_CHECKING: # redirected to:
|
512 |
+
utils3d.numpy.pixel_to_uv, utils3d.torch.pixel_to_uv
|
513 |
+
return _call_based_on_args('pixel_to_uv', args, kwargs)
|
514 |
+
|
515 |
+
@suppress_traceback
|
516 |
+
def pixel_to_ndc(*args, **kwargs):
|
517 |
+
if TYPE_CHECKING: # redirected to:
|
518 |
+
utils3d.numpy.pixel_to_ndc, utils3d.torch.pixel_to_ndc
|
519 |
+
return _call_based_on_args('pixel_to_ndc', args, kwargs)
|
520 |
+
|
521 |
+
@suppress_traceback
|
522 |
+
def uv_to_pixel(*args, **kwargs):
|
523 |
+
if TYPE_CHECKING: # redirected to:
|
524 |
+
utils3d.numpy.uv_to_pixel, utils3d.torch.uv_to_pixel
|
525 |
+
return _call_based_on_args('uv_to_pixel', args, kwargs)
|
526 |
+
|
527 |
+
@suppress_traceback
|
528 |
+
def project_depth(*args, **kwargs):
|
529 |
+
if TYPE_CHECKING: # redirected to:
|
530 |
+
utils3d.numpy.project_depth, utils3d.torch.project_depth
|
531 |
+
return _call_based_on_args('project_depth', args, kwargs)
|
532 |
+
|
533 |
+
@suppress_traceback
|
534 |
+
def depth_buffer_to_linear(*args, **kwargs):
|
535 |
+
if TYPE_CHECKING: # redirected to:
|
536 |
+
utils3d.numpy.depth_buffer_to_linear, utils3d.torch.depth_buffer_to_linear
|
537 |
+
return _call_based_on_args('depth_buffer_to_linear', args, kwargs)
|
538 |
+
|
539 |
+
@suppress_traceback
|
540 |
+
def unproject_cv(*args, **kwargs):
|
541 |
+
if TYPE_CHECKING: # redirected to:
|
542 |
+
utils3d.numpy.unproject_cv, utils3d.torch.unproject_cv
|
543 |
+
return _call_based_on_args('unproject_cv', args, kwargs)
|
544 |
+
|
545 |
+
@suppress_traceback
|
546 |
+
def unproject_gl(*args, **kwargs):
|
547 |
+
if TYPE_CHECKING: # redirected to:
|
548 |
+
utils3d.numpy.unproject_gl, utils3d.torch.unproject_gl
|
549 |
+
return _call_based_on_args('unproject_gl', args, kwargs)
|
550 |
+
|
551 |
+
@suppress_traceback
|
552 |
+
def project_cv(*args, **kwargs):
|
553 |
+
if TYPE_CHECKING: # redirected to:
|
554 |
+
utils3d.numpy.project_cv, utils3d.torch.project_cv
|
555 |
+
return _call_based_on_args('project_cv', args, kwargs)
|
556 |
+
|
557 |
+
@suppress_traceback
|
558 |
+
def project_gl(*args, **kwargs):
|
559 |
+
if TYPE_CHECKING: # redirected to:
|
560 |
+
utils3d.numpy.project_gl, utils3d.torch.project_gl
|
561 |
+
return _call_based_on_args('project_gl', args, kwargs)
|
562 |
+
|
563 |
+
@suppress_traceback
|
564 |
+
def quaternion_to_matrix(*args, **kwargs):
|
565 |
+
if TYPE_CHECKING: # redirected to:
|
566 |
+
utils3d.numpy.quaternion_to_matrix, utils3d.torch.quaternion_to_matrix
|
567 |
+
return _call_based_on_args('quaternion_to_matrix', args, kwargs)
|
568 |
+
|
569 |
+
@suppress_traceback
|
570 |
+
def axis_angle_to_matrix(*args, **kwargs):
|
571 |
+
if TYPE_CHECKING: # redirected to:
|
572 |
+
utils3d.numpy.axis_angle_to_matrix, utils3d.torch.axis_angle_to_matrix
|
573 |
+
return _call_based_on_args('axis_angle_to_matrix', args, kwargs)
|
574 |
+
|
575 |
+
@suppress_traceback
|
576 |
+
def matrix_to_quaternion(*args, **kwargs):
|
577 |
+
if TYPE_CHECKING: # redirected to:
|
578 |
+
utils3d.numpy.matrix_to_quaternion, utils3d.torch.matrix_to_quaternion
|
579 |
+
return _call_based_on_args('matrix_to_quaternion', args, kwargs)
|
580 |
+
|
581 |
+
@suppress_traceback
|
582 |
+
def extrinsics_to_essential(*args, **kwargs):
|
583 |
+
if TYPE_CHECKING: # redirected to:
|
584 |
+
utils3d.numpy.extrinsics_to_essential, utils3d.torch.extrinsics_to_essential
|
585 |
+
return _call_based_on_args('extrinsics_to_essential', args, kwargs)
|
586 |
+
|
587 |
+
@suppress_traceback
|
588 |
+
def euler_axis_angle_rotation(*args, **kwargs):
|
589 |
+
if TYPE_CHECKING: # redirected to:
|
590 |
+
utils3d.numpy.euler_axis_angle_rotation, utils3d.torch.euler_axis_angle_rotation
|
591 |
+
return _call_based_on_args('euler_axis_angle_rotation', args, kwargs)
|
592 |
+
|
593 |
+
@suppress_traceback
|
594 |
+
def euler_angles_to_matrix(*args, **kwargs):
|
595 |
+
if TYPE_CHECKING: # redirected to:
|
596 |
+
utils3d.numpy.euler_angles_to_matrix, utils3d.torch.euler_angles_to_matrix
|
597 |
+
return _call_based_on_args('euler_angles_to_matrix', args, kwargs)
|
598 |
+
|
599 |
+
@suppress_traceback
|
600 |
+
def skew_symmetric(*args, **kwargs):
|
601 |
+
if TYPE_CHECKING: # redirected to:
|
602 |
+
utils3d.numpy.skew_symmetric, utils3d.torch.skew_symmetric
|
603 |
+
return _call_based_on_args('skew_symmetric', args, kwargs)
|
604 |
+
|
605 |
+
@suppress_traceback
|
606 |
+
def rotation_matrix_from_vectors(*args, **kwargs):
|
607 |
+
if TYPE_CHECKING: # redirected to:
|
608 |
+
utils3d.numpy.rotation_matrix_from_vectors, utils3d.torch.rotation_matrix_from_vectors
|
609 |
+
return _call_based_on_args('rotation_matrix_from_vectors', args, kwargs)
|
610 |
+
|
611 |
+
@suppress_traceback
|
612 |
+
def ray_intersection(*args, **kwargs):
|
613 |
+
if TYPE_CHECKING: # redirected to:
|
614 |
+
utils3d.numpy.ray_intersection, None
|
615 |
+
return _call_based_on_args('ray_intersection', args, kwargs)
|
616 |
+
|
617 |
+
@suppress_traceback
|
618 |
+
def se3_matrix(*args, **kwargs):
|
619 |
+
if TYPE_CHECKING: # redirected to:
|
620 |
+
utils3d.numpy.se3_matrix, None
|
621 |
+
return _call_based_on_args('se3_matrix', args, kwargs)
|
622 |
+
|
623 |
+
@suppress_traceback
|
624 |
+
def slerp_quaternion(*args, **kwargs):
|
625 |
+
if TYPE_CHECKING: # redirected to:
|
626 |
+
utils3d.numpy.slerp_quaternion, None
|
627 |
+
return _call_based_on_args('slerp_quaternion', args, kwargs)
|
628 |
+
|
629 |
+
@suppress_traceback
|
630 |
+
def slerp_vector(*args, **kwargs):
|
631 |
+
if TYPE_CHECKING: # redirected to:
|
632 |
+
utils3d.numpy.slerp_vector, None
|
633 |
+
return _call_based_on_args('slerp_vector', args, kwargs)
|
634 |
+
|
635 |
+
@suppress_traceback
|
636 |
+
def lerp(*args, **kwargs):
|
637 |
+
if TYPE_CHECKING: # redirected to:
|
638 |
+
utils3d.numpy.lerp, None
|
639 |
+
return _call_based_on_args('lerp', args, kwargs)
|
640 |
+
|
641 |
+
@suppress_traceback
|
642 |
+
def lerp_se3_matrix(*args, **kwargs):
|
643 |
+
if TYPE_CHECKING: # redirected to:
|
644 |
+
utils3d.numpy.lerp_se3_matrix, None
|
645 |
+
return _call_based_on_args('lerp_se3_matrix', args, kwargs)
|
646 |
+
|
647 |
+
@suppress_traceback
|
648 |
+
def piecewise_lerp(*args, **kwargs):
|
649 |
+
if TYPE_CHECKING: # redirected to:
|
650 |
+
utils3d.numpy.piecewise_lerp, None
|
651 |
+
return _call_based_on_args('piecewise_lerp', args, kwargs)
|
652 |
+
|
653 |
+
@suppress_traceback
|
654 |
+
def piecewise_lerp_se3_matrix(*args, **kwargs):
|
655 |
+
if TYPE_CHECKING: # redirected to:
|
656 |
+
utils3d.numpy.piecewise_lerp_se3_matrix, None
|
657 |
+
return _call_based_on_args('piecewise_lerp_se3_matrix', args, kwargs)
|
658 |
+
|
659 |
+
@suppress_traceback
|
660 |
+
def apply_transform(*args, **kwargs):
|
661 |
+
if TYPE_CHECKING: # redirected to:
|
662 |
+
utils3d.numpy.apply_transform, None
|
663 |
+
return _call_based_on_args('apply_transform', args, kwargs)
|
664 |
+
|
665 |
+
@suppress_traceback
|
666 |
+
def linear_spline_interpolate(*args, **kwargs):
|
667 |
+
if TYPE_CHECKING: # redirected to:
|
668 |
+
utils3d.numpy.linear_spline_interpolate, None
|
669 |
+
return _call_based_on_args('linear_spline_interpolate', args, kwargs)
|
670 |
+
|
671 |
+
@suppress_traceback
|
672 |
+
def RastContext(*args, **kwargs):
|
673 |
+
if TYPE_CHECKING: # redirected to:
|
674 |
+
utils3d.numpy.RastContext, utils3d.torch.RastContext
|
675 |
+
return _call_based_on_args('RastContext', args, kwargs)
|
676 |
+
|
677 |
+
@suppress_traceback
|
678 |
+
def rasterize_triangle_faces(*args, **kwargs):
|
679 |
+
if TYPE_CHECKING: # redirected to:
|
680 |
+
utils3d.numpy.rasterize_triangle_faces, utils3d.torch.rasterize_triangle_faces
|
681 |
+
return _call_based_on_args('rasterize_triangle_faces', args, kwargs)
|
682 |
+
|
683 |
+
@suppress_traceback
|
684 |
+
def rasterize_edges(*args, **kwargs):
|
685 |
+
if TYPE_CHECKING: # redirected to:
|
686 |
+
utils3d.numpy.rasterize_edges, None
|
687 |
+
return _call_based_on_args('rasterize_edges', args, kwargs)
|
688 |
+
|
689 |
+
@suppress_traceback
|
690 |
+
def texture(*args, **kwargs):
|
691 |
+
if TYPE_CHECKING: # redirected to:
|
692 |
+
utils3d.numpy.texture, None
|
693 |
+
return _call_based_on_args('texture', args, kwargs)
|
694 |
+
|
695 |
+
@suppress_traceback
|
696 |
+
def warp_image_by_depth(*args, **kwargs):
|
697 |
+
if TYPE_CHECKING: # redirected to:
|
698 |
+
utils3d.numpy.warp_image_by_depth, utils3d.torch.warp_image_by_depth
|
699 |
+
return _call_based_on_args('warp_image_by_depth', args, kwargs)
|
700 |
+
|
701 |
+
@suppress_traceback
|
702 |
+
def test_rasterization(*args, **kwargs):
|
703 |
+
if TYPE_CHECKING: # redirected to:
|
704 |
+
utils3d.numpy.test_rasterization, None
|
705 |
+
return _call_based_on_args('test_rasterization', args, kwargs)
|
706 |
+
|
707 |
+
@suppress_traceback
|
708 |
+
def compute_face_angles(*args, **kwargs):
|
709 |
+
if TYPE_CHECKING: # redirected to:
|
710 |
+
None, utils3d.torch.compute_face_angles
|
711 |
+
return _call_based_on_args('compute_face_angles', args, kwargs)
|
712 |
+
|
713 |
+
@suppress_traceback
|
714 |
+
def compute_face_tbn(*args, **kwargs):
|
715 |
+
if TYPE_CHECKING: # redirected to:
|
716 |
+
None, utils3d.torch.compute_face_tbn
|
717 |
+
return _call_based_on_args('compute_face_tbn', args, kwargs)
|
718 |
+
|
719 |
+
@suppress_traceback
|
720 |
+
def compute_vertex_tbn(*args, **kwargs):
|
721 |
+
if TYPE_CHECKING: # redirected to:
|
722 |
+
None, utils3d.torch.compute_vertex_tbn
|
723 |
+
return _call_based_on_args('compute_vertex_tbn', args, kwargs)
|
724 |
+
|
725 |
+
@suppress_traceback
|
726 |
+
def laplacian(*args, **kwargs):
|
727 |
+
if TYPE_CHECKING: # redirected to:
|
728 |
+
None, utils3d.torch.laplacian
|
729 |
+
return _call_based_on_args('laplacian', args, kwargs)
|
730 |
+
|
731 |
+
@suppress_traceback
|
732 |
+
def laplacian_smooth_mesh(*args, **kwargs):
|
733 |
+
if TYPE_CHECKING: # redirected to:
|
734 |
+
None, utils3d.torch.laplacian_smooth_mesh
|
735 |
+
return _call_based_on_args('laplacian_smooth_mesh', args, kwargs)
|
736 |
+
|
737 |
+
@suppress_traceback
|
738 |
+
def taubin_smooth_mesh(*args, **kwargs):
|
739 |
+
if TYPE_CHECKING: # redirected to:
|
740 |
+
None, utils3d.torch.taubin_smooth_mesh
|
741 |
+
return _call_based_on_args('taubin_smooth_mesh', args, kwargs)
|
742 |
+
|
743 |
+
@suppress_traceback
|
744 |
+
def laplacian_hc_smooth_mesh(*args, **kwargs):
|
745 |
+
if TYPE_CHECKING: # redirected to:
|
746 |
+
None, utils3d.torch.laplacian_hc_smooth_mesh
|
747 |
+
return _call_based_on_args('laplacian_hc_smooth_mesh', args, kwargs)
|
748 |
+
|
749 |
+
@suppress_traceback
|
750 |
+
def get_rays(*args, **kwargs):
|
751 |
+
if TYPE_CHECKING: # redirected to:
|
752 |
+
None, utils3d.torch.get_rays
|
753 |
+
return _call_based_on_args('get_rays', args, kwargs)
|
754 |
+
|
755 |
+
@suppress_traceback
|
756 |
+
def get_image_rays(*args, **kwargs):
|
757 |
+
if TYPE_CHECKING: # redirected to:
|
758 |
+
None, utils3d.torch.get_image_rays
|
759 |
+
return _call_based_on_args('get_image_rays', args, kwargs)
|
760 |
+
|
761 |
+
@suppress_traceback
|
762 |
+
def get_mipnerf_cones(*args, **kwargs):
|
763 |
+
if TYPE_CHECKING: # redirected to:
|
764 |
+
None, utils3d.torch.get_mipnerf_cones
|
765 |
+
return _call_based_on_args('get_mipnerf_cones', args, kwargs)
|
766 |
+
|
767 |
+
@suppress_traceback
|
768 |
+
def volume_rendering(*args, **kwargs):
|
769 |
+
if TYPE_CHECKING: # redirected to:
|
770 |
+
None, utils3d.torch.volume_rendering
|
771 |
+
return _call_based_on_args('volume_rendering', args, kwargs)
|
772 |
+
|
773 |
+
@suppress_traceback
|
774 |
+
def bin_sample(*args, **kwargs):
|
775 |
+
if TYPE_CHECKING: # redirected to:
|
776 |
+
None, utils3d.torch.bin_sample
|
777 |
+
return _call_based_on_args('bin_sample', args, kwargs)
|
778 |
+
|
779 |
+
@suppress_traceback
|
780 |
+
def importance_sample(*args, **kwargs):
|
781 |
+
if TYPE_CHECKING: # redirected to:
|
782 |
+
None, utils3d.torch.importance_sample
|
783 |
+
return _call_based_on_args('importance_sample', args, kwargs)
|
784 |
+
|
785 |
+
@suppress_traceback
|
786 |
+
def nerf_render_rays(*args, **kwargs):
|
787 |
+
if TYPE_CHECKING: # redirected to:
|
788 |
+
None, utils3d.torch.nerf_render_rays
|
789 |
+
return _call_based_on_args('nerf_render_rays', args, kwargs)
|
790 |
+
|
791 |
+
@suppress_traceback
|
792 |
+
def mipnerf_render_rays(*args, **kwargs):
|
793 |
+
if TYPE_CHECKING: # redirected to:
|
794 |
+
None, utils3d.torch.mipnerf_render_rays
|
795 |
+
return _call_based_on_args('mipnerf_render_rays', args, kwargs)
|
796 |
+
|
797 |
+
@suppress_traceback
|
798 |
+
def nerf_render_view(*args, **kwargs):
|
799 |
+
if TYPE_CHECKING: # redirected to:
|
800 |
+
None, utils3d.torch.nerf_render_view
|
801 |
+
return _call_based_on_args('nerf_render_view', args, kwargs)
|
802 |
+
|
803 |
+
@suppress_traceback
|
804 |
+
def mipnerf_render_view(*args, **kwargs):
|
805 |
+
if TYPE_CHECKING: # redirected to:
|
806 |
+
None, utils3d.torch.mipnerf_render_view
|
807 |
+
return _call_based_on_args('mipnerf_render_view', args, kwargs)
|
808 |
+
|
809 |
+
@suppress_traceback
|
810 |
+
def InstantNGP(*args, **kwargs):
|
811 |
+
if TYPE_CHECKING: # redirected to:
|
812 |
+
None, utils3d.torch.InstantNGP
|
813 |
+
return _call_based_on_args('InstantNGP', args, kwargs)
|
814 |
+
|
815 |
+
@suppress_traceback
|
816 |
+
def point_to_normal(*args, **kwargs):
|
817 |
+
if TYPE_CHECKING: # redirected to:
|
818 |
+
None, utils3d.torch.point_to_normal
|
819 |
+
return _call_based_on_args('point_to_normal', args, kwargs)
|
820 |
+
|
821 |
+
@suppress_traceback
|
822 |
+
def depth_to_normal(*args, **kwargs):
|
823 |
+
if TYPE_CHECKING: # redirected to:
|
824 |
+
None, utils3d.torch.depth_to_normal
|
825 |
+
return _call_based_on_args('depth_to_normal', args, kwargs)
|
826 |
+
|
827 |
+
@suppress_traceback
|
828 |
+
def masked_min(*args, **kwargs):
|
829 |
+
if TYPE_CHECKING: # redirected to:
|
830 |
+
None, utils3d.torch.masked_min
|
831 |
+
return _call_based_on_args('masked_min', args, kwargs)
|
832 |
+
|
833 |
+
@suppress_traceback
|
834 |
+
def masked_max(*args, **kwargs):
|
835 |
+
if TYPE_CHECKING: # redirected to:
|
836 |
+
None, utils3d.torch.masked_max
|
837 |
+
return _call_based_on_args('masked_max', args, kwargs)
|
838 |
+
|
839 |
+
@suppress_traceback
|
840 |
+
def bounding_rect(*args, **kwargs):
|
841 |
+
if TYPE_CHECKING: # redirected to:
|
842 |
+
None, utils3d.torch.bounding_rect
|
843 |
+
return _call_based_on_args('bounding_rect', args, kwargs)
|
844 |
+
|
845 |
+
@suppress_traceback
|
846 |
+
def intrinsics_from_fov_xy(*args, **kwargs):
|
847 |
+
if TYPE_CHECKING: # redirected to:
|
848 |
+
None, utils3d.torch.intrinsics_from_fov_xy
|
849 |
+
return _call_based_on_args('intrinsics_from_fov_xy', args, kwargs)
|
850 |
+
|
851 |
+
@suppress_traceback
|
852 |
+
def matrix_to_euler_angles(*args, **kwargs):
|
853 |
+
if TYPE_CHECKING: # redirected to:
|
854 |
+
None, utils3d.torch.matrix_to_euler_angles
|
855 |
+
return _call_based_on_args('matrix_to_euler_angles', args, kwargs)
|
856 |
+
|
857 |
+
@suppress_traceback
|
858 |
+
def matrix_to_axis_angle(*args, **kwargs):
|
859 |
+
if TYPE_CHECKING: # redirected to:
|
860 |
+
None, utils3d.torch.matrix_to_axis_angle
|
861 |
+
return _call_based_on_args('matrix_to_axis_angle', args, kwargs)
|
862 |
+
|
863 |
+
@suppress_traceback
|
864 |
+
def axis_angle_to_quaternion(*args, **kwargs):
|
865 |
+
if TYPE_CHECKING: # redirected to:
|
866 |
+
None, utils3d.torch.axis_angle_to_quaternion
|
867 |
+
return _call_based_on_args('axis_angle_to_quaternion', args, kwargs)
|
868 |
+
|
869 |
+
@suppress_traceback
|
870 |
+
def quaternion_to_axis_angle(*args, **kwargs):
|
871 |
+
if TYPE_CHECKING: # redirected to:
|
872 |
+
None, utils3d.torch.quaternion_to_axis_angle
|
873 |
+
return _call_based_on_args('quaternion_to_axis_angle', args, kwargs)
|
874 |
+
|
875 |
+
@suppress_traceback
|
876 |
+
def slerp(*args, **kwargs):
|
877 |
+
if TYPE_CHECKING: # redirected to:
|
878 |
+
None, utils3d.torch.slerp
|
879 |
+
return _call_based_on_args('slerp', args, kwargs)
|
880 |
+
|
881 |
+
@suppress_traceback
|
882 |
+
def interpolate_extrinsics(*args, **kwargs):
|
883 |
+
if TYPE_CHECKING: # redirected to:
|
884 |
+
None, utils3d.torch.interpolate_extrinsics
|
885 |
+
return _call_based_on_args('interpolate_extrinsics', args, kwargs)
|
886 |
+
|
887 |
+
@suppress_traceback
|
888 |
+
def interpolate_view(*args, **kwargs):
|
889 |
+
if TYPE_CHECKING: # redirected to:
|
890 |
+
None, utils3d.torch.interpolate_view
|
891 |
+
return _call_based_on_args('interpolate_view', args, kwargs)
|
892 |
+
|
893 |
+
@suppress_traceback
|
894 |
+
def to4x4(*args, **kwargs):
|
895 |
+
if TYPE_CHECKING: # redirected to:
|
896 |
+
None, utils3d.torch.to4x4
|
897 |
+
return _call_based_on_args('to4x4', args, kwargs)
|
898 |
+
|
899 |
+
@suppress_traceback
|
900 |
+
def rotation_matrix_2d(*args, **kwargs):
|
901 |
+
if TYPE_CHECKING: # redirected to:
|
902 |
+
None, utils3d.torch.rotation_matrix_2d
|
903 |
+
return _call_based_on_args('rotation_matrix_2d', args, kwargs)
|
904 |
+
|
905 |
+
@suppress_traceback
|
906 |
+
def rotate_2d(*args, **kwargs):
|
907 |
+
if TYPE_CHECKING: # redirected to:
|
908 |
+
None, utils3d.torch.rotate_2d
|
909 |
+
return _call_based_on_args('rotate_2d', args, kwargs)
|
910 |
+
|
911 |
+
@suppress_traceback
|
912 |
+
def translate_2d(*args, **kwargs):
|
913 |
+
if TYPE_CHECKING: # redirected to:
|
914 |
+
None, utils3d.torch.translate_2d
|
915 |
+
return _call_based_on_args('translate_2d', args, kwargs)
|
916 |
+
|
917 |
+
@suppress_traceback
|
918 |
+
def scale_2d(*args, **kwargs):
|
919 |
+
if TYPE_CHECKING: # redirected to:
|
920 |
+
None, utils3d.torch.scale_2d
|
921 |
+
return _call_based_on_args('scale_2d', args, kwargs)
|
922 |
+
|
923 |
+
@suppress_traceback
|
924 |
+
def apply_2d(*args, **kwargs):
|
925 |
+
if TYPE_CHECKING: # redirected to:
|
926 |
+
None, utils3d.torch.apply_2d
|
927 |
+
return _call_based_on_args('apply_2d', args, kwargs)
|
928 |
+
|
929 |
+
@suppress_traceback
|
930 |
+
def warp_image_by_forward_flow(*args, **kwargs):
|
931 |
+
if TYPE_CHECKING: # redirected to:
|
932 |
+
None, utils3d.torch.warp_image_by_forward_flow
|
933 |
+
return _call_based_on_args('warp_image_by_forward_flow', args, kwargs)
|
934 |
+
|
src/utils3d/_unified/__init__.pyi
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/utils3d/io/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .obj import *
|
2 |
+
from .colmap import *
|
3 |
+
from .ply import *
|
src/utils3d/io/colmap.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from scipy.spatial.transform import Rotation
|
6 |
+
|
7 |
+
|
8 |
+
__all__ = ['read_extrinsics_from_colmap', 'read_intrinsics_from_colmap', 'write_extrinsics_as_colmap', 'write_intrinsics_as_colmap']
|
9 |
+
|
10 |
+
|
11 |
+
def write_extrinsics_as_colmap(file: Union[str, Path], extrinsics: np.ndarray, image_names: Union[str, List[str]] = 'image_{i:04d}.png', camera_ids: List[int] = None):
|
12 |
+
"""
|
13 |
+
Write extrinsics to colmap `images.txt` file.
|
14 |
+
Args:
|
15 |
+
file: Path to `images.txt` file.
|
16 |
+
extrinsics: (N, 4, 4) array of extrinsics.
|
17 |
+
image_names: str or List of str, image names. Length is N.
|
18 |
+
If str, it should be a format string with `i` as the index. (i starts from 1, in correspondence with IMAGE_ID in colmap)
|
19 |
+
camera_ids: List of int, camera ids. Length is N.
|
20 |
+
If None, it will be set to [1, 2, ..., N].
|
21 |
+
"""
|
22 |
+
assert extrinsics.shape[1:] == (4, 4) and extrinsics.ndim == 3 or extrinsics.shape == (4, 4)
|
23 |
+
if extrinsics.ndim == 2:
|
24 |
+
extrinsics = extrinsics[np.newaxis, ...]
|
25 |
+
quats = Rotation.from_matrix(extrinsics[:, :3, :3]).as_quat()
|
26 |
+
trans = extrinsics[:, :3, 3]
|
27 |
+
if camera_ids is None:
|
28 |
+
camera_ids = list(range(1, len(extrinsics) + 1))
|
29 |
+
if isinstance(image_names, str):
|
30 |
+
image_names = [image_names.format(i=i) for i in range(1, len(extrinsics) + 1)]
|
31 |
+
assert len(extrinsics) == len(image_names) == len(camera_ids), \
|
32 |
+
f'Number of extrinsics ({len(extrinsics)}), image_names ({len(image_names)}), and camera_ids ({len(camera_ids)}) must be the same'
|
33 |
+
with open(file, 'w') as fp:
|
34 |
+
print("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME", file=fp)
|
35 |
+
for i, (quat, t, name, camera_id) in enumerate(zip(quats.tolist(), trans.tolist(), image_names, camera_ids)):
|
36 |
+
# Colmap has wxyz order while scipy.spatial.transform.Rotation has xyzw order.
|
37 |
+
qx, qy, qz, qw = quat
|
38 |
+
tx, ty, tz = t
|
39 |
+
print(f'{i + 1} {qw:f} {qx:f} {qy:f} {qz:f} {tx:f} {ty:f} {tz:f} {camera_id:d} {name}', file=fp)
|
40 |
+
print()
|
41 |
+
|
42 |
+
|
43 |
+
def write_intrinsics_as_colmap(file: Union[str, Path], intrinsics: np.ndarray, width: int, height: int, normalized: bool = False):
|
44 |
+
"""
|
45 |
+
Write intrinsics to colmap `cameras.txt` file. Currently only support PINHOLE model (no distortion)
|
46 |
+
Args:
|
47 |
+
file: Path to `cameras.txt` file.
|
48 |
+
intrinsics: (N, 3, 3) array of intrinsics.
|
49 |
+
width: Image width.
|
50 |
+
height: Image height.
|
51 |
+
normalized: Whether the intrinsics are normalized. If True, the intrinsics will unnormalized for writing.
|
52 |
+
"""
|
53 |
+
assert intrinsics.shape[1:] == (3, 3) and intrinsics.ndim == 3 or intrinsics.shape == (3, 3)
|
54 |
+
if intrinsics.ndim == 2:
|
55 |
+
intrinsics = intrinsics[np.newaxis, ...]
|
56 |
+
if normalized:
|
57 |
+
intrinsics = intrinsics * np.array([width, height, 1])[:, None]
|
58 |
+
with open(file, 'w') as fp:
|
59 |
+
print("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]", file=fp)
|
60 |
+
for i, intr in enumerate(intrinsics):
|
61 |
+
fx, fy, cx, cy = intr[0, 0], intr[1, 1], intr[0, 2], intr[1, 2]
|
62 |
+
print(f'{i + 1} PINHOLE {width:d} {height:d} {fx:f} {fy:f} {cx:f} {cy:f}', file=fp)
|
63 |
+
|
64 |
+
|
65 |
+
def read_extrinsics_from_colmap(file: Union[str, Path]) -> Union[np.ndarray, List[int], List[str]]:
|
66 |
+
"""
|
67 |
+
Read extrinsics from colmap `images.txt` file.
|
68 |
+
Args:
|
69 |
+
file: Path to `images.txt` file.
|
70 |
+
Returns:
|
71 |
+
extrinsics: (N, 4, 4) array of extrinsics.
|
72 |
+
camera_ids: List of int, camera ids. Length is N. Note that camera ids in colmap typically starts from 1.
|
73 |
+
image_names: List of str, image names. Length is N.
|
74 |
+
"""
|
75 |
+
with open(file) as fp:
|
76 |
+
lines = fp.readlines()
|
77 |
+
image_names, quats, trans, camera_ids = [], [], [], []
|
78 |
+
i_line = 0
|
79 |
+
for line in lines:
|
80 |
+
line = line.strip()
|
81 |
+
if line.startswith('#'):
|
82 |
+
continue
|
83 |
+
i_line += 1
|
84 |
+
if i_line % 2 == 0:
|
85 |
+
continue
|
86 |
+
image_id, qw, qx, qy, qz, tx, ty, tz, camera_id, name = line.split()
|
87 |
+
quats.append([float(qx), float(qy), float(qz), float(qw)])
|
88 |
+
trans.append([float(tx), float(ty), float(tz)])
|
89 |
+
camera_ids.append(int(camera_id))
|
90 |
+
image_names.append(name)
|
91 |
+
|
92 |
+
quats = np.array(quats, dtype=np.float32)
|
93 |
+
trans = np.array(trans, dtype=np.float32)
|
94 |
+
rotation = Rotation.from_quat(quats).as_matrix()
|
95 |
+
extrinsics = np.concatenate([
|
96 |
+
np.concatenate([rotation, trans[..., None]], axis=-1),
|
97 |
+
np.array([0, 0, 0, 1], dtype=np.float32)[None, None, :].repeat(len(quats), axis=0)
|
98 |
+
], axis=-2)
|
99 |
+
|
100 |
+
return extrinsics, camera_ids, image_names
|
101 |
+
|
102 |
+
|
103 |
+
def read_intrinsics_from_colmap(file: Union[str, Path], normalize: bool = False) -> Tuple[List[int], np.ndarray, np.ndarray]:
|
104 |
+
"""
|
105 |
+
Read intrinsics from colmap `cameras.txt` file.
|
106 |
+
Args:
|
107 |
+
file: Path to `cameras.txt` file.
|
108 |
+
normalize: Whether to normalize the intrinsics. If True, the intrinsics will be normalized. (mapping coordinates to [0, 1] range)
|
109 |
+
Returns:
|
110 |
+
camera_ids: List of int, camera ids. Length is N. Note that camera ids in colmap typically starts from 1.
|
111 |
+
intrinsics: (N, 3, 3) array of intrinsics.
|
112 |
+
distortions: (N, 5) array of distortions.
|
113 |
+
"""
|
114 |
+
with open(file) as fp:
|
115 |
+
lines = fp.readlines()
|
116 |
+
intrinsics, distortions, camera_ids = [], [], []
|
117 |
+
for line in lines:
|
118 |
+
line = line.strip()
|
119 |
+
if not line or line.startswith('#'):
|
120 |
+
continue
|
121 |
+
camera_id, model, width, height, *params = line.split()
|
122 |
+
camera_id, width, height = int(camera_id), int(width), int(height)
|
123 |
+
if model == 'PINHOLE':
|
124 |
+
fx, fy, cx, cy = map(float, params[:4])
|
125 |
+
k1 = k2 = k3 = p1 = p2 = 0.0
|
126 |
+
elif model == 'OPENCV':
|
127 |
+
fx, fy, cx, cy, k1, k2, p1, p2, k3 = *map(float, params[:8]), 0.0
|
128 |
+
elif model == 'SIMPLE_RADIAL':
|
129 |
+
f, cx, cy, k = map(float, params[:4])
|
130 |
+
fx = fy = f
|
131 |
+
k1, k2, p1, p2, k3 = k, 0.0, 0.0, 0.0, 0.0
|
132 |
+
camera_ids.append(camera_id)
|
133 |
+
if normalize:
|
134 |
+
fx, fy, cx, cy = fx / width, fy / height, cx / width, cy / height
|
135 |
+
intrinsics.append([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
|
136 |
+
distortions.append([k1, k2, p1, p2, k3])
|
137 |
+
intrinsics = np.array(intrinsics, dtype=np.float32)
|
138 |
+
distortions = np.array(distortions, dtype=np.float32)
|
139 |
+
return camera_ids, intrinsics, distortions
|
src/utils3d/io/obj.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from io import TextIOWrapper
|
2 |
+
from typing import Dict, Any, Union, Iterable
|
3 |
+
import numpy as np
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
'read_obj',
|
8 |
+
'write_obj',
|
9 |
+
'simple_write_obj'
|
10 |
+
]
|
11 |
+
|
12 |
+
def read_obj(
|
13 |
+
file : Union[str, Path, TextIOWrapper],
|
14 |
+
encoding: Union[str, None] = None,
|
15 |
+
ignore_unknown: bool = False
|
16 |
+
):
|
17 |
+
"""
|
18 |
+
Read wavefront .obj file, without preprocessing.
|
19 |
+
|
20 |
+
Why bothering having this read_obj() while we already have other libraries like `trimesh`?
|
21 |
+
This function read the raw format from .obj file and keeps the order of vertices and faces,
|
22 |
+
while trimesh which involves modification like merge/split vertices, which could break the orders of vertices and faces,
|
23 |
+
Those libraries are commonly aiming at geometry processing and rendering supporting various formats.
|
24 |
+
If you want mesh geometry processing, you may turn to `trimesh` for more features.
|
25 |
+
|
26 |
+
### Parameters
|
27 |
+
`file` (str, Path, TextIOWrapper): filepath or file object
|
28 |
+
encoding (str, optional):
|
29 |
+
|
30 |
+
### Returns
|
31 |
+
obj (dict): A dict containing .obj components
|
32 |
+
{
|
33 |
+
'mtllib': [],
|
34 |
+
'v': [[0,1, 0.2, 1.0], [1.2, 0.0, 0.0], ...],
|
35 |
+
'vt': [[0.5, 0.5], ...],
|
36 |
+
'vn': [[0., 0.7, 0.7], [0., -0.7, 0.7], ...],
|
37 |
+
'f': [[0, 1, 2], [2, 3, 4],...],
|
38 |
+
'usemtl': [{'name': 'mtl1', 'f': 7}]
|
39 |
+
}
|
40 |
+
"""
|
41 |
+
if hasattr(file,'read'):
|
42 |
+
lines = file.read().splitlines()
|
43 |
+
else:
|
44 |
+
with open(file, 'r', encoding=encoding) as fp:
|
45 |
+
lines = fp.read().splitlines()
|
46 |
+
mtllib = []
|
47 |
+
v, vt, vn, vp = [], [], [], [] # Vertex coordinates, Vertex texture coordinate, Vertex normal, Vertex parameter
|
48 |
+
f, ft, fn = [], [], [] # Face indices, Face texture indices, Face normal indices
|
49 |
+
o = []
|
50 |
+
s = []
|
51 |
+
usemtl = []
|
52 |
+
|
53 |
+
def pad(l: list, n: Any):
|
54 |
+
return l + [n] * (3 - len(l))
|
55 |
+
|
56 |
+
for i, line in enumerate(lines):
|
57 |
+
sq = line.strip().split()
|
58 |
+
if len(sq) == 0:
|
59 |
+
continue
|
60 |
+
if sq[0] == 'v':
|
61 |
+
assert 4 <= len(sq) <= 5, f'Invalid format of line {i}: {line}'
|
62 |
+
v.append([float(e) for e in sq[1:]][:3])
|
63 |
+
elif sq[0] == 'vt':
|
64 |
+
assert 3 <= len(sq) <= 4, f'Invalid format of line {i}: {line}'
|
65 |
+
vt.append([float(e) for e in sq[1:]][:2])
|
66 |
+
elif sq[0] == 'vn':
|
67 |
+
assert len(sq) == 4, f'Invalid format of line {i}: {line}'
|
68 |
+
vn.append([float(e) for e in sq[1:]])
|
69 |
+
elif sq[0] == 'vp':
|
70 |
+
assert 2 <= len(sq) <= 4, f'Invalid format of line {i}: {line}'
|
71 |
+
vp.append(pad([float(e) for e in sq[1:]], 0))
|
72 |
+
elif sq[0] == 'f':
|
73 |
+
spliting = [pad([int(j) - 1 for j in e.split('/')], -1) for e in sq[1:]]
|
74 |
+
f.append([e[0] for e in spliting])
|
75 |
+
ft.append([e[1] for e in spliting])
|
76 |
+
fn.append([e[2] for e in spliting])
|
77 |
+
elif sq[0] == 'usemtl':
|
78 |
+
assert len(sq) == 2
|
79 |
+
usemtl.append((sq[1], len(f)))
|
80 |
+
elif sq[0] == 'o':
|
81 |
+
assert len(sq) == 2
|
82 |
+
o.append((sq[1], len(f)))
|
83 |
+
elif sq[0] == 's':
|
84 |
+
s.append((sq[1], len(f)))
|
85 |
+
elif sq[0] == 'mtllib':
|
86 |
+
assert len(sq) == 2
|
87 |
+
mtllib.append(sq[1])
|
88 |
+
elif sq[0][0] == '#':
|
89 |
+
continue
|
90 |
+
else:
|
91 |
+
if not ignore_unknown:
|
92 |
+
raise Exception(f'Unknown keyword {sq[0]}')
|
93 |
+
|
94 |
+
min_poly_vertices = min(len(f) for f in f)
|
95 |
+
max_poly_vertices = max(len(f) for f in f)
|
96 |
+
|
97 |
+
return {
|
98 |
+
'mtllib': mtllib,
|
99 |
+
'v': np.array(v, dtype=np.float32),
|
100 |
+
'vt': np.array(vt, dtype=np.float32),
|
101 |
+
'vn': np.array(vn, dtype=np.float32),
|
102 |
+
'vp': np.array(vp, dtype=np.float32),
|
103 |
+
'f': np.array(f, dtype=np.int32) if min_poly_vertices == max_poly_vertices else f,
|
104 |
+
'ft': np.array(ft, dtype=np.int32) if min_poly_vertices == max_poly_vertices else ft,
|
105 |
+
'fn': np.array(fn, dtype=np.int32) if min_poly_vertices == max_poly_vertices else fn,
|
106 |
+
'o': o,
|
107 |
+
's': s,
|
108 |
+
'usemtl': usemtl,
|
109 |
+
}
|
110 |
+
|
111 |
+
|
112 |
+
def write_obj(
|
113 |
+
file: Union[str, Path],
|
114 |
+
obj: Dict[str, Any],
|
115 |
+
encoding: Union[str, None] = None
|
116 |
+
):
|
117 |
+
with open(file, 'w', encoding=encoding) as fp:
|
118 |
+
for k in ['v', 'vt', 'vn', 'vp']:
|
119 |
+
if k not in obj:
|
120 |
+
continue
|
121 |
+
for v in obj[k]:
|
122 |
+
print(k, *map(float, v), file=fp)
|
123 |
+
for f in obj['f']:
|
124 |
+
print('f', *((str('/').join(map(int, i)) if isinstance(int(i), Iterable) else i) for i in f), file=fp)
|
125 |
+
|
126 |
+
|
127 |
+
def simple_write_obj(
|
128 |
+
file: Union[str, Path],
|
129 |
+
vertices: np.ndarray,
|
130 |
+
faces: np.ndarray,
|
131 |
+
encoding: Union[str, None] = None
|
132 |
+
):
|
133 |
+
"""
|
134 |
+
Write wavefront .obj file, without preprocessing.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
vertices (np.ndarray): [N, 3]
|
138 |
+
faces (np.ndarray): [T, 3]
|
139 |
+
file (Any): filepath
|
140 |
+
encoding (str, optional):
|
141 |
+
"""
|
142 |
+
with open(file, 'w', encoding=encoding) as fp:
|
143 |
+
for v in vertices:
|
144 |
+
print('v', *map(float, v), file=fp)
|
145 |
+
for f in faces:
|
146 |
+
print('f', *map(int, f + 1), file=fp)
|
src/utils3d/io/ply.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
from typing import *
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
|
7 |
+
def read_ply(
|
8 |
+
file: Union[str, Path],
|
9 |
+
encoding: Union[str, None] = None,
|
10 |
+
ignore_unknown: bool = False
|
11 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
12 |
+
"""
|
13 |
+
Read .ply file, without preprocessing.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
file (Any): filepath
|
17 |
+
encoding (str, optional):
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
Tuple[np.ndarray, np.ndarray]: vertices, faces
|
21 |
+
"""
|
22 |
+
import plyfile
|
23 |
+
plydata = plyfile.PlyData.read(file)
|
24 |
+
vertices = np.stack([plydata['vertex'][k] for k in ['x', 'y', 'z']], axis=-1)
|
25 |
+
if 'face' in plydata:
|
26 |
+
faces = np.array(plydata['face']['vertex_indices'].tolist())
|
27 |
+
else:
|
28 |
+
faces = None
|
29 |
+
return vertices, faces
|
30 |
+
|
31 |
+
|
32 |
+
def write_ply(
|
33 |
+
file: Union[str, Path],
|
34 |
+
vertices: np.ndarray,
|
35 |
+
faces: np.ndarray = None,
|
36 |
+
edges: np.ndarray = None,
|
37 |
+
vertex_colors: np.ndarray = None,
|
38 |
+
edge_colors: np.ndarray = None,
|
39 |
+
text: bool = False
|
40 |
+
):
|
41 |
+
"""
|
42 |
+
Write .ply file, without preprocessing.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
file (Any): filepath
|
46 |
+
vertices (np.ndarray): [N, 3]
|
47 |
+
faces (np.ndarray): [T, E]
|
48 |
+
edges (np.ndarray): [E, 2]
|
49 |
+
vertex_colors (np.ndarray, optional): [N, 3]. Defaults to None.
|
50 |
+
edge_colors (np.ndarray, optional): [E, 3]. Defaults to None.
|
51 |
+
text (bool, optional): save data in text format. Defaults to False.
|
52 |
+
"""
|
53 |
+
import plyfile
|
54 |
+
assert vertices.ndim == 2 and vertices.shape[1] == 3
|
55 |
+
vertices = vertices.astype(np.float32)
|
56 |
+
if faces is not None:
|
57 |
+
assert faces.ndim == 2
|
58 |
+
faces = faces.astype(np.int32)
|
59 |
+
if edges is not None:
|
60 |
+
assert edges.ndim == 2 and edges.shape[1] == 2
|
61 |
+
edges = edges.astype(np.int32)
|
62 |
+
|
63 |
+
if vertex_colors is not None:
|
64 |
+
assert vertex_colors.ndim == 2 and vertex_colors.shape[1] == 3
|
65 |
+
if vertex_colors.dtype in [np.float32, np.float64]:
|
66 |
+
vertex_colors = vertex_colors * 255
|
67 |
+
vertex_colors = np.clip(vertex_colors, 0, 255).astype(np.uint8)
|
68 |
+
vertices_data = np.zeros(len(vertices), dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
|
69 |
+
vertices_data['x'] = vertices[:, 0]
|
70 |
+
vertices_data['y'] = vertices[:, 1]
|
71 |
+
vertices_data['z'] = vertices[:, 2]
|
72 |
+
vertices_data['red'] = vertex_colors[:, 0]
|
73 |
+
vertices_data['green'] = vertex_colors[:, 1]
|
74 |
+
vertices_data['blue'] = vertex_colors[:, 2]
|
75 |
+
else:
|
76 |
+
vertices_data = np.array([tuple(v) for v in vertices], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
|
77 |
+
|
78 |
+
if faces is not None:
|
79 |
+
faces_data = np.zeros(len(faces), dtype=[('vertex_indices', 'i4', (faces.shape[1],))])
|
80 |
+
faces_data['vertex_indices'] = faces
|
81 |
+
|
82 |
+
if edges is not None:
|
83 |
+
if edge_colors is not None:
|
84 |
+
assert edge_colors.ndim == 2 and edge_colors.shape[1] == 3
|
85 |
+
if edge_colors.dtype in [np.float32, np.float64]:
|
86 |
+
edge_colors = edge_colors * 255
|
87 |
+
edge_colors = np.clip(edge_colors, 0, 255).astype(np.uint8)
|
88 |
+
edges_data = np.zeros(len(edges), dtype=[('vertex1', 'i4'), ('vertex2', 'i4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
|
89 |
+
edges_data['vertex1'] = edges[:, 0]
|
90 |
+
edges_data['vertex2'] = edges[:, 1]
|
91 |
+
edges_data['red'] = edge_colors[:, 0]
|
92 |
+
edges_data['green'] = edge_colors[:, 1]
|
93 |
+
edges_data['blue'] = edge_colors[:, 2]
|
94 |
+
else:
|
95 |
+
edges_data = np.array([tuple(e) for e in edges], dtype=[('vertex1', 'i4'), ('vertex2', 'i4')])
|
96 |
+
|
97 |
+
ply_data = [plyfile.PlyElement.describe(vertices_data, 'vertex')]
|
98 |
+
if faces is not None:
|
99 |
+
ply_data.append(plyfile.PlyElement.describe(faces_data, 'face'))
|
100 |
+
if edges is not None:
|
101 |
+
ply_data.append(plyfile.PlyElement.describe(edges_data, 'edge'))
|
102 |
+
|
103 |
+
plyfile.PlyData(ply_data, text=text).write(file)
|
104 |
+
|
src/utils3d/numpy/__init__.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
3D utility functions workings with NumPy.
|
3 |
+
"""
|
4 |
+
import importlib
|
5 |
+
import itertools
|
6 |
+
import numpy
|
7 |
+
from typing import TYPE_CHECKING
|
8 |
+
|
9 |
+
|
10 |
+
__modules_all__ = {
|
11 |
+
'mesh':[
|
12 |
+
'triangulate',
|
13 |
+
'compute_face_normal',
|
14 |
+
'compute_face_angle',
|
15 |
+
'compute_vertex_normal',
|
16 |
+
'compute_vertex_normal_weighted',
|
17 |
+
'remove_corrupted_faces',
|
18 |
+
'merge_duplicate_vertices',
|
19 |
+
'remove_unreferenced_vertices',
|
20 |
+
'subdivide_mesh_simple',
|
21 |
+
'mesh_relations',
|
22 |
+
'flatten_mesh_indices'
|
23 |
+
],
|
24 |
+
'quadmesh': [
|
25 |
+
'calc_quad_candidates',
|
26 |
+
'calc_quad_distortion',
|
27 |
+
'calc_quad_direction',
|
28 |
+
'calc_quad_smoothness',
|
29 |
+
'sovle_quad',
|
30 |
+
'sovle_quad_qp',
|
31 |
+
'tri_to_quad'
|
32 |
+
],
|
33 |
+
'utils': [
|
34 |
+
'sliding_window_1d',
|
35 |
+
'sliding_window_nd',
|
36 |
+
'sliding_window_2d',
|
37 |
+
'max_pool_1d',
|
38 |
+
'max_pool_2d',
|
39 |
+
'max_pool_nd',
|
40 |
+
'depth_edge',
|
41 |
+
'normals_edge',
|
42 |
+
'depth_aliasing',
|
43 |
+
'interpolate',
|
44 |
+
'image_scrcoord',
|
45 |
+
'image_uv',
|
46 |
+
'image_pixel_center',
|
47 |
+
'image_pixel',
|
48 |
+
'image_mesh',
|
49 |
+
'image_mesh_from_depth',
|
50 |
+
'depth_to_normals',
|
51 |
+
'points_to_normals',
|
52 |
+
'chessboard',
|
53 |
+
'cube',
|
54 |
+
'icosahedron',
|
55 |
+
'square',
|
56 |
+
'camera_frustum',
|
57 |
+
],
|
58 |
+
'transforms': [
|
59 |
+
'perspective',
|
60 |
+
'perspective_from_fov',
|
61 |
+
'perspective_from_fov_xy',
|
62 |
+
'intrinsics_from_focal_center',
|
63 |
+
'intrinsics_from_fov',
|
64 |
+
'fov_to_focal',
|
65 |
+
'focal_to_fov',
|
66 |
+
'intrinsics_to_fov',
|
67 |
+
'view_look_at',
|
68 |
+
'extrinsics_look_at',
|
69 |
+
'perspective_to_intrinsics',
|
70 |
+
'perspective_to_near_far',
|
71 |
+
'intrinsics_to_perspective',
|
72 |
+
'extrinsics_to_view',
|
73 |
+
'view_to_extrinsics',
|
74 |
+
'normalize_intrinsics',
|
75 |
+
'crop_intrinsics',
|
76 |
+
'pixel_to_uv',
|
77 |
+
'pixel_to_ndc',
|
78 |
+
'uv_to_pixel',
|
79 |
+
'project_depth',
|
80 |
+
'depth_buffer_to_linear',
|
81 |
+
'unproject_cv',
|
82 |
+
'unproject_gl',
|
83 |
+
'project_cv',
|
84 |
+
'project_gl',
|
85 |
+
'quaternion_to_matrix',
|
86 |
+
'axis_angle_to_matrix',
|
87 |
+
'matrix_to_quaternion',
|
88 |
+
'extrinsics_to_essential',
|
89 |
+
'euler_axis_angle_rotation',
|
90 |
+
'euler_angles_to_matrix',
|
91 |
+
'skew_symmetric',
|
92 |
+
'rotation_matrix_from_vectors',
|
93 |
+
'ray_intersection',
|
94 |
+
'se3_matrix',
|
95 |
+
'slerp_quaternion',
|
96 |
+
'slerp_vector',
|
97 |
+
'lerp',
|
98 |
+
'lerp_se3_matrix',
|
99 |
+
'piecewise_lerp',
|
100 |
+
'piecewise_lerp_se3_matrix',
|
101 |
+
'apply_transform'
|
102 |
+
],
|
103 |
+
'spline': [
|
104 |
+
'linear_spline_interpolate',
|
105 |
+
],
|
106 |
+
'rasterization': [
|
107 |
+
'RastContext',
|
108 |
+
'rasterize_triangle_faces',
|
109 |
+
'rasterize_edges',
|
110 |
+
'texture',
|
111 |
+
'warp_image_by_depth',
|
112 |
+
'test_rasterization'
|
113 |
+
],
|
114 |
+
}
|
115 |
+
|
116 |
+
|
117 |
+
__all__ = list(itertools.chain(*__modules_all__.values()))
|
118 |
+
|
119 |
+
def __getattr__(name):
|
120 |
+
try:
|
121 |
+
return globals()[name]
|
122 |
+
except KeyError:
|
123 |
+
pass
|
124 |
+
|
125 |
+
try:
|
126 |
+
module_name = next(m for m in __modules_all__ if name in __modules_all__[m])
|
127 |
+
except StopIteration:
|
128 |
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
129 |
+
module = importlib.import_module(f'.{module_name}', __name__)
|
130 |
+
for key in __modules_all__[module_name]:
|
131 |
+
globals()[key] = getattr(module, key)
|
132 |
+
|
133 |
+
return globals()[name]
|
134 |
+
|
135 |
+
|
136 |
+
if TYPE_CHECKING:
|
137 |
+
from .quadmesh import *
|
138 |
+
from .transforms import *
|
139 |
+
from .mesh import *
|
140 |
+
from .utils import *
|
141 |
+
from .rasterization import *
|
142 |
+
from .spline import *
|
src/utils3d/numpy/_helpers.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# decorator
|
2 |
+
import numpy as np
|
3 |
+
from numbers import Number
|
4 |
+
import inspect
|
5 |
+
from functools import wraps
|
6 |
+
from typing import *
|
7 |
+
from .._helpers import suppress_traceback
|
8 |
+
|
9 |
+
|
10 |
+
def get_args_order(func, args, kwargs):
|
11 |
+
"""
|
12 |
+
Get the order of the arguments of a function.
|
13 |
+
"""
|
14 |
+
names = inspect.getfullargspec(func).args
|
15 |
+
names_idx = {name: i for i, name in enumerate(names)}
|
16 |
+
args_order = []
|
17 |
+
kwargs_order = {}
|
18 |
+
for name, arg in kwargs.items():
|
19 |
+
if name in names:
|
20 |
+
kwargs_order[name] = names_idx[name]
|
21 |
+
names.remove(name)
|
22 |
+
for i, arg in enumerate(args):
|
23 |
+
if i < len(names):
|
24 |
+
args_order.append(names_idx[names[i]])
|
25 |
+
return args_order, kwargs_order
|
26 |
+
|
27 |
+
|
28 |
+
def broadcast_args(args, kwargs, args_dim, kwargs_dim):
|
29 |
+
spatial = []
|
30 |
+
for arg, arg_dim in zip(args + list(kwargs.values()), args_dim + list(kwargs_dim.values())):
|
31 |
+
if isinstance(arg, np.ndarray) and arg_dim is not None:
|
32 |
+
arg_spatial = arg.shape[:arg.ndim-arg_dim]
|
33 |
+
if len(arg_spatial) > len(spatial):
|
34 |
+
spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial
|
35 |
+
for j in range(len(arg_spatial)):
|
36 |
+
if spatial[-j] < arg_spatial[-j]:
|
37 |
+
if spatial[-j] == 1:
|
38 |
+
spatial[-j] = arg_spatial[-j]
|
39 |
+
else:
|
40 |
+
raise ValueError("Cannot broadcast arguments.")
|
41 |
+
for i, arg in enumerate(args):
|
42 |
+
if isinstance(arg, np.ndarray) and args_dim[i] is not None:
|
43 |
+
args[i] = np.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-args_dim[i]:]])
|
44 |
+
for key, arg in kwargs.items():
|
45 |
+
if isinstance(arg, np.ndarray) and kwargs_dim[key] is not None:
|
46 |
+
kwargs[key] = np.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-kwargs_dim[key]:]])
|
47 |
+
return args, kwargs, spatial
|
48 |
+
|
49 |
+
|
50 |
+
def batched(*dims):
|
51 |
+
"""
|
52 |
+
Decorator that allows a function to be called with batched arguments.
|
53 |
+
"""
|
54 |
+
def decorator(func):
|
55 |
+
@wraps(func)
|
56 |
+
@suppress_traceback
|
57 |
+
def wrapper(*args, **kwargs):
|
58 |
+
args = list(args)
|
59 |
+
# get arguments dimensions
|
60 |
+
args_order, kwargs_order = get_args_order(func, args, kwargs)
|
61 |
+
args_dim = [dims[i] for i in args_order]
|
62 |
+
kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()}
|
63 |
+
# convert to numpy array
|
64 |
+
for i, arg in enumerate(args):
|
65 |
+
if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None:
|
66 |
+
args[i] = np.array(arg)
|
67 |
+
for key, arg in kwargs.items():
|
68 |
+
if isinstance(arg, (Number, list, tuple)) and kwargs_dim[key] is not None:
|
69 |
+
kwargs[key] = np.array(arg)
|
70 |
+
# broadcast arguments
|
71 |
+
args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim)
|
72 |
+
for i, (arg, arg_dim) in enumerate(zip(args, args_dim)):
|
73 |
+
if isinstance(arg, np.ndarray) and arg_dim is not None:
|
74 |
+
args[i] = arg.reshape([-1, *arg.shape[arg.ndim-arg_dim:]])
|
75 |
+
for key, arg in kwargs.items():
|
76 |
+
if isinstance(arg, np.ndarray) and kwargs_dim[key] is not None:
|
77 |
+
kwargs[key] = arg.reshape([-1, *arg.shape[arg.ndim-kwargs_dim[key]:]])
|
78 |
+
# call function
|
79 |
+
results = func(*args, **kwargs)
|
80 |
+
type_results = type(results)
|
81 |
+
results = list(results) if isinstance(results, (tuple, list)) else [results]
|
82 |
+
# restore spatial dimensions
|
83 |
+
for i, result in enumerate(results):
|
84 |
+
results[i] = result.reshape([*spatial, *result.shape[1:]])
|
85 |
+
if type_results == tuple:
|
86 |
+
results = tuple(results)
|
87 |
+
elif type_results == list:
|
88 |
+
results = list(results)
|
89 |
+
else:
|
90 |
+
results = results[0]
|
91 |
+
return results
|
92 |
+
return wrapper
|
93 |
+
return decorator
|
src/utils3d/numpy/mesh.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import *
|
3 |
+
from ._helpers import batched
|
4 |
+
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
'triangulate',
|
8 |
+
'compute_face_normal',
|
9 |
+
'compute_face_angle',
|
10 |
+
'compute_vertex_normal',
|
11 |
+
'compute_vertex_normal_weighted',
|
12 |
+
'remove_corrupted_faces',
|
13 |
+
'merge_duplicate_vertices',
|
14 |
+
'remove_unreferenced_vertices',
|
15 |
+
'subdivide_mesh_simple',
|
16 |
+
'mesh_relations',
|
17 |
+
'flatten_mesh_indices'
|
18 |
+
]
|
19 |
+
|
20 |
+
|
21 |
+
def triangulate(
|
22 |
+
faces: np.ndarray,
|
23 |
+
vertices: np.ndarray = None,
|
24 |
+
backslash: np.ndarray = None
|
25 |
+
) -> np.ndarray:
|
26 |
+
"""
|
27 |
+
Triangulate a polygonal mesh.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
faces (np.ndarray): [L, P] polygonal faces
|
31 |
+
vertices (np.ndarray, optional): [N, 3] 3-dimensional vertices.
|
32 |
+
If given, the triangulation is performed according to the distance
|
33 |
+
between vertices. Defaults to None.
|
34 |
+
backslash (np.ndarray, optional): [L] boolean array indicating
|
35 |
+
how to triangulate the quad faces. Defaults to None.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
(np.ndarray): [L * (P - 2), 3] triangular faces
|
39 |
+
"""
|
40 |
+
if faces.shape[-1] == 3:
|
41 |
+
return faces
|
42 |
+
P = faces.shape[-1]
|
43 |
+
if vertices is not None:
|
44 |
+
assert faces.shape[-1] == 4, "now only support quad mesh"
|
45 |
+
if backslash is None:
|
46 |
+
backslash = np.linalg.norm(vertices[faces[:, 0]] - vertices[faces[:, 2]], axis=-1) < \
|
47 |
+
np.linalg.norm(vertices[faces[:, 1]] - vertices[faces[:, 3]], axis=-1)
|
48 |
+
if backslash is None:
|
49 |
+
loop_indice = np.stack([
|
50 |
+
np.zeros(P - 2, dtype=int),
|
51 |
+
np.arange(1, P - 1, 1, dtype=int),
|
52 |
+
np.arange(2, P, 1, dtype=int)
|
53 |
+
], axis=1)
|
54 |
+
return faces[:, loop_indice].reshape((-1, 3))
|
55 |
+
else:
|
56 |
+
assert faces.shape[-1] == 4, "now only support quad mesh"
|
57 |
+
faces = np.where(
|
58 |
+
backslash[:, None],
|
59 |
+
faces[:, [0, 1, 2, 0, 2, 3]],
|
60 |
+
faces[:, [0, 1, 3, 3, 1, 2]]
|
61 |
+
).reshape((-1, 3))
|
62 |
+
return faces
|
63 |
+
|
64 |
+
|
65 |
+
@batched(2, None)
|
66 |
+
def compute_face_normal(
|
67 |
+
vertices: np.ndarray,
|
68 |
+
faces: np.ndarray
|
69 |
+
) -> np.ndarray:
|
70 |
+
"""
|
71 |
+
Compute face normals of a triangular mesh
|
72 |
+
|
73 |
+
Args:
|
74 |
+
vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
|
75 |
+
faces (np.ndarray): [T, 3] triangular face indices
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
normals (np.ndarray): [..., T, 3] face normals
|
79 |
+
"""
|
80 |
+
normal = np.cross(
|
81 |
+
vertices[..., faces[:, 1], :] - vertices[..., faces[:, 0], :],
|
82 |
+
vertices[..., faces[:, 2], :] - vertices[..., faces[:, 0], :]
|
83 |
+
)
|
84 |
+
normal_norm = np.linalg.norm(normal, axis=-1, keepdims=True)
|
85 |
+
normal_norm[normal_norm == 0] = 1
|
86 |
+
normal /= normal_norm
|
87 |
+
return normal
|
88 |
+
|
89 |
+
|
90 |
+
@batched(2, None)
|
91 |
+
def compute_face_angle(
|
92 |
+
vertices: np.ndarray,
|
93 |
+
faces: np.ndarray,
|
94 |
+
eps: float = 1e-12
|
95 |
+
) -> np.ndarray:
|
96 |
+
"""
|
97 |
+
Compute face angles of a triangular mesh
|
98 |
+
|
99 |
+
Args:
|
100 |
+
vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
|
101 |
+
faces (np.ndarray): [T, 3] triangular face indices
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
angles (np.ndarray): [..., T, 3] face angles
|
105 |
+
"""
|
106 |
+
face_angle = np.zeros_like(faces, dtype=vertices.dtype)
|
107 |
+
for i in range(3):
|
108 |
+
edge1 = vertices[..., faces[:, (i + 1) % 3], :] - vertices[..., faces[:, i], :]
|
109 |
+
edge2 = vertices[..., faces[:, (i + 2) % 3], :] - vertices[..., faces[:, i], :]
|
110 |
+
face_angle[..., i] = np.arccos(np.sum(
|
111 |
+
edge1 / np.clip(np.linalg.norm(edge1, axis=-1, keepdims=True), eps, None) *
|
112 |
+
edge2 / np.clip(np.linalg.norm(edge2, axis=-1, keepdims=True), eps, None),
|
113 |
+
axis=-1
|
114 |
+
))
|
115 |
+
return face_angle
|
116 |
+
|
117 |
+
|
118 |
+
@batched(2, None, 2)
|
119 |
+
def compute_vertex_normal(
|
120 |
+
vertices: np.ndarray,
|
121 |
+
faces: np.ndarray,
|
122 |
+
face_normal: np.ndarray = None
|
123 |
+
) -> np.ndarray:
|
124 |
+
"""
|
125 |
+
Compute vertex normals of a triangular mesh by averaging neightboring face normals
|
126 |
+
TODO: can be improved.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
|
130 |
+
faces (np.ndarray): [T, 3] triangular face indices
|
131 |
+
face_normal (np.ndarray, optional): [..., T, 3] face normals.
|
132 |
+
None to compute face normals from vertices and faces. Defaults to None.
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
normals (np.ndarray): [..., N, 3] vertex normals
|
136 |
+
"""
|
137 |
+
if face_normal is None:
|
138 |
+
face_normal = compute_face_normal(vertices, faces)
|
139 |
+
vertex_normal = np.zeros_like(vertices, dtype=vertices.dtype)
|
140 |
+
for n in range(vertices.shape[0]):
|
141 |
+
for i in range(3):
|
142 |
+
vertex_normal[n, :, 0] += np.bincount(faces[:, i], weights=face_normal[n, :, 0], minlength=vertices.shape[1])
|
143 |
+
vertex_normal[n, :, 1] += np.bincount(faces[:, i], weights=face_normal[n, :, 1], minlength=vertices.shape[1])
|
144 |
+
vertex_normal[n, :, 2] += np.bincount(faces[:, i], weights=face_normal[n, :, 2], minlength=vertices.shape[1])
|
145 |
+
vertex_normal_norm = np.linalg.norm(vertex_normal, axis=-1, keepdims=True)
|
146 |
+
vertex_normal_norm[vertex_normal_norm == 0] = 1
|
147 |
+
vertex_normal /= vertex_normal_norm
|
148 |
+
return vertex_normal
|
149 |
+
|
150 |
+
|
151 |
+
@batched(2, None, 2)
|
152 |
+
def compute_vertex_normal_weighted(
|
153 |
+
vertices: np.ndarray,
|
154 |
+
faces: np.ndarray,
|
155 |
+
face_normal: np.ndarray = None
|
156 |
+
) -> np.ndarray:
|
157 |
+
"""
|
158 |
+
Compute vertex normals of a triangular mesh by weighted sum of neightboring face normals
|
159 |
+
according to the angles
|
160 |
+
|
161 |
+
Args:
|
162 |
+
vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
|
163 |
+
faces (np.ndarray): [..., T, 3] triangular face indices
|
164 |
+
face_normal (np.ndarray, optional): [..., T, 3] face normals.
|
165 |
+
None to compute face normals from vertices and faces. Defaults to None.
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
normals (np.ndarray): [..., N, 3] vertex normals
|
169 |
+
"""
|
170 |
+
if face_normal is None:
|
171 |
+
face_normal = compute_face_normal(vertices, faces)
|
172 |
+
face_angle = compute_face_angle(vertices, faces)
|
173 |
+
vertex_normal = np.zeros_like(vertices)
|
174 |
+
for n in range(vertices.shape[0]):
|
175 |
+
for i in range(3):
|
176 |
+
vertex_normal[n, :, 0] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 0] * face_angle[n, :, i], minlength=vertices.shape[1])
|
177 |
+
vertex_normal[n, :, 1] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 1] * face_angle[n, :, i], minlength=vertices.shape[1])
|
178 |
+
vertex_normal[n, :, 2] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 2] * face_angle[n, :, i], minlength=vertices.shape[1])
|
179 |
+
vertex_normal_norm = np.linalg.norm(vertex_normal, axis=-1, keepdims=True)
|
180 |
+
vertex_normal_norm[vertex_normal_norm == 0] = 1
|
181 |
+
vertex_normal /= vertex_normal_norm
|
182 |
+
return vertex_normal
|
183 |
+
|
184 |
+
|
185 |
+
def remove_corrupted_faces(
|
186 |
+
faces: np.ndarray
|
187 |
+
) -> np.ndarray:
|
188 |
+
"""
|
189 |
+
Remove corrupted faces (faces with duplicated vertices)
|
190 |
+
|
191 |
+
Args:
|
192 |
+
faces (np.ndarray): [T, 3] triangular face indices
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
np.ndarray: [T_, 3] triangular face indices
|
196 |
+
"""
|
197 |
+
corrupted = (faces[:, 0] == faces[:, 1]) | (faces[:, 1] == faces[:, 2]) | (faces[:, 2] == faces[:, 0])
|
198 |
+
return faces[~corrupted]
|
199 |
+
|
200 |
+
|
201 |
+
def merge_duplicate_vertices(
|
202 |
+
vertices: np.ndarray,
|
203 |
+
faces: np.ndarray,
|
204 |
+
tol: float = 1e-6
|
205 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
206 |
+
"""
|
207 |
+
Merge duplicate vertices of a triangular mesh.
|
208 |
+
Duplicate vertices are merged by selecte one of them, and the face indices are updated accordingly.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
vertices (np.ndarray): [N, 3] 3-dimensional vertices
|
212 |
+
faces (np.ndarray): [T, 3] triangular face indices
|
213 |
+
tol (float, optional): tolerance for merging. Defaults to 1e-6.
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
vertices (np.ndarray): [N_, 3] 3-dimensional vertices
|
217 |
+
faces (np.ndarray): [T, 3] triangular face indices
|
218 |
+
"""
|
219 |
+
vertices_round = np.round(vertices / tol)
|
220 |
+
_, uni_i, uni_inv = np.unique(vertices_round, return_index=True, return_inverse=True, axis=0)
|
221 |
+
vertices = vertices[uni_i]
|
222 |
+
faces = uni_inv[faces]
|
223 |
+
return vertices, faces
|
224 |
+
|
225 |
+
|
226 |
+
def remove_unreferenced_vertices(
|
227 |
+
faces: np.ndarray,
|
228 |
+
*vertice_attrs,
|
229 |
+
return_indices: bool = False
|
230 |
+
) -> Tuple[np.ndarray, ...]:
|
231 |
+
"""
|
232 |
+
Remove unreferenced vertices of a mesh.
|
233 |
+
Unreferenced vertices are removed, and the face indices are updated accordingly.
|
234 |
+
|
235 |
+
Args:
|
236 |
+
faces (np.ndarray): [T, P] face indices
|
237 |
+
*vertice_attrs: vertex attributes
|
238 |
+
|
239 |
+
Returns:
|
240 |
+
faces (np.ndarray): [T, P] face indices
|
241 |
+
*vertice_attrs: vertex attributes
|
242 |
+
indices (np.ndarray, optional): [N] indices of vertices that are kept. Defaults to None.
|
243 |
+
"""
|
244 |
+
P = faces.shape[-1]
|
245 |
+
fewer_indices, inv_map = np.unique(faces, return_inverse=True)
|
246 |
+
faces = inv_map.astype(np.int32).reshape(-1, P)
|
247 |
+
ret = [faces]
|
248 |
+
for attr in vertice_attrs:
|
249 |
+
ret.append(attr[fewer_indices])
|
250 |
+
if return_indices:
|
251 |
+
ret.append(fewer_indices)
|
252 |
+
return tuple(ret)
|
253 |
+
|
254 |
+
|
255 |
+
def subdivide_mesh_simple(
|
256 |
+
vertices: np.ndarray,
|
257 |
+
faces: np.ndarray,
|
258 |
+
n: int = 1
|
259 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
260 |
+
"""
|
261 |
+
Subdivide a triangular mesh by splitting each triangle into 4 smaller triangles.
|
262 |
+
NOTE: All original vertices are kept, and new vertices are appended to the end of the vertex list.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
vertices (np.ndarray): [N, 3] 3-dimensional vertices
|
266 |
+
faces (np.ndarray): [T, 3] triangular face indices
|
267 |
+
n (int, optional): number of subdivisions. Defaults to 1.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
vertices (np.ndarray): [N_, 3] subdivided 3-dimensional vertices
|
271 |
+
faces (np.ndarray): [4 * T, 3] subdivided triangular face indices
|
272 |
+
"""
|
273 |
+
for _ in range(n):
|
274 |
+
edges = np.stack([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], axis=0)
|
275 |
+
edges = np.sort(edges, axis=2)
|
276 |
+
uni_edges, uni_inv = np.unique(edges.reshape(-1, 2), return_inverse=True, axis=0)
|
277 |
+
uni_inv = uni_inv.reshape(3, -1)
|
278 |
+
midpoints = (vertices[uni_edges[:, 0]] + vertices[uni_edges[:, 1]]) / 2
|
279 |
+
|
280 |
+
n_vertices = vertices.shape[0]
|
281 |
+
vertices = np.concatenate([vertices, midpoints], axis=0)
|
282 |
+
faces = np.concatenate([
|
283 |
+
np.stack([faces[:, 0], n_vertices + uni_inv[0], n_vertices + uni_inv[2]], axis=1),
|
284 |
+
np.stack([faces[:, 1], n_vertices + uni_inv[1], n_vertices + uni_inv[0]], axis=1),
|
285 |
+
np.stack([faces[:, 2], n_vertices + uni_inv[2], n_vertices + uni_inv[1]], axis=1),
|
286 |
+
np.stack([n_vertices + uni_inv[0], n_vertices + uni_inv[1], n_vertices + uni_inv[2]], axis=1),
|
287 |
+
], axis=0)
|
288 |
+
return vertices, faces
|
289 |
+
|
290 |
+
|
291 |
+
def mesh_relations(
|
292 |
+
faces: np.ndarray,
|
293 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
294 |
+
"""
|
295 |
+
Calculate the relation between vertices and faces.
|
296 |
+
NOTE: The input mesh must be a manifold triangle mesh.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
faces (np.ndarray): [T, 3] triangular face indices
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
edges (np.ndarray): [E, 2] edge indices
|
303 |
+
edge2face (np.ndarray): [E, 2] edge to face relation. The second column is -1 if the edge is boundary.
|
304 |
+
face2edge (np.ndarray): [T, 3] face to edge relation
|
305 |
+
face2face (np.ndarray): [T, 3] face to face relation
|
306 |
+
"""
|
307 |
+
T = faces.shape[0]
|
308 |
+
edges = np.stack([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], axis=1).reshape(-1, 2) # [3T, 2]
|
309 |
+
edges = np.sort(edges, axis=1) # [3T, 2]
|
310 |
+
edges, face2edge, occurence = np.unique(edges, axis=0, return_inverse=True, return_counts=True) # [E, 2], [3T], [E]
|
311 |
+
E = edges.shape[0]
|
312 |
+
assert np.all(occurence <= 2), "The input mesh is not a manifold mesh."
|
313 |
+
|
314 |
+
# Edge to face relation
|
315 |
+
padding = np.arange(E, dtype=np.int32)[occurence == 1]
|
316 |
+
padded_face2edge = np.concatenate([face2edge, padding], axis=0) # [2E]
|
317 |
+
edge2face = np.argsort(padded_face2edge, kind='stable').reshape(-1, 2) // 3 # [E, 2]
|
318 |
+
edge2face_valid = edge2face[:, 1] < T # [E]
|
319 |
+
edge2face[~edge2face_valid, 1] = -1
|
320 |
+
|
321 |
+
# Face to edge relation
|
322 |
+
face2edge = face2edge.reshape(-1, 3) # [T, 3]
|
323 |
+
|
324 |
+
# Face to face relation
|
325 |
+
face2face = edge2face[face2edge] # [T, 3, 2]
|
326 |
+
face2face = face2face[face2face != np.arange(T)[:, None, None]].reshape(T, 3) # [T, 3]
|
327 |
+
|
328 |
+
return edges, edge2face, face2edge, face2face
|
329 |
+
|
330 |
+
|
331 |
+
@overload
|
332 |
+
def flatten_mesh_indices(faces1: np.ndarray, attr1: np.ndarray, *other_faces_attrs_pairs: np.ndarray) -> Tuple[np.ndarray, ...]:
|
333 |
+
"""
|
334 |
+
Rearrange the indices of a mesh to a flattened version. Vertices will be no longer shared.
|
335 |
+
|
336 |
+
### Parameters:
|
337 |
+
- `faces1`: [T, P] face indices of the first attribute
|
338 |
+
- `attr1`: [N1, ...] attributes of the first mesh
|
339 |
+
- ...
|
340 |
+
|
341 |
+
### Returns:
|
342 |
+
- `faces`: [T, P] flattened face indices, contigous from 0 to T * P - 1
|
343 |
+
- `attr1`: [T * P, ...] attributes of the first mesh, where every P values correspond to a face
|
344 |
+
_ ...
|
345 |
+
"""
|
346 |
+
def flatten_mesh_indices(*args: np.ndarray) -> Tuple[np.ndarray, ...]:
|
347 |
+
assert len(args) % 2 == 0, "The number of arguments must be even."
|
348 |
+
T, P = args[0].shape
|
349 |
+
assert all(arg.shape[0] == T and arg.shape[1] == P for arg in args[::2]), "The faces must have the same shape."
|
350 |
+
attr_flat = []
|
351 |
+
for faces_, attr_ in zip(args[::2], args[1::2]):
|
352 |
+
attr_flat_ = attr_[faces_].reshape(-1, *attr_.shape[1:])
|
353 |
+
attr_flat.append(attr_flat_)
|
354 |
+
faces_flat = np.arange(T * P, dtype=np.int32).reshape(T, P)
|
355 |
+
return faces_flat, *attr_flat
|
src/utils3d/numpy/quadmesh.py
ADDED
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy as sp
|
3 |
+
import scipy.optimize as spopt
|
4 |
+
from typing import *
|
5 |
+
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
'calc_quad_candidates',
|
9 |
+
'calc_quad_distortion',
|
10 |
+
'calc_quad_direction',
|
11 |
+
'calc_quad_smoothness',
|
12 |
+
'sovle_quad',
|
13 |
+
'sovle_quad_qp',
|
14 |
+
'tri_to_quad'
|
15 |
+
]
|
16 |
+
|
17 |
+
|
18 |
+
def calc_quad_candidates(
|
19 |
+
edges: np.ndarray,
|
20 |
+
face2edge: np.ndarray,
|
21 |
+
edge2face: np.ndarray,
|
22 |
+
):
|
23 |
+
"""
|
24 |
+
Calculate the candidate quad faces.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
edges (np.ndarray): [E, 2] edge indices
|
28 |
+
face2edge (np.ndarray): [T, 3] face to edge relation
|
29 |
+
edge2face (np.ndarray): [E, 2] edge to face relation
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
quads (np.ndarray): [Q, 4] quad candidate indices
|
33 |
+
quad2edge (np.ndarray): [Q, 4] edge to quad candidate relation
|
34 |
+
quad2adj (np.ndarray): [Q, 8] adjacent quad candidates of each quad candidate
|
35 |
+
quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid
|
36 |
+
"""
|
37 |
+
E = edges.shape[0]
|
38 |
+
T = face2edge.shape[0]
|
39 |
+
|
40 |
+
quads_valid = edge2face[:, 1] != -1
|
41 |
+
Q = quads_valid.sum()
|
42 |
+
quad2face = edge2face[quads_valid] # [Q, 2]
|
43 |
+
quad2edge = face2edge[quad2face] # [Q, 2, 3]
|
44 |
+
flag = quad2edge == np.arange(E)[quads_valid][:, None, None] # [Q, 2, 3]
|
45 |
+
flag = flag.argmax(axis=-1) # [Q, 2]
|
46 |
+
quad2edge = np.stack([
|
47 |
+
quad2edge[np.arange(Q)[:, None], np.arange(2)[None, :], (flag + 1) % 3],
|
48 |
+
quad2edge[np.arange(Q)[:, None], np.arange(2)[None, :], (flag + 2) % 3],
|
49 |
+
], axis=-1).reshape(Q, 4) # [Q, 4]
|
50 |
+
|
51 |
+
quads = np.concatenate([
|
52 |
+
np.where(
|
53 |
+
(edges[quad2edge[:, 0:1], 1:] == edges[quad2edge[:, 1:2], :]).any(axis=-1),
|
54 |
+
edges[quad2edge[:, 0:1], [[0, 1]]],
|
55 |
+
edges[quad2edge[:, 0:1], [[1, 0]]],
|
56 |
+
),
|
57 |
+
np.where(
|
58 |
+
(edges[quad2edge[:, 2:3], 1:] == edges[quad2edge[:, 3:4], :]).any(axis=-1),
|
59 |
+
edges[quad2edge[:, 2:3], [[0, 1]]],
|
60 |
+
edges[quad2edge[:, 2:3], [[1, 0]]],
|
61 |
+
),
|
62 |
+
], axis=1) # [Q, 4]
|
63 |
+
|
64 |
+
quad2adj = edge2face[quad2edge] # [Q, 4, 2]
|
65 |
+
quad2adj = quad2adj[quad2adj != quad2face[:, [0,0,1,1], None]].reshape(Q, 4) # [Q, 4]
|
66 |
+
quad2adj_valid = quad2adj != -1
|
67 |
+
quad2adj = face2edge[quad2adj] # [Q, 4, 3]
|
68 |
+
quad2adj[~quad2adj_valid, 0] = quad2edge[~quad2adj_valid]
|
69 |
+
quad2adj[~quad2adj_valid, 1:] = -1
|
70 |
+
quad2adj = quad2adj[quad2adj != quad2edge[..., None]].reshape(Q, 8) # [Q, 8]
|
71 |
+
edge_valid = -np.ones(E, dtype=np.int32)
|
72 |
+
edge_valid[quads_valid] = np.arange(Q)
|
73 |
+
quad2adj_valid = quad2adj != -1
|
74 |
+
quad2adj[quad2adj_valid] = edge_valid[quad2adj[quad2adj_valid]] # [Q, 8]
|
75 |
+
|
76 |
+
return quads, quad2edge, quad2adj, quads_valid
|
77 |
+
|
78 |
+
|
79 |
+
def calc_quad_distortion(
|
80 |
+
vertices: np.ndarray,
|
81 |
+
quads: np.ndarray,
|
82 |
+
):
|
83 |
+
"""
|
84 |
+
Calculate the distortion of each candidate quad face.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
vertices (np.ndarray): [N, 3] 3-dimensional vertices
|
88 |
+
quads (np.ndarray): [Q, 4] quad face indices
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
distortion (np.ndarray): [Q] distortion of each quad face
|
92 |
+
"""
|
93 |
+
edge0 = vertices[quads[:, 1]] - vertices[quads[:, 0]] # [Q, 3]
|
94 |
+
edge1 = vertices[quads[:, 2]] - vertices[quads[:, 1]] # [Q, 3]
|
95 |
+
edge2 = vertices[quads[:, 3]] - vertices[quads[:, 2]] # [Q, 3]
|
96 |
+
edge3 = vertices[quads[:, 0]] - vertices[quads[:, 3]] # [Q, 3]
|
97 |
+
cross = vertices[quads[:, 0]] - vertices[quads[:, 2]] # [Q, 3]
|
98 |
+
|
99 |
+
len0 = np.maximum(np.linalg.norm(edge0, axis=-1), 1e-10) # [Q]
|
100 |
+
len1 = np.maximum(np.linalg.norm(edge1, axis=-1), 1e-10) # [Q]
|
101 |
+
len2 = np.maximum(np.linalg.norm(edge2, axis=-1), 1e-10) # [Q]
|
102 |
+
len3 = np.maximum(np.linalg.norm(edge3, axis=-1), 1e-10) # [Q]
|
103 |
+
len_cross = np.maximum(np.linalg.norm(cross, axis=-1), 1e-10) # [Q]
|
104 |
+
|
105 |
+
angle0 = np.arccos(np.clip(np.sum(-edge0 * edge1, axis=-1) / (len0 * len1), -1, 1)) # [Q]
|
106 |
+
angle1 = np.arccos(np.clip(np.sum(-edge1 * cross, axis=-1) / (len1 * len_cross), -1, 1)) \
|
107 |
+
+ np.arccos(np.clip(np.sum(cross * edge2, axis=-1) / (len_cross * len2), -1, 1)) # [Q]
|
108 |
+
angle2 = np.arccos(np.clip(np.sum(-edge2 * edge3, axis=-1) / (len2 * len3), -1, 1)) # [Q]
|
109 |
+
angle3 = np.arccos(np.clip(np.sum(-edge3 * -cross, axis=-1) / (len3 * len_cross), -1, 1)) \
|
110 |
+
+ np.arccos(np.clip(np.sum(-cross * edge0, axis=-1) / (len_cross * len0), -1, 1)) # [Q]
|
111 |
+
|
112 |
+
normal0 = np.cross(edge0, edge1) # [Q, 3]
|
113 |
+
normal1 = np.cross(edge2, edge3) # [Q, 3]
|
114 |
+
normal0 = normal0 / np.maximum(np.linalg.norm(normal0, axis=-1, keepdims=True), 1e-10) # [Q, 3]
|
115 |
+
normal1 = normal1 / np.maximum(np.linalg.norm(normal1, axis=-1, keepdims=True), 1e-10) # [Q, 3]
|
116 |
+
angle_normal = np.arccos(np.clip(np.sum(normal0 * normal1, axis=-1), -1, 1)) # [Q]
|
117 |
+
|
118 |
+
D90 = np.pi / 2
|
119 |
+
D180 = np.pi
|
120 |
+
D360 = np.pi * 2
|
121 |
+
ang_eng = (np.abs(angle0 - D90)**2 + np.abs(angle1 - D90)**2 + np.abs(angle2 - D90)**2 + np.abs(angle3 - D90)**2) / 4 # [Q]
|
122 |
+
dist_eng = np.abs(angle0 - angle2)**2 / np.minimum(np.maximum(np.minimum(angle0, angle2), 1e-10), np.maximum(D180 - np.maximum(angle0, angle2), 1e-10)) \
|
123 |
+
+ np.abs(angle1 - angle3)**2 / np.minimum(np.maximum(np.minimum(angle1, angle3), 1e-10), np.maximum(D180 - np.maximum(angle1, angle3), 1e-10)) # [Q]
|
124 |
+
plane_eng = np.where(angle_normal < D90/2, np.abs(angle_normal)**2, 1e10) # [Q]
|
125 |
+
eng = ang_eng + 2 * dist_eng + 2 * plane_eng # [Q]
|
126 |
+
|
127 |
+
return eng
|
128 |
+
|
129 |
+
|
130 |
+
def calc_quad_direction(
|
131 |
+
vertices: np.ndarray,
|
132 |
+
quads: np.ndarray,
|
133 |
+
):
|
134 |
+
"""
|
135 |
+
Calculate the direction of each candidate quad face.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
vertices (np.ndarray): [N, 3] 3-dimensional vertices
|
139 |
+
quads (np.ndarray): [Q, 4] quad face indices
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
direction (np.ndarray): [Q, 4] direction of each quad face.
|
143 |
+
Represented by the angle between the crossing and each edge.
|
144 |
+
"""
|
145 |
+
mid0 = (vertices[quads[:, 0]] + vertices[quads[:, 1]]) / 2 # [Q, 3]
|
146 |
+
mid1 = (vertices[quads[:, 1]] + vertices[quads[:, 2]]) / 2 # [Q, 3]
|
147 |
+
mid2 = (vertices[quads[:, 2]] + vertices[quads[:, 3]]) / 2 # [Q, 3]
|
148 |
+
mid3 = (vertices[quads[:, 3]] + vertices[quads[:, 0]]) / 2 # [Q, 3]
|
149 |
+
|
150 |
+
cross0 = mid2 - mid0 # [Q, 3]
|
151 |
+
cross1 = mid3 - mid1 # [Q, 3]
|
152 |
+
cross0 = cross0 / np.maximum(np.linalg.norm(cross0, axis=-1, keepdims=True), 1e-10) # [Q, 3]
|
153 |
+
cross1 = cross1 / np.maximum(np.linalg.norm(cross1, axis=-1, keepdims=True), 1e-10) # [Q, 3]
|
154 |
+
|
155 |
+
edge0 = vertices[quads[:, 1]] - vertices[quads[:, 0]] # [Q, 3]
|
156 |
+
edge1 = vertices[quads[:, 2]] - vertices[quads[:, 1]] # [Q, 3]
|
157 |
+
edge2 = vertices[quads[:, 3]] - vertices[quads[:, 2]] # [Q, 3]
|
158 |
+
edge3 = vertices[quads[:, 0]] - vertices[quads[:, 3]] # [Q, 3]
|
159 |
+
edge0 = edge0 / np.maximum(np.linalg.norm(edge0, axis=-1, keepdims=True), 1e-10) # [Q, 3]
|
160 |
+
edge1 = edge1 / np.maximum(np.linalg.norm(edge1, axis=-1, keepdims=True), 1e-10) # [Q, 3]
|
161 |
+
edge2 = edge2 / np.maximum(np.linalg.norm(edge2, axis=-1, keepdims=True), 1e-10) # [Q, 3]
|
162 |
+
edge3 = edge3 / np.maximum(np.linalg.norm(edge3, axis=-1, keepdims=True), 1e-10) # [Q, 3]
|
163 |
+
|
164 |
+
direction = np.stack([
|
165 |
+
np.arccos(np.clip(np.sum(cross0 * edge0, axis=-1), -1, 1)),
|
166 |
+
np.arccos(np.clip(np.sum(cross1 * edge1, axis=-1), -1, 1)),
|
167 |
+
np.arccos(np.clip(np.sum(-cross0 * edge2, axis=-1), -1, 1)),
|
168 |
+
np.arccos(np.clip(np.sum(-cross1 * edge3, axis=-1), -1, 1)),
|
169 |
+
], axis=-1) # [Q, 4]
|
170 |
+
|
171 |
+
return direction
|
172 |
+
|
173 |
+
|
174 |
+
def calc_quad_smoothness(
|
175 |
+
quad2edge: np.ndarray,
|
176 |
+
quad2adj: np.ndarray,
|
177 |
+
quads_direction: np.ndarray,
|
178 |
+
):
|
179 |
+
"""
|
180 |
+
Calculate the smoothness of each candidate quad face connection.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face
|
184 |
+
quads_direction (np.ndarray): [Q, 4] direction of each quad face
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection
|
188 |
+
"""
|
189 |
+
Q = quad2adj.shape[0]
|
190 |
+
quad2adj_valid = quad2adj != -1
|
191 |
+
connections = np.stack([
|
192 |
+
np.arange(Q)[:, None].repeat(8, axis=1),
|
193 |
+
quad2adj,
|
194 |
+
], axis=-1)[quad2adj_valid] # [C, 2]
|
195 |
+
shared_edge_idx_0 = np.array([[0, 0, 1, 1, 2, 2, 3, 3]]).repeat(Q, axis=0)[quad2adj_valid] # [C]
|
196 |
+
shared_edge_idx_1 = np.argmax(quad2edge[quad2adj][quad2adj_valid] == quad2edge[connections[:, 0], shared_edge_idx_0][:, None], axis=-1) # [C]
|
197 |
+
valid_smoothness = np.abs(quads_direction[connections[:, 0], shared_edge_idx_0] - quads_direction[connections[:, 1], shared_edge_idx_1])**2 # [C]
|
198 |
+
smoothness = np.zeros([Q, 8], dtype=np.float32)
|
199 |
+
smoothness[quad2adj_valid] = valid_smoothness
|
200 |
+
return smoothness
|
201 |
+
|
202 |
+
|
203 |
+
def sovle_quad(
|
204 |
+
face2edge: np.ndarray,
|
205 |
+
edge2face: np.ndarray,
|
206 |
+
quad2adj: np.ndarray,
|
207 |
+
quads_distortion: np.ndarray,
|
208 |
+
quads_smoothness: np.ndarray,
|
209 |
+
quads_valid: np.ndarray,
|
210 |
+
):
|
211 |
+
"""
|
212 |
+
Solve the quad mesh from the candidate quad faces.
|
213 |
+
|
214 |
+
Args:
|
215 |
+
face2edge (np.ndarray): [T, 3] face to edge relation
|
216 |
+
edge2face (np.ndarray): [E, 2] edge to face relation
|
217 |
+
quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face
|
218 |
+
quads_distortion (np.ndarray): [Q] distortion of each quad face
|
219 |
+
quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection
|
220 |
+
quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
weights (np.ndarray): [Q] weight of each valid quad face
|
224 |
+
"""
|
225 |
+
T = face2edge.shape[0]
|
226 |
+
E = edge2face.shape[0]
|
227 |
+
Q = quads_distortion.shape[0]
|
228 |
+
edge_valid = -np.ones(E, dtype=np.int32)
|
229 |
+
edge_valid[quads_valid] = np.arange(Q)
|
230 |
+
|
231 |
+
quads_connection = np.stack([
|
232 |
+
np.arange(Q)[:, None].repeat(8, axis=1),
|
233 |
+
quad2adj,
|
234 |
+
], axis=-1)[quad2adj != -1] # [C, 2]
|
235 |
+
quads_connection = np.sort(quads_connection, axis=-1) # [C, 2]
|
236 |
+
quads_connection, quads_connection_idx = np.unique(quads_connection, axis=0, return_index=True) # [C, 2], [C]
|
237 |
+
quads_smoothness = quads_smoothness[quad2adj != -1] # [C]
|
238 |
+
quads_smoothness = quads_smoothness[quads_connection_idx] # [C]
|
239 |
+
C = quads_connection.shape[0]
|
240 |
+
|
241 |
+
# Construct the linear programming problem
|
242 |
+
|
243 |
+
# Variables:
|
244 |
+
# quads_weight: [Q] weight of each quad face
|
245 |
+
# tri_min_weight: [T] minimum weight of each triangle face
|
246 |
+
# conn_min_weight: [C] minimum weight of each quad face connection
|
247 |
+
# conn_max_weight: [C] maximum weight of each quad face connection
|
248 |
+
# Objective:
|
249 |
+
# mimi
|
250 |
+
|
251 |
+
c = np.concatenate([
|
252 |
+
quads_distortion - 3,
|
253 |
+
quads_smoothness*4 - 2,
|
254 |
+
quads_smoothness*4,
|
255 |
+
], axis=0) # [Q+C]
|
256 |
+
|
257 |
+
A_ub_triplet = np.concatenate([
|
258 |
+
np.stack([np.arange(T), edge_valid[face2edge[:, 0]], np.ones(T)], axis=1), # [T, 3]
|
259 |
+
np.stack([np.arange(T), edge_valid[face2edge[:, 1]], np.ones(T)], axis=1), # [T, 3]
|
260 |
+
np.stack([np.arange(T), edge_valid[face2edge[:, 2]], np.ones(T)], axis=1), # [T, 3]
|
261 |
+
np.stack([np.arange(T, T+C), np.arange(Q, Q+C), np.ones(C)], axis=1), # [C, 3]
|
262 |
+
np.stack([np.arange(T, T+C), quads_connection[:, 0], -np.ones(C)], axis=1), # [C, 3]
|
263 |
+
np.stack([np.arange(T, T+C), quads_connection[:, 1], -np.ones(C)], axis=1), # [C, 3]
|
264 |
+
np.stack([np.arange(T+C, T+2*C), np.arange(Q+C, Q+2*C), -np.ones(C)], axis=1), # [C, 3]
|
265 |
+
np.stack([np.arange(T+C, T+2*C), quads_connection[:, 0], np.ones(C)], axis=1), # [C, 3]
|
266 |
+
np.stack([np.arange(T+C, T+2*C), quads_connection[:, 1], np.ones(C)], axis=1), # [C, 3]
|
267 |
+
], axis=0) # [3T+6C, 3]
|
268 |
+
A_ub_triplet = A_ub_triplet[A_ub_triplet[:, 1] != -1] # [3T', 3]
|
269 |
+
A_ub = sp.sparse.coo_matrix((A_ub_triplet[:, 2], (A_ub_triplet[:, 0], A_ub_triplet[:, 1])), shape=[T+2*C, Q+2*C]) # [T,
|
270 |
+
b_ub = np.concatenate([np.ones(T), -np.ones(C), np.ones(C)], axis=0) # [T+2C]
|
271 |
+
bound = np.stack([
|
272 |
+
np.concatenate([np.zeros(Q), -np.ones(C), np.zeros(C)], axis=0),
|
273 |
+
np.concatenate([np.ones(Q), np.ones(C), np.ones(C)], axis=0),
|
274 |
+
], axis=1) # [Q+2C, 2]
|
275 |
+
A_eq = None
|
276 |
+
b_eq = None
|
277 |
+
|
278 |
+
print('Solver statistics:')
|
279 |
+
print(f' #T = {T}')
|
280 |
+
print(f' #Q = {Q}')
|
281 |
+
print(f' #C = {C}')
|
282 |
+
|
283 |
+
# Solve the linear programming problem
|
284 |
+
last_num_valid = 0
|
285 |
+
for i in range(100):
|
286 |
+
res_ = spopt.linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=bound)
|
287 |
+
if not res_.success:
|
288 |
+
print(f' Iter {i} | Failed with {res_.message}')
|
289 |
+
break
|
290 |
+
res = res_
|
291 |
+
weights = res.x[:Q]
|
292 |
+
valid = (weights > 0.5)
|
293 |
+
num_valid = valid.sum()
|
294 |
+
print(f' Iter {i} | #Q_valid = {num_valid}')
|
295 |
+
if num_valid == last_num_valid:
|
296 |
+
break
|
297 |
+
last_num_valid = num_valid
|
298 |
+
A_eq_triplet = np.stack([
|
299 |
+
np.arange(num_valid),
|
300 |
+
np.arange(Q)[valid],
|
301 |
+
np.ones(num_valid),
|
302 |
+
], axis=1) # [num_valid, 3]
|
303 |
+
A_eq = sp.sparse.coo_matrix((A_eq_triplet[:, 2], (A_eq_triplet[:, 0], A_eq_triplet[:, 1])), shape=[num_valid, Q+2*C]) # [num_valid, Q+C]
|
304 |
+
b_eq = np.where(weights[valid] > 0.5, 1, 0) # [num_valid]
|
305 |
+
|
306 |
+
# Return the result
|
307 |
+
quads_weight = res.x[:Q]
|
308 |
+
conn_min_weight = res.x[Q:Q+C]
|
309 |
+
conn_max_weight = res.x[Q+C:Q+2*C]
|
310 |
+
return quads_weight, conn_min_weight, conn_max_weight
|
311 |
+
|
312 |
+
|
313 |
+
def sovle_quad_qp(
|
314 |
+
face2edge: np.ndarray,
|
315 |
+
edge2face: np.ndarray,
|
316 |
+
quad2adj: np.ndarray,
|
317 |
+
quads_distortion: np.ndarray,
|
318 |
+
quads_smoothness: np.ndarray,
|
319 |
+
quads_valid: np.ndarray,
|
320 |
+
):
|
321 |
+
"""
|
322 |
+
Solve the quad mesh from the candidate quad faces.
|
323 |
+
|
324 |
+
Args:
|
325 |
+
face2edge (np.ndarray): [T, 3] face to edge relation
|
326 |
+
edge2face (np.ndarray): [E, 2] edge to face relation
|
327 |
+
quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face
|
328 |
+
quads_distortion (np.ndarray): [Q] distortion of each quad face
|
329 |
+
quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection
|
330 |
+
quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid
|
331 |
+
|
332 |
+
Returns:
|
333 |
+
weights (np.ndarray): [Q] weight of each valid quad face
|
334 |
+
"""
|
335 |
+
T = face2edge.shape[0]
|
336 |
+
E = edge2face.shape[0]
|
337 |
+
Q = quads_distortion.shape[0]
|
338 |
+
edge_valid = -np.ones(E, dtype=np.int32)
|
339 |
+
edge_valid[quads_valid] = np.arange(Q)
|
340 |
+
|
341 |
+
# Construct the quadratic programming problem
|
342 |
+
C_smoothness_triplet = np.stack([
|
343 |
+
np.arange(Q)[:, None].repeat(8, axis=1)[quad2adj != -1],
|
344 |
+
quad2adj[quad2adj != -1],
|
345 |
+
5 * quads_smoothness[quad2adj != -1],
|
346 |
+
], axis=-1) # [C, 3]
|
347 |
+
# C_smoothness_triplet = np.concatenate([
|
348 |
+
# C_smoothness_triplet,
|
349 |
+
# np.stack([np.arange(Q), np.arange(Q), 20*np.ones(Q)], axis=1),
|
350 |
+
# ], axis=0) # [C+Q, 3]
|
351 |
+
C_smoothness = sp.sparse.coo_matrix((C_smoothness_triplet[:, 2], (C_smoothness_triplet[:, 0], C_smoothness_triplet[:, 1])), shape=[Q, Q]) # [Q, Q]
|
352 |
+
C_smoothness = C_smoothness.tocsc()
|
353 |
+
C_dist = quads_distortion - 20 # [Q]
|
354 |
+
|
355 |
+
A_eq = sp.sparse.coo_matrix((np.zeros(Q), (np.zeros(Q), np.arange(Q))), shape=[1, Q]) # [1, Q]\
|
356 |
+
A_eq = A_eq.tocsc()
|
357 |
+
b_eq = np.array([0])
|
358 |
+
|
359 |
+
A_ub_triplet = np.concatenate([
|
360 |
+
np.stack([np.arange(T), edge_valid[face2edge[:, 0]], np.ones(T)], axis=1), # [T, 3]
|
361 |
+
np.stack([np.arange(T), edge_valid[face2edge[:, 1]], np.ones(T)], axis=1), # [T, 3]
|
362 |
+
np.stack([np.arange(T), edge_valid[face2edge[:, 2]], np.ones(T)], axis=1), # [T, 3]
|
363 |
+
], axis=0) # [3T, 3]
|
364 |
+
A_ub_triplet = A_ub_triplet[A_ub_triplet[:, 1] != -1] # [3T', 3]
|
365 |
+
A_ub = sp.sparse.coo_matrix((A_ub_triplet[:, 2], (A_ub_triplet[:, 0], A_ub_triplet[:, 1])), shape=[T, Q]) # [T, Q]
|
366 |
+
A_ub = A_ub.tocsc()
|
367 |
+
b_ub = np.ones(T)
|
368 |
+
|
369 |
+
lb = np.zeros(Q)
|
370 |
+
ub = np.ones(Q)
|
371 |
+
|
372 |
+
import piqp
|
373 |
+
solver = piqp.SparseSolver()
|
374 |
+
solver.settings.verbose = True
|
375 |
+
solver.settings.compute_timings = True
|
376 |
+
solver.setup(C_smoothness, C_dist, A_eq, b_eq, A_ub, b_ub, lb, ub)
|
377 |
+
|
378 |
+
status = solver.solve()
|
379 |
+
|
380 |
+
# x = cp.Variable(Q)
|
381 |
+
# prob = cp.Problem(
|
382 |
+
# cp.Minimize(cp.quad_form(x, C_smoothness) + C_dist.T @ x),
|
383 |
+
# [
|
384 |
+
# A_ub @ x <= b_ub,
|
385 |
+
# x >= 0, x <= 1,
|
386 |
+
# ]
|
387 |
+
# )
|
388 |
+
|
389 |
+
# # Solve the quadratic programming problem
|
390 |
+
# prob.solve(solver=cp.PIQP, verbose=True)
|
391 |
+
|
392 |
+
# Return the result
|
393 |
+
weights = solver.result.x
|
394 |
+
return weights
|
395 |
+
|
396 |
+
|
397 |
+
def tri_to_quad(
|
398 |
+
vertices: np.ndarray,
|
399 |
+
faces: np.ndarray,
|
400 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
401 |
+
"""
|
402 |
+
Convert a triangle mesh to a quad mesh.
|
403 |
+
NOTE: The input mesh must be a manifold mesh.
|
404 |
+
|
405 |
+
Args:
|
406 |
+
vertices (np.ndarray): [N, 3] 3-dimensional vertices
|
407 |
+
faces (np.ndarray): [T, 3] triangular face indices
|
408 |
+
|
409 |
+
Returns:
|
410 |
+
vertices (np.ndarray): [N_, 3] 3-dimensional vertices
|
411 |
+
faces (np.ndarray): [Q, 4] quad face indices
|
412 |
+
"""
|
413 |
+
raise NotImplementedError
|
414 |
+
|
415 |
+
|
416 |
+
if __name__ == '__main__':
|
417 |
+
import os
|
418 |
+
import sys
|
419 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')))
|
420 |
+
import utils3d
|
421 |
+
import numpy as np
|
422 |
+
import cv2
|
423 |
+
from vis import vis_edge_color
|
424 |
+
|
425 |
+
file = 'miku'
|
426 |
+
|
427 |
+
vertices, faces = utils3d.io.read_ply(f'test/assets/{file}.ply')
|
428 |
+
edges, edge2face, face2edge, face2face = calc_relations(faces)
|
429 |
+
quad_cands, quad2edge, quad2adj, quad_valid = calc_quad_candidates(edges, face2edge, edge2face)
|
430 |
+
distortion = calc_quad_distortion(vertices, quad_cands)
|
431 |
+
direction = calc_quad_direction(vertices, quad_cands)
|
432 |
+
smoothness = calc_quad_smoothness(quad2edge, quad2adj, direction)
|
433 |
+
boundary_edges = edges[edge2face[:, 1] == -1]
|
434 |
+
quads_weight, conn_min_weight, conn_max_weight = sovle_quad(face2edge, edge2face, quad2adj, distortion, smoothness, quad_valid)
|
435 |
+
quads = quad_cands[quads_weight > 0.5]
|
436 |
+
print('Mesh statistics')
|
437 |
+
print(f' #V = {vertices.shape[0]}')
|
438 |
+
print(f' #F = {faces.shape[0]}')
|
439 |
+
print(f' #E = {edges.shape[0]}')
|
440 |
+
print(f' #B = {boundary_edges.shape[0]}')
|
441 |
+
print(f' #Q_cand = {quad_cands.shape[0]}')
|
442 |
+
print(f' #Q = {quads.shape[0]}')
|
443 |
+
|
444 |
+
utils3d.io.write_ply(f'test/assets/{file}_boundary_edges.ply', vertices=vertices, edges=boundary_edges)
|
445 |
+
utils3d.io.write_ply(f'test/assets/{file}_quad_candidates.ply', vertices=vertices, faces=quads)
|
446 |
+
|
447 |
+
edge_colors = np.zeros([edges.shape[0], 3], dtype=np.uint8)
|
448 |
+
distortion = (distortion - distortion.min()) / (distortion.max() - distortion.min())
|
449 |
+
distortion = (distortion * 255).astype(np.uint8)
|
450 |
+
edge_colors[quad_valid] = cv2.cvtColor(cv2.applyColorMap(distortion, cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
|
451 |
+
utils3d.io.write_ply(f'test/assets/{file}_quad_candidates_distortion.ply', **vis_edge_color(vertices, edges, edge_colors))
|
452 |
+
|
453 |
+
edge_colors = np.zeros([edges.shape[0], 3], dtype=np.uint8)
|
454 |
+
edge_colors[quad_valid] = cv2.cvtColor(cv2.applyColorMap((quads_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
|
455 |
+
utils3d.io.write_ply(f'test/assets/{file}_quad_candidates_weights.ply', **vis_edge_color(vertices, edges, edge_colors))
|
456 |
+
utils3d.io.write_ply(f'test/assets/{file}_quad.ply', vertices=vertices, faces=quads)
|
457 |
+
|
458 |
+
quad_centers = vertices[quad_cands].mean(axis=1)
|
459 |
+
conns = np.stack([
|
460 |
+
np.arange(quad_cands.shape[0])[:, None].repeat(8, axis=1),
|
461 |
+
quad2adj,
|
462 |
+
], axis=-1)[quad2adj != -1] # [C, 2]
|
463 |
+
conns, conns_idx = np.unique(np.sort(conns, axis=-1), axis=0, return_index=True) # [C, 2], [C]
|
464 |
+
smoothness = smoothness[quad2adj != -1][conns_idx] # [C]
|
465 |
+
conns_color = cv2.cvtColor(cv2.applyColorMap((smoothness * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
|
466 |
+
utils3d.io.write_ply(f'test/assets/{file}_quad_conn_smoothness.ply', **vis_edge_color(quad_centers, conns, conns_color))
|
467 |
+
conns_color = cv2.cvtColor(cv2.applyColorMap((conn_min_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
|
468 |
+
utils3d.io.write_ply(f'test/assets/{file}_quad_conn_min.ply', **vis_edge_color(quad_centers, conns, conns_color))
|
469 |
+
conns_color = cv2.cvtColor(cv2.applyColorMap((conn_max_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
|
470 |
+
utils3d.io.write_ply(f'test/assets/{file}_quad_conn_max.ply', **vis_edge_color(quad_centers, conns, conns_color))
|
471 |
+
|
472 |
+
|