ruili3 commited on
Commit
860c6b0
·
1 Parent(s): 0cd3872

init LaRI demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +195 -0
  2. app.py +282 -0
  3. demo.py +136 -0
  4. requirements.txt +19 -0
  5. src/lari/model/__init__.py +2 -0
  6. src/lari/model/blocks.py +209 -0
  7. src/lari/model/dinoseg_model.py +153 -0
  8. src/lari/model/dinov2/__init__.py +6 -0
  9. src/lari/model/dinov2/hub/__init__.py +4 -0
  10. src/lari/model/dinov2/hub/backbones.py +156 -0
  11. src/lari/model/dinov2/hub/utils.py +39 -0
  12. src/lari/model/dinov2/layers/__init__.py +11 -0
  13. src/lari/model/dinov2/layers/attention.py +89 -0
  14. src/lari/model/dinov2/layers/block.py +259 -0
  15. src/lari/model/dinov2/layers/dino_head.py +58 -0
  16. src/lari/model/dinov2/layers/drop_path.py +34 -0
  17. src/lari/model/dinov2/layers/layer_scale.py +27 -0
  18. src/lari/model/dinov2/layers/mlp.py +40 -0
  19. src/lari/model/dinov2/layers/patch_embed.py +88 -0
  20. src/lari/model/dinov2/layers/swiglu_ffn.py +72 -0
  21. src/lari/model/dinov2/models/__init__.py +43 -0
  22. src/lari/model/dinov2/models/vision_transformer.py +396 -0
  23. src/lari/model/dinov2/utils/__init__.py +4 -0
  24. src/lari/model/dinov2/utils/cluster.py +95 -0
  25. src/lari/model/dinov2/utils/config.py +72 -0
  26. src/lari/model/dinov2/utils/dtype.py +37 -0
  27. src/lari/model/dinov2/utils/param_groups.py +103 -0
  28. src/lari/model/dinov2/utils/utils.py +95 -0
  29. src/lari/model/dpt_seg_head.py +158 -0
  30. src/lari/model/heads.py +104 -0
  31. src/lari/model/lari_model.py +177 -0
  32. src/lari/model/utils.py +38 -0
  33. src/lari/utils/__init__.py +0 -0
  34. src/lari/utils/geometry_numpy.py +187 -0
  35. src/lari/utils/geometry_torch.py +221 -0
  36. src/utils/__init__.py +2 -0
  37. src/utils/vis.py +105 -0
  38. src/utils3d/README.md +3 -0
  39. src/utils3d/__init__.py +20 -0
  40. src/utils3d/_helpers.py +35 -0
  41. src/utils3d/_unified/__init__.py +934 -0
  42. src/utils3d/_unified/__init__.pyi +0 -0
  43. src/utils3d/io/__init__.py +3 -0
  44. src/utils3d/io/colmap.py +139 -0
  45. src/utils3d/io/obj.py +146 -0
  46. src/utils3d/io/ply.py +104 -0
  47. src/utils3d/numpy/__init__.py +142 -0
  48. src/utils3d/numpy/_helpers.py +93 -0
  49. src/utils3d/numpy/mesh.py +355 -0
  50. src/utils3d/numpy/quadmesh.py +472 -0
.gitignore ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scripts/rendering/blender-*
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+ .vscode/
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/#use-with-ide
113
+ .pdm.toml
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ .env
127
+ .venv
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
164
+ *.sif
165
+ blender-4.2.5-linux-x64*/
166
+ *.zip
167
+ *.png
168
+ *.jpg
169
+ *.log
170
+ intermediate/
171
+ __pycache__
172
+ *.ply
173
+ *.npy
174
+ *.npz
175
+ *.obj
176
+ *.mtl
177
+ *.json.gz
178
+ dcgm/
179
+ wandb/
180
+ # *.json
181
+
182
+
183
+
184
+ # Exception to add training list
185
+ !lgm_leq20Kpts_simtopo25_train.json.gz
186
+ !lgm_leq20Kpts_simtopo25_test.json.gz
187
+ !lgm_leq20Kpts_train.json.gz
188
+ !lgm_leq20Kpts_train_same_size_wrt_simtopo.json.gz
189
+ *.mp4
190
+ *.gif
191
+ *.glb
192
+ !lgm_leq20Kpts_plus_3Kremain_train.json.gz
193
+ !lgm_leq20Kpts_test_cleaned.json.gz
194
+ !lgm_leq20Kpts_train_cleaned.json.gz
195
+ test_metrics.json
app.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gradio
3
+ import torch
4
+ import torch.backends.cudnn as cudnn
5
+ from src.utils.vis import prob_to_mask
6
+ from src.lari.model import LaRIModel, DinoSegModel
7
+ from tools import load_model, process_image, post_process_output, get_masked_depth, save_to_glb, get_point_cloud, removebg_crop
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ parser = argparse.ArgumentParser("Arguments for deploying a LaRI Demo")
11
+
12
+ parser.add_argument(
13
+ "--model_info_pm",
14
+ type=str,
15
+ default="LaRIModel(use_pretrained = 'moge_full', num_output_layer = 5, head_type = 'point')",
16
+ help="Network parameters to load the model",
17
+ )
18
+
19
+ parser.add_argument(
20
+ "--model_info_mask",
21
+ type=str,
22
+ default="DinoSegModel(use_pretrained = 'dinov2', dim_proj = 256, pretrained_path = '', num_output_layer = 4, output_type = 'ray_stop')",
23
+ help="Network parameters to load the model",
24
+ )
25
+
26
+ parser.add_argument(
27
+ "--ckpt_path_pm",
28
+ type=str,
29
+ default="lari_obj_16k_pointmap.pth",
30
+ help="Path to pre-trained weights",
31
+ )
32
+
33
+ parser.add_argument(
34
+ "--ckpt_path_mask",
35
+ type=str,
36
+ default="lari_obj_16k_seg.pth",
37
+ help="Path to pre-trained weights",
38
+ )
39
+
40
+ parser.add_argument(
41
+ "--resolution", type=int, default=512, help="Default model resolution"
42
+ )
43
+ args = parser.parse_args()
44
+
45
+
46
+
47
+ def model_forward(pil_input, layered_id, rembg_checkbox):
48
+ """
49
+ Perform LaRI estimation by:
50
+ 1. image processing
51
+ 2. network forward
52
+ 3. save masked layered depth image
53
+ 4. save point cloud
54
+ """
55
+ if pil_input is None:
56
+ return (None, None, None, None, None, None)
57
+
58
+ if rembg_checkbox:
59
+ pil_input = removebg_crop(pil_input)
60
+
61
+ # Process the input image.
62
+ input_tensor, ori_img_tensor, crop_coords, original_size = process_image(
63
+ pil_input, resolution=512
64
+ )
65
+ input_tensor = input_tensor.to(device)
66
+
67
+ # Run inference.
68
+ with torch.no_grad():
69
+ # lari map
70
+ pred_dict = model_pm(input_tensor)
71
+ lari_map = -pred_dict["pts3d"].squeeze(
72
+ 0
73
+ ) # Expected output shape: (H_reso, W_reso, L, 3)
74
+ # mask
75
+ if model_mask:
76
+ pred_dict = model_mask(input_tensor)
77
+ assert "seg_prob" in pred_dict
78
+ valid_mask = prob_to_mask(pred_dict["seg_prob"].squeeze(0)) # H W L 1
79
+ else:
80
+ h, w, l, _ = lari_map.shape
81
+ valid_mask = torch.new_ones((h, w, l, 1), device=lari_map.device)
82
+
83
+ # crop & resize the output to the original resolution.
84
+ if original_size[0] != args.resolution or original_size[1] != args.resolution:
85
+ lari_map = post_process_output(lari_map, crop_coords, original_size) # H W L 3
86
+ valid_mask = post_process_output(
87
+ valid_mask.float(), crop_coords, original_size
88
+ ).bool() # H W L 1
89
+
90
+ max_n_layer = min(valid_mask.shape[-2], lari_map.shape[-2])
91
+ valid_mask = valid_mask[:, :, :max_n_layer, :]
92
+ lari_map = lari_map[:, :, :max_n_layer, :]
93
+
94
+ curr_layer_id = min(max_n_layer - 1, layered_id - 1)
95
+
96
+ # masked depth list
97
+ depth_image = get_masked_depth(
98
+ lari_map=lari_map, valid_mask=valid_mask, layer_id=curr_layer_id
99
+ )
100
+ # point cloud
101
+ glb_path, ply_path = get_point_cloud(
102
+ lari_map, ori_img_tensor, valid_mask, first_layer_color="pseudo"
103
+ )
104
+
105
+ return (
106
+ depth_image,
107
+ glb_path,
108
+ lari_map,
109
+ valid_mask,
110
+ 0,
111
+ max_n_layer - 1,
112
+ glb_path,
113
+ ply_path,
114
+ pil_input,
115
+ )
116
+
117
+
118
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
119
+ cudnn.benchmark = True
120
+
121
+
122
+ # Download the file
123
+ model_path_pm = hf_hub_download(repo_id="ruili3/LaRI", filename=args.ckpt_path_pm, repo_type="model")
124
+ model_path_mask = hf_hub_download(repo_id="ruili3/LaRI", filename=args.ckpt_path_mask, repo_type="model")
125
+
126
+
127
+ # Load the model with pretrained weights.
128
+ model_pm = load_model(args.model_info_pm, model_path_pm, device)
129
+ model_mask = (
130
+ load_model(args.model_info_mask, model_path_mask, device)
131
+ if args.model_info_mask is not None
132
+ else None
133
+ )
134
+
135
+
136
+ def change_layer(slider_layer_id, lari_map, valid_mask, min_layer_id, max_layer_id):
137
+
138
+ if lari_map is None:
139
+ return
140
+
141
+ slider_layer_id = slider_layer_id - 1
142
+ curr_layer_id = min(slider_layer_id, max_layer_id)
143
+ curr_layer_id = max(curr_layer_id, min_layer_id)
144
+
145
+ # masked depth list
146
+ depth_image = get_masked_depth(
147
+ lari_map=lari_map, valid_mask=valid_mask, layer_id=curr_layer_id
148
+ )
149
+
150
+ return depth_image
151
+
152
+
153
+ def clear_everything():
154
+ return (
155
+ gradio.update(value=None),
156
+ gradio.update(value=None),
157
+ gradio.update(value=None),
158
+ gradio.update(value=None),
159
+ gradio.update(value=None),
160
+ gradio.update(value=None),
161
+ gradio.update(value=None),
162
+ )
163
+
164
+
165
+ with gradio.Blocks(
166
+ css=""".gradio-container {margin: 0 !important; min-width: 100%};""",
167
+ title="LaRI Demo",
168
+ ) as demo:
169
+
170
+ gradio.Markdown(
171
+ "<h1 style='text-align: center;'>LaRI: Layered Ray Intersections for Single-view 3D Geometric Reasoning</h1>",
172
+ elem_id="title",
173
+ )
174
+
175
+ gradio.Markdown(
176
+ """
177
+ This is the official demo of Layered Ray Intersection (<a href="https://ruili3.github.io/lari/index.html" target="_blank" style="color: #2a9d8f;">LaRI</a>). For a quick start, click the images in 'Examples' and then click the 'Process' Button.
178
+
179
+ You can try with your own images with following steps:
180
+ - Load an image;
181
+ - Click the 'Process' button;
182
+ - Browse layered depth maps (z-channel of the resulting LaRI point map) by tunning 'Layer ID';
183
+
184
+ Note that in '3D Point Cloud', different color denotes diffrent intersection layers, i.e., <b style="color: #FFBD1C;">layer 1</b>, <b style="color: #FB5607;">layer 2</b>, <b style="color: #F15BB5;">layer 3</b>, <b style="color: #8338EC;">layer 4</b>.
185
+ """
186
+ )
187
+
188
+ # , <b style="color: #3A86FF;">layer 5</b>.
189
+ lari_map = gradio.State(None)
190
+ valid_mask = gradio.State(None)
191
+ min_layer_id = gradio.State(None)
192
+ max_layer_id = gradio.State(None)
193
+
194
+ with gradio.Column():
195
+ with gradio.Row(equal_height=True):
196
+ with gradio.Column(scale=1):
197
+ image_input = gradio.Image(
198
+ label="Upload an Image", type="pil", height=350
199
+ )
200
+ with gradio.Row():
201
+ rembg_checkbox = gradio.Checkbox(label="Remove background")
202
+ clear_button = gradio.Button("Clear")
203
+ submit_btn = gradio.Button("Process")
204
+ with gradio.Column(scale=1):
205
+ depth_output = gradio.Image(
206
+ label="LaRI Map at Z-axis (depth)",
207
+ type="pil",
208
+ interactive=False,
209
+ height=300,
210
+ )
211
+ slider_layer_id = gradio.Slider(
212
+ minimum=1,
213
+ maximum=4,
214
+ step=1,
215
+ value=1,
216
+ label="Layer ID",
217
+ interactive=True,
218
+ )
219
+
220
+ with gradio.Row(scale=1):
221
+ outmodel = gradio.Model3D(
222
+ label="3D Point Cloud (Color denotes different layers)",
223
+ interactive=False,
224
+ zoom_speed=0.5,
225
+ pan_speed=0.5,
226
+ height=450,
227
+ )
228
+
229
+ with gradio.Row():
230
+ ply_file_output = gradio.File(label="ply output", elem_classes="small-file")
231
+ glb_file_output = gradio.File(label="glb output", elem_classes="small-file")
232
+
233
+ submit_btn.click(
234
+ fn=model_forward,
235
+ inputs=[image_input, slider_layer_id, rembg_checkbox],
236
+ outputs=[
237
+ depth_output,
238
+ outmodel,
239
+ lari_map,
240
+ valid_mask,
241
+ min_layer_id,
242
+ max_layer_id,
243
+ glb_file_output,
244
+ ply_file_output,
245
+ image_input,
246
+ ],
247
+ )
248
+
249
+ clear_button.click(
250
+ fn=clear_everything,
251
+ outputs=[
252
+ lari_map,
253
+ valid_mask,
254
+ min_layer_id,
255
+ max_layer_id,
256
+ image_input,
257
+ depth_output,
258
+ outmodel,
259
+ ],
260
+ )
261
+
262
+ slider_layer_id.change(
263
+ fn=change_layer,
264
+ inputs=[slider_layer_id, lari_map, valid_mask, min_layer_id, max_layer_id],
265
+ outputs=depth_output,
266
+ )
267
+
268
+ gradio.Examples(examples=["assets/cole_hardware.png",
269
+ "assets/3m_tape.png",
270
+ "assets/horse.png",
271
+ "assets/rhino.png",
272
+ "assets/alphabet.png",
273
+ "assets/martin_wedge.png",
274
+ "assets/d_rose.png",
275
+ "assets/ace.png",
276
+ "assets/bifidus.png",
277
+ "assets/fem.png",
278
+ ],
279
+ inputs=image_input)
280
+
281
+
282
+ demo.launch(share=False)
demo.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ import torch.backends.cudnn as cudnn
5
+ from PIL import Image
6
+ from src.utils.vis import prob_to_mask
7
+ from huggingface_hub import hf_hub_download
8
+ from tools import load_model, process_image, post_process_output, get_masked_depth, get_point_cloud, removebg_crop
9
+
10
+ parser = argparse.ArgumentParser("Arguments for deploying a LaRI Demo")
11
+ parser.add_argument(
12
+ "--image_path",
13
+ type=str,
14
+ default="assets/cole_hardware.png",
15
+ help="input image name",
16
+ )
17
+
18
+ parser.add_argument(
19
+ "--output_path",
20
+ type=str,
21
+ default="./results",
22
+ help="path to save the image",
23
+ )
24
+
25
+ parser.add_argument(
26
+ "--model_info_pm",
27
+ type=str,
28
+ default="LaRIModel(use_pretrained = 'moge_full', num_output_layer = 5, head_type = 'point')",
29
+ help="Network parameters to load the model",
30
+ )
31
+
32
+ parser.add_argument(
33
+ "--model_info_mask",
34
+ type=str,
35
+ default="DinoSegModel(use_pretrained = 'dinov2', dim_proj = 256, pretrained_path = '', num_output_layer = 4, output_type = 'ray_stop')",
36
+ help="Network parameters to load the model",
37
+ )
38
+
39
+ parser.add_argument(
40
+ "--ckpt_path_pm",
41
+ type=str,
42
+ default="lari_obj_16k_pointmap.pth",
43
+ help="Path to pre-trained weights",
44
+ )
45
+
46
+ parser.add_argument(
47
+ "--ckpt_path_mask",
48
+ type=str,
49
+ default="lari_obj_16k_seg.pth",
50
+ help="Path to pre-trained weights",
51
+ )
52
+
53
+ parser.add_argument(
54
+ "--resolution", type=int, default=512, help="Default model resolution"
55
+ )
56
+
57
+ parser.add_argument(
58
+ "--is_remove_background", action="store_true", help="Automatically remove the background."
59
+ )
60
+
61
+ args = parser.parse_args()
62
+
63
+
64
+
65
+
66
+
67
+
68
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
+ cudnn.benchmark = True
70
+
71
+ # === Load the model
72
+
73
+ model_path_pm = hf_hub_download(repo_id="ruili3/LaRI", filename=args.ckpt_path_pm, repo_type="model")
74
+ model_path_mask = hf_hub_download(repo_id="ruili3/LaRI", filename=args.ckpt_path_mask, repo_type="model")
75
+ # Load the model with pretrained weights.
76
+ model_pm = load_model(args.model_info_pm, model_path_pm, device)
77
+ model_mask = (
78
+ load_model(args.model_info_mask, model_path_mask, device)
79
+ if args.model_info_mask is not None
80
+ else None
81
+ )
82
+
83
+ # === Image pre-processing
84
+ pil_input = Image.open(args.image_path)
85
+ if args.is_remove_background:
86
+ pil_input = removebg_crop(pil_input) # remove background
87
+ input_tensor, ori_img_tensor, crop_coords, original_size = process_image(
88
+ pil_input, resolution=512) # crop & resize to fit the model input size
89
+ input_tensor = input_tensor.to(device)
90
+
91
+
92
+ # === Run inference
93
+ with torch.no_grad():
94
+ # lari map
95
+ pred_dict = model_pm(input_tensor)
96
+ lari_map = -pred_dict["pts3d"].squeeze(
97
+ 0
98
+ )
99
+ # mask
100
+ if model_mask:
101
+ pred_dict = model_mask(input_tensor)
102
+ assert "seg_prob" in pred_dict
103
+ valid_mask = prob_to_mask(pred_dict["seg_prob"].squeeze(0)) # H W L 1
104
+ else:
105
+ h, w, l, _ = lari_map.shape
106
+ valid_mask = torch.new_ones((h, w, l, 1), device=lari_map.device)
107
+
108
+ # === crop & resize back to the original resolution
109
+ if original_size[0] != args.resolution or original_size[1] != args.resolution:
110
+ lari_map = post_process_output(lari_map, crop_coords, original_size) # H W L 3
111
+ valid_mask = post_process_output(
112
+ valid_mask.float(), crop_coords, original_size
113
+ ).bool() # H W L 1
114
+
115
+ max_n_layer = min(valid_mask.shape[-2], lari_map.shape[-2])
116
+ valid_mask = valid_mask[:, :, :max_n_layer, :]
117
+ lari_map = lari_map[:, :, :max_n_layer, :]
118
+
119
+
120
+ # === save output
121
+ os.makedirs(args.output_path, exist_ok=True)
122
+
123
+ for layer_id in range(max_n_layer):
124
+ depth_pil = get_masked_depth(
125
+ lari_map=lari_map, valid_mask=valid_mask, layer_id=layer_id
126
+ )
127
+ depth_pil.save(os.path.join(args.output_path, f"layered_depth_{layer_id}.jpg"))
128
+
129
+
130
+ # point cloud
131
+ glb_path, ply_path = get_point_cloud(
132
+ lari_map, ori_img_tensor, valid_mask, first_layer_color="pseudo",
133
+ target_folder=args.output_path
134
+ )
135
+
136
+ print("All results saved to `{}`.".format(args.output_path))
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==5.23.3
2
+ huggingface_hub==0.30.1
3
+ imageio==2.37.0
4
+ matplotlib==3.10.1
5
+ moderngl==5.12.0
6
+ omegaconf==2.3.0
7
+ opencv_python==4.11.0.86
8
+ opencv_python_headless==4.11.0.86
9
+ Pillow==11.1.0
10
+ piqp==0.5.0
11
+ plyfile==1.1
12
+ rembg==2.0.65
13
+ scipy==1.15.2
14
+ torchvision==0.21.0
15
+ trimesh==4.6.4
16
+ xformers==0.0.29.post3
17
+ numpy==1.26.4
18
+ torch==2.6.0
19
+ opencv-python==4.11.0
src/lari/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .lari_model import LaRIModel
2
+ from .dinoseg_model import DinoSegModel
src/lari/model/blocks.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import torch.nn as nn
3
+
4
+ class ResidualConvBlock(nn.Module):
5
+ def __init__(self, in_channels: int, out_channels: int = None, hidden_channels: int = None, padding_mode: str = 'replicate', activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', norm: Literal['group_norm', 'layer_norm'] = 'group_norm'):
6
+ super(ResidualConvBlock, self).__init__()
7
+ if out_channels is None:
8
+ out_channels = in_channels
9
+ if hidden_channels is None:
10
+ hidden_channels = in_channels
11
+
12
+ if activation =='relu':
13
+ activation_cls = lambda: nn.ReLU(inplace=True)
14
+ elif activation == 'leaky_relu':
15
+ activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True)
16
+ elif activation =='silu':
17
+ activation_cls = lambda: nn.SiLU(inplace=True)
18
+ elif activation == 'elu':
19
+ activation_cls = lambda: nn.ELU(inplace=True)
20
+ else:
21
+ raise ValueError(f'Unsupported activation function: {activation}')
22
+
23
+ self.layers = nn.Sequential(
24
+ nn.GroupNorm(1, in_channels),
25
+ activation_cls(),
26
+ nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode=padding_mode),
27
+ nn.GroupNorm(hidden_channels // 32 if norm == 'group_norm' else 1, hidden_channels),
28
+ activation_cls(),
29
+ nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode)
30
+ )
31
+
32
+ self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity()
33
+
34
+ def forward(self, x):
35
+ skip = self.skip_connection(x)
36
+ x = self.layers(x)
37
+ x = x + skip
38
+ return x
39
+
40
+
41
+
42
+
43
+ def make_upsampler(in_channels: int, out_channels: int):
44
+ upsampler = nn.Sequential(
45
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
46
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
47
+ )
48
+ upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1]
49
+ return upsampler
50
+
51
+ def make_output_block(dim_in: int, dim_out: int, dim_times_res_block_hidden: int, last_res_blocks: int, last_conv_channels: int, last_conv_size: int, res_block_norm: Literal['group_norm', 'layer_norm']):
52
+ return nn.Sequential(
53
+ nn.Conv2d(dim_in, last_conv_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
54
+ *(ResidualConvBlock(last_conv_channels, last_conv_channels, dim_times_res_block_hidden * last_conv_channels, activation='relu', norm=res_block_norm) for _ in range(last_res_blocks)),
55
+ nn.ReLU(inplace=True),
56
+ nn.Conv2d(last_conv_channels, dim_out, kernel_size=last_conv_size, stride=1, padding=last_conv_size // 2, padding_mode='replicate'),
57
+ )
58
+
59
+
60
+
61
+ # ---- the following are from Depth Anything ----
62
+ import torch.nn as nn
63
+
64
+
65
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
66
+ scratch = nn.Module()
67
+
68
+ out_shape1 = out_shape
69
+ out_shape2 = out_shape
70
+ out_shape3 = out_shape
71
+ if len(in_shape) >= 4:
72
+ out_shape4 = out_shape
73
+
74
+ if expand:
75
+ out_shape1 = out_shape
76
+ out_shape2 = out_shape * 2
77
+ out_shape3 = out_shape * 4
78
+ if len(in_shape) >= 4:
79
+ out_shape4 = out_shape * 8
80
+
81
+ scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
82
+ scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
83
+ scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
84
+ if len(in_shape) >= 4:
85
+ scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
86
+
87
+ return scratch
88
+
89
+
90
+ class ResidualConvUnit(nn.Module):
91
+ """Residual convolution module.
92
+ """
93
+
94
+ def __init__(self, features, activation, bn):
95
+ """Init.
96
+
97
+ Args:
98
+ features (int): number of features
99
+ """
100
+ super().__init__()
101
+
102
+ self.bn = bn
103
+
104
+ self.groups=1
105
+
106
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
107
+
108
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
109
+
110
+ if self.bn == True:
111
+ self.bn1 = nn.BatchNorm2d(features)
112
+ self.bn2 = nn.BatchNorm2d(features)
113
+
114
+ self.activation = activation
115
+
116
+ self.skip_add = nn.quantized.FloatFunctional()
117
+
118
+ def forward(self, x):
119
+ """Forward pass.
120
+
121
+ Args:
122
+ x (tensor): input
123
+
124
+ Returns:
125
+ tensor: output
126
+ """
127
+
128
+ out = self.activation(x)
129
+ out = self.conv1(out)
130
+ if self.bn == True:
131
+ out = self.bn1(out)
132
+
133
+ out = self.activation(out)
134
+ out = self.conv2(out)
135
+ if self.bn == True:
136
+ out = self.bn2(out)
137
+
138
+ if self.groups > 1:
139
+ out = self.conv_merge(out)
140
+
141
+ return self.skip_add.add(out, x)
142
+
143
+
144
+ class FeatureFusionBlock(nn.Module):
145
+ """Feature fusion block.
146
+ """
147
+
148
+ def __init__(
149
+ self,
150
+ features,
151
+ activation,
152
+ deconv=False,
153
+ bn=False,
154
+ expand=False,
155
+ align_corners=True,
156
+ size=None
157
+ ):
158
+ """Init.
159
+
160
+ Args:
161
+ features (int): number of features
162
+ """
163
+ super(FeatureFusionBlock, self).__init__()
164
+
165
+ self.deconv = deconv
166
+ self.align_corners = align_corners
167
+
168
+ self.groups=1
169
+
170
+ self.expand = expand
171
+ out_features = features
172
+ if self.expand == True:
173
+ out_features = features // 2
174
+
175
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
176
+
177
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
178
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
179
+
180
+ self.skip_add = nn.quantized.FloatFunctional()
181
+
182
+ self.size=size
183
+
184
+ def forward(self, *xs, size=None):
185
+ """Forward pass.
186
+
187
+ Returns:
188
+ tensor: output
189
+ """
190
+ output = xs[0]
191
+
192
+ if len(xs) == 2:
193
+ res = self.resConfUnit1(xs[1])
194
+ output = self.skip_add.add(output, res)
195
+
196
+ output = self.resConfUnit2(output)
197
+
198
+ if (size is None) and (self.size is None):
199
+ modifier = {"scale_factor": 2}
200
+ elif size is None:
201
+ modifier = {"size": self.size}
202
+ else:
203
+ modifier = {"size": size}
204
+
205
+ output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
206
+
207
+ output = self.out_conv(output)
208
+
209
+ return output
src/lari/model/dinoseg_model.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from numbers import Number
3
+ from functools import partial
4
+ from pathlib import Path
5
+ import importlib
6
+ import warnings
7
+ import json
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils
13
+ import torch.utils.checkpoint
14
+ import torch.version
15
+ from huggingface_hub import hf_hub_download
16
+
17
+
18
+ from src.lari.model.utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
19
+ from src.lari.model.dpt_seg_head import DPTSegHead
20
+
21
+
22
+
23
+ class DinoSegModel(nn.Module):
24
+
25
+ def __init__(self,
26
+ encoder: str = 'dinov2_vitl14',
27
+ intermediate_layers: Union[int, List[int]] = 4,
28
+ dim_proj: int = 512,
29
+ use_pretrained: Literal["dinov2", "moge_full", "moge_backbone", None] = None,
30
+ pretrained_path: str = None,
31
+ num_output_layer: str = None,
32
+ output_type: str = "ray_stop", # "seg_sep"
33
+ **deprecated_kwargs
34
+ ):
35
+ super(DinoSegModel, self).__init__()
36
+ if deprecated_kwargs:
37
+ warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
38
+
39
+ self.encoder = encoder
40
+ self.intermediate_layers = intermediate_layers
41
+ self.use_pretrained = use_pretrained
42
+ self.pretrained_path = pretrained_path
43
+ self.num_output_layer = num_output_layer
44
+ self.output_type = output_type
45
+ assert self.output_type in ["seg_sep", "ray_stop"]
46
+
47
+ hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder)
48
+
49
+ self.backbone = hub_loader(pretrained=True if self.use_pretrained == "dinov2" else False)
50
+ dim_feature = self.backbone.blocks[0].attn.qkv.in_features
51
+
52
+
53
+
54
+
55
+ self.head = DPTSegHead(in_channels=dim_feature,
56
+ features=dim_proj,
57
+ use_bn=True,
58
+ out_channels=[256, 512, 1024, 1024],
59
+ use_clstoken=False,
60
+ num_classes = num_output_layer,
61
+ output_type = self.output_type
62
+ )
63
+
64
+
65
+ if torch.__version__ >= '2.0':
66
+ self.enable_pytorch_native_sdpa()
67
+
68
+ self._load_pretrained()
69
+
70
+
71
+ def _load_pretrained(self):
72
+ '''
73
+ Load data from MoGe model
74
+ '''
75
+ return
76
+
77
+
78
+
79
+
80
+ @classmethod
81
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'DinoSegModel':
82
+ """
83
+ Load a model from a checkpoint file.
84
+
85
+ ### Parameters:
86
+ - `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
87
+ - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
88
+ - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
89
+
90
+ ### Returns:
91
+ - A new instance of `MoGe` with the parameters loaded from the checkpoint.
92
+ """
93
+ if Path(pretrained_model_name_or_path).exists():
94
+ checkpoint = torch.load(pretrained_model_name_or_path, map_location='cpu', weights_only=True)
95
+ else:
96
+ cached_checkpoint_path = hf_hub_download(
97
+ repo_id=pretrained_model_name_or_path,
98
+ repo_type="model",
99
+ filename="model.pt",
100
+ **hf_kwargs
101
+ )
102
+ checkpoint = torch.load(cached_checkpoint_path, map_location='cpu', weights_only=True)
103
+ model_config = checkpoint['model_config']
104
+ if model_kwargs is not None:
105
+ model_config.update(model_kwargs)
106
+ model = cls(**model_config)
107
+ model.load_state_dict(checkpoint['model'])
108
+ return model
109
+
110
+ @staticmethod
111
+ def cache_pretrained_backbone(encoder: str, pretrained: bool):
112
+ _ = torch.hub.load('facebookresearch/dinov2', encoder, pretrained=pretrained)
113
+
114
+ def load_pretrained_backbone(self):
115
+ "Load the backbone with pretrained dinov2 weights from torch hub"
116
+ state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict()
117
+ self.backbone.load_state_dict(state_dict)
118
+
119
+ def enable_backbone_gradient_checkpointing(self):
120
+ for i in range(len(self.backbone.blocks)):
121
+ self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
122
+
123
+ def enable_pytorch_native_sdpa(self):
124
+ for i in range(len(self.backbone.blocks)):
125
+ self.backbone.blocks[i].attn = wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn)
126
+
127
+
128
+
129
+ def forward(self, image: torch.Tensor, mixed_precision: bool = False) -> Dict[str, torch.Tensor]:
130
+ raw_img_h, raw_img_w = image.shape[-2:]
131
+ patch_h, patch_w = raw_img_h // 14, raw_img_w // 14
132
+ # Apply image transformation for DINOv2
133
+ image_14 = F.interpolate(image, (patch_h * 14, patch_w * 14), mode="bilinear", align_corners=False, antialias=True)
134
+
135
+ # Get intermediate layers from the backbone
136
+ with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=mixed_precision):
137
+ features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True)
138
+
139
+ # Predict points and mask (mask scores)
140
+ mask = self.head(features, patch_h, patch_w)
141
+
142
+ # b c h w
143
+ mask = F.interpolate(mask, (raw_img_h, raw_img_w), mode="bilinear", align_corners=False)
144
+
145
+ out_dict = {}
146
+
147
+ if self.output_type == "seg_sep":
148
+ # mask = torch.nn.functional.sigmoid(mask) # for binary segmentation
149
+ out_dict["mask"] = mask.permute(0, 2, 3, 1).unsqueeze(-1) # B H W L 1
150
+ elif self.output_type == "ray_stop":
151
+ out_dict["seg_prob"] = mask # B L+1 H W
152
+
153
+ return out_dict
src/lari/model/dinov2/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ __version__ = "0.0.1"
src/lari/model/dinov2/hub/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
src/lari/model/dinov2/hub/backbones.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ from typing import Union
8
+
9
+ import torch
10
+
11
+ from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
12
+
13
+
14
+ class Weights(Enum):
15
+ LVD142M = "LVD142M"
16
+
17
+
18
+ def _make_dinov2_model(
19
+ *,
20
+ arch_name: str = "vit_large",
21
+ img_size: int = 518,
22
+ patch_size: int = 14,
23
+ init_values: float = 1.0,
24
+ ffn_layer: str = "mlp",
25
+ block_chunks: int = 0,
26
+ num_register_tokens: int = 0,
27
+ interpolate_antialias: bool = False,
28
+ interpolate_offset: float = 0.1,
29
+ pretrained: bool = True,
30
+ weights: Union[Weights, str] = Weights.LVD142M,
31
+ **kwargs,
32
+ ):
33
+ from ..models import vision_transformer as vits
34
+
35
+ if isinstance(weights, str):
36
+ try:
37
+ weights = Weights[weights]
38
+ except KeyError:
39
+ raise AssertionError(f"Unsupported weights: {weights}")
40
+
41
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
42
+ vit_kwargs = dict(
43
+ img_size=img_size,
44
+ patch_size=patch_size,
45
+ init_values=init_values,
46
+ ffn_layer=ffn_layer,
47
+ block_chunks=block_chunks,
48
+ num_register_tokens=num_register_tokens,
49
+ interpolate_antialias=interpolate_antialias,
50
+ interpolate_offset=interpolate_offset,
51
+ )
52
+ vit_kwargs.update(**kwargs)
53
+ model = vits.__dict__[arch_name](**vit_kwargs)
54
+
55
+ if pretrained:
56
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
57
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
58
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
59
+ model.load_state_dict(state_dict, strict=True)
60
+
61
+ return model
62
+
63
+
64
+ def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
65
+ """
66
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
67
+ """
68
+ return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
69
+
70
+
71
+ def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
72
+ """
73
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
74
+ """
75
+ return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
76
+
77
+
78
+ def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
79
+ """
80
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
81
+ """
82
+ return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
83
+
84
+
85
+ def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
86
+ """
87
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
88
+ """
89
+ return _make_dinov2_model(
90
+ arch_name="vit_giant2",
91
+ ffn_layer="swiglufused",
92
+ weights=weights,
93
+ pretrained=pretrained,
94
+ **kwargs,
95
+ )
96
+
97
+
98
+ def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
99
+ """
100
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
101
+ """
102
+ return _make_dinov2_model(
103
+ arch_name="vit_small",
104
+ pretrained=pretrained,
105
+ weights=weights,
106
+ num_register_tokens=4,
107
+ interpolate_antialias=True,
108
+ interpolate_offset=0.0,
109
+ **kwargs,
110
+ )
111
+
112
+
113
+ def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
114
+ """
115
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
116
+ """
117
+ return _make_dinov2_model(
118
+ arch_name="vit_base",
119
+ pretrained=pretrained,
120
+ weights=weights,
121
+ num_register_tokens=4,
122
+ interpolate_antialias=True,
123
+ interpolate_offset=0.0,
124
+ **kwargs,
125
+ )
126
+
127
+
128
+ def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
129
+ """
130
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
131
+ """
132
+ return _make_dinov2_model(
133
+ arch_name="vit_large",
134
+ pretrained=pretrained,
135
+ weights=weights,
136
+ num_register_tokens=4,
137
+ interpolate_antialias=True,
138
+ interpolate_offset=0.0,
139
+ **kwargs,
140
+ )
141
+
142
+
143
+ def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
144
+ """
145
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
146
+ """
147
+ return _make_dinov2_model(
148
+ arch_name="vit_giant2",
149
+ ffn_layer="swiglufused",
150
+ weights=weights,
151
+ pretrained=pretrained,
152
+ num_register_tokens=4,
153
+ interpolate_antialias=True,
154
+ interpolate_offset=0.0,
155
+ **kwargs,
156
+ )
src/lari/model/dinov2/hub/utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
15
+
16
+
17
+ def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
18
+ compact_arch_name = arch_name.replace("_", "")[:4]
19
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
20
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
21
+
22
+
23
+ class CenterPadding(nn.Module):
24
+ def __init__(self, multiple):
25
+ super().__init__()
26
+ self.multiple = multiple
27
+
28
+ def _get_pad(self, size):
29
+ new_size = math.ceil(size / self.multiple) * self.multiple
30
+ pad_size = new_size - size
31
+ pad_size_left = pad_size // 2
32
+ pad_size_right = pad_size - pad_size_left
33
+ return pad_size_left, pad_size_right
34
+
35
+ @torch.inference_mode()
36
+ def forward(self, x):
37
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
38
+ output = F.pad(x, pads)
39
+ return output
src/lari/model/dinov2/layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .dino_head import DINOHead
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
src/lari/model/dinov2/layers/attention.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
22
+ try:
23
+ if XFORMERS_ENABLED:
24
+ from xformers.ops import memory_efficient_attention, unbind
25
+
26
+ XFORMERS_AVAILABLE = True
27
+ # warnings.warn("xFormers is available (Attention)")
28
+ else:
29
+ # warnings.warn("xFormers is disabled (Attention)")
30
+ raise ImportError
31
+ except ImportError:
32
+ XFORMERS_AVAILABLE = False
33
+ # warnings.warn("xFormers is not available (Attention)")
34
+
35
+
36
+ class Attention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int = 8,
41
+ qkv_bias: bool = False,
42
+ proj_bias: bool = True,
43
+ attn_drop: float = 0.0,
44
+ proj_drop: float = 0.0,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ self.scale = head_dim**-0.5
50
+
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+
56
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
57
+ B, N, C = x.shape
58
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
59
+
60
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
61
+ attn = q @ k.transpose(-2, -1)
62
+
63
+ attn = attn.softmax(dim=-1)
64
+ attn = self.attn_drop(attn)
65
+
66
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
67
+ x = self.proj(x)
68
+ x = self.proj_drop(x)
69
+ return x
70
+
71
+
72
+ class MemEffAttention(Attention):
73
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
74
+ if not XFORMERS_AVAILABLE:
75
+ if attn_bias is not None:
76
+ raise AssertionError("xFormers is required for using nested tensors")
77
+ return super().forward(x)
78
+
79
+ B, N, C = x.shape
80
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
81
+
82
+ q, k, v = unbind(qkv, 2)
83
+
84
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
85
+ x = x.reshape([B, N, C])
86
+
87
+ x = self.proj(x)
88
+ x = self.proj_drop(x)
89
+ return x
src/lari/model/dinov2/layers/block.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention, MemEffAttention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+
24
+ logger = logging.getLogger("dinov2")
25
+
26
+
27
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
28
+ try:
29
+ if XFORMERS_ENABLED:
30
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
31
+
32
+ XFORMERS_AVAILABLE = True
33
+ # warnings.warn("xFormers is available (Block)")
34
+ else:
35
+ # warnings.warn("xFormers is disabled (Block)")
36
+ raise ImportError
37
+ except ImportError:
38
+ XFORMERS_AVAILABLE = False
39
+ # warnings.warn("xFormers is not available (Block)")
40
+
41
+
42
+ class Block(nn.Module):
43
+ def __init__(
44
+ self,
45
+ dim: int,
46
+ num_heads: int,
47
+ mlp_ratio: float = 4.0,
48
+ qkv_bias: bool = False,
49
+ proj_bias: bool = True,
50
+ ffn_bias: bool = True,
51
+ drop: float = 0.0,
52
+ attn_drop: float = 0.0,
53
+ init_values=None,
54
+ drop_path: float = 0.0,
55
+ act_layer: Callable[..., nn.Module] = nn.GELU,
56
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
57
+ attn_class: Callable[..., nn.Module] = Attention,
58
+ ffn_layer: Callable[..., nn.Module] = Mlp,
59
+ ) -> None:
60
+ super().__init__()
61
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
62
+ self.norm1 = norm_layer(dim)
63
+ self.attn = attn_class(
64
+ dim,
65
+ num_heads=num_heads,
66
+ qkv_bias=qkv_bias,
67
+ proj_bias=proj_bias,
68
+ attn_drop=attn_drop,
69
+ proj_drop=drop,
70
+ )
71
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
72
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
73
+
74
+ self.norm2 = norm_layer(dim)
75
+ mlp_hidden_dim = int(dim * mlp_ratio)
76
+ self.mlp = ffn_layer(
77
+ in_features=dim,
78
+ hidden_features=mlp_hidden_dim,
79
+ act_layer=act_layer,
80
+ drop=drop,
81
+ bias=ffn_bias,
82
+ )
83
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
84
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
85
+
86
+ self.sample_drop_ratio = drop_path
87
+
88
+ def forward(self, x: Tensor) -> Tensor:
89
+ def attn_residual_func(x: Tensor) -> Tensor:
90
+ return self.ls1(self.attn(self.norm1(x)))
91
+
92
+ def ffn_residual_func(x: Tensor) -> Tensor:
93
+ return self.ls2(self.mlp(self.norm2(x)))
94
+
95
+ if self.training and self.sample_drop_ratio > 0.1:
96
+ # the overhead is compensated only for a drop path rate larger than 0.1
97
+ x = drop_add_residual_stochastic_depth(
98
+ x,
99
+ residual_func=attn_residual_func,
100
+ sample_drop_ratio=self.sample_drop_ratio,
101
+ )
102
+ x = drop_add_residual_stochastic_depth(
103
+ x,
104
+ residual_func=ffn_residual_func,
105
+ sample_drop_ratio=self.sample_drop_ratio,
106
+ )
107
+ elif self.training and self.sample_drop_ratio > 0.0:
108
+ x = x + self.drop_path1(attn_residual_func(x))
109
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
110
+ else:
111
+ x = x + attn_residual_func(x)
112
+ x = x + ffn_residual_func(x)
113
+ return x
114
+
115
+
116
+ def drop_add_residual_stochastic_depth(
117
+ x: Tensor,
118
+ residual_func: Callable[[Tensor], Tensor],
119
+ sample_drop_ratio: float = 0.0,
120
+ ) -> Tensor:
121
+ # 1) extract subset using permutation
122
+ b, n, d = x.shape
123
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
124
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
125
+ x_subset = x[brange]
126
+
127
+ # 2) apply residual_func to get residual
128
+ residual = residual_func(x_subset)
129
+
130
+ x_flat = x.flatten(1)
131
+ residual = residual.flatten(1)
132
+
133
+ residual_scale_factor = b / sample_subset_size
134
+
135
+ # 3) add the residual
136
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
137
+ return x_plus_residual.view_as(x)
138
+
139
+
140
+ def get_branges_scales(x, sample_drop_ratio=0.0):
141
+ b, n, d = x.shape
142
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
143
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
144
+ residual_scale_factor = b / sample_subset_size
145
+ return brange, residual_scale_factor
146
+
147
+
148
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
149
+ if scaling_vector is None:
150
+ x_flat = x.flatten(1)
151
+ residual = residual.flatten(1)
152
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
153
+ else:
154
+ x_plus_residual = scaled_index_add(
155
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
156
+ )
157
+ return x_plus_residual
158
+
159
+
160
+ attn_bias_cache: Dict[Tuple, Any] = {}
161
+
162
+
163
+ def get_attn_bias_and_cat(x_list, branges=None):
164
+ """
165
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
166
+ """
167
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
168
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
169
+ if all_shapes not in attn_bias_cache.keys():
170
+ seqlens = []
171
+ for b, x in zip(batch_sizes, x_list):
172
+ for _ in range(b):
173
+ seqlens.append(x.shape[1])
174
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
175
+ attn_bias._batch_sizes = batch_sizes
176
+ attn_bias_cache[all_shapes] = attn_bias
177
+
178
+ if branges is not None:
179
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
180
+ else:
181
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
182
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
183
+
184
+ return attn_bias_cache[all_shapes], cat_tensors
185
+
186
+
187
+ def drop_add_residual_stochastic_depth_list(
188
+ x_list: List[Tensor],
189
+ residual_func: Callable[[Tensor, Any], Tensor],
190
+ sample_drop_ratio: float = 0.0,
191
+ scaling_vector=None,
192
+ ) -> Tensor:
193
+ # 1) generate random set of indices for dropping samples in the batch
194
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
195
+ branges = [s[0] for s in branges_scales]
196
+ residual_scale_factors = [s[1] for s in branges_scales]
197
+
198
+ # 2) get attention bias and index+concat the tensors
199
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
200
+
201
+ # 3) apply residual_func to get residual, and split the result
202
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
203
+
204
+ outputs = []
205
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
206
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
207
+ return outputs
208
+
209
+
210
+ class NestedTensorBlock(Block):
211
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
212
+ """
213
+ x_list contains a list of tensors to nest together and run
214
+ """
215
+ assert isinstance(self.attn, MemEffAttention)
216
+
217
+ if self.training and self.sample_drop_ratio > 0.0:
218
+
219
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
220
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
221
+
222
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
223
+ return self.mlp(self.norm2(x))
224
+
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=attn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ x_list = drop_add_residual_stochastic_depth_list(
232
+ x_list,
233
+ residual_func=ffn_residual_func,
234
+ sample_drop_ratio=self.sample_drop_ratio,
235
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
236
+ )
237
+ return x_list
238
+ else:
239
+
240
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
241
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
242
+
243
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
244
+ return self.ls2(self.mlp(self.norm2(x)))
245
+
246
+ attn_bias, x = get_attn_bias_and_cat(x_list)
247
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
248
+ x = x + ffn_residual_func(x)
249
+ return attn_bias.split(x)
250
+
251
+ def forward(self, x_or_x_list):
252
+ if isinstance(x_or_x_list, Tensor):
253
+ return super().forward(x_or_x_list)
254
+ elif isinstance(x_or_x_list, list):
255
+ if not XFORMERS_AVAILABLE:
256
+ raise AssertionError("xFormers is required for using nested tensors")
257
+ return self.forward_nested(x_or_x_list)
258
+ else:
259
+ raise AssertionError
src/lari/model/dinov2/layers/dino_head.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn.init import trunc_normal_
9
+ from torch.nn.utils import weight_norm
10
+
11
+
12
+ class DINOHead(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_dim,
16
+ out_dim,
17
+ use_bn=False,
18
+ nlayers=3,
19
+ hidden_dim=2048,
20
+ bottleneck_dim=256,
21
+ mlp_bias=True,
22
+ ):
23
+ super().__init__()
24
+ nlayers = max(nlayers, 1)
25
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
26
+ self.apply(self._init_weights)
27
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
28
+ self.last_layer.weight_g.data.fill_(1)
29
+
30
+ def _init_weights(self, m):
31
+ if isinstance(m, nn.Linear):
32
+ trunc_normal_(m.weight, std=0.02)
33
+ if isinstance(m, nn.Linear) and m.bias is not None:
34
+ nn.init.constant_(m.bias, 0)
35
+
36
+ def forward(self, x):
37
+ x = self.mlp(x)
38
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
39
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
40
+ x = self.last_layer(x)
41
+ return x
42
+
43
+
44
+ def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
45
+ if nlayers == 1:
46
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
47
+ else:
48
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
49
+ if use_bn:
50
+ layers.append(nn.BatchNorm1d(hidden_dim))
51
+ layers.append(nn.GELU())
52
+ for _ in range(nlayers - 2):
53
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
54
+ if use_bn:
55
+ layers.append(nn.BatchNorm1d(hidden_dim))
56
+ layers.append(nn.GELU())
57
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
58
+ return nn.Sequential(*layers)
src/lari/model/dinov2/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
src/lari/model/dinov2/layers/layer_scale.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ init_values: Union[float, Tensor] = 1e-5,
20
+ inplace: bool = False,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.inplace = inplace
24
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
src/lari/model/dinov2/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
src/lari/model/dinov2/layers/patch_embed.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (
51
+ image_HW[0] // patch_HW[0],
52
+ image_HW[1] // patch_HW[1],
53
+ )
54
+
55
+ self.img_size = image_HW
56
+ self.patch_size = patch_HW
57
+ self.patches_resolution = patch_grid_size
58
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
+
60
+ self.in_chans = in_chans
61
+ self.embed_dim = embed_dim
62
+
63
+ self.flatten_embedding = flatten_embedding
64
+
65
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ _, _, H, W = x.shape
70
+ patch_H, patch_W = self.patch_size
71
+
72
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74
+
75
+ x = self.proj(x) # B C H W
76
+ H, W = x.size(2), x.size(3)
77
+ x = x.flatten(2).transpose(1, 2) # B HW C
78
+ x = self.norm(x)
79
+ if not self.flatten_embedding:
80
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81
+ return x
82
+
83
+ def flops(self) -> float:
84
+ Ho, Wo = self.patches_resolution
85
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86
+ if self.norm is not None:
87
+ flops += Ho * Wo * self.embed_dim
88
+ return flops
src/lari/model/dinov2/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from typing import Callable, Optional
8
+ import warnings
9
+
10
+ from torch import Tensor, nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ hidden_features: Optional[int] = None,
19
+ out_features: Optional[int] = None,
20
+ act_layer: Callable[..., nn.Module] = None,
21
+ drop: float = 0.0,
22
+ bias: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ x12 = self.w12(x)
32
+ x1, x2 = x12.chunk(2, dim=-1)
33
+ hidden = F.silu(x1) * x2
34
+ return self.w3(hidden)
35
+
36
+
37
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38
+ try:
39
+ if XFORMERS_ENABLED:
40
+ from xformers.ops import SwiGLU
41
+
42
+ XFORMERS_AVAILABLE = True
43
+ # warnings.warn("xFormers is available (SwiGLU)")
44
+ else:
45
+ # warnings.warn("xFormers is disabled (SwiGLU)")
46
+ raise ImportError
47
+ except ImportError:
48
+ SwiGLU = SwiGLUFFN
49
+ XFORMERS_AVAILABLE = False
50
+
51
+ # warnings.warn("xFormers is not available (SwiGLU)")
52
+
53
+
54
+ class SwiGLUFFNFused(SwiGLU):
55
+ def __init__(
56
+ self,
57
+ in_features: int,
58
+ hidden_features: Optional[int] = None,
59
+ out_features: Optional[int] = None,
60
+ act_layer: Callable[..., nn.Module] = None,
61
+ drop: float = 0.0,
62
+ bias: bool = True,
63
+ ) -> None:
64
+ out_features = out_features or in_features
65
+ hidden_features = hidden_features or in_features
66
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67
+ super().__init__(
68
+ in_features=in_features,
69
+ hidden_features=hidden_features,
70
+ out_features=out_features,
71
+ bias=bias,
72
+ )
src/lari/model/dinov2/models/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ from . import vision_transformer as vits
9
+
10
+
11
+ logger = logging.getLogger("dinov2")
12
+
13
+
14
+ def build_model(args, only_teacher=False, img_size=224):
15
+ args.arch = args.arch.removesuffix("_memeff")
16
+ if "vit" in args.arch:
17
+ vit_kwargs = dict(
18
+ img_size=img_size,
19
+ patch_size=args.patch_size,
20
+ init_values=args.layerscale,
21
+ ffn_layer=args.ffn_layer,
22
+ block_chunks=args.block_chunks,
23
+ qkv_bias=args.qkv_bias,
24
+ proj_bias=args.proj_bias,
25
+ ffn_bias=args.ffn_bias,
26
+ num_register_tokens=args.num_register_tokens,
27
+ interpolate_offset=args.interpolate_offset,
28
+ interpolate_antialias=args.interpolate_antialias,
29
+ )
30
+ teacher = vits.__dict__[args.arch](**vit_kwargs)
31
+ if only_teacher:
32
+ return teacher, teacher.embed_dim
33
+ student = vits.__dict__[args.arch](
34
+ **vit_kwargs,
35
+ drop_path_rate=args.drop_path_rate,
36
+ drop_path_uniform=args.drop_path_uniform,
37
+ )
38
+ embed_dim = student.embed_dim
39
+ return student, teacher, embed_dim
40
+
41
+
42
+ def build_model_from_cfg(cfg, only_teacher=False):
43
+ return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
src/lari/model/dinov2/models/vision_transformer.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.utils.checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+
20
+ from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
27
+ if not depth_first and include_root:
28
+ fn(module=module, name=name)
29
+ for child_name, child_module in module.named_children():
30
+ child_name = ".".join((name, child_name)) if name else child_name
31
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
32
+ if depth_first and include_root:
33
+ fn(module=module, name=name)
34
+ return module
35
+
36
+
37
+ class BlockChunk(nn.ModuleList):
38
+ def forward(self, x):
39
+ for b in self:
40
+ x = b(x)
41
+ return x
42
+
43
+
44
+ class DinoVisionTransformer(nn.Module):
45
+ def __init__(
46
+ self,
47
+ img_size=224,
48
+ patch_size=16,
49
+ in_chans=3,
50
+ embed_dim=768,
51
+ depth=12,
52
+ num_heads=12,
53
+ mlp_ratio=4.0,
54
+ qkv_bias=True,
55
+ ffn_bias=True,
56
+ proj_bias=True,
57
+ drop_path_rate=0.0,
58
+ drop_path_uniform=False,
59
+ init_values=None, # for layerscale: None or 0 => no layerscale
60
+ embed_layer=PatchEmbed,
61
+ act_layer=nn.GELU,
62
+ block_fn=Block,
63
+ ffn_layer="mlp",
64
+ block_chunks=1,
65
+ num_register_tokens=0,
66
+ interpolate_antialias=False,
67
+ interpolate_offset=0.1,
68
+ ):
69
+ """
70
+ Args:
71
+ img_size (int, tuple): input image size
72
+ patch_size (int, tuple): patch size
73
+ in_chans (int): number of input channels
74
+ embed_dim (int): embedding dimension
75
+ depth (int): depth of transformer
76
+ num_heads (int): number of attention heads
77
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
78
+ qkv_bias (bool): enable bias for qkv if True
79
+ proj_bias (bool): enable bias for proj in attn if True
80
+ ffn_bias (bool): enable bias for ffn if True
81
+ drop_path_rate (float): stochastic depth rate
82
+ drop_path_uniform (bool): apply uniform drop rate across blocks
83
+ weight_init (str): weight init scheme
84
+ init_values (float): layer-scale init values
85
+ embed_layer (nn.Module): patch embedding layer
86
+ act_layer (nn.Module): MLP activation layer
87
+ block_fn (nn.Module): transformer block class
88
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
89
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
90
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
91
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
92
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
93
+ """
94
+ super().__init__()
95
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
96
+
97
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
98
+ self.num_tokens = 1
99
+ self.n_blocks = depth
100
+ self.num_heads = num_heads
101
+ self.patch_size = patch_size
102
+ self.num_register_tokens = num_register_tokens
103
+ self.interpolate_antialias = interpolate_antialias
104
+ self.interpolate_offset = interpolate_offset
105
+
106
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
107
+ num_patches = self.patch_embed.num_patches
108
+
109
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
110
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
111
+ assert num_register_tokens >= 0
112
+ self.register_tokens = (
113
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
114
+ )
115
+
116
+ if drop_path_uniform is True:
117
+ dpr = [drop_path_rate] * depth
118
+ else:
119
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
120
+
121
+ if ffn_layer == "mlp":
122
+ logger.info("using MLP layer as FFN")
123
+ ffn_layer = Mlp
124
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
125
+ logger.info("using SwiGLU layer as FFN")
126
+ ffn_layer = SwiGLUFFNFused
127
+ elif ffn_layer == "identity":
128
+ logger.info("using Identity layer as FFN")
129
+
130
+ def f(*args, **kwargs):
131
+ return nn.Identity()
132
+
133
+ ffn_layer = f
134
+ else:
135
+ raise NotImplementedError
136
+
137
+ blocks_list = [
138
+ block_fn(
139
+ dim=embed_dim,
140
+ num_heads=num_heads,
141
+ mlp_ratio=mlp_ratio,
142
+ qkv_bias=qkv_bias,
143
+ proj_bias=proj_bias,
144
+ ffn_bias=ffn_bias,
145
+ drop_path=dpr[i],
146
+ norm_layer=norm_layer,
147
+ act_layer=act_layer,
148
+ ffn_layer=ffn_layer,
149
+ init_values=init_values,
150
+ )
151
+ for i in range(depth)
152
+ ]
153
+ if block_chunks > 0:
154
+ self.chunked_blocks = True
155
+ chunked_blocks = []
156
+ chunksize = depth // block_chunks
157
+ for i in range(0, depth, chunksize):
158
+ # this is to keep the block index consistent if we chunk the block list
159
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
160
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
161
+ else:
162
+ self.chunked_blocks = False
163
+ self.blocks = nn.ModuleList(blocks_list)
164
+
165
+ self.norm = norm_layer(embed_dim)
166
+ self.head = nn.Identity()
167
+
168
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
169
+
170
+ self.init_weights()
171
+
172
+ def init_weights(self):
173
+ trunc_normal_(self.pos_embed, std=0.02)
174
+ nn.init.normal_(self.cls_token, std=1e-6)
175
+ if self.register_tokens is not None:
176
+ nn.init.normal_(self.register_tokens, std=1e-6)
177
+ named_apply(init_weights_vit_timm, self)
178
+
179
+ def interpolate_pos_encoding(self, x, w, h):
180
+ previous_dtype = x.dtype
181
+ npatch = x.shape[1] - 1
182
+ N = self.pos_embed.shape[1] - 1
183
+ if npatch == N and w == h:
184
+ return self.pos_embed
185
+ pos_embed = self.pos_embed.float()
186
+ class_pos_embed = pos_embed[:, 0]
187
+ patch_pos_embed = pos_embed[:, 1:]
188
+ dim = x.shape[-1]
189
+ w0 = w // self.patch_size
190
+ h0 = h // self.patch_size
191
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
192
+ assert N == M * M
193
+ kwargs = {}
194
+ if self.interpolate_offset:
195
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
196
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
197
+ sx = float(w0 + self.interpolate_offset) / M
198
+ sy = float(h0 + self.interpolate_offset) / M
199
+ kwargs["scale_factor"] = (sx, sy)
200
+ else:
201
+ # Simply specify an output size instead of a scale factor
202
+ kwargs["size"] = (w0, h0)
203
+ patch_pos_embed = nn.functional.interpolate(
204
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
205
+ mode="bicubic",
206
+ antialias=self.interpolate_antialias,
207
+ **kwargs,
208
+ )
209
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
210
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
211
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
212
+
213
+ def prepare_tokens_with_masks(self, x, masks=None):
214
+ B, nc, w, h = x.shape
215
+ x = self.patch_embed(x)
216
+ if masks is not None:
217
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
218
+
219
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
220
+ x = x + self.interpolate_pos_encoding(x, w, h)
221
+
222
+ if self.register_tokens is not None:
223
+ x = torch.cat(
224
+ (
225
+ x[:, :1],
226
+ self.register_tokens.expand(x.shape[0], -1, -1),
227
+ x[:, 1:],
228
+ ),
229
+ dim=1,
230
+ )
231
+
232
+ return x
233
+
234
+ def forward_features_list(self, x_list, masks_list):
235
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
236
+ for blk in self.blocks:
237
+ x = blk(x)
238
+
239
+ all_x = x
240
+ output = []
241
+ for x, masks in zip(all_x, masks_list):
242
+ x_norm = self.norm(x)
243
+ output.append(
244
+ {
245
+ "x_norm_clstoken": x_norm[:, 0],
246
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
247
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
248
+ "x_prenorm": x,
249
+ "masks": masks,
250
+ }
251
+ )
252
+ return output
253
+
254
+ def forward_features(self, x, masks=None):
255
+ if isinstance(x, list):
256
+ return self.forward_features_list(x, masks)
257
+
258
+ x = self.prepare_tokens_with_masks(x, masks)
259
+
260
+ for blk in self.blocks:
261
+ x = blk(x)
262
+
263
+ x_norm = self.norm(x)
264
+ return {
265
+ "x_norm_clstoken": x_norm[:, 0],
266
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
267
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
268
+ "x_prenorm": x,
269
+ "masks": masks,
270
+ }
271
+
272
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
273
+ x = self.prepare_tokens_with_masks(x)
274
+ # If n is an int, take the n last blocks. If it's a list, take them
275
+ output, total_block_len = [], len(self.blocks)
276
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
277
+ for i, blk in enumerate(self.blocks):
278
+ x = blk(x)
279
+ if i in blocks_to_take:
280
+ output.append(x)
281
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
282
+ return output
283
+
284
+ def _get_intermediate_layers_chunked(self, x, n=1):
285
+ x = self.prepare_tokens_with_masks(x)
286
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
287
+ # If n is an int, take the n last blocks. If it's a list, take them
288
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
289
+ for block_chunk in self.blocks:
290
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
291
+ x = blk(x)
292
+ if i in blocks_to_take:
293
+ output.append(x)
294
+ i += 1
295
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
296
+ return output
297
+
298
+ def get_intermediate_layers(
299
+ self,
300
+ x: torch.Tensor,
301
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
302
+ reshape: bool = False,
303
+ return_class_token: bool = False,
304
+ norm=True,
305
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
306
+ if self.chunked_blocks:
307
+ outputs = self._get_intermediate_layers_chunked(x, n)
308
+ else:
309
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
310
+ if norm:
311
+ outputs = [self.norm(out) for out in outputs]
312
+ class_tokens = [out[:, 0] for out in outputs]
313
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
314
+ if reshape:
315
+ B, _, w, h = x.shape
316
+ outputs = [
317
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
318
+ for out in outputs
319
+ ]
320
+ if return_class_token:
321
+ return tuple(zip(outputs, class_tokens))
322
+ return tuple(outputs)
323
+
324
+ def forward(self, *args, is_training=False, **kwargs):
325
+ ret = self.forward_features(*args, **kwargs)
326
+ if is_training:
327
+ return ret
328
+ else:
329
+ return self.head(ret["x_norm_clstoken"])
330
+
331
+
332
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
333
+ """ViT weight initialization, original timm impl (for reproducibility)"""
334
+ if isinstance(module, nn.Linear):
335
+ trunc_normal_(module.weight, std=0.02)
336
+ if module.bias is not None:
337
+ nn.init.zeros_(module.bias)
338
+
339
+
340
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
341
+ model = DinoVisionTransformer(
342
+ patch_size=patch_size,
343
+ embed_dim=384,
344
+ depth=12,
345
+ num_heads=6,
346
+ mlp_ratio=4,
347
+ block_fn=partial(Block, attn_class=MemEffAttention),
348
+ num_register_tokens=num_register_tokens,
349
+ **kwargs,
350
+ )
351
+ return model
352
+
353
+
354
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
355
+ model = DinoVisionTransformer(
356
+ patch_size=patch_size,
357
+ embed_dim=768,
358
+ depth=12,
359
+ num_heads=12,
360
+ mlp_ratio=4,
361
+ block_fn=partial(Block, attn_class=MemEffAttention),
362
+ num_register_tokens=num_register_tokens,
363
+ **kwargs,
364
+ )
365
+ return model
366
+
367
+
368
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
369
+ model = DinoVisionTransformer(
370
+ patch_size=patch_size,
371
+ embed_dim=1024,
372
+ depth=24,
373
+ num_heads=16,
374
+ mlp_ratio=4,
375
+ block_fn=partial(Block, attn_class=MemEffAttention),
376
+ num_register_tokens=num_register_tokens,
377
+ **kwargs,
378
+ )
379
+ return model
380
+
381
+
382
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
383
+ """
384
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
385
+ """
386
+ model = DinoVisionTransformer(
387
+ patch_size=patch_size,
388
+ embed_dim=1536,
389
+ depth=40,
390
+ num_heads=24,
391
+ mlp_ratio=4,
392
+ block_fn=partial(Block, attn_class=MemEffAttention),
393
+ num_register_tokens=num_register_tokens,
394
+ **kwargs,
395
+ )
396
+ return model
src/lari/model/dinov2/utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
src/lari/model/dinov2/utils/cluster.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Optional
10
+
11
+
12
+ class ClusterType(Enum):
13
+ AWS = "aws"
14
+ FAIR = "fair"
15
+ RSC = "rsc"
16
+
17
+
18
+ def _guess_cluster_type() -> ClusterType:
19
+ uname = os.uname()
20
+ if uname.sysname == "Linux":
21
+ if uname.release.endswith("-aws"):
22
+ # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
23
+ return ClusterType.AWS
24
+ elif uname.nodename.startswith("rsc"):
25
+ # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
26
+ return ClusterType.RSC
27
+
28
+ return ClusterType.FAIR
29
+
30
+
31
+ def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
32
+ if cluster_type is None:
33
+ return _guess_cluster_type()
34
+
35
+ return cluster_type
36
+
37
+
38
+ def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
39
+ cluster_type = get_cluster_type(cluster_type)
40
+ if cluster_type is None:
41
+ return None
42
+
43
+ CHECKPOINT_DIRNAMES = {
44
+ ClusterType.AWS: "checkpoints",
45
+ ClusterType.FAIR: "checkpoint",
46
+ ClusterType.RSC: "checkpoint/dino",
47
+ }
48
+ return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
49
+
50
+
51
+ def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
52
+ checkpoint_path = get_checkpoint_path(cluster_type)
53
+ if checkpoint_path is None:
54
+ return None
55
+
56
+ username = os.environ.get("USER")
57
+ assert username is not None
58
+ return checkpoint_path / username
59
+
60
+
61
+ def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
62
+ cluster_type = get_cluster_type(cluster_type)
63
+ if cluster_type is None:
64
+ return None
65
+
66
+ SLURM_PARTITIONS = {
67
+ ClusterType.AWS: "learnlab",
68
+ ClusterType.FAIR: "learnlab",
69
+ ClusterType.RSC: "learn",
70
+ }
71
+ return SLURM_PARTITIONS[cluster_type]
72
+
73
+
74
+ def get_slurm_executor_parameters(
75
+ nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
76
+ ) -> Dict[str, Any]:
77
+ # create default parameters
78
+ params = {
79
+ "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
80
+ "gpus_per_node": num_gpus_per_node,
81
+ "tasks_per_node": num_gpus_per_node, # one task per GPU
82
+ "cpus_per_task": 10,
83
+ "nodes": nodes,
84
+ "slurm_partition": get_slurm_partition(cluster_type),
85
+ }
86
+ # apply cluster-specific adjustments
87
+ cluster_type = get_cluster_type(cluster_type)
88
+ if cluster_type == ClusterType.AWS:
89
+ params["cpus_per_task"] = 12
90
+ del params["mem_gb"]
91
+ elif cluster_type == ClusterType.RSC:
92
+ params["cpus_per_task"] = 12
93
+ # set additional parameters / apply overrides
94
+ params.update(kwargs)
95
+ return params
src/lari/model/dinov2/utils/config.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import logging
8
+ import os
9
+
10
+ from omegaconf import OmegaConf
11
+
12
+ import dinov2.distributed as distributed
13
+ from dinov2.logging import setup_logging
14
+ from dinov2.utils import utils
15
+ from dinov2.configs import dinov2_default_config
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ def apply_scaling_rules_to_cfg(cfg): # to fix
22
+ if cfg.optim.scaling_rule == "sqrt_wrt_1024":
23
+ base_lr = cfg.optim.base_lr
24
+ cfg.optim.lr = base_lr
25
+ cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
26
+ logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
27
+ else:
28
+ raise NotImplementedError
29
+ return cfg
30
+
31
+
32
+ def write_config(cfg, output_dir, name="config.yaml"):
33
+ logger.info(OmegaConf.to_yaml(cfg))
34
+ saved_cfg_path = os.path.join(output_dir, name)
35
+ with open(saved_cfg_path, "w") as f:
36
+ OmegaConf.save(config=cfg, f=f)
37
+ return saved_cfg_path
38
+
39
+
40
+ def get_cfg_from_args(args):
41
+ args.output_dir = os.path.abspath(args.output_dir)
42
+ args.opts += [f"train.output_dir={args.output_dir}"]
43
+ default_cfg = OmegaConf.create(dinov2_default_config)
44
+ cfg = OmegaConf.load(args.config_file)
45
+ cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
46
+ return cfg
47
+
48
+
49
+ def default_setup(args):
50
+ distributed.enable(overwrite=True)
51
+ seed = getattr(args, "seed", 0)
52
+ rank = distributed.get_global_rank()
53
+
54
+ global logger
55
+ setup_logging(output=args.output_dir, level=logging.INFO)
56
+ logger = logging.getLogger("dinov2")
57
+
58
+ utils.fix_random_seeds(seed + rank)
59
+ logger.info("git:\n {}\n".format(utils.get_sha()))
60
+ logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
61
+
62
+
63
+ def setup(args):
64
+ """
65
+ Create configs and perform basic setups.
66
+ """
67
+ cfg = get_cfg_from_args(args)
68
+ os.makedirs(args.output_dir, exist_ok=True)
69
+ default_setup(args)
70
+ apply_scaling_rules_to_cfg(cfg)
71
+ write_config(cfg, args.output_dir)
72
+ return cfg
src/lari/model/dinov2/utils/dtype.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from typing import Dict, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+
13
+ TypeSpec = Union[str, np.dtype, torch.dtype]
14
+
15
+
16
+ _NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
17
+ np.dtype("bool"): torch.bool,
18
+ np.dtype("uint8"): torch.uint8,
19
+ np.dtype("int8"): torch.int8,
20
+ np.dtype("int16"): torch.int16,
21
+ np.dtype("int32"): torch.int32,
22
+ np.dtype("int64"): torch.int64,
23
+ np.dtype("float16"): torch.float16,
24
+ np.dtype("float32"): torch.float32,
25
+ np.dtype("float64"): torch.float64,
26
+ np.dtype("complex64"): torch.complex64,
27
+ np.dtype("complex128"): torch.complex128,
28
+ }
29
+
30
+
31
+ def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
32
+ if isinstance(dtype, torch.dtype):
33
+ return dtype
34
+ if isinstance(dtype, str):
35
+ dtype = np.dtype(dtype)
36
+ assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
37
+ return _NUMPY_TO_TORCH_DTYPE[dtype]
src/lari/model/dinov2/utils/param_groups.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import defaultdict
7
+ import logging
8
+
9
+
10
+ logger = logging.getLogger("dinov2")
11
+
12
+
13
+ def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
14
+ """
15
+ Calculate lr decay rate for different ViT blocks.
16
+ Args:
17
+ name (string): parameter name.
18
+ lr_decay_rate (float): base lr decay rate.
19
+ num_layers (int): number of ViT blocks.
20
+ Returns:
21
+ lr decay rate for the given parameter.
22
+ """
23
+ layer_id = num_layers + 1
24
+ if name.startswith("backbone") or force_is_backbone:
25
+ if (
26
+ ".pos_embed" in name
27
+ or ".patch_embed" in name
28
+ or ".mask_token" in name
29
+ or ".cls_token" in name
30
+ or ".register_tokens" in name
31
+ ):
32
+ layer_id = 0
33
+ elif force_is_backbone and (
34
+ "pos_embed" in name
35
+ or "patch_embed" in name
36
+ or "mask_token" in name
37
+ or "cls_token" in name
38
+ or "register_tokens" in name
39
+ ):
40
+ layer_id = 0
41
+ elif ".blocks." in name and ".residual." not in name:
42
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
43
+ elif chunked_blocks and "blocks." in name and "residual." not in name:
44
+ layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
45
+ elif "blocks." in name and "residual." not in name:
46
+ layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
47
+
48
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
49
+
50
+
51
+ def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
52
+ chunked_blocks = False
53
+ if hasattr(model, "n_blocks"):
54
+ logger.info("chunked fsdp")
55
+ n_blocks = model.n_blocks
56
+ chunked_blocks = model.chunked_blocks
57
+ elif hasattr(model, "blocks"):
58
+ logger.info("first code branch")
59
+ n_blocks = len(model.blocks)
60
+ elif hasattr(model, "backbone"):
61
+ logger.info("second code branch")
62
+ n_blocks = len(model.backbone.blocks)
63
+ else:
64
+ logger.info("else code branch")
65
+ n_blocks = 0
66
+ all_param_groups = []
67
+
68
+ for name, param in model.named_parameters():
69
+ name = name.replace("_fsdp_wrapped_module.", "")
70
+ if not param.requires_grad:
71
+ continue
72
+ decay_rate = get_vit_lr_decay_rate(
73
+ name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
74
+ )
75
+ d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
76
+
77
+ if "last_layer" in name:
78
+ d.update({"is_last_layer": True})
79
+
80
+ if name.endswith(".bias") or "norm" in name or "gamma" in name:
81
+ d.update({"wd_multiplier": 0.0})
82
+
83
+ if "patch_embed" in name:
84
+ d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
85
+
86
+ all_param_groups.append(d)
87
+ logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
88
+
89
+ return all_param_groups
90
+
91
+
92
+ def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
93
+ fused_params_groups = defaultdict(lambda: {"params": []})
94
+ for d in all_params_groups:
95
+ identifier = ""
96
+ for k in keys:
97
+ identifier += k + str(d[k]) + "_"
98
+
99
+ for k in keys:
100
+ fused_params_groups[identifier][k] = d[k]
101
+ fused_params_groups[identifier]["params"].append(d["params"])
102
+
103
+ return fused_params_groups.values()
src/lari/model/dinov2/utils/utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import os
8
+ import random
9
+ import subprocess
10
+ from urllib.parse import urlparse
11
+
12
+ import numpy as np
13
+ import torch
14
+ from torch import nn
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
21
+ if urlparse(pretrained_weights).scheme: # If it looks like an URL
22
+ state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
23
+ else:
24
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
25
+ if checkpoint_key is not None and checkpoint_key in state_dict:
26
+ logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
27
+ state_dict = state_dict[checkpoint_key]
28
+ # remove `module.` prefix
29
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
30
+ # remove `backbone.` prefix induced by multicrop wrapper
31
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
32
+ msg = model.load_state_dict(state_dict, strict=False)
33
+ logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
34
+
35
+
36
+ def fix_random_seeds(seed=31):
37
+ """
38
+ Fix random seeds.
39
+ """
40
+ torch.manual_seed(seed)
41
+ torch.cuda.manual_seed_all(seed)
42
+ np.random.seed(seed)
43
+ random.seed(seed)
44
+
45
+
46
+ def get_sha():
47
+ cwd = os.path.dirname(os.path.abspath(__file__))
48
+
49
+ def _run(command):
50
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
51
+
52
+ sha = "N/A"
53
+ diff = "clean"
54
+ branch = "N/A"
55
+ try:
56
+ sha = _run(["git", "rev-parse", "HEAD"])
57
+ subprocess.check_output(["git", "diff"], cwd=cwd)
58
+ diff = _run(["git", "diff-index", "HEAD"])
59
+ diff = "has uncommitted changes" if diff else "clean"
60
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
61
+ except Exception:
62
+ pass
63
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
64
+ return message
65
+
66
+
67
+ class CosineScheduler(object):
68
+ def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
69
+ super().__init__()
70
+ self.final_value = final_value
71
+ self.total_iters = total_iters
72
+
73
+ freeze_schedule = np.zeros((freeze_iters))
74
+
75
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
76
+
77
+ iters = np.arange(total_iters - warmup_iters - freeze_iters)
78
+ schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
79
+ self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
80
+
81
+ assert len(self.schedule) == self.total_iters
82
+
83
+ def __getitem__(self, it):
84
+ if it >= self.total_iters:
85
+ return self.final_value
86
+ else:
87
+ return self.schedule[it]
88
+
89
+
90
+ def has_batchnorms(model):
91
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
92
+ for name, module in model.named_modules():
93
+ if isinstance(module, bn_types):
94
+ return True
95
+ return False
src/lari/model/dpt_seg_head.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ The code is modified based on Depth Anything and DPT
3
+ '''
4
+ from src.lari.model.blocks import FeatureFusionBlock, _make_scratch
5
+ import cv2
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torchvision.transforms import Compose
10
+
11
+
12
+
13
+
14
+ def _make_fusion_block(features, use_bn, size=None):
15
+ return FeatureFusionBlock(
16
+ features,
17
+ nn.ReLU(False),
18
+ deconv=False,
19
+ bn=use_bn,
20
+ expand=False,
21
+ align_corners=True,
22
+ size=size,
23
+ )
24
+
25
+
26
+ class ConvBlock(nn.Module):
27
+ def __init__(self, in_feature, out_feature):
28
+ super().__init__()
29
+
30
+ self.conv_block = nn.Sequential(
31
+ nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
32
+ nn.BatchNorm2d(out_feature),
33
+ nn.ReLU(True)
34
+ )
35
+
36
+ def forward(self, x):
37
+ return self.conv_block(x)
38
+
39
+
40
+ class DPTSegHead(nn.Module):
41
+ def __init__(
42
+ self,
43
+ in_channels,
44
+ features=256,
45
+ use_bn=False,
46
+ out_channels=[256, 512, 1024, 1024],
47
+ use_clstoken=False,
48
+ num_classes = 5,
49
+ output_type = "ray_stop" # "seg_sep"
50
+ ):
51
+ super(DPTSegHead, self).__init__()
52
+
53
+ self.use_clstoken = use_clstoken
54
+ self.output_type = output_type
55
+
56
+ # output one more layer to indicate the invalid ray-stopping point using index 0
57
+ self.num_classes = num_classes + 1 if self.output_type == "ray_stop" else num_classes
58
+
59
+
60
+ self.projects = nn.ModuleList([
61
+ nn.Conv2d(
62
+ in_channels=in_channels,
63
+ out_channels=out_channel,
64
+ kernel_size=1,
65
+ stride=1,
66
+ padding=0,
67
+ ) for out_channel in out_channels
68
+ ])
69
+
70
+ self.resize_layers = nn.ModuleList([
71
+ nn.ConvTranspose2d(
72
+ in_channels=out_channels[0],
73
+ out_channels=out_channels[0],
74
+ kernel_size=4,
75
+ stride=4,
76
+ padding=0),
77
+ nn.ConvTranspose2d(
78
+ in_channels=out_channels[1],
79
+ out_channels=out_channels[1],
80
+ kernel_size=2,
81
+ stride=2,
82
+ padding=0),
83
+ nn.Identity(),
84
+ nn.Conv2d(
85
+ in_channels=out_channels[3],
86
+ out_channels=out_channels[3],
87
+ kernel_size=3,
88
+ stride=2,
89
+ padding=1)
90
+ ])
91
+
92
+ if use_clstoken:
93
+ self.readout_projects = nn.ModuleList()
94
+ for _ in range(len(self.projects)):
95
+ self.readout_projects.append(
96
+ nn.Sequential(
97
+ nn.Linear(2 * in_channels, in_channels),
98
+ nn.GELU()))
99
+
100
+ self.scratch = _make_scratch(
101
+ out_channels,
102
+ features,
103
+ groups=1,
104
+ expand=False,
105
+ )
106
+
107
+ self.scratch.stem_transpose = None
108
+
109
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
110
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
111
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
112
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
113
+
114
+ self.scratch.output_conv1 = nn.Sequential(
115
+ nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
116
+ nn.BatchNorm2d(features),
117
+ nn.ReLU(True),
118
+ nn.Dropout(0.1, False),
119
+ nn.Conv2d(features, self.num_classes, kernel_size=1),
120
+ )
121
+
122
+
123
+
124
+ def forward(self, out_features, patch_h, patch_w):
125
+ out = []
126
+ for i, x in enumerate(out_features):
127
+ if self.use_clstoken:
128
+ x, cls_token = x[0], x[1]
129
+ readout = cls_token.unsqueeze(1).expand_as(x)
130
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
131
+ else:
132
+ x = x[0]
133
+
134
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
135
+
136
+ x = self.projects[i](x)
137
+ x = self.resize_layers[i](x)
138
+
139
+ out.append(x)
140
+
141
+ layer_1, layer_2, layer_3, layer_4 = out
142
+
143
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
144
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
145
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
146
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
147
+
148
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
149
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
150
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
151
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
152
+
153
+ out = self.scratch.output_conv1(path_1)
154
+
155
+ # B C H W - segmentaton logits
156
+ out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
157
+
158
+ return out
src/lari/model/heads.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils
6
+ import torch.utils.checkpoint
7
+ import torch.version
8
+ from typing import *
9
+ import os
10
+ import sys
11
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
12
+ from src.lari.model.blocks import ResidualConvBlock, make_upsampler, make_output_block
13
+ from src.lari.utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, gaussian_blur_2d
14
+
15
+
16
+ class PointHead(nn.Module):
17
+ def __init__(
18
+ self,
19
+ num_features: int,
20
+ dim_in: int,
21
+ dim_out: int,
22
+ dim_proj: int = 512,
23
+ dim_upsample: List[int] = [256, 128, 128],
24
+ dim_times_res_block_hidden: int = 1,
25
+ num_res_blocks: int = 1,
26
+ res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm',
27
+ last_res_blocks: int = 0,
28
+ last_conv_channels: int = 32,
29
+ last_conv_size: int = 1,
30
+ num_output_layer: int = 5
31
+ ):
32
+ super().__init__()
33
+
34
+ self.num_output_layer = num_output_layer
35
+
36
+ self.projects = nn.ModuleList([
37
+ nn.Conv2d(in_channels=dim_in, out_channels=dim_proj, kernel_size=1, stride=1, padding=0,) for _ in range(num_features)
38
+ ])
39
+
40
+ self.upsample_blocks = nn.ModuleList([
41
+ nn.Sequential(
42
+ make_upsampler(in_ch + 2, out_ch),
43
+ *(ResidualConvBlock(out_ch, out_ch, dim_times_res_block_hidden * out_ch, activation="relu", norm=res_block_norm) for _ in range(num_res_blocks))
44
+ ) for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample)
45
+ ])
46
+
47
+ # layer iterations
48
+ self.first_layer_block = make_output_block(dim_upsample[-1] + 2, dim_out,
49
+ dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm,)
50
+
51
+ self.remaining_layer_block = nn.ModuleList([make_output_block(dim_upsample[-1] + 2, dim_out,
52
+ dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm,)
53
+ for _ in range(self.num_output_layer - 1)])
54
+
55
+
56
+
57
+ def forward(self, hidden_states: torch.Tensor, image: torch.Tensor):
58
+ img_h, img_w = image.shape[-2:]
59
+ patch_h, patch_w = img_h // 14, img_w // 14
60
+
61
+ # Process the hidden states
62
+ x = torch.stack([
63
+ proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous())
64
+ for proj, (feat, clstoken) in zip(self.projects, hidden_states)
65
+ ], dim=1).sum(dim=1)
66
+
67
+ # Upsample stage
68
+ # (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8)
69
+ for i, block in enumerate(self.upsample_blocks):
70
+ # UV coordinates is for awareness of image aspect ratio
71
+ uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
72
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
73
+ x = torch.cat([x, uv], dim=1)
74
+ for layer in block:
75
+ x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False)
76
+
77
+ # (patch_h * 8, patch_w * 8) -> (img_h, img_w)
78
+ x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False)
79
+ uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
80
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
81
+ x = torch.cat([x, uv], dim=1)
82
+
83
+
84
+ pts_list = []
85
+ for layer_id in range(self.num_output_layer):
86
+ if layer_id == 0:
87
+ blocks = self.first_layer_block
88
+ else:
89
+ blocks = self.remaining_layer_block[layer_id-1]
90
+
91
+ # for each block
92
+ if isinstance(blocks, nn.ModuleList):
93
+ raise NotImplementedError()
94
+ else:
95
+ res = torch.utils.checkpoint.checkpoint(blocks, x, use_reentrant=False)[:,:3, :,:]
96
+ pts_list.append(res[:, :3, :,:])
97
+
98
+ pts = torch.stack(pts_list, dim=-1)
99
+ seg = pts.new_zeros(pts.shape)[:, :1, ...]
100
+
101
+ # <b 3 h w l>, <b 1 h w l>
102
+ output = [pts, seg]
103
+
104
+ return output
src/lari/model/lari_model.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from numbers import Number
3
+ from functools import partial
4
+ from pathlib import Path
5
+ import importlib
6
+ import warnings
7
+ import json
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils
13
+ import torch.utils.checkpoint
14
+ import torch.version
15
+ from huggingface_hub import hf_hub_download
16
+ from src.lari.model.utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
17
+ from src.lari.model.heads import PointHead
18
+
19
+
20
+ class LaRIModel(nn.Module):
21
+ image_mean: torch.Tensor
22
+ image_std: torch.Tensor
23
+
24
+ def __init__(self,
25
+ encoder: str = 'dinov2_vitl14',
26
+ intermediate_layers: Union[int, List[int]] = 4,
27
+ dim_proj: int = 512,
28
+ dim_upsample: List[int] = [256, 128, 64],
29
+ dim_times_res_block_hidden: int = 2,
30
+ num_res_blocks: int = 2,
31
+ output_mask: bool = True,
32
+ split_head: bool = True,
33
+ remap_output: Literal[False, True, 'linear', 'sinh', 'exp', 'sinh_exp'] = 'exp',
34
+ res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm',
35
+ last_res_blocks: int = 0,
36
+ last_conv_channels: int = 32,
37
+ last_conv_size: int = 1,
38
+ use_pretrained: Literal["dinov2", "moge_full", "moge_backbone", None] = None,
39
+ pretrained_path: str = "",
40
+ num_output_layer: str = None,
41
+ head_type = None,
42
+ **deprecated_kwargs
43
+ ):
44
+ super(LaRIModel, self).__init__()
45
+ if deprecated_kwargs:
46
+ warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
47
+
48
+ self.encoder = encoder
49
+ self.remap_output = remap_output
50
+ self.intermediate_layers = intermediate_layers
51
+ self.head_type = head_type
52
+ self.output_mask = output_mask
53
+ self.split_head = split_head
54
+ self.use_pretrained = use_pretrained
55
+ self.pretrained_path = pretrained_path
56
+ self.num_output_layer = num_output_layer
57
+
58
+ hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder)
59
+ # hub_loader = getattr(importlib.import_module("dinov2.hub.backbones", __package__), encoder)
60
+
61
+ self.backbone = hub_loader(pretrained=True if self.use_pretrained == "dinov2" else False)
62
+ dim_feature = self.backbone.blocks[0].attn.qkv.in_features
63
+
64
+ if self.head_type == "point":
65
+ self.head = PointHead(
66
+ num_features=intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers),
67
+ dim_in=dim_feature,
68
+ dim_out=3,
69
+ dim_proj=dim_proj,
70
+ dim_upsample=dim_upsample,
71
+ dim_times_res_block_hidden=dim_times_res_block_hidden,
72
+ num_res_blocks=num_res_blocks,
73
+ res_block_norm=res_block_norm,
74
+ last_res_blocks=last_res_blocks,
75
+ last_conv_channels=last_conv_channels,
76
+ last_conv_size=last_conv_size,
77
+ num_output_layer = num_output_layer
78
+ )
79
+ else:
80
+ raise NotImplementedError()
81
+
82
+
83
+ if torch.__version__ >= '2.0':
84
+ self.enable_pytorch_native_sdpa()
85
+
86
+ self._load_pretrained()
87
+
88
+
89
+ def _load_pretrained(self):
90
+ '''
91
+ Load pre-trained weights
92
+ '''
93
+ if self.use_pretrained == "dinov2" or self.use_pretrained is None: return
94
+
95
+ if self.use_pretrained == "moge_full" and self.pretrained_path != "":
96
+ checkpoint = torch.load(self.pretrained_path, map_location='cpu', weights_only=True)
97
+ if self.head_type == "point":
98
+ key_transition_map = {"output_block": "first_layer_block"}
99
+ model_state_dict = {}
100
+
101
+ # change the key name of the dict
102
+ for key, val in checkpoint['model'].items():
103
+ for trans_src, trans_target in key_transition_map.items():
104
+ if trans_src in key:
105
+ model_state_dict[key.replace(trans_src, trans_target)] = val
106
+ else:
107
+ model_state_dict[key] = val
108
+
109
+ self.load_state_dict(model_state_dict, strict=False)
110
+ del model_state_dict
111
+
112
+
113
+ else:
114
+ return
115
+
116
+
117
+ @staticmethod
118
+ def cache_pretrained_backbone(encoder: str, pretrained: bool):
119
+ _ = torch.hub.load('facebookresearch/dinov2', encoder, pretrained=pretrained)
120
+
121
+ def load_pretrained_backbone(self):
122
+ "Load the backbone with pretrained dinov2 weights from torch hub"
123
+ state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict()
124
+ self.backbone.load_state_dict(state_dict)
125
+
126
+ def enable_backbone_gradient_checkpointing(self):
127
+ for i in range(len(self.backbone.blocks)):
128
+ self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
129
+
130
+ def enable_pytorch_native_sdpa(self):
131
+ for i in range(len(self.backbone.blocks)):
132
+ self.backbone.blocks[i].attn = wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn)
133
+
134
+ def forward(self, image: torch.Tensor, mixed_precision: bool = False) -> Dict[str, torch.Tensor]:
135
+ raw_img_h, raw_img_w = image.shape[-2:]
136
+ patch_h, patch_w = raw_img_h // 14, raw_img_w // 14
137
+
138
+ # Apply image transformation for DINOv2
139
+ image_14 = F.interpolate(image, (patch_h * 14, patch_w * 14), mode="bilinear", align_corners=False, antialias=True)
140
+
141
+ # Get intermediate layers from the backbone
142
+ with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=mixed_precision):
143
+ features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True)
144
+
145
+ # Predict points and mask (mask scores)
146
+ points, mask = self.head(features, image)
147
+
148
+ is_output_prob = False
149
+ if mask.ndim == 5:
150
+ # <b, h, w, layer, 3>, <b, h, w, layer, 1>
151
+ points, mask = points.permute(0, 2, 3, 4, 1), mask.permute(0,2,3,4,1)
152
+ elif mask.ndim == 4: # <b, h, w, layer, 3>, <b, layer, h, w>
153
+ points = points.permute(0, 2, 3, 4, 1)
154
+ is_output_prob = True
155
+
156
+ if self.remap_output == 'linear' or self.remap_output == False:
157
+ pass
158
+ elif self.remap_output =='sinh' or self.remap_output == True:
159
+ points = torch.sinh(points)
160
+ elif self.remap_output == 'exp':
161
+ xy, z = points.split([2, 1], dim=-1)
162
+ z = torch.exp(z)
163
+ points = torch.cat([xy * z, z], dim=-1)
164
+ elif self.remap_output =='sinh_exp':
165
+ xy, z = points.split([2, 1], dim=-1)
166
+ points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
167
+ else:
168
+ raise ValueError(f"Invalid remap output type: {self.remap_output}")
169
+
170
+ return_dict = {'pts3d': points}
171
+
172
+ if not is_output_prob:
173
+ return_dict['mask'] = mask
174
+ else:
175
+ return_dict["seg_prob"] = mask
176
+
177
+ return return_dict
src/lari/model/utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ def wrap_module_with_gradient_checkpointing(module: nn.Module):
8
+ from torch.utils.checkpoint import checkpoint
9
+ class _CheckpointingWrapper(module.__class__):
10
+ _restore_cls = module.__class__
11
+ def forward(self, *args, **kwargs):
12
+ return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
13
+
14
+ module.__class__ = _CheckpointingWrapper
15
+ return module
16
+
17
+
18
+ def unwrap_module_with_gradient_checkpointing(module: nn.Module):
19
+ module.__class__ = module.__class__._restore_cls
20
+
21
+
22
+ def wrap_dinov2_attention_with_sdpa(module: nn.Module):
23
+ assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later"
24
+ class _AttentionWrapper(module.__class__):
25
+ def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
26
+ B, N, C = x.shape
27
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
28
+
29
+ q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
30
+
31
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias)
32
+ x = x.permute(0, 2, 1, 3).reshape(B, N, C)
33
+
34
+ x = self.proj(x)
35
+ x = self.proj_drop(x)
36
+ return x
37
+ module.__class__ = _AttentionWrapper
38
+ return module
src/lari/utils/__init__.py ADDED
File without changes
src/lari/utils/geometry_numpy.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from functools import partial
3
+ import math
4
+
5
+ import numpy as np
6
+ import utils3d
7
+
8
+ def weighted_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
9
+ if w is None:
10
+ return np.mean(x, axis=axis)
11
+ else:
12
+ w = w.astype(x.dtype)
13
+ return (x * w).mean(axis=axis) / np.clip(w.mean(axis=axis), eps, None)
14
+
15
+
16
+ def harmonic_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
17
+ if w is None:
18
+ return 1 / (1 / np.clip(x, eps, None)).mean(axis=axis)
19
+ else:
20
+ w = w.astype(x.dtype)
21
+ return 1 / (weighted_mean_numpy(1 / (x + eps), w, axis=axis, keepdims=keepdims, eps=eps) + eps)
22
+
23
+
24
+ def normalized_view_plane_uv_numpy(width: int, height: int, aspect_ratio: float = None, dtype: np.dtype = np.float32) -> np.ndarray:
25
+ "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
26
+ if aspect_ratio is None:
27
+ aspect_ratio = width / height
28
+
29
+ span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
30
+ span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
31
+
32
+ u = np.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype)
33
+ v = np.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype)
34
+ u, v = np.meshgrid(u, v, indexing='xy')
35
+ uv = np.stack([u, v], axis=-1)
36
+ return uv
37
+
38
+
39
+ def focal_to_fov_numpy(focal: np.ndarray):
40
+ return 2 * np.arctan(0.5 / focal)
41
+
42
+
43
+ def fov_to_focal_numpy(fov: np.ndarray):
44
+ return 0.5 / np.tan(fov / 2)
45
+
46
+
47
+ def intrinsics_to_fov_numpy(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
48
+ fov_x = focal_to_fov_numpy(intrinsics[..., 0, 0])
49
+ fov_y = focal_to_fov_numpy(intrinsics[..., 1, 1])
50
+ return fov_x, fov_y
51
+
52
+
53
+ def point_map_to_depth_legacy_numpy(points: np.ndarray):
54
+ height, width = points.shape[-3:-1]
55
+ diagonal = (height ** 2 + width ** 2) ** 0.5
56
+ uv = normalized_view_plane_uv_numpy(width, height, dtype=points.dtype) # (H, W, 2)
57
+ _, uv = np.broadcast_arrays(points[..., :2], uv)
58
+
59
+ # Solve least squares problem
60
+ b = (uv * points[..., 2:]).reshape(*points.shape[:-3], -1) # (..., H * W * 2)
61
+ A = np.stack([points[..., :2], -uv], axis=-1).reshape(*points.shape[:-3], -1, 2) # (..., H * W * 2, 2)
62
+
63
+ M = A.swapaxes(-2, -1) @ A
64
+ solution = (np.linalg.inv(M + 1e-6 * np.eye(2)) @ (A.swapaxes(-2, -1) @ b[..., None])).squeeze(-1)
65
+ focal, shift = solution
66
+
67
+ depth = points[..., 2] + shift[..., None, None]
68
+ fov_x = np.arctan(width / diagonal / focal) * 2
69
+ fov_y = np.arctan(height / diagonal / focal) * 2
70
+ return depth, fov_x, fov_y, shift
71
+
72
+
73
+ def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray):
74
+ "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal"
75
+ from scipy.optimize import least_squares
76
+ uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
77
+
78
+ def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
79
+ xy_proj = xy / (z + shift)[: , None]
80
+ f = (xy_proj * uv).sum() / np.square(xy_proj).sum()
81
+ err = (f * xy_proj - uv).ravel()
82
+ return err
83
+
84
+ solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
85
+ optim_shift = solution['x'].squeeze().astype(np.float32)
86
+
87
+ xy_proj = xy / (z + optim_shift)[: , None]
88
+ optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum()
89
+
90
+ return optim_shift, optim_focal
91
+
92
+
93
+ def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float):
94
+ "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift"
95
+ from scipy.optimize import least_squares
96
+ uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
97
+
98
+ def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
99
+ xy_proj = xy/ (z + shift)[: , None]
100
+ err = (focal * xy_proj - uv).ravel()
101
+ return err
102
+
103
+ solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
104
+ optim_shift = solution['x'].squeeze().astype(np.float32)
105
+
106
+ return optim_shift
107
+
108
+
109
+ def recover_focal_shift_numpy(points: np.ndarray, mask: np.ndarray = None, focal: float = None, downsample_size: Tuple[int, int] = (64, 64)):
110
+ import cv2
111
+ assert points.shape[-1] == 3, "Points should (H, W, 3)"
112
+
113
+ height, width = points.shape[-3], points.shape[-2]
114
+ diagonal = (height ** 2 + width ** 2) ** 0.5
115
+
116
+ uv = normalized_view_plane_uv_numpy(width=width, height=height)
117
+
118
+ if mask is None:
119
+ points_lr = cv2.resize(points, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 3)
120
+ uv_lr = cv2.resize(uv, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 2)
121
+ else:
122
+ index, mask_lr = mask_aware_nearest_resize_numpy(mask, *downsample_size)
123
+ points_lr, uv_lr = points[index][mask_lr], uv[index][mask_lr]
124
+
125
+ if points_lr.size == 0:
126
+ return np.zeros((height, width)), 0, 0, 0
127
+
128
+ if focal is None:
129
+ focal, shift = solve_optimal_focal_shift(uv_lr, points_lr)
130
+ else:
131
+ shift = solve_optimal_shift(uv_lr, points_lr, focal)
132
+
133
+ return focal, shift
134
+
135
+
136
+ def mask_aware_nearest_resize_numpy(mask: np.ndarray, target_width: int, target_height: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
137
+ """
138
+ Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
139
+
140
+ ### Parameters
141
+ - `mask`: Input 2D mask of shape (..., H, W)
142
+ - `target_width`: target width of the resized map
143
+ - `target_height`: target height of the resized map
144
+
145
+ ### Returns
146
+ - `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width). Indices are like j + i * W, where j is the row index and i is the column index.
147
+ - `target_mask`: Mask of the resized map of shape (..., target_height, target_width)
148
+ """
149
+ height, width = mask.shape[-2:]
150
+ filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
151
+ filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
152
+ filter_size = filter_h_i * filter_w_i
153
+ padding_h, padding_w = round(filter_h_f / 2), round(filter_w_f / 2)
154
+
155
+ # Window the original mask and uv
156
+ uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32)
157
+ indices = np.arange(height * width, dtype=np.int32).reshape(height, width)
158
+ padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32)
159
+ padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
160
+ padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool)
161
+ padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
162
+ padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32)
163
+ padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
164
+ windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1))
165
+ windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1))
166
+ windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1))
167
+
168
+ # Gather the target pixels's local window
169
+ target_uv = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32)
170
+ target_corner = target_uv - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
171
+ target_corner = np.round(target_corner - 0.5).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32)
172
+
173
+ target_window_uv = windowed_uv[target_corner[..., 1], target_corner[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
174
+ target_window_mask = windowed_mask[..., target_corner[..., 1], target_corner[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
175
+ target_window_indices = windowed_indices[target_corner[..., 1], target_corner[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
176
+
177
+ # Compute nearest neighbor in the local window for each pixel
178
+ dist = np.square(target_window_uv - target_uv[..., None])
179
+ dist = dist[..., 0, :] + dist[..., 1, :]
180
+ dist = np.where(target_window_mask, dist, np.inf) # (..., target_height, tgt_width, filter_size)
181
+ nearest_in_window = np.argmin(dist, axis=-1, keepdims=True) # (..., target_height, tgt_width, 1)
182
+ nearest_idx = np.take_along_axis(target_window_indices, nearest_in_window, axis=-1).squeeze(-1) # (..., target_height, tgt_width)
183
+ nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
184
+ target_mask = np.any(target_window_mask, axis=-1)
185
+ batch_indices = [np.arange(n).reshape([1] * i + [n] + [1] * (mask.ndim - i - 1)) for i, n in enumerate(mask.shape[:-2])]
186
+
187
+ return (*batch_indices, nearest_i, nearest_j), target_mask
src/lari/utils/geometry_torch.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import math
3
+ from collections import namedtuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.types
10
+
11
+ import os
12
+ import sys
13
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
14
+ import utils3d
15
+ from .geometry_numpy import solve_optimal_focal_shift, solve_optimal_shift
16
+
17
+
18
+ def weighted_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
19
+ if w is None:
20
+ return x.mean(dim=dim, keepdim=keepdim)
21
+ else:
22
+ w = w.to(x.dtype)
23
+ return (x * w).mean(dim=dim, keepdim=keepdim) / w.mean(dim=dim, keepdim=keepdim).add(eps)
24
+
25
+
26
+ def harmonic_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
27
+ if w is None:
28
+ return x.add(eps).reciprocal().mean(dim=dim, keepdim=keepdim).reciprocal()
29
+ else:
30
+ w = w.to(x.dtype)
31
+ return weighted_mean(x.add(eps).reciprocal(), w, dim=dim, keepdim=keepdim, eps=eps).add(eps).reciprocal()
32
+
33
+
34
+ def geometric_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
35
+ if w is None:
36
+ return x.add(eps).log().mean(dim=dim).exp()
37
+ else:
38
+ w = w.to(x.dtype)
39
+ return weighted_mean(x.add(eps).log(), w, dim=dim, keepdim=keepdim, eps=eps).exp()
40
+
41
+
42
+ def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor:
43
+ "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
44
+ if aspect_ratio is None:
45
+ aspect_ratio = width / height
46
+
47
+ span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
48
+ span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
49
+
50
+ u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
51
+ v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
52
+ u, v = torch.meshgrid(u, v, indexing='xy')
53
+ uv = torch.stack([u, v], dim=-1)
54
+ return uv
55
+
56
+
57
+ def gaussian_blur_2d(input: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
58
+ kernel = torch.exp(-(torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=input.dtype, device=input.device) ** 2) / (2 * sigma ** 2))
59
+ kernel = kernel / kernel.sum()
60
+ kernel = (kernel[:, None] * kernel[None, :]).reshape(1, 1, kernel_size, kernel_size)
61
+ input = F.pad(input, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), mode='replicate')
62
+ input = F.conv2d(input, kernel, groups=input.shape[1])
63
+ return input
64
+
65
+
66
+ def focal_to_fov(focal: torch.Tensor):
67
+ return 2 * torch.atan(0.5 / focal)
68
+
69
+
70
+ def fov_to_focal(fov: torch.Tensor):
71
+ return 0.5 / torch.tan(fov / 2)
72
+
73
+
74
+ def intrinsics_to_fov(intrinsics: torch.Tensor):
75
+ """
76
+ Returns field of view in radians from normalized intrinsics matrix.
77
+ ### Parameters:
78
+ - intrinsics: torch.Tensor of shape (..., 3, 3)
79
+
80
+ ### Returns:
81
+ - fov_x: torch.Tensor of shape (...)
82
+ - fov_y: torch.Tensor of shape (...)
83
+ """
84
+ focal_x = intrinsics[..., 0, 0]
85
+ focal_y = intrinsics[..., 1, 1]
86
+ return 2 * torch.atan(0.5 / focal_x), 2 * torch.atan(0.5 / focal_y)
87
+
88
+
89
+ def point_map_to_depth_legacy(points: torch.Tensor):
90
+ height, width = points.shape[-3:-1]
91
+ diagonal = (height ** 2 + width ** 2) ** 0.5
92
+ uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
93
+
94
+ # Solve least squares problem
95
+ b = (uv * points[..., 2:]).flatten(-3, -1) # (..., H * W * 2)
96
+ A = torch.stack([points[..., :2], -uv.expand_as(points[..., :2])], dim=-1).flatten(-4, -2) # (..., H * W * 2, 2)
97
+
98
+ M = A.transpose(-2, -1) @ A
99
+ solution = (torch.inverse(M + 1e-6 * torch.eye(2).to(A)) @ (A.transpose(-2, -1) @ b[..., None])).squeeze(-1)
100
+ focal, shift = solution.unbind(-1)
101
+
102
+ depth = points[..., 2] + shift[..., None, None]
103
+ fov_x = torch.atan(width / diagonal / focal) * 2
104
+ fov_y = torch.atan(height / diagonal / focal) * 2
105
+ return depth, fov_x, fov_y, shift
106
+
107
+
108
+ def view_plane_uv_to_focal(uv: torch.Tensor):
109
+ normed_uv = normalized_view_plane_uv(width=uv.shape[-2], height=uv.shape[-3], device=uv.device, dtype=uv.dtype)
110
+ focal = (uv * normed_uv).sum() / uv.square().sum().add(1e-12)
111
+ return focal
112
+
113
+
114
+ def recover_focal_shift(points: torch.Tensor, mask: torch.Tensor = None, focal: torch.Tensor = None, downsample_size: Tuple[int, int] = (64, 64)):
115
+ """
116
+ Recover the depth map and FoV from a point map with unknown z shift and focal.
117
+
118
+ Note that it assumes:
119
+ - the optical center is at the center of the map
120
+ - the map is undistorted
121
+ - the map is isometric in the x and y directions
122
+
123
+ ### Parameters:
124
+ - `points: torch.Tensor` of shape (..., H, W, 3)
125
+ - `mask: torch.Tensor` of shape (..., H, W). Optional.
126
+ - `focal: torch.Tensor` of shape (...). Optional.
127
+ - `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps.
128
+
129
+ ### Returns:
130
+ - `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map
131
+ - `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space
132
+ """
133
+ shape = points.shape
134
+ height, width = points.shape[-3], points.shape[-2]
135
+ diagonal = (height ** 2 + width ** 2) ** 0.5
136
+
137
+ points = points.reshape(-1, *shape[-3:])
138
+ mask = None if mask is None else mask.reshape(-1, *shape[-3:-1])
139
+ focal = focal.reshape(-1) if focal is not None else None
140
+ uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
141
+
142
+ points_lr = F.interpolate(points.permute(0, 3, 1, 2), downsample_size, mode='nearest').permute(0, 2, 3, 1)
143
+ uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode='nearest').squeeze(0).permute(1, 2, 0)
144
+ mask_lr = None if mask is None else F.interpolate(mask.to(torch.float32).unsqueeze(1), downsample_size, mode='nearest').squeeze(1) > 0
145
+
146
+ uv_lr_np = uv_lr.cpu().numpy()
147
+ points_lr_np = points_lr.detach().cpu().numpy()
148
+ focal_np = focal.cpu().numpy() if focal is not None else None
149
+ mask_lr_np = None if mask is None else mask_lr.cpu().numpy()
150
+ optim_shift, optim_focal = [], []
151
+ for i in range(points.shape[0]):
152
+ points_lr_i_np = points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]]
153
+ uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]]
154
+ if focal is None:
155
+ optim_shift_i, optim_focal_i = solve_optimal_focal_shift(uv_lr_i_np, points_lr_i_np)
156
+ optim_focal.append(float(optim_focal_i))
157
+ else:
158
+ optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i])
159
+ optim_shift.append(float(optim_shift_i))
160
+ optim_shift = torch.tensor(optim_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3])
161
+
162
+ if focal is None:
163
+ optim_focal = torch.tensor(optim_focal, device=points.device, dtype=points.dtype).reshape(shape[:-3])
164
+ else:
165
+ optim_focal = focal.reshape(shape[:-3])
166
+
167
+ return optim_focal, optim_shift
168
+
169
+
170
+ def mask_aware_nearest_resize(mask: torch.BoolTensor, target_width: int, target_height: int) -> Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]:
171
+ """
172
+ Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
173
+
174
+ ### Parameters
175
+ - `mask`: Input 2D mask of shape (..., H, W)
176
+ - `target_width`: target width of the resized map
177
+ - `target_height`: target height of the resized map
178
+
179
+ ### Returns
180
+ - `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension
181
+ - `target_mask`: Mask of the resized map of shape (..., target_height, target_width)
182
+ """
183
+ height, width = mask.shape[-2:]
184
+ device = mask.device
185
+ filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
186
+ filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
187
+ filter_size = filter_h_i * filter_w_i
188
+ padding_h, padding_w = round(filter_h_f / 2), round(filter_w_f / 2)
189
+
190
+ # Window the original mask and uv
191
+ uv = utils3d.torch.image_pixel_center(width=width, height=height, dtype=torch.float32, device=device)
192
+ indices = torch.arange(height * width, dtype=torch.long, device=device).reshape(height, width)
193
+ padded_uv = torch.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=torch.float32, device=device)
194
+ padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
195
+ padded_mask = torch.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=torch.bool, device=device)
196
+ padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
197
+ padded_indices = torch.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=torch.long, device=device)
198
+ padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
199
+ windowed_uv = utils3d.torch.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, dim=(0, 1))
200
+ windowed_mask = utils3d.torch.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, dim=(-2, -1))
201
+ windowed_indices = utils3d.torch.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, dim=(0, 1))
202
+
203
+ # Gather the target pixels's local window
204
+ target_uv = utils3d.torch.image_uv(width=target_width, height=target_height, dtype=torch.float32, device=device) * torch.tensor([width, height], dtype=torch.float32, device=device)
205
+ target_corner = target_uv - torch.tensor((filter_w_f / 2, filter_h_f / 2), dtype=torch.float32, device=device)
206
+ target_corner = torch.round(target_corner - 0.5).long() + torch.tensor((padding_w, padding_h), dtype=torch.long, device=device)
207
+
208
+ target_window_uv = windowed_uv[target_corner[..., 1], target_corner[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
209
+ target_window_mask = windowed_mask[..., target_corner[..., 1], target_corner[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
210
+ target_window_indices = windowed_indices[target_corner[..., 1], target_corner[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
211
+ target_window_indices = target_window_indices.expand_as(target_window_mask)
212
+
213
+ # Compute nearest neighbor in the local window for each pixel
214
+ dist = torch.where(target_window_mask, torch.norm(target_window_uv - target_uv[..., None], dim=-2), torch.inf) # (..., target_height, tgt_width, filter_size)
215
+ nearest = torch.argmin(dist, dim=-1, keepdim=True) # (..., target_height, tgt_width, 1)
216
+ nearest_idx = torch.gather(target_window_indices, index=nearest, dim=-1).squeeze(-1) # (..., target_height, tgt_width)
217
+ target_mask = torch.any(target_window_mask, dim=-1)
218
+ nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
219
+ batch_indices = [torch.arange(n, device=device).reshape([1] * i + [n] + [1] * (mask.dim() - i - 1)) for i, n in enumerate(mask.shape[:-2])]
220
+
221
+ return (*batch_indices, nearest_i, nearest_j), target_mask
src/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
src/utils/vis.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import torchvision.transforms as transforms
2
+ # import torch.nn.functional as F
3
+ # import cv2
4
+ # import os
5
+ # import logging
6
+ # from pathlib import Path
7
+ import numpy as np
8
+ # import os
9
+ import torch
10
+ import matplotlib
11
+ # import cv2
12
+ # import random
13
+ # from PIL import Image
14
+ # import imageio
15
+
16
+ def prob_to_mask(prob):
17
+ """
18
+ Transforms a probability map of stopping points (shape: (n_layer+1, H, W))
19
+ into a binary mask (shape: (H, W, n_layer, 1)) where for each pixel, layers
20
+ with index ≤ stopping index (as given by argmax) are marked valid.
21
+ """
22
+ num_layer_plus1, H, W = prob.shape
23
+ # Get stopping index for each pixel; values are in {0, 1, ..., n_layer}
24
+ stopping_indices = torch.argmax(prob, dim=0) # (H, W)
25
+
26
+ # Create a tensor with layer indices [1, 2, ..., n_layer]
27
+ layer_indices = torch.arange(1, num_layer_plus1, device=prob.device).view(-1, 1, 1)
28
+
29
+ # Compare: a layer is valid if its index is <= the stopping index.
30
+ pred_mask = (layer_indices <= stopping_indices.unsqueeze(0))
31
+
32
+ # Permute and unsqueeze to get shape (H, W, n_layer, 1)
33
+ pred_mask = pred_mask.permute(1, 2, 0).unsqueeze(-1)
34
+ return pred_mask
35
+
36
+
37
+
38
+
39
+ def colorize(value, vmin=None, vmax=None, cmap='rainbow', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None):
40
+ """Converts a depth map to a color image.
41
+
42
+ Args:
43
+ value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed
44
+ vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None.
45
+ vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None.
46
+ cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'.
47
+ invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99.
48
+ invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None.
49
+ background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255).
50
+ gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False.
51
+ value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None.
52
+
53
+ Returns:
54
+ numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4)
55
+ """
56
+ if isinstance(value, torch.Tensor):
57
+ value = value.detach().cpu().numpy()
58
+
59
+ value = value.squeeze()
60
+ if invalid_mask is None:
61
+ invalid_mask = value == invalid_val
62
+ mask = np.logical_not(invalid_mask)
63
+
64
+ # normalize
65
+ vmin = np.percentile(value[mask],2) if vmin is None else vmin
66
+ vmax = np.percentile(value[mask],85) if vmax is None else vmax
67
+ if vmin != vmax:
68
+ value = (value - vmin) / (vmax - vmin) # vmin..vmax
69
+ else:
70
+ # Avoid 0-division
71
+ value = value * 0.
72
+
73
+ value[invalid_mask] = np.nan
74
+ cmapper = matplotlib.cm.get_cmap(cmap)
75
+ if value_transform:
76
+ value = value_transform(value)
77
+ # value = value / value.max()
78
+ value = cmapper(value, bytes=True) # (nxmx4)
79
+
80
+ # img = value[:, :, :]
81
+ img = value[...]
82
+ img[invalid_mask] = background_color
83
+
84
+ if gamma_corrected:
85
+ # gamma correction
86
+ img = img / 255
87
+ img = np.power(img, 2.2)
88
+ img = img * 255
89
+ img = img.astype(np.uint8)
90
+ return img
91
+
92
+
93
+
94
+ def denormalize(x):
95
+ """Reverses the imagenet normalization applied to the input.
96
+
97
+ Args:
98
+ x (torch.Tensor - shape(N,3,H,W)): input tensor
99
+
100
+ Returns:
101
+ torch.Tensor - shape(N,3,H,W): Denormalized input
102
+ """
103
+ mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
104
+ std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
105
+ return x * std + mean
src/utils3d/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # utils3d
2
+
3
+ This is a collection of utility functions for 3D computer vision tasks copied from https://github.com/EasternJournalist/utils3d.
src/utils3d/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A package for common utility functions in 3D computer graphics and vision. Providing NumPy utilities in `utils3d.numpy`, PyTorch utilities in `utils3d.torch`, and IO utilities in `utils3d.io`.
3
+ """
4
+ import importlib
5
+ from typing import TYPE_CHECKING
6
+
7
+ try:
8
+ from ._unified import *
9
+ except ImportError:
10
+ pass
11
+
12
+ __all__ = ['numpy', 'torch', 'io']
13
+
14
+ def __getattr__(name: str):
15
+ return globals().get(name, importlib.import_module(f'.{name}', __package__))
16
+
17
+ if TYPE_CHECKING:
18
+ from . import torch
19
+ from . import numpy
20
+ from . import io
src/utils3d/_helpers.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ import warnings
3
+
4
+
5
+ def suppress_traceback(fn):
6
+ @wraps(fn)
7
+ def wrapper(*args, **kwargs):
8
+ try:
9
+ return fn(*args, **kwargs)
10
+ except Exception as e:
11
+ e.__traceback__ = e.__traceback__.tb_next.tb_next
12
+ raise
13
+ return wrapper
14
+
15
+
16
+ class no_warnings:
17
+ def __init__(self, action: str = 'ignore', **kwargs):
18
+ self.action = action
19
+ self.filter_kwargs = kwargs
20
+
21
+ def __call__(self, fn):
22
+ @wraps(fn)
23
+ def wrapper(*args, **kwargs):
24
+ with warnings.catch_warnings():
25
+ warnings.simplefilter(self.action, **self.filter_kwargs)
26
+ return fn(*args, **kwargs)
27
+ return wrapper
28
+
29
+ def __enter__(self):
30
+ self.warnings_manager = warnings.catch_warnings()
31
+ self.warnings_manager.__enter__()
32
+ warnings.simplefilter(self.action, **self.filter_kwargs)
33
+
34
+ def __exit__(self, exc_type, exc_val, exc_tb):
35
+ self.warnings_manager.__exit__(exc_type, exc_val, exc_tb)
src/utils3d/_unified/__init__.py ADDED
@@ -0,0 +1,934 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Auto-generated implementation redirecting to numpy/torch implementations
2
+ import sys
3
+ from typing import TYPE_CHECKING
4
+ import utils3d
5
+ from .._helpers import suppress_traceback
6
+
7
+ __all__ = ["triangulate",
8
+ "compute_face_normal",
9
+ "compute_face_angle",
10
+ "compute_vertex_normal",
11
+ "compute_vertex_normal_weighted",
12
+ "remove_corrupted_faces",
13
+ "merge_duplicate_vertices",
14
+ "remove_unreferenced_vertices",
15
+ "subdivide_mesh_simple",
16
+ "mesh_relations",
17
+ "flatten_mesh_indices",
18
+ "calc_quad_candidates",
19
+ "calc_quad_distortion",
20
+ "calc_quad_direction",
21
+ "calc_quad_smoothness",
22
+ "sovle_quad",
23
+ "sovle_quad_qp",
24
+ "tri_to_quad",
25
+ "sliding_window_1d",
26
+ "sliding_window_nd",
27
+ "sliding_window_2d",
28
+ "max_pool_1d",
29
+ "max_pool_2d",
30
+ "max_pool_nd",
31
+ "depth_edge",
32
+ "normals_edge",
33
+ "depth_aliasing",
34
+ "interpolate",
35
+ "image_scrcoord",
36
+ "image_uv",
37
+ "image_pixel_center",
38
+ "image_pixel",
39
+ "image_mesh",
40
+ "image_mesh_from_depth",
41
+ "depth_to_normals",
42
+ "points_to_normals",
43
+ "chessboard",
44
+ "cube",
45
+ "icosahedron",
46
+ "square",
47
+ "camera_frustum",
48
+ "perspective",
49
+ "perspective_from_fov",
50
+ "perspective_from_fov_xy",
51
+ "intrinsics_from_focal_center",
52
+ "intrinsics_from_fov",
53
+ "fov_to_focal",
54
+ "focal_to_fov",
55
+ "intrinsics_to_fov",
56
+ "view_look_at",
57
+ "extrinsics_look_at",
58
+ "perspective_to_intrinsics",
59
+ "perspective_to_near_far",
60
+ "intrinsics_to_perspective",
61
+ "extrinsics_to_view",
62
+ "view_to_extrinsics",
63
+ "normalize_intrinsics",
64
+ "crop_intrinsics",
65
+ "pixel_to_uv",
66
+ "pixel_to_ndc",
67
+ "uv_to_pixel",
68
+ "project_depth",
69
+ "depth_buffer_to_linear",
70
+ "unproject_cv",
71
+ "unproject_gl",
72
+ "project_cv",
73
+ "project_gl",
74
+ "quaternion_to_matrix",
75
+ "axis_angle_to_matrix",
76
+ "matrix_to_quaternion",
77
+ "extrinsics_to_essential",
78
+ "euler_axis_angle_rotation",
79
+ "euler_angles_to_matrix",
80
+ "skew_symmetric",
81
+ "rotation_matrix_from_vectors",
82
+ "ray_intersection",
83
+ "se3_matrix",
84
+ "slerp_quaternion",
85
+ "slerp_vector",
86
+ "lerp",
87
+ "lerp_se3_matrix",
88
+ "piecewise_lerp",
89
+ "piecewise_lerp_se3_matrix",
90
+ "apply_transform",
91
+ "linear_spline_interpolate",
92
+ "RastContext",
93
+ "rasterize_triangle_faces",
94
+ "rasterize_edges",
95
+ "texture",
96
+ "warp_image_by_depth",
97
+ "test_rasterization",
98
+ "compute_face_angles",
99
+ "compute_face_tbn",
100
+ "compute_vertex_tbn",
101
+ "laplacian",
102
+ "laplacian_smooth_mesh",
103
+ "taubin_smooth_mesh",
104
+ "laplacian_hc_smooth_mesh",
105
+ "get_rays",
106
+ "get_image_rays",
107
+ "get_mipnerf_cones",
108
+ "volume_rendering",
109
+ "bin_sample",
110
+ "importance_sample",
111
+ "nerf_render_rays",
112
+ "mipnerf_render_rays",
113
+ "nerf_render_view",
114
+ "mipnerf_render_view",
115
+ "InstantNGP",
116
+ "point_to_normal",
117
+ "depth_to_normal",
118
+ "masked_min",
119
+ "masked_max",
120
+ "bounding_rect",
121
+ "intrinsics_from_fov_xy",
122
+ "matrix_to_euler_angles",
123
+ "matrix_to_axis_angle",
124
+ "axis_angle_to_quaternion",
125
+ "quaternion_to_axis_angle",
126
+ "slerp",
127
+ "interpolate_extrinsics",
128
+ "interpolate_view",
129
+ "to4x4",
130
+ "rotation_matrix_2d",
131
+ "rotate_2d",
132
+ "translate_2d",
133
+ "scale_2d",
134
+ "apply_2d",
135
+ "warp_image_by_forward_flow"]
136
+
137
+ def _contains_tensor(obj):
138
+ if isinstance(obj, (list, tuple)):
139
+ return any(_contains_tensor(item) for item in obj)
140
+ elif isinstance(obj, dict):
141
+ return any(_contains_tensor(value) for value in obj.values())
142
+ else:
143
+ import torch
144
+ return isinstance(obj, torch.Tensor)
145
+
146
+
147
+ @suppress_traceback
148
+ def _call_based_on_args(fname, args, kwargs):
149
+ if 'torch' in sys.modules:
150
+ if any(_contains_tensor(arg) for arg in args) or any(_contains_tensor(v) for v in kwargs.values()):
151
+ fn = getattr(utils3d.torch, fname, None)
152
+ if fn is None:
153
+ raise NotImplementedError(f"Function {fname} has no torch implementation.")
154
+ return fn(*args, **kwargs)
155
+ fn = getattr(utils3d.numpy, fname, None)
156
+ if fn is None:
157
+ raise NotImplementedError(f"Function {fname} has no numpy implementation.")
158
+ return fn(*args, **kwargs)
159
+
160
+
161
+ @suppress_traceback
162
+ def triangulate(*args, **kwargs):
163
+ if TYPE_CHECKING: # redirected to:
164
+ utils3d.numpy.triangulate, utils3d.torch.triangulate
165
+ return _call_based_on_args('triangulate', args, kwargs)
166
+
167
+ @suppress_traceback
168
+ def compute_face_normal(*args, **kwargs):
169
+ if TYPE_CHECKING: # redirected to:
170
+ utils3d.numpy.compute_face_normal, utils3d.torch.compute_face_normal
171
+ return _call_based_on_args('compute_face_normal', args, kwargs)
172
+
173
+ @suppress_traceback
174
+ def compute_face_angle(*args, **kwargs):
175
+ if TYPE_CHECKING: # redirected to:
176
+ utils3d.numpy.compute_face_angle, None
177
+ return _call_based_on_args('compute_face_angle', args, kwargs)
178
+
179
+ @suppress_traceback
180
+ def compute_vertex_normal(*args, **kwargs):
181
+ if TYPE_CHECKING: # redirected to:
182
+ utils3d.numpy.compute_vertex_normal, utils3d.torch.compute_vertex_normal
183
+ return _call_based_on_args('compute_vertex_normal', args, kwargs)
184
+
185
+ @suppress_traceback
186
+ def compute_vertex_normal_weighted(*args, **kwargs):
187
+ if TYPE_CHECKING: # redirected to:
188
+ utils3d.numpy.compute_vertex_normal_weighted, utils3d.torch.compute_vertex_normal_weighted
189
+ return _call_based_on_args('compute_vertex_normal_weighted', args, kwargs)
190
+
191
+ @suppress_traceback
192
+ def remove_corrupted_faces(*args, **kwargs):
193
+ if TYPE_CHECKING: # redirected to:
194
+ utils3d.numpy.remove_corrupted_faces, utils3d.torch.remove_corrupted_faces
195
+ return _call_based_on_args('remove_corrupted_faces', args, kwargs)
196
+
197
+ @suppress_traceback
198
+ def merge_duplicate_vertices(*args, **kwargs):
199
+ if TYPE_CHECKING: # redirected to:
200
+ utils3d.numpy.merge_duplicate_vertices, utils3d.torch.merge_duplicate_vertices
201
+ return _call_based_on_args('merge_duplicate_vertices', args, kwargs)
202
+
203
+ @suppress_traceback
204
+ def remove_unreferenced_vertices(*args, **kwargs):
205
+ if TYPE_CHECKING: # redirected to:
206
+ utils3d.numpy.remove_unreferenced_vertices, utils3d.torch.remove_unreferenced_vertices
207
+ return _call_based_on_args('remove_unreferenced_vertices', args, kwargs)
208
+
209
+ @suppress_traceback
210
+ def subdivide_mesh_simple(*args, **kwargs):
211
+ if TYPE_CHECKING: # redirected to:
212
+ utils3d.numpy.subdivide_mesh_simple, utils3d.torch.subdivide_mesh_simple
213
+ return _call_based_on_args('subdivide_mesh_simple', args, kwargs)
214
+
215
+ @suppress_traceback
216
+ def mesh_relations(*args, **kwargs):
217
+ if TYPE_CHECKING: # redirected to:
218
+ utils3d.numpy.mesh_relations, None
219
+ return _call_based_on_args('mesh_relations', args, kwargs)
220
+
221
+ @suppress_traceback
222
+ def flatten_mesh_indices(*args, **kwargs):
223
+ if TYPE_CHECKING: # redirected to:
224
+ utils3d.numpy.flatten_mesh_indices, None
225
+ return _call_based_on_args('flatten_mesh_indices', args, kwargs)
226
+
227
+ @suppress_traceback
228
+ def calc_quad_candidates(*args, **kwargs):
229
+ if TYPE_CHECKING: # redirected to:
230
+ utils3d.numpy.calc_quad_candidates, None
231
+ return _call_based_on_args('calc_quad_candidates', args, kwargs)
232
+
233
+ @suppress_traceback
234
+ def calc_quad_distortion(*args, **kwargs):
235
+ if TYPE_CHECKING: # redirected to:
236
+ utils3d.numpy.calc_quad_distortion, None
237
+ return _call_based_on_args('calc_quad_distortion', args, kwargs)
238
+
239
+ @suppress_traceback
240
+ def calc_quad_direction(*args, **kwargs):
241
+ if TYPE_CHECKING: # redirected to:
242
+ utils3d.numpy.calc_quad_direction, None
243
+ return _call_based_on_args('calc_quad_direction', args, kwargs)
244
+
245
+ @suppress_traceback
246
+ def calc_quad_smoothness(*args, **kwargs):
247
+ if TYPE_CHECKING: # redirected to:
248
+ utils3d.numpy.calc_quad_smoothness, None
249
+ return _call_based_on_args('calc_quad_smoothness', args, kwargs)
250
+
251
+ @suppress_traceback
252
+ def sovle_quad(*args, **kwargs):
253
+ if TYPE_CHECKING: # redirected to:
254
+ utils3d.numpy.sovle_quad, None
255
+ return _call_based_on_args('sovle_quad', args, kwargs)
256
+
257
+ @suppress_traceback
258
+ def sovle_quad_qp(*args, **kwargs):
259
+ if TYPE_CHECKING: # redirected to:
260
+ utils3d.numpy.sovle_quad_qp, None
261
+ return _call_based_on_args('sovle_quad_qp', args, kwargs)
262
+
263
+ @suppress_traceback
264
+ def tri_to_quad(*args, **kwargs):
265
+ if TYPE_CHECKING: # redirected to:
266
+ utils3d.numpy.tri_to_quad, None
267
+ return _call_based_on_args('tri_to_quad', args, kwargs)
268
+
269
+ @suppress_traceback
270
+ def sliding_window_1d(*args, **kwargs):
271
+ if TYPE_CHECKING: # redirected to:
272
+ utils3d.numpy.sliding_window_1d, utils3d.torch.sliding_window_1d
273
+ return _call_based_on_args('sliding_window_1d', args, kwargs)
274
+
275
+ @suppress_traceback
276
+ def sliding_window_nd(*args, **kwargs):
277
+ if TYPE_CHECKING: # redirected to:
278
+ utils3d.numpy.sliding_window_nd, utils3d.torch.sliding_window_nd
279
+ return _call_based_on_args('sliding_window_nd', args, kwargs)
280
+
281
+ @suppress_traceback
282
+ def sliding_window_2d(*args, **kwargs):
283
+ if TYPE_CHECKING: # redirected to:
284
+ utils3d.numpy.sliding_window_2d, utils3d.torch.sliding_window_2d
285
+ return _call_based_on_args('sliding_window_2d', args, kwargs)
286
+
287
+ @suppress_traceback
288
+ def max_pool_1d(*args, **kwargs):
289
+ if TYPE_CHECKING: # redirected to:
290
+ utils3d.numpy.max_pool_1d, None
291
+ return _call_based_on_args('max_pool_1d', args, kwargs)
292
+
293
+ @suppress_traceback
294
+ def max_pool_2d(*args, **kwargs):
295
+ if TYPE_CHECKING: # redirected to:
296
+ utils3d.numpy.max_pool_2d, None
297
+ return _call_based_on_args('max_pool_2d', args, kwargs)
298
+
299
+ @suppress_traceback
300
+ def max_pool_nd(*args, **kwargs):
301
+ if TYPE_CHECKING: # redirected to:
302
+ utils3d.numpy.max_pool_nd, None
303
+ return _call_based_on_args('max_pool_nd', args, kwargs)
304
+
305
+ @suppress_traceback
306
+ def depth_edge(*args, **kwargs):
307
+ if TYPE_CHECKING: # redirected to:
308
+ utils3d.numpy.depth_edge, utils3d.torch.depth_edge
309
+ return _call_based_on_args('depth_edge', args, kwargs)
310
+
311
+ @suppress_traceback
312
+ def normals_edge(*args, **kwargs):
313
+ if TYPE_CHECKING: # redirected to:
314
+ utils3d.numpy.normals_edge, None
315
+ return _call_based_on_args('normals_edge', args, kwargs)
316
+
317
+ @suppress_traceback
318
+ def depth_aliasing(*args, **kwargs):
319
+ if TYPE_CHECKING: # redirected to:
320
+ utils3d.numpy.depth_aliasing, utils3d.torch.depth_aliasing
321
+ return _call_based_on_args('depth_aliasing', args, kwargs)
322
+
323
+ @suppress_traceback
324
+ def interpolate(*args, **kwargs):
325
+ if TYPE_CHECKING: # redirected to:
326
+ utils3d.numpy.interpolate, None
327
+ return _call_based_on_args('interpolate', args, kwargs)
328
+
329
+ @suppress_traceback
330
+ def image_scrcoord(*args, **kwargs):
331
+ if TYPE_CHECKING: # redirected to:
332
+ utils3d.numpy.image_scrcoord, None
333
+ return _call_based_on_args('image_scrcoord', args, kwargs)
334
+
335
+ @suppress_traceback
336
+ def image_uv(*args, **kwargs):
337
+ if TYPE_CHECKING: # redirected to:
338
+ utils3d.numpy.image_uv, utils3d.torch.image_uv
339
+ return _call_based_on_args('image_uv', args, kwargs)
340
+
341
+ @suppress_traceback
342
+ def image_pixel_center(*args, **kwargs):
343
+ if TYPE_CHECKING: # redirected to:
344
+ utils3d.numpy.image_pixel_center, utils3d.torch.image_pixel_center
345
+ return _call_based_on_args('image_pixel_center', args, kwargs)
346
+
347
+ @suppress_traceback
348
+ def image_pixel(*args, **kwargs):
349
+ if TYPE_CHECKING: # redirected to:
350
+ utils3d.numpy.image_pixel, None
351
+ return _call_based_on_args('image_pixel', args, kwargs)
352
+
353
+ @suppress_traceback
354
+ def image_mesh(*args, **kwargs):
355
+ if TYPE_CHECKING: # redirected to:
356
+ utils3d.numpy.image_mesh, utils3d.torch.image_mesh
357
+ return _call_based_on_args('image_mesh', args, kwargs)
358
+
359
+ @suppress_traceback
360
+ def image_mesh_from_depth(*args, **kwargs):
361
+ if TYPE_CHECKING: # redirected to:
362
+ utils3d.numpy.image_mesh_from_depth, utils3d.torch.image_mesh_from_depth
363
+ return _call_based_on_args('image_mesh_from_depth', args, kwargs)
364
+
365
+ @suppress_traceback
366
+ def depth_to_normals(*args, **kwargs):
367
+ if TYPE_CHECKING: # redirected to:
368
+ utils3d.numpy.depth_to_normals, None
369
+ return _call_based_on_args('depth_to_normals', args, kwargs)
370
+
371
+ @suppress_traceback
372
+ def points_to_normals(*args, **kwargs):
373
+ if TYPE_CHECKING: # redirected to:
374
+ utils3d.numpy.points_to_normals, None
375
+ return _call_based_on_args('points_to_normals', args, kwargs)
376
+
377
+ @suppress_traceback
378
+ def chessboard(*args, **kwargs):
379
+ if TYPE_CHECKING: # redirected to:
380
+ utils3d.numpy.chessboard, utils3d.torch.chessboard
381
+ return _call_based_on_args('chessboard', args, kwargs)
382
+
383
+ @suppress_traceback
384
+ def cube(*args, **kwargs):
385
+ if TYPE_CHECKING: # redirected to:
386
+ utils3d.numpy.cube, None
387
+ return _call_based_on_args('cube', args, kwargs)
388
+
389
+ @suppress_traceback
390
+ def icosahedron(*args, **kwargs):
391
+ if TYPE_CHECKING: # redirected to:
392
+ utils3d.numpy.icosahedron, None
393
+ return _call_based_on_args('icosahedron', args, kwargs)
394
+
395
+ @suppress_traceback
396
+ def square(*args, **kwargs):
397
+ if TYPE_CHECKING: # redirected to:
398
+ utils3d.numpy.square, None
399
+ return _call_based_on_args('square', args, kwargs)
400
+
401
+ @suppress_traceback
402
+ def camera_frustum(*args, **kwargs):
403
+ if TYPE_CHECKING: # redirected to:
404
+ utils3d.numpy.camera_frustum, None
405
+ return _call_based_on_args('camera_frustum', args, kwargs)
406
+
407
+ @suppress_traceback
408
+ def perspective(*args, **kwargs):
409
+ if TYPE_CHECKING: # redirected to:
410
+ utils3d.numpy.perspective, utils3d.torch.perspective
411
+ return _call_based_on_args('perspective', args, kwargs)
412
+
413
+ @suppress_traceback
414
+ def perspective_from_fov(*args, **kwargs):
415
+ if TYPE_CHECKING: # redirected to:
416
+ utils3d.numpy.perspective_from_fov, utils3d.torch.perspective_from_fov
417
+ return _call_based_on_args('perspective_from_fov', args, kwargs)
418
+
419
+ @suppress_traceback
420
+ def perspective_from_fov_xy(*args, **kwargs):
421
+ if TYPE_CHECKING: # redirected to:
422
+ utils3d.numpy.perspective_from_fov_xy, utils3d.torch.perspective_from_fov_xy
423
+ return _call_based_on_args('perspective_from_fov_xy', args, kwargs)
424
+
425
+ @suppress_traceback
426
+ def intrinsics_from_focal_center(*args, **kwargs):
427
+ if TYPE_CHECKING: # redirected to:
428
+ utils3d.numpy.intrinsics_from_focal_center, utils3d.torch.intrinsics_from_focal_center
429
+ return _call_based_on_args('intrinsics_from_focal_center', args, kwargs)
430
+
431
+ @suppress_traceback
432
+ def intrinsics_from_fov(*args, **kwargs):
433
+ if TYPE_CHECKING: # redirected to:
434
+ utils3d.numpy.intrinsics_from_fov, utils3d.torch.intrinsics_from_fov
435
+ return _call_based_on_args('intrinsics_from_fov', args, kwargs)
436
+
437
+ @suppress_traceback
438
+ def fov_to_focal(*args, **kwargs):
439
+ if TYPE_CHECKING: # redirected to:
440
+ utils3d.numpy.fov_to_focal, None
441
+ return _call_based_on_args('fov_to_focal', args, kwargs)
442
+
443
+ @suppress_traceback
444
+ def focal_to_fov(*args, **kwargs):
445
+ if TYPE_CHECKING: # redirected to:
446
+ utils3d.numpy.focal_to_fov, None
447
+ return _call_based_on_args('focal_to_fov', args, kwargs)
448
+
449
+ @suppress_traceback
450
+ def intrinsics_to_fov(*args, **kwargs):
451
+ if TYPE_CHECKING: # redirected to:
452
+ utils3d.numpy.intrinsics_to_fov, None
453
+ return _call_based_on_args('intrinsics_to_fov', args, kwargs)
454
+
455
+ @suppress_traceback
456
+ def view_look_at(*args, **kwargs):
457
+ if TYPE_CHECKING: # redirected to:
458
+ utils3d.numpy.view_look_at, utils3d.torch.view_look_at
459
+ return _call_based_on_args('view_look_at', args, kwargs)
460
+
461
+ @suppress_traceback
462
+ def extrinsics_look_at(*args, **kwargs):
463
+ if TYPE_CHECKING: # redirected to:
464
+ utils3d.numpy.extrinsics_look_at, utils3d.torch.extrinsics_look_at
465
+ return _call_based_on_args('extrinsics_look_at', args, kwargs)
466
+
467
+ @suppress_traceback
468
+ def perspective_to_intrinsics(*args, **kwargs):
469
+ if TYPE_CHECKING: # redirected to:
470
+ utils3d.numpy.perspective_to_intrinsics, utils3d.torch.perspective_to_intrinsics
471
+ return _call_based_on_args('perspective_to_intrinsics', args, kwargs)
472
+
473
+ @suppress_traceback
474
+ def perspective_to_near_far(*args, **kwargs):
475
+ if TYPE_CHECKING: # redirected to:
476
+ utils3d.numpy.perspective_to_near_far, None
477
+ return _call_based_on_args('perspective_to_near_far', args, kwargs)
478
+
479
+ @suppress_traceback
480
+ def intrinsics_to_perspective(*args, **kwargs):
481
+ if TYPE_CHECKING: # redirected to:
482
+ utils3d.numpy.intrinsics_to_perspective, utils3d.torch.intrinsics_to_perspective
483
+ return _call_based_on_args('intrinsics_to_perspective', args, kwargs)
484
+
485
+ @suppress_traceback
486
+ def extrinsics_to_view(*args, **kwargs):
487
+ if TYPE_CHECKING: # redirected to:
488
+ utils3d.numpy.extrinsics_to_view, utils3d.torch.extrinsics_to_view
489
+ return _call_based_on_args('extrinsics_to_view', args, kwargs)
490
+
491
+ @suppress_traceback
492
+ def view_to_extrinsics(*args, **kwargs):
493
+ if TYPE_CHECKING: # redirected to:
494
+ utils3d.numpy.view_to_extrinsics, utils3d.torch.view_to_extrinsics
495
+ return _call_based_on_args('view_to_extrinsics', args, kwargs)
496
+
497
+ @suppress_traceback
498
+ def normalize_intrinsics(*args, **kwargs):
499
+ if TYPE_CHECKING: # redirected to:
500
+ utils3d.numpy.normalize_intrinsics, utils3d.torch.normalize_intrinsics
501
+ return _call_based_on_args('normalize_intrinsics', args, kwargs)
502
+
503
+ @suppress_traceback
504
+ def crop_intrinsics(*args, **kwargs):
505
+ if TYPE_CHECKING: # redirected to:
506
+ utils3d.numpy.crop_intrinsics, utils3d.torch.crop_intrinsics
507
+ return _call_based_on_args('crop_intrinsics', args, kwargs)
508
+
509
+ @suppress_traceback
510
+ def pixel_to_uv(*args, **kwargs):
511
+ if TYPE_CHECKING: # redirected to:
512
+ utils3d.numpy.pixel_to_uv, utils3d.torch.pixel_to_uv
513
+ return _call_based_on_args('pixel_to_uv', args, kwargs)
514
+
515
+ @suppress_traceback
516
+ def pixel_to_ndc(*args, **kwargs):
517
+ if TYPE_CHECKING: # redirected to:
518
+ utils3d.numpy.pixel_to_ndc, utils3d.torch.pixel_to_ndc
519
+ return _call_based_on_args('pixel_to_ndc', args, kwargs)
520
+
521
+ @suppress_traceback
522
+ def uv_to_pixel(*args, **kwargs):
523
+ if TYPE_CHECKING: # redirected to:
524
+ utils3d.numpy.uv_to_pixel, utils3d.torch.uv_to_pixel
525
+ return _call_based_on_args('uv_to_pixel', args, kwargs)
526
+
527
+ @suppress_traceback
528
+ def project_depth(*args, **kwargs):
529
+ if TYPE_CHECKING: # redirected to:
530
+ utils3d.numpy.project_depth, utils3d.torch.project_depth
531
+ return _call_based_on_args('project_depth', args, kwargs)
532
+
533
+ @suppress_traceback
534
+ def depth_buffer_to_linear(*args, **kwargs):
535
+ if TYPE_CHECKING: # redirected to:
536
+ utils3d.numpy.depth_buffer_to_linear, utils3d.torch.depth_buffer_to_linear
537
+ return _call_based_on_args('depth_buffer_to_linear', args, kwargs)
538
+
539
+ @suppress_traceback
540
+ def unproject_cv(*args, **kwargs):
541
+ if TYPE_CHECKING: # redirected to:
542
+ utils3d.numpy.unproject_cv, utils3d.torch.unproject_cv
543
+ return _call_based_on_args('unproject_cv', args, kwargs)
544
+
545
+ @suppress_traceback
546
+ def unproject_gl(*args, **kwargs):
547
+ if TYPE_CHECKING: # redirected to:
548
+ utils3d.numpy.unproject_gl, utils3d.torch.unproject_gl
549
+ return _call_based_on_args('unproject_gl', args, kwargs)
550
+
551
+ @suppress_traceback
552
+ def project_cv(*args, **kwargs):
553
+ if TYPE_CHECKING: # redirected to:
554
+ utils3d.numpy.project_cv, utils3d.torch.project_cv
555
+ return _call_based_on_args('project_cv', args, kwargs)
556
+
557
+ @suppress_traceback
558
+ def project_gl(*args, **kwargs):
559
+ if TYPE_CHECKING: # redirected to:
560
+ utils3d.numpy.project_gl, utils3d.torch.project_gl
561
+ return _call_based_on_args('project_gl', args, kwargs)
562
+
563
+ @suppress_traceback
564
+ def quaternion_to_matrix(*args, **kwargs):
565
+ if TYPE_CHECKING: # redirected to:
566
+ utils3d.numpy.quaternion_to_matrix, utils3d.torch.quaternion_to_matrix
567
+ return _call_based_on_args('quaternion_to_matrix', args, kwargs)
568
+
569
+ @suppress_traceback
570
+ def axis_angle_to_matrix(*args, **kwargs):
571
+ if TYPE_CHECKING: # redirected to:
572
+ utils3d.numpy.axis_angle_to_matrix, utils3d.torch.axis_angle_to_matrix
573
+ return _call_based_on_args('axis_angle_to_matrix', args, kwargs)
574
+
575
+ @suppress_traceback
576
+ def matrix_to_quaternion(*args, **kwargs):
577
+ if TYPE_CHECKING: # redirected to:
578
+ utils3d.numpy.matrix_to_quaternion, utils3d.torch.matrix_to_quaternion
579
+ return _call_based_on_args('matrix_to_quaternion', args, kwargs)
580
+
581
+ @suppress_traceback
582
+ def extrinsics_to_essential(*args, **kwargs):
583
+ if TYPE_CHECKING: # redirected to:
584
+ utils3d.numpy.extrinsics_to_essential, utils3d.torch.extrinsics_to_essential
585
+ return _call_based_on_args('extrinsics_to_essential', args, kwargs)
586
+
587
+ @suppress_traceback
588
+ def euler_axis_angle_rotation(*args, **kwargs):
589
+ if TYPE_CHECKING: # redirected to:
590
+ utils3d.numpy.euler_axis_angle_rotation, utils3d.torch.euler_axis_angle_rotation
591
+ return _call_based_on_args('euler_axis_angle_rotation', args, kwargs)
592
+
593
+ @suppress_traceback
594
+ def euler_angles_to_matrix(*args, **kwargs):
595
+ if TYPE_CHECKING: # redirected to:
596
+ utils3d.numpy.euler_angles_to_matrix, utils3d.torch.euler_angles_to_matrix
597
+ return _call_based_on_args('euler_angles_to_matrix', args, kwargs)
598
+
599
+ @suppress_traceback
600
+ def skew_symmetric(*args, **kwargs):
601
+ if TYPE_CHECKING: # redirected to:
602
+ utils3d.numpy.skew_symmetric, utils3d.torch.skew_symmetric
603
+ return _call_based_on_args('skew_symmetric', args, kwargs)
604
+
605
+ @suppress_traceback
606
+ def rotation_matrix_from_vectors(*args, **kwargs):
607
+ if TYPE_CHECKING: # redirected to:
608
+ utils3d.numpy.rotation_matrix_from_vectors, utils3d.torch.rotation_matrix_from_vectors
609
+ return _call_based_on_args('rotation_matrix_from_vectors', args, kwargs)
610
+
611
+ @suppress_traceback
612
+ def ray_intersection(*args, **kwargs):
613
+ if TYPE_CHECKING: # redirected to:
614
+ utils3d.numpy.ray_intersection, None
615
+ return _call_based_on_args('ray_intersection', args, kwargs)
616
+
617
+ @suppress_traceback
618
+ def se3_matrix(*args, **kwargs):
619
+ if TYPE_CHECKING: # redirected to:
620
+ utils3d.numpy.se3_matrix, None
621
+ return _call_based_on_args('se3_matrix', args, kwargs)
622
+
623
+ @suppress_traceback
624
+ def slerp_quaternion(*args, **kwargs):
625
+ if TYPE_CHECKING: # redirected to:
626
+ utils3d.numpy.slerp_quaternion, None
627
+ return _call_based_on_args('slerp_quaternion', args, kwargs)
628
+
629
+ @suppress_traceback
630
+ def slerp_vector(*args, **kwargs):
631
+ if TYPE_CHECKING: # redirected to:
632
+ utils3d.numpy.slerp_vector, None
633
+ return _call_based_on_args('slerp_vector', args, kwargs)
634
+
635
+ @suppress_traceback
636
+ def lerp(*args, **kwargs):
637
+ if TYPE_CHECKING: # redirected to:
638
+ utils3d.numpy.lerp, None
639
+ return _call_based_on_args('lerp', args, kwargs)
640
+
641
+ @suppress_traceback
642
+ def lerp_se3_matrix(*args, **kwargs):
643
+ if TYPE_CHECKING: # redirected to:
644
+ utils3d.numpy.lerp_se3_matrix, None
645
+ return _call_based_on_args('lerp_se3_matrix', args, kwargs)
646
+
647
+ @suppress_traceback
648
+ def piecewise_lerp(*args, **kwargs):
649
+ if TYPE_CHECKING: # redirected to:
650
+ utils3d.numpy.piecewise_lerp, None
651
+ return _call_based_on_args('piecewise_lerp', args, kwargs)
652
+
653
+ @suppress_traceback
654
+ def piecewise_lerp_se3_matrix(*args, **kwargs):
655
+ if TYPE_CHECKING: # redirected to:
656
+ utils3d.numpy.piecewise_lerp_se3_matrix, None
657
+ return _call_based_on_args('piecewise_lerp_se3_matrix', args, kwargs)
658
+
659
+ @suppress_traceback
660
+ def apply_transform(*args, **kwargs):
661
+ if TYPE_CHECKING: # redirected to:
662
+ utils3d.numpy.apply_transform, None
663
+ return _call_based_on_args('apply_transform', args, kwargs)
664
+
665
+ @suppress_traceback
666
+ def linear_spline_interpolate(*args, **kwargs):
667
+ if TYPE_CHECKING: # redirected to:
668
+ utils3d.numpy.linear_spline_interpolate, None
669
+ return _call_based_on_args('linear_spline_interpolate', args, kwargs)
670
+
671
+ @suppress_traceback
672
+ def RastContext(*args, **kwargs):
673
+ if TYPE_CHECKING: # redirected to:
674
+ utils3d.numpy.RastContext, utils3d.torch.RastContext
675
+ return _call_based_on_args('RastContext', args, kwargs)
676
+
677
+ @suppress_traceback
678
+ def rasterize_triangle_faces(*args, **kwargs):
679
+ if TYPE_CHECKING: # redirected to:
680
+ utils3d.numpy.rasterize_triangle_faces, utils3d.torch.rasterize_triangle_faces
681
+ return _call_based_on_args('rasterize_triangle_faces', args, kwargs)
682
+
683
+ @suppress_traceback
684
+ def rasterize_edges(*args, **kwargs):
685
+ if TYPE_CHECKING: # redirected to:
686
+ utils3d.numpy.rasterize_edges, None
687
+ return _call_based_on_args('rasterize_edges', args, kwargs)
688
+
689
+ @suppress_traceback
690
+ def texture(*args, **kwargs):
691
+ if TYPE_CHECKING: # redirected to:
692
+ utils3d.numpy.texture, None
693
+ return _call_based_on_args('texture', args, kwargs)
694
+
695
+ @suppress_traceback
696
+ def warp_image_by_depth(*args, **kwargs):
697
+ if TYPE_CHECKING: # redirected to:
698
+ utils3d.numpy.warp_image_by_depth, utils3d.torch.warp_image_by_depth
699
+ return _call_based_on_args('warp_image_by_depth', args, kwargs)
700
+
701
+ @suppress_traceback
702
+ def test_rasterization(*args, **kwargs):
703
+ if TYPE_CHECKING: # redirected to:
704
+ utils3d.numpy.test_rasterization, None
705
+ return _call_based_on_args('test_rasterization', args, kwargs)
706
+
707
+ @suppress_traceback
708
+ def compute_face_angles(*args, **kwargs):
709
+ if TYPE_CHECKING: # redirected to:
710
+ None, utils3d.torch.compute_face_angles
711
+ return _call_based_on_args('compute_face_angles', args, kwargs)
712
+
713
+ @suppress_traceback
714
+ def compute_face_tbn(*args, **kwargs):
715
+ if TYPE_CHECKING: # redirected to:
716
+ None, utils3d.torch.compute_face_tbn
717
+ return _call_based_on_args('compute_face_tbn', args, kwargs)
718
+
719
+ @suppress_traceback
720
+ def compute_vertex_tbn(*args, **kwargs):
721
+ if TYPE_CHECKING: # redirected to:
722
+ None, utils3d.torch.compute_vertex_tbn
723
+ return _call_based_on_args('compute_vertex_tbn', args, kwargs)
724
+
725
+ @suppress_traceback
726
+ def laplacian(*args, **kwargs):
727
+ if TYPE_CHECKING: # redirected to:
728
+ None, utils3d.torch.laplacian
729
+ return _call_based_on_args('laplacian', args, kwargs)
730
+
731
+ @suppress_traceback
732
+ def laplacian_smooth_mesh(*args, **kwargs):
733
+ if TYPE_CHECKING: # redirected to:
734
+ None, utils3d.torch.laplacian_smooth_mesh
735
+ return _call_based_on_args('laplacian_smooth_mesh', args, kwargs)
736
+
737
+ @suppress_traceback
738
+ def taubin_smooth_mesh(*args, **kwargs):
739
+ if TYPE_CHECKING: # redirected to:
740
+ None, utils3d.torch.taubin_smooth_mesh
741
+ return _call_based_on_args('taubin_smooth_mesh', args, kwargs)
742
+
743
+ @suppress_traceback
744
+ def laplacian_hc_smooth_mesh(*args, **kwargs):
745
+ if TYPE_CHECKING: # redirected to:
746
+ None, utils3d.torch.laplacian_hc_smooth_mesh
747
+ return _call_based_on_args('laplacian_hc_smooth_mesh', args, kwargs)
748
+
749
+ @suppress_traceback
750
+ def get_rays(*args, **kwargs):
751
+ if TYPE_CHECKING: # redirected to:
752
+ None, utils3d.torch.get_rays
753
+ return _call_based_on_args('get_rays', args, kwargs)
754
+
755
+ @suppress_traceback
756
+ def get_image_rays(*args, **kwargs):
757
+ if TYPE_CHECKING: # redirected to:
758
+ None, utils3d.torch.get_image_rays
759
+ return _call_based_on_args('get_image_rays', args, kwargs)
760
+
761
+ @suppress_traceback
762
+ def get_mipnerf_cones(*args, **kwargs):
763
+ if TYPE_CHECKING: # redirected to:
764
+ None, utils3d.torch.get_mipnerf_cones
765
+ return _call_based_on_args('get_mipnerf_cones', args, kwargs)
766
+
767
+ @suppress_traceback
768
+ def volume_rendering(*args, **kwargs):
769
+ if TYPE_CHECKING: # redirected to:
770
+ None, utils3d.torch.volume_rendering
771
+ return _call_based_on_args('volume_rendering', args, kwargs)
772
+
773
+ @suppress_traceback
774
+ def bin_sample(*args, **kwargs):
775
+ if TYPE_CHECKING: # redirected to:
776
+ None, utils3d.torch.bin_sample
777
+ return _call_based_on_args('bin_sample', args, kwargs)
778
+
779
+ @suppress_traceback
780
+ def importance_sample(*args, **kwargs):
781
+ if TYPE_CHECKING: # redirected to:
782
+ None, utils3d.torch.importance_sample
783
+ return _call_based_on_args('importance_sample', args, kwargs)
784
+
785
+ @suppress_traceback
786
+ def nerf_render_rays(*args, **kwargs):
787
+ if TYPE_CHECKING: # redirected to:
788
+ None, utils3d.torch.nerf_render_rays
789
+ return _call_based_on_args('nerf_render_rays', args, kwargs)
790
+
791
+ @suppress_traceback
792
+ def mipnerf_render_rays(*args, **kwargs):
793
+ if TYPE_CHECKING: # redirected to:
794
+ None, utils3d.torch.mipnerf_render_rays
795
+ return _call_based_on_args('mipnerf_render_rays', args, kwargs)
796
+
797
+ @suppress_traceback
798
+ def nerf_render_view(*args, **kwargs):
799
+ if TYPE_CHECKING: # redirected to:
800
+ None, utils3d.torch.nerf_render_view
801
+ return _call_based_on_args('nerf_render_view', args, kwargs)
802
+
803
+ @suppress_traceback
804
+ def mipnerf_render_view(*args, **kwargs):
805
+ if TYPE_CHECKING: # redirected to:
806
+ None, utils3d.torch.mipnerf_render_view
807
+ return _call_based_on_args('mipnerf_render_view', args, kwargs)
808
+
809
+ @suppress_traceback
810
+ def InstantNGP(*args, **kwargs):
811
+ if TYPE_CHECKING: # redirected to:
812
+ None, utils3d.torch.InstantNGP
813
+ return _call_based_on_args('InstantNGP', args, kwargs)
814
+
815
+ @suppress_traceback
816
+ def point_to_normal(*args, **kwargs):
817
+ if TYPE_CHECKING: # redirected to:
818
+ None, utils3d.torch.point_to_normal
819
+ return _call_based_on_args('point_to_normal', args, kwargs)
820
+
821
+ @suppress_traceback
822
+ def depth_to_normal(*args, **kwargs):
823
+ if TYPE_CHECKING: # redirected to:
824
+ None, utils3d.torch.depth_to_normal
825
+ return _call_based_on_args('depth_to_normal', args, kwargs)
826
+
827
+ @suppress_traceback
828
+ def masked_min(*args, **kwargs):
829
+ if TYPE_CHECKING: # redirected to:
830
+ None, utils3d.torch.masked_min
831
+ return _call_based_on_args('masked_min', args, kwargs)
832
+
833
+ @suppress_traceback
834
+ def masked_max(*args, **kwargs):
835
+ if TYPE_CHECKING: # redirected to:
836
+ None, utils3d.torch.masked_max
837
+ return _call_based_on_args('masked_max', args, kwargs)
838
+
839
+ @suppress_traceback
840
+ def bounding_rect(*args, **kwargs):
841
+ if TYPE_CHECKING: # redirected to:
842
+ None, utils3d.torch.bounding_rect
843
+ return _call_based_on_args('bounding_rect', args, kwargs)
844
+
845
+ @suppress_traceback
846
+ def intrinsics_from_fov_xy(*args, **kwargs):
847
+ if TYPE_CHECKING: # redirected to:
848
+ None, utils3d.torch.intrinsics_from_fov_xy
849
+ return _call_based_on_args('intrinsics_from_fov_xy', args, kwargs)
850
+
851
+ @suppress_traceback
852
+ def matrix_to_euler_angles(*args, **kwargs):
853
+ if TYPE_CHECKING: # redirected to:
854
+ None, utils3d.torch.matrix_to_euler_angles
855
+ return _call_based_on_args('matrix_to_euler_angles', args, kwargs)
856
+
857
+ @suppress_traceback
858
+ def matrix_to_axis_angle(*args, **kwargs):
859
+ if TYPE_CHECKING: # redirected to:
860
+ None, utils3d.torch.matrix_to_axis_angle
861
+ return _call_based_on_args('matrix_to_axis_angle', args, kwargs)
862
+
863
+ @suppress_traceback
864
+ def axis_angle_to_quaternion(*args, **kwargs):
865
+ if TYPE_CHECKING: # redirected to:
866
+ None, utils3d.torch.axis_angle_to_quaternion
867
+ return _call_based_on_args('axis_angle_to_quaternion', args, kwargs)
868
+
869
+ @suppress_traceback
870
+ def quaternion_to_axis_angle(*args, **kwargs):
871
+ if TYPE_CHECKING: # redirected to:
872
+ None, utils3d.torch.quaternion_to_axis_angle
873
+ return _call_based_on_args('quaternion_to_axis_angle', args, kwargs)
874
+
875
+ @suppress_traceback
876
+ def slerp(*args, **kwargs):
877
+ if TYPE_CHECKING: # redirected to:
878
+ None, utils3d.torch.slerp
879
+ return _call_based_on_args('slerp', args, kwargs)
880
+
881
+ @suppress_traceback
882
+ def interpolate_extrinsics(*args, **kwargs):
883
+ if TYPE_CHECKING: # redirected to:
884
+ None, utils3d.torch.interpolate_extrinsics
885
+ return _call_based_on_args('interpolate_extrinsics', args, kwargs)
886
+
887
+ @suppress_traceback
888
+ def interpolate_view(*args, **kwargs):
889
+ if TYPE_CHECKING: # redirected to:
890
+ None, utils3d.torch.interpolate_view
891
+ return _call_based_on_args('interpolate_view', args, kwargs)
892
+
893
+ @suppress_traceback
894
+ def to4x4(*args, **kwargs):
895
+ if TYPE_CHECKING: # redirected to:
896
+ None, utils3d.torch.to4x4
897
+ return _call_based_on_args('to4x4', args, kwargs)
898
+
899
+ @suppress_traceback
900
+ def rotation_matrix_2d(*args, **kwargs):
901
+ if TYPE_CHECKING: # redirected to:
902
+ None, utils3d.torch.rotation_matrix_2d
903
+ return _call_based_on_args('rotation_matrix_2d', args, kwargs)
904
+
905
+ @suppress_traceback
906
+ def rotate_2d(*args, **kwargs):
907
+ if TYPE_CHECKING: # redirected to:
908
+ None, utils3d.torch.rotate_2d
909
+ return _call_based_on_args('rotate_2d', args, kwargs)
910
+
911
+ @suppress_traceback
912
+ def translate_2d(*args, **kwargs):
913
+ if TYPE_CHECKING: # redirected to:
914
+ None, utils3d.torch.translate_2d
915
+ return _call_based_on_args('translate_2d', args, kwargs)
916
+
917
+ @suppress_traceback
918
+ def scale_2d(*args, **kwargs):
919
+ if TYPE_CHECKING: # redirected to:
920
+ None, utils3d.torch.scale_2d
921
+ return _call_based_on_args('scale_2d', args, kwargs)
922
+
923
+ @suppress_traceback
924
+ def apply_2d(*args, **kwargs):
925
+ if TYPE_CHECKING: # redirected to:
926
+ None, utils3d.torch.apply_2d
927
+ return _call_based_on_args('apply_2d', args, kwargs)
928
+
929
+ @suppress_traceback
930
+ def warp_image_by_forward_flow(*args, **kwargs):
931
+ if TYPE_CHECKING: # redirected to:
932
+ None, utils3d.torch.warp_image_by_forward_flow
933
+ return _call_based_on_args('warp_image_by_forward_flow', args, kwargs)
934
+
src/utils3d/_unified/__init__.pyi ADDED
The diff for this file is too large to render. See raw diff
 
src/utils3d/io/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .obj import *
2
+ from .colmap import *
3
+ from .ply import *
src/utils3d/io/colmap.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ from scipy.spatial.transform import Rotation
6
+
7
+
8
+ __all__ = ['read_extrinsics_from_colmap', 'read_intrinsics_from_colmap', 'write_extrinsics_as_colmap', 'write_intrinsics_as_colmap']
9
+
10
+
11
+ def write_extrinsics_as_colmap(file: Union[str, Path], extrinsics: np.ndarray, image_names: Union[str, List[str]] = 'image_{i:04d}.png', camera_ids: List[int] = None):
12
+ """
13
+ Write extrinsics to colmap `images.txt` file.
14
+ Args:
15
+ file: Path to `images.txt` file.
16
+ extrinsics: (N, 4, 4) array of extrinsics.
17
+ image_names: str or List of str, image names. Length is N.
18
+ If str, it should be a format string with `i` as the index. (i starts from 1, in correspondence with IMAGE_ID in colmap)
19
+ camera_ids: List of int, camera ids. Length is N.
20
+ If None, it will be set to [1, 2, ..., N].
21
+ """
22
+ assert extrinsics.shape[1:] == (4, 4) and extrinsics.ndim == 3 or extrinsics.shape == (4, 4)
23
+ if extrinsics.ndim == 2:
24
+ extrinsics = extrinsics[np.newaxis, ...]
25
+ quats = Rotation.from_matrix(extrinsics[:, :3, :3]).as_quat()
26
+ trans = extrinsics[:, :3, 3]
27
+ if camera_ids is None:
28
+ camera_ids = list(range(1, len(extrinsics) + 1))
29
+ if isinstance(image_names, str):
30
+ image_names = [image_names.format(i=i) for i in range(1, len(extrinsics) + 1)]
31
+ assert len(extrinsics) == len(image_names) == len(camera_ids), \
32
+ f'Number of extrinsics ({len(extrinsics)}), image_names ({len(image_names)}), and camera_ids ({len(camera_ids)}) must be the same'
33
+ with open(file, 'w') as fp:
34
+ print("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME", file=fp)
35
+ for i, (quat, t, name, camera_id) in enumerate(zip(quats.tolist(), trans.tolist(), image_names, camera_ids)):
36
+ # Colmap has wxyz order while scipy.spatial.transform.Rotation has xyzw order.
37
+ qx, qy, qz, qw = quat
38
+ tx, ty, tz = t
39
+ print(f'{i + 1} {qw:f} {qx:f} {qy:f} {qz:f} {tx:f} {ty:f} {tz:f} {camera_id:d} {name}', file=fp)
40
+ print()
41
+
42
+
43
+ def write_intrinsics_as_colmap(file: Union[str, Path], intrinsics: np.ndarray, width: int, height: int, normalized: bool = False):
44
+ """
45
+ Write intrinsics to colmap `cameras.txt` file. Currently only support PINHOLE model (no distortion)
46
+ Args:
47
+ file: Path to `cameras.txt` file.
48
+ intrinsics: (N, 3, 3) array of intrinsics.
49
+ width: Image width.
50
+ height: Image height.
51
+ normalized: Whether the intrinsics are normalized. If True, the intrinsics will unnormalized for writing.
52
+ """
53
+ assert intrinsics.shape[1:] == (3, 3) and intrinsics.ndim == 3 or intrinsics.shape == (3, 3)
54
+ if intrinsics.ndim == 2:
55
+ intrinsics = intrinsics[np.newaxis, ...]
56
+ if normalized:
57
+ intrinsics = intrinsics * np.array([width, height, 1])[:, None]
58
+ with open(file, 'w') as fp:
59
+ print("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]", file=fp)
60
+ for i, intr in enumerate(intrinsics):
61
+ fx, fy, cx, cy = intr[0, 0], intr[1, 1], intr[0, 2], intr[1, 2]
62
+ print(f'{i + 1} PINHOLE {width:d} {height:d} {fx:f} {fy:f} {cx:f} {cy:f}', file=fp)
63
+
64
+
65
+ def read_extrinsics_from_colmap(file: Union[str, Path]) -> Union[np.ndarray, List[int], List[str]]:
66
+ """
67
+ Read extrinsics from colmap `images.txt` file.
68
+ Args:
69
+ file: Path to `images.txt` file.
70
+ Returns:
71
+ extrinsics: (N, 4, 4) array of extrinsics.
72
+ camera_ids: List of int, camera ids. Length is N. Note that camera ids in colmap typically starts from 1.
73
+ image_names: List of str, image names. Length is N.
74
+ """
75
+ with open(file) as fp:
76
+ lines = fp.readlines()
77
+ image_names, quats, trans, camera_ids = [], [], [], []
78
+ i_line = 0
79
+ for line in lines:
80
+ line = line.strip()
81
+ if line.startswith('#'):
82
+ continue
83
+ i_line += 1
84
+ if i_line % 2 == 0:
85
+ continue
86
+ image_id, qw, qx, qy, qz, tx, ty, tz, camera_id, name = line.split()
87
+ quats.append([float(qx), float(qy), float(qz), float(qw)])
88
+ trans.append([float(tx), float(ty), float(tz)])
89
+ camera_ids.append(int(camera_id))
90
+ image_names.append(name)
91
+
92
+ quats = np.array(quats, dtype=np.float32)
93
+ trans = np.array(trans, dtype=np.float32)
94
+ rotation = Rotation.from_quat(quats).as_matrix()
95
+ extrinsics = np.concatenate([
96
+ np.concatenate([rotation, trans[..., None]], axis=-1),
97
+ np.array([0, 0, 0, 1], dtype=np.float32)[None, None, :].repeat(len(quats), axis=0)
98
+ ], axis=-2)
99
+
100
+ return extrinsics, camera_ids, image_names
101
+
102
+
103
+ def read_intrinsics_from_colmap(file: Union[str, Path], normalize: bool = False) -> Tuple[List[int], np.ndarray, np.ndarray]:
104
+ """
105
+ Read intrinsics from colmap `cameras.txt` file.
106
+ Args:
107
+ file: Path to `cameras.txt` file.
108
+ normalize: Whether to normalize the intrinsics. If True, the intrinsics will be normalized. (mapping coordinates to [0, 1] range)
109
+ Returns:
110
+ camera_ids: List of int, camera ids. Length is N. Note that camera ids in colmap typically starts from 1.
111
+ intrinsics: (N, 3, 3) array of intrinsics.
112
+ distortions: (N, 5) array of distortions.
113
+ """
114
+ with open(file) as fp:
115
+ lines = fp.readlines()
116
+ intrinsics, distortions, camera_ids = [], [], []
117
+ for line in lines:
118
+ line = line.strip()
119
+ if not line or line.startswith('#'):
120
+ continue
121
+ camera_id, model, width, height, *params = line.split()
122
+ camera_id, width, height = int(camera_id), int(width), int(height)
123
+ if model == 'PINHOLE':
124
+ fx, fy, cx, cy = map(float, params[:4])
125
+ k1 = k2 = k3 = p1 = p2 = 0.0
126
+ elif model == 'OPENCV':
127
+ fx, fy, cx, cy, k1, k2, p1, p2, k3 = *map(float, params[:8]), 0.0
128
+ elif model == 'SIMPLE_RADIAL':
129
+ f, cx, cy, k = map(float, params[:4])
130
+ fx = fy = f
131
+ k1, k2, p1, p2, k3 = k, 0.0, 0.0, 0.0, 0.0
132
+ camera_ids.append(camera_id)
133
+ if normalize:
134
+ fx, fy, cx, cy = fx / width, fy / height, cx / width, cy / height
135
+ intrinsics.append([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
136
+ distortions.append([k1, k2, p1, p2, k3])
137
+ intrinsics = np.array(intrinsics, dtype=np.float32)
138
+ distortions = np.array(distortions, dtype=np.float32)
139
+ return camera_ids, intrinsics, distortions
src/utils3d/io/obj.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import TextIOWrapper
2
+ from typing import Dict, Any, Union, Iterable
3
+ import numpy as np
4
+ from pathlib import Path
5
+
6
+ __all__ = [
7
+ 'read_obj',
8
+ 'write_obj',
9
+ 'simple_write_obj'
10
+ ]
11
+
12
+ def read_obj(
13
+ file : Union[str, Path, TextIOWrapper],
14
+ encoding: Union[str, None] = None,
15
+ ignore_unknown: bool = False
16
+ ):
17
+ """
18
+ Read wavefront .obj file, without preprocessing.
19
+
20
+ Why bothering having this read_obj() while we already have other libraries like `trimesh`?
21
+ This function read the raw format from .obj file and keeps the order of vertices and faces,
22
+ while trimesh which involves modification like merge/split vertices, which could break the orders of vertices and faces,
23
+ Those libraries are commonly aiming at geometry processing and rendering supporting various formats.
24
+ If you want mesh geometry processing, you may turn to `trimesh` for more features.
25
+
26
+ ### Parameters
27
+ `file` (str, Path, TextIOWrapper): filepath or file object
28
+ encoding (str, optional):
29
+
30
+ ### Returns
31
+ obj (dict): A dict containing .obj components
32
+ {
33
+ 'mtllib': [],
34
+ 'v': [[0,1, 0.2, 1.0], [1.2, 0.0, 0.0], ...],
35
+ 'vt': [[0.5, 0.5], ...],
36
+ 'vn': [[0., 0.7, 0.7], [0., -0.7, 0.7], ...],
37
+ 'f': [[0, 1, 2], [2, 3, 4],...],
38
+ 'usemtl': [{'name': 'mtl1', 'f': 7}]
39
+ }
40
+ """
41
+ if hasattr(file,'read'):
42
+ lines = file.read().splitlines()
43
+ else:
44
+ with open(file, 'r', encoding=encoding) as fp:
45
+ lines = fp.read().splitlines()
46
+ mtllib = []
47
+ v, vt, vn, vp = [], [], [], [] # Vertex coordinates, Vertex texture coordinate, Vertex normal, Vertex parameter
48
+ f, ft, fn = [], [], [] # Face indices, Face texture indices, Face normal indices
49
+ o = []
50
+ s = []
51
+ usemtl = []
52
+
53
+ def pad(l: list, n: Any):
54
+ return l + [n] * (3 - len(l))
55
+
56
+ for i, line in enumerate(lines):
57
+ sq = line.strip().split()
58
+ if len(sq) == 0:
59
+ continue
60
+ if sq[0] == 'v':
61
+ assert 4 <= len(sq) <= 5, f'Invalid format of line {i}: {line}'
62
+ v.append([float(e) for e in sq[1:]][:3])
63
+ elif sq[0] == 'vt':
64
+ assert 3 <= len(sq) <= 4, f'Invalid format of line {i}: {line}'
65
+ vt.append([float(e) for e in sq[1:]][:2])
66
+ elif sq[0] == 'vn':
67
+ assert len(sq) == 4, f'Invalid format of line {i}: {line}'
68
+ vn.append([float(e) for e in sq[1:]])
69
+ elif sq[0] == 'vp':
70
+ assert 2 <= len(sq) <= 4, f'Invalid format of line {i}: {line}'
71
+ vp.append(pad([float(e) for e in sq[1:]], 0))
72
+ elif sq[0] == 'f':
73
+ spliting = [pad([int(j) - 1 for j in e.split('/')], -1) for e in sq[1:]]
74
+ f.append([e[0] for e in spliting])
75
+ ft.append([e[1] for e in spliting])
76
+ fn.append([e[2] for e in spliting])
77
+ elif sq[0] == 'usemtl':
78
+ assert len(sq) == 2
79
+ usemtl.append((sq[1], len(f)))
80
+ elif sq[0] == 'o':
81
+ assert len(sq) == 2
82
+ o.append((sq[1], len(f)))
83
+ elif sq[0] == 's':
84
+ s.append((sq[1], len(f)))
85
+ elif sq[0] == 'mtllib':
86
+ assert len(sq) == 2
87
+ mtllib.append(sq[1])
88
+ elif sq[0][0] == '#':
89
+ continue
90
+ else:
91
+ if not ignore_unknown:
92
+ raise Exception(f'Unknown keyword {sq[0]}')
93
+
94
+ min_poly_vertices = min(len(f) for f in f)
95
+ max_poly_vertices = max(len(f) for f in f)
96
+
97
+ return {
98
+ 'mtllib': mtllib,
99
+ 'v': np.array(v, dtype=np.float32),
100
+ 'vt': np.array(vt, dtype=np.float32),
101
+ 'vn': np.array(vn, dtype=np.float32),
102
+ 'vp': np.array(vp, dtype=np.float32),
103
+ 'f': np.array(f, dtype=np.int32) if min_poly_vertices == max_poly_vertices else f,
104
+ 'ft': np.array(ft, dtype=np.int32) if min_poly_vertices == max_poly_vertices else ft,
105
+ 'fn': np.array(fn, dtype=np.int32) if min_poly_vertices == max_poly_vertices else fn,
106
+ 'o': o,
107
+ 's': s,
108
+ 'usemtl': usemtl,
109
+ }
110
+
111
+
112
+ def write_obj(
113
+ file: Union[str, Path],
114
+ obj: Dict[str, Any],
115
+ encoding: Union[str, None] = None
116
+ ):
117
+ with open(file, 'w', encoding=encoding) as fp:
118
+ for k in ['v', 'vt', 'vn', 'vp']:
119
+ if k not in obj:
120
+ continue
121
+ for v in obj[k]:
122
+ print(k, *map(float, v), file=fp)
123
+ for f in obj['f']:
124
+ print('f', *((str('/').join(map(int, i)) if isinstance(int(i), Iterable) else i) for i in f), file=fp)
125
+
126
+
127
+ def simple_write_obj(
128
+ file: Union[str, Path],
129
+ vertices: np.ndarray,
130
+ faces: np.ndarray,
131
+ encoding: Union[str, None] = None
132
+ ):
133
+ """
134
+ Write wavefront .obj file, without preprocessing.
135
+
136
+ Args:
137
+ vertices (np.ndarray): [N, 3]
138
+ faces (np.ndarray): [T, 3]
139
+ file (Any): filepath
140
+ encoding (str, optional):
141
+ """
142
+ with open(file, 'w', encoding=encoding) as fp:
143
+ for v in vertices:
144
+ print('v', *map(float, v), file=fp)
145
+ for f in faces:
146
+ print('f', *map(int, f + 1), file=fp)
src/utils3d/io/ply.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from typing import *
4
+ from pathlib import Path
5
+
6
+
7
+ def read_ply(
8
+ file: Union[str, Path],
9
+ encoding: Union[str, None] = None,
10
+ ignore_unknown: bool = False
11
+ ) -> Tuple[np.ndarray, np.ndarray]:
12
+ """
13
+ Read .ply file, without preprocessing.
14
+
15
+ Args:
16
+ file (Any): filepath
17
+ encoding (str, optional):
18
+
19
+ Returns:
20
+ Tuple[np.ndarray, np.ndarray]: vertices, faces
21
+ """
22
+ import plyfile
23
+ plydata = plyfile.PlyData.read(file)
24
+ vertices = np.stack([plydata['vertex'][k] for k in ['x', 'y', 'z']], axis=-1)
25
+ if 'face' in plydata:
26
+ faces = np.array(plydata['face']['vertex_indices'].tolist())
27
+ else:
28
+ faces = None
29
+ return vertices, faces
30
+
31
+
32
+ def write_ply(
33
+ file: Union[str, Path],
34
+ vertices: np.ndarray,
35
+ faces: np.ndarray = None,
36
+ edges: np.ndarray = None,
37
+ vertex_colors: np.ndarray = None,
38
+ edge_colors: np.ndarray = None,
39
+ text: bool = False
40
+ ):
41
+ """
42
+ Write .ply file, without preprocessing.
43
+
44
+ Args:
45
+ file (Any): filepath
46
+ vertices (np.ndarray): [N, 3]
47
+ faces (np.ndarray): [T, E]
48
+ edges (np.ndarray): [E, 2]
49
+ vertex_colors (np.ndarray, optional): [N, 3]. Defaults to None.
50
+ edge_colors (np.ndarray, optional): [E, 3]. Defaults to None.
51
+ text (bool, optional): save data in text format. Defaults to False.
52
+ """
53
+ import plyfile
54
+ assert vertices.ndim == 2 and vertices.shape[1] == 3
55
+ vertices = vertices.astype(np.float32)
56
+ if faces is not None:
57
+ assert faces.ndim == 2
58
+ faces = faces.astype(np.int32)
59
+ if edges is not None:
60
+ assert edges.ndim == 2 and edges.shape[1] == 2
61
+ edges = edges.astype(np.int32)
62
+
63
+ if vertex_colors is not None:
64
+ assert vertex_colors.ndim == 2 and vertex_colors.shape[1] == 3
65
+ if vertex_colors.dtype in [np.float32, np.float64]:
66
+ vertex_colors = vertex_colors * 255
67
+ vertex_colors = np.clip(vertex_colors, 0, 255).astype(np.uint8)
68
+ vertices_data = np.zeros(len(vertices), dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
69
+ vertices_data['x'] = vertices[:, 0]
70
+ vertices_data['y'] = vertices[:, 1]
71
+ vertices_data['z'] = vertices[:, 2]
72
+ vertices_data['red'] = vertex_colors[:, 0]
73
+ vertices_data['green'] = vertex_colors[:, 1]
74
+ vertices_data['blue'] = vertex_colors[:, 2]
75
+ else:
76
+ vertices_data = np.array([tuple(v) for v in vertices], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
77
+
78
+ if faces is not None:
79
+ faces_data = np.zeros(len(faces), dtype=[('vertex_indices', 'i4', (faces.shape[1],))])
80
+ faces_data['vertex_indices'] = faces
81
+
82
+ if edges is not None:
83
+ if edge_colors is not None:
84
+ assert edge_colors.ndim == 2 and edge_colors.shape[1] == 3
85
+ if edge_colors.dtype in [np.float32, np.float64]:
86
+ edge_colors = edge_colors * 255
87
+ edge_colors = np.clip(edge_colors, 0, 255).astype(np.uint8)
88
+ edges_data = np.zeros(len(edges), dtype=[('vertex1', 'i4'), ('vertex2', 'i4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
89
+ edges_data['vertex1'] = edges[:, 0]
90
+ edges_data['vertex2'] = edges[:, 1]
91
+ edges_data['red'] = edge_colors[:, 0]
92
+ edges_data['green'] = edge_colors[:, 1]
93
+ edges_data['blue'] = edge_colors[:, 2]
94
+ else:
95
+ edges_data = np.array([tuple(e) for e in edges], dtype=[('vertex1', 'i4'), ('vertex2', 'i4')])
96
+
97
+ ply_data = [plyfile.PlyElement.describe(vertices_data, 'vertex')]
98
+ if faces is not None:
99
+ ply_data.append(plyfile.PlyElement.describe(faces_data, 'face'))
100
+ if edges is not None:
101
+ ply_data.append(plyfile.PlyElement.describe(edges_data, 'edge'))
102
+
103
+ plyfile.PlyData(ply_data, text=text).write(file)
104
+
src/utils3d/numpy/__init__.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 3D utility functions workings with NumPy.
3
+ """
4
+ import importlib
5
+ import itertools
6
+ import numpy
7
+ from typing import TYPE_CHECKING
8
+
9
+
10
+ __modules_all__ = {
11
+ 'mesh':[
12
+ 'triangulate',
13
+ 'compute_face_normal',
14
+ 'compute_face_angle',
15
+ 'compute_vertex_normal',
16
+ 'compute_vertex_normal_weighted',
17
+ 'remove_corrupted_faces',
18
+ 'merge_duplicate_vertices',
19
+ 'remove_unreferenced_vertices',
20
+ 'subdivide_mesh_simple',
21
+ 'mesh_relations',
22
+ 'flatten_mesh_indices'
23
+ ],
24
+ 'quadmesh': [
25
+ 'calc_quad_candidates',
26
+ 'calc_quad_distortion',
27
+ 'calc_quad_direction',
28
+ 'calc_quad_smoothness',
29
+ 'sovle_quad',
30
+ 'sovle_quad_qp',
31
+ 'tri_to_quad'
32
+ ],
33
+ 'utils': [
34
+ 'sliding_window_1d',
35
+ 'sliding_window_nd',
36
+ 'sliding_window_2d',
37
+ 'max_pool_1d',
38
+ 'max_pool_2d',
39
+ 'max_pool_nd',
40
+ 'depth_edge',
41
+ 'normals_edge',
42
+ 'depth_aliasing',
43
+ 'interpolate',
44
+ 'image_scrcoord',
45
+ 'image_uv',
46
+ 'image_pixel_center',
47
+ 'image_pixel',
48
+ 'image_mesh',
49
+ 'image_mesh_from_depth',
50
+ 'depth_to_normals',
51
+ 'points_to_normals',
52
+ 'chessboard',
53
+ 'cube',
54
+ 'icosahedron',
55
+ 'square',
56
+ 'camera_frustum',
57
+ ],
58
+ 'transforms': [
59
+ 'perspective',
60
+ 'perspective_from_fov',
61
+ 'perspective_from_fov_xy',
62
+ 'intrinsics_from_focal_center',
63
+ 'intrinsics_from_fov',
64
+ 'fov_to_focal',
65
+ 'focal_to_fov',
66
+ 'intrinsics_to_fov',
67
+ 'view_look_at',
68
+ 'extrinsics_look_at',
69
+ 'perspective_to_intrinsics',
70
+ 'perspective_to_near_far',
71
+ 'intrinsics_to_perspective',
72
+ 'extrinsics_to_view',
73
+ 'view_to_extrinsics',
74
+ 'normalize_intrinsics',
75
+ 'crop_intrinsics',
76
+ 'pixel_to_uv',
77
+ 'pixel_to_ndc',
78
+ 'uv_to_pixel',
79
+ 'project_depth',
80
+ 'depth_buffer_to_linear',
81
+ 'unproject_cv',
82
+ 'unproject_gl',
83
+ 'project_cv',
84
+ 'project_gl',
85
+ 'quaternion_to_matrix',
86
+ 'axis_angle_to_matrix',
87
+ 'matrix_to_quaternion',
88
+ 'extrinsics_to_essential',
89
+ 'euler_axis_angle_rotation',
90
+ 'euler_angles_to_matrix',
91
+ 'skew_symmetric',
92
+ 'rotation_matrix_from_vectors',
93
+ 'ray_intersection',
94
+ 'se3_matrix',
95
+ 'slerp_quaternion',
96
+ 'slerp_vector',
97
+ 'lerp',
98
+ 'lerp_se3_matrix',
99
+ 'piecewise_lerp',
100
+ 'piecewise_lerp_se3_matrix',
101
+ 'apply_transform'
102
+ ],
103
+ 'spline': [
104
+ 'linear_spline_interpolate',
105
+ ],
106
+ 'rasterization': [
107
+ 'RastContext',
108
+ 'rasterize_triangle_faces',
109
+ 'rasterize_edges',
110
+ 'texture',
111
+ 'warp_image_by_depth',
112
+ 'test_rasterization'
113
+ ],
114
+ }
115
+
116
+
117
+ __all__ = list(itertools.chain(*__modules_all__.values()))
118
+
119
+ def __getattr__(name):
120
+ try:
121
+ return globals()[name]
122
+ except KeyError:
123
+ pass
124
+
125
+ try:
126
+ module_name = next(m for m in __modules_all__ if name in __modules_all__[m])
127
+ except StopIteration:
128
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
129
+ module = importlib.import_module(f'.{module_name}', __name__)
130
+ for key in __modules_all__[module_name]:
131
+ globals()[key] = getattr(module, key)
132
+
133
+ return globals()[name]
134
+
135
+
136
+ if TYPE_CHECKING:
137
+ from .quadmesh import *
138
+ from .transforms import *
139
+ from .mesh import *
140
+ from .utils import *
141
+ from .rasterization import *
142
+ from .spline import *
src/utils3d/numpy/_helpers.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # decorator
2
+ import numpy as np
3
+ from numbers import Number
4
+ import inspect
5
+ from functools import wraps
6
+ from typing import *
7
+ from .._helpers import suppress_traceback
8
+
9
+
10
+ def get_args_order(func, args, kwargs):
11
+ """
12
+ Get the order of the arguments of a function.
13
+ """
14
+ names = inspect.getfullargspec(func).args
15
+ names_idx = {name: i for i, name in enumerate(names)}
16
+ args_order = []
17
+ kwargs_order = {}
18
+ for name, arg in kwargs.items():
19
+ if name in names:
20
+ kwargs_order[name] = names_idx[name]
21
+ names.remove(name)
22
+ for i, arg in enumerate(args):
23
+ if i < len(names):
24
+ args_order.append(names_idx[names[i]])
25
+ return args_order, kwargs_order
26
+
27
+
28
+ def broadcast_args(args, kwargs, args_dim, kwargs_dim):
29
+ spatial = []
30
+ for arg, arg_dim in zip(args + list(kwargs.values()), args_dim + list(kwargs_dim.values())):
31
+ if isinstance(arg, np.ndarray) and arg_dim is not None:
32
+ arg_spatial = arg.shape[:arg.ndim-arg_dim]
33
+ if len(arg_spatial) > len(spatial):
34
+ spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial
35
+ for j in range(len(arg_spatial)):
36
+ if spatial[-j] < arg_spatial[-j]:
37
+ if spatial[-j] == 1:
38
+ spatial[-j] = arg_spatial[-j]
39
+ else:
40
+ raise ValueError("Cannot broadcast arguments.")
41
+ for i, arg in enumerate(args):
42
+ if isinstance(arg, np.ndarray) and args_dim[i] is not None:
43
+ args[i] = np.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-args_dim[i]:]])
44
+ for key, arg in kwargs.items():
45
+ if isinstance(arg, np.ndarray) and kwargs_dim[key] is not None:
46
+ kwargs[key] = np.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-kwargs_dim[key]:]])
47
+ return args, kwargs, spatial
48
+
49
+
50
+ def batched(*dims):
51
+ """
52
+ Decorator that allows a function to be called with batched arguments.
53
+ """
54
+ def decorator(func):
55
+ @wraps(func)
56
+ @suppress_traceback
57
+ def wrapper(*args, **kwargs):
58
+ args = list(args)
59
+ # get arguments dimensions
60
+ args_order, kwargs_order = get_args_order(func, args, kwargs)
61
+ args_dim = [dims[i] for i in args_order]
62
+ kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()}
63
+ # convert to numpy array
64
+ for i, arg in enumerate(args):
65
+ if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None:
66
+ args[i] = np.array(arg)
67
+ for key, arg in kwargs.items():
68
+ if isinstance(arg, (Number, list, tuple)) and kwargs_dim[key] is not None:
69
+ kwargs[key] = np.array(arg)
70
+ # broadcast arguments
71
+ args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim)
72
+ for i, (arg, arg_dim) in enumerate(zip(args, args_dim)):
73
+ if isinstance(arg, np.ndarray) and arg_dim is not None:
74
+ args[i] = arg.reshape([-1, *arg.shape[arg.ndim-arg_dim:]])
75
+ for key, arg in kwargs.items():
76
+ if isinstance(arg, np.ndarray) and kwargs_dim[key] is not None:
77
+ kwargs[key] = arg.reshape([-1, *arg.shape[arg.ndim-kwargs_dim[key]:]])
78
+ # call function
79
+ results = func(*args, **kwargs)
80
+ type_results = type(results)
81
+ results = list(results) if isinstance(results, (tuple, list)) else [results]
82
+ # restore spatial dimensions
83
+ for i, result in enumerate(results):
84
+ results[i] = result.reshape([*spatial, *result.shape[1:]])
85
+ if type_results == tuple:
86
+ results = tuple(results)
87
+ elif type_results == list:
88
+ results = list(results)
89
+ else:
90
+ results = results[0]
91
+ return results
92
+ return wrapper
93
+ return decorator
src/utils3d/numpy/mesh.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import *
3
+ from ._helpers import batched
4
+
5
+
6
+ __all__ = [
7
+ 'triangulate',
8
+ 'compute_face_normal',
9
+ 'compute_face_angle',
10
+ 'compute_vertex_normal',
11
+ 'compute_vertex_normal_weighted',
12
+ 'remove_corrupted_faces',
13
+ 'merge_duplicate_vertices',
14
+ 'remove_unreferenced_vertices',
15
+ 'subdivide_mesh_simple',
16
+ 'mesh_relations',
17
+ 'flatten_mesh_indices'
18
+ ]
19
+
20
+
21
+ def triangulate(
22
+ faces: np.ndarray,
23
+ vertices: np.ndarray = None,
24
+ backslash: np.ndarray = None
25
+ ) -> np.ndarray:
26
+ """
27
+ Triangulate a polygonal mesh.
28
+
29
+ Args:
30
+ faces (np.ndarray): [L, P] polygonal faces
31
+ vertices (np.ndarray, optional): [N, 3] 3-dimensional vertices.
32
+ If given, the triangulation is performed according to the distance
33
+ between vertices. Defaults to None.
34
+ backslash (np.ndarray, optional): [L] boolean array indicating
35
+ how to triangulate the quad faces. Defaults to None.
36
+
37
+ Returns:
38
+ (np.ndarray): [L * (P - 2), 3] triangular faces
39
+ """
40
+ if faces.shape[-1] == 3:
41
+ return faces
42
+ P = faces.shape[-1]
43
+ if vertices is not None:
44
+ assert faces.shape[-1] == 4, "now only support quad mesh"
45
+ if backslash is None:
46
+ backslash = np.linalg.norm(vertices[faces[:, 0]] - vertices[faces[:, 2]], axis=-1) < \
47
+ np.linalg.norm(vertices[faces[:, 1]] - vertices[faces[:, 3]], axis=-1)
48
+ if backslash is None:
49
+ loop_indice = np.stack([
50
+ np.zeros(P - 2, dtype=int),
51
+ np.arange(1, P - 1, 1, dtype=int),
52
+ np.arange(2, P, 1, dtype=int)
53
+ ], axis=1)
54
+ return faces[:, loop_indice].reshape((-1, 3))
55
+ else:
56
+ assert faces.shape[-1] == 4, "now only support quad mesh"
57
+ faces = np.where(
58
+ backslash[:, None],
59
+ faces[:, [0, 1, 2, 0, 2, 3]],
60
+ faces[:, [0, 1, 3, 3, 1, 2]]
61
+ ).reshape((-1, 3))
62
+ return faces
63
+
64
+
65
+ @batched(2, None)
66
+ def compute_face_normal(
67
+ vertices: np.ndarray,
68
+ faces: np.ndarray
69
+ ) -> np.ndarray:
70
+ """
71
+ Compute face normals of a triangular mesh
72
+
73
+ Args:
74
+ vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
75
+ faces (np.ndarray): [T, 3] triangular face indices
76
+
77
+ Returns:
78
+ normals (np.ndarray): [..., T, 3] face normals
79
+ """
80
+ normal = np.cross(
81
+ vertices[..., faces[:, 1], :] - vertices[..., faces[:, 0], :],
82
+ vertices[..., faces[:, 2], :] - vertices[..., faces[:, 0], :]
83
+ )
84
+ normal_norm = np.linalg.norm(normal, axis=-1, keepdims=True)
85
+ normal_norm[normal_norm == 0] = 1
86
+ normal /= normal_norm
87
+ return normal
88
+
89
+
90
+ @batched(2, None)
91
+ def compute_face_angle(
92
+ vertices: np.ndarray,
93
+ faces: np.ndarray,
94
+ eps: float = 1e-12
95
+ ) -> np.ndarray:
96
+ """
97
+ Compute face angles of a triangular mesh
98
+
99
+ Args:
100
+ vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
101
+ faces (np.ndarray): [T, 3] triangular face indices
102
+
103
+ Returns:
104
+ angles (np.ndarray): [..., T, 3] face angles
105
+ """
106
+ face_angle = np.zeros_like(faces, dtype=vertices.dtype)
107
+ for i in range(3):
108
+ edge1 = vertices[..., faces[:, (i + 1) % 3], :] - vertices[..., faces[:, i], :]
109
+ edge2 = vertices[..., faces[:, (i + 2) % 3], :] - vertices[..., faces[:, i], :]
110
+ face_angle[..., i] = np.arccos(np.sum(
111
+ edge1 / np.clip(np.linalg.norm(edge1, axis=-1, keepdims=True), eps, None) *
112
+ edge2 / np.clip(np.linalg.norm(edge2, axis=-1, keepdims=True), eps, None),
113
+ axis=-1
114
+ ))
115
+ return face_angle
116
+
117
+
118
+ @batched(2, None, 2)
119
+ def compute_vertex_normal(
120
+ vertices: np.ndarray,
121
+ faces: np.ndarray,
122
+ face_normal: np.ndarray = None
123
+ ) -> np.ndarray:
124
+ """
125
+ Compute vertex normals of a triangular mesh by averaging neightboring face normals
126
+ TODO: can be improved.
127
+
128
+ Args:
129
+ vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
130
+ faces (np.ndarray): [T, 3] triangular face indices
131
+ face_normal (np.ndarray, optional): [..., T, 3] face normals.
132
+ None to compute face normals from vertices and faces. Defaults to None.
133
+
134
+ Returns:
135
+ normals (np.ndarray): [..., N, 3] vertex normals
136
+ """
137
+ if face_normal is None:
138
+ face_normal = compute_face_normal(vertices, faces)
139
+ vertex_normal = np.zeros_like(vertices, dtype=vertices.dtype)
140
+ for n in range(vertices.shape[0]):
141
+ for i in range(3):
142
+ vertex_normal[n, :, 0] += np.bincount(faces[:, i], weights=face_normal[n, :, 0], minlength=vertices.shape[1])
143
+ vertex_normal[n, :, 1] += np.bincount(faces[:, i], weights=face_normal[n, :, 1], minlength=vertices.shape[1])
144
+ vertex_normal[n, :, 2] += np.bincount(faces[:, i], weights=face_normal[n, :, 2], minlength=vertices.shape[1])
145
+ vertex_normal_norm = np.linalg.norm(vertex_normal, axis=-1, keepdims=True)
146
+ vertex_normal_norm[vertex_normal_norm == 0] = 1
147
+ vertex_normal /= vertex_normal_norm
148
+ return vertex_normal
149
+
150
+
151
+ @batched(2, None, 2)
152
+ def compute_vertex_normal_weighted(
153
+ vertices: np.ndarray,
154
+ faces: np.ndarray,
155
+ face_normal: np.ndarray = None
156
+ ) -> np.ndarray:
157
+ """
158
+ Compute vertex normals of a triangular mesh by weighted sum of neightboring face normals
159
+ according to the angles
160
+
161
+ Args:
162
+ vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
163
+ faces (np.ndarray): [..., T, 3] triangular face indices
164
+ face_normal (np.ndarray, optional): [..., T, 3] face normals.
165
+ None to compute face normals from vertices and faces. Defaults to None.
166
+
167
+ Returns:
168
+ normals (np.ndarray): [..., N, 3] vertex normals
169
+ """
170
+ if face_normal is None:
171
+ face_normal = compute_face_normal(vertices, faces)
172
+ face_angle = compute_face_angle(vertices, faces)
173
+ vertex_normal = np.zeros_like(vertices)
174
+ for n in range(vertices.shape[0]):
175
+ for i in range(3):
176
+ vertex_normal[n, :, 0] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 0] * face_angle[n, :, i], minlength=vertices.shape[1])
177
+ vertex_normal[n, :, 1] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 1] * face_angle[n, :, i], minlength=vertices.shape[1])
178
+ vertex_normal[n, :, 2] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 2] * face_angle[n, :, i], minlength=vertices.shape[1])
179
+ vertex_normal_norm = np.linalg.norm(vertex_normal, axis=-1, keepdims=True)
180
+ vertex_normal_norm[vertex_normal_norm == 0] = 1
181
+ vertex_normal /= vertex_normal_norm
182
+ return vertex_normal
183
+
184
+
185
+ def remove_corrupted_faces(
186
+ faces: np.ndarray
187
+ ) -> np.ndarray:
188
+ """
189
+ Remove corrupted faces (faces with duplicated vertices)
190
+
191
+ Args:
192
+ faces (np.ndarray): [T, 3] triangular face indices
193
+
194
+ Returns:
195
+ np.ndarray: [T_, 3] triangular face indices
196
+ """
197
+ corrupted = (faces[:, 0] == faces[:, 1]) | (faces[:, 1] == faces[:, 2]) | (faces[:, 2] == faces[:, 0])
198
+ return faces[~corrupted]
199
+
200
+
201
+ def merge_duplicate_vertices(
202
+ vertices: np.ndarray,
203
+ faces: np.ndarray,
204
+ tol: float = 1e-6
205
+ ) -> Tuple[np.ndarray, np.ndarray]:
206
+ """
207
+ Merge duplicate vertices of a triangular mesh.
208
+ Duplicate vertices are merged by selecte one of them, and the face indices are updated accordingly.
209
+
210
+ Args:
211
+ vertices (np.ndarray): [N, 3] 3-dimensional vertices
212
+ faces (np.ndarray): [T, 3] triangular face indices
213
+ tol (float, optional): tolerance for merging. Defaults to 1e-6.
214
+
215
+ Returns:
216
+ vertices (np.ndarray): [N_, 3] 3-dimensional vertices
217
+ faces (np.ndarray): [T, 3] triangular face indices
218
+ """
219
+ vertices_round = np.round(vertices / tol)
220
+ _, uni_i, uni_inv = np.unique(vertices_round, return_index=True, return_inverse=True, axis=0)
221
+ vertices = vertices[uni_i]
222
+ faces = uni_inv[faces]
223
+ return vertices, faces
224
+
225
+
226
+ def remove_unreferenced_vertices(
227
+ faces: np.ndarray,
228
+ *vertice_attrs,
229
+ return_indices: bool = False
230
+ ) -> Tuple[np.ndarray, ...]:
231
+ """
232
+ Remove unreferenced vertices of a mesh.
233
+ Unreferenced vertices are removed, and the face indices are updated accordingly.
234
+
235
+ Args:
236
+ faces (np.ndarray): [T, P] face indices
237
+ *vertice_attrs: vertex attributes
238
+
239
+ Returns:
240
+ faces (np.ndarray): [T, P] face indices
241
+ *vertice_attrs: vertex attributes
242
+ indices (np.ndarray, optional): [N] indices of vertices that are kept. Defaults to None.
243
+ """
244
+ P = faces.shape[-1]
245
+ fewer_indices, inv_map = np.unique(faces, return_inverse=True)
246
+ faces = inv_map.astype(np.int32).reshape(-1, P)
247
+ ret = [faces]
248
+ for attr in vertice_attrs:
249
+ ret.append(attr[fewer_indices])
250
+ if return_indices:
251
+ ret.append(fewer_indices)
252
+ return tuple(ret)
253
+
254
+
255
+ def subdivide_mesh_simple(
256
+ vertices: np.ndarray,
257
+ faces: np.ndarray,
258
+ n: int = 1
259
+ ) -> Tuple[np.ndarray, np.ndarray]:
260
+ """
261
+ Subdivide a triangular mesh by splitting each triangle into 4 smaller triangles.
262
+ NOTE: All original vertices are kept, and new vertices are appended to the end of the vertex list.
263
+
264
+ Args:
265
+ vertices (np.ndarray): [N, 3] 3-dimensional vertices
266
+ faces (np.ndarray): [T, 3] triangular face indices
267
+ n (int, optional): number of subdivisions. Defaults to 1.
268
+
269
+ Returns:
270
+ vertices (np.ndarray): [N_, 3] subdivided 3-dimensional vertices
271
+ faces (np.ndarray): [4 * T, 3] subdivided triangular face indices
272
+ """
273
+ for _ in range(n):
274
+ edges = np.stack([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], axis=0)
275
+ edges = np.sort(edges, axis=2)
276
+ uni_edges, uni_inv = np.unique(edges.reshape(-1, 2), return_inverse=True, axis=0)
277
+ uni_inv = uni_inv.reshape(3, -1)
278
+ midpoints = (vertices[uni_edges[:, 0]] + vertices[uni_edges[:, 1]]) / 2
279
+
280
+ n_vertices = vertices.shape[0]
281
+ vertices = np.concatenate([vertices, midpoints], axis=0)
282
+ faces = np.concatenate([
283
+ np.stack([faces[:, 0], n_vertices + uni_inv[0], n_vertices + uni_inv[2]], axis=1),
284
+ np.stack([faces[:, 1], n_vertices + uni_inv[1], n_vertices + uni_inv[0]], axis=1),
285
+ np.stack([faces[:, 2], n_vertices + uni_inv[2], n_vertices + uni_inv[1]], axis=1),
286
+ np.stack([n_vertices + uni_inv[0], n_vertices + uni_inv[1], n_vertices + uni_inv[2]], axis=1),
287
+ ], axis=0)
288
+ return vertices, faces
289
+
290
+
291
+ def mesh_relations(
292
+ faces: np.ndarray,
293
+ ) -> Tuple[np.ndarray, np.ndarray]:
294
+ """
295
+ Calculate the relation between vertices and faces.
296
+ NOTE: The input mesh must be a manifold triangle mesh.
297
+
298
+ Args:
299
+ faces (np.ndarray): [T, 3] triangular face indices
300
+
301
+ Returns:
302
+ edges (np.ndarray): [E, 2] edge indices
303
+ edge2face (np.ndarray): [E, 2] edge to face relation. The second column is -1 if the edge is boundary.
304
+ face2edge (np.ndarray): [T, 3] face to edge relation
305
+ face2face (np.ndarray): [T, 3] face to face relation
306
+ """
307
+ T = faces.shape[0]
308
+ edges = np.stack([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], axis=1).reshape(-1, 2) # [3T, 2]
309
+ edges = np.sort(edges, axis=1) # [3T, 2]
310
+ edges, face2edge, occurence = np.unique(edges, axis=0, return_inverse=True, return_counts=True) # [E, 2], [3T], [E]
311
+ E = edges.shape[0]
312
+ assert np.all(occurence <= 2), "The input mesh is not a manifold mesh."
313
+
314
+ # Edge to face relation
315
+ padding = np.arange(E, dtype=np.int32)[occurence == 1]
316
+ padded_face2edge = np.concatenate([face2edge, padding], axis=0) # [2E]
317
+ edge2face = np.argsort(padded_face2edge, kind='stable').reshape(-1, 2) // 3 # [E, 2]
318
+ edge2face_valid = edge2face[:, 1] < T # [E]
319
+ edge2face[~edge2face_valid, 1] = -1
320
+
321
+ # Face to edge relation
322
+ face2edge = face2edge.reshape(-1, 3) # [T, 3]
323
+
324
+ # Face to face relation
325
+ face2face = edge2face[face2edge] # [T, 3, 2]
326
+ face2face = face2face[face2face != np.arange(T)[:, None, None]].reshape(T, 3) # [T, 3]
327
+
328
+ return edges, edge2face, face2edge, face2face
329
+
330
+
331
+ @overload
332
+ def flatten_mesh_indices(faces1: np.ndarray, attr1: np.ndarray, *other_faces_attrs_pairs: np.ndarray) -> Tuple[np.ndarray, ...]:
333
+ """
334
+ Rearrange the indices of a mesh to a flattened version. Vertices will be no longer shared.
335
+
336
+ ### Parameters:
337
+ - `faces1`: [T, P] face indices of the first attribute
338
+ - `attr1`: [N1, ...] attributes of the first mesh
339
+ - ...
340
+
341
+ ### Returns:
342
+ - `faces`: [T, P] flattened face indices, contigous from 0 to T * P - 1
343
+ - `attr1`: [T * P, ...] attributes of the first mesh, where every P values correspond to a face
344
+ _ ...
345
+ """
346
+ def flatten_mesh_indices(*args: np.ndarray) -> Tuple[np.ndarray, ...]:
347
+ assert len(args) % 2 == 0, "The number of arguments must be even."
348
+ T, P = args[0].shape
349
+ assert all(arg.shape[0] == T and arg.shape[1] == P for arg in args[::2]), "The faces must have the same shape."
350
+ attr_flat = []
351
+ for faces_, attr_ in zip(args[::2], args[1::2]):
352
+ attr_flat_ = attr_[faces_].reshape(-1, *attr_.shape[1:])
353
+ attr_flat.append(attr_flat_)
354
+ faces_flat = np.arange(T * P, dtype=np.int32).reshape(T, P)
355
+ return faces_flat, *attr_flat
src/utils3d/numpy/quadmesh.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy as sp
3
+ import scipy.optimize as spopt
4
+ from typing import *
5
+
6
+
7
+ __all__ = [
8
+ 'calc_quad_candidates',
9
+ 'calc_quad_distortion',
10
+ 'calc_quad_direction',
11
+ 'calc_quad_smoothness',
12
+ 'sovle_quad',
13
+ 'sovle_quad_qp',
14
+ 'tri_to_quad'
15
+ ]
16
+
17
+
18
+ def calc_quad_candidates(
19
+ edges: np.ndarray,
20
+ face2edge: np.ndarray,
21
+ edge2face: np.ndarray,
22
+ ):
23
+ """
24
+ Calculate the candidate quad faces.
25
+
26
+ Args:
27
+ edges (np.ndarray): [E, 2] edge indices
28
+ face2edge (np.ndarray): [T, 3] face to edge relation
29
+ edge2face (np.ndarray): [E, 2] edge to face relation
30
+
31
+ Returns:
32
+ quads (np.ndarray): [Q, 4] quad candidate indices
33
+ quad2edge (np.ndarray): [Q, 4] edge to quad candidate relation
34
+ quad2adj (np.ndarray): [Q, 8] adjacent quad candidates of each quad candidate
35
+ quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid
36
+ """
37
+ E = edges.shape[0]
38
+ T = face2edge.shape[0]
39
+
40
+ quads_valid = edge2face[:, 1] != -1
41
+ Q = quads_valid.sum()
42
+ quad2face = edge2face[quads_valid] # [Q, 2]
43
+ quad2edge = face2edge[quad2face] # [Q, 2, 3]
44
+ flag = quad2edge == np.arange(E)[quads_valid][:, None, None] # [Q, 2, 3]
45
+ flag = flag.argmax(axis=-1) # [Q, 2]
46
+ quad2edge = np.stack([
47
+ quad2edge[np.arange(Q)[:, None], np.arange(2)[None, :], (flag + 1) % 3],
48
+ quad2edge[np.arange(Q)[:, None], np.arange(2)[None, :], (flag + 2) % 3],
49
+ ], axis=-1).reshape(Q, 4) # [Q, 4]
50
+
51
+ quads = np.concatenate([
52
+ np.where(
53
+ (edges[quad2edge[:, 0:1], 1:] == edges[quad2edge[:, 1:2], :]).any(axis=-1),
54
+ edges[quad2edge[:, 0:1], [[0, 1]]],
55
+ edges[quad2edge[:, 0:1], [[1, 0]]],
56
+ ),
57
+ np.where(
58
+ (edges[quad2edge[:, 2:3], 1:] == edges[quad2edge[:, 3:4], :]).any(axis=-1),
59
+ edges[quad2edge[:, 2:3], [[0, 1]]],
60
+ edges[quad2edge[:, 2:3], [[1, 0]]],
61
+ ),
62
+ ], axis=1) # [Q, 4]
63
+
64
+ quad2adj = edge2face[quad2edge] # [Q, 4, 2]
65
+ quad2adj = quad2adj[quad2adj != quad2face[:, [0,0,1,1], None]].reshape(Q, 4) # [Q, 4]
66
+ quad2adj_valid = quad2adj != -1
67
+ quad2adj = face2edge[quad2adj] # [Q, 4, 3]
68
+ quad2adj[~quad2adj_valid, 0] = quad2edge[~quad2adj_valid]
69
+ quad2adj[~quad2adj_valid, 1:] = -1
70
+ quad2adj = quad2adj[quad2adj != quad2edge[..., None]].reshape(Q, 8) # [Q, 8]
71
+ edge_valid = -np.ones(E, dtype=np.int32)
72
+ edge_valid[quads_valid] = np.arange(Q)
73
+ quad2adj_valid = quad2adj != -1
74
+ quad2adj[quad2adj_valid] = edge_valid[quad2adj[quad2adj_valid]] # [Q, 8]
75
+
76
+ return quads, quad2edge, quad2adj, quads_valid
77
+
78
+
79
+ def calc_quad_distortion(
80
+ vertices: np.ndarray,
81
+ quads: np.ndarray,
82
+ ):
83
+ """
84
+ Calculate the distortion of each candidate quad face.
85
+
86
+ Args:
87
+ vertices (np.ndarray): [N, 3] 3-dimensional vertices
88
+ quads (np.ndarray): [Q, 4] quad face indices
89
+
90
+ Returns:
91
+ distortion (np.ndarray): [Q] distortion of each quad face
92
+ """
93
+ edge0 = vertices[quads[:, 1]] - vertices[quads[:, 0]] # [Q, 3]
94
+ edge1 = vertices[quads[:, 2]] - vertices[quads[:, 1]] # [Q, 3]
95
+ edge2 = vertices[quads[:, 3]] - vertices[quads[:, 2]] # [Q, 3]
96
+ edge3 = vertices[quads[:, 0]] - vertices[quads[:, 3]] # [Q, 3]
97
+ cross = vertices[quads[:, 0]] - vertices[quads[:, 2]] # [Q, 3]
98
+
99
+ len0 = np.maximum(np.linalg.norm(edge0, axis=-1), 1e-10) # [Q]
100
+ len1 = np.maximum(np.linalg.norm(edge1, axis=-1), 1e-10) # [Q]
101
+ len2 = np.maximum(np.linalg.norm(edge2, axis=-1), 1e-10) # [Q]
102
+ len3 = np.maximum(np.linalg.norm(edge3, axis=-1), 1e-10) # [Q]
103
+ len_cross = np.maximum(np.linalg.norm(cross, axis=-1), 1e-10) # [Q]
104
+
105
+ angle0 = np.arccos(np.clip(np.sum(-edge0 * edge1, axis=-1) / (len0 * len1), -1, 1)) # [Q]
106
+ angle1 = np.arccos(np.clip(np.sum(-edge1 * cross, axis=-1) / (len1 * len_cross), -1, 1)) \
107
+ + np.arccos(np.clip(np.sum(cross * edge2, axis=-1) / (len_cross * len2), -1, 1)) # [Q]
108
+ angle2 = np.arccos(np.clip(np.sum(-edge2 * edge3, axis=-1) / (len2 * len3), -1, 1)) # [Q]
109
+ angle3 = np.arccos(np.clip(np.sum(-edge3 * -cross, axis=-1) / (len3 * len_cross), -1, 1)) \
110
+ + np.arccos(np.clip(np.sum(-cross * edge0, axis=-1) / (len_cross * len0), -1, 1)) # [Q]
111
+
112
+ normal0 = np.cross(edge0, edge1) # [Q, 3]
113
+ normal1 = np.cross(edge2, edge3) # [Q, 3]
114
+ normal0 = normal0 / np.maximum(np.linalg.norm(normal0, axis=-1, keepdims=True), 1e-10) # [Q, 3]
115
+ normal1 = normal1 / np.maximum(np.linalg.norm(normal1, axis=-1, keepdims=True), 1e-10) # [Q, 3]
116
+ angle_normal = np.arccos(np.clip(np.sum(normal0 * normal1, axis=-1), -1, 1)) # [Q]
117
+
118
+ D90 = np.pi / 2
119
+ D180 = np.pi
120
+ D360 = np.pi * 2
121
+ ang_eng = (np.abs(angle0 - D90)**2 + np.abs(angle1 - D90)**2 + np.abs(angle2 - D90)**2 + np.abs(angle3 - D90)**2) / 4 # [Q]
122
+ dist_eng = np.abs(angle0 - angle2)**2 / np.minimum(np.maximum(np.minimum(angle0, angle2), 1e-10), np.maximum(D180 - np.maximum(angle0, angle2), 1e-10)) \
123
+ + np.abs(angle1 - angle3)**2 / np.minimum(np.maximum(np.minimum(angle1, angle3), 1e-10), np.maximum(D180 - np.maximum(angle1, angle3), 1e-10)) # [Q]
124
+ plane_eng = np.where(angle_normal < D90/2, np.abs(angle_normal)**2, 1e10) # [Q]
125
+ eng = ang_eng + 2 * dist_eng + 2 * plane_eng # [Q]
126
+
127
+ return eng
128
+
129
+
130
+ def calc_quad_direction(
131
+ vertices: np.ndarray,
132
+ quads: np.ndarray,
133
+ ):
134
+ """
135
+ Calculate the direction of each candidate quad face.
136
+
137
+ Args:
138
+ vertices (np.ndarray): [N, 3] 3-dimensional vertices
139
+ quads (np.ndarray): [Q, 4] quad face indices
140
+
141
+ Returns:
142
+ direction (np.ndarray): [Q, 4] direction of each quad face.
143
+ Represented by the angle between the crossing and each edge.
144
+ """
145
+ mid0 = (vertices[quads[:, 0]] + vertices[quads[:, 1]]) / 2 # [Q, 3]
146
+ mid1 = (vertices[quads[:, 1]] + vertices[quads[:, 2]]) / 2 # [Q, 3]
147
+ mid2 = (vertices[quads[:, 2]] + vertices[quads[:, 3]]) / 2 # [Q, 3]
148
+ mid3 = (vertices[quads[:, 3]] + vertices[quads[:, 0]]) / 2 # [Q, 3]
149
+
150
+ cross0 = mid2 - mid0 # [Q, 3]
151
+ cross1 = mid3 - mid1 # [Q, 3]
152
+ cross0 = cross0 / np.maximum(np.linalg.norm(cross0, axis=-1, keepdims=True), 1e-10) # [Q, 3]
153
+ cross1 = cross1 / np.maximum(np.linalg.norm(cross1, axis=-1, keepdims=True), 1e-10) # [Q, 3]
154
+
155
+ edge0 = vertices[quads[:, 1]] - vertices[quads[:, 0]] # [Q, 3]
156
+ edge1 = vertices[quads[:, 2]] - vertices[quads[:, 1]] # [Q, 3]
157
+ edge2 = vertices[quads[:, 3]] - vertices[quads[:, 2]] # [Q, 3]
158
+ edge3 = vertices[quads[:, 0]] - vertices[quads[:, 3]] # [Q, 3]
159
+ edge0 = edge0 / np.maximum(np.linalg.norm(edge0, axis=-1, keepdims=True), 1e-10) # [Q, 3]
160
+ edge1 = edge1 / np.maximum(np.linalg.norm(edge1, axis=-1, keepdims=True), 1e-10) # [Q, 3]
161
+ edge2 = edge2 / np.maximum(np.linalg.norm(edge2, axis=-1, keepdims=True), 1e-10) # [Q, 3]
162
+ edge3 = edge3 / np.maximum(np.linalg.norm(edge3, axis=-1, keepdims=True), 1e-10) # [Q, 3]
163
+
164
+ direction = np.stack([
165
+ np.arccos(np.clip(np.sum(cross0 * edge0, axis=-1), -1, 1)),
166
+ np.arccos(np.clip(np.sum(cross1 * edge1, axis=-1), -1, 1)),
167
+ np.arccos(np.clip(np.sum(-cross0 * edge2, axis=-1), -1, 1)),
168
+ np.arccos(np.clip(np.sum(-cross1 * edge3, axis=-1), -1, 1)),
169
+ ], axis=-1) # [Q, 4]
170
+
171
+ return direction
172
+
173
+
174
+ def calc_quad_smoothness(
175
+ quad2edge: np.ndarray,
176
+ quad2adj: np.ndarray,
177
+ quads_direction: np.ndarray,
178
+ ):
179
+ """
180
+ Calculate the smoothness of each candidate quad face connection.
181
+
182
+ Args:
183
+ quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face
184
+ quads_direction (np.ndarray): [Q, 4] direction of each quad face
185
+
186
+ Returns:
187
+ smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection
188
+ """
189
+ Q = quad2adj.shape[0]
190
+ quad2adj_valid = quad2adj != -1
191
+ connections = np.stack([
192
+ np.arange(Q)[:, None].repeat(8, axis=1),
193
+ quad2adj,
194
+ ], axis=-1)[quad2adj_valid] # [C, 2]
195
+ shared_edge_idx_0 = np.array([[0, 0, 1, 1, 2, 2, 3, 3]]).repeat(Q, axis=0)[quad2adj_valid] # [C]
196
+ shared_edge_idx_1 = np.argmax(quad2edge[quad2adj][quad2adj_valid] == quad2edge[connections[:, 0], shared_edge_idx_0][:, None], axis=-1) # [C]
197
+ valid_smoothness = np.abs(quads_direction[connections[:, 0], shared_edge_idx_0] - quads_direction[connections[:, 1], shared_edge_idx_1])**2 # [C]
198
+ smoothness = np.zeros([Q, 8], dtype=np.float32)
199
+ smoothness[quad2adj_valid] = valid_smoothness
200
+ return smoothness
201
+
202
+
203
+ def sovle_quad(
204
+ face2edge: np.ndarray,
205
+ edge2face: np.ndarray,
206
+ quad2adj: np.ndarray,
207
+ quads_distortion: np.ndarray,
208
+ quads_smoothness: np.ndarray,
209
+ quads_valid: np.ndarray,
210
+ ):
211
+ """
212
+ Solve the quad mesh from the candidate quad faces.
213
+
214
+ Args:
215
+ face2edge (np.ndarray): [T, 3] face to edge relation
216
+ edge2face (np.ndarray): [E, 2] edge to face relation
217
+ quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face
218
+ quads_distortion (np.ndarray): [Q] distortion of each quad face
219
+ quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection
220
+ quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid
221
+
222
+ Returns:
223
+ weights (np.ndarray): [Q] weight of each valid quad face
224
+ """
225
+ T = face2edge.shape[0]
226
+ E = edge2face.shape[0]
227
+ Q = quads_distortion.shape[0]
228
+ edge_valid = -np.ones(E, dtype=np.int32)
229
+ edge_valid[quads_valid] = np.arange(Q)
230
+
231
+ quads_connection = np.stack([
232
+ np.arange(Q)[:, None].repeat(8, axis=1),
233
+ quad2adj,
234
+ ], axis=-1)[quad2adj != -1] # [C, 2]
235
+ quads_connection = np.sort(quads_connection, axis=-1) # [C, 2]
236
+ quads_connection, quads_connection_idx = np.unique(quads_connection, axis=0, return_index=True) # [C, 2], [C]
237
+ quads_smoothness = quads_smoothness[quad2adj != -1] # [C]
238
+ quads_smoothness = quads_smoothness[quads_connection_idx] # [C]
239
+ C = quads_connection.shape[0]
240
+
241
+ # Construct the linear programming problem
242
+
243
+ # Variables:
244
+ # quads_weight: [Q] weight of each quad face
245
+ # tri_min_weight: [T] minimum weight of each triangle face
246
+ # conn_min_weight: [C] minimum weight of each quad face connection
247
+ # conn_max_weight: [C] maximum weight of each quad face connection
248
+ # Objective:
249
+ # mimi
250
+
251
+ c = np.concatenate([
252
+ quads_distortion - 3,
253
+ quads_smoothness*4 - 2,
254
+ quads_smoothness*4,
255
+ ], axis=0) # [Q+C]
256
+
257
+ A_ub_triplet = np.concatenate([
258
+ np.stack([np.arange(T), edge_valid[face2edge[:, 0]], np.ones(T)], axis=1), # [T, 3]
259
+ np.stack([np.arange(T), edge_valid[face2edge[:, 1]], np.ones(T)], axis=1), # [T, 3]
260
+ np.stack([np.arange(T), edge_valid[face2edge[:, 2]], np.ones(T)], axis=1), # [T, 3]
261
+ np.stack([np.arange(T, T+C), np.arange(Q, Q+C), np.ones(C)], axis=1), # [C, 3]
262
+ np.stack([np.arange(T, T+C), quads_connection[:, 0], -np.ones(C)], axis=1), # [C, 3]
263
+ np.stack([np.arange(T, T+C), quads_connection[:, 1], -np.ones(C)], axis=1), # [C, 3]
264
+ np.stack([np.arange(T+C, T+2*C), np.arange(Q+C, Q+2*C), -np.ones(C)], axis=1), # [C, 3]
265
+ np.stack([np.arange(T+C, T+2*C), quads_connection[:, 0], np.ones(C)], axis=1), # [C, 3]
266
+ np.stack([np.arange(T+C, T+2*C), quads_connection[:, 1], np.ones(C)], axis=1), # [C, 3]
267
+ ], axis=0) # [3T+6C, 3]
268
+ A_ub_triplet = A_ub_triplet[A_ub_triplet[:, 1] != -1] # [3T', 3]
269
+ A_ub = sp.sparse.coo_matrix((A_ub_triplet[:, 2], (A_ub_triplet[:, 0], A_ub_triplet[:, 1])), shape=[T+2*C, Q+2*C]) # [T,
270
+ b_ub = np.concatenate([np.ones(T), -np.ones(C), np.ones(C)], axis=0) # [T+2C]
271
+ bound = np.stack([
272
+ np.concatenate([np.zeros(Q), -np.ones(C), np.zeros(C)], axis=0),
273
+ np.concatenate([np.ones(Q), np.ones(C), np.ones(C)], axis=0),
274
+ ], axis=1) # [Q+2C, 2]
275
+ A_eq = None
276
+ b_eq = None
277
+
278
+ print('Solver statistics:')
279
+ print(f' #T = {T}')
280
+ print(f' #Q = {Q}')
281
+ print(f' #C = {C}')
282
+
283
+ # Solve the linear programming problem
284
+ last_num_valid = 0
285
+ for i in range(100):
286
+ res_ = spopt.linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=bound)
287
+ if not res_.success:
288
+ print(f' Iter {i} | Failed with {res_.message}')
289
+ break
290
+ res = res_
291
+ weights = res.x[:Q]
292
+ valid = (weights > 0.5)
293
+ num_valid = valid.sum()
294
+ print(f' Iter {i} | #Q_valid = {num_valid}')
295
+ if num_valid == last_num_valid:
296
+ break
297
+ last_num_valid = num_valid
298
+ A_eq_triplet = np.stack([
299
+ np.arange(num_valid),
300
+ np.arange(Q)[valid],
301
+ np.ones(num_valid),
302
+ ], axis=1) # [num_valid, 3]
303
+ A_eq = sp.sparse.coo_matrix((A_eq_triplet[:, 2], (A_eq_triplet[:, 0], A_eq_triplet[:, 1])), shape=[num_valid, Q+2*C]) # [num_valid, Q+C]
304
+ b_eq = np.where(weights[valid] > 0.5, 1, 0) # [num_valid]
305
+
306
+ # Return the result
307
+ quads_weight = res.x[:Q]
308
+ conn_min_weight = res.x[Q:Q+C]
309
+ conn_max_weight = res.x[Q+C:Q+2*C]
310
+ return quads_weight, conn_min_weight, conn_max_weight
311
+
312
+
313
+ def sovle_quad_qp(
314
+ face2edge: np.ndarray,
315
+ edge2face: np.ndarray,
316
+ quad2adj: np.ndarray,
317
+ quads_distortion: np.ndarray,
318
+ quads_smoothness: np.ndarray,
319
+ quads_valid: np.ndarray,
320
+ ):
321
+ """
322
+ Solve the quad mesh from the candidate quad faces.
323
+
324
+ Args:
325
+ face2edge (np.ndarray): [T, 3] face to edge relation
326
+ edge2face (np.ndarray): [E, 2] edge to face relation
327
+ quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face
328
+ quads_distortion (np.ndarray): [Q] distortion of each quad face
329
+ quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection
330
+ quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid
331
+
332
+ Returns:
333
+ weights (np.ndarray): [Q] weight of each valid quad face
334
+ """
335
+ T = face2edge.shape[0]
336
+ E = edge2face.shape[0]
337
+ Q = quads_distortion.shape[0]
338
+ edge_valid = -np.ones(E, dtype=np.int32)
339
+ edge_valid[quads_valid] = np.arange(Q)
340
+
341
+ # Construct the quadratic programming problem
342
+ C_smoothness_triplet = np.stack([
343
+ np.arange(Q)[:, None].repeat(8, axis=1)[quad2adj != -1],
344
+ quad2adj[quad2adj != -1],
345
+ 5 * quads_smoothness[quad2adj != -1],
346
+ ], axis=-1) # [C, 3]
347
+ # C_smoothness_triplet = np.concatenate([
348
+ # C_smoothness_triplet,
349
+ # np.stack([np.arange(Q), np.arange(Q), 20*np.ones(Q)], axis=1),
350
+ # ], axis=0) # [C+Q, 3]
351
+ C_smoothness = sp.sparse.coo_matrix((C_smoothness_triplet[:, 2], (C_smoothness_triplet[:, 0], C_smoothness_triplet[:, 1])), shape=[Q, Q]) # [Q, Q]
352
+ C_smoothness = C_smoothness.tocsc()
353
+ C_dist = quads_distortion - 20 # [Q]
354
+
355
+ A_eq = sp.sparse.coo_matrix((np.zeros(Q), (np.zeros(Q), np.arange(Q))), shape=[1, Q]) # [1, Q]\
356
+ A_eq = A_eq.tocsc()
357
+ b_eq = np.array([0])
358
+
359
+ A_ub_triplet = np.concatenate([
360
+ np.stack([np.arange(T), edge_valid[face2edge[:, 0]], np.ones(T)], axis=1), # [T, 3]
361
+ np.stack([np.arange(T), edge_valid[face2edge[:, 1]], np.ones(T)], axis=1), # [T, 3]
362
+ np.stack([np.arange(T), edge_valid[face2edge[:, 2]], np.ones(T)], axis=1), # [T, 3]
363
+ ], axis=0) # [3T, 3]
364
+ A_ub_triplet = A_ub_triplet[A_ub_triplet[:, 1] != -1] # [3T', 3]
365
+ A_ub = sp.sparse.coo_matrix((A_ub_triplet[:, 2], (A_ub_triplet[:, 0], A_ub_triplet[:, 1])), shape=[T, Q]) # [T, Q]
366
+ A_ub = A_ub.tocsc()
367
+ b_ub = np.ones(T)
368
+
369
+ lb = np.zeros(Q)
370
+ ub = np.ones(Q)
371
+
372
+ import piqp
373
+ solver = piqp.SparseSolver()
374
+ solver.settings.verbose = True
375
+ solver.settings.compute_timings = True
376
+ solver.setup(C_smoothness, C_dist, A_eq, b_eq, A_ub, b_ub, lb, ub)
377
+
378
+ status = solver.solve()
379
+
380
+ # x = cp.Variable(Q)
381
+ # prob = cp.Problem(
382
+ # cp.Minimize(cp.quad_form(x, C_smoothness) + C_dist.T @ x),
383
+ # [
384
+ # A_ub @ x <= b_ub,
385
+ # x >= 0, x <= 1,
386
+ # ]
387
+ # )
388
+
389
+ # # Solve the quadratic programming problem
390
+ # prob.solve(solver=cp.PIQP, verbose=True)
391
+
392
+ # Return the result
393
+ weights = solver.result.x
394
+ return weights
395
+
396
+
397
+ def tri_to_quad(
398
+ vertices: np.ndarray,
399
+ faces: np.ndarray,
400
+ ) -> Tuple[np.ndarray, np.ndarray]:
401
+ """
402
+ Convert a triangle mesh to a quad mesh.
403
+ NOTE: The input mesh must be a manifold mesh.
404
+
405
+ Args:
406
+ vertices (np.ndarray): [N, 3] 3-dimensional vertices
407
+ faces (np.ndarray): [T, 3] triangular face indices
408
+
409
+ Returns:
410
+ vertices (np.ndarray): [N_, 3] 3-dimensional vertices
411
+ faces (np.ndarray): [Q, 4] quad face indices
412
+ """
413
+ raise NotImplementedError
414
+
415
+
416
+ if __name__ == '__main__':
417
+ import os
418
+ import sys
419
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')))
420
+ import utils3d
421
+ import numpy as np
422
+ import cv2
423
+ from vis import vis_edge_color
424
+
425
+ file = 'miku'
426
+
427
+ vertices, faces = utils3d.io.read_ply(f'test/assets/{file}.ply')
428
+ edges, edge2face, face2edge, face2face = calc_relations(faces)
429
+ quad_cands, quad2edge, quad2adj, quad_valid = calc_quad_candidates(edges, face2edge, edge2face)
430
+ distortion = calc_quad_distortion(vertices, quad_cands)
431
+ direction = calc_quad_direction(vertices, quad_cands)
432
+ smoothness = calc_quad_smoothness(quad2edge, quad2adj, direction)
433
+ boundary_edges = edges[edge2face[:, 1] == -1]
434
+ quads_weight, conn_min_weight, conn_max_weight = sovle_quad(face2edge, edge2face, quad2adj, distortion, smoothness, quad_valid)
435
+ quads = quad_cands[quads_weight > 0.5]
436
+ print('Mesh statistics')
437
+ print(f' #V = {vertices.shape[0]}')
438
+ print(f' #F = {faces.shape[0]}')
439
+ print(f' #E = {edges.shape[0]}')
440
+ print(f' #B = {boundary_edges.shape[0]}')
441
+ print(f' #Q_cand = {quad_cands.shape[0]}')
442
+ print(f' #Q = {quads.shape[0]}')
443
+
444
+ utils3d.io.write_ply(f'test/assets/{file}_boundary_edges.ply', vertices=vertices, edges=boundary_edges)
445
+ utils3d.io.write_ply(f'test/assets/{file}_quad_candidates.ply', vertices=vertices, faces=quads)
446
+
447
+ edge_colors = np.zeros([edges.shape[0], 3], dtype=np.uint8)
448
+ distortion = (distortion - distortion.min()) / (distortion.max() - distortion.min())
449
+ distortion = (distortion * 255).astype(np.uint8)
450
+ edge_colors[quad_valid] = cv2.cvtColor(cv2.applyColorMap(distortion, cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
451
+ utils3d.io.write_ply(f'test/assets/{file}_quad_candidates_distortion.ply', **vis_edge_color(vertices, edges, edge_colors))
452
+
453
+ edge_colors = np.zeros([edges.shape[0], 3], dtype=np.uint8)
454
+ edge_colors[quad_valid] = cv2.cvtColor(cv2.applyColorMap((quads_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
455
+ utils3d.io.write_ply(f'test/assets/{file}_quad_candidates_weights.ply', **vis_edge_color(vertices, edges, edge_colors))
456
+ utils3d.io.write_ply(f'test/assets/{file}_quad.ply', vertices=vertices, faces=quads)
457
+
458
+ quad_centers = vertices[quad_cands].mean(axis=1)
459
+ conns = np.stack([
460
+ np.arange(quad_cands.shape[0])[:, None].repeat(8, axis=1),
461
+ quad2adj,
462
+ ], axis=-1)[quad2adj != -1] # [C, 2]
463
+ conns, conns_idx = np.unique(np.sort(conns, axis=-1), axis=0, return_index=True) # [C, 2], [C]
464
+ smoothness = smoothness[quad2adj != -1][conns_idx] # [C]
465
+ conns_color = cv2.cvtColor(cv2.applyColorMap((smoothness * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
466
+ utils3d.io.write_ply(f'test/assets/{file}_quad_conn_smoothness.ply', **vis_edge_color(quad_centers, conns, conns_color))
467
+ conns_color = cv2.cvtColor(cv2.applyColorMap((conn_min_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
468
+ utils3d.io.write_ply(f'test/assets/{file}_quad_conn_min.ply', **vis_edge_color(quad_centers, conns, conns_color))
469
+ conns_color = cv2.cvtColor(cv2.applyColorMap((conn_max_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
470
+ utils3d.io.write_ply(f'test/assets/{file}_quad_conn_max.ply', **vis_edge_color(quad_centers, conns, conns_color))
471
+
472
+