diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..b75c5c5c8b54dd7eb396efd67261396b063f16b6
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,195 @@
+scripts/rendering/blender-*
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+.vscode/
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
+*.sif
+blender-4.2.5-linux-x64*/
+*.zip
+*.png
+*.jpg
+*.log
+intermediate/
+__pycache__
+*.ply
+*.npy
+*.npz
+*.obj
+*.mtl
+*.json.gz
+dcgm/
+wandb/
+# *.json
+
+
+
+# Exception to add training list
+!lgm_leq20Kpts_simtopo25_train.json.gz
+!lgm_leq20Kpts_simtopo25_test.json.gz
+!lgm_leq20Kpts_train.json.gz
+!lgm_leq20Kpts_train_same_size_wrt_simtopo.json.gz
+*.mp4
+*.gif
+*.glb
+!lgm_leq20Kpts_plus_3Kremain_train.json.gz
+!lgm_leq20Kpts_test_cleaned.json.gz
+!lgm_leq20Kpts_train_cleaned.json.gz
+test_metrics.json
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ca3d21fd05e09ce8861f4882e7a37a625167d73
--- /dev/null
+++ b/app.py
@@ -0,0 +1,282 @@
+import argparse
+import gradio
+import torch
+import torch.backends.cudnn as cudnn
+from src.utils.vis import prob_to_mask
+from src.lari.model import LaRIModel, DinoSegModel
+from tools import load_model, process_image, post_process_output, get_masked_depth, save_to_glb, get_point_cloud, removebg_crop
+from huggingface_hub import hf_hub_download
+
+parser = argparse.ArgumentParser("Arguments for deploying a LaRI Demo")
+
+parser.add_argument(
+ "--model_info_pm",
+ type=str,
+ default="LaRIModel(use_pretrained = 'moge_full', num_output_layer = 5, head_type = 'point')",
+ help="Network parameters to load the model",
+)
+
+parser.add_argument(
+ "--model_info_mask",
+ type=str,
+ default="DinoSegModel(use_pretrained = 'dinov2', dim_proj = 256, pretrained_path = '', num_output_layer = 4, output_type = 'ray_stop')",
+ help="Network parameters to load the model",
+)
+
+parser.add_argument(
+ "--ckpt_path_pm",
+ type=str,
+ default="lari_obj_16k_pointmap.pth",
+ help="Path to pre-trained weights",
+)
+
+parser.add_argument(
+ "--ckpt_path_mask",
+ type=str,
+ default="lari_obj_16k_seg.pth",
+ help="Path to pre-trained weights",
+)
+
+parser.add_argument(
+ "--resolution", type=int, default=512, help="Default model resolution"
+)
+args = parser.parse_args()
+
+
+
+def model_forward(pil_input, layered_id, rembg_checkbox):
+ """
+ Perform LaRI estimation by:
+ 1. image processing
+ 2. network forward
+ 3. save masked layered depth image
+ 4. save point cloud
+ """
+ if pil_input is None:
+ return (None, None, None, None, None, None)
+
+ if rembg_checkbox:
+ pil_input = removebg_crop(pil_input)
+
+ # Process the input image.
+ input_tensor, ori_img_tensor, crop_coords, original_size = process_image(
+ pil_input, resolution=512
+ )
+ input_tensor = input_tensor.to(device)
+
+ # Run inference.
+ with torch.no_grad():
+ # lari map
+ pred_dict = model_pm(input_tensor)
+ lari_map = -pred_dict["pts3d"].squeeze(
+ 0
+ ) # Expected output shape: (H_reso, W_reso, L, 3)
+ # mask
+ if model_mask:
+ pred_dict = model_mask(input_tensor)
+ assert "seg_prob" in pred_dict
+ valid_mask = prob_to_mask(pred_dict["seg_prob"].squeeze(0)) # H W L 1
+ else:
+ h, w, l, _ = lari_map.shape
+ valid_mask = torch.new_ones((h, w, l, 1), device=lari_map.device)
+
+ # crop & resize the output to the original resolution.
+ if original_size[0] != args.resolution or original_size[1] != args.resolution:
+ lari_map = post_process_output(lari_map, crop_coords, original_size) # H W L 3
+ valid_mask = post_process_output(
+ valid_mask.float(), crop_coords, original_size
+ ).bool() # H W L 1
+
+ max_n_layer = min(valid_mask.shape[-2], lari_map.shape[-2])
+ valid_mask = valid_mask[:, :, :max_n_layer, :]
+ lari_map = lari_map[:, :, :max_n_layer, :]
+
+ curr_layer_id = min(max_n_layer - 1, layered_id - 1)
+
+ # masked depth list
+ depth_image = get_masked_depth(
+ lari_map=lari_map, valid_mask=valid_mask, layer_id=curr_layer_id
+ )
+ # point cloud
+ glb_path, ply_path = get_point_cloud(
+ lari_map, ori_img_tensor, valid_mask, first_layer_color="pseudo"
+ )
+
+ return (
+ depth_image,
+ glb_path,
+ lari_map,
+ valid_mask,
+ 0,
+ max_n_layer - 1,
+ glb_path,
+ ply_path,
+ pil_input,
+ )
+
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+cudnn.benchmark = True
+
+
+# Download the file
+model_path_pm = hf_hub_download(repo_id="ruili3/LaRI", filename=args.ckpt_path_pm, repo_type="model")
+model_path_mask = hf_hub_download(repo_id="ruili3/LaRI", filename=args.ckpt_path_mask, repo_type="model")
+
+
+# Load the model with pretrained weights.
+model_pm = load_model(args.model_info_pm, model_path_pm, device)
+model_mask = (
+ load_model(args.model_info_mask, model_path_mask, device)
+ if args.model_info_mask is not None
+ else None
+)
+
+
+def change_layer(slider_layer_id, lari_map, valid_mask, min_layer_id, max_layer_id):
+
+ if lari_map is None:
+ return
+
+ slider_layer_id = slider_layer_id - 1
+ curr_layer_id = min(slider_layer_id, max_layer_id)
+ curr_layer_id = max(curr_layer_id, min_layer_id)
+
+ # masked depth list
+ depth_image = get_masked_depth(
+ lari_map=lari_map, valid_mask=valid_mask, layer_id=curr_layer_id
+ )
+
+ return depth_image
+
+
+def clear_everything():
+ return (
+ gradio.update(value=None),
+ gradio.update(value=None),
+ gradio.update(value=None),
+ gradio.update(value=None),
+ gradio.update(value=None),
+ gradio.update(value=None),
+ gradio.update(value=None),
+ )
+
+
+with gradio.Blocks(
+ css=""".gradio-container {margin: 0 !important; min-width: 100%};""",
+ title="LaRI Demo",
+) as demo:
+
+ gradio.Markdown(
+ "
LaRI: Layered Ray Intersections for Single-view 3D Geometric Reasoning
",
+ elem_id="title",
+ )
+
+ gradio.Markdown(
+ """
+ This is the official demo of Layered Ray Intersection (LaRI). For a quick start, click the images in 'Examples' and then click the 'Process' Button.
+
+ You can try with your own images with following steps:
+ - Load an image;
+ - Click the 'Process' button;
+ - Browse layered depth maps (z-channel of the resulting LaRI point map) by tunning 'Layer ID';
+
+ Note that in '3D Point Cloud', different color denotes diffrent intersection layers, i.e., layer 1, layer 2, layer 3, layer 4.
+ """
+ )
+
+ # , layer 5.
+ lari_map = gradio.State(None)
+ valid_mask = gradio.State(None)
+ min_layer_id = gradio.State(None)
+ max_layer_id = gradio.State(None)
+
+ with gradio.Column():
+ with gradio.Row(equal_height=True):
+ with gradio.Column(scale=1):
+ image_input = gradio.Image(
+ label="Upload an Image", type="pil", height=350
+ )
+ with gradio.Row():
+ rembg_checkbox = gradio.Checkbox(label="Remove background")
+ clear_button = gradio.Button("Clear")
+ submit_btn = gradio.Button("Process")
+ with gradio.Column(scale=1):
+ depth_output = gradio.Image(
+ label="LaRI Map at Z-axis (depth)",
+ type="pil",
+ interactive=False,
+ height=300,
+ )
+ slider_layer_id = gradio.Slider(
+ minimum=1,
+ maximum=4,
+ step=1,
+ value=1,
+ label="Layer ID",
+ interactive=True,
+ )
+
+ with gradio.Row(scale=1):
+ outmodel = gradio.Model3D(
+ label="3D Point Cloud (Color denotes different layers)",
+ interactive=False,
+ zoom_speed=0.5,
+ pan_speed=0.5,
+ height=450,
+ )
+
+ with gradio.Row():
+ ply_file_output = gradio.File(label="ply output", elem_classes="small-file")
+ glb_file_output = gradio.File(label="glb output", elem_classes="small-file")
+
+ submit_btn.click(
+ fn=model_forward,
+ inputs=[image_input, slider_layer_id, rembg_checkbox],
+ outputs=[
+ depth_output,
+ outmodel,
+ lari_map,
+ valid_mask,
+ min_layer_id,
+ max_layer_id,
+ glb_file_output,
+ ply_file_output,
+ image_input,
+ ],
+ )
+
+ clear_button.click(
+ fn=clear_everything,
+ outputs=[
+ lari_map,
+ valid_mask,
+ min_layer_id,
+ max_layer_id,
+ image_input,
+ depth_output,
+ outmodel,
+ ],
+ )
+
+ slider_layer_id.change(
+ fn=change_layer,
+ inputs=[slider_layer_id, lari_map, valid_mask, min_layer_id, max_layer_id],
+ outputs=depth_output,
+ )
+
+ gradio.Examples(examples=["assets/cole_hardware.png",
+ "assets/3m_tape.png",
+ "assets/horse.png",
+ "assets/rhino.png",
+ "assets/alphabet.png",
+ "assets/martin_wedge.png",
+ "assets/d_rose.png",
+ "assets/ace.png",
+ "assets/bifidus.png",
+ "assets/fem.png",
+ ],
+ inputs=image_input)
+
+
+demo.launch(share=False)
diff --git a/demo.py b/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..389fd4ddda46554037b6a6138124ec77c07ac2d5
--- /dev/null
+++ b/demo.py
@@ -0,0 +1,136 @@
+import argparse
+import os
+import torch
+import torch.backends.cudnn as cudnn
+from PIL import Image
+from src.utils.vis import prob_to_mask
+from huggingface_hub import hf_hub_download
+from tools import load_model, process_image, post_process_output, get_masked_depth, get_point_cloud, removebg_crop
+
+parser = argparse.ArgumentParser("Arguments for deploying a LaRI Demo")
+parser.add_argument(
+ "--image_path",
+ type=str,
+ default="assets/cole_hardware.png",
+ help="input image name",
+)
+
+parser.add_argument(
+ "--output_path",
+ type=str,
+ default="./results",
+ help="path to save the image",
+)
+
+parser.add_argument(
+ "--model_info_pm",
+ type=str,
+ default="LaRIModel(use_pretrained = 'moge_full', num_output_layer = 5, head_type = 'point')",
+ help="Network parameters to load the model",
+)
+
+parser.add_argument(
+ "--model_info_mask",
+ type=str,
+ default="DinoSegModel(use_pretrained = 'dinov2', dim_proj = 256, pretrained_path = '', num_output_layer = 4, output_type = 'ray_stop')",
+ help="Network parameters to load the model",
+)
+
+parser.add_argument(
+ "--ckpt_path_pm",
+ type=str,
+ default="lari_obj_16k_pointmap.pth",
+ help="Path to pre-trained weights",
+)
+
+parser.add_argument(
+ "--ckpt_path_mask",
+ type=str,
+ default="lari_obj_16k_seg.pth",
+ help="Path to pre-trained weights",
+)
+
+parser.add_argument(
+ "--resolution", type=int, default=512, help="Default model resolution"
+)
+
+parser.add_argument(
+ "--is_remove_background", action="store_true", help="Automatically remove the background."
+)
+
+args = parser.parse_args()
+
+
+
+
+
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+cudnn.benchmark = True
+
+# === Load the model
+
+model_path_pm = hf_hub_download(repo_id="ruili3/LaRI", filename=args.ckpt_path_pm, repo_type="model")
+model_path_mask = hf_hub_download(repo_id="ruili3/LaRI", filename=args.ckpt_path_mask, repo_type="model")
+# Load the model with pretrained weights.
+model_pm = load_model(args.model_info_pm, model_path_pm, device)
+model_mask = (
+ load_model(args.model_info_mask, model_path_mask, device)
+ if args.model_info_mask is not None
+ else None
+)
+
+# === Image pre-processing
+pil_input = Image.open(args.image_path)
+if args.is_remove_background:
+ pil_input = removebg_crop(pil_input) # remove background
+input_tensor, ori_img_tensor, crop_coords, original_size = process_image(
+ pil_input, resolution=512) # crop & resize to fit the model input size
+input_tensor = input_tensor.to(device)
+
+
+# === Run inference
+with torch.no_grad():
+ # lari map
+ pred_dict = model_pm(input_tensor)
+ lari_map = -pred_dict["pts3d"].squeeze(
+ 0
+ )
+ # mask
+ if model_mask:
+ pred_dict = model_mask(input_tensor)
+ assert "seg_prob" in pred_dict
+ valid_mask = prob_to_mask(pred_dict["seg_prob"].squeeze(0)) # H W L 1
+ else:
+ h, w, l, _ = lari_map.shape
+ valid_mask = torch.new_ones((h, w, l, 1), device=lari_map.device)
+
+# === crop & resize back to the original resolution
+if original_size[0] != args.resolution or original_size[1] != args.resolution:
+ lari_map = post_process_output(lari_map, crop_coords, original_size) # H W L 3
+ valid_mask = post_process_output(
+ valid_mask.float(), crop_coords, original_size
+ ).bool() # H W L 1
+
+max_n_layer = min(valid_mask.shape[-2], lari_map.shape[-2])
+valid_mask = valid_mask[:, :, :max_n_layer, :]
+lari_map = lari_map[:, :, :max_n_layer, :]
+
+
+# === save output
+os.makedirs(args.output_path, exist_ok=True)
+
+for layer_id in range(max_n_layer):
+ depth_pil = get_masked_depth(
+ lari_map=lari_map, valid_mask=valid_mask, layer_id=layer_id
+ )
+ depth_pil.save(os.path.join(args.output_path, f"layered_depth_{layer_id}.jpg"))
+
+
+# point cloud
+glb_path, ply_path = get_point_cloud(
+ lari_map, ori_img_tensor, valid_mask, first_layer_color="pseudo",
+ target_folder=args.output_path
+)
+
+print("All results saved to `{}`.".format(args.output_path))
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..26d85c62b9f1d7294e4c0738272cedc83f2cc6ce
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,19 @@
+gradio==5.23.3
+huggingface_hub==0.30.1
+imageio==2.37.0
+matplotlib==3.10.1
+moderngl==5.12.0
+omegaconf==2.3.0
+opencv_python==4.11.0.86
+opencv_python_headless==4.11.0.86
+Pillow==11.1.0
+piqp==0.5.0
+plyfile==1.1
+rembg==2.0.65
+scipy==1.15.2
+torchvision==0.21.0
+trimesh==4.6.4
+xformers==0.0.29.post3
+numpy==1.26.4
+torch==2.6.0
+opencv-python==4.11.0
\ No newline at end of file
diff --git a/src/lari/model/__init__.py b/src/lari/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d64ed1ea63d542c9852a5ef5cd933a8dd018d94
--- /dev/null
+++ b/src/lari/model/__init__.py
@@ -0,0 +1,2 @@
+from .lari_model import LaRIModel
+from .dinoseg_model import DinoSegModel
\ No newline at end of file
diff --git a/src/lari/model/blocks.py b/src/lari/model/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..51d061093ef4fb1f25532ae3b0bf3cbcbcf709f1
--- /dev/null
+++ b/src/lari/model/blocks.py
@@ -0,0 +1,209 @@
+from typing import *
+import torch.nn as nn
+
+class ResidualConvBlock(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int = None, hidden_channels: int = None, padding_mode: str = 'replicate', activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', norm: Literal['group_norm', 'layer_norm'] = 'group_norm'):
+ super(ResidualConvBlock, self).__init__()
+ if out_channels is None:
+ out_channels = in_channels
+ if hidden_channels is None:
+ hidden_channels = in_channels
+
+ if activation =='relu':
+ activation_cls = lambda: nn.ReLU(inplace=True)
+ elif activation == 'leaky_relu':
+ activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ elif activation =='silu':
+ activation_cls = lambda: nn.SiLU(inplace=True)
+ elif activation == 'elu':
+ activation_cls = lambda: nn.ELU(inplace=True)
+ else:
+ raise ValueError(f'Unsupported activation function: {activation}')
+
+ self.layers = nn.Sequential(
+ nn.GroupNorm(1, in_channels),
+ activation_cls(),
+ nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode=padding_mode),
+ nn.GroupNorm(hidden_channels // 32 if norm == 'group_norm' else 1, hidden_channels),
+ activation_cls(),
+ nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode)
+ )
+
+ self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity()
+
+ def forward(self, x):
+ skip = self.skip_connection(x)
+ x = self.layers(x)
+ x = x + skip
+ return x
+
+
+
+
+def make_upsampler(in_channels: int, out_channels: int):
+ upsampler = nn.Sequential(
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
+ )
+ upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1]
+ return upsampler
+
+def make_output_block(dim_in: int, dim_out: int, dim_times_res_block_hidden: int, last_res_blocks: int, last_conv_channels: int, last_conv_size: int, res_block_norm: Literal['group_norm', 'layer_norm']):
+ return nn.Sequential(
+ nn.Conv2d(dim_in, last_conv_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
+ *(ResidualConvBlock(last_conv_channels, last_conv_channels, dim_times_res_block_hidden * last_conv_channels, activation='relu', norm=res_block_norm) for _ in range(last_res_blocks)),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(last_conv_channels, dim_out, kernel_size=last_conv_size, stride=1, padding=last_conv_size // 2, padding_mode='replicate'),
+ )
+
+
+
+# ---- the following are from Depth Anything ----
+import torch.nn as nn
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape
+
+ if expand:
+ out_shape1 = out_shape
+ out_shape2 = out_shape * 2
+ out_shape3 = out_shape * 4
+ if len(in_shape) >= 4:
+ out_shape4 = out_shape * 8
+
+ scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+ scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+ scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+ if len(in_shape) >= 4:
+ scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
+
+ return scratch
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features, activation, bn):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups=1
+
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
+
+ if self.bn == True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn == True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn == True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(
+ self,
+ features,
+ activation,
+ deconv=False,
+ bn=False,
+ expand=False,
+ align_corners=True,
+ size=None
+ ):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups=1
+
+ self.expand = expand
+ out_features = features
+ if self.expand == True:
+ out_features = features // 2
+
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ self.size=size
+
+ def forward(self, *xs, size=None):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+
+ output = self.resConfUnit2(output)
+
+ if (size is None) and (self.size is None):
+ modifier = {"scale_factor": 2}
+ elif size is None:
+ modifier = {"size": self.size}
+ else:
+ modifier = {"size": size}
+
+ output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
+
+ output = self.out_conv(output)
+
+ return output
diff --git a/src/lari/model/dinoseg_model.py b/src/lari/model/dinoseg_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e6e218ac769a564ee297e1a466e030273e33374
--- /dev/null
+++ b/src/lari/model/dinoseg_model.py
@@ -0,0 +1,153 @@
+from typing import *
+from numbers import Number
+from functools import partial
+from pathlib import Path
+import importlib
+import warnings
+import json
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils
+import torch.utils.checkpoint
+import torch.version
+from huggingface_hub import hf_hub_download
+
+
+from src.lari.model.utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
+from src.lari.model.dpt_seg_head import DPTSegHead
+
+
+
+class DinoSegModel(nn.Module):
+
+ def __init__(self,
+ encoder: str = 'dinov2_vitl14',
+ intermediate_layers: Union[int, List[int]] = 4,
+ dim_proj: int = 512,
+ use_pretrained: Literal["dinov2", "moge_full", "moge_backbone", None] = None,
+ pretrained_path: str = None,
+ num_output_layer: str = None,
+ output_type: str = "ray_stop", # "seg_sep"
+ **deprecated_kwargs
+ ):
+ super(DinoSegModel, self).__init__()
+ if deprecated_kwargs:
+ warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
+
+ self.encoder = encoder
+ self.intermediate_layers = intermediate_layers
+ self.use_pretrained = use_pretrained
+ self.pretrained_path = pretrained_path
+ self.num_output_layer = num_output_layer
+ self.output_type = output_type
+ assert self.output_type in ["seg_sep", "ray_stop"]
+
+ hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder)
+
+ self.backbone = hub_loader(pretrained=True if self.use_pretrained == "dinov2" else False)
+ dim_feature = self.backbone.blocks[0].attn.qkv.in_features
+
+
+
+
+ self.head = DPTSegHead(in_channels=dim_feature,
+ features=dim_proj,
+ use_bn=True,
+ out_channels=[256, 512, 1024, 1024],
+ use_clstoken=False,
+ num_classes = num_output_layer,
+ output_type = self.output_type
+ )
+
+
+ if torch.__version__ >= '2.0':
+ self.enable_pytorch_native_sdpa()
+
+ self._load_pretrained()
+
+
+ def _load_pretrained(self):
+ '''
+ Load data from MoGe model
+ '''
+ return
+
+
+
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'DinoSegModel':
+ """
+ Load a model from a checkpoint file.
+
+ ### Parameters:
+ - `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
+ - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
+ - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
+
+ ### Returns:
+ - A new instance of `MoGe` with the parameters loaded from the checkpoint.
+ """
+ if Path(pretrained_model_name_or_path).exists():
+ checkpoint = torch.load(pretrained_model_name_or_path, map_location='cpu', weights_only=True)
+ else:
+ cached_checkpoint_path = hf_hub_download(
+ repo_id=pretrained_model_name_or_path,
+ repo_type="model",
+ filename="model.pt",
+ **hf_kwargs
+ )
+ checkpoint = torch.load(cached_checkpoint_path, map_location='cpu', weights_only=True)
+ model_config = checkpoint['model_config']
+ if model_kwargs is not None:
+ model_config.update(model_kwargs)
+ model = cls(**model_config)
+ model.load_state_dict(checkpoint['model'])
+ return model
+
+ @staticmethod
+ def cache_pretrained_backbone(encoder: str, pretrained: bool):
+ _ = torch.hub.load('facebookresearch/dinov2', encoder, pretrained=pretrained)
+
+ def load_pretrained_backbone(self):
+ "Load the backbone with pretrained dinov2 weights from torch hub"
+ state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict()
+ self.backbone.load_state_dict(state_dict)
+
+ def enable_backbone_gradient_checkpointing(self):
+ for i in range(len(self.backbone.blocks)):
+ self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
+
+ def enable_pytorch_native_sdpa(self):
+ for i in range(len(self.backbone.blocks)):
+ self.backbone.blocks[i].attn = wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn)
+
+
+
+ def forward(self, image: torch.Tensor, mixed_precision: bool = False) -> Dict[str, torch.Tensor]:
+ raw_img_h, raw_img_w = image.shape[-2:]
+ patch_h, patch_w = raw_img_h // 14, raw_img_w // 14
+ # Apply image transformation for DINOv2
+ image_14 = F.interpolate(image, (patch_h * 14, patch_w * 14), mode="bilinear", align_corners=False, antialias=True)
+
+ # Get intermediate layers from the backbone
+ with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=mixed_precision):
+ features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True)
+
+ # Predict points and mask (mask scores)
+ mask = self.head(features, patch_h, patch_w)
+
+ # b c h w
+ mask = F.interpolate(mask, (raw_img_h, raw_img_w), mode="bilinear", align_corners=False)
+
+ out_dict = {}
+
+ if self.output_type == "seg_sep":
+ # mask = torch.nn.functional.sigmoid(mask) # for binary segmentation
+ out_dict["mask"] = mask.permute(0, 2, 3, 1).unsqueeze(-1) # B H W L 1
+ elif self.output_type == "ray_stop":
+ out_dict["seg_prob"] = mask # B L+1 H W
+
+ return out_dict
\ No newline at end of file
diff --git a/src/lari/model/dinov2/__init__.py b/src/lari/model/dinov2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae847e46898077fe3d8701b8a181d7b4e3d41cd9
--- /dev/null
+++ b/src/lari/model/dinov2/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+__version__ = "0.0.1"
diff --git a/src/lari/model/dinov2/hub/__init__.py b/src/lari/model/dinov2/hub/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9
--- /dev/null
+++ b/src/lari/model/dinov2/hub/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
diff --git a/src/lari/model/dinov2/hub/backbones.py b/src/lari/model/dinov2/hub/backbones.py
new file mode 100644
index 0000000000000000000000000000000000000000..53fe83719d5107eb77a8f25ef1814c3d73446002
--- /dev/null
+++ b/src/lari/model/dinov2/hub/backbones.py
@@ -0,0 +1,156 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+from typing import Union
+
+import torch
+
+from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
+
+
+class Weights(Enum):
+ LVD142M = "LVD142M"
+
+
+def _make_dinov2_model(
+ *,
+ arch_name: str = "vit_large",
+ img_size: int = 518,
+ patch_size: int = 14,
+ init_values: float = 1.0,
+ ffn_layer: str = "mlp",
+ block_chunks: int = 0,
+ num_register_tokens: int = 0,
+ interpolate_antialias: bool = False,
+ interpolate_offset: float = 0.1,
+ pretrained: bool = True,
+ weights: Union[Weights, str] = Weights.LVD142M,
+ **kwargs,
+):
+ from ..models import vision_transformer as vits
+
+ if isinstance(weights, str):
+ try:
+ weights = Weights[weights]
+ except KeyError:
+ raise AssertionError(f"Unsupported weights: {weights}")
+
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
+ vit_kwargs = dict(
+ img_size=img_size,
+ patch_size=patch_size,
+ init_values=init_values,
+ ffn_layer=ffn_layer,
+ block_chunks=block_chunks,
+ num_register_tokens=num_register_tokens,
+ interpolate_antialias=interpolate_antialias,
+ interpolate_offset=interpolate_offset,
+ )
+ vit_kwargs.update(**kwargs)
+ model = vits.__dict__[arch_name](**vit_kwargs)
+
+ if pretrained:
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
+ model.load_state_dict(state_dict, strict=True)
+
+ return model
+
+
+def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
+
+
+def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_giant2",
+ ffn_layer="swiglufused",
+ weights=weights,
+ pretrained=pretrained,
+ **kwargs,
+ )
+
+
+def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_small",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_base",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_large",
+ pretrained=pretrained,
+ weights=weights,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
+
+
+def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
+ """
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
+ """
+ return _make_dinov2_model(
+ arch_name="vit_giant2",
+ ffn_layer="swiglufused",
+ weights=weights,
+ pretrained=pretrained,
+ num_register_tokens=4,
+ interpolate_antialias=True,
+ interpolate_offset=0.0,
+ **kwargs,
+ )
diff --git a/src/lari/model/dinov2/hub/utils.py b/src/lari/model/dinov2/hub/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c6641404093652d5a2f19b4cf283d976ec39e64
--- /dev/null
+++ b/src/lari/model/dinov2/hub/utils.py
@@ -0,0 +1,39 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import itertools
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
+
+
+def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
+ compact_arch_name = arch_name.replace("_", "")[:4]
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
+
+
+class CenterPadding(nn.Module):
+ def __init__(self, multiple):
+ super().__init__()
+ self.multiple = multiple
+
+ def _get_pad(self, size):
+ new_size = math.ceil(size / self.multiple) * self.multiple
+ pad_size = new_size - size
+ pad_size_left = pad_size // 2
+ pad_size_right = pad_size - pad_size_left
+ return pad_size_left, pad_size_right
+
+ @torch.inference_mode()
+ def forward(self, x):
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
+ output = F.pad(x, pads)
+ return output
diff --git a/src/lari/model/dinov2/layers/__init__.py b/src/lari/model/dinov2/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..05a0b61868e43abb821ca05a813bab2b8b43629e
--- /dev/null
+++ b/src/lari/model/dinov2/layers/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from .dino_head import DINOHead
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention
diff --git a/src/lari/model/dinov2/layers/attention.py b/src/lari/model/dinov2/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fed573116d5c837be46a7525d8acf77422c2400
--- /dev/null
+++ b/src/lari/model/dinov2/layers/attention.py
@@ -0,0 +1,89 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+import os
+import warnings
+
+from torch import Tensor
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import memory_efficient_attention, unbind
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (Attention)")
+ else:
+ # warnings.warn("xFormers is disabled (Attention)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+ # warnings.warn("xFormers is not available (Attention)")
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ if attn_bias is not None:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/src/lari/model/dinov2/layers/block.py b/src/lari/model/dinov2/layers/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd5b8a7bb8527b74186af7c1e060e37bdb52c73d
--- /dev/null
+++ b/src/lari/model/dinov2/layers/block.py
@@ -0,0 +1,259 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+import os
+from typing import Callable, List, Any, Tuple, Dict
+import warnings
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention, MemEffAttention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+logger = logging.getLogger("dinov2")
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (Block)")
+ else:
+ # warnings.warn("xFormers is disabled (Block)")
+ raise ImportError
+except ImportError:
+ XFORMERS_AVAILABLE = False
+ # warnings.warn("xFormers is not available (Block)")
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor) -> Tensor:
+ def attn_residual_func(x: Tensor) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x)))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ if not XFORMERS_AVAILABLE:
+ raise AssertionError("xFormers is required for using nested tensors")
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/src/lari/model/dinov2/layers/dino_head.py b/src/lari/model/dinov2/layers/dino_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565
--- /dev/null
+++ b/src/lari/model/dinov2/layers/dino_head.py
@@ -0,0 +1,58 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn.init import trunc_normal_
+from torch.nn.utils import weight_norm
+
+
+class DINOHead(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ use_bn=False,
+ nlayers=3,
+ hidden_dim=2048,
+ bottleneck_dim=256,
+ mlp_bias=True,
+ ):
+ super().__init__()
+ nlayers = max(nlayers, 1)
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
+ self.apply(self._init_weights)
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+ self.last_layer.weight_g.data.fill_(1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.mlp(x)
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
+ x = self.last_layer(x)
+ return x
+
+
+def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
+ if nlayers == 1:
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
+ else:
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ for _ in range(nlayers - 2):
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
+ return nn.Sequential(*layers)
diff --git a/src/lari/model/dinov2/layers/drop_path.py b/src/lari/model/dinov2/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5
--- /dev/null
+++ b/src/lari/model/dinov2/layers/drop_path.py
@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/src/lari/model/dinov2/layers/layer_scale.py b/src/lari/model/dinov2/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386
--- /dev/null
+++ b/src/lari/model/dinov2/layers/layer_scale.py
@@ -0,0 +1,27 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/src/lari/model/dinov2/layers/mlp.py b/src/lari/model/dinov2/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e
--- /dev/null
+++ b/src/lari/model/dinov2/layers/mlp.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/src/lari/model/dinov2/layers/patch_embed.py b/src/lari/model/dinov2/layers/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339
--- /dev/null
+++ b/src/lari/model/dinov2/layers/patch_embed.py
@@ -0,0 +1,88 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/src/lari/model/dinov2/layers/swiglu_ffn.py b/src/lari/model/dinov2/layers/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ce211515774d42e04c8b51003bae53b88f14b35
--- /dev/null
+++ b/src/lari/model/dinov2/layers/swiglu_ffn.py
@@ -0,0 +1,72 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import os
+from typing import Callable, Optional
+import warnings
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
+try:
+ if XFORMERS_ENABLED:
+ from xformers.ops import SwiGLU
+
+ XFORMERS_AVAILABLE = True
+ # warnings.warn("xFormers is available (SwiGLU)")
+ else:
+ # warnings.warn("xFormers is disabled (SwiGLU)")
+ raise ImportError
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+ # warnings.warn("xFormers is not available (SwiGLU)")
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/src/lari/model/dinov2/models/__init__.py b/src/lari/model/dinov2/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fdff20badbd5244bf79f16bf18dd2cb73982265
--- /dev/null
+++ b/src/lari/model/dinov2/models/__init__.py
@@ -0,0 +1,43 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+
+from . import vision_transformer as vits
+
+
+logger = logging.getLogger("dinov2")
+
+
+def build_model(args, only_teacher=False, img_size=224):
+ args.arch = args.arch.removesuffix("_memeff")
+ if "vit" in args.arch:
+ vit_kwargs = dict(
+ img_size=img_size,
+ patch_size=args.patch_size,
+ init_values=args.layerscale,
+ ffn_layer=args.ffn_layer,
+ block_chunks=args.block_chunks,
+ qkv_bias=args.qkv_bias,
+ proj_bias=args.proj_bias,
+ ffn_bias=args.ffn_bias,
+ num_register_tokens=args.num_register_tokens,
+ interpolate_offset=args.interpolate_offset,
+ interpolate_antialias=args.interpolate_antialias,
+ )
+ teacher = vits.__dict__[args.arch](**vit_kwargs)
+ if only_teacher:
+ return teacher, teacher.embed_dim
+ student = vits.__dict__[args.arch](
+ **vit_kwargs,
+ drop_path_rate=args.drop_path_rate,
+ drop_path_uniform=args.drop_path_uniform,
+ )
+ embed_dim = student.embed_dim
+ return student, teacher, embed_dim
+
+
+def build_model_from_cfg(cfg, only_teacher=False):
+ return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
diff --git a/src/lari/model/dinov2/models/vision_transformer.py b/src/lari/model/dinov2/models/vision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1007ba57ddb35109c91716f1f5bf203db346e7be
--- /dev/null
+++ b/src/lari/model/dinov2/models/vision_transformer.py
@@ -0,0 +1,396 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+
+logger = logging.getLogger("dinov2")
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
+ assert N == M * M
+ kwargs = {}
+ if self.interpolate_offset:
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
+ sx = float(w0 + self.interpolate_offset) / M
+ sy = float(h0 + self.interpolate_offset) / M
+ kwargs["scale_factor"] = (sx, sy)
+ else:
+ # Simply specify an output size instead of a scale factor
+ kwargs["size"] = (w0, h0)
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ **kwargs,
+ )
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (
+ x[:, :1],
+ self.register_tokens.expand(x.shape[0], -1, -1),
+ x[:, 1:],
+ ),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ return {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ if is_training:
+ return ret
+ else:
+ return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
diff --git a/src/lari/model/dinov2/utils/__init__.py b/src/lari/model/dinov2/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9
--- /dev/null
+++ b/src/lari/model/dinov2/utils/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
diff --git a/src/lari/model/dinov2/utils/cluster.py b/src/lari/model/dinov2/utils/cluster.py
new file mode 100644
index 0000000000000000000000000000000000000000..3df87dc3e1eb4f0f8a280dc3137cfef031886314
--- /dev/null
+++ b/src/lari/model/dinov2/utils/cluster.py
@@ -0,0 +1,95 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from enum import Enum
+import os
+from pathlib import Path
+from typing import Any, Dict, Optional
+
+
+class ClusterType(Enum):
+ AWS = "aws"
+ FAIR = "fair"
+ RSC = "rsc"
+
+
+def _guess_cluster_type() -> ClusterType:
+ uname = os.uname()
+ if uname.sysname == "Linux":
+ if uname.release.endswith("-aws"):
+ # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
+ return ClusterType.AWS
+ elif uname.nodename.startswith("rsc"):
+ # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
+ return ClusterType.RSC
+
+ return ClusterType.FAIR
+
+
+def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
+ if cluster_type is None:
+ return _guess_cluster_type()
+
+ return cluster_type
+
+
+def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
+ cluster_type = get_cluster_type(cluster_type)
+ if cluster_type is None:
+ return None
+
+ CHECKPOINT_DIRNAMES = {
+ ClusterType.AWS: "checkpoints",
+ ClusterType.FAIR: "checkpoint",
+ ClusterType.RSC: "checkpoint/dino",
+ }
+ return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
+
+
+def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
+ checkpoint_path = get_checkpoint_path(cluster_type)
+ if checkpoint_path is None:
+ return None
+
+ username = os.environ.get("USER")
+ assert username is not None
+ return checkpoint_path / username
+
+
+def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
+ cluster_type = get_cluster_type(cluster_type)
+ if cluster_type is None:
+ return None
+
+ SLURM_PARTITIONS = {
+ ClusterType.AWS: "learnlab",
+ ClusterType.FAIR: "learnlab",
+ ClusterType.RSC: "learn",
+ }
+ return SLURM_PARTITIONS[cluster_type]
+
+
+def get_slurm_executor_parameters(
+ nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
+) -> Dict[str, Any]:
+ # create default parameters
+ params = {
+ "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
+ "gpus_per_node": num_gpus_per_node,
+ "tasks_per_node": num_gpus_per_node, # one task per GPU
+ "cpus_per_task": 10,
+ "nodes": nodes,
+ "slurm_partition": get_slurm_partition(cluster_type),
+ }
+ # apply cluster-specific adjustments
+ cluster_type = get_cluster_type(cluster_type)
+ if cluster_type == ClusterType.AWS:
+ params["cpus_per_task"] = 12
+ del params["mem_gb"]
+ elif cluster_type == ClusterType.RSC:
+ params["cpus_per_task"] = 12
+ # set additional parameters / apply overrides
+ params.update(kwargs)
+ return params
diff --git a/src/lari/model/dinov2/utils/config.py b/src/lari/model/dinov2/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9de578787bbcb376f8bd5a782206d0eb7ec1f52
--- /dev/null
+++ b/src/lari/model/dinov2/utils/config.py
@@ -0,0 +1,72 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import math
+import logging
+import os
+
+from omegaconf import OmegaConf
+
+import dinov2.distributed as distributed
+from dinov2.logging import setup_logging
+from dinov2.utils import utils
+from dinov2.configs import dinov2_default_config
+
+
+logger = logging.getLogger("dinov2")
+
+
+def apply_scaling_rules_to_cfg(cfg): # to fix
+ if cfg.optim.scaling_rule == "sqrt_wrt_1024":
+ base_lr = cfg.optim.base_lr
+ cfg.optim.lr = base_lr
+ cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
+ logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
+ else:
+ raise NotImplementedError
+ return cfg
+
+
+def write_config(cfg, output_dir, name="config.yaml"):
+ logger.info(OmegaConf.to_yaml(cfg))
+ saved_cfg_path = os.path.join(output_dir, name)
+ with open(saved_cfg_path, "w") as f:
+ OmegaConf.save(config=cfg, f=f)
+ return saved_cfg_path
+
+
+def get_cfg_from_args(args):
+ args.output_dir = os.path.abspath(args.output_dir)
+ args.opts += [f"train.output_dir={args.output_dir}"]
+ default_cfg = OmegaConf.create(dinov2_default_config)
+ cfg = OmegaConf.load(args.config_file)
+ cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
+ return cfg
+
+
+def default_setup(args):
+ distributed.enable(overwrite=True)
+ seed = getattr(args, "seed", 0)
+ rank = distributed.get_global_rank()
+
+ global logger
+ setup_logging(output=args.output_dir, level=logging.INFO)
+ logger = logging.getLogger("dinov2")
+
+ utils.fix_random_seeds(seed + rank)
+ logger.info("git:\n {}\n".format(utils.get_sha()))
+ logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
+
+
+def setup(args):
+ """
+ Create configs and perform basic setups.
+ """
+ cfg = get_cfg_from_args(args)
+ os.makedirs(args.output_dir, exist_ok=True)
+ default_setup(args)
+ apply_scaling_rules_to_cfg(cfg)
+ write_config(cfg, args.output_dir)
+ return cfg
diff --git a/src/lari/model/dinov2/utils/dtype.py b/src/lari/model/dinov2/utils/dtype.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f4cd74d99faa2731dbe9f8d3a13d71b3f8e3a8
--- /dev/null
+++ b/src/lari/model/dinov2/utils/dtype.py
@@ -0,0 +1,37 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+
+from typing import Dict, Union
+
+import numpy as np
+import torch
+
+
+TypeSpec = Union[str, np.dtype, torch.dtype]
+
+
+_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
+ np.dtype("bool"): torch.bool,
+ np.dtype("uint8"): torch.uint8,
+ np.dtype("int8"): torch.int8,
+ np.dtype("int16"): torch.int16,
+ np.dtype("int32"): torch.int32,
+ np.dtype("int64"): torch.int64,
+ np.dtype("float16"): torch.float16,
+ np.dtype("float32"): torch.float32,
+ np.dtype("float64"): torch.float64,
+ np.dtype("complex64"): torch.complex64,
+ np.dtype("complex128"): torch.complex128,
+}
+
+
+def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
+ if isinstance(dtype, torch.dtype):
+ return dtype
+ if isinstance(dtype, str):
+ dtype = np.dtype(dtype)
+ assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
+ return _NUMPY_TO_TORCH_DTYPE[dtype]
diff --git a/src/lari/model/dinov2/utils/param_groups.py b/src/lari/model/dinov2/utils/param_groups.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a5d2ff627cddadc222e5f836864ee39c865208f
--- /dev/null
+++ b/src/lari/model/dinov2/utils/param_groups.py
@@ -0,0 +1,103 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+from collections import defaultdict
+import logging
+
+
+logger = logging.getLogger("dinov2")
+
+
+def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
+ """
+ Calculate lr decay rate for different ViT blocks.
+ Args:
+ name (string): parameter name.
+ lr_decay_rate (float): base lr decay rate.
+ num_layers (int): number of ViT blocks.
+ Returns:
+ lr decay rate for the given parameter.
+ """
+ layer_id = num_layers + 1
+ if name.startswith("backbone") or force_is_backbone:
+ if (
+ ".pos_embed" in name
+ or ".patch_embed" in name
+ or ".mask_token" in name
+ or ".cls_token" in name
+ or ".register_tokens" in name
+ ):
+ layer_id = 0
+ elif force_is_backbone and (
+ "pos_embed" in name
+ or "patch_embed" in name
+ or "mask_token" in name
+ or "cls_token" in name
+ or "register_tokens" in name
+ ):
+ layer_id = 0
+ elif ".blocks." in name and ".residual." not in name:
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
+ elif chunked_blocks and "blocks." in name and "residual." not in name:
+ layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
+ elif "blocks." in name and "residual." not in name:
+ layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
+
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
+
+
+def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
+ chunked_blocks = False
+ if hasattr(model, "n_blocks"):
+ logger.info("chunked fsdp")
+ n_blocks = model.n_blocks
+ chunked_blocks = model.chunked_blocks
+ elif hasattr(model, "blocks"):
+ logger.info("first code branch")
+ n_blocks = len(model.blocks)
+ elif hasattr(model, "backbone"):
+ logger.info("second code branch")
+ n_blocks = len(model.backbone.blocks)
+ else:
+ logger.info("else code branch")
+ n_blocks = 0
+ all_param_groups = []
+
+ for name, param in model.named_parameters():
+ name = name.replace("_fsdp_wrapped_module.", "")
+ if not param.requires_grad:
+ continue
+ decay_rate = get_vit_lr_decay_rate(
+ name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
+ )
+ d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
+
+ if "last_layer" in name:
+ d.update({"is_last_layer": True})
+
+ if name.endswith(".bias") or "norm" in name or "gamma" in name:
+ d.update({"wd_multiplier": 0.0})
+
+ if "patch_embed" in name:
+ d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
+
+ all_param_groups.append(d)
+ logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
+
+ return all_param_groups
+
+
+def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
+ fused_params_groups = defaultdict(lambda: {"params": []})
+ for d in all_params_groups:
+ identifier = ""
+ for k in keys:
+ identifier += k + str(d[k]) + "_"
+
+ for k in keys:
+ fused_params_groups[identifier][k] = d[k]
+ fused_params_groups[identifier]["params"].append(d["params"])
+
+ return fused_params_groups.values()
diff --git a/src/lari/model/dinov2/utils/utils.py b/src/lari/model/dinov2/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..68f8e2c3be5f780bbb7e00359b5ac4fd0ba0785f
--- /dev/null
+++ b/src/lari/model/dinov2/utils/utils.py
@@ -0,0 +1,95 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the Apache License, Version 2.0
+# found in the LICENSE file in the root directory of this source tree.
+
+import logging
+import os
+import random
+import subprocess
+from urllib.parse import urlparse
+
+import numpy as np
+import torch
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
+ if urlparse(pretrained_weights).scheme: # If it looks like an URL
+ state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
+ else:
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
+ if checkpoint_key is not None and checkpoint_key in state_dict:
+ logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
+ state_dict = state_dict[checkpoint_key]
+ # remove `module.` prefix
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
+ # remove `backbone.` prefix induced by multicrop wrapper
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
+ msg = model.load_state_dict(state_dict, strict=False)
+ logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
+
+
+def fix_random_seeds(seed=31):
+ """
+ Fix random seeds.
+ """
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+
+
+def get_sha():
+ cwd = os.path.dirname(os.path.abspath(__file__))
+
+ def _run(command):
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
+
+ sha = "N/A"
+ diff = "clean"
+ branch = "N/A"
+ try:
+ sha = _run(["git", "rev-parse", "HEAD"])
+ subprocess.check_output(["git", "diff"], cwd=cwd)
+ diff = _run(["git", "diff-index", "HEAD"])
+ diff = "has uncommitted changes" if diff else "clean"
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
+ except Exception:
+ pass
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
+ return message
+
+
+class CosineScheduler(object):
+ def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
+ super().__init__()
+ self.final_value = final_value
+ self.total_iters = total_iters
+
+ freeze_schedule = np.zeros((freeze_iters))
+
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
+
+ iters = np.arange(total_iters - warmup_iters - freeze_iters)
+ schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
+ self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
+
+ assert len(self.schedule) == self.total_iters
+
+ def __getitem__(self, it):
+ if it >= self.total_iters:
+ return self.final_value
+ else:
+ return self.schedule[it]
+
+
+def has_batchnorms(model):
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
+ for name, module in model.named_modules():
+ if isinstance(module, bn_types):
+ return True
+ return False
diff --git a/src/lari/model/dpt_seg_head.py b/src/lari/model/dpt_seg_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..45a1f82d2bf6e01dad32346f5e925be5e580c3b4
--- /dev/null
+++ b/src/lari/model/dpt_seg_head.py
@@ -0,0 +1,158 @@
+'''
+The code is modified based on Depth Anything and DPT
+'''
+from src.lari.model.blocks import FeatureFusionBlock, _make_scratch
+import cv2
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision.transforms import Compose
+
+
+
+
+def _make_fusion_block(features, use_bn, size=None):
+ return FeatureFusionBlock(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ size=size,
+ )
+
+
+class ConvBlock(nn.Module):
+ def __init__(self, in_feature, out_feature):
+ super().__init__()
+
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
+ nn.BatchNorm2d(out_feature),
+ nn.ReLU(True)
+ )
+
+ def forward(self, x):
+ return self.conv_block(x)
+
+
+class DPTSegHead(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ features=256,
+ use_bn=False,
+ out_channels=[256, 512, 1024, 1024],
+ use_clstoken=False,
+ num_classes = 5,
+ output_type = "ray_stop" # "seg_sep"
+ ):
+ super(DPTSegHead, self).__init__()
+
+ self.use_clstoken = use_clstoken
+ self.output_type = output_type
+
+ # output one more layer to indicate the invalid ray-stopping point using index 0
+ self.num_classes = num_classes + 1 if self.output_type == "ray_stop" else num_classes
+
+
+ self.projects = nn.ModuleList([
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ) for out_channel in out_channels
+ ])
+
+ self.resize_layers = nn.ModuleList([
+ nn.ConvTranspose2d(
+ in_channels=out_channels[0],
+ out_channels=out_channels[0],
+ kernel_size=4,
+ stride=4,
+ padding=0),
+ nn.ConvTranspose2d(
+ in_channels=out_channels[1],
+ out_channels=out_channels[1],
+ kernel_size=2,
+ stride=2,
+ padding=0),
+ nn.Identity(),
+ nn.Conv2d(
+ in_channels=out_channels[3],
+ out_channels=out_channels[3],
+ kernel_size=3,
+ stride=2,
+ padding=1)
+ ])
+
+ if use_clstoken:
+ self.readout_projects = nn.ModuleList()
+ for _ in range(len(self.projects)):
+ self.readout_projects.append(
+ nn.Sequential(
+ nn.Linear(2 * in_channels, in_channels),
+ nn.GELU()))
+
+ self.scratch = _make_scratch(
+ out_channels,
+ features,
+ groups=1,
+ expand=False,
+ )
+
+ self.scratch.stem_transpose = None
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+ self.scratch.output_conv1 = nn.Sequential(
+ nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(features),
+ nn.ReLU(True),
+ nn.Dropout(0.1, False),
+ nn.Conv2d(features, self.num_classes, kernel_size=1),
+ )
+
+
+
+ def forward(self, out_features, patch_h, patch_w):
+ out = []
+ for i, x in enumerate(out_features):
+ if self.use_clstoken:
+ x, cls_token = x[0], x[1]
+ readout = cls_token.unsqueeze(1).expand_as(x)
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
+ else:
+ x = x[0]
+
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
+
+ x = self.projects[i](x)
+ x = self.resize_layers[i](x)
+
+ out.append(x)
+
+ layer_1, layer_2, layer_3, layer_4 = out
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv1(path_1)
+
+ # B C H W - segmentaton logits
+ out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
+
+ return out
\ No newline at end of file
diff --git a/src/lari/model/heads.py b/src/lari/model/heads.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9559d07992b1c54e5a12b0efc7238c2fabaa410
--- /dev/null
+++ b/src/lari/model/heads.py
@@ -0,0 +1,104 @@
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils
+import torch.utils.checkpoint
+import torch.version
+from typing import *
+import os
+import sys
+sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
+from src.lari.model.blocks import ResidualConvBlock, make_upsampler, make_output_block
+from src.lari.utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, gaussian_blur_2d
+
+
+class PointHead(nn.Module):
+ def __init__(
+ self,
+ num_features: int,
+ dim_in: int,
+ dim_out: int,
+ dim_proj: int = 512,
+ dim_upsample: List[int] = [256, 128, 128],
+ dim_times_res_block_hidden: int = 1,
+ num_res_blocks: int = 1,
+ res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm',
+ last_res_blocks: int = 0,
+ last_conv_channels: int = 32,
+ last_conv_size: int = 1,
+ num_output_layer: int = 5
+ ):
+ super().__init__()
+
+ self.num_output_layer = num_output_layer
+
+ self.projects = nn.ModuleList([
+ nn.Conv2d(in_channels=dim_in, out_channels=dim_proj, kernel_size=1, stride=1, padding=0,) for _ in range(num_features)
+ ])
+
+ self.upsample_blocks = nn.ModuleList([
+ nn.Sequential(
+ make_upsampler(in_ch + 2, out_ch),
+ *(ResidualConvBlock(out_ch, out_ch, dim_times_res_block_hidden * out_ch, activation="relu", norm=res_block_norm) for _ in range(num_res_blocks))
+ ) for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample)
+ ])
+
+ # layer iterations
+ self.first_layer_block = make_output_block(dim_upsample[-1] + 2, dim_out,
+ dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm,)
+
+ self.remaining_layer_block = nn.ModuleList([make_output_block(dim_upsample[-1] + 2, dim_out,
+ dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm,)
+ for _ in range(self.num_output_layer - 1)])
+
+
+
+ def forward(self, hidden_states: torch.Tensor, image: torch.Tensor):
+ img_h, img_w = image.shape[-2:]
+ patch_h, patch_w = img_h // 14, img_w // 14
+
+ # Process the hidden states
+ x = torch.stack([
+ proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous())
+ for proj, (feat, clstoken) in zip(self.projects, hidden_states)
+ ], dim=1).sum(dim=1)
+
+ # Upsample stage
+ # (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8)
+ for i, block in enumerate(self.upsample_blocks):
+ # UV coordinates is for awareness of image aspect ratio
+ uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
+ x = torch.cat([x, uv], dim=1)
+ for layer in block:
+ x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False)
+
+ # (patch_h * 8, patch_w * 8) -> (img_h, img_w)
+ x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False)
+ uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
+ x = torch.cat([x, uv], dim=1)
+
+
+ pts_list = []
+ for layer_id in range(self.num_output_layer):
+ if layer_id == 0:
+ blocks = self.first_layer_block
+ else:
+ blocks = self.remaining_layer_block[layer_id-1]
+
+ # for each block
+ if isinstance(blocks, nn.ModuleList):
+ raise NotImplementedError()
+ else:
+ res = torch.utils.checkpoint.checkpoint(blocks, x, use_reentrant=False)[:,:3, :,:]
+ pts_list.append(res[:, :3, :,:])
+
+ pts = torch.stack(pts_list, dim=-1)
+ seg = pts.new_zeros(pts.shape)[:, :1, ...]
+
+ # ,
+ output = [pts, seg]
+
+ return output
\ No newline at end of file
diff --git a/src/lari/model/lari_model.py b/src/lari/model/lari_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4351c6b021e3a15ecb424a7f404ca08b39173305
--- /dev/null
+++ b/src/lari/model/lari_model.py
@@ -0,0 +1,177 @@
+from typing import *
+from numbers import Number
+from functools import partial
+from pathlib import Path
+import importlib
+import warnings
+import json
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils
+import torch.utils.checkpoint
+import torch.version
+from huggingface_hub import hf_hub_download
+from src.lari.model.utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
+from src.lari.model.heads import PointHead
+
+
+class LaRIModel(nn.Module):
+ image_mean: torch.Tensor
+ image_std: torch.Tensor
+
+ def __init__(self,
+ encoder: str = 'dinov2_vitl14',
+ intermediate_layers: Union[int, List[int]] = 4,
+ dim_proj: int = 512,
+ dim_upsample: List[int] = [256, 128, 64],
+ dim_times_res_block_hidden: int = 2,
+ num_res_blocks: int = 2,
+ output_mask: bool = True,
+ split_head: bool = True,
+ remap_output: Literal[False, True, 'linear', 'sinh', 'exp', 'sinh_exp'] = 'exp',
+ res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm',
+ last_res_blocks: int = 0,
+ last_conv_channels: int = 32,
+ last_conv_size: int = 1,
+ use_pretrained: Literal["dinov2", "moge_full", "moge_backbone", None] = None,
+ pretrained_path: str = "",
+ num_output_layer: str = None,
+ head_type = None,
+ **deprecated_kwargs
+ ):
+ super(LaRIModel, self).__init__()
+ if deprecated_kwargs:
+ warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
+
+ self.encoder = encoder
+ self.remap_output = remap_output
+ self.intermediate_layers = intermediate_layers
+ self.head_type = head_type
+ self.output_mask = output_mask
+ self.split_head = split_head
+ self.use_pretrained = use_pretrained
+ self.pretrained_path = pretrained_path
+ self.num_output_layer = num_output_layer
+
+ hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder)
+ # hub_loader = getattr(importlib.import_module("dinov2.hub.backbones", __package__), encoder)
+
+ self.backbone = hub_loader(pretrained=True if self.use_pretrained == "dinov2" else False)
+ dim_feature = self.backbone.blocks[0].attn.qkv.in_features
+
+ if self.head_type == "point":
+ self.head = PointHead(
+ num_features=intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers),
+ dim_in=dim_feature,
+ dim_out=3,
+ dim_proj=dim_proj,
+ dim_upsample=dim_upsample,
+ dim_times_res_block_hidden=dim_times_res_block_hidden,
+ num_res_blocks=num_res_blocks,
+ res_block_norm=res_block_norm,
+ last_res_blocks=last_res_blocks,
+ last_conv_channels=last_conv_channels,
+ last_conv_size=last_conv_size,
+ num_output_layer = num_output_layer
+ )
+ else:
+ raise NotImplementedError()
+
+
+ if torch.__version__ >= '2.0':
+ self.enable_pytorch_native_sdpa()
+
+ self._load_pretrained()
+
+
+ def _load_pretrained(self):
+ '''
+ Load pre-trained weights
+ '''
+ if self.use_pretrained == "dinov2" or self.use_pretrained is None: return
+
+ if self.use_pretrained == "moge_full" and self.pretrained_path != "":
+ checkpoint = torch.load(self.pretrained_path, map_location='cpu', weights_only=True)
+ if self.head_type == "point":
+ key_transition_map = {"output_block": "first_layer_block"}
+ model_state_dict = {}
+
+ # change the key name of the dict
+ for key, val in checkpoint['model'].items():
+ for trans_src, trans_target in key_transition_map.items():
+ if trans_src in key:
+ model_state_dict[key.replace(trans_src, trans_target)] = val
+ else:
+ model_state_dict[key] = val
+
+ self.load_state_dict(model_state_dict, strict=False)
+ del model_state_dict
+
+
+ else:
+ return
+
+
+ @staticmethod
+ def cache_pretrained_backbone(encoder: str, pretrained: bool):
+ _ = torch.hub.load('facebookresearch/dinov2', encoder, pretrained=pretrained)
+
+ def load_pretrained_backbone(self):
+ "Load the backbone with pretrained dinov2 weights from torch hub"
+ state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict()
+ self.backbone.load_state_dict(state_dict)
+
+ def enable_backbone_gradient_checkpointing(self):
+ for i in range(len(self.backbone.blocks)):
+ self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
+
+ def enable_pytorch_native_sdpa(self):
+ for i in range(len(self.backbone.blocks)):
+ self.backbone.blocks[i].attn = wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn)
+
+ def forward(self, image: torch.Tensor, mixed_precision: bool = False) -> Dict[str, torch.Tensor]:
+ raw_img_h, raw_img_w = image.shape[-2:]
+ patch_h, patch_w = raw_img_h // 14, raw_img_w // 14
+
+ # Apply image transformation for DINOv2
+ image_14 = F.interpolate(image, (patch_h * 14, patch_w * 14), mode="bilinear", align_corners=False, antialias=True)
+
+ # Get intermediate layers from the backbone
+ with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=mixed_precision):
+ features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True)
+
+ # Predict points and mask (mask scores)
+ points, mask = self.head(features, image)
+
+ is_output_prob = False
+ if mask.ndim == 5:
+ # ,
+ points, mask = points.permute(0, 2, 3, 4, 1), mask.permute(0,2,3,4,1)
+ elif mask.ndim == 4: # ,
+ points = points.permute(0, 2, 3, 4, 1)
+ is_output_prob = True
+
+ if self.remap_output == 'linear' or self.remap_output == False:
+ pass
+ elif self.remap_output =='sinh' or self.remap_output == True:
+ points = torch.sinh(points)
+ elif self.remap_output == 'exp':
+ xy, z = points.split([2, 1], dim=-1)
+ z = torch.exp(z)
+ points = torch.cat([xy * z, z], dim=-1)
+ elif self.remap_output =='sinh_exp':
+ xy, z = points.split([2, 1], dim=-1)
+ points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
+ else:
+ raise ValueError(f"Invalid remap output type: {self.remap_output}")
+
+ return_dict = {'pts3d': points}
+
+ if not is_output_prob:
+ return_dict['mask'] = mask
+ else:
+ return_dict["seg_prob"] = mask
+
+ return return_dict
\ No newline at end of file
diff --git a/src/lari/model/utils.py b/src/lari/model/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..af0d0a042209ed87cb60f340529940359fdfa900
--- /dev/null
+++ b/src/lari/model/utils.py
@@ -0,0 +1,38 @@
+from typing import *
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+def wrap_module_with_gradient_checkpointing(module: nn.Module):
+ from torch.utils.checkpoint import checkpoint
+ class _CheckpointingWrapper(module.__class__):
+ _restore_cls = module.__class__
+ def forward(self, *args, **kwargs):
+ return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
+
+ module.__class__ = _CheckpointingWrapper
+ return module
+
+
+def unwrap_module_with_gradient_checkpointing(module: nn.Module):
+ module.__class__ = module.__class__._restore_cls
+
+
+def wrap_dinov2_attention_with_sdpa(module: nn.Module):
+ assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later"
+ class _AttentionWrapper(module.__class__):
+ def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
+
+ q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
+
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias)
+ x = x.permute(0, 2, 1, 3).reshape(B, N, C)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+ module.__class__ = _AttentionWrapper
+ return module
\ No newline at end of file
diff --git a/src/lari/utils/__init__.py b/src/lari/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/lari/utils/geometry_numpy.py b/src/lari/utils/geometry_numpy.py
new file mode 100644
index 0000000000000000000000000000000000000000..b553af48e1d866f03089a90bb9d3d6d3e5a5337d
--- /dev/null
+++ b/src/lari/utils/geometry_numpy.py
@@ -0,0 +1,187 @@
+from typing import *
+from functools import partial
+import math
+
+import numpy as np
+import utils3d
+
+def weighted_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
+ if w is None:
+ return np.mean(x, axis=axis)
+ else:
+ w = w.astype(x.dtype)
+ return (x * w).mean(axis=axis) / np.clip(w.mean(axis=axis), eps, None)
+
+
+def harmonic_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
+ if w is None:
+ return 1 / (1 / np.clip(x, eps, None)).mean(axis=axis)
+ else:
+ w = w.astype(x.dtype)
+ return 1 / (weighted_mean_numpy(1 / (x + eps), w, axis=axis, keepdims=keepdims, eps=eps) + eps)
+
+
+def normalized_view_plane_uv_numpy(width: int, height: int, aspect_ratio: float = None, dtype: np.dtype = np.float32) -> np.ndarray:
+ "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
+ if aspect_ratio is None:
+ aspect_ratio = width / height
+
+ span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
+ span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
+
+ u = np.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype)
+ v = np.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype)
+ u, v = np.meshgrid(u, v, indexing='xy')
+ uv = np.stack([u, v], axis=-1)
+ return uv
+
+
+def focal_to_fov_numpy(focal: np.ndarray):
+ return 2 * np.arctan(0.5 / focal)
+
+
+def fov_to_focal_numpy(fov: np.ndarray):
+ return 0.5 / np.tan(fov / 2)
+
+
+def intrinsics_to_fov_numpy(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ fov_x = focal_to_fov_numpy(intrinsics[..., 0, 0])
+ fov_y = focal_to_fov_numpy(intrinsics[..., 1, 1])
+ return fov_x, fov_y
+
+
+def point_map_to_depth_legacy_numpy(points: np.ndarray):
+ height, width = points.shape[-3:-1]
+ diagonal = (height ** 2 + width ** 2) ** 0.5
+ uv = normalized_view_plane_uv_numpy(width, height, dtype=points.dtype) # (H, W, 2)
+ _, uv = np.broadcast_arrays(points[..., :2], uv)
+
+ # Solve least squares problem
+ b = (uv * points[..., 2:]).reshape(*points.shape[:-3], -1) # (..., H * W * 2)
+ A = np.stack([points[..., :2], -uv], axis=-1).reshape(*points.shape[:-3], -1, 2) # (..., H * W * 2, 2)
+
+ M = A.swapaxes(-2, -1) @ A
+ solution = (np.linalg.inv(M + 1e-6 * np.eye(2)) @ (A.swapaxes(-2, -1) @ b[..., None])).squeeze(-1)
+ focal, shift = solution
+
+ depth = points[..., 2] + shift[..., None, None]
+ fov_x = np.arctan(width / diagonal / focal) * 2
+ fov_y = np.arctan(height / diagonal / focal) * 2
+ return depth, fov_x, fov_y, shift
+
+
+def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray):
+ "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal"
+ from scipy.optimize import least_squares
+ uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
+
+ def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
+ xy_proj = xy / (z + shift)[: , None]
+ f = (xy_proj * uv).sum() / np.square(xy_proj).sum()
+ err = (f * xy_proj - uv).ravel()
+ return err
+
+ solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
+ optim_shift = solution['x'].squeeze().astype(np.float32)
+
+ xy_proj = xy / (z + optim_shift)[: , None]
+ optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum()
+
+ return optim_shift, optim_focal
+
+
+def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float):
+ "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift"
+ from scipy.optimize import least_squares
+ uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
+
+ def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
+ xy_proj = xy/ (z + shift)[: , None]
+ err = (focal * xy_proj - uv).ravel()
+ return err
+
+ solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
+ optim_shift = solution['x'].squeeze().astype(np.float32)
+
+ return optim_shift
+
+
+def recover_focal_shift_numpy(points: np.ndarray, mask: np.ndarray = None, focal: float = None, downsample_size: Tuple[int, int] = (64, 64)):
+ import cv2
+ assert points.shape[-1] == 3, "Points should (H, W, 3)"
+
+ height, width = points.shape[-3], points.shape[-2]
+ diagonal = (height ** 2 + width ** 2) ** 0.5
+
+ uv = normalized_view_plane_uv_numpy(width=width, height=height)
+
+ if mask is None:
+ points_lr = cv2.resize(points, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 3)
+ uv_lr = cv2.resize(uv, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 2)
+ else:
+ index, mask_lr = mask_aware_nearest_resize_numpy(mask, *downsample_size)
+ points_lr, uv_lr = points[index][mask_lr], uv[index][mask_lr]
+
+ if points_lr.size == 0:
+ return np.zeros((height, width)), 0, 0, 0
+
+ if focal is None:
+ focal, shift = solve_optimal_focal_shift(uv_lr, points_lr)
+ else:
+ shift = solve_optimal_shift(uv_lr, points_lr, focal)
+
+ return focal, shift
+
+
+def mask_aware_nearest_resize_numpy(mask: np.ndarray, target_width: int, target_height: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
+
+ ### Parameters
+ - `mask`: Input 2D mask of shape (..., H, W)
+ - `target_width`: target width of the resized map
+ - `target_height`: target height of the resized map
+
+ ### Returns
+ - `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width). Indices are like j + i * W, where j is the row index and i is the column index.
+ - `target_mask`: Mask of the resized map of shape (..., target_height, target_width)
+ """
+ height, width = mask.shape[-2:]
+ filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
+ filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
+ filter_size = filter_h_i * filter_w_i
+ padding_h, padding_w = round(filter_h_f / 2), round(filter_w_f / 2)
+
+ # Window the original mask and uv
+ uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32)
+ indices = np.arange(height * width, dtype=np.int32).reshape(height, width)
+ padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32)
+ padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
+ padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool)
+ padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
+ padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32)
+ padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
+ windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1))
+ windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1))
+ windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1))
+
+ # Gather the target pixels's local window
+ target_uv = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32)
+ target_corner = target_uv - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
+ target_corner = np.round(target_corner - 0.5).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32)
+
+ target_window_uv = windowed_uv[target_corner[..., 1], target_corner[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
+ target_window_mask = windowed_mask[..., target_corner[..., 1], target_corner[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
+ target_window_indices = windowed_indices[target_corner[..., 1], target_corner[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
+
+ # Compute nearest neighbor in the local window for each pixel
+ dist = np.square(target_window_uv - target_uv[..., None])
+ dist = dist[..., 0, :] + dist[..., 1, :]
+ dist = np.where(target_window_mask, dist, np.inf) # (..., target_height, tgt_width, filter_size)
+ nearest_in_window = np.argmin(dist, axis=-1, keepdims=True) # (..., target_height, tgt_width, 1)
+ nearest_idx = np.take_along_axis(target_window_indices, nearest_in_window, axis=-1).squeeze(-1) # (..., target_height, tgt_width)
+ nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
+ target_mask = np.any(target_window_mask, axis=-1)
+ batch_indices = [np.arange(n).reshape([1] * i + [n] + [1] * (mask.ndim - i - 1)) for i, n in enumerate(mask.shape[:-2])]
+
+ return (*batch_indices, nearest_i, nearest_j), target_mask
\ No newline at end of file
diff --git a/src/lari/utils/geometry_torch.py b/src/lari/utils/geometry_torch.py
new file mode 100644
index 0000000000000000000000000000000000000000..c28cd4de1b37fb61afdf12640c306a22d7796ce9
--- /dev/null
+++ b/src/lari/utils/geometry_torch.py
@@ -0,0 +1,221 @@
+from typing import *
+import math
+from collections import namedtuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.types
+
+import os
+import sys
+sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
+import utils3d
+from .geometry_numpy import solve_optimal_focal_shift, solve_optimal_shift
+
+
+def weighted_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
+ if w is None:
+ return x.mean(dim=dim, keepdim=keepdim)
+ else:
+ w = w.to(x.dtype)
+ return (x * w).mean(dim=dim, keepdim=keepdim) / w.mean(dim=dim, keepdim=keepdim).add(eps)
+
+
+def harmonic_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
+ if w is None:
+ return x.add(eps).reciprocal().mean(dim=dim, keepdim=keepdim).reciprocal()
+ else:
+ w = w.to(x.dtype)
+ return weighted_mean(x.add(eps).reciprocal(), w, dim=dim, keepdim=keepdim, eps=eps).add(eps).reciprocal()
+
+
+def geometric_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
+ if w is None:
+ return x.add(eps).log().mean(dim=dim).exp()
+ else:
+ w = w.to(x.dtype)
+ return weighted_mean(x.add(eps).log(), w, dim=dim, keepdim=keepdim, eps=eps).exp()
+
+
+def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor:
+ "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
+ if aspect_ratio is None:
+ aspect_ratio = width / height
+
+ span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
+ span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
+
+ u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
+ v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
+ u, v = torch.meshgrid(u, v, indexing='xy')
+ uv = torch.stack([u, v], dim=-1)
+ return uv
+
+
+def gaussian_blur_2d(input: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
+ kernel = torch.exp(-(torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=input.dtype, device=input.device) ** 2) / (2 * sigma ** 2))
+ kernel = kernel / kernel.sum()
+ kernel = (kernel[:, None] * kernel[None, :]).reshape(1, 1, kernel_size, kernel_size)
+ input = F.pad(input, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), mode='replicate')
+ input = F.conv2d(input, kernel, groups=input.shape[1])
+ return input
+
+
+def focal_to_fov(focal: torch.Tensor):
+ return 2 * torch.atan(0.5 / focal)
+
+
+def fov_to_focal(fov: torch.Tensor):
+ return 0.5 / torch.tan(fov / 2)
+
+
+def intrinsics_to_fov(intrinsics: torch.Tensor):
+ """
+ Returns field of view in radians from normalized intrinsics matrix.
+ ### Parameters:
+ - intrinsics: torch.Tensor of shape (..., 3, 3)
+
+ ### Returns:
+ - fov_x: torch.Tensor of shape (...)
+ - fov_y: torch.Tensor of shape (...)
+ """
+ focal_x = intrinsics[..., 0, 0]
+ focal_y = intrinsics[..., 1, 1]
+ return 2 * torch.atan(0.5 / focal_x), 2 * torch.atan(0.5 / focal_y)
+
+
+def point_map_to_depth_legacy(points: torch.Tensor):
+ height, width = points.shape[-3:-1]
+ diagonal = (height ** 2 + width ** 2) ** 0.5
+ uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
+
+ # Solve least squares problem
+ b = (uv * points[..., 2:]).flatten(-3, -1) # (..., H * W * 2)
+ A = torch.stack([points[..., :2], -uv.expand_as(points[..., :2])], dim=-1).flatten(-4, -2) # (..., H * W * 2, 2)
+
+ M = A.transpose(-2, -1) @ A
+ solution = (torch.inverse(M + 1e-6 * torch.eye(2).to(A)) @ (A.transpose(-2, -1) @ b[..., None])).squeeze(-1)
+ focal, shift = solution.unbind(-1)
+
+ depth = points[..., 2] + shift[..., None, None]
+ fov_x = torch.atan(width / diagonal / focal) * 2
+ fov_y = torch.atan(height / diagonal / focal) * 2
+ return depth, fov_x, fov_y, shift
+
+
+def view_plane_uv_to_focal(uv: torch.Tensor):
+ normed_uv = normalized_view_plane_uv(width=uv.shape[-2], height=uv.shape[-3], device=uv.device, dtype=uv.dtype)
+ focal = (uv * normed_uv).sum() / uv.square().sum().add(1e-12)
+ return focal
+
+
+def recover_focal_shift(points: torch.Tensor, mask: torch.Tensor = None, focal: torch.Tensor = None, downsample_size: Tuple[int, int] = (64, 64)):
+ """
+ Recover the depth map and FoV from a point map with unknown z shift and focal.
+
+ Note that it assumes:
+ - the optical center is at the center of the map
+ - the map is undistorted
+ - the map is isometric in the x and y directions
+
+ ### Parameters:
+ - `points: torch.Tensor` of shape (..., H, W, 3)
+ - `mask: torch.Tensor` of shape (..., H, W). Optional.
+ - `focal: torch.Tensor` of shape (...). Optional.
+ - `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps.
+
+ ### Returns:
+ - `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map
+ - `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space
+ """
+ shape = points.shape
+ height, width = points.shape[-3], points.shape[-2]
+ diagonal = (height ** 2 + width ** 2) ** 0.5
+
+ points = points.reshape(-1, *shape[-3:])
+ mask = None if mask is None else mask.reshape(-1, *shape[-3:-1])
+ focal = focal.reshape(-1) if focal is not None else None
+ uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
+
+ points_lr = F.interpolate(points.permute(0, 3, 1, 2), downsample_size, mode='nearest').permute(0, 2, 3, 1)
+ uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode='nearest').squeeze(0).permute(1, 2, 0)
+ mask_lr = None if mask is None else F.interpolate(mask.to(torch.float32).unsqueeze(1), downsample_size, mode='nearest').squeeze(1) > 0
+
+ uv_lr_np = uv_lr.cpu().numpy()
+ points_lr_np = points_lr.detach().cpu().numpy()
+ focal_np = focal.cpu().numpy() if focal is not None else None
+ mask_lr_np = None if mask is None else mask_lr.cpu().numpy()
+ optim_shift, optim_focal = [], []
+ for i in range(points.shape[0]):
+ points_lr_i_np = points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]]
+ uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]]
+ if focal is None:
+ optim_shift_i, optim_focal_i = solve_optimal_focal_shift(uv_lr_i_np, points_lr_i_np)
+ optim_focal.append(float(optim_focal_i))
+ else:
+ optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i])
+ optim_shift.append(float(optim_shift_i))
+ optim_shift = torch.tensor(optim_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3])
+
+ if focal is None:
+ optim_focal = torch.tensor(optim_focal, device=points.device, dtype=points.dtype).reshape(shape[:-3])
+ else:
+ optim_focal = focal.reshape(shape[:-3])
+
+ return optim_focal, optim_shift
+
+
+def mask_aware_nearest_resize(mask: torch.BoolTensor, target_width: int, target_height: int) -> Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]:
+ """
+ Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
+
+ ### Parameters
+ - `mask`: Input 2D mask of shape (..., H, W)
+ - `target_width`: target width of the resized map
+ - `target_height`: target height of the resized map
+
+ ### Returns
+ - `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension
+ - `target_mask`: Mask of the resized map of shape (..., target_height, target_width)
+ """
+ height, width = mask.shape[-2:]
+ device = mask.device
+ filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
+ filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
+ filter_size = filter_h_i * filter_w_i
+ padding_h, padding_w = round(filter_h_f / 2), round(filter_w_f / 2)
+
+ # Window the original mask and uv
+ uv = utils3d.torch.image_pixel_center(width=width, height=height, dtype=torch.float32, device=device)
+ indices = torch.arange(height * width, dtype=torch.long, device=device).reshape(height, width)
+ padded_uv = torch.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=torch.float32, device=device)
+ padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
+ padded_mask = torch.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=torch.bool, device=device)
+ padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
+ padded_indices = torch.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=torch.long, device=device)
+ padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
+ windowed_uv = utils3d.torch.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, dim=(0, 1))
+ windowed_mask = utils3d.torch.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, dim=(-2, -1))
+ windowed_indices = utils3d.torch.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, dim=(0, 1))
+
+ # Gather the target pixels's local window
+ target_uv = utils3d.torch.image_uv(width=target_width, height=target_height, dtype=torch.float32, device=device) * torch.tensor([width, height], dtype=torch.float32, device=device)
+ target_corner = target_uv - torch.tensor((filter_w_f / 2, filter_h_f / 2), dtype=torch.float32, device=device)
+ target_corner = torch.round(target_corner - 0.5).long() + torch.tensor((padding_w, padding_h), dtype=torch.long, device=device)
+
+ target_window_uv = windowed_uv[target_corner[..., 1], target_corner[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
+ target_window_mask = windowed_mask[..., target_corner[..., 1], target_corner[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
+ target_window_indices = windowed_indices[target_corner[..., 1], target_corner[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
+ target_window_indices = target_window_indices.expand_as(target_window_mask)
+
+ # Compute nearest neighbor in the local window for each pixel
+ dist = torch.where(target_window_mask, torch.norm(target_window_uv - target_uv[..., None], dim=-2), torch.inf) # (..., target_height, tgt_width, filter_size)
+ nearest = torch.argmin(dist, dim=-1, keepdim=True) # (..., target_height, tgt_width, 1)
+ nearest_idx = torch.gather(target_window_indices, index=nearest, dim=-1).squeeze(-1) # (..., target_height, tgt_width)
+ target_mask = torch.any(target_window_mask, dim=-1)
+ nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
+ batch_indices = [torch.arange(n, device=device).reshape([1] * i + [n] + [1] * (mask.dim() - i - 1)) for i, n in enumerate(mask.shape[:-2])]
+
+ return (*batch_indices, nearest_i, nearest_j), target_mask
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e
--- /dev/null
+++ b/src/utils/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
diff --git a/src/utils/vis.py b/src/utils/vis.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c833afb5ab0b98cf11fb202d3bd8d93debe5737
--- /dev/null
+++ b/src/utils/vis.py
@@ -0,0 +1,105 @@
+# import torchvision.transforms as transforms
+# import torch.nn.functional as F
+# import cv2
+# import os
+# import logging
+# from pathlib import Path
+import numpy as np
+# import os
+import torch
+import matplotlib
+# import cv2
+# import random
+# from PIL import Image
+# import imageio
+
+def prob_to_mask(prob):
+ """
+ Transforms a probability map of stopping points (shape: (n_layer+1, H, W))
+ into a binary mask (shape: (H, W, n_layer, 1)) where for each pixel, layers
+ with index ≤ stopping index (as given by argmax) are marked valid.
+ """
+ num_layer_plus1, H, W = prob.shape
+ # Get stopping index for each pixel; values are in {0, 1, ..., n_layer}
+ stopping_indices = torch.argmax(prob, dim=0) # (H, W)
+
+ # Create a tensor with layer indices [1, 2, ..., n_layer]
+ layer_indices = torch.arange(1, num_layer_plus1, device=prob.device).view(-1, 1, 1)
+
+ # Compare: a layer is valid if its index is <= the stopping index.
+ pred_mask = (layer_indices <= stopping_indices.unsqueeze(0))
+
+ # Permute and unsqueeze to get shape (H, W, n_layer, 1)
+ pred_mask = pred_mask.permute(1, 2, 0).unsqueeze(-1)
+ return pred_mask
+
+
+
+
+def colorize(value, vmin=None, vmax=None, cmap='rainbow', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None):
+ """Converts a depth map to a color image.
+
+ Args:
+ value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed
+ vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None.
+ vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None.
+ cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'.
+ invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99.
+ invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None.
+ background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255).
+ gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False.
+ value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None.
+
+ Returns:
+ numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4)
+ """
+ if isinstance(value, torch.Tensor):
+ value = value.detach().cpu().numpy()
+
+ value = value.squeeze()
+ if invalid_mask is None:
+ invalid_mask = value == invalid_val
+ mask = np.logical_not(invalid_mask)
+
+ # normalize
+ vmin = np.percentile(value[mask],2) if vmin is None else vmin
+ vmax = np.percentile(value[mask],85) if vmax is None else vmax
+ if vmin != vmax:
+ value = (value - vmin) / (vmax - vmin) # vmin..vmax
+ else:
+ # Avoid 0-division
+ value = value * 0.
+
+ value[invalid_mask] = np.nan
+ cmapper = matplotlib.cm.get_cmap(cmap)
+ if value_transform:
+ value = value_transform(value)
+ # value = value / value.max()
+ value = cmapper(value, bytes=True) # (nxmx4)
+
+ # img = value[:, :, :]
+ img = value[...]
+ img[invalid_mask] = background_color
+
+ if gamma_corrected:
+ # gamma correction
+ img = img / 255
+ img = np.power(img, 2.2)
+ img = img * 255
+ img = img.astype(np.uint8)
+ return img
+
+
+
+def denormalize(x):
+ """Reverses the imagenet normalization applied to the input.
+
+ Args:
+ x (torch.Tensor - shape(N,3,H,W)): input tensor
+
+ Returns:
+ torch.Tensor - shape(N,3,H,W): Denormalized input
+ """
+ mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
+ std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
+ return x * std + mean
\ No newline at end of file
diff --git a/src/utils3d/README.md b/src/utils3d/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..25b9c737af1398819f077988b4dfca878f839205
--- /dev/null
+++ b/src/utils3d/README.md
@@ -0,0 +1,3 @@
+# utils3d
+
+This is a collection of utility functions for 3D computer vision tasks copied from https://github.com/EasternJournalist/utils3d.
diff --git a/src/utils3d/__init__.py b/src/utils3d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3291ba3a79a7f263c6208a546c20ca458a5058d9
--- /dev/null
+++ b/src/utils3d/__init__.py
@@ -0,0 +1,20 @@
+"""
+A package for common utility functions in 3D computer graphics and vision. Providing NumPy utilities in `utils3d.numpy`, PyTorch utilities in `utils3d.torch`, and IO utilities in `utils3d.io`.
+"""
+import importlib
+from typing import TYPE_CHECKING
+
+try:
+ from ._unified import *
+except ImportError:
+ pass
+
+__all__ = ['numpy', 'torch', 'io']
+
+def __getattr__(name: str):
+ return globals().get(name, importlib.import_module(f'.{name}', __package__))
+
+if TYPE_CHECKING:
+ from . import torch
+ from . import numpy
+ from . import io
\ No newline at end of file
diff --git a/src/utils3d/_helpers.py b/src/utils3d/_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d1d5b086dec88b8e2283e2546520d6f2a3d8505
--- /dev/null
+++ b/src/utils3d/_helpers.py
@@ -0,0 +1,35 @@
+from functools import wraps
+import warnings
+
+
+def suppress_traceback(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ try:
+ return fn(*args, **kwargs)
+ except Exception as e:
+ e.__traceback__ = e.__traceback__.tb_next.tb_next
+ raise
+ return wrapper
+
+
+class no_warnings:
+ def __init__(self, action: str = 'ignore', **kwargs):
+ self.action = action
+ self.filter_kwargs = kwargs
+
+ def __call__(self, fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ with warnings.catch_warnings():
+ warnings.simplefilter(self.action, **self.filter_kwargs)
+ return fn(*args, **kwargs)
+ return wrapper
+
+ def __enter__(self):
+ self.warnings_manager = warnings.catch_warnings()
+ self.warnings_manager.__enter__()
+ warnings.simplefilter(self.action, **self.filter_kwargs)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.warnings_manager.__exit__(exc_type, exc_val, exc_tb)
diff --git a/src/utils3d/_unified/__init__.py b/src/utils3d/_unified/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..84a675766935a51cbae7f3bcbb7378ca25227bd5
--- /dev/null
+++ b/src/utils3d/_unified/__init__.py
@@ -0,0 +1,934 @@
+# Auto-generated implementation redirecting to numpy/torch implementations
+import sys
+from typing import TYPE_CHECKING
+import utils3d
+from .._helpers import suppress_traceback
+
+__all__ = ["triangulate",
+"compute_face_normal",
+"compute_face_angle",
+"compute_vertex_normal",
+"compute_vertex_normal_weighted",
+"remove_corrupted_faces",
+"merge_duplicate_vertices",
+"remove_unreferenced_vertices",
+"subdivide_mesh_simple",
+"mesh_relations",
+"flatten_mesh_indices",
+"calc_quad_candidates",
+"calc_quad_distortion",
+"calc_quad_direction",
+"calc_quad_smoothness",
+"sovle_quad",
+"sovle_quad_qp",
+"tri_to_quad",
+"sliding_window_1d",
+"sliding_window_nd",
+"sliding_window_2d",
+"max_pool_1d",
+"max_pool_2d",
+"max_pool_nd",
+"depth_edge",
+"normals_edge",
+"depth_aliasing",
+"interpolate",
+"image_scrcoord",
+"image_uv",
+"image_pixel_center",
+"image_pixel",
+"image_mesh",
+"image_mesh_from_depth",
+"depth_to_normals",
+"points_to_normals",
+"chessboard",
+"cube",
+"icosahedron",
+"square",
+"camera_frustum",
+"perspective",
+"perspective_from_fov",
+"perspective_from_fov_xy",
+"intrinsics_from_focal_center",
+"intrinsics_from_fov",
+"fov_to_focal",
+"focal_to_fov",
+"intrinsics_to_fov",
+"view_look_at",
+"extrinsics_look_at",
+"perspective_to_intrinsics",
+"perspective_to_near_far",
+"intrinsics_to_perspective",
+"extrinsics_to_view",
+"view_to_extrinsics",
+"normalize_intrinsics",
+"crop_intrinsics",
+"pixel_to_uv",
+"pixel_to_ndc",
+"uv_to_pixel",
+"project_depth",
+"depth_buffer_to_linear",
+"unproject_cv",
+"unproject_gl",
+"project_cv",
+"project_gl",
+"quaternion_to_matrix",
+"axis_angle_to_matrix",
+"matrix_to_quaternion",
+"extrinsics_to_essential",
+"euler_axis_angle_rotation",
+"euler_angles_to_matrix",
+"skew_symmetric",
+"rotation_matrix_from_vectors",
+"ray_intersection",
+"se3_matrix",
+"slerp_quaternion",
+"slerp_vector",
+"lerp",
+"lerp_se3_matrix",
+"piecewise_lerp",
+"piecewise_lerp_se3_matrix",
+"apply_transform",
+"linear_spline_interpolate",
+"RastContext",
+"rasterize_triangle_faces",
+"rasterize_edges",
+"texture",
+"warp_image_by_depth",
+"test_rasterization",
+"compute_face_angles",
+"compute_face_tbn",
+"compute_vertex_tbn",
+"laplacian",
+"laplacian_smooth_mesh",
+"taubin_smooth_mesh",
+"laplacian_hc_smooth_mesh",
+"get_rays",
+"get_image_rays",
+"get_mipnerf_cones",
+"volume_rendering",
+"bin_sample",
+"importance_sample",
+"nerf_render_rays",
+"mipnerf_render_rays",
+"nerf_render_view",
+"mipnerf_render_view",
+"InstantNGP",
+"point_to_normal",
+"depth_to_normal",
+"masked_min",
+"masked_max",
+"bounding_rect",
+"intrinsics_from_fov_xy",
+"matrix_to_euler_angles",
+"matrix_to_axis_angle",
+"axis_angle_to_quaternion",
+"quaternion_to_axis_angle",
+"slerp",
+"interpolate_extrinsics",
+"interpolate_view",
+"to4x4",
+"rotation_matrix_2d",
+"rotate_2d",
+"translate_2d",
+"scale_2d",
+"apply_2d",
+"warp_image_by_forward_flow"]
+
+def _contains_tensor(obj):
+ if isinstance(obj, (list, tuple)):
+ return any(_contains_tensor(item) for item in obj)
+ elif isinstance(obj, dict):
+ return any(_contains_tensor(value) for value in obj.values())
+ else:
+ import torch
+ return isinstance(obj, torch.Tensor)
+
+
+@suppress_traceback
+def _call_based_on_args(fname, args, kwargs):
+ if 'torch' in sys.modules:
+ if any(_contains_tensor(arg) for arg in args) or any(_contains_tensor(v) for v in kwargs.values()):
+ fn = getattr(utils3d.torch, fname, None)
+ if fn is None:
+ raise NotImplementedError(f"Function {fname} has no torch implementation.")
+ return fn(*args, **kwargs)
+ fn = getattr(utils3d.numpy, fname, None)
+ if fn is None:
+ raise NotImplementedError(f"Function {fname} has no numpy implementation.")
+ return fn(*args, **kwargs)
+
+
+@suppress_traceback
+def triangulate(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.triangulate, utils3d.torch.triangulate
+ return _call_based_on_args('triangulate', args, kwargs)
+
+@suppress_traceback
+def compute_face_normal(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.compute_face_normal, utils3d.torch.compute_face_normal
+ return _call_based_on_args('compute_face_normal', args, kwargs)
+
+@suppress_traceback
+def compute_face_angle(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.compute_face_angle, None
+ return _call_based_on_args('compute_face_angle', args, kwargs)
+
+@suppress_traceback
+def compute_vertex_normal(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.compute_vertex_normal, utils3d.torch.compute_vertex_normal
+ return _call_based_on_args('compute_vertex_normal', args, kwargs)
+
+@suppress_traceback
+def compute_vertex_normal_weighted(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.compute_vertex_normal_weighted, utils3d.torch.compute_vertex_normal_weighted
+ return _call_based_on_args('compute_vertex_normal_weighted', args, kwargs)
+
+@suppress_traceback
+def remove_corrupted_faces(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.remove_corrupted_faces, utils3d.torch.remove_corrupted_faces
+ return _call_based_on_args('remove_corrupted_faces', args, kwargs)
+
+@suppress_traceback
+def merge_duplicate_vertices(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.merge_duplicate_vertices, utils3d.torch.merge_duplicate_vertices
+ return _call_based_on_args('merge_duplicate_vertices', args, kwargs)
+
+@suppress_traceback
+def remove_unreferenced_vertices(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.remove_unreferenced_vertices, utils3d.torch.remove_unreferenced_vertices
+ return _call_based_on_args('remove_unreferenced_vertices', args, kwargs)
+
+@suppress_traceback
+def subdivide_mesh_simple(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.subdivide_mesh_simple, utils3d.torch.subdivide_mesh_simple
+ return _call_based_on_args('subdivide_mesh_simple', args, kwargs)
+
+@suppress_traceback
+def mesh_relations(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.mesh_relations, None
+ return _call_based_on_args('mesh_relations', args, kwargs)
+
+@suppress_traceback
+def flatten_mesh_indices(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.flatten_mesh_indices, None
+ return _call_based_on_args('flatten_mesh_indices', args, kwargs)
+
+@suppress_traceback
+def calc_quad_candidates(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.calc_quad_candidates, None
+ return _call_based_on_args('calc_quad_candidates', args, kwargs)
+
+@suppress_traceback
+def calc_quad_distortion(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.calc_quad_distortion, None
+ return _call_based_on_args('calc_quad_distortion', args, kwargs)
+
+@suppress_traceback
+def calc_quad_direction(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.calc_quad_direction, None
+ return _call_based_on_args('calc_quad_direction', args, kwargs)
+
+@suppress_traceback
+def calc_quad_smoothness(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.calc_quad_smoothness, None
+ return _call_based_on_args('calc_quad_smoothness', args, kwargs)
+
+@suppress_traceback
+def sovle_quad(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.sovle_quad, None
+ return _call_based_on_args('sovle_quad', args, kwargs)
+
+@suppress_traceback
+def sovle_quad_qp(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.sovle_quad_qp, None
+ return _call_based_on_args('sovle_quad_qp', args, kwargs)
+
+@suppress_traceback
+def tri_to_quad(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.tri_to_quad, None
+ return _call_based_on_args('tri_to_quad', args, kwargs)
+
+@suppress_traceback
+def sliding_window_1d(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.sliding_window_1d, utils3d.torch.sliding_window_1d
+ return _call_based_on_args('sliding_window_1d', args, kwargs)
+
+@suppress_traceback
+def sliding_window_nd(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.sliding_window_nd, utils3d.torch.sliding_window_nd
+ return _call_based_on_args('sliding_window_nd', args, kwargs)
+
+@suppress_traceback
+def sliding_window_2d(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.sliding_window_2d, utils3d.torch.sliding_window_2d
+ return _call_based_on_args('sliding_window_2d', args, kwargs)
+
+@suppress_traceback
+def max_pool_1d(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.max_pool_1d, None
+ return _call_based_on_args('max_pool_1d', args, kwargs)
+
+@suppress_traceback
+def max_pool_2d(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.max_pool_2d, None
+ return _call_based_on_args('max_pool_2d', args, kwargs)
+
+@suppress_traceback
+def max_pool_nd(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.max_pool_nd, None
+ return _call_based_on_args('max_pool_nd', args, kwargs)
+
+@suppress_traceback
+def depth_edge(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.depth_edge, utils3d.torch.depth_edge
+ return _call_based_on_args('depth_edge', args, kwargs)
+
+@suppress_traceback
+def normals_edge(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.normals_edge, None
+ return _call_based_on_args('normals_edge', args, kwargs)
+
+@suppress_traceback
+def depth_aliasing(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.depth_aliasing, utils3d.torch.depth_aliasing
+ return _call_based_on_args('depth_aliasing', args, kwargs)
+
+@suppress_traceback
+def interpolate(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.interpolate, None
+ return _call_based_on_args('interpolate', args, kwargs)
+
+@suppress_traceback
+def image_scrcoord(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.image_scrcoord, None
+ return _call_based_on_args('image_scrcoord', args, kwargs)
+
+@suppress_traceback
+def image_uv(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.image_uv, utils3d.torch.image_uv
+ return _call_based_on_args('image_uv', args, kwargs)
+
+@suppress_traceback
+def image_pixel_center(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.image_pixel_center, utils3d.torch.image_pixel_center
+ return _call_based_on_args('image_pixel_center', args, kwargs)
+
+@suppress_traceback
+def image_pixel(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.image_pixel, None
+ return _call_based_on_args('image_pixel', args, kwargs)
+
+@suppress_traceback
+def image_mesh(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.image_mesh, utils3d.torch.image_mesh
+ return _call_based_on_args('image_mesh', args, kwargs)
+
+@suppress_traceback
+def image_mesh_from_depth(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.image_mesh_from_depth, utils3d.torch.image_mesh_from_depth
+ return _call_based_on_args('image_mesh_from_depth', args, kwargs)
+
+@suppress_traceback
+def depth_to_normals(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.depth_to_normals, None
+ return _call_based_on_args('depth_to_normals', args, kwargs)
+
+@suppress_traceback
+def points_to_normals(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.points_to_normals, None
+ return _call_based_on_args('points_to_normals', args, kwargs)
+
+@suppress_traceback
+def chessboard(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.chessboard, utils3d.torch.chessboard
+ return _call_based_on_args('chessboard', args, kwargs)
+
+@suppress_traceback
+def cube(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.cube, None
+ return _call_based_on_args('cube', args, kwargs)
+
+@suppress_traceback
+def icosahedron(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.icosahedron, None
+ return _call_based_on_args('icosahedron', args, kwargs)
+
+@suppress_traceback
+def square(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.square, None
+ return _call_based_on_args('square', args, kwargs)
+
+@suppress_traceback
+def camera_frustum(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.camera_frustum, None
+ return _call_based_on_args('camera_frustum', args, kwargs)
+
+@suppress_traceback
+def perspective(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.perspective, utils3d.torch.perspective
+ return _call_based_on_args('perspective', args, kwargs)
+
+@suppress_traceback
+def perspective_from_fov(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.perspective_from_fov, utils3d.torch.perspective_from_fov
+ return _call_based_on_args('perspective_from_fov', args, kwargs)
+
+@suppress_traceback
+def perspective_from_fov_xy(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.perspective_from_fov_xy, utils3d.torch.perspective_from_fov_xy
+ return _call_based_on_args('perspective_from_fov_xy', args, kwargs)
+
+@suppress_traceback
+def intrinsics_from_focal_center(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.intrinsics_from_focal_center, utils3d.torch.intrinsics_from_focal_center
+ return _call_based_on_args('intrinsics_from_focal_center', args, kwargs)
+
+@suppress_traceback
+def intrinsics_from_fov(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.intrinsics_from_fov, utils3d.torch.intrinsics_from_fov
+ return _call_based_on_args('intrinsics_from_fov', args, kwargs)
+
+@suppress_traceback
+def fov_to_focal(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.fov_to_focal, None
+ return _call_based_on_args('fov_to_focal', args, kwargs)
+
+@suppress_traceback
+def focal_to_fov(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.focal_to_fov, None
+ return _call_based_on_args('focal_to_fov', args, kwargs)
+
+@suppress_traceback
+def intrinsics_to_fov(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.intrinsics_to_fov, None
+ return _call_based_on_args('intrinsics_to_fov', args, kwargs)
+
+@suppress_traceback
+def view_look_at(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.view_look_at, utils3d.torch.view_look_at
+ return _call_based_on_args('view_look_at', args, kwargs)
+
+@suppress_traceback
+def extrinsics_look_at(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.extrinsics_look_at, utils3d.torch.extrinsics_look_at
+ return _call_based_on_args('extrinsics_look_at', args, kwargs)
+
+@suppress_traceback
+def perspective_to_intrinsics(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.perspective_to_intrinsics, utils3d.torch.perspective_to_intrinsics
+ return _call_based_on_args('perspective_to_intrinsics', args, kwargs)
+
+@suppress_traceback
+def perspective_to_near_far(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.perspective_to_near_far, None
+ return _call_based_on_args('perspective_to_near_far', args, kwargs)
+
+@suppress_traceback
+def intrinsics_to_perspective(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.intrinsics_to_perspective, utils3d.torch.intrinsics_to_perspective
+ return _call_based_on_args('intrinsics_to_perspective', args, kwargs)
+
+@suppress_traceback
+def extrinsics_to_view(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.extrinsics_to_view, utils3d.torch.extrinsics_to_view
+ return _call_based_on_args('extrinsics_to_view', args, kwargs)
+
+@suppress_traceback
+def view_to_extrinsics(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.view_to_extrinsics, utils3d.torch.view_to_extrinsics
+ return _call_based_on_args('view_to_extrinsics', args, kwargs)
+
+@suppress_traceback
+def normalize_intrinsics(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.normalize_intrinsics, utils3d.torch.normalize_intrinsics
+ return _call_based_on_args('normalize_intrinsics', args, kwargs)
+
+@suppress_traceback
+def crop_intrinsics(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.crop_intrinsics, utils3d.torch.crop_intrinsics
+ return _call_based_on_args('crop_intrinsics', args, kwargs)
+
+@suppress_traceback
+def pixel_to_uv(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.pixel_to_uv, utils3d.torch.pixel_to_uv
+ return _call_based_on_args('pixel_to_uv', args, kwargs)
+
+@suppress_traceback
+def pixel_to_ndc(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.pixel_to_ndc, utils3d.torch.pixel_to_ndc
+ return _call_based_on_args('pixel_to_ndc', args, kwargs)
+
+@suppress_traceback
+def uv_to_pixel(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.uv_to_pixel, utils3d.torch.uv_to_pixel
+ return _call_based_on_args('uv_to_pixel', args, kwargs)
+
+@suppress_traceback
+def project_depth(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.project_depth, utils3d.torch.project_depth
+ return _call_based_on_args('project_depth', args, kwargs)
+
+@suppress_traceback
+def depth_buffer_to_linear(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.depth_buffer_to_linear, utils3d.torch.depth_buffer_to_linear
+ return _call_based_on_args('depth_buffer_to_linear', args, kwargs)
+
+@suppress_traceback
+def unproject_cv(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.unproject_cv, utils3d.torch.unproject_cv
+ return _call_based_on_args('unproject_cv', args, kwargs)
+
+@suppress_traceback
+def unproject_gl(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.unproject_gl, utils3d.torch.unproject_gl
+ return _call_based_on_args('unproject_gl', args, kwargs)
+
+@suppress_traceback
+def project_cv(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.project_cv, utils3d.torch.project_cv
+ return _call_based_on_args('project_cv', args, kwargs)
+
+@suppress_traceback
+def project_gl(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.project_gl, utils3d.torch.project_gl
+ return _call_based_on_args('project_gl', args, kwargs)
+
+@suppress_traceback
+def quaternion_to_matrix(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.quaternion_to_matrix, utils3d.torch.quaternion_to_matrix
+ return _call_based_on_args('quaternion_to_matrix', args, kwargs)
+
+@suppress_traceback
+def axis_angle_to_matrix(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.axis_angle_to_matrix, utils3d.torch.axis_angle_to_matrix
+ return _call_based_on_args('axis_angle_to_matrix', args, kwargs)
+
+@suppress_traceback
+def matrix_to_quaternion(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.matrix_to_quaternion, utils3d.torch.matrix_to_quaternion
+ return _call_based_on_args('matrix_to_quaternion', args, kwargs)
+
+@suppress_traceback
+def extrinsics_to_essential(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.extrinsics_to_essential, utils3d.torch.extrinsics_to_essential
+ return _call_based_on_args('extrinsics_to_essential', args, kwargs)
+
+@suppress_traceback
+def euler_axis_angle_rotation(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.euler_axis_angle_rotation, utils3d.torch.euler_axis_angle_rotation
+ return _call_based_on_args('euler_axis_angle_rotation', args, kwargs)
+
+@suppress_traceback
+def euler_angles_to_matrix(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.euler_angles_to_matrix, utils3d.torch.euler_angles_to_matrix
+ return _call_based_on_args('euler_angles_to_matrix', args, kwargs)
+
+@suppress_traceback
+def skew_symmetric(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.skew_symmetric, utils3d.torch.skew_symmetric
+ return _call_based_on_args('skew_symmetric', args, kwargs)
+
+@suppress_traceback
+def rotation_matrix_from_vectors(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.rotation_matrix_from_vectors, utils3d.torch.rotation_matrix_from_vectors
+ return _call_based_on_args('rotation_matrix_from_vectors', args, kwargs)
+
+@suppress_traceback
+def ray_intersection(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.ray_intersection, None
+ return _call_based_on_args('ray_intersection', args, kwargs)
+
+@suppress_traceback
+def se3_matrix(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.se3_matrix, None
+ return _call_based_on_args('se3_matrix', args, kwargs)
+
+@suppress_traceback
+def slerp_quaternion(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.slerp_quaternion, None
+ return _call_based_on_args('slerp_quaternion', args, kwargs)
+
+@suppress_traceback
+def slerp_vector(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.slerp_vector, None
+ return _call_based_on_args('slerp_vector', args, kwargs)
+
+@suppress_traceback
+def lerp(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.lerp, None
+ return _call_based_on_args('lerp', args, kwargs)
+
+@suppress_traceback
+def lerp_se3_matrix(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.lerp_se3_matrix, None
+ return _call_based_on_args('lerp_se3_matrix', args, kwargs)
+
+@suppress_traceback
+def piecewise_lerp(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.piecewise_lerp, None
+ return _call_based_on_args('piecewise_lerp', args, kwargs)
+
+@suppress_traceback
+def piecewise_lerp_se3_matrix(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.piecewise_lerp_se3_matrix, None
+ return _call_based_on_args('piecewise_lerp_se3_matrix', args, kwargs)
+
+@suppress_traceback
+def apply_transform(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.apply_transform, None
+ return _call_based_on_args('apply_transform', args, kwargs)
+
+@suppress_traceback
+def linear_spline_interpolate(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.linear_spline_interpolate, None
+ return _call_based_on_args('linear_spline_interpolate', args, kwargs)
+
+@suppress_traceback
+def RastContext(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.RastContext, utils3d.torch.RastContext
+ return _call_based_on_args('RastContext', args, kwargs)
+
+@suppress_traceback
+def rasterize_triangle_faces(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.rasterize_triangle_faces, utils3d.torch.rasterize_triangle_faces
+ return _call_based_on_args('rasterize_triangle_faces', args, kwargs)
+
+@suppress_traceback
+def rasterize_edges(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.rasterize_edges, None
+ return _call_based_on_args('rasterize_edges', args, kwargs)
+
+@suppress_traceback
+def texture(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.texture, None
+ return _call_based_on_args('texture', args, kwargs)
+
+@suppress_traceback
+def warp_image_by_depth(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.warp_image_by_depth, utils3d.torch.warp_image_by_depth
+ return _call_based_on_args('warp_image_by_depth', args, kwargs)
+
+@suppress_traceback
+def test_rasterization(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ utils3d.numpy.test_rasterization, None
+ return _call_based_on_args('test_rasterization', args, kwargs)
+
+@suppress_traceback
+def compute_face_angles(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.compute_face_angles
+ return _call_based_on_args('compute_face_angles', args, kwargs)
+
+@suppress_traceback
+def compute_face_tbn(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.compute_face_tbn
+ return _call_based_on_args('compute_face_tbn', args, kwargs)
+
+@suppress_traceback
+def compute_vertex_tbn(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.compute_vertex_tbn
+ return _call_based_on_args('compute_vertex_tbn', args, kwargs)
+
+@suppress_traceback
+def laplacian(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.laplacian
+ return _call_based_on_args('laplacian', args, kwargs)
+
+@suppress_traceback
+def laplacian_smooth_mesh(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.laplacian_smooth_mesh
+ return _call_based_on_args('laplacian_smooth_mesh', args, kwargs)
+
+@suppress_traceback
+def taubin_smooth_mesh(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.taubin_smooth_mesh
+ return _call_based_on_args('taubin_smooth_mesh', args, kwargs)
+
+@suppress_traceback
+def laplacian_hc_smooth_mesh(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.laplacian_hc_smooth_mesh
+ return _call_based_on_args('laplacian_hc_smooth_mesh', args, kwargs)
+
+@suppress_traceback
+def get_rays(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.get_rays
+ return _call_based_on_args('get_rays', args, kwargs)
+
+@suppress_traceback
+def get_image_rays(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.get_image_rays
+ return _call_based_on_args('get_image_rays', args, kwargs)
+
+@suppress_traceback
+def get_mipnerf_cones(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.get_mipnerf_cones
+ return _call_based_on_args('get_mipnerf_cones', args, kwargs)
+
+@suppress_traceback
+def volume_rendering(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.volume_rendering
+ return _call_based_on_args('volume_rendering', args, kwargs)
+
+@suppress_traceback
+def bin_sample(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.bin_sample
+ return _call_based_on_args('bin_sample', args, kwargs)
+
+@suppress_traceback
+def importance_sample(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.importance_sample
+ return _call_based_on_args('importance_sample', args, kwargs)
+
+@suppress_traceback
+def nerf_render_rays(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.nerf_render_rays
+ return _call_based_on_args('nerf_render_rays', args, kwargs)
+
+@suppress_traceback
+def mipnerf_render_rays(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.mipnerf_render_rays
+ return _call_based_on_args('mipnerf_render_rays', args, kwargs)
+
+@suppress_traceback
+def nerf_render_view(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.nerf_render_view
+ return _call_based_on_args('nerf_render_view', args, kwargs)
+
+@suppress_traceback
+def mipnerf_render_view(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.mipnerf_render_view
+ return _call_based_on_args('mipnerf_render_view', args, kwargs)
+
+@suppress_traceback
+def InstantNGP(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.InstantNGP
+ return _call_based_on_args('InstantNGP', args, kwargs)
+
+@suppress_traceback
+def point_to_normal(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.point_to_normal
+ return _call_based_on_args('point_to_normal', args, kwargs)
+
+@suppress_traceback
+def depth_to_normal(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.depth_to_normal
+ return _call_based_on_args('depth_to_normal', args, kwargs)
+
+@suppress_traceback
+def masked_min(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.masked_min
+ return _call_based_on_args('masked_min', args, kwargs)
+
+@suppress_traceback
+def masked_max(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.masked_max
+ return _call_based_on_args('masked_max', args, kwargs)
+
+@suppress_traceback
+def bounding_rect(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.bounding_rect
+ return _call_based_on_args('bounding_rect', args, kwargs)
+
+@suppress_traceback
+def intrinsics_from_fov_xy(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.intrinsics_from_fov_xy
+ return _call_based_on_args('intrinsics_from_fov_xy', args, kwargs)
+
+@suppress_traceback
+def matrix_to_euler_angles(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.matrix_to_euler_angles
+ return _call_based_on_args('matrix_to_euler_angles', args, kwargs)
+
+@suppress_traceback
+def matrix_to_axis_angle(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.matrix_to_axis_angle
+ return _call_based_on_args('matrix_to_axis_angle', args, kwargs)
+
+@suppress_traceback
+def axis_angle_to_quaternion(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.axis_angle_to_quaternion
+ return _call_based_on_args('axis_angle_to_quaternion', args, kwargs)
+
+@suppress_traceback
+def quaternion_to_axis_angle(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.quaternion_to_axis_angle
+ return _call_based_on_args('quaternion_to_axis_angle', args, kwargs)
+
+@suppress_traceback
+def slerp(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.slerp
+ return _call_based_on_args('slerp', args, kwargs)
+
+@suppress_traceback
+def interpolate_extrinsics(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.interpolate_extrinsics
+ return _call_based_on_args('interpolate_extrinsics', args, kwargs)
+
+@suppress_traceback
+def interpolate_view(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.interpolate_view
+ return _call_based_on_args('interpolate_view', args, kwargs)
+
+@suppress_traceback
+def to4x4(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.to4x4
+ return _call_based_on_args('to4x4', args, kwargs)
+
+@suppress_traceback
+def rotation_matrix_2d(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.rotation_matrix_2d
+ return _call_based_on_args('rotation_matrix_2d', args, kwargs)
+
+@suppress_traceback
+def rotate_2d(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.rotate_2d
+ return _call_based_on_args('rotate_2d', args, kwargs)
+
+@suppress_traceback
+def translate_2d(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.translate_2d
+ return _call_based_on_args('translate_2d', args, kwargs)
+
+@suppress_traceback
+def scale_2d(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.scale_2d
+ return _call_based_on_args('scale_2d', args, kwargs)
+
+@suppress_traceback
+def apply_2d(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.apply_2d
+ return _call_based_on_args('apply_2d', args, kwargs)
+
+@suppress_traceback
+def warp_image_by_forward_flow(*args, **kwargs):
+ if TYPE_CHECKING: # redirected to:
+ None, utils3d.torch.warp_image_by_forward_flow
+ return _call_based_on_args('warp_image_by_forward_flow', args, kwargs)
+
diff --git a/src/utils3d/_unified/__init__.pyi b/src/utils3d/_unified/__init__.pyi
new file mode 100644
index 0000000000000000000000000000000000000000..28f662efdf95ed2894c9af693674e74b307109dc
--- /dev/null
+++ b/src/utils3d/_unified/__init__.pyi
@@ -0,0 +1,2431 @@
+# Auto-generated interface file
+from typing import List, Tuple, Dict, Union, Optional, Any, overload, Literal, Callable
+import numpy as numpy_
+import torch as torch_
+import nvdiffrast.torch
+import numbers
+from . import numpy, torch
+import utils3d.numpy, utils3d.torch
+
+__all__ = ["triangulate",
+"compute_face_normal",
+"compute_face_angle",
+"compute_vertex_normal",
+"compute_vertex_normal_weighted",
+"remove_corrupted_faces",
+"merge_duplicate_vertices",
+"remove_unreferenced_vertices",
+"subdivide_mesh_simple",
+"mesh_relations",
+"flatten_mesh_indices",
+"calc_quad_candidates",
+"calc_quad_distortion",
+"calc_quad_direction",
+"calc_quad_smoothness",
+"sovle_quad",
+"sovle_quad_qp",
+"tri_to_quad",
+"sliding_window_1d",
+"sliding_window_nd",
+"sliding_window_2d",
+"max_pool_1d",
+"max_pool_2d",
+"max_pool_nd",
+"depth_edge",
+"normals_edge",
+"depth_aliasing",
+"interpolate",
+"image_scrcoord",
+"image_uv",
+"image_pixel_center",
+"image_pixel",
+"image_mesh",
+"image_mesh_from_depth",
+"depth_to_normals",
+"points_to_normals",
+"chessboard",
+"cube",
+"icosahedron",
+"square",
+"camera_frustum",
+"perspective",
+"perspective_from_fov",
+"perspective_from_fov_xy",
+"intrinsics_from_focal_center",
+"intrinsics_from_fov",
+"fov_to_focal",
+"focal_to_fov",
+"intrinsics_to_fov",
+"view_look_at",
+"extrinsics_look_at",
+"perspective_to_intrinsics",
+"perspective_to_near_far",
+"intrinsics_to_perspective",
+"extrinsics_to_view",
+"view_to_extrinsics",
+"normalize_intrinsics",
+"crop_intrinsics",
+"pixel_to_uv",
+"pixel_to_ndc",
+"uv_to_pixel",
+"project_depth",
+"depth_buffer_to_linear",
+"unproject_cv",
+"unproject_gl",
+"project_cv",
+"project_gl",
+"quaternion_to_matrix",
+"axis_angle_to_matrix",
+"matrix_to_quaternion",
+"extrinsics_to_essential",
+"euler_axis_angle_rotation",
+"euler_angles_to_matrix",
+"skew_symmetric",
+"rotation_matrix_from_vectors",
+"ray_intersection",
+"se3_matrix",
+"slerp_quaternion",
+"slerp_vector",
+"lerp",
+"lerp_se3_matrix",
+"piecewise_lerp",
+"piecewise_lerp_se3_matrix",
+"apply_transform",
+"linear_spline_interpolate",
+"RastContext",
+"rasterize_triangle_faces",
+"rasterize_edges",
+"texture",
+"warp_image_by_depth",
+"test_rasterization",
+"compute_face_angles",
+"compute_face_tbn",
+"compute_vertex_tbn",
+"laplacian",
+"laplacian_smooth_mesh",
+"taubin_smooth_mesh",
+"laplacian_hc_smooth_mesh",
+"get_rays",
+"get_image_rays",
+"get_mipnerf_cones",
+"volume_rendering",
+"bin_sample",
+"importance_sample",
+"nerf_render_rays",
+"mipnerf_render_rays",
+"nerf_render_view",
+"mipnerf_render_view",
+"InstantNGP",
+"point_to_normal",
+"depth_to_normal",
+"masked_min",
+"masked_max",
+"bounding_rect",
+"intrinsics_from_fov_xy",
+"matrix_to_euler_angles",
+"matrix_to_axis_angle",
+"axis_angle_to_quaternion",
+"quaternion_to_axis_angle",
+"slerp",
+"interpolate_extrinsics",
+"interpolate_view",
+"to4x4",
+"rotation_matrix_2d",
+"rotate_2d",
+"translate_2d",
+"scale_2d",
+"apply_2d",
+"warp_image_by_forward_flow"]
+
+@overload
+def triangulate(faces: numpy_.ndarray, vertices: numpy_.ndarray = None, backslash: numpy_.ndarray = None) -> numpy_.ndarray:
+ """Triangulate a polygonal mesh.
+
+Args:
+ faces (np.ndarray): [L, P] polygonal faces
+ vertices (np.ndarray, optional): [N, 3] 3-dimensional vertices.
+ If given, the triangulation is performed according to the distance
+ between vertices. Defaults to None.
+ backslash (np.ndarray, optional): [L] boolean array indicating
+ how to triangulate the quad faces. Defaults to None.
+
+Returns:
+ (np.ndarray): [L * (P - 2), 3] triangular faces"""
+ utils3d.numpy.mesh.triangulate
+
+@overload
+def compute_face_normal(vertices: numpy_.ndarray, faces: numpy_.ndarray) -> numpy_.ndarray:
+ """Compute face normals of a triangular mesh
+
+Args:
+ vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
+ faces (np.ndarray): [T, 3] triangular face indices
+
+Returns:
+ normals (np.ndarray): [..., T, 3] face normals"""
+ utils3d.numpy.mesh.compute_face_normal
+
+@overload
+def compute_face_angle(vertices: numpy_.ndarray, faces: numpy_.ndarray, eps: float = 1e-12) -> numpy_.ndarray:
+ """Compute face angles of a triangular mesh
+
+Args:
+ vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
+ faces (np.ndarray): [T, 3] triangular face indices
+
+Returns:
+ angles (np.ndarray): [..., T, 3] face angles"""
+ utils3d.numpy.mesh.compute_face_angle
+
+@overload
+def compute_vertex_normal(vertices: numpy_.ndarray, faces: numpy_.ndarray, face_normal: numpy_.ndarray = None) -> numpy_.ndarray:
+ """Compute vertex normals of a triangular mesh by averaging neightboring face normals
+TODO: can be improved.
+
+Args:
+ vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
+ faces (np.ndarray): [T, 3] triangular face indices
+ face_normal (np.ndarray, optional): [..., T, 3] face normals.
+ None to compute face normals from vertices and faces. Defaults to None.
+
+Returns:
+ normals (np.ndarray): [..., N, 3] vertex normals"""
+ utils3d.numpy.mesh.compute_vertex_normal
+
+@overload
+def compute_vertex_normal_weighted(vertices: numpy_.ndarray, faces: numpy_.ndarray, face_normal: numpy_.ndarray = None) -> numpy_.ndarray:
+ """Compute vertex normals of a triangular mesh by weighted sum of neightboring face normals
+according to the angles
+
+Args:
+ vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
+ faces (np.ndarray): [..., T, 3] triangular face indices
+ face_normal (np.ndarray, optional): [..., T, 3] face normals.
+ None to compute face normals from vertices and faces. Defaults to None.
+
+Returns:
+ normals (np.ndarray): [..., N, 3] vertex normals"""
+ utils3d.numpy.mesh.compute_vertex_normal_weighted
+
+@overload
+def remove_corrupted_faces(faces: numpy_.ndarray) -> numpy_.ndarray:
+ """Remove corrupted faces (faces with duplicated vertices)
+
+Args:
+ faces (np.ndarray): [T, 3] triangular face indices
+
+Returns:
+ np.ndarray: [T_, 3] triangular face indices"""
+ utils3d.numpy.mesh.remove_corrupted_faces
+
+@overload
+def merge_duplicate_vertices(vertices: numpy_.ndarray, faces: numpy_.ndarray, tol: float = 1e-06) -> Tuple[numpy_.ndarray, numpy_.ndarray]:
+ """Merge duplicate vertices of a triangular mesh.
+Duplicate vertices are merged by selecte one of them, and the face indices are updated accordingly.
+
+Args:
+ vertices (np.ndarray): [N, 3] 3-dimensional vertices
+ faces (np.ndarray): [T, 3] triangular face indices
+ tol (float, optional): tolerance for merging. Defaults to 1e-6.
+
+Returns:
+ vertices (np.ndarray): [N_, 3] 3-dimensional vertices
+ faces (np.ndarray): [T, 3] triangular face indices"""
+ utils3d.numpy.mesh.merge_duplicate_vertices
+
+@overload
+def remove_unreferenced_vertices(faces: numpy_.ndarray, *vertice_attrs, return_indices: bool = False) -> Tuple[numpy_.ndarray, ...]:
+ """Remove unreferenced vertices of a mesh.
+Unreferenced vertices are removed, and the face indices are updated accordingly.
+
+Args:
+ faces (np.ndarray): [T, P] face indices
+ *vertice_attrs: vertex attributes
+
+Returns:
+ faces (np.ndarray): [T, P] face indices
+ *vertice_attrs: vertex attributes
+ indices (np.ndarray, optional): [N] indices of vertices that are kept. Defaults to None."""
+ utils3d.numpy.mesh.remove_unreferenced_vertices
+
+@overload
+def subdivide_mesh_simple(vertices: numpy_.ndarray, faces: numpy_.ndarray, n: int = 1) -> Tuple[numpy_.ndarray, numpy_.ndarray]:
+ """Subdivide a triangular mesh by splitting each triangle into 4 smaller triangles.
+NOTE: All original vertices are kept, and new vertices are appended to the end of the vertex list.
+
+Args:
+ vertices (np.ndarray): [N, 3] 3-dimensional vertices
+ faces (np.ndarray): [T, 3] triangular face indices
+ n (int, optional): number of subdivisions. Defaults to 1.
+
+Returns:
+ vertices (np.ndarray): [N_, 3] subdivided 3-dimensional vertices
+ faces (np.ndarray): [4 * T, 3] subdivided triangular face indices"""
+ utils3d.numpy.mesh.subdivide_mesh_simple
+
+@overload
+def mesh_relations(faces: numpy_.ndarray) -> Tuple[numpy_.ndarray, numpy_.ndarray]:
+ """Calculate the relation between vertices and faces.
+NOTE: The input mesh must be a manifold triangle mesh.
+
+Args:
+ faces (np.ndarray): [T, 3] triangular face indices
+
+Returns:
+ edges (np.ndarray): [E, 2] edge indices
+ edge2face (np.ndarray): [E, 2] edge to face relation. The second column is -1 if the edge is boundary.
+ face2edge (np.ndarray): [T, 3] face to edge relation
+ face2face (np.ndarray): [T, 3] face to face relation"""
+ utils3d.numpy.mesh.mesh_relations
+
+@overload
+def flatten_mesh_indices(*args: numpy_.ndarray) -> Tuple[numpy_.ndarray, ...]:
+ utils3d.numpy.mesh.flatten_mesh_indices
+
+@overload
+def calc_quad_candidates(edges: numpy_.ndarray, face2edge: numpy_.ndarray, edge2face: numpy_.ndarray):
+ """Calculate the candidate quad faces.
+
+Args:
+ edges (np.ndarray): [E, 2] edge indices
+ face2edge (np.ndarray): [T, 3] face to edge relation
+ edge2face (np.ndarray): [E, 2] edge to face relation
+
+Returns:
+ quads (np.ndarray): [Q, 4] quad candidate indices
+ quad2edge (np.ndarray): [Q, 4] edge to quad candidate relation
+ quad2adj (np.ndarray): [Q, 8] adjacent quad candidates of each quad candidate
+ quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid"""
+ utils3d.numpy.quadmesh.calc_quad_candidates
+
+@overload
+def calc_quad_distortion(vertices: numpy_.ndarray, quads: numpy_.ndarray):
+ """Calculate the distortion of each candidate quad face.
+
+Args:
+ vertices (np.ndarray): [N, 3] 3-dimensional vertices
+ quads (np.ndarray): [Q, 4] quad face indices
+
+Returns:
+ distortion (np.ndarray): [Q] distortion of each quad face"""
+ utils3d.numpy.quadmesh.calc_quad_distortion
+
+@overload
+def calc_quad_direction(vertices: numpy_.ndarray, quads: numpy_.ndarray):
+ """Calculate the direction of each candidate quad face.
+
+Args:
+ vertices (np.ndarray): [N, 3] 3-dimensional vertices
+ quads (np.ndarray): [Q, 4] quad face indices
+
+Returns:
+ direction (np.ndarray): [Q, 4] direction of each quad face.
+ Represented by the angle between the crossing and each edge."""
+ utils3d.numpy.quadmesh.calc_quad_direction
+
+@overload
+def calc_quad_smoothness(quad2edge: numpy_.ndarray, quad2adj: numpy_.ndarray, quads_direction: numpy_.ndarray):
+ """Calculate the smoothness of each candidate quad face connection.
+
+Args:
+ quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face
+ quads_direction (np.ndarray): [Q, 4] direction of each quad face
+
+Returns:
+ smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection"""
+ utils3d.numpy.quadmesh.calc_quad_smoothness
+
+@overload
+def sovle_quad(face2edge: numpy_.ndarray, edge2face: numpy_.ndarray, quad2adj: numpy_.ndarray, quads_distortion: numpy_.ndarray, quads_smoothness: numpy_.ndarray, quads_valid: numpy_.ndarray):
+ """Solve the quad mesh from the candidate quad faces.
+
+Args:
+ face2edge (np.ndarray): [T, 3] face to edge relation
+ edge2face (np.ndarray): [E, 2] edge to face relation
+ quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face
+ quads_distortion (np.ndarray): [Q] distortion of each quad face
+ quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection
+ quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid
+
+Returns:
+ weights (np.ndarray): [Q] weight of each valid quad face"""
+ utils3d.numpy.quadmesh.sovle_quad
+
+@overload
+def sovle_quad_qp(face2edge: numpy_.ndarray, edge2face: numpy_.ndarray, quad2adj: numpy_.ndarray, quads_distortion: numpy_.ndarray, quads_smoothness: numpy_.ndarray, quads_valid: numpy_.ndarray):
+ """Solve the quad mesh from the candidate quad faces.
+
+Args:
+ face2edge (np.ndarray): [T, 3] face to edge relation
+ edge2face (np.ndarray): [E, 2] edge to face relation
+ quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face
+ quads_distortion (np.ndarray): [Q] distortion of each quad face
+ quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection
+ quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid
+
+Returns:
+ weights (np.ndarray): [Q] weight of each valid quad face"""
+ utils3d.numpy.quadmesh.sovle_quad_qp
+
+@overload
+def tri_to_quad(vertices: numpy_.ndarray, faces: numpy_.ndarray) -> Tuple[numpy_.ndarray, numpy_.ndarray]:
+ """Convert a triangle mesh to a quad mesh.
+NOTE: The input mesh must be a manifold mesh.
+
+Args:
+ vertices (np.ndarray): [N, 3] 3-dimensional vertices
+ faces (np.ndarray): [T, 3] triangular face indices
+
+Returns:
+ vertices (np.ndarray): [N_, 3] 3-dimensional vertices
+ faces (np.ndarray): [Q, 4] quad face indices"""
+ utils3d.numpy.quadmesh.tri_to_quad
+
+@overload
+def sliding_window_1d(x: numpy_.ndarray, window_size: int, stride: int, axis: int = -1):
+ """Return x view of the input array with x sliding window of the given kernel size and stride.
+The sliding window is performed over the given axis, and the window dimension is append to the end of the output array's shape.
+
+Args:
+ x (np.ndarray): input array with shape (..., axis_size, ...)
+ kernel_size (int): size of the sliding window
+ stride (int): stride of the sliding window
+ axis (int): axis to perform sliding window over
+
+Returns:
+ a_sliding (np.ndarray): view of the input array with shape (..., n_windows, ..., kernel_size), where n_windows = (axis_size - kernel_size + 1) // stride"""
+ utils3d.numpy.utils.sliding_window_1d
+
+@overload
+def sliding_window_nd(x: numpy_.ndarray, window_size: Tuple[int, ...], stride: Tuple[int, ...], axis: Tuple[int, ...]) -> numpy_.ndarray:
+ utils3d.numpy.utils.sliding_window_nd
+
+@overload
+def sliding_window_2d(x: numpy_.ndarray, window_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], axis: Tuple[int, int] = (-2, -1)) -> numpy_.ndarray:
+ utils3d.numpy.utils.sliding_window_2d
+
+@overload
+def max_pool_1d(x: numpy_.ndarray, kernel_size: int, stride: int, padding: int = 0, axis: int = -1):
+ utils3d.numpy.utils.max_pool_1d
+
+@overload
+def max_pool_2d(x: numpy_.ndarray, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], padding: Union[int, Tuple[int, int]], axis: Tuple[int, int] = (-2, -1)):
+ utils3d.numpy.utils.max_pool_2d
+
+@overload
+def max_pool_nd(x: numpy_.ndarray, kernel_size: Tuple[int, ...], stride: Tuple[int, ...], padding: Tuple[int, ...], axis: Tuple[int, ...]) -> numpy_.ndarray:
+ utils3d.numpy.utils.max_pool_nd
+
+@overload
+def depth_edge(depth: numpy_.ndarray, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: numpy_.ndarray = None) -> numpy_.ndarray:
+ """Compute the edge mask from depth map. The edge is defined as the pixels whose neighbors have large difference in depth.
+
+Args:
+ depth (np.ndarray): shape (..., height, width), linear depth map
+ atol (float): absolute tolerance
+ rtol (float): relative tolerance
+
+Returns:
+ edge (np.ndarray): shape (..., height, width) of dtype torch.bool"""
+ utils3d.numpy.utils.depth_edge
+
+@overload
+def normals_edge(normals: numpy_.ndarray, tol: float, kernel_size: int = 3, mask: numpy_.ndarray = None) -> numpy_.ndarray:
+ """Compute the edge mask from normal map.
+
+Args:
+ normal (np.ndarray): shape (..., height, width, 3), normal map
+ tol (float): tolerance in degrees
+
+Returns:
+ edge (np.ndarray): shape (..., height, width) of dtype torch.bool"""
+ utils3d.numpy.utils.normals_edge
+
+@overload
+def depth_aliasing(depth: numpy_.ndarray, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: numpy_.ndarray = None) -> numpy_.ndarray:
+ """Compute the map that indicates the aliasing of x depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors.
+Args:
+ depth (np.ndarray): shape (..., height, width), linear depth map
+ atol (float): absolute tolerance
+ rtol (float): relative tolerance
+
+Returns:
+ edge (np.ndarray): shape (..., height, width) of dtype torch.bool"""
+ utils3d.numpy.utils.depth_aliasing
+
+@overload
+def interpolate(bary: numpy_.ndarray, tri_id: numpy_.ndarray, attr: numpy_.ndarray, faces: numpy_.ndarray) -> numpy_.ndarray:
+ """Interpolate with given barycentric coordinates and triangle indices
+
+Args:
+ bary (np.ndarray): shape (..., 3), barycentric coordinates
+ tri_id (np.ndarray): int array of shape (...), triangle indices
+ attr (np.ndarray): shape (N, M), vertices attributes
+ faces (np.ndarray): int array of shape (T, 3), face vertex indices
+
+Returns:
+ np.ndarray: shape (..., M) interpolated result"""
+ utils3d.numpy.utils.interpolate
+
+@overload
+def image_scrcoord(width: int, height: int) -> numpy_.ndarray:
+ """Get OpenGL's screen space coordinates, ranging in [0, 1].
+[0, 0] is the bottom-left corner of the image.
+
+Args:
+ width (int): image width
+ height (int): image height
+
+Returns:
+ (np.ndarray): shape (height, width, 2)"""
+ utils3d.numpy.utils.image_scrcoord
+
+@overload
+def image_uv(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, dtype: numpy_.dtype = numpy_.float32) -> numpy_.ndarray:
+ """Get image space UV grid, ranging in [0, 1].
+
+>>> image_uv(10, 10):
+[[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]],
+ [[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]],
+ ... ... ...
+ [[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]]
+
+Args:
+ width (int): image width
+ height (int): image height
+
+Returns:
+ np.ndarray: shape (height, width, 2)"""
+ utils3d.numpy.utils.image_uv
+
+@overload
+def image_pixel_center(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, dtype: numpy_.dtype = numpy_.float32) -> numpy_.ndarray:
+ """Get image pixel center coordinates, ranging in [0, width] and [0, height].
+`image[i, j]` has pixel center coordinates `(j + 0.5, i + 0.5)`.
+
+>>> image_pixel_center(10, 10):
+[[[0.5, 0.5], [1.5, 0.5], ..., [9.5, 0.5]],
+ [[0.5, 1.5], [1.5, 1.5], ..., [9.5, 1.5]],
+ ... ... ...
+[[0.5, 9.5], [1.5, 9.5], ..., [9.5, 9.5]]]
+
+Args:
+ width (int): image width
+ height (int): image height
+
+Returns:
+ np.ndarray: shape (height, width, 2)"""
+ utils3d.numpy.utils.image_pixel_center
+
+@overload
+def image_pixel(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, dtype: numpy_.dtype = numpy_.int32) -> numpy_.ndarray:
+ """Get image pixel coordinates grid, ranging in [0, width - 1] and [0, height - 1].
+`image[i, j]` has pixel center coordinates `(j, i)`.
+
+>>> image_pixel_center(10, 10):
+[[[0, 0], [1, 0], ..., [9, 0]],
+ [[0, 1.5], [1, 1], ..., [9, 1]],
+ ... ... ...
+[[0, 9.5], [1, 9], ..., [9, 9 ]]]
+
+Args:
+ width (int): image width
+ height (int): image height
+
+Returns:
+ np.ndarray: shape (height, width, 2)"""
+ utils3d.numpy.utils.image_pixel
+
+@overload
+def image_mesh(*image_attrs: numpy_.ndarray, mask: numpy_.ndarray = None, tri: bool = False, return_indices: bool = False) -> Tuple[numpy_.ndarray, ...]:
+ """Get a mesh regarding image pixel uv coordinates as vertices and image grid as faces.
+
+Args:
+ *image_attrs (np.ndarray): image attributes in shape (height, width, [channels])
+ mask (np.ndarray, optional): binary mask of shape (height, width), dtype=bool. Defaults to None.
+
+Returns:
+ faces (np.ndarray): faces connecting neighboring pixels. shape (T, 4) if tri is False, else (T, 3)
+ *vertex_attrs (np.ndarray): vertex attributes in corresponding order with input image_attrs
+ indices (np.ndarray, optional): indices of vertices in the original mesh"""
+ utils3d.numpy.utils.image_mesh
+
+@overload
+def image_mesh_from_depth(depth: numpy_.ndarray, extrinsics: numpy_.ndarray = None, intrinsics: numpy_.ndarray = None, *vertice_attrs: numpy_.ndarray, atol: float = None, rtol: float = None, remove_by_depth: bool = False, return_uv: bool = False, return_indices: bool = False) -> Tuple[numpy_.ndarray, ...]:
+ """Get x triangle mesh by lifting depth map to 3D.
+
+Args:
+ depth (np.ndarray): [H, W] depth map
+ extrinsics (np.ndarray, optional): [4, 4] extrinsics matrix. Defaults to None.
+ intrinsics (np.ndarray, optional): [3, 3] intrinsics matrix. Defaults to None.
+ *vertice_attrs (np.ndarray): [H, W, C] vertex attributes. Defaults to None.
+ atol (float, optional): absolute tolerance. Defaults to None.
+ rtol (float, optional): relative tolerance. Defaults to None.
+ triangles with vertices having depth difference larger than atol + rtol * depth will be marked.
+ remove_by_depth (bool, optional): whether to remove triangles with large depth difference. Defaults to True.
+ return_uv (bool, optional): whether to return uv coordinates. Defaults to False.
+ return_indices (bool, optional): whether to return indices of vertices in the original mesh. Defaults to False.
+
+Returns:
+ vertices (np.ndarray): [N, 3] vertices
+ faces (np.ndarray): [T, 3] faces
+ *vertice_attrs (np.ndarray): [N, C] vertex attributes
+ image_uv (np.ndarray, optional): [N, 2] uv coordinates
+ ref_indices (np.ndarray, optional): [N] indices of vertices in the original mesh"""
+ utils3d.numpy.utils.image_mesh_from_depth
+
+@overload
+def depth_to_normals(depth: numpy_.ndarray, intrinsics: numpy_.ndarray, mask: numpy_.ndarray = None) -> numpy_.ndarray:
+ """Calculate normal map from depth map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system.
+
+Args:
+ depth (np.ndarray): shape (height, width), linear depth map
+ intrinsics (np.ndarray): shape (3, 3), intrinsics matrix
+Returns:
+ normal (np.ndarray): shape (height, width, 3), normal map. """
+ utils3d.numpy.utils.depth_to_normals
+
+@overload
+def points_to_normals(point: numpy_.ndarray, mask: numpy_.ndarray = None) -> numpy_.ndarray:
+ """Calculate normal map from point map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system.
+
+Args:
+ point (np.ndarray): shape (height, width, 3), point map
+Returns:
+ normal (np.ndarray): shape (height, width, 3), normal map. """
+ utils3d.numpy.utils.points_to_normals
+
+@overload
+def chessboard(width: int, height: int, grid_size: int, color_a: numpy_.ndarray, color_b: numpy_.ndarray) -> numpy_.ndarray:
+ """get x chessboard image
+
+Args:
+ width (int): image width
+ height (int): image height
+ grid_size (int): size of chessboard grid
+ color_a (np.ndarray): color of the grid at the top-left corner
+ color_b (np.ndarray): color in complementary grid cells
+
+Returns:
+ image (np.ndarray): shape (height, width, channels), chessboard image"""
+ utils3d.numpy.utils.chessboard
+
+@overload
+def cube(tri: bool = False) -> Tuple[numpy_.ndarray, numpy_.ndarray]:
+ """Get x cube mesh of size 1 centered at origin.
+
+### Parameters
+ tri (bool, optional): return triangulated mesh. Defaults to False, which returns quad mesh.
+
+### Returns
+ vertices (np.ndarray): shape (8, 3)
+ faces (np.ndarray): shape (12, 3)"""
+ utils3d.numpy.utils.cube
+
+@overload
+def icosahedron():
+ utils3d.numpy.utils.icosahedron
+
+@overload
+def square(tri: bool = False) -> Tuple[numpy_.ndarray, numpy_.ndarray]:
+ """Get a square mesh of area 1 centered at origin in the xy-plane.
+
+### Returns
+ vertices (np.ndarray): shape (4, 3)
+ faces (np.ndarray): shape (1, 4)"""
+ utils3d.numpy.utils.square
+
+@overload
+def camera_frustum(extrinsics: numpy_.ndarray, intrinsics: numpy_.ndarray, depth: float = 1.0) -> Tuple[numpy_.ndarray, numpy_.ndarray, numpy_.ndarray]:
+ """Get x triangle mesh of camera frustum."""
+ utils3d.numpy.utils.camera_frustum
+
+@overload
+def perspective(fov_y: Union[float, numpy_.ndarray], aspect: Union[float, numpy_.ndarray], near: Union[float, numpy_.ndarray], far: Union[float, numpy_.ndarray]) -> numpy_.ndarray:
+ """Get OpenGL perspective matrix
+
+Args:
+ fov_y (float | np.ndarray): field of view in y axis
+ aspect (float | np.ndarray): aspect ratio
+ near (float | np.ndarray): near plane to clip
+ far (float | np.ndarray): far plane to clip
+
+Returns:
+ (np.ndarray): [..., 4, 4] perspective matrix"""
+ utils3d.numpy.transforms.perspective
+
+@overload
+def perspective_from_fov(fov: Union[float, numpy_.ndarray], width: Union[int, numpy_.ndarray], height: Union[int, numpy_.ndarray], near: Union[float, numpy_.ndarray], far: Union[float, numpy_.ndarray]) -> numpy_.ndarray:
+ """Get OpenGL perspective matrix from field of view in largest dimension
+
+Args:
+ fov (float | np.ndarray): field of view in largest dimension
+ width (int | np.ndarray): image width
+ height (int | np.ndarray): image height
+ near (float | np.ndarray): near plane to clip
+ far (float | np.ndarray): far plane to clip
+
+Returns:
+ (np.ndarray): [..., 4, 4] perspective matrix"""
+ utils3d.numpy.transforms.perspective_from_fov
+
+@overload
+def perspective_from_fov_xy(fov_x: Union[float, numpy_.ndarray], fov_y: Union[float, numpy_.ndarray], near: Union[float, numpy_.ndarray], far: Union[float, numpy_.ndarray]) -> numpy_.ndarray:
+ """Get OpenGL perspective matrix from field of view in x and y axis
+
+Args:
+ fov_x (float | np.ndarray): field of view in x axis
+ fov_y (float | np.ndarray): field of view in y axis
+ near (float | np.ndarray): near plane to clip
+ far (float | np.ndarray): far plane to clip
+
+Returns:
+ (np.ndarray): [..., 4, 4] perspective matrix"""
+ utils3d.numpy.transforms.perspective_from_fov_xy
+
+@overload
+def intrinsics_from_focal_center(fx: Union[float, numpy_.ndarray], fy: Union[float, numpy_.ndarray], cx: Union[float, numpy_.ndarray], cy: Union[float, numpy_.ndarray], dtype: Optional[numpy_.dtype] = numpy_.float32) -> numpy_.ndarray:
+ """Get OpenCV intrinsics matrix
+
+Returns:
+ (np.ndarray): [..., 3, 3] OpenCV intrinsics matrix"""
+ utils3d.numpy.transforms.intrinsics_from_focal_center
+
+@overload
+def intrinsics_from_fov(fov_max: Union[float, numpy_.ndarray] = None, fov_min: Union[float, numpy_.ndarray] = None, fov_x: Union[float, numpy_.ndarray] = None, fov_y: Union[float, numpy_.ndarray] = None, width: Union[int, numpy_.ndarray] = None, height: Union[int, numpy_.ndarray] = None) -> numpy_.ndarray:
+ """Get normalized OpenCV intrinsics matrix from given field of view.
+You can provide either fov_max, fov_min, fov_x or fov_y
+
+Args:
+ width (int | np.ndarray): image width
+ height (int | np.ndarray): image height
+ fov_max (float | np.ndarray): field of view in largest dimension
+ fov_min (float | np.ndarray): field of view in smallest dimension
+ fov_x (float | np.ndarray): field of view in x axis
+ fov_y (float | np.ndarray): field of view in y axis
+
+Returns:
+ (np.ndarray): [..., 3, 3] OpenCV intrinsics matrix"""
+ utils3d.numpy.transforms.intrinsics_from_fov
+
+@overload
+def fov_to_focal(fov: numpy_.ndarray):
+ utils3d.numpy.transforms.fov_to_focal
+
+@overload
+def focal_to_fov(focal: numpy_.ndarray):
+ utils3d.numpy.transforms.focal_to_fov
+
+@overload
+def intrinsics_to_fov(intrinsics: numpy_.ndarray) -> Tuple[numpy_.ndarray, numpy_.ndarray]:
+ utils3d.numpy.transforms.intrinsics_to_fov
+
+@overload
+def view_look_at(eye: numpy_.ndarray, look_at: numpy_.ndarray, up: numpy_.ndarray) -> numpy_.ndarray:
+ """Get OpenGL view matrix looking at something
+
+Args:
+ eye (np.ndarray): [..., 3] the eye position
+ look_at (np.ndarray): [..., 3] the position to look at
+ up (np.ndarray): [..., 3] head up direction (y axis in screen space). Not necessarily othogonal to view direction
+
+Returns:
+ (np.ndarray): [..., 4, 4], view matrix"""
+ utils3d.numpy.transforms.view_look_at
+
+@overload
+def extrinsics_look_at(eye: numpy_.ndarray, look_at: numpy_.ndarray, up: numpy_.ndarray) -> numpy_.ndarray:
+ """Get OpenCV extrinsics matrix looking at something
+
+Args:
+ eye (np.ndarray): [..., 3] the eye position
+ look_at (np.ndarray): [..., 3] the position to look at
+ up (np.ndarray): [..., 3] head up direction (-y axis in screen space). Not necessarily othogonal to view direction
+
+Returns:
+ (np.ndarray): [..., 4, 4], extrinsics matrix"""
+ utils3d.numpy.transforms.extrinsics_look_at
+
+@overload
+def perspective_to_intrinsics(perspective: numpy_.ndarray) -> numpy_.ndarray:
+ """OpenGL perspective matrix to OpenCV intrinsics
+
+Args:
+ perspective (np.ndarray): [..., 4, 4] OpenGL perspective matrix
+
+Returns:
+ (np.ndarray): shape [..., 3, 3] OpenCV intrinsics"""
+ utils3d.numpy.transforms.perspective_to_intrinsics
+
+@overload
+def perspective_to_near_far(perspective: numpy_.ndarray) -> Tuple[numpy_.ndarray, numpy_.ndarray]:
+ """Get near and far planes from OpenGL perspective matrix
+
+Args:"""
+ utils3d.numpy.transforms.perspective_to_near_far
+
+@overload
+def intrinsics_to_perspective(intrinsics: numpy_.ndarray, near: Union[float, numpy_.ndarray], far: Union[float, numpy_.ndarray]) -> numpy_.ndarray:
+ """OpenCV intrinsics to OpenGL perspective matrix
+NOTE: not work for tile-shifting intrinsics currently
+
+Args:
+ intrinsics (np.ndarray): [..., 3, 3] OpenCV intrinsics matrix
+ near (float | np.ndarray): [...] near plane to clip
+ far (float | np.ndarray): [...] far plane to clip
+Returns:
+ (np.ndarray): [..., 4, 4] OpenGL perspective matrix"""
+ utils3d.numpy.transforms.intrinsics_to_perspective
+
+@overload
+def extrinsics_to_view(extrinsics: numpy_.ndarray) -> numpy_.ndarray:
+ """OpenCV camera extrinsics to OpenGL view matrix
+
+Args:
+ extrinsics (np.ndarray): [..., 4, 4] OpenCV camera extrinsics matrix
+
+Returns:
+ (np.ndarray): [..., 4, 4] OpenGL view matrix"""
+ utils3d.numpy.transforms.extrinsics_to_view
+
+@overload
+def view_to_extrinsics(view: numpy_.ndarray) -> numpy_.ndarray:
+ """OpenGL view matrix to OpenCV camera extrinsics
+
+Args:
+ view (np.ndarray): [..., 4, 4] OpenGL view matrix
+
+Returns:
+ (np.ndarray): [..., 4, 4] OpenCV camera extrinsics matrix"""
+ utils3d.numpy.transforms.view_to_extrinsics
+
+@overload
+def normalize_intrinsics(intrinsics: numpy_.ndarray, width: Union[int, numpy_.ndarray], height: Union[int, numpy_.ndarray], integer_pixel_centers: bool = True) -> numpy_.ndarray:
+ """Normalize intrinsics from pixel cooridnates to uv coordinates
+
+Args:
+ intrinsics (np.ndarray): [..., 3, 3] camera intrinsics(s) to normalize
+ width (int | np.ndarray): [...] image width(s)
+ height (int | np.ndarray): [...] image height(s)
+ integer_pixel_centers (bool): whether the integer pixel coordinates are at the center of the pixel. If False, the integer coordinates are at the left-top corner of the pixel.
+
+Returns:
+ (np.ndarray): [..., 3, 3] normalized camera intrinsics(s)"""
+ utils3d.numpy.transforms.normalize_intrinsics
+
+@overload
+def crop_intrinsics(intrinsics: numpy_.ndarray, width: Union[int, numpy_.ndarray], height: Union[int, numpy_.ndarray], left: Union[int, numpy_.ndarray], top: Union[int, numpy_.ndarray], crop_width: Union[int, numpy_.ndarray], crop_height: Union[int, numpy_.ndarray]) -> numpy_.ndarray:
+ """Evaluate the new intrinsics(s) after crop the image: cropped_img = img[top:top+crop_height, left:left+crop_width]
+
+Args:
+ intrinsics (np.ndarray): [..., 3, 3] camera intrinsics(s) to crop
+ width (int | np.ndarray): [...] image width(s)
+ height (int | np.ndarray): [...] image height(s)
+ left (int | np.ndarray): [...] left crop boundary
+ top (int | np.ndarray): [...] top crop boundary
+ crop_width (int | np.ndarray): [...] crop width
+ crop_height (int | np.ndarray): [...] crop height
+
+Returns:
+ (np.ndarray): [..., 3, 3] cropped camera intrinsics(s)"""
+ utils3d.numpy.transforms.crop_intrinsics
+
+@overload
+def pixel_to_uv(pixel: numpy_.ndarray, width: Union[int, numpy_.ndarray], height: Union[int, numpy_.ndarray]) -> numpy_.ndarray:
+ """Args:
+ pixel (np.ndarray): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1)
+ width (int | np.ndarray): [...] image width(s)
+ height (int | np.ndarray): [...] image height(s)
+
+Returns:
+ (np.ndarray): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)"""
+ utils3d.numpy.transforms.pixel_to_uv
+
+@overload
+def pixel_to_ndc(pixel: numpy_.ndarray, width: Union[int, numpy_.ndarray], height: Union[int, numpy_.ndarray]) -> numpy_.ndarray:
+ """Args:
+ pixel (np.ndarray): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1)
+ width (int | np.ndarray): [...] image width(s)
+ height (int | np.ndarray): [...] image height(s)
+
+Returns:
+ (np.ndarray): [..., 2] pixel coordinrates defined in ndc space, the range is (-1, 1)"""
+ utils3d.numpy.transforms.pixel_to_ndc
+
+@overload
+def uv_to_pixel(uv: numpy_.ndarray, width: Union[int, numpy_.ndarray], height: Union[int, numpy_.ndarray]) -> numpy_.ndarray:
+ """Args:
+ pixel (np.ndarray): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1)
+ width (int | np.ndarray): [...] image width(s)
+ height (int | np.ndarray): [...] image height(s)
+
+Returns:
+ (np.ndarray): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)"""
+ utils3d.numpy.transforms.uv_to_pixel
+
+@overload
+def project_depth(depth: numpy_.ndarray, near: Union[float, numpy_.ndarray], far: Union[float, numpy_.ndarray]) -> numpy_.ndarray:
+ """Project linear depth to depth value in screen space
+
+Args:
+ depth (np.ndarray): [...] depth value
+ near (float | np.ndarray): [...] near plane to clip
+ far (float | np.ndarray): [...] far plane to clip
+
+Returns:
+ (np.ndarray): [..., 1] depth value in screen space, value ranging in [0, 1]"""
+ utils3d.numpy.transforms.project_depth
+
+@overload
+def depth_buffer_to_linear(depth_buffer: numpy_.ndarray, near: Union[float, numpy_.ndarray], far: Union[float, numpy_.ndarray]) -> numpy_.ndarray:
+ """OpenGL depth buffer to linear depth
+
+Args:
+ depth_buffer (np.ndarray): [...] depth value
+ near (float | np.ndarray): [...] near plane to clip
+ far (float | np.ndarray): [...] far plane to clip
+
+Returns:
+ (np.ndarray): [..., 1] linear depth"""
+ utils3d.numpy.transforms.depth_buffer_to_linear
+
+@overload
+def unproject_cv(uv_coord: numpy_.ndarray, depth: numpy_.ndarray = None, extrinsics: numpy_.ndarray = None, intrinsics: numpy_.ndarray = None) -> numpy_.ndarray:
+ """Unproject uv coordinates to 3D view space following the OpenCV convention
+
+Args:
+ uv_coord (np.ndarray): [..., N, 2] uv coordinates, value ranging in [0, 1].
+ The origin (0., 0.) is corresponding to the left & top
+ depth (np.ndarray): [..., N] depth value
+ extrinsics (np.ndarray): [..., 4, 4] extrinsics matrix
+ intrinsics (np.ndarray): [..., 3, 3] intrinsics matrix
+
+Returns:
+ points (np.ndarray): [..., N, 3] 3d points"""
+ utils3d.numpy.transforms.unproject_cv
+
+@overload
+def unproject_gl(screen_coord: numpy_.ndarray, model: numpy_.ndarray = None, view: numpy_.ndarray = None, perspective: numpy_.ndarray = None) -> numpy_.ndarray:
+ """Unproject screen space coordinates to 3D view space following the OpenGL convention (except for row major matrice)
+
+Args:
+ screen_coord (np.ndarray): [..., N, 3] screen space coordinates, value ranging in [0, 1].
+ The origin (0., 0., 0.) is corresponding to the left & bottom & nearest
+ model (np.ndarray): [..., 4, 4] model matrix
+ view (np.ndarray): [..., 4, 4] view matrix
+ perspective (np.ndarray): [..., 4, 4] perspective matrix
+
+Returns:
+ points (np.ndarray): [..., N, 3] 3d points"""
+ utils3d.numpy.transforms.unproject_gl
+
+@overload
+def project_cv(points: numpy_.ndarray, extrinsics: numpy_.ndarray = None, intrinsics: numpy_.ndarray = None) -> Tuple[numpy_.ndarray, numpy_.ndarray]:
+ """Project 3D points to 2D following the OpenCV convention
+
+Args:
+ points (np.ndarray): [..., N, 3] or [..., N, 4] 3D points to project, if the last
+ dimension is 4, the points are assumed to be in homogeneous coordinates
+ extrinsics (np.ndarray): [..., 4, 4] extrinsics matrix
+ intrinsics (np.ndarray): [..., 3, 3] intrinsics matrix
+
+Returns:
+ uv_coord (np.ndarray): [..., N, 2] uv coordinates, value ranging in [0, 1].
+ The origin (0., 0.) is corresponding to the left & top
+ linear_depth (np.ndarray): [..., N] linear depth"""
+ utils3d.numpy.transforms.project_cv
+
+@overload
+def project_gl(points: numpy_.ndarray, model: numpy_.ndarray = None, view: numpy_.ndarray = None, perspective: numpy_.ndarray = None) -> Tuple[numpy_.ndarray, numpy_.ndarray]:
+ """Project 3D points to 2D following the OpenGL convention (except for row major matrice)
+
+Args:
+ points (np.ndarray): [..., N, 3] or [..., N, 4] 3D points to project, if the last
+ dimension is 4, the points are assumed to be in homogeneous coordinates
+ model (np.ndarray): [..., 4, 4] model matrix
+ view (np.ndarray): [..., 4, 4] view matrix
+ perspective (np.ndarray): [..., 4, 4] perspective matrix
+
+Returns:
+ scr_coord (np.ndarray): [..., N, 3] screen space coordinates, value ranging in [0, 1].
+ The origin (0., 0., 0.) is corresponding to the left & bottom & nearest
+ linear_depth (np.ndarray): [..., N] linear depth"""
+ utils3d.numpy.transforms.project_gl
+
+@overload
+def quaternion_to_matrix(quaternion: numpy_.ndarray, eps: float = 1e-12) -> numpy_.ndarray:
+ """Converts a batch of quaternions (w, x, y, z) to rotation matrices
+
+Args:
+ quaternion (np.ndarray): shape (..., 4), the quaternions to convert
+
+Returns:
+ np.ndarray: shape (..., 3, 3), the rotation matrices corresponding to the given quaternions"""
+ utils3d.numpy.transforms.quaternion_to_matrix
+
+@overload
+def axis_angle_to_matrix(axis_angle: numpy_.ndarray, eps: float = 1e-12) -> numpy_.ndarray:
+ """Convert axis-angle representation (rotation vector) to rotation matrix, whose direction is the axis of rotation and length is the angle of rotation
+
+Args:
+ axis_angle (np.ndarray): shape (..., 3), axis-angle vcetors
+
+Returns:
+ np.ndarray: shape (..., 3, 3) The rotation matrices for the given axis-angle parameters"""
+ utils3d.numpy.transforms.axis_angle_to_matrix
+
+@overload
+def matrix_to_quaternion(rot_mat: numpy_.ndarray, eps: float = 1e-12) -> numpy_.ndarray:
+ """Convert 3x3 rotation matrix to quaternion (w, x, y, z)
+
+Args:
+ rot_mat (np.ndarray): shape (..., 3, 3), the rotation matrices to convert
+
+Returns:
+ np.ndarray: shape (..., 4), the quaternions corresponding to the given rotation matrices"""
+ utils3d.numpy.transforms.matrix_to_quaternion
+
+@overload
+def extrinsics_to_essential(extrinsics: numpy_.ndarray):
+ """extrinsics matrix `[[R, t] [0, 0, 0, 1]]` such that `x' = R (x - t)` to essential matrix such that `x' E x = 0`
+
+Args:
+ extrinsics (np.ndaray): [..., 4, 4] extrinsics matrix
+
+Returns:
+ (np.ndaray): [..., 3, 3] essential matrix"""
+ utils3d.numpy.transforms.extrinsics_to_essential
+
+@overload
+def euler_axis_angle_rotation(axis: str, angle: numpy_.ndarray) -> numpy_.ndarray:
+ """Return the rotation matrices for one of the rotations about an axis
+of which Euler angles describe, for each value of the angle given.
+
+Args:
+ axis: Axis label "X" or "Y or "Z".
+ angle: any shape tensor of Euler angles in radians
+
+Returns:
+ Rotation matrices as tensor of shape (..., 3, 3)."""
+ utils3d.numpy.transforms.euler_axis_angle_rotation
+
+@overload
+def euler_angles_to_matrix(euler_angles: numpy_.ndarray, convention: str = 'XYZ') -> numpy_.ndarray:
+ """Convert rotations given as Euler angles in radians to rotation matrices.
+
+Args:
+ euler_angles: Euler angles in radians as ndarray of shape (..., 3), XYZ
+ convention: permutation of "X", "Y" or "Z", representing the order of Euler rotations to apply.
+
+Returns:
+ Rotation matrices as ndarray of shape (..., 3, 3)."""
+ utils3d.numpy.transforms.euler_angles_to_matrix
+
+@overload
+def skew_symmetric(v: numpy_.ndarray):
+ """Skew symmetric matrix from a 3D vector"""
+ utils3d.numpy.transforms.skew_symmetric
+
+@overload
+def rotation_matrix_from_vectors(v1: numpy_.ndarray, v2: numpy_.ndarray):
+ """Rotation matrix that rotates v1 to v2"""
+ utils3d.numpy.transforms.rotation_matrix_from_vectors
+
+@overload
+def ray_intersection(p1: numpy_.ndarray, d1: numpy_.ndarray, p2: numpy_.ndarray, d2: numpy_.ndarray):
+ """Compute the intersection/closest point of two D-dimensional rays
+If the rays are intersecting, the closest point is the intersection point.
+
+Args:
+ p1 (np.ndarray): (..., D) origin of ray 1
+ d1 (np.ndarray): (..., D) direction of ray 1
+ p2 (np.ndarray): (..., D) origin of ray 2
+ d2 (np.ndarray): (..., D) direction of ray 2
+
+Returns:
+ (np.ndarray): (..., N) intersection point"""
+ utils3d.numpy.transforms.ray_intersection
+
+@overload
+def se3_matrix(R: numpy_.ndarray, t: numpy_.ndarray) -> numpy_.ndarray:
+ """Convert rotation matrix and translation vector to 4x4 transformation matrix.
+
+Args:
+ R (np.ndarray): [..., 3, 3] rotation matrix
+ t (np.ndarray): [..., 3] translation vector
+
+Returns:
+ np.ndarray: [..., 4, 4] transformation matrix"""
+ utils3d.numpy.transforms.se3_matrix
+
+@overload
+def slerp_quaternion(q1: numpy_.ndarray, q2: numpy_.ndarray, t: numpy_.ndarray) -> numpy_.ndarray:
+ """Spherical linear interpolation between two unit quaternions.
+
+Args:
+ q1 (np.ndarray): [..., d] unit vector 1
+ q2 (np.ndarray): [..., d] unit vector 2
+ t (np.ndarray): [...] interpolation parameter in [0, 1]
+
+Returns:
+ np.ndarray: [..., 3] interpolated unit vector"""
+ utils3d.numpy.transforms.slerp_quaternion
+
+@overload
+def slerp_vector(v1: numpy_.ndarray, v2: numpy_.ndarray, t: numpy_.ndarray) -> numpy_.ndarray:
+ """Spherical linear interpolation between two unit vectors. The vectors are assumed to be normalized.
+
+Args:
+ v1 (np.ndarray): [..., d] unit vector 1
+ v2 (np.ndarray): [..., d] unit vector 2
+ t (np.ndarray): [...] interpolation parameter in [0, 1]
+
+Returns:
+ np.ndarray: [..., d] interpolated unit vector"""
+ utils3d.numpy.transforms.slerp_vector
+
+@overload
+def lerp(x1: numpy_.ndarray, x2: numpy_.ndarray, t: numpy_.ndarray) -> numpy_.ndarray:
+ """Linear interpolation between two vectors.
+
+Args:
+ x1 (np.ndarray): [..., d] vector 1
+ x2 (np.ndarray): [..., d] vector 2
+ t (np.ndarray): [...] interpolation parameter. [0, 1] for interpolation between x1 and x2, otherwise for extrapolation.
+
+Returns:
+ np.ndarray: [..., d] interpolated vector"""
+ utils3d.numpy.transforms.lerp
+
+@overload
+def lerp_se3_matrix(T1: numpy_.ndarray, T2: numpy_.ndarray, t: numpy_.ndarray) -> numpy_.ndarray:
+ """Linear interpolation between two SE(3) matrices.
+
+Args:
+ T1 (np.ndarray): [..., 4, 4] SE(3) matrix 1
+ T2 (np.ndarray): [..., 4, 4] SE(3) matrix 2
+ t (np.ndarray): [...] interpolation parameter in [0, 1]
+
+Returns:
+ np.ndarray: [..., 4, 4] interpolated SE(3) matrix"""
+ utils3d.numpy.transforms.lerp_se3_matrix
+
+@overload
+def piecewise_lerp(x: numpy_.ndarray, t: numpy_.ndarray, s: numpy_.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> numpy_.ndarray:
+ """Linear spline interpolation.
+
+### Parameters:
+- `x`: np.ndarray, shape (n, d): the values of data points.
+- `t`: np.ndarray, shape (n,): the times of the data points.
+- `s`: np.ndarray, shape (m,): the times to be interpolated.
+- `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly.
+
+### Returns:
+- `y`: np.ndarray, shape (..., m, d): the interpolated values."""
+ utils3d.numpy.transforms.piecewise_lerp
+
+@overload
+def piecewise_lerp_se3_matrix(T: numpy_.ndarray, t: numpy_.ndarray, s: numpy_.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> numpy_.ndarray:
+ """Linear spline interpolation for SE(3) matrices.
+
+### Parameters:
+- `T`: np.ndarray, shape (n, 4, 4): the SE(3) matrices.
+- `t`: np.ndarray, shape (n,): the times of the data points.
+- `s`: np.ndarray, shape (m,): the times to be interpolated.
+- `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly.
+
+### Returns:
+- `T_interp`: np.ndarray, shape (..., m, 4, 4): the interpolated SE(3) matrices."""
+ utils3d.numpy.transforms.piecewise_lerp_se3_matrix
+
+@overload
+def apply_transform(T: numpy_.ndarray, x: numpy_.ndarray) -> numpy_.ndarray:
+ """Apply SE(3) transformation to a point or a set of points.
+
+### Parameters:
+- `T`: np.ndarray, shape (..., 4, 4): the SE(3) matrix.
+- `x`: np.ndarray, shape (..., 3): the point or a set of points to be transformed.
+
+### Returns:
+- `x_transformed`: np.ndarray, shape (..., 3): the transformed point or a set of points."""
+ utils3d.numpy.transforms.apply_transform
+
+@overload
+def linear_spline_interpolate(x: numpy_.ndarray, t: numpy_.ndarray, s: numpy_.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> numpy_.ndarray:
+ """Linear spline interpolation.
+
+### Parameters:
+- `x`: np.ndarray, shape (n, d): the values of data points.
+- `t`: np.ndarray, shape (n,): the times of the data points.
+- `s`: np.ndarray, shape (m,): the times to be interpolated.
+- `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly.
+
+### Returns:
+- `y`: np.ndarray, shape (..., m, d): the interpolated values."""
+ utils3d.numpy.spline.linear_spline_interpolate
+
+@overload
+def RastContext(*args, **kwargs):
+ utils3d.numpy.rasterization.RastContext
+
+@overload
+def rasterize_triangle_faces(ctx: utils3d.numpy.rasterization.RastContext, vertices: numpy_.ndarray, faces: numpy_.ndarray, attr: numpy_.ndarray, width: int, height: int, transform: numpy_.ndarray = None, cull_backface: bool = True, return_depth: bool = False, image: numpy_.ndarray = None, depth: numpy_.ndarray = None) -> Tuple[numpy_.ndarray, numpy_.ndarray]:
+ """Rasterize vertex attribute.
+
+Args:
+ vertices (np.ndarray): [N, 3]
+ faces (np.ndarray): [T, 3]
+ attr (np.ndarray): [N, C]
+ width (int): width of rendered image
+ height (int): height of rendered image
+ transform (np.ndarray): [4, 4] model-view-projection transformation matrix.
+ cull_backface (bool): whether to cull backface
+ image: (np.ndarray): [H, W, C] background image
+ depth: (np.ndarray): [H, W] background depth
+
+Returns:
+ image (np.ndarray): [H, W, C] rendered image
+ depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None."""
+ utils3d.numpy.rasterization.rasterize_triangle_faces
+
+@overload
+def rasterize_edges(ctx: utils3d.numpy.rasterization.RastContext, vertices: numpy_.ndarray, edges: numpy_.ndarray, attr: numpy_.ndarray, width: int, height: int, transform: numpy_.ndarray = None, line_width: float = 1.0, return_depth: bool = False, image: numpy_.ndarray = None, depth: numpy_.ndarray = None) -> Tuple[numpy_.ndarray, ...]:
+ """Rasterize vertex attribute.
+
+Args:
+ vertices (np.ndarray): [N, 3]
+ faces (np.ndarray): [T, 3]
+ attr (np.ndarray): [N, C]
+ width (int): width of rendered image
+ height (int): height of rendered image
+ transform (np.ndarray): [4, 4] model-view-projection matrix
+ line_width (float): width of line. Defaults to 1.0. NOTE: Values other than 1.0 may not work across all platforms.
+ cull_backface (bool): whether to cull backface
+
+Returns:
+ image (np.ndarray): [H, W, C] rendered image
+ depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None."""
+ utils3d.numpy.rasterization.rasterize_edges
+
+@overload
+def texture(ctx: utils3d.numpy.rasterization.RastContext, uv: numpy_.ndarray, texture: numpy_.ndarray, interpolation: str = 'linear', wrap: str = 'clamp') -> numpy_.ndarray:
+ """Given an UV image, texturing from the texture map"""
+ utils3d.numpy.rasterization.texture
+
+@overload
+def warp_image_by_depth(ctx: utils3d.numpy.rasterization.RastContext, src_depth: numpy_.ndarray, src_image: numpy_.ndarray = None, width: int = None, height: int = None, *, extrinsics_src: numpy_.ndarray = None, extrinsics_tgt: numpy_.ndarray = None, intrinsics_src: numpy_.ndarray = None, intrinsics_tgt: numpy_.ndarray = None, near: float = 0.1, far: float = 100.0, cull_backface: bool = True, ssaa: int = 1, return_depth: bool = False) -> Tuple[numpy_.ndarray, ...]:
+ """Warp image by depth map.
+
+Args:
+ ctx (RastContext): rasterizer context
+ src_depth (np.ndarray): [H, W]
+ src_image (np.ndarray, optional): [H, W, C]. The image to warp. Defaults to None (use uv coordinates).
+ width (int, optional): width of the output image. None to use depth map width. Defaults to None.
+ height (int, optional): height of the output image. None to use depth map height. Defaults to None.
+ extrinsics_src (np.ndarray, optional): extrinsics matrix of the source camera. Defaults to None (identity).
+ extrinsics_tgt (np.ndarray, optional): extrinsics matrix of the target camera. Defaults to None (identity).
+ intrinsics_src (np.ndarray, optional): intrinsics matrix of the source camera. Defaults to None (use the same as intrinsics_tgt).
+ intrinsics_tgt (np.ndarray, optional): intrinsics matrix of the target camera. Defaults to None (use the same as intrinsics_src).
+ cull_backface (bool, optional): whether to cull backface. Defaults to True.
+ ssaa (int, optional): super sampling anti-aliasing. Defaults to 1.
+
+Returns:
+ tgt_image (np.ndarray): [H, W, C] warped image (or uv coordinates if image is None).
+ tgt_depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None."""
+ utils3d.numpy.rasterization.warp_image_by_depth
+
+@overload
+def test_rasterization(ctx: utils3d.numpy.rasterization.RastContext):
+ """Test if rasterization works. It will render a cube with random colors and save it as a CHECKME.png file."""
+ utils3d.numpy.rasterization.test_rasterization
+
+@overload
+def triangulate(faces: torch_.Tensor, vertices: torch_.Tensor = None, backslash: bool = None) -> torch_.Tensor:
+ """Triangulate a polygonal mesh.
+
+Args:
+ faces (torch.Tensor): [..., L, P] polygonal faces
+ vertices (torch.Tensor, optional): [..., N, 3] 3-dimensional vertices.
+ If given, the triangulation is performed according to the distance
+ between vertices. Defaults to None.
+ backslash (torch.Tensor, optional): [..., L] boolean array indicating
+ how to triangulate the quad faces. Defaults to None.
+
+
+Returns:
+ (torch.Tensor): [L * (P - 2), 3] triangular faces"""
+ utils3d.torch.mesh.triangulate
+
+@overload
+def compute_face_normal(vertices: torch_.Tensor, faces: torch_.Tensor) -> torch_.Tensor:
+ """Compute face normals of a triangular mesh
+
+Args:
+ vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices
+ faces (torch.Tensor): [..., T, 3] triangular face indices
+
+Returns:
+ normals (torch.Tensor): [..., T, 3] face normals"""
+ utils3d.torch.mesh.compute_face_normal
+
+@overload
+def compute_face_angles(vertices: torch_.Tensor, faces: torch_.Tensor) -> torch_.Tensor:
+ """Compute face angles of a triangular mesh
+
+Args:
+ vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices
+ faces (torch.Tensor): [T, 3] triangular face indices
+
+Returns:
+ angles (torch.Tensor): [..., T, 3] face angles"""
+ utils3d.torch.mesh.compute_face_angles
+
+@overload
+def compute_vertex_normal(vertices: torch_.Tensor, faces: torch_.Tensor, face_normal: torch_.Tensor = None) -> torch_.Tensor:
+ """Compute vertex normals of a triangular mesh by averaging neightboring face normals
+
+Args:
+ vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices
+ faces (torch.Tensor): [T, 3] triangular face indices
+ face_normal (torch.Tensor, optional): [..., T, 3] face normals.
+ None to compute face normals from vertices and faces. Defaults to None.
+
+Returns:
+ normals (torch.Tensor): [..., N, 3] vertex normals"""
+ utils3d.torch.mesh.compute_vertex_normal
+
+@overload
+def compute_vertex_normal_weighted(vertices: torch_.Tensor, faces: torch_.Tensor, face_normal: torch_.Tensor = None) -> torch_.Tensor:
+ """Compute vertex normals of a triangular mesh by weighted sum of neightboring face normals
+according to the angles
+
+Args:
+ vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices
+ faces (torch.Tensor): [T, 3] triangular face indices
+ face_normal (torch.Tensor, optional): [..., T, 3] face normals.
+ None to compute face normals from vertices and faces. Defaults to None.
+
+Returns:
+ normals (torch.Tensor): [..., N, 3] vertex normals"""
+ utils3d.torch.mesh.compute_vertex_normal_weighted
+
+@overload
+def remove_unreferenced_vertices(faces: torch_.Tensor, *vertice_attrs, return_indices: bool = False) -> Tuple[torch_.Tensor, ...]:
+ """Remove unreferenced vertices of a mesh.
+Unreferenced vertices are removed, and the face indices are updated accordingly.
+
+Args:
+ faces (torch.Tensor): [T, P] face indices
+ *vertice_attrs: vertex attributes
+
+Returns:
+ faces (torch.Tensor): [T, P] face indices
+ *vertice_attrs: vertex attributes
+ indices (torch.Tensor, optional): [N] indices of vertices that are kept. Defaults to None."""
+ utils3d.torch.mesh.remove_unreferenced_vertices
+
+@overload
+def remove_corrupted_faces(faces: torch_.Tensor) -> torch_.Tensor:
+ """Remove corrupted faces (faces with duplicated vertices)
+
+Args:
+ faces (torch.Tensor): [T, 3] triangular face indices
+
+Returns:
+ torch.Tensor: [T_, 3] triangular face indices"""
+ utils3d.torch.mesh.remove_corrupted_faces
+
+@overload
+def merge_duplicate_vertices(vertices: torch_.Tensor, faces: torch_.Tensor, tol: float = 1e-06) -> Tuple[torch_.Tensor, torch_.Tensor]:
+ """Merge duplicate vertices of a triangular mesh.
+Duplicate vertices are merged by selecte one of them, and the face indices are updated accordingly.
+
+Args:
+ vertices (torch.Tensor): [N, 3] 3-dimensional vertices
+ faces (torch.Tensor): [T, 3] triangular face indices
+ tol (float, optional): tolerance for merging. Defaults to 1e-6.
+
+Returns:
+ vertices (torch.Tensor): [N_, 3] 3-dimensional vertices
+ faces (torch.Tensor): [T, 3] triangular face indices"""
+ utils3d.torch.mesh.merge_duplicate_vertices
+
+@overload
+def subdivide_mesh_simple(vertices: torch_.Tensor, faces: torch_.Tensor, n: int = 1) -> Tuple[torch_.Tensor, torch_.Tensor]:
+ """Subdivide a triangular mesh by splitting each triangle into 4 smaller triangles.
+NOTE: All original vertices are kept, and new vertices are appended to the end of the vertex list.
+
+Args:
+ vertices (torch.Tensor): [N, 3] 3-dimensional vertices
+ faces (torch.Tensor): [T, 3] triangular face indices
+ n (int, optional): number of subdivisions. Defaults to 1.
+
+Returns:
+ vertices (torch.Tensor): [N_, 3] subdivided 3-dimensional vertices
+ faces (torch.Tensor): [4 * T, 3] subdivided triangular face indices"""
+ utils3d.torch.mesh.subdivide_mesh_simple
+
+@overload
+def compute_face_tbn(pos: torch_.Tensor, faces_pos: torch_.Tensor, uv: torch_.Tensor, faces_uv: torch_.Tensor, eps: float = 1e-07) -> torch_.Tensor:
+ """compute TBN matrix for each face
+
+Args:
+ pos (torch.Tensor): shape (..., N_pos, 3), positions
+ faces_pos (torch.Tensor): shape(T, 3)
+ uv (torch.Tensor): shape (..., N_uv, 3) uv coordinates,
+ faces_uv (torch.Tensor): shape(T, 3)
+
+Returns:
+ torch.Tensor: (..., T, 3, 3) TBN matrix for each face. Note TBN vectors are normalized but not necessarily orthognal"""
+ utils3d.torch.mesh.compute_face_tbn
+
+@overload
+def compute_vertex_tbn(faces_topo: torch_.Tensor, pos: torch_.Tensor, faces_pos: torch_.Tensor, uv: torch_.Tensor, faces_uv: torch_.Tensor) -> torch_.Tensor:
+ """compute TBN matrix for each face
+
+Args:
+ faces_topo (torch.Tensor): (T, 3), face indice of topology
+ pos (torch.Tensor): shape (..., N_pos, 3), positions
+ faces_pos (torch.Tensor): shape(T, 3)
+ uv (torch.Tensor): shape (..., N_uv, 3) uv coordinates,
+ faces_uv (torch.Tensor): shape(T, 3)
+
+Returns:
+ torch.Tensor: (..., V, 3, 3) TBN matrix for each face. Note TBN vectors are normalized but not necessarily orthognal"""
+ utils3d.torch.mesh.compute_vertex_tbn
+
+@overload
+def laplacian(vertices: torch_.Tensor, faces: torch_.Tensor, weight: str = 'uniform') -> torch_.Tensor:
+ """Laplacian smooth with cotangent weights
+
+Args:
+ vertices (torch.Tensor): shape (..., N, 3)
+ faces (torch.Tensor): shape (T, 3)
+ weight (str): 'uniform' or 'cotangent'"""
+ utils3d.torch.mesh.laplacian
+
+@overload
+def laplacian_smooth_mesh(vertices: torch_.Tensor, faces: torch_.Tensor, weight: str = 'uniform', times: int = 5) -> torch_.Tensor:
+ """Laplacian smooth with cotangent weights
+
+Args:
+ vertices (torch.Tensor): shape (..., N, 3)
+ faces (torch.Tensor): shape (T, 3)
+ weight (str): 'uniform' or 'cotangent'"""
+ utils3d.torch.mesh.laplacian_smooth_mesh
+
+@overload
+def taubin_smooth_mesh(vertices: torch_.Tensor, faces: torch_.Tensor, lambda_: float = 0.5, mu_: float = -0.51) -> torch_.Tensor:
+ """Taubin smooth mesh
+
+Args:
+ vertices (torch.Tensor): _description_
+ faces (torch.Tensor): _description_
+ lambda_ (float, optional): _description_. Defaults to 0.5.
+ mu_ (float, optional): _description_. Defaults to -0.51.
+
+Returns:
+ torch.Tensor: _description_"""
+ utils3d.torch.mesh.taubin_smooth_mesh
+
+@overload
+def laplacian_hc_smooth_mesh(vertices: torch_.Tensor, faces: torch_.Tensor, times: int = 5, alpha: float = 0.5, beta: float = 0.5, weight: str = 'uniform'):
+ """HC algorithm from Improved Laplacian Smoothing of Noisy Surface Meshes by J.Vollmer et al.
+ """
+ utils3d.torch.mesh.laplacian_hc_smooth_mesh
+
+@overload
+def get_rays(extrinsics: torch_.Tensor, intrinsics: torch_.Tensor, uv: torch_.Tensor) -> Tuple[torch_.Tensor, torch_.Tensor]:
+ """Args:
+ extrinsics: (..., 4, 4) extrinsics matrices.
+ intrinsics: (..., 3, 3) intrinsics matrices.
+ uv: (..., n_rays, 2) uv coordinates of the rays.
+
+Returns:
+ rays_o: (..., 1, 3) ray origins
+ rays_d: (..., n_rays, 3) ray directions.
+ NOTE: ray directions are NOT normalized. They actuallys makes rays_o + rays_d * z = world coordinates, where z is the depth."""
+ utils3d.torch.nerf.get_rays
+
+@overload
+def get_image_rays(extrinsics: torch_.Tensor, intrinsics: torch_.Tensor, width: int, height: int) -> Tuple[torch_.Tensor, torch_.Tensor]:
+ """Args:
+ extrinsics: (..., 4, 4) extrinsics matrices.
+ intrinsics: (..., 3, 3) intrinsics matrices.
+ width: width of the image.
+ height: height of the image.
+
+Returns:
+ rays_o: (..., 1, 1, 3) ray origins
+ rays_d: (..., height, width, 3) ray directions.
+ NOTE: ray directions are NOT normalized. They actuallys makes rays_o + rays_d * z = world coordinates, where z is the depth."""
+ utils3d.torch.nerf.get_image_rays
+
+@overload
+def get_mipnerf_cones(rays_o: torch_.Tensor, rays_d: torch_.Tensor, z_vals: torch_.Tensor, pixel_width: torch_.Tensor) -> Tuple[torch_.Tensor, torch_.Tensor]:
+ """Args:
+ rays_o: (..., n_rays, 3) ray origins
+ rays_d: (..., n_rays, 3) ray directions.
+ z_vals: (..., n_rays, n_samples) z values.
+ pixel_width: (...) pixel width. = 1 / (normalized focal length * width)
+
+Returns:
+ mu: (..., n_rays, n_samples, 3) cone mu.
+ sigma: (..., n_rays, n_samples, 3, 3) cone sigma."""
+ utils3d.torch.nerf.get_mipnerf_cones
+
+@overload
+def volume_rendering(color: torch_.Tensor, sigma: torch_.Tensor, z_vals: torch_.Tensor, ray_length: torch_.Tensor, rgb: bool = True, depth: bool = True) -> Tuple[torch_.Tensor, torch_.Tensor, torch_.Tensor]:
+ """Given color, sigma and z_vals (linear depth of the sampling points), render the volume.
+
+NOTE: By default, color and sigma should have one less sample than z_vals, in correspondence with the average value in intervals.
+If queried color are aligned with z_vals, we use trapezoidal rule to calculate the average values in intervals.
+
+Args:
+ color: (..., n_samples or n_samples - 1, 3) color values.
+ sigma: (..., n_samples or n_samples - 1) density values.
+ z_vals: (..., n_samples) z values.
+ ray_length: (...) length of the ray
+
+Returns:
+ rgb: (..., 3) rendered color values.
+ depth: (...) rendered depth values.
+ weights (..., n_samples) weights."""
+ utils3d.torch.nerf.volume_rendering
+
+@overload
+def bin_sample(size: Union[torch_.Size, Tuple[int, ...]], n_samples: int, min_value: numbers.Number, max_value: numbers.Number, spacing: Literal['linear', 'inverse_linear'], dtype: torch_.dtype = None, device: torch_.device = None) -> torch_.Tensor:
+ """Uniformly (or uniformly in inverse space) sample z values in `n_samples` bins in range [min_value, max_value].
+Args:
+ size: size of the rays
+ n_samples: number of samples to be sampled, also the number of bins
+ min_value: minimum value of the range
+ max_value: maximum value of the range
+ space: 'linear' or 'inverse_linear'. If 'inverse_linear', the sampling is uniform in inverse space.
+
+Returns:
+ z_rand: (*size, n_samples) sampled z values, sorted in ascending order."""
+ utils3d.torch.nerf.bin_sample
+
+@overload
+def importance_sample(z_vals: torch_.Tensor, weights: torch_.Tensor, n_samples: int) -> Tuple[torch_.Tensor, torch_.Tensor]:
+ """Importance sample z values.
+
+NOTE: By default, weights should have one less sample than z_vals, in correspondence with the intervals.
+If weights has the same number of samples as z_vals, we use trapezoidal rule to calculate the average weights in intervals.
+
+Args:
+ z_vals: (..., n_rays, n_input_samples) z values, sorted in ascending order.
+ weights: (..., n_rays, n_input_samples or n_input_samples - 1) weights.
+ n_samples: number of output samples for importance sampling.
+
+Returns:
+ z_importance: (..., n_rays, n_samples) importance sampled z values, unsorted."""
+ utils3d.torch.nerf.importance_sample
+
+@overload
+def nerf_render_rays(nerf: Union[Callable[[torch_.Tensor, torch_.Tensor], Tuple[torch_.Tensor, torch_.Tensor]], Tuple[Callable[[torch_.Tensor], Tuple[torch_.Tensor, torch_.Tensor]], Callable[[torch_.Tensor], Tuple[torch_.Tensor, torch_.Tensor]]]], rays_o: torch_.Tensor, rays_d: torch_.Tensor, *, return_dict: bool = False, n_coarse: int = 64, n_fine: int = 64, near: float = 0.1, far: float = 100.0, z_spacing: Literal['linear', 'inverse_linear'] = 'linear'):
+ """NeRF rendering of rays. Note that it supports arbitrary batch dimensions (denoted as `...`)
+
+Args:
+ nerf: nerf model, which takes (points, directions) as input and returns (color, density) as output.
+ If nerf is a tuple, it should be (nerf_coarse, nerf_fine), where nerf_coarse and nerf_fine are two nerf models for coarse and fine stages respectively.
+
+ nerf args:
+ points: (..., n_rays, n_samples, 3)
+ directions: (..., n_rays, n_samples, 3)
+ nerf returns:
+ color: (..., n_rays, n_samples, 3) color values.
+ density: (..., n_rays, n_samples) density values.
+
+ rays_o: (..., n_rays, 3) ray origins
+ rays_d: (..., n_rays, 3) ray directions.
+ pixel_width: (..., n_rays) pixel width. How to compute? pixel_width = 1 / (normalized focal length * width)
+
+Returns
+ if return_dict is False, return rendered rgb and depth for short cut. (If there are separate coarse and fine results, return fine results)
+ rgb: (..., n_rays, 3) rendered color values.
+ depth: (..., n_rays) rendered depth values.
+ else, return a dict. If `n_fine == 0` or `nerf` is a single model, the dict only contains coarse results:
+ ```
+ {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}
+ ```
+ If there are two models for coarse and fine stages, the dict contains both coarse and fine results:
+ ```
+ {
+ "coarse": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..},
+ "fine": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}
+ }
+ ```"""
+ utils3d.torch.nerf.nerf_render_rays
+
+@overload
+def mipnerf_render_rays(mipnerf: Callable[[torch_.Tensor, torch_.Tensor, torch_.Tensor], Tuple[torch_.Tensor, torch_.Tensor]], rays_o: torch_.Tensor, rays_d: torch_.Tensor, pixel_width: torch_.Tensor, *, return_dict: bool = False, n_coarse: int = 64, n_fine: int = 64, uniform_ratio: float = 0.4, near: float = 0.1, far: float = 100.0, z_spacing: Literal['linear', 'inverse_linear'] = 'linear') -> Union[Tuple[torch_.Tensor, torch_.Tensor], Dict[str, torch_.Tensor]]:
+ """MipNeRF rendering.
+
+Args:
+ mipnerf: mipnerf model, which takes (points_mu, points_sigma) as input and returns (color, density) as output.
+
+ mipnerf args:
+ points_mu: (..., n_rays, n_samples, 3) cone mu.
+ points_sigma: (..., n_rays, n_samples, 3, 3) cone sigma.
+ directions: (..., n_rays, n_samples, 3)
+ mipnerf returns:
+ color: (..., n_rays, n_samples, 3) color values.
+ density: (..., n_rays, n_samples) density values.
+
+ rays_o: (..., n_rays, 3) ray origins
+ rays_d: (..., n_rays, 3) ray directions.
+ pixel_width: (..., n_rays) pixel width. How to compute? pixel_width = 1 / (normalized focal length * width)
+
+Returns
+ if return_dict is False, return rendered results only: (If `n_fine == 0`, return coarse results, otherwise return fine results)
+ rgb: (..., n_rays, 3) rendered color values.
+ depth: (..., n_rays) rendered depth values.
+ else, return a dict. If `n_fine == 0`, the dict only contains coarse results:
+ ```
+ {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}
+ ```
+ If n_fine > 0, the dict contains both coarse and fine results :
+ ```
+ {
+ "coarse": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..},
+ "fine": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}
+ }
+ ```"""
+ utils3d.torch.nerf.mipnerf_render_rays
+
+@overload
+def nerf_render_view(nerf: torch_.Tensor, extrinsics: torch_.Tensor, intrinsics: torch_.Tensor, width: int, height: int, *, patchify: bool = False, patch_size: Tuple[int, int] = (64, 64), **options: Dict[str, Any]) -> Tuple[torch_.Tensor, torch_.Tensor]:
+ """NeRF rendering of views. Note that it supports arbitrary batch dimensions (denoted as `...`)
+
+Args:
+ extrinsics: (..., 4, 4) extrinsics matrice of the rendered views
+ intrinsics (optional): (..., 3, 3) intrinsics matrice of the rendered views.
+ width (optional): image width of the rendered views.
+ height (optional): image height of the rendered views.
+ patchify (optional): If the image is too large, render it patch by patch
+ **options: rendering options.
+
+Returns:
+ rgb: (..., channels, height, width) rendered color values.
+ depth: (..., height, width) rendered depth values."""
+ utils3d.torch.nerf.nerf_render_view
+
+@overload
+def mipnerf_render_view(mipnerf: torch_.Tensor, extrinsics: torch_.Tensor, intrinsics: torch_.Tensor, width: int, height: int, *, patchify: bool = False, patch_size: Tuple[int, int] = (64, 64), **options: Dict[str, Any]) -> Tuple[torch_.Tensor, torch_.Tensor]:
+ """MipNeRF rendering of views. Note that it supports arbitrary batch dimensions (denoted as `...`)
+
+Args:
+ extrinsics: (..., 4, 4) extrinsics matrice of the rendered views
+ intrinsics (optional): (..., 3, 3) intrinsics matrice of the rendered views.
+ width (optional): image width of the rendered views.
+ height (optional): image height of the rendered views.
+ patchify (optional): If the image is too large, render it patch by patch
+ **options: rendering options.
+
+Returns:
+ rgb: (..., 3, height, width) rendered color values.
+ depth: (..., height, width) rendered depth values."""
+ utils3d.torch.nerf.mipnerf_render_view
+
+@overload
+def InstantNGP(view_dependent: bool = True, base_resolution: int = 16, finest_resolution: int = 2048, n_levels: int = 16, num_layers_density: int = 2, hidden_dim_density: int = 64, num_layers_color: int = 3, hidden_dim_color: int = 64, log2_hashmap_size: int = 19, bound: float = 1.0, color_channels: int = 3):
+ """An implementation of InstantNGP, Müller et. al., https://nvlabs.github.io/instant-ngp/.
+Requires `tinycudann` package.
+Install it by:
+```
+pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
+```"""
+ utils3d.torch.nerf.InstantNGP
+
+@overload
+def sliding_window_1d(x: torch_.Tensor, window_size: int, stride: int = 1, dim: int = -1) -> torch_.Tensor:
+ """Sliding window view of the input tensor. The dimension of the sliding window is appended to the end of the input tensor's shape.
+NOTE: Since Pytorch has `unfold` function, 1D sliding window view is just a wrapper of it."""
+ utils3d.torch.utils.sliding_window_1d
+
+@overload
+def sliding_window_2d(x: torch_.Tensor, window_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], dim: Union[int, Tuple[int, int]] = (-2, -1)) -> torch_.Tensor:
+ utils3d.torch.utils.sliding_window_2d
+
+@overload
+def sliding_window_nd(x: torch_.Tensor, window_size: Tuple[int, ...], stride: Tuple[int, ...], dim: Tuple[int, ...]) -> torch_.Tensor:
+ utils3d.torch.utils.sliding_window_nd
+
+@overload
+def image_uv(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, device: torch_.device = None, dtype: torch_.dtype = None) -> torch_.Tensor:
+ """Get image space UV grid, ranging in [0, 1].
+
+>>> image_uv(10, 10):
+[[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]],
+ [[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]],
+ ... ... ...
+ [[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]]
+
+Args:
+ width (int): image width
+ height (int): image height
+
+Returns:
+ np.ndarray: shape (height, width, 2)"""
+ utils3d.torch.utils.image_uv
+
+@overload
+def image_pixel_center(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, dtype: torch_.dtype = None, device: torch_.device = None) -> torch_.Tensor:
+ """Get image pixel center coordinates, ranging in [0, width] and [0, height].
+`image[i, j]` has pixel center coordinates `(j + 0.5, i + 0.5)`.
+
+>>> image_pixel_center(10, 10):
+[[[0.5, 0.5], [1.5, 0.5], ..., [9.5, 0.5]],
+ [[0.5, 1.5], [1.5, 1.5], ..., [9.5, 1.5]],
+ ... ... ...
+[[0.5, 9.5], [1.5, 9.5], ..., [9.5, 9.5]]]
+
+Args:
+ width (int): image width
+ height (int): image height
+
+Returns:
+ np.ndarray: shape (height, width, 2)"""
+ utils3d.torch.utils.image_pixel_center
+
+@overload
+def image_mesh(height: int, width: int, mask: torch_.Tensor = None, device: torch_.device = None, dtype: torch_.dtype = None) -> Tuple[torch_.Tensor, torch_.Tensor]:
+ """Get a quad mesh regarding image pixel uv coordinates as vertices and image grid as faces.
+
+Args:
+ width (int): image width
+ height (int): image height
+ mask (np.ndarray, optional): binary mask of shape (height, width), dtype=bool. Defaults to None.
+
+Returns:
+ uv (np.ndarray): uv corresponding to pixels as described in image_uv()
+ faces (np.ndarray): quad faces connecting neighboring pixels
+ indices (np.ndarray, optional): indices of vertices in the original mesh"""
+ utils3d.torch.utils.image_mesh
+
+@overload
+def chessboard(width: int, height: int, grid_size: int, color_a: torch_.Tensor, color_b: torch_.Tensor) -> torch_.Tensor:
+ """get a chessboard image
+
+Args:
+ width (int): image width
+ height (int): image height
+ grid_size (int): size of chessboard grid
+ color_a (torch.Tensor): shape (chanenls,), color of the grid at the top-left corner
+ color_b (torch.Tensor): shape (chanenls,), color in complementary grids
+
+Returns:
+ image (torch.Tensor): shape (height, width, channels), chessboard image"""
+ utils3d.torch.utils.chessboard
+
+@overload
+def depth_edge(depth: torch_.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch_.Tensor = None) -> torch_.BoolTensor:
+ """Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth.
+
+Args:
+ depth (torch.Tensor): shape (..., height, width), linear depth map
+ atol (float): absolute tolerance
+ rtol (float): relative tolerance
+
+Returns:
+ edge (torch.Tensor): shape (..., height, width) of dtype torch.bool"""
+ utils3d.torch.utils.depth_edge
+
+@overload
+def depth_aliasing(depth: torch_.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch_.Tensor = None) -> torch_.BoolTensor:
+ """Compute the map that indicates the aliasing of a depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors.
+Args:
+ depth (torch.Tensor): shape (..., height, width), linear depth map
+ atol (float): absolute tolerance
+ rtol (float): relative tolerance
+
+Returns:
+ edge (torch.Tensor): shape (..., height, width) of dtype torch.bool"""
+ utils3d.torch.utils.depth_aliasing
+
+@overload
+def image_mesh_from_depth(depth: torch_.Tensor, extrinsics: torch_.Tensor = None, intrinsics: torch_.Tensor = None) -> Tuple[torch_.Tensor, torch_.Tensor]:
+ utils3d.torch.utils.image_mesh_from_depth
+
+@overload
+def point_to_normal(point: torch_.Tensor, mask: torch_.Tensor = None) -> torch_.Tensor:
+ """Calculate normal map from point map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system.
+
+Args:
+ point (torch.Tensor): shape (..., height, width, 3), point map
+Returns:
+ normal (torch.Tensor): shape (..., height, width, 3), normal map. """
+ utils3d.torch.utils.point_to_normal
+
+@overload
+def depth_to_normal(depth: torch_.Tensor, intrinsics: torch_.Tensor, mask: torch_.Tensor = None) -> torch_.Tensor:
+ """Calculate normal map from depth map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system.
+
+Args:
+ depth (torch.Tensor): shape (..., height, width), linear depth map
+ intrinsics (torch.Tensor): shape (..., 3, 3), intrinsics matrix
+Returns:
+ normal (torch.Tensor): shape (..., 3, height, width), normal map. """
+ utils3d.torch.utils.depth_to_normal
+
+@overload
+def masked_min(input: torch_.Tensor, mask: torch_.BoolTensor, dim: int = None, keepdim: bool = False) -> Union[torch_.Tensor, Tuple[torch_.Tensor, torch_.Tensor]]:
+ """Similar to torch.min, but with mask
+ """
+ utils3d.torch.utils.masked_min
+
+@overload
+def masked_max(input: torch_.Tensor, mask: torch_.BoolTensor, dim: int = None, keepdim: bool = False) -> Union[torch_.Tensor, Tuple[torch_.Tensor, torch_.Tensor]]:
+ """Similar to torch.max, but with mask
+ """
+ utils3d.torch.utils.masked_max
+
+@overload
+def bounding_rect(mask: torch_.BoolTensor):
+ """get bounding rectangle of a mask
+
+Args:
+ mask (torch.Tensor): shape (..., height, width), mask
+
+Returns:
+ rect (torch.Tensor): shape (..., 4), bounding rectangle (left, top, right, bottom)"""
+ utils3d.torch.utils.bounding_rect
+
+@overload
+def perspective(fov_y: Union[float, torch_.Tensor], aspect: Union[float, torch_.Tensor], near: Union[float, torch_.Tensor], far: Union[float, torch_.Tensor]) -> torch_.Tensor:
+ """Get OpenGL perspective matrix
+
+Args:
+ fov_y (float | torch.Tensor): field of view in y axis
+ aspect (float | torch.Tensor): aspect ratio
+ near (float | torch.Tensor): near plane to clip
+ far (float | torch.Tensor): far plane to clip
+
+Returns:
+ (torch.Tensor): [..., 4, 4] perspective matrix"""
+ utils3d.torch.transforms.perspective
+
+@overload
+def perspective_from_fov(fov: Union[float, torch_.Tensor], width: Union[int, torch_.Tensor], height: Union[int, torch_.Tensor], near: Union[float, torch_.Tensor], far: Union[float, torch_.Tensor]) -> torch_.Tensor:
+ """Get OpenGL perspective matrix from field of view in largest dimension
+
+Args:
+ fov (float | torch.Tensor): field of view in largest dimension
+ width (int | torch.Tensor): image width
+ height (int | torch.Tensor): image height
+ near (float | torch.Tensor): near plane to clip
+ far (float | torch.Tensor): far plane to clip
+
+Returns:
+ (torch.Tensor): [..., 4, 4] perspective matrix"""
+ utils3d.torch.transforms.perspective_from_fov
+
+@overload
+def perspective_from_fov_xy(fov_x: Union[float, torch_.Tensor], fov_y: Union[float, torch_.Tensor], near: Union[float, torch_.Tensor], far: Union[float, torch_.Tensor]) -> torch_.Tensor:
+ """Get OpenGL perspective matrix from field of view in x and y axis
+
+Args:
+ fov_x (float | torch.Tensor): field of view in x axis
+ fov_y (float | torch.Tensor): field of view in y axis
+ near (float | torch.Tensor): near plane to clip
+ far (float | torch.Tensor): far plane to clip
+
+Returns:
+ (torch.Tensor): [..., 4, 4] perspective matrix"""
+ utils3d.torch.transforms.perspective_from_fov_xy
+
+@overload
+def intrinsics_from_focal_center(fx: Union[float, torch_.Tensor], fy: Union[float, torch_.Tensor], cx: Union[float, torch_.Tensor], cy: Union[float, torch_.Tensor]) -> torch_.Tensor:
+ """Get OpenCV intrinsics matrix
+
+Args:
+ focal_x (float | torch.Tensor): focal length in x axis
+ focal_y (float | torch.Tensor): focal length in y axis
+ cx (float | torch.Tensor): principal point in x axis
+ cy (float | torch.Tensor): principal point in y axis
+
+Returns:
+ (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix"""
+ utils3d.torch.transforms.intrinsics_from_focal_center
+
+@overload
+def intrinsics_from_fov(fov_max: Union[float, torch_.Tensor] = None, fov_min: Union[float, torch_.Tensor] = None, fov_x: Union[float, torch_.Tensor] = None, fov_y: Union[float, torch_.Tensor] = None, width: Union[int, torch_.Tensor] = None, height: Union[int, torch_.Tensor] = None) -> torch_.Tensor:
+ """Get normalized OpenCV intrinsics matrix from given field of view.
+You can provide either fov_max, fov_min, fov_x or fov_y
+
+Args:
+ width (int | torch.Tensor): image width
+ height (int | torch.Tensor): image height
+ fov_max (float | torch.Tensor): field of view in largest dimension
+ fov_min (float | torch.Tensor): field of view in smallest dimension
+ fov_x (float | torch.Tensor): field of view in x axis
+ fov_y (float | torch.Tensor): field of view in y axis
+
+Returns:
+ (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix"""
+ utils3d.torch.transforms.intrinsics_from_fov
+
+@overload
+def intrinsics_from_fov_xy(fov_x: Union[float, torch_.Tensor], fov_y: Union[float, torch_.Tensor]) -> torch_.Tensor:
+ """Get OpenCV intrinsics matrix from field of view in x and y axis
+
+Args:
+ fov_x (float | torch.Tensor): field of view in x axis
+ fov_y (float | torch.Tensor): field of view in y axis
+
+Returns:
+ (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix"""
+ utils3d.torch.transforms.intrinsics_from_fov_xy
+
+@overload
+def view_look_at(eye: torch_.Tensor, look_at: torch_.Tensor, up: torch_.Tensor) -> torch_.Tensor:
+ """Get OpenGL view matrix looking at something
+
+Args:
+ eye (torch.Tensor): [..., 3] the eye position
+ look_at (torch.Tensor): [..., 3] the position to look at
+ up (torch.Tensor): [..., 3] head up direction (y axis in screen space). Not necessarily othogonal to view direction
+
+Returns:
+ (torch.Tensor): [..., 4, 4], view matrix"""
+ utils3d.torch.transforms.view_look_at
+
+@overload
+def extrinsics_look_at(eye: torch_.Tensor, look_at: torch_.Tensor, up: torch_.Tensor) -> torch_.Tensor:
+ """Get OpenCV extrinsics matrix looking at something
+
+Args:
+ eye (torch.Tensor): [..., 3] the eye position
+ look_at (torch.Tensor): [..., 3] the position to look at
+ up (torch.Tensor): [..., 3] head up direction (-y axis in screen space). Not necessarily othogonal to view direction
+
+Returns:
+ (torch.Tensor): [..., 4, 4], extrinsics matrix"""
+ utils3d.torch.transforms.extrinsics_look_at
+
+@overload
+def perspective_to_intrinsics(perspective: torch_.Tensor) -> torch_.Tensor:
+ """OpenGL perspective matrix to OpenCV intrinsics
+
+Args:
+ perspective (torch.Tensor): [..., 4, 4] OpenGL perspective matrix
+
+Returns:
+ (torch.Tensor): shape [..., 3, 3] OpenCV intrinsics"""
+ utils3d.torch.transforms.perspective_to_intrinsics
+
+@overload
+def intrinsics_to_perspective(intrinsics: torch_.Tensor, near: Union[float, torch_.Tensor], far: Union[float, torch_.Tensor]) -> torch_.Tensor:
+ """OpenCV intrinsics to OpenGL perspective matrix
+
+Args:
+ intrinsics (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
+ near (float | torch.Tensor): [...] near plane to clip
+ far (float | torch.Tensor): [...] far plane to clip
+Returns:
+ (torch.Tensor): [..., 4, 4] OpenGL perspective matrix"""
+ utils3d.torch.transforms.intrinsics_to_perspective
+
+@overload
+def extrinsics_to_view(extrinsics: torch_.Tensor) -> torch_.Tensor:
+ """OpenCV camera extrinsics to OpenGL view matrix
+
+Args:
+ extrinsics (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix
+
+Returns:
+ (torch.Tensor): [..., 4, 4] OpenGL view matrix"""
+ utils3d.torch.transforms.extrinsics_to_view
+
+@overload
+def view_to_extrinsics(view: torch_.Tensor) -> torch_.Tensor:
+ """OpenGL view matrix to OpenCV camera extrinsics
+
+Args:
+ view (torch.Tensor): [..., 4, 4] OpenGL view matrix
+
+Returns:
+ (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix"""
+ utils3d.torch.transforms.view_to_extrinsics
+
+@overload
+def normalize_intrinsics(intrinsics: torch_.Tensor, width: Union[int, torch_.Tensor], height: Union[int, torch_.Tensor]) -> torch_.Tensor:
+ """Normalize camera intrinsics(s) to uv space
+
+Args:
+ intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to normalize
+ width (int | torch.Tensor): [...] image width(s)
+ height (int | torch.Tensor): [...] image height(s)
+
+Returns:
+ (torch.Tensor): [..., 3, 3] normalized camera intrinsics(s)"""
+ utils3d.torch.transforms.normalize_intrinsics
+
+@overload
+def crop_intrinsics(intrinsics: torch_.Tensor, width: Union[int, torch_.Tensor], height: Union[int, torch_.Tensor], left: Union[int, torch_.Tensor], top: Union[int, torch_.Tensor], crop_width: Union[int, torch_.Tensor], crop_height: Union[int, torch_.Tensor]) -> torch_.Tensor:
+ """Evaluate the new intrinsics(s) after crop the image: cropped_img = img[top:top+crop_height, left:left+crop_width]
+
+Args:
+ intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to crop
+ width (int | torch.Tensor): [...] image width(s)
+ height (int | torch.Tensor): [...] image height(s)
+ left (int | torch.Tensor): [...] left crop boundary
+ top (int | torch.Tensor): [...] top crop boundary
+ crop_width (int | torch.Tensor): [...] crop width
+ crop_height (int | torch.Tensor): [...] crop height
+
+Returns:
+ (torch.Tensor): [..., 3, 3] cropped camera intrinsics(s)"""
+ utils3d.torch.transforms.crop_intrinsics
+
+@overload
+def pixel_to_uv(pixel: torch_.Tensor, width: Union[int, torch_.Tensor], height: Union[int, torch_.Tensor]) -> torch_.Tensor:
+ """Args:
+ pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1)
+ width (int | torch.Tensor): [...] image width(s)
+ height (int | torch.Tensor): [...] image height(s)
+
+Returns:
+ (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)"""
+ utils3d.torch.transforms.pixel_to_uv
+
+@overload
+def pixel_to_ndc(pixel: torch_.Tensor, width: Union[int, torch_.Tensor], height: Union[int, torch_.Tensor]) -> torch_.Tensor:
+ """Args:
+ pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1)
+ width (int | torch.Tensor): [...] image width(s)
+ height (int | torch.Tensor): [...] image height(s)
+
+Returns:
+ (torch.Tensor): [..., 2] pixel coordinrates defined in ndc space, the range is (-1, 1)"""
+ utils3d.torch.transforms.pixel_to_ndc
+
+@overload
+def uv_to_pixel(uv: torch_.Tensor, width: Union[int, torch_.Tensor], height: Union[int, torch_.Tensor]) -> torch_.Tensor:
+ """Args:
+ uv (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)
+ width (int | torch.Tensor): [...] image width(s)
+ height (int | torch.Tensor): [...] image height(s)
+
+Returns:
+ (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)"""
+ utils3d.torch.transforms.uv_to_pixel
+
+@overload
+def project_depth(depth: torch_.Tensor, near: Union[float, torch_.Tensor], far: Union[float, torch_.Tensor]) -> torch_.Tensor:
+ """Project linear depth to depth value in screen space
+
+Args:
+ depth (torch.Tensor): [...] depth value
+ near (float | torch.Tensor): [...] near plane to clip
+ far (float | torch.Tensor): [...] far plane to clip
+
+Returns:
+ (torch.Tensor): [..., 1] depth value in screen space, value ranging in [0, 1]"""
+ utils3d.torch.transforms.project_depth
+
+@overload
+def depth_buffer_to_linear(depth: torch_.Tensor, near: Union[float, torch_.Tensor], far: Union[float, torch_.Tensor]) -> torch_.Tensor:
+ """Linearize depth value to linear depth
+
+Args:
+ depth (torch.Tensor): [...] screen depth value, ranging in [0, 1]
+ near (float | torch.Tensor): [...] near plane to clip
+ far (float | torch.Tensor): [...] far plane to clip
+
+Returns:
+ (torch.Tensor): [...] linear depth"""
+ utils3d.torch.transforms.depth_buffer_to_linear
+
+@overload
+def project_gl(points: torch_.Tensor, model: torch_.Tensor = None, view: torch_.Tensor = None, perspective: torch_.Tensor = None) -> Tuple[torch_.Tensor, torch_.Tensor]:
+ """Project 3D points to 2D following the OpenGL convention (except for row major matrice)
+
+Args:
+ points (torch.Tensor): [..., N, 3 or 4] 3D points to project, if the last
+ dimension is 4, the points are assumed to be in homogeneous coordinates
+ model (torch.Tensor): [..., 4, 4] model matrix
+ view (torch.Tensor): [..., 4, 4] view matrix
+ perspective (torch.Tensor): [..., 4, 4] perspective matrix
+
+Returns:
+ scr_coord (torch.Tensor): [..., N, 3] screen space coordinates, value ranging in [0, 1].
+ The origin (0., 0., 0.) is corresponding to the left & bottom & nearest
+ linear_depth (torch.Tensor): [..., N] linear depth"""
+ utils3d.torch.transforms.project_gl
+
+@overload
+def project_cv(points: torch_.Tensor, extrinsics: torch_.Tensor = None, intrinsics: torch_.Tensor = None) -> Tuple[torch_.Tensor, torch_.Tensor]:
+ """Project 3D points to 2D following the OpenCV convention
+
+Args:
+ points (torch.Tensor): [..., N, 3] or [..., N, 4] 3D points to project, if the last
+ dimension is 4, the points are assumed to be in homogeneous coordinates
+ extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix
+ intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix
+
+Returns:
+ uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1].
+ The origin (0., 0.) is corresponding to the left & top
+ linear_depth (torch.Tensor): [..., N] linear depth"""
+ utils3d.torch.transforms.project_cv
+
+@overload
+def unproject_gl(screen_coord: torch_.Tensor, model: torch_.Tensor = None, view: torch_.Tensor = None, perspective: torch_.Tensor = None) -> torch_.Tensor:
+ """Unproject screen space coordinates to 3D view space following the OpenGL convention (except for row major matrice)
+
+Args:
+ screen_coord (torch.Tensor): [... N, 3] screen space coordinates, value ranging in [0, 1].
+ The origin (0., 0., 0.) is corresponding to the left & bottom & nearest
+ model (torch.Tensor): [..., 4, 4] model matrix
+ view (torch.Tensor): [..., 4, 4] view matrix
+ perspective (torch.Tensor): [..., 4, 4] perspective matrix
+
+Returns:
+ points (torch.Tensor): [..., N, 3] 3d points"""
+ utils3d.torch.transforms.unproject_gl
+
+@overload
+def unproject_cv(uv_coord: torch_.Tensor, depth: torch_.Tensor, extrinsics: torch_.Tensor = None, intrinsics: torch_.Tensor = None) -> torch_.Tensor:
+ """Unproject uv coordinates to 3D view space following the OpenCV convention
+
+Args:
+ uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1].
+ The origin (0., 0.) is corresponding to the left & top
+ depth (torch.Tensor): [..., N] depth value
+ extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix
+ intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix
+
+Returns:
+ points (torch.Tensor): [..., N, 3] 3d points"""
+ utils3d.torch.transforms.unproject_cv
+
+@overload
+def skew_symmetric(v: torch_.Tensor):
+ """Skew symmetric matrix from a 3D vector"""
+ utils3d.torch.transforms.skew_symmetric
+
+@overload
+def rotation_matrix_from_vectors(v1: torch_.Tensor, v2: torch_.Tensor):
+ """Rotation matrix that rotates v1 to v2"""
+ utils3d.torch.transforms.rotation_matrix_from_vectors
+
+@overload
+def euler_axis_angle_rotation(axis: str, angle: torch_.Tensor) -> torch_.Tensor:
+ """Return the rotation matrices for one of the rotations about an axis
+of which Euler angles describe, for each value of the angle given.
+
+Args:
+ axis: Axis label "X" or "Y or "Z".
+ angle: any shape tensor of Euler angles in radians
+
+Returns:
+ Rotation matrices as tensor of shape (..., 3, 3)."""
+ utils3d.torch.transforms.euler_axis_angle_rotation
+
+@overload
+def euler_angles_to_matrix(euler_angles: torch_.Tensor, convention: str = 'XYZ') -> torch_.Tensor:
+ """Convert rotations given as Euler angles in radians to rotation matrices.
+
+Args:
+ euler_angles: Euler angles in radians as tensor of shape (..., 3), XYZ
+ convention: permutation of "X", "Y" or "Z", representing the order of Euler rotations to apply.
+
+Returns:
+ Rotation matrices as tensor of shape (..., 3, 3)."""
+ utils3d.torch.transforms.euler_angles_to_matrix
+
+@overload
+def matrix_to_euler_angles(matrix: torch_.Tensor, convention: str) -> torch_.Tensor:
+ """Convert rotations given as rotation matrices to Euler angles in radians.
+NOTE: The composition order eg. `XYZ` means `Rz * Ry * Rx` (like blender), instead of `Rx * Ry * Rz` (like pytorch3d)
+
+Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+ convention: Convention string of three uppercase letters.
+
+Returns:
+ Euler angles in radians as tensor of shape (..., 3), in the order of XYZ (like blender), instead of convention (like pytorch3d)"""
+ utils3d.torch.transforms.matrix_to_euler_angles
+
+@overload
+def matrix_to_quaternion(rot_mat: torch_.Tensor, eps: float = 1e-12) -> torch_.Tensor:
+ """Convert 3x3 rotation matrix to quaternion (w, x, y, z)
+
+Args:
+ rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert
+
+Returns:
+ torch.Tensor: shape (..., 4), the quaternions corresponding to the given rotation matrices"""
+ utils3d.torch.transforms.matrix_to_quaternion
+
+@overload
+def quaternion_to_matrix(quaternion: torch_.Tensor, eps: float = 1e-12) -> torch_.Tensor:
+ """Converts a batch of quaternions (w, x, y, z) to rotation matrices
+
+Args:
+ quaternion (torch.Tensor): shape (..., 4), the quaternions to convert
+
+Returns:
+ torch.Tensor: shape (..., 3, 3), the rotation matrices corresponding to the given quaternions"""
+ utils3d.torch.transforms.quaternion_to_matrix
+
+@overload
+def matrix_to_axis_angle(rot_mat: torch_.Tensor, eps: float = 1e-12) -> torch_.Tensor:
+ """Convert a batch of 3x3 rotation matrices to axis-angle representation (rotation vector)
+
+Args:
+ rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert
+
+Returns:
+ torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given rotation matrices"""
+ utils3d.torch.transforms.matrix_to_axis_angle
+
+@overload
+def axis_angle_to_matrix(axis_angle: torch_.Tensor, eps: float = 1e-12) -> torch_.Tensor:
+ """Convert axis-angle representation (rotation vector) to rotation matrix, whose direction is the axis of rotation and length is the angle of rotation
+
+Args:
+ axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors
+
+Returns:
+ torch.Tensor: shape (..., 3, 3) The rotation matrices for the given axis-angle parameters"""
+ utils3d.torch.transforms.axis_angle_to_matrix
+
+@overload
+def axis_angle_to_quaternion(axis_angle: torch_.Tensor, eps: float = 1e-12) -> torch_.Tensor:
+ """Convert axis-angle representation (rotation vector) to quaternion (w, x, y, z)
+
+Args:
+ axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors
+
+Returns:
+ torch.Tensor: shape (..., 4) The quaternions for the given axis-angle parameters"""
+ utils3d.torch.transforms.axis_angle_to_quaternion
+
+@overload
+def quaternion_to_axis_angle(quaternion: torch_.Tensor, eps: float = 1e-12) -> torch_.Tensor:
+ """Convert a batch of quaternions (w, x, y, z) to axis-angle representation (rotation vector)
+
+Args:
+ quaternion (torch.Tensor): shape (..., 4), the quaternions to convert
+
+Returns:
+ torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given quaternions"""
+ utils3d.torch.transforms.quaternion_to_axis_angle
+
+@overload
+def slerp(rot_mat_1: torch_.Tensor, rot_mat_2: torch_.Tensor, t: Union[numbers.Number, torch_.Tensor]) -> torch_.Tensor:
+ """Spherical linear interpolation between two rotation matrices
+
+Args:
+ rot_mat_1 (torch.Tensor): shape (..., 3, 3), the first rotation matrix
+ rot_mat_2 (torch.Tensor): shape (..., 3, 3), the second rotation matrix
+ t (torch.Tensor): scalar or shape (...,), the interpolation factor
+
+Returns:
+ torch.Tensor: shape (..., 3, 3), the interpolated rotation matrix"""
+ utils3d.torch.transforms.slerp
+
+@overload
+def interpolate_extrinsics(ext1: torch_.Tensor, ext2: torch_.Tensor, t: Union[numbers.Number, torch_.Tensor]) -> torch_.Tensor:
+ """Interpolate extrinsics between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation.
+
+Args:
+ ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose
+ ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose
+ t (torch.Tensor): scalar or shape (...,), the interpolation factor
+
+Returns:
+ torch.Tensor: shape (..., 4, 4), the interpolated camera pose"""
+ utils3d.torch.transforms.interpolate_extrinsics
+
+@overload
+def interpolate_view(view1: torch_.Tensor, view2: torch_.Tensor, t: Union[numbers.Number, torch_.Tensor]):
+ """Interpolate view matrices between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation.
+
+Args:
+ ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose
+ ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose
+ t (torch.Tensor): scalar or shape (...,), the interpolation factor
+
+Returns:
+ torch.Tensor: shape (..., 4, 4), the interpolated camera pose"""
+ utils3d.torch.transforms.interpolate_view
+
+@overload
+def extrinsics_to_essential(extrinsics: torch_.Tensor):
+ """extrinsics matrix `[[R, t] [0, 0, 0, 1]]` such that `x' = R (x - t)` to essential matrix such that `x' E x = 0`
+
+Args:
+ extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix
+
+Returns:
+ (torch.Tensor): [..., 3, 3] essential matrix"""
+ utils3d.torch.transforms.extrinsics_to_essential
+
+@overload
+def to4x4(R: torch_.Tensor, t: torch_.Tensor):
+ """Compose rotation matrix and translation vector to 4x4 transformation matrix
+
+Args:
+ R (torch.Tensor): [..., 3, 3] rotation matrix
+ t (torch.Tensor): [..., 3] translation vector
+
+Returns:
+ (torch.Tensor): [..., 4, 4] transformation matrix"""
+ utils3d.torch.transforms.to4x4
+
+@overload
+def rotation_matrix_2d(theta: Union[float, torch_.Tensor]):
+ """2x2 matrix for 2D rotation
+
+Args:
+ theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,)
+
+Returns:
+ (torch.Tensor): (..., 2, 2) rotation matrix"""
+ utils3d.torch.transforms.rotation_matrix_2d
+
+@overload
+def rotate_2d(theta: Union[float, torch_.Tensor], center: torch_.Tensor = None):
+ """3x3 matrix for 2D rotation around a center
+```
+ [[Rxx, Rxy, tx],
+ [Ryx, Ryy, ty],
+ [0, 0, 1]]
+```
+Args:
+ theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,)
+ center (torch.Tensor): rotation center, arbitrary shape (..., 2). Default to (0, 0)
+
+Returns:
+ (torch.Tensor): (..., 3, 3) transformation matrix"""
+ utils3d.torch.transforms.rotate_2d
+
+@overload
+def translate_2d(translation: torch_.Tensor):
+ """Translation matrix for 2D translation
+```
+ [[1, 0, tx],
+ [0, 1, ty],
+ [0, 0, 1]]
+```
+Args:
+ translation (torch.Tensor): translation vector, arbitrary shape (..., 2)
+
+Returns:
+ (torch.Tensor): (..., 3, 3) transformation matrix"""
+ utils3d.torch.transforms.translate_2d
+
+@overload
+def scale_2d(scale: Union[float, torch_.Tensor], center: torch_.Tensor = None):
+ """Scale matrix for 2D scaling
+```
+ [[s, 0, tx],
+ [0, s, ty],
+ [0, 0, 1]]
+```
+Args:
+ scale (float | torch.Tensor): scale factor, arbitrary shape (...,)
+ center (torch.Tensor): scale center, arbitrary shape (..., 2). Default to (0, 0)
+
+Returns:
+ (torch.Tensor): (..., 3, 3) transformation matrix"""
+ utils3d.torch.transforms.scale_2d
+
+@overload
+def apply_2d(transform: torch_.Tensor, points: torch_.Tensor):
+ """Apply (3x3 or 2x3) 2D affine transformation to points
+```
+ p = R @ p + t
+```
+Args:
+ transform (torch.Tensor): (..., 2 or 3, 3) transformation matrix
+ points (torch.Tensor): (..., N, 2) points to transform
+
+Returns:
+ (torch.Tensor): (..., N, 2) transformed points"""
+ utils3d.torch.transforms.apply_2d
+
+@overload
+def RastContext(nvd_ctx: Union[nvdiffrast.torch.ops.RasterizeCudaContext, nvdiffrast.torch.ops.RasterizeGLContext] = None, *, backend: Literal['cuda', 'gl'] = 'gl', device: Union[str, torch_.device] = None):
+ """Create a rasterization context. Nothing but a wrapper of nvdiffrast.torch.RasterizeCudaContext or nvdiffrast.torch.RasterizeGLContext."""
+ utils3d.torch.rasterization.RastContext
+
+@overload
+def rasterize_triangle_faces(ctx: utils3d.torch.rasterization.RastContext, vertices: torch_.Tensor, faces: torch_.Tensor, attr: torch_.Tensor, width: int, height: int, model: torch_.Tensor = None, view: torch_.Tensor = None, projection: torch_.Tensor = None, antialiasing: Union[bool, List[int]] = True, diff_attrs: Optional[List[int]] = None) -> Tuple[torch_.Tensor, torch_.Tensor, Optional[torch_.Tensor]]:
+ """Rasterize a mesh with vertex attributes.
+
+Args:
+ ctx (GLContext): rasterizer context
+ vertices (np.ndarray): (B, N, 2 or 3 or 4)
+ faces (torch.Tensor): (T, 3)
+ attr (torch.Tensor): (B, N, C)
+ width (int): width of the output image
+ height (int): height of the output image
+ model (torch.Tensor, optional): ([B,] 4, 4) model matrix. Defaults to None (identity).
+ view (torch.Tensor, optional): ([B,] 4, 4) view matrix. Defaults to None (identity).
+ projection (torch.Tensor, optional): ([B,] 4, 4) projection matrix. Defaults to None (identity).
+ antialiasing (Union[bool, List[int]], optional): whether to perform antialiasing. Defaults to True. If a list of indices is provided, only those channels will be antialiased.
+ diff_attrs (Union[None, List[int]], optional): indices of attributes to compute screen-space derivatives. Defaults to None.
+
+Returns:
+ image: (torch.Tensor): (B, C, H, W)
+ depth: (torch.Tensor): (B, H, W) screen space depth, ranging from 0 (near) to 1. (far)
+ NOTE: Empty pixels will have depth 1., i.e. far plane."""
+ utils3d.torch.rasterization.rasterize_triangle_faces
+
+@overload
+def warp_image_by_depth(ctx: utils3d.torch.rasterization.RastContext, depth: torch_.FloatTensor, image: torch_.FloatTensor = None, mask: torch_.BoolTensor = None, width: int = None, height: int = None, *, extrinsics_src: torch_.FloatTensor = None, extrinsics_tgt: torch_.FloatTensor = None, intrinsics_src: torch_.FloatTensor = None, intrinsics_tgt: torch_.FloatTensor = None, near: float = 0.1, far: float = 100.0, antialiasing: bool = True, backslash: bool = False, padding: int = 0, return_uv: bool = False, return_dr: bool = False) -> Tuple[torch_.FloatTensor, torch_.FloatTensor, torch_.BoolTensor, Optional[torch_.FloatTensor], Optional[torch_.FloatTensor]]:
+ """Warp image by depth.
+NOTE: if batch size is 1, image mesh will be triangulated aware of the depth, yielding less distorted results.
+Otherwise, image mesh will be triangulated simply for batch rendering.
+
+Args:
+ ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): rasterization context
+ depth (torch.Tensor): (B, H, W) linear depth
+ image (torch.Tensor): (B, C, H, W). None to use image space uv. Defaults to None.
+ width (int, optional): width of the output image. None to use the same as depth. Defaults to None.
+ height (int, optional): height of the output image. Defaults the same as depth..
+ extrinsics_src (torch.Tensor, optional): (B, 4, 4) extrinsics matrix for source. None to use identity. Defaults to None.
+ extrinsics_tgt (torch.Tensor, optional): (B, 4, 4) extrinsics matrix for target. None to use identity. Defaults to None.
+ intrinsics_src (torch.Tensor, optional): (B, 3, 3) intrinsics matrix for source. None to use the same as target. Defaults to None.
+ intrinsics_tgt (torch.Tensor, optional): (B, 3, 3) intrinsics matrix for target. None to use the same as source. Defaults to None.
+ near (float, optional): near plane. Defaults to 0.1.
+ far (float, optional): far plane. Defaults to 100.0.
+ antialiasing (bool, optional): whether to perform antialiasing. Defaults to True.
+ backslash (bool, optional): whether to use backslash triangulation. Defaults to False.
+ padding (int, optional): padding of the image. Defaults to 0.
+ return_uv (bool, optional): whether to return the uv. Defaults to False.
+ return_dr (bool, optional): whether to return the image-space derivatives of uv. Defaults to False.
+
+Returns:
+ image: (torch.FloatTensor): (B, C, H, W) rendered image
+ depth: (torch.FloatTensor): (B, H, W) linear depth, ranging from 0 to inf
+ mask: (torch.BoolTensor): (B, H, W) mask of valid pixels
+ uv: (torch.FloatTensor): (B, 2, H, W) image-space uv
+ dr: (torch.FloatTensor): (B, 4, H, W) image-space derivatives of uv"""
+ utils3d.torch.rasterization.warp_image_by_depth
+
+@overload
+def warp_image_by_forward_flow(ctx: utils3d.torch.rasterization.RastContext, image: torch_.FloatTensor, flow: torch_.FloatTensor, depth: torch_.FloatTensor = None, *, antialiasing: bool = True, backslash: bool = False) -> Tuple[torch_.FloatTensor, torch_.BoolTensor]:
+ """Warp image by forward flow.
+NOTE: if batch size is 1, image mesh will be triangulated aware of the depth, yielding less distorted results.
+Otherwise, image mesh will be triangulated simply for batch rendering.
+
+Args:
+ ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): rasterization context
+ image (torch.Tensor): (B, C, H, W) image
+ flow (torch.Tensor): (B, 2, H, W) forward flow
+ depth (torch.Tensor, optional): (B, H, W) linear depth. If None, will use the same for all pixels. Defaults to None.
+ antialiasing (bool, optional): whether to perform antialiasing. Defaults to True.
+ backslash (bool, optional): whether to use backslash triangulation. Defaults to False.
+
+Returns:
+ image: (torch.FloatTensor): (B, C, H, W) rendered image
+ mask: (torch.BoolTensor): (B, H, W) mask of valid pixels"""
+ utils3d.torch.rasterization.warp_image_by_forward_flow
+
diff --git a/src/utils3d/io/__init__.py b/src/utils3d/io/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e67b88d2d06b04520fec1cb21b70bdda521eafed
--- /dev/null
+++ b/src/utils3d/io/__init__.py
@@ -0,0 +1,3 @@
+from .obj import *
+from .colmap import *
+from .ply import *
diff --git a/src/utils3d/io/colmap.py b/src/utils3d/io/colmap.py
new file mode 100644
index 0000000000000000000000000000000000000000..d00ccbea8b3974e45fc91678000e9083e4ce378b
--- /dev/null
+++ b/src/utils3d/io/colmap.py
@@ -0,0 +1,139 @@
+from typing import *
+from pathlib import Path
+
+import numpy as np
+from scipy.spatial.transform import Rotation
+
+
+__all__ = ['read_extrinsics_from_colmap', 'read_intrinsics_from_colmap', 'write_extrinsics_as_colmap', 'write_intrinsics_as_colmap']
+
+
+def write_extrinsics_as_colmap(file: Union[str, Path], extrinsics: np.ndarray, image_names: Union[str, List[str]] = 'image_{i:04d}.png', camera_ids: List[int] = None):
+ """
+ Write extrinsics to colmap `images.txt` file.
+ Args:
+ file: Path to `images.txt` file.
+ extrinsics: (N, 4, 4) array of extrinsics.
+ image_names: str or List of str, image names. Length is N.
+ If str, it should be a format string with `i` as the index. (i starts from 1, in correspondence with IMAGE_ID in colmap)
+ camera_ids: List of int, camera ids. Length is N.
+ If None, it will be set to [1, 2, ..., N].
+ """
+ assert extrinsics.shape[1:] == (4, 4) and extrinsics.ndim == 3 or extrinsics.shape == (4, 4)
+ if extrinsics.ndim == 2:
+ extrinsics = extrinsics[np.newaxis, ...]
+ quats = Rotation.from_matrix(extrinsics[:, :3, :3]).as_quat()
+ trans = extrinsics[:, :3, 3]
+ if camera_ids is None:
+ camera_ids = list(range(1, len(extrinsics) + 1))
+ if isinstance(image_names, str):
+ image_names = [image_names.format(i=i) for i in range(1, len(extrinsics) + 1)]
+ assert len(extrinsics) == len(image_names) == len(camera_ids), \
+ f'Number of extrinsics ({len(extrinsics)}), image_names ({len(image_names)}), and camera_ids ({len(camera_ids)}) must be the same'
+ with open(file, 'w') as fp:
+ print("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME", file=fp)
+ for i, (quat, t, name, camera_id) in enumerate(zip(quats.tolist(), trans.tolist(), image_names, camera_ids)):
+ # Colmap has wxyz order while scipy.spatial.transform.Rotation has xyzw order.
+ qx, qy, qz, qw = quat
+ tx, ty, tz = t
+ print(f'{i + 1} {qw:f} {qx:f} {qy:f} {qz:f} {tx:f} {ty:f} {tz:f} {camera_id:d} {name}', file=fp)
+ print()
+
+
+def write_intrinsics_as_colmap(file: Union[str, Path], intrinsics: np.ndarray, width: int, height: int, normalized: bool = False):
+ """
+ Write intrinsics to colmap `cameras.txt` file. Currently only support PINHOLE model (no distortion)
+ Args:
+ file: Path to `cameras.txt` file.
+ intrinsics: (N, 3, 3) array of intrinsics.
+ width: Image width.
+ height: Image height.
+ normalized: Whether the intrinsics are normalized. If True, the intrinsics will unnormalized for writing.
+ """
+ assert intrinsics.shape[1:] == (3, 3) and intrinsics.ndim == 3 or intrinsics.shape == (3, 3)
+ if intrinsics.ndim == 2:
+ intrinsics = intrinsics[np.newaxis, ...]
+ if normalized:
+ intrinsics = intrinsics * np.array([width, height, 1])[:, None]
+ with open(file, 'w') as fp:
+ print("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]", file=fp)
+ for i, intr in enumerate(intrinsics):
+ fx, fy, cx, cy = intr[0, 0], intr[1, 1], intr[0, 2], intr[1, 2]
+ print(f'{i + 1} PINHOLE {width:d} {height:d} {fx:f} {fy:f} {cx:f} {cy:f}', file=fp)
+
+
+def read_extrinsics_from_colmap(file: Union[str, Path]) -> Union[np.ndarray, List[int], List[str]]:
+ """
+ Read extrinsics from colmap `images.txt` file.
+ Args:
+ file: Path to `images.txt` file.
+ Returns:
+ extrinsics: (N, 4, 4) array of extrinsics.
+ camera_ids: List of int, camera ids. Length is N. Note that camera ids in colmap typically starts from 1.
+ image_names: List of str, image names. Length is N.
+ """
+ with open(file) as fp:
+ lines = fp.readlines()
+ image_names, quats, trans, camera_ids = [], [], [], []
+ i_line = 0
+ for line in lines:
+ line = line.strip()
+ if line.startswith('#'):
+ continue
+ i_line += 1
+ if i_line % 2 == 0:
+ continue
+ image_id, qw, qx, qy, qz, tx, ty, tz, camera_id, name = line.split()
+ quats.append([float(qx), float(qy), float(qz), float(qw)])
+ trans.append([float(tx), float(ty), float(tz)])
+ camera_ids.append(int(camera_id))
+ image_names.append(name)
+
+ quats = np.array(quats, dtype=np.float32)
+ trans = np.array(trans, dtype=np.float32)
+ rotation = Rotation.from_quat(quats).as_matrix()
+ extrinsics = np.concatenate([
+ np.concatenate([rotation, trans[..., None]], axis=-1),
+ np.array([0, 0, 0, 1], dtype=np.float32)[None, None, :].repeat(len(quats), axis=0)
+ ], axis=-2)
+
+ return extrinsics, camera_ids, image_names
+
+
+def read_intrinsics_from_colmap(file: Union[str, Path], normalize: bool = False) -> Tuple[List[int], np.ndarray, np.ndarray]:
+ """
+ Read intrinsics from colmap `cameras.txt` file.
+ Args:
+ file: Path to `cameras.txt` file.
+ normalize: Whether to normalize the intrinsics. If True, the intrinsics will be normalized. (mapping coordinates to [0, 1] range)
+ Returns:
+ camera_ids: List of int, camera ids. Length is N. Note that camera ids in colmap typically starts from 1.
+ intrinsics: (N, 3, 3) array of intrinsics.
+ distortions: (N, 5) array of distortions.
+ """
+ with open(file) as fp:
+ lines = fp.readlines()
+ intrinsics, distortions, camera_ids = [], [], []
+ for line in lines:
+ line = line.strip()
+ if not line or line.startswith('#'):
+ continue
+ camera_id, model, width, height, *params = line.split()
+ camera_id, width, height = int(camera_id), int(width), int(height)
+ if model == 'PINHOLE':
+ fx, fy, cx, cy = map(float, params[:4])
+ k1 = k2 = k3 = p1 = p2 = 0.0
+ elif model == 'OPENCV':
+ fx, fy, cx, cy, k1, k2, p1, p2, k3 = *map(float, params[:8]), 0.0
+ elif model == 'SIMPLE_RADIAL':
+ f, cx, cy, k = map(float, params[:4])
+ fx = fy = f
+ k1, k2, p1, p2, k3 = k, 0.0, 0.0, 0.0, 0.0
+ camera_ids.append(camera_id)
+ if normalize:
+ fx, fy, cx, cy = fx / width, fy / height, cx / width, cy / height
+ intrinsics.append([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
+ distortions.append([k1, k2, p1, p2, k3])
+ intrinsics = np.array(intrinsics, dtype=np.float32)
+ distortions = np.array(distortions, dtype=np.float32)
+ return camera_ids, intrinsics, distortions
diff --git a/src/utils3d/io/obj.py b/src/utils3d/io/obj.py
new file mode 100644
index 0000000000000000000000000000000000000000..3471e490bd758fbf58173cfb7297ec747f46f173
--- /dev/null
+++ b/src/utils3d/io/obj.py
@@ -0,0 +1,146 @@
+from io import TextIOWrapper
+from typing import Dict, Any, Union, Iterable
+import numpy as np
+from pathlib import Path
+
+__all__ = [
+ 'read_obj',
+ 'write_obj',
+ 'simple_write_obj'
+]
+
+def read_obj(
+ file : Union[str, Path, TextIOWrapper],
+ encoding: Union[str, None] = None,
+ ignore_unknown: bool = False
+):
+ """
+ Read wavefront .obj file, without preprocessing.
+
+ Why bothering having this read_obj() while we already have other libraries like `trimesh`?
+ This function read the raw format from .obj file and keeps the order of vertices and faces,
+ while trimesh which involves modification like merge/split vertices, which could break the orders of vertices and faces,
+ Those libraries are commonly aiming at geometry processing and rendering supporting various formats.
+ If you want mesh geometry processing, you may turn to `trimesh` for more features.
+
+ ### Parameters
+ `file` (str, Path, TextIOWrapper): filepath or file object
+ encoding (str, optional):
+
+ ### Returns
+ obj (dict): A dict containing .obj components
+ {
+ 'mtllib': [],
+ 'v': [[0,1, 0.2, 1.0], [1.2, 0.0, 0.0], ...],
+ 'vt': [[0.5, 0.5], ...],
+ 'vn': [[0., 0.7, 0.7], [0., -0.7, 0.7], ...],
+ 'f': [[0, 1, 2], [2, 3, 4],...],
+ 'usemtl': [{'name': 'mtl1', 'f': 7}]
+ }
+ """
+ if hasattr(file,'read'):
+ lines = file.read().splitlines()
+ else:
+ with open(file, 'r', encoding=encoding) as fp:
+ lines = fp.read().splitlines()
+ mtllib = []
+ v, vt, vn, vp = [], [], [], [] # Vertex coordinates, Vertex texture coordinate, Vertex normal, Vertex parameter
+ f, ft, fn = [], [], [] # Face indices, Face texture indices, Face normal indices
+ o = []
+ s = []
+ usemtl = []
+
+ def pad(l: list, n: Any):
+ return l + [n] * (3 - len(l))
+
+ for i, line in enumerate(lines):
+ sq = line.strip().split()
+ if len(sq) == 0:
+ continue
+ if sq[0] == 'v':
+ assert 4 <= len(sq) <= 5, f'Invalid format of line {i}: {line}'
+ v.append([float(e) for e in sq[1:]][:3])
+ elif sq[0] == 'vt':
+ assert 3 <= len(sq) <= 4, f'Invalid format of line {i}: {line}'
+ vt.append([float(e) for e in sq[1:]][:2])
+ elif sq[0] == 'vn':
+ assert len(sq) == 4, f'Invalid format of line {i}: {line}'
+ vn.append([float(e) for e in sq[1:]])
+ elif sq[0] == 'vp':
+ assert 2 <= len(sq) <= 4, f'Invalid format of line {i}: {line}'
+ vp.append(pad([float(e) for e in sq[1:]], 0))
+ elif sq[0] == 'f':
+ spliting = [pad([int(j) - 1 for j in e.split('/')], -1) for e in sq[1:]]
+ f.append([e[0] for e in spliting])
+ ft.append([e[1] for e in spliting])
+ fn.append([e[2] for e in spliting])
+ elif sq[0] == 'usemtl':
+ assert len(sq) == 2
+ usemtl.append((sq[1], len(f)))
+ elif sq[0] == 'o':
+ assert len(sq) == 2
+ o.append((sq[1], len(f)))
+ elif sq[0] == 's':
+ s.append((sq[1], len(f)))
+ elif sq[0] == 'mtllib':
+ assert len(sq) == 2
+ mtllib.append(sq[1])
+ elif sq[0][0] == '#':
+ continue
+ else:
+ if not ignore_unknown:
+ raise Exception(f'Unknown keyword {sq[0]}')
+
+ min_poly_vertices = min(len(f) for f in f)
+ max_poly_vertices = max(len(f) for f in f)
+
+ return {
+ 'mtllib': mtllib,
+ 'v': np.array(v, dtype=np.float32),
+ 'vt': np.array(vt, dtype=np.float32),
+ 'vn': np.array(vn, dtype=np.float32),
+ 'vp': np.array(vp, dtype=np.float32),
+ 'f': np.array(f, dtype=np.int32) if min_poly_vertices == max_poly_vertices else f,
+ 'ft': np.array(ft, dtype=np.int32) if min_poly_vertices == max_poly_vertices else ft,
+ 'fn': np.array(fn, dtype=np.int32) if min_poly_vertices == max_poly_vertices else fn,
+ 'o': o,
+ 's': s,
+ 'usemtl': usemtl,
+ }
+
+
+def write_obj(
+ file: Union[str, Path],
+ obj: Dict[str, Any],
+ encoding: Union[str, None] = None
+ ):
+ with open(file, 'w', encoding=encoding) as fp:
+ for k in ['v', 'vt', 'vn', 'vp']:
+ if k not in obj:
+ continue
+ for v in obj[k]:
+ print(k, *map(float, v), file=fp)
+ for f in obj['f']:
+ print('f', *((str('/').join(map(int, i)) if isinstance(int(i), Iterable) else i) for i in f), file=fp)
+
+
+def simple_write_obj(
+ file: Union[str, Path],
+ vertices: np.ndarray,
+ faces: np.ndarray,
+ encoding: Union[str, None] = None
+ ):
+ """
+ Write wavefront .obj file, without preprocessing.
+
+ Args:
+ vertices (np.ndarray): [N, 3]
+ faces (np.ndarray): [T, 3]
+ file (Any): filepath
+ encoding (str, optional):
+ """
+ with open(file, 'w', encoding=encoding) as fp:
+ for v in vertices:
+ print('v', *map(float, v), file=fp)
+ for f in faces:
+ print('f', *map(int, f + 1), file=fp)
diff --git a/src/utils3d/io/ply.py b/src/utils3d/io/ply.py
new file mode 100644
index 0000000000000000000000000000000000000000..39fa41728a7be76d25743788c85dacb384d6d83e
--- /dev/null
+++ b/src/utils3d/io/ply.py
@@ -0,0 +1,104 @@
+import numpy as np
+
+from typing import *
+from pathlib import Path
+
+
+def read_ply(
+ file: Union[str, Path],
+ encoding: Union[str, None] = None,
+ ignore_unknown: bool = False
+) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Read .ply file, without preprocessing.
+
+ Args:
+ file (Any): filepath
+ encoding (str, optional):
+
+ Returns:
+ Tuple[np.ndarray, np.ndarray]: vertices, faces
+ """
+ import plyfile
+ plydata = plyfile.PlyData.read(file)
+ vertices = np.stack([plydata['vertex'][k] for k in ['x', 'y', 'z']], axis=-1)
+ if 'face' in plydata:
+ faces = np.array(plydata['face']['vertex_indices'].tolist())
+ else:
+ faces = None
+ return vertices, faces
+
+
+def write_ply(
+ file: Union[str, Path],
+ vertices: np.ndarray,
+ faces: np.ndarray = None,
+ edges: np.ndarray = None,
+ vertex_colors: np.ndarray = None,
+ edge_colors: np.ndarray = None,
+ text: bool = False
+):
+ """
+ Write .ply file, without preprocessing.
+
+ Args:
+ file (Any): filepath
+ vertices (np.ndarray): [N, 3]
+ faces (np.ndarray): [T, E]
+ edges (np.ndarray): [E, 2]
+ vertex_colors (np.ndarray, optional): [N, 3]. Defaults to None.
+ edge_colors (np.ndarray, optional): [E, 3]. Defaults to None.
+ text (bool, optional): save data in text format. Defaults to False.
+ """
+ import plyfile
+ assert vertices.ndim == 2 and vertices.shape[1] == 3
+ vertices = vertices.astype(np.float32)
+ if faces is not None:
+ assert faces.ndim == 2
+ faces = faces.astype(np.int32)
+ if edges is not None:
+ assert edges.ndim == 2 and edges.shape[1] == 2
+ edges = edges.astype(np.int32)
+
+ if vertex_colors is not None:
+ assert vertex_colors.ndim == 2 and vertex_colors.shape[1] == 3
+ if vertex_colors.dtype in [np.float32, np.float64]:
+ vertex_colors = vertex_colors * 255
+ vertex_colors = np.clip(vertex_colors, 0, 255).astype(np.uint8)
+ vertices_data = np.zeros(len(vertices), dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
+ vertices_data['x'] = vertices[:, 0]
+ vertices_data['y'] = vertices[:, 1]
+ vertices_data['z'] = vertices[:, 2]
+ vertices_data['red'] = vertex_colors[:, 0]
+ vertices_data['green'] = vertex_colors[:, 1]
+ vertices_data['blue'] = vertex_colors[:, 2]
+ else:
+ vertices_data = np.array([tuple(v) for v in vertices], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
+
+ if faces is not None:
+ faces_data = np.zeros(len(faces), dtype=[('vertex_indices', 'i4', (faces.shape[1],))])
+ faces_data['vertex_indices'] = faces
+
+ if edges is not None:
+ if edge_colors is not None:
+ assert edge_colors.ndim == 2 and edge_colors.shape[1] == 3
+ if edge_colors.dtype in [np.float32, np.float64]:
+ edge_colors = edge_colors * 255
+ edge_colors = np.clip(edge_colors, 0, 255).astype(np.uint8)
+ edges_data = np.zeros(len(edges), dtype=[('vertex1', 'i4'), ('vertex2', 'i4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
+ edges_data['vertex1'] = edges[:, 0]
+ edges_data['vertex2'] = edges[:, 1]
+ edges_data['red'] = edge_colors[:, 0]
+ edges_data['green'] = edge_colors[:, 1]
+ edges_data['blue'] = edge_colors[:, 2]
+ else:
+ edges_data = np.array([tuple(e) for e in edges], dtype=[('vertex1', 'i4'), ('vertex2', 'i4')])
+
+ ply_data = [plyfile.PlyElement.describe(vertices_data, 'vertex')]
+ if faces is not None:
+ ply_data.append(plyfile.PlyElement.describe(faces_data, 'face'))
+ if edges is not None:
+ ply_data.append(plyfile.PlyElement.describe(edges_data, 'edge'))
+
+ plyfile.PlyData(ply_data, text=text).write(file)
+
\ No newline at end of file
diff --git a/src/utils3d/numpy/__init__.py b/src/utils3d/numpy/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa06c53abe3b4abd39f1d7f8372851d7cdc58260
--- /dev/null
+++ b/src/utils3d/numpy/__init__.py
@@ -0,0 +1,142 @@
+"""
+3D utility functions workings with NumPy.
+"""
+import importlib
+import itertools
+import numpy
+from typing import TYPE_CHECKING
+
+
+__modules_all__ = {
+ 'mesh':[
+ 'triangulate',
+ 'compute_face_normal',
+ 'compute_face_angle',
+ 'compute_vertex_normal',
+ 'compute_vertex_normal_weighted',
+ 'remove_corrupted_faces',
+ 'merge_duplicate_vertices',
+ 'remove_unreferenced_vertices',
+ 'subdivide_mesh_simple',
+ 'mesh_relations',
+ 'flatten_mesh_indices'
+ ],
+ 'quadmesh': [
+ 'calc_quad_candidates',
+ 'calc_quad_distortion',
+ 'calc_quad_direction',
+ 'calc_quad_smoothness',
+ 'sovle_quad',
+ 'sovle_quad_qp',
+ 'tri_to_quad'
+ ],
+ 'utils': [
+ 'sliding_window_1d',
+ 'sliding_window_nd',
+ 'sliding_window_2d',
+ 'max_pool_1d',
+ 'max_pool_2d',
+ 'max_pool_nd',
+ 'depth_edge',
+ 'normals_edge',
+ 'depth_aliasing',
+ 'interpolate',
+ 'image_scrcoord',
+ 'image_uv',
+ 'image_pixel_center',
+ 'image_pixel',
+ 'image_mesh',
+ 'image_mesh_from_depth',
+ 'depth_to_normals',
+ 'points_to_normals',
+ 'chessboard',
+ 'cube',
+ 'icosahedron',
+ 'square',
+ 'camera_frustum',
+ ],
+ 'transforms': [
+ 'perspective',
+ 'perspective_from_fov',
+ 'perspective_from_fov_xy',
+ 'intrinsics_from_focal_center',
+ 'intrinsics_from_fov',
+ 'fov_to_focal',
+ 'focal_to_fov',
+ 'intrinsics_to_fov',
+ 'view_look_at',
+ 'extrinsics_look_at',
+ 'perspective_to_intrinsics',
+ 'perspective_to_near_far',
+ 'intrinsics_to_perspective',
+ 'extrinsics_to_view',
+ 'view_to_extrinsics',
+ 'normalize_intrinsics',
+ 'crop_intrinsics',
+ 'pixel_to_uv',
+ 'pixel_to_ndc',
+ 'uv_to_pixel',
+ 'project_depth',
+ 'depth_buffer_to_linear',
+ 'unproject_cv',
+ 'unproject_gl',
+ 'project_cv',
+ 'project_gl',
+ 'quaternion_to_matrix',
+ 'axis_angle_to_matrix',
+ 'matrix_to_quaternion',
+ 'extrinsics_to_essential',
+ 'euler_axis_angle_rotation',
+ 'euler_angles_to_matrix',
+ 'skew_symmetric',
+ 'rotation_matrix_from_vectors',
+ 'ray_intersection',
+ 'se3_matrix',
+ 'slerp_quaternion',
+ 'slerp_vector',
+ 'lerp',
+ 'lerp_se3_matrix',
+ 'piecewise_lerp',
+ 'piecewise_lerp_se3_matrix',
+ 'apply_transform'
+ ],
+ 'spline': [
+ 'linear_spline_interpolate',
+ ],
+ 'rasterization': [
+ 'RastContext',
+ 'rasterize_triangle_faces',
+ 'rasterize_edges',
+ 'texture',
+ 'warp_image_by_depth',
+ 'test_rasterization'
+ ],
+}
+
+
+__all__ = list(itertools.chain(*__modules_all__.values()))
+
+def __getattr__(name):
+ try:
+ return globals()[name]
+ except KeyError:
+ pass
+
+ try:
+ module_name = next(m for m in __modules_all__ if name in __modules_all__[m])
+ except StopIteration:
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
+ module = importlib.import_module(f'.{module_name}', __name__)
+ for key in __modules_all__[module_name]:
+ globals()[key] = getattr(module, key)
+
+ return globals()[name]
+
+
+if TYPE_CHECKING:
+ from .quadmesh import *
+ from .transforms import *
+ from .mesh import *
+ from .utils import *
+ from .rasterization import *
+ from .spline import *
\ No newline at end of file
diff --git a/src/utils3d/numpy/_helpers.py b/src/utils3d/numpy/_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c397df338e3e04e0f228341f68171d8e067eb4e
--- /dev/null
+++ b/src/utils3d/numpy/_helpers.py
@@ -0,0 +1,93 @@
+# decorator
+import numpy as np
+from numbers import Number
+import inspect
+from functools import wraps
+from typing import *
+from .._helpers import suppress_traceback
+
+
+def get_args_order(func, args, kwargs):
+ """
+ Get the order of the arguments of a function.
+ """
+ names = inspect.getfullargspec(func).args
+ names_idx = {name: i for i, name in enumerate(names)}
+ args_order = []
+ kwargs_order = {}
+ for name, arg in kwargs.items():
+ if name in names:
+ kwargs_order[name] = names_idx[name]
+ names.remove(name)
+ for i, arg in enumerate(args):
+ if i < len(names):
+ args_order.append(names_idx[names[i]])
+ return args_order, kwargs_order
+
+
+def broadcast_args(args, kwargs, args_dim, kwargs_dim):
+ spatial = []
+ for arg, arg_dim in zip(args + list(kwargs.values()), args_dim + list(kwargs_dim.values())):
+ if isinstance(arg, np.ndarray) and arg_dim is not None:
+ arg_spatial = arg.shape[:arg.ndim-arg_dim]
+ if len(arg_spatial) > len(spatial):
+ spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial
+ for j in range(len(arg_spatial)):
+ if spatial[-j] < arg_spatial[-j]:
+ if spatial[-j] == 1:
+ spatial[-j] = arg_spatial[-j]
+ else:
+ raise ValueError("Cannot broadcast arguments.")
+ for i, arg in enumerate(args):
+ if isinstance(arg, np.ndarray) and args_dim[i] is not None:
+ args[i] = np.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-args_dim[i]:]])
+ for key, arg in kwargs.items():
+ if isinstance(arg, np.ndarray) and kwargs_dim[key] is not None:
+ kwargs[key] = np.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-kwargs_dim[key]:]])
+ return args, kwargs, spatial
+
+
+def batched(*dims):
+ """
+ Decorator that allows a function to be called with batched arguments.
+ """
+ def decorator(func):
+ @wraps(func)
+ @suppress_traceback
+ def wrapper(*args, **kwargs):
+ args = list(args)
+ # get arguments dimensions
+ args_order, kwargs_order = get_args_order(func, args, kwargs)
+ args_dim = [dims[i] for i in args_order]
+ kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()}
+ # convert to numpy array
+ for i, arg in enumerate(args):
+ if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None:
+ args[i] = np.array(arg)
+ for key, arg in kwargs.items():
+ if isinstance(arg, (Number, list, tuple)) and kwargs_dim[key] is not None:
+ kwargs[key] = np.array(arg)
+ # broadcast arguments
+ args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim)
+ for i, (arg, arg_dim) in enumerate(zip(args, args_dim)):
+ if isinstance(arg, np.ndarray) and arg_dim is not None:
+ args[i] = arg.reshape([-1, *arg.shape[arg.ndim-arg_dim:]])
+ for key, arg in kwargs.items():
+ if isinstance(arg, np.ndarray) and kwargs_dim[key] is not None:
+ kwargs[key] = arg.reshape([-1, *arg.shape[arg.ndim-kwargs_dim[key]:]])
+ # call function
+ results = func(*args, **kwargs)
+ type_results = type(results)
+ results = list(results) if isinstance(results, (tuple, list)) else [results]
+ # restore spatial dimensions
+ for i, result in enumerate(results):
+ results[i] = result.reshape([*spatial, *result.shape[1:]])
+ if type_results == tuple:
+ results = tuple(results)
+ elif type_results == list:
+ results = list(results)
+ else:
+ results = results[0]
+ return results
+ return wrapper
+ return decorator
diff --git a/src/utils3d/numpy/mesh.py b/src/utils3d/numpy/mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..afadb5f2510b58a1c5acbabff2ff798c041744d6
--- /dev/null
+++ b/src/utils3d/numpy/mesh.py
@@ -0,0 +1,355 @@
+import numpy as np
+from typing import *
+from ._helpers import batched
+
+
+__all__ = [
+ 'triangulate',
+ 'compute_face_normal',
+ 'compute_face_angle',
+ 'compute_vertex_normal',
+ 'compute_vertex_normal_weighted',
+ 'remove_corrupted_faces',
+ 'merge_duplicate_vertices',
+ 'remove_unreferenced_vertices',
+ 'subdivide_mesh_simple',
+ 'mesh_relations',
+ 'flatten_mesh_indices'
+]
+
+
+def triangulate(
+ faces: np.ndarray,
+ vertices: np.ndarray = None,
+ backslash: np.ndarray = None
+) -> np.ndarray:
+ """
+ Triangulate a polygonal mesh.
+
+ Args:
+ faces (np.ndarray): [L, P] polygonal faces
+ vertices (np.ndarray, optional): [N, 3] 3-dimensional vertices.
+ If given, the triangulation is performed according to the distance
+ between vertices. Defaults to None.
+ backslash (np.ndarray, optional): [L] boolean array indicating
+ how to triangulate the quad faces. Defaults to None.
+
+ Returns:
+ (np.ndarray): [L * (P - 2), 3] triangular faces
+ """
+ if faces.shape[-1] == 3:
+ return faces
+ P = faces.shape[-1]
+ if vertices is not None:
+ assert faces.shape[-1] == 4, "now only support quad mesh"
+ if backslash is None:
+ backslash = np.linalg.norm(vertices[faces[:, 0]] - vertices[faces[:, 2]], axis=-1) < \
+ np.linalg.norm(vertices[faces[:, 1]] - vertices[faces[:, 3]], axis=-1)
+ if backslash is None:
+ loop_indice = np.stack([
+ np.zeros(P - 2, dtype=int),
+ np.arange(1, P - 1, 1, dtype=int),
+ np.arange(2, P, 1, dtype=int)
+ ], axis=1)
+ return faces[:, loop_indice].reshape((-1, 3))
+ else:
+ assert faces.shape[-1] == 4, "now only support quad mesh"
+ faces = np.where(
+ backslash[:, None],
+ faces[:, [0, 1, 2, 0, 2, 3]],
+ faces[:, [0, 1, 3, 3, 1, 2]]
+ ).reshape((-1, 3))
+ return faces
+
+
+@batched(2, None)
+def compute_face_normal(
+ vertices: np.ndarray,
+ faces: np.ndarray
+) -> np.ndarray:
+ """
+ Compute face normals of a triangular mesh
+
+ Args:
+ vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
+ faces (np.ndarray): [T, 3] triangular face indices
+
+ Returns:
+ normals (np.ndarray): [..., T, 3] face normals
+ """
+ normal = np.cross(
+ vertices[..., faces[:, 1], :] - vertices[..., faces[:, 0], :],
+ vertices[..., faces[:, 2], :] - vertices[..., faces[:, 0], :]
+ )
+ normal_norm = np.linalg.norm(normal, axis=-1, keepdims=True)
+ normal_norm[normal_norm == 0] = 1
+ normal /= normal_norm
+ return normal
+
+
+@batched(2, None)
+def compute_face_angle(
+ vertices: np.ndarray,
+ faces: np.ndarray,
+ eps: float = 1e-12
+ ) -> np.ndarray:
+ """
+ Compute face angles of a triangular mesh
+
+ Args:
+ vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
+ faces (np.ndarray): [T, 3] triangular face indices
+
+ Returns:
+ angles (np.ndarray): [..., T, 3] face angles
+ """
+ face_angle = np.zeros_like(faces, dtype=vertices.dtype)
+ for i in range(3):
+ edge1 = vertices[..., faces[:, (i + 1) % 3], :] - vertices[..., faces[:, i], :]
+ edge2 = vertices[..., faces[:, (i + 2) % 3], :] - vertices[..., faces[:, i], :]
+ face_angle[..., i] = np.arccos(np.sum(
+ edge1 / np.clip(np.linalg.norm(edge1, axis=-1, keepdims=True), eps, None) *
+ edge2 / np.clip(np.linalg.norm(edge2, axis=-1, keepdims=True), eps, None),
+ axis=-1
+ ))
+ return face_angle
+
+
+@batched(2, None, 2)
+def compute_vertex_normal(
+ vertices: np.ndarray,
+ faces: np.ndarray,
+ face_normal: np.ndarray = None
+) -> np.ndarray:
+ """
+ Compute vertex normals of a triangular mesh by averaging neightboring face normals
+ TODO: can be improved.
+
+ Args:
+ vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
+ faces (np.ndarray): [T, 3] triangular face indices
+ face_normal (np.ndarray, optional): [..., T, 3] face normals.
+ None to compute face normals from vertices and faces. Defaults to None.
+
+ Returns:
+ normals (np.ndarray): [..., N, 3] vertex normals
+ """
+ if face_normal is None:
+ face_normal = compute_face_normal(vertices, faces)
+ vertex_normal = np.zeros_like(vertices, dtype=vertices.dtype)
+ for n in range(vertices.shape[0]):
+ for i in range(3):
+ vertex_normal[n, :, 0] += np.bincount(faces[:, i], weights=face_normal[n, :, 0], minlength=vertices.shape[1])
+ vertex_normal[n, :, 1] += np.bincount(faces[:, i], weights=face_normal[n, :, 1], minlength=vertices.shape[1])
+ vertex_normal[n, :, 2] += np.bincount(faces[:, i], weights=face_normal[n, :, 2], minlength=vertices.shape[1])
+ vertex_normal_norm = np.linalg.norm(vertex_normal, axis=-1, keepdims=True)
+ vertex_normal_norm[vertex_normal_norm == 0] = 1
+ vertex_normal /= vertex_normal_norm
+ return vertex_normal
+
+
+@batched(2, None, 2)
+def compute_vertex_normal_weighted(
+ vertices: np.ndarray,
+ faces: np.ndarray,
+ face_normal: np.ndarray = None
+) -> np.ndarray:
+ """
+ Compute vertex normals of a triangular mesh by weighted sum of neightboring face normals
+ according to the angles
+
+ Args:
+ vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
+ faces (np.ndarray): [..., T, 3] triangular face indices
+ face_normal (np.ndarray, optional): [..., T, 3] face normals.
+ None to compute face normals from vertices and faces. Defaults to None.
+
+ Returns:
+ normals (np.ndarray): [..., N, 3] vertex normals
+ """
+ if face_normal is None:
+ face_normal = compute_face_normal(vertices, faces)
+ face_angle = compute_face_angle(vertices, faces)
+ vertex_normal = np.zeros_like(vertices)
+ for n in range(vertices.shape[0]):
+ for i in range(3):
+ vertex_normal[n, :, 0] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 0] * face_angle[n, :, i], minlength=vertices.shape[1])
+ vertex_normal[n, :, 1] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 1] * face_angle[n, :, i], minlength=vertices.shape[1])
+ vertex_normal[n, :, 2] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 2] * face_angle[n, :, i], minlength=vertices.shape[1])
+ vertex_normal_norm = np.linalg.norm(vertex_normal, axis=-1, keepdims=True)
+ vertex_normal_norm[vertex_normal_norm == 0] = 1
+ vertex_normal /= vertex_normal_norm
+ return vertex_normal
+
+
+def remove_corrupted_faces(
+ faces: np.ndarray
+ ) -> np.ndarray:
+ """
+ Remove corrupted faces (faces with duplicated vertices)
+
+ Args:
+ faces (np.ndarray): [T, 3] triangular face indices
+
+ Returns:
+ np.ndarray: [T_, 3] triangular face indices
+ """
+ corrupted = (faces[:, 0] == faces[:, 1]) | (faces[:, 1] == faces[:, 2]) | (faces[:, 2] == faces[:, 0])
+ return faces[~corrupted]
+
+
+def merge_duplicate_vertices(
+ vertices: np.ndarray,
+ faces: np.ndarray,
+ tol: float = 1e-6
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Merge duplicate vertices of a triangular mesh.
+ Duplicate vertices are merged by selecte one of them, and the face indices are updated accordingly.
+
+ Args:
+ vertices (np.ndarray): [N, 3] 3-dimensional vertices
+ faces (np.ndarray): [T, 3] triangular face indices
+ tol (float, optional): tolerance for merging. Defaults to 1e-6.
+
+ Returns:
+ vertices (np.ndarray): [N_, 3] 3-dimensional vertices
+ faces (np.ndarray): [T, 3] triangular face indices
+ """
+ vertices_round = np.round(vertices / tol)
+ _, uni_i, uni_inv = np.unique(vertices_round, return_index=True, return_inverse=True, axis=0)
+ vertices = vertices[uni_i]
+ faces = uni_inv[faces]
+ return vertices, faces
+
+
+def remove_unreferenced_vertices(
+ faces: np.ndarray,
+ *vertice_attrs,
+ return_indices: bool = False
+) -> Tuple[np.ndarray, ...]:
+ """
+ Remove unreferenced vertices of a mesh.
+ Unreferenced vertices are removed, and the face indices are updated accordingly.
+
+ Args:
+ faces (np.ndarray): [T, P] face indices
+ *vertice_attrs: vertex attributes
+
+ Returns:
+ faces (np.ndarray): [T, P] face indices
+ *vertice_attrs: vertex attributes
+ indices (np.ndarray, optional): [N] indices of vertices that are kept. Defaults to None.
+ """
+ P = faces.shape[-1]
+ fewer_indices, inv_map = np.unique(faces, return_inverse=True)
+ faces = inv_map.astype(np.int32).reshape(-1, P)
+ ret = [faces]
+ for attr in vertice_attrs:
+ ret.append(attr[fewer_indices])
+ if return_indices:
+ ret.append(fewer_indices)
+ return tuple(ret)
+
+
+def subdivide_mesh_simple(
+ vertices: np.ndarray,
+ faces: np.ndarray,
+ n: int = 1
+) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Subdivide a triangular mesh by splitting each triangle into 4 smaller triangles.
+ NOTE: All original vertices are kept, and new vertices are appended to the end of the vertex list.
+
+ Args:
+ vertices (np.ndarray): [N, 3] 3-dimensional vertices
+ faces (np.ndarray): [T, 3] triangular face indices
+ n (int, optional): number of subdivisions. Defaults to 1.
+
+ Returns:
+ vertices (np.ndarray): [N_, 3] subdivided 3-dimensional vertices
+ faces (np.ndarray): [4 * T, 3] subdivided triangular face indices
+ """
+ for _ in range(n):
+ edges = np.stack([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], axis=0)
+ edges = np.sort(edges, axis=2)
+ uni_edges, uni_inv = np.unique(edges.reshape(-1, 2), return_inverse=True, axis=0)
+ uni_inv = uni_inv.reshape(3, -1)
+ midpoints = (vertices[uni_edges[:, 0]] + vertices[uni_edges[:, 1]]) / 2
+
+ n_vertices = vertices.shape[0]
+ vertices = np.concatenate([vertices, midpoints], axis=0)
+ faces = np.concatenate([
+ np.stack([faces[:, 0], n_vertices + uni_inv[0], n_vertices + uni_inv[2]], axis=1),
+ np.stack([faces[:, 1], n_vertices + uni_inv[1], n_vertices + uni_inv[0]], axis=1),
+ np.stack([faces[:, 2], n_vertices + uni_inv[2], n_vertices + uni_inv[1]], axis=1),
+ np.stack([n_vertices + uni_inv[0], n_vertices + uni_inv[1], n_vertices + uni_inv[2]], axis=1),
+ ], axis=0)
+ return vertices, faces
+
+
+def mesh_relations(
+ faces: np.ndarray,
+) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Calculate the relation between vertices and faces.
+ NOTE: The input mesh must be a manifold triangle mesh.
+
+ Args:
+ faces (np.ndarray): [T, 3] triangular face indices
+
+ Returns:
+ edges (np.ndarray): [E, 2] edge indices
+ edge2face (np.ndarray): [E, 2] edge to face relation. The second column is -1 if the edge is boundary.
+ face2edge (np.ndarray): [T, 3] face to edge relation
+ face2face (np.ndarray): [T, 3] face to face relation
+ """
+ T = faces.shape[0]
+ edges = np.stack([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], axis=1).reshape(-1, 2) # [3T, 2]
+ edges = np.sort(edges, axis=1) # [3T, 2]
+ edges, face2edge, occurence = np.unique(edges, axis=0, return_inverse=True, return_counts=True) # [E, 2], [3T], [E]
+ E = edges.shape[0]
+ assert np.all(occurence <= 2), "The input mesh is not a manifold mesh."
+
+ # Edge to face relation
+ padding = np.arange(E, dtype=np.int32)[occurence == 1]
+ padded_face2edge = np.concatenate([face2edge, padding], axis=0) # [2E]
+ edge2face = np.argsort(padded_face2edge, kind='stable').reshape(-1, 2) // 3 # [E, 2]
+ edge2face_valid = edge2face[:, 1] < T # [E]
+ edge2face[~edge2face_valid, 1] = -1
+
+ # Face to edge relation
+ face2edge = face2edge.reshape(-1, 3) # [T, 3]
+
+ # Face to face relation
+ face2face = edge2face[face2edge] # [T, 3, 2]
+ face2face = face2face[face2face != np.arange(T)[:, None, None]].reshape(T, 3) # [T, 3]
+
+ return edges, edge2face, face2edge, face2face
+
+
+@overload
+def flatten_mesh_indices(faces1: np.ndarray, attr1: np.ndarray, *other_faces_attrs_pairs: np.ndarray) -> Tuple[np.ndarray, ...]:
+ """
+ Rearrange the indices of a mesh to a flattened version. Vertices will be no longer shared.
+
+ ### Parameters:
+ - `faces1`: [T, P] face indices of the first attribute
+ - `attr1`: [N1, ...] attributes of the first mesh
+ - ...
+
+ ### Returns:
+ - `faces`: [T, P] flattened face indices, contigous from 0 to T * P - 1
+ - `attr1`: [T * P, ...] attributes of the first mesh, where every P values correspond to a face
+ _ ...
+ """
+def flatten_mesh_indices(*args: np.ndarray) -> Tuple[np.ndarray, ...]:
+ assert len(args) % 2 == 0, "The number of arguments must be even."
+ T, P = args[0].shape
+ assert all(arg.shape[0] == T and arg.shape[1] == P for arg in args[::2]), "The faces must have the same shape."
+ attr_flat = []
+ for faces_, attr_ in zip(args[::2], args[1::2]):
+ attr_flat_ = attr_[faces_].reshape(-1, *attr_.shape[1:])
+ attr_flat.append(attr_flat_)
+ faces_flat = np.arange(T * P, dtype=np.int32).reshape(T, P)
+ return faces_flat, *attr_flat
\ No newline at end of file
diff --git a/src/utils3d/numpy/quadmesh.py b/src/utils3d/numpy/quadmesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..6728d91124020767cc9b3c1fdd6b21d50dc55828
--- /dev/null
+++ b/src/utils3d/numpy/quadmesh.py
@@ -0,0 +1,472 @@
+import numpy as np
+import scipy as sp
+import scipy.optimize as spopt
+from typing import *
+
+
+__all__ = [
+ 'calc_quad_candidates',
+ 'calc_quad_distortion',
+ 'calc_quad_direction',
+ 'calc_quad_smoothness',
+ 'sovle_quad',
+ 'sovle_quad_qp',
+ 'tri_to_quad'
+]
+
+
+def calc_quad_candidates(
+ edges: np.ndarray,
+ face2edge: np.ndarray,
+ edge2face: np.ndarray,
+):
+ """
+ Calculate the candidate quad faces.
+
+ Args:
+ edges (np.ndarray): [E, 2] edge indices
+ face2edge (np.ndarray): [T, 3] face to edge relation
+ edge2face (np.ndarray): [E, 2] edge to face relation
+
+ Returns:
+ quads (np.ndarray): [Q, 4] quad candidate indices
+ quad2edge (np.ndarray): [Q, 4] edge to quad candidate relation
+ quad2adj (np.ndarray): [Q, 8] adjacent quad candidates of each quad candidate
+ quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid
+ """
+ E = edges.shape[0]
+ T = face2edge.shape[0]
+
+ quads_valid = edge2face[:, 1] != -1
+ Q = quads_valid.sum()
+ quad2face = edge2face[quads_valid] # [Q, 2]
+ quad2edge = face2edge[quad2face] # [Q, 2, 3]
+ flag = quad2edge == np.arange(E)[quads_valid][:, None, None] # [Q, 2, 3]
+ flag = flag.argmax(axis=-1) # [Q, 2]
+ quad2edge = np.stack([
+ quad2edge[np.arange(Q)[:, None], np.arange(2)[None, :], (flag + 1) % 3],
+ quad2edge[np.arange(Q)[:, None], np.arange(2)[None, :], (flag + 2) % 3],
+ ], axis=-1).reshape(Q, 4) # [Q, 4]
+
+ quads = np.concatenate([
+ np.where(
+ (edges[quad2edge[:, 0:1], 1:] == edges[quad2edge[:, 1:2], :]).any(axis=-1),
+ edges[quad2edge[:, 0:1], [[0, 1]]],
+ edges[quad2edge[:, 0:1], [[1, 0]]],
+ ),
+ np.where(
+ (edges[quad2edge[:, 2:3], 1:] == edges[quad2edge[:, 3:4], :]).any(axis=-1),
+ edges[quad2edge[:, 2:3], [[0, 1]]],
+ edges[quad2edge[:, 2:3], [[1, 0]]],
+ ),
+ ], axis=1) # [Q, 4]
+
+ quad2adj = edge2face[quad2edge] # [Q, 4, 2]
+ quad2adj = quad2adj[quad2adj != quad2face[:, [0,0,1,1], None]].reshape(Q, 4) # [Q, 4]
+ quad2adj_valid = quad2adj != -1
+ quad2adj = face2edge[quad2adj] # [Q, 4, 3]
+ quad2adj[~quad2adj_valid, 0] = quad2edge[~quad2adj_valid]
+ quad2adj[~quad2adj_valid, 1:] = -1
+ quad2adj = quad2adj[quad2adj != quad2edge[..., None]].reshape(Q, 8) # [Q, 8]
+ edge_valid = -np.ones(E, dtype=np.int32)
+ edge_valid[quads_valid] = np.arange(Q)
+ quad2adj_valid = quad2adj != -1
+ quad2adj[quad2adj_valid] = edge_valid[quad2adj[quad2adj_valid]] # [Q, 8]
+
+ return quads, quad2edge, quad2adj, quads_valid
+
+
+def calc_quad_distortion(
+ vertices: np.ndarray,
+ quads: np.ndarray,
+):
+ """
+ Calculate the distortion of each candidate quad face.
+
+ Args:
+ vertices (np.ndarray): [N, 3] 3-dimensional vertices
+ quads (np.ndarray): [Q, 4] quad face indices
+
+ Returns:
+ distortion (np.ndarray): [Q] distortion of each quad face
+ """
+ edge0 = vertices[quads[:, 1]] - vertices[quads[:, 0]] # [Q, 3]
+ edge1 = vertices[quads[:, 2]] - vertices[quads[:, 1]] # [Q, 3]
+ edge2 = vertices[quads[:, 3]] - vertices[quads[:, 2]] # [Q, 3]
+ edge3 = vertices[quads[:, 0]] - vertices[quads[:, 3]] # [Q, 3]
+ cross = vertices[quads[:, 0]] - vertices[quads[:, 2]] # [Q, 3]
+
+ len0 = np.maximum(np.linalg.norm(edge0, axis=-1), 1e-10) # [Q]
+ len1 = np.maximum(np.linalg.norm(edge1, axis=-1), 1e-10) # [Q]
+ len2 = np.maximum(np.linalg.norm(edge2, axis=-1), 1e-10) # [Q]
+ len3 = np.maximum(np.linalg.norm(edge3, axis=-1), 1e-10) # [Q]
+ len_cross = np.maximum(np.linalg.norm(cross, axis=-1), 1e-10) # [Q]
+
+ angle0 = np.arccos(np.clip(np.sum(-edge0 * edge1, axis=-1) / (len0 * len1), -1, 1)) # [Q]
+ angle1 = np.arccos(np.clip(np.sum(-edge1 * cross, axis=-1) / (len1 * len_cross), -1, 1)) \
+ + np.arccos(np.clip(np.sum(cross * edge2, axis=-1) / (len_cross * len2), -1, 1)) # [Q]
+ angle2 = np.arccos(np.clip(np.sum(-edge2 * edge3, axis=-1) / (len2 * len3), -1, 1)) # [Q]
+ angle3 = np.arccos(np.clip(np.sum(-edge3 * -cross, axis=-1) / (len3 * len_cross), -1, 1)) \
+ + np.arccos(np.clip(np.sum(-cross * edge0, axis=-1) / (len_cross * len0), -1, 1)) # [Q]
+
+ normal0 = np.cross(edge0, edge1) # [Q, 3]
+ normal1 = np.cross(edge2, edge3) # [Q, 3]
+ normal0 = normal0 / np.maximum(np.linalg.norm(normal0, axis=-1, keepdims=True), 1e-10) # [Q, 3]
+ normal1 = normal1 / np.maximum(np.linalg.norm(normal1, axis=-1, keepdims=True), 1e-10) # [Q, 3]
+ angle_normal = np.arccos(np.clip(np.sum(normal0 * normal1, axis=-1), -1, 1)) # [Q]
+
+ D90 = np.pi / 2
+ D180 = np.pi
+ D360 = np.pi * 2
+ ang_eng = (np.abs(angle0 - D90)**2 + np.abs(angle1 - D90)**2 + np.abs(angle2 - D90)**2 + np.abs(angle3 - D90)**2) / 4 # [Q]
+ dist_eng = np.abs(angle0 - angle2)**2 / np.minimum(np.maximum(np.minimum(angle0, angle2), 1e-10), np.maximum(D180 - np.maximum(angle0, angle2), 1e-10)) \
+ + np.abs(angle1 - angle3)**2 / np.minimum(np.maximum(np.minimum(angle1, angle3), 1e-10), np.maximum(D180 - np.maximum(angle1, angle3), 1e-10)) # [Q]
+ plane_eng = np.where(angle_normal < D90/2, np.abs(angle_normal)**2, 1e10) # [Q]
+ eng = ang_eng + 2 * dist_eng + 2 * plane_eng # [Q]
+
+ return eng
+
+
+def calc_quad_direction(
+ vertices: np.ndarray,
+ quads: np.ndarray,
+ ):
+ """
+ Calculate the direction of each candidate quad face.
+
+ Args:
+ vertices (np.ndarray): [N, 3] 3-dimensional vertices
+ quads (np.ndarray): [Q, 4] quad face indices
+
+ Returns:
+ direction (np.ndarray): [Q, 4] direction of each quad face.
+ Represented by the angle between the crossing and each edge.
+ """
+ mid0 = (vertices[quads[:, 0]] + vertices[quads[:, 1]]) / 2 # [Q, 3]
+ mid1 = (vertices[quads[:, 1]] + vertices[quads[:, 2]]) / 2 # [Q, 3]
+ mid2 = (vertices[quads[:, 2]] + vertices[quads[:, 3]]) / 2 # [Q, 3]
+ mid3 = (vertices[quads[:, 3]] + vertices[quads[:, 0]]) / 2 # [Q, 3]
+
+ cross0 = mid2 - mid0 # [Q, 3]
+ cross1 = mid3 - mid1 # [Q, 3]
+ cross0 = cross0 / np.maximum(np.linalg.norm(cross0, axis=-1, keepdims=True), 1e-10) # [Q, 3]
+ cross1 = cross1 / np.maximum(np.linalg.norm(cross1, axis=-1, keepdims=True), 1e-10) # [Q, 3]
+
+ edge0 = vertices[quads[:, 1]] - vertices[quads[:, 0]] # [Q, 3]
+ edge1 = vertices[quads[:, 2]] - vertices[quads[:, 1]] # [Q, 3]
+ edge2 = vertices[quads[:, 3]] - vertices[quads[:, 2]] # [Q, 3]
+ edge3 = vertices[quads[:, 0]] - vertices[quads[:, 3]] # [Q, 3]
+ edge0 = edge0 / np.maximum(np.linalg.norm(edge0, axis=-1, keepdims=True), 1e-10) # [Q, 3]
+ edge1 = edge1 / np.maximum(np.linalg.norm(edge1, axis=-1, keepdims=True), 1e-10) # [Q, 3]
+ edge2 = edge2 / np.maximum(np.linalg.norm(edge2, axis=-1, keepdims=True), 1e-10) # [Q, 3]
+ edge3 = edge3 / np.maximum(np.linalg.norm(edge3, axis=-1, keepdims=True), 1e-10) # [Q, 3]
+
+ direction = np.stack([
+ np.arccos(np.clip(np.sum(cross0 * edge0, axis=-1), -1, 1)),
+ np.arccos(np.clip(np.sum(cross1 * edge1, axis=-1), -1, 1)),
+ np.arccos(np.clip(np.sum(-cross0 * edge2, axis=-1), -1, 1)),
+ np.arccos(np.clip(np.sum(-cross1 * edge3, axis=-1), -1, 1)),
+ ], axis=-1) # [Q, 4]
+
+ return direction
+
+
+def calc_quad_smoothness(
+ quad2edge: np.ndarray,
+ quad2adj: np.ndarray,
+ quads_direction: np.ndarray,
+ ):
+ """
+ Calculate the smoothness of each candidate quad face connection.
+
+ Args:
+ quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face
+ quads_direction (np.ndarray): [Q, 4] direction of each quad face
+
+ Returns:
+ smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection
+ """
+ Q = quad2adj.shape[0]
+ quad2adj_valid = quad2adj != -1
+ connections = np.stack([
+ np.arange(Q)[:, None].repeat(8, axis=1),
+ quad2adj,
+ ], axis=-1)[quad2adj_valid] # [C, 2]
+ shared_edge_idx_0 = np.array([[0, 0, 1, 1, 2, 2, 3, 3]]).repeat(Q, axis=0)[quad2adj_valid] # [C]
+ shared_edge_idx_1 = np.argmax(quad2edge[quad2adj][quad2adj_valid] == quad2edge[connections[:, 0], shared_edge_idx_0][:, None], axis=-1) # [C]
+ valid_smoothness = np.abs(quads_direction[connections[:, 0], shared_edge_idx_0] - quads_direction[connections[:, 1], shared_edge_idx_1])**2 # [C]
+ smoothness = np.zeros([Q, 8], dtype=np.float32)
+ smoothness[quad2adj_valid] = valid_smoothness
+ return smoothness
+
+
+def sovle_quad(
+ face2edge: np.ndarray,
+ edge2face: np.ndarray,
+ quad2adj: np.ndarray,
+ quads_distortion: np.ndarray,
+ quads_smoothness: np.ndarray,
+ quads_valid: np.ndarray,
+ ):
+ """
+ Solve the quad mesh from the candidate quad faces.
+
+ Args:
+ face2edge (np.ndarray): [T, 3] face to edge relation
+ edge2face (np.ndarray): [E, 2] edge to face relation
+ quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face
+ quads_distortion (np.ndarray): [Q] distortion of each quad face
+ quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection
+ quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid
+
+ Returns:
+ weights (np.ndarray): [Q] weight of each valid quad face
+ """
+ T = face2edge.shape[0]
+ E = edge2face.shape[0]
+ Q = quads_distortion.shape[0]
+ edge_valid = -np.ones(E, dtype=np.int32)
+ edge_valid[quads_valid] = np.arange(Q)
+
+ quads_connection = np.stack([
+ np.arange(Q)[:, None].repeat(8, axis=1),
+ quad2adj,
+ ], axis=-1)[quad2adj != -1] # [C, 2]
+ quads_connection = np.sort(quads_connection, axis=-1) # [C, 2]
+ quads_connection, quads_connection_idx = np.unique(quads_connection, axis=0, return_index=True) # [C, 2], [C]
+ quads_smoothness = quads_smoothness[quad2adj != -1] # [C]
+ quads_smoothness = quads_smoothness[quads_connection_idx] # [C]
+ C = quads_connection.shape[0]
+
+ # Construct the linear programming problem
+
+ # Variables:
+ # quads_weight: [Q] weight of each quad face
+ # tri_min_weight: [T] minimum weight of each triangle face
+ # conn_min_weight: [C] minimum weight of each quad face connection
+ # conn_max_weight: [C] maximum weight of each quad face connection
+ # Objective:
+ # mimi
+
+ c = np.concatenate([
+ quads_distortion - 3,
+ quads_smoothness*4 - 2,
+ quads_smoothness*4,
+ ], axis=0) # [Q+C]
+
+ A_ub_triplet = np.concatenate([
+ np.stack([np.arange(T), edge_valid[face2edge[:, 0]], np.ones(T)], axis=1), # [T, 3]
+ np.stack([np.arange(T), edge_valid[face2edge[:, 1]], np.ones(T)], axis=1), # [T, 3]
+ np.stack([np.arange(T), edge_valid[face2edge[:, 2]], np.ones(T)], axis=1), # [T, 3]
+ np.stack([np.arange(T, T+C), np.arange(Q, Q+C), np.ones(C)], axis=1), # [C, 3]
+ np.stack([np.arange(T, T+C), quads_connection[:, 0], -np.ones(C)], axis=1), # [C, 3]
+ np.stack([np.arange(T, T+C), quads_connection[:, 1], -np.ones(C)], axis=1), # [C, 3]
+ np.stack([np.arange(T+C, T+2*C), np.arange(Q+C, Q+2*C), -np.ones(C)], axis=1), # [C, 3]
+ np.stack([np.arange(T+C, T+2*C), quads_connection[:, 0], np.ones(C)], axis=1), # [C, 3]
+ np.stack([np.arange(T+C, T+2*C), quads_connection[:, 1], np.ones(C)], axis=1), # [C, 3]
+ ], axis=0) # [3T+6C, 3]
+ A_ub_triplet = A_ub_triplet[A_ub_triplet[:, 1] != -1] # [3T', 3]
+ A_ub = sp.sparse.coo_matrix((A_ub_triplet[:, 2], (A_ub_triplet[:, 0], A_ub_triplet[:, 1])), shape=[T+2*C, Q+2*C]) # [T,
+ b_ub = np.concatenate([np.ones(T), -np.ones(C), np.ones(C)], axis=0) # [T+2C]
+ bound = np.stack([
+ np.concatenate([np.zeros(Q), -np.ones(C), np.zeros(C)], axis=0),
+ np.concatenate([np.ones(Q), np.ones(C), np.ones(C)], axis=0),
+ ], axis=1) # [Q+2C, 2]
+ A_eq = None
+ b_eq = None
+
+ print('Solver statistics:')
+ print(f' #T = {T}')
+ print(f' #Q = {Q}')
+ print(f' #C = {C}')
+
+ # Solve the linear programming problem
+ last_num_valid = 0
+ for i in range(100):
+ res_ = spopt.linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=bound)
+ if not res_.success:
+ print(f' Iter {i} | Failed with {res_.message}')
+ break
+ res = res_
+ weights = res.x[:Q]
+ valid = (weights > 0.5)
+ num_valid = valid.sum()
+ print(f' Iter {i} | #Q_valid = {num_valid}')
+ if num_valid == last_num_valid:
+ break
+ last_num_valid = num_valid
+ A_eq_triplet = np.stack([
+ np.arange(num_valid),
+ np.arange(Q)[valid],
+ np.ones(num_valid),
+ ], axis=1) # [num_valid, 3]
+ A_eq = sp.sparse.coo_matrix((A_eq_triplet[:, 2], (A_eq_triplet[:, 0], A_eq_triplet[:, 1])), shape=[num_valid, Q+2*C]) # [num_valid, Q+C]
+ b_eq = np.where(weights[valid] > 0.5, 1, 0) # [num_valid]
+
+ # Return the result
+ quads_weight = res.x[:Q]
+ conn_min_weight = res.x[Q:Q+C]
+ conn_max_weight = res.x[Q+C:Q+2*C]
+ return quads_weight, conn_min_weight, conn_max_weight
+
+
+def sovle_quad_qp(
+ face2edge: np.ndarray,
+ edge2face: np.ndarray,
+ quad2adj: np.ndarray,
+ quads_distortion: np.ndarray,
+ quads_smoothness: np.ndarray,
+ quads_valid: np.ndarray,
+ ):
+ """
+ Solve the quad mesh from the candidate quad faces.
+
+ Args:
+ face2edge (np.ndarray): [T, 3] face to edge relation
+ edge2face (np.ndarray): [E, 2] edge to face relation
+ quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face
+ quads_distortion (np.ndarray): [Q] distortion of each quad face
+ quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection
+ quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid
+
+ Returns:
+ weights (np.ndarray): [Q] weight of each valid quad face
+ """
+ T = face2edge.shape[0]
+ E = edge2face.shape[0]
+ Q = quads_distortion.shape[0]
+ edge_valid = -np.ones(E, dtype=np.int32)
+ edge_valid[quads_valid] = np.arange(Q)
+
+ # Construct the quadratic programming problem
+ C_smoothness_triplet = np.stack([
+ np.arange(Q)[:, None].repeat(8, axis=1)[quad2adj != -1],
+ quad2adj[quad2adj != -1],
+ 5 * quads_smoothness[quad2adj != -1],
+ ], axis=-1) # [C, 3]
+ # C_smoothness_triplet = np.concatenate([
+ # C_smoothness_triplet,
+ # np.stack([np.arange(Q), np.arange(Q), 20*np.ones(Q)], axis=1),
+ # ], axis=0) # [C+Q, 3]
+ C_smoothness = sp.sparse.coo_matrix((C_smoothness_triplet[:, 2], (C_smoothness_triplet[:, 0], C_smoothness_triplet[:, 1])), shape=[Q, Q]) # [Q, Q]
+ C_smoothness = C_smoothness.tocsc()
+ C_dist = quads_distortion - 20 # [Q]
+
+ A_eq = sp.sparse.coo_matrix((np.zeros(Q), (np.zeros(Q), np.arange(Q))), shape=[1, Q]) # [1, Q]\
+ A_eq = A_eq.tocsc()
+ b_eq = np.array([0])
+
+ A_ub_triplet = np.concatenate([
+ np.stack([np.arange(T), edge_valid[face2edge[:, 0]], np.ones(T)], axis=1), # [T, 3]
+ np.stack([np.arange(T), edge_valid[face2edge[:, 1]], np.ones(T)], axis=1), # [T, 3]
+ np.stack([np.arange(T), edge_valid[face2edge[:, 2]], np.ones(T)], axis=1), # [T, 3]
+ ], axis=0) # [3T, 3]
+ A_ub_triplet = A_ub_triplet[A_ub_triplet[:, 1] != -1] # [3T', 3]
+ A_ub = sp.sparse.coo_matrix((A_ub_triplet[:, 2], (A_ub_triplet[:, 0], A_ub_triplet[:, 1])), shape=[T, Q]) # [T, Q]
+ A_ub = A_ub.tocsc()
+ b_ub = np.ones(T)
+
+ lb = np.zeros(Q)
+ ub = np.ones(Q)
+
+ import piqp
+ solver = piqp.SparseSolver()
+ solver.settings.verbose = True
+ solver.settings.compute_timings = True
+ solver.setup(C_smoothness, C_dist, A_eq, b_eq, A_ub, b_ub, lb, ub)
+
+ status = solver.solve()
+
+ # x = cp.Variable(Q)
+ # prob = cp.Problem(
+ # cp.Minimize(cp.quad_form(x, C_smoothness) + C_dist.T @ x),
+ # [
+ # A_ub @ x <= b_ub,
+ # x >= 0, x <= 1,
+ # ]
+ # )
+
+ # # Solve the quadratic programming problem
+ # prob.solve(solver=cp.PIQP, verbose=True)
+
+ # Return the result
+ weights = solver.result.x
+ return weights
+
+
+def tri_to_quad(
+ vertices: np.ndarray,
+ faces: np.ndarray,
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Convert a triangle mesh to a quad mesh.
+ NOTE: The input mesh must be a manifold mesh.
+
+ Args:
+ vertices (np.ndarray): [N, 3] 3-dimensional vertices
+ faces (np.ndarray): [T, 3] triangular face indices
+
+ Returns:
+ vertices (np.ndarray): [N_, 3] 3-dimensional vertices
+ faces (np.ndarray): [Q, 4] quad face indices
+ """
+ raise NotImplementedError
+
+
+if __name__ == '__main__':
+ import os
+ import sys
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')))
+ import utils3d
+ import numpy as np
+ import cv2
+ from vis import vis_edge_color
+
+ file = 'miku'
+
+ vertices, faces = utils3d.io.read_ply(f'test/assets/{file}.ply')
+ edges, edge2face, face2edge, face2face = calc_relations(faces)
+ quad_cands, quad2edge, quad2adj, quad_valid = calc_quad_candidates(edges, face2edge, edge2face)
+ distortion = calc_quad_distortion(vertices, quad_cands)
+ direction = calc_quad_direction(vertices, quad_cands)
+ smoothness = calc_quad_smoothness(quad2edge, quad2adj, direction)
+ boundary_edges = edges[edge2face[:, 1] == -1]
+ quads_weight, conn_min_weight, conn_max_weight = sovle_quad(face2edge, edge2face, quad2adj, distortion, smoothness, quad_valid)
+ quads = quad_cands[quads_weight > 0.5]
+ print('Mesh statistics')
+ print(f' #V = {vertices.shape[0]}')
+ print(f' #F = {faces.shape[0]}')
+ print(f' #E = {edges.shape[0]}')
+ print(f' #B = {boundary_edges.shape[0]}')
+ print(f' #Q_cand = {quad_cands.shape[0]}')
+ print(f' #Q = {quads.shape[0]}')
+
+ utils3d.io.write_ply(f'test/assets/{file}_boundary_edges.ply', vertices=vertices, edges=boundary_edges)
+ utils3d.io.write_ply(f'test/assets/{file}_quad_candidates.ply', vertices=vertices, faces=quads)
+
+ edge_colors = np.zeros([edges.shape[0], 3], dtype=np.uint8)
+ distortion = (distortion - distortion.min()) / (distortion.max() - distortion.min())
+ distortion = (distortion * 255).astype(np.uint8)
+ edge_colors[quad_valid] = cv2.cvtColor(cv2.applyColorMap(distortion, cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
+ utils3d.io.write_ply(f'test/assets/{file}_quad_candidates_distortion.ply', **vis_edge_color(vertices, edges, edge_colors))
+
+ edge_colors = np.zeros([edges.shape[0], 3], dtype=np.uint8)
+ edge_colors[quad_valid] = cv2.cvtColor(cv2.applyColorMap((quads_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
+ utils3d.io.write_ply(f'test/assets/{file}_quad_candidates_weights.ply', **vis_edge_color(vertices, edges, edge_colors))
+ utils3d.io.write_ply(f'test/assets/{file}_quad.ply', vertices=vertices, faces=quads)
+
+ quad_centers = vertices[quad_cands].mean(axis=1)
+ conns = np.stack([
+ np.arange(quad_cands.shape[0])[:, None].repeat(8, axis=1),
+ quad2adj,
+ ], axis=-1)[quad2adj != -1] # [C, 2]
+ conns, conns_idx = np.unique(np.sort(conns, axis=-1), axis=0, return_index=True) # [C, 2], [C]
+ smoothness = smoothness[quad2adj != -1][conns_idx] # [C]
+ conns_color = cv2.cvtColor(cv2.applyColorMap((smoothness * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
+ utils3d.io.write_ply(f'test/assets/{file}_quad_conn_smoothness.ply', **vis_edge_color(quad_centers, conns, conns_color))
+ conns_color = cv2.cvtColor(cv2.applyColorMap((conn_min_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
+ utils3d.io.write_ply(f'test/assets/{file}_quad_conn_min.ply', **vis_edge_color(quad_centers, conns, conns_color))
+ conns_color = cv2.cvtColor(cv2.applyColorMap((conn_max_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
+ utils3d.io.write_ply(f'test/assets/{file}_quad_conn_max.ply', **vis_edge_color(quad_centers, conns, conns_color))
+
+
\ No newline at end of file
diff --git a/src/utils3d/numpy/rasterization.py b/src/utils3d/numpy/rasterization.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f0f0db55d87f37f108a778dac29ae6320418f3a
--- /dev/null
+++ b/src/utils3d/numpy/rasterization.py
@@ -0,0 +1,469 @@
+import os
+from typing import *
+
+import numpy as np
+import moderngl
+
+from . import transforms, utils, mesh
+
+
+__all__ = [
+ 'RastContext',
+ 'rasterize_triangle_faces',
+ 'rasterize_edges',
+ 'texture',
+ 'test_rasterization',
+ 'warp_image_by_depth',
+]
+
+
+def map_np_dtype(dtype) -> str:
+ if dtype == int:
+ return 'i4'
+ elif dtype == np.uint8:
+ return 'u1'
+ elif dtype == np.uint32:
+ return 'u2'
+ elif dtype == np.float16:
+ return 'f2'
+ elif dtype == np.float32:
+ return 'f4'
+
+
+def one_value(dtype):
+ if dtype == 'u1':
+ return 255
+ elif dtype == 'u2':
+ return 65535
+ else:
+ return 1
+
+
+class RastContext:
+ def __init__(self, *args, **kwargs):
+ """
+ Create a moderngl context.
+
+ Args:
+ See moderngl.create_context
+ """
+ if len(args) == 1 and isinstance(args[0], moderngl.Context):
+ self.mgl_ctx = args[0]
+ else:
+ self.mgl_ctx = moderngl.create_context(*args, **kwargs)
+ self.__prog_src = {}
+ self.__prog = {}
+
+ def program_vertex_attribute(self, n: int) -> moderngl.Program:
+ assert n in [1, 2, 3, 4], 'vertex attribute only supports channels 1, 2, 3, 4'
+
+ if 'vertex_attribute_vsh' not in self.__prog_src:
+ with open(os.path.join(os.path.dirname(__file__), 'shaders', 'vertex_attribute.vsh'), 'r') as f:
+ self.__prog_src['vertex_attribute_vsh'] = f.read()
+ if 'vertex_attribute_fsh' not in self.__prog_src:
+ with open(os.path.join(os.path.dirname(__file__), 'shaders', 'vertex_attribute.fsh'), 'r') as f:
+ self.__prog_src['vertex_attribute_fsh'] = f.read()
+
+ if f'vertex_attribute_{n}' not in self.__prog:
+ vsh = self.__prog_src['vertex_attribute_vsh'].replace('vecN', f'vec{n}')
+ fsh = self.__prog_src['vertex_attribute_fsh'].replace('vecN', f'vec{n}')
+ self.__prog[f'vertex_attribute_{n}'] = self.mgl_ctx.program(vertex_shader=vsh, fragment_shader=fsh)
+
+ return self.__prog[f'vertex_attribute_{n}']
+
+ def program_texture(self, n: int) -> moderngl.Program:
+ assert n in [1, 2, 3, 4], 'texture only supports channels 1, 2, 3, 4'
+
+ if 'texture_vsh' not in self.__prog_src:
+ with open(os.path.join(os.path.dirname(__file__), 'shaders', 'texture.vsh'), 'r') as f:
+ self.__prog_src['texture_vsh'] = f.read()
+ if 'texture_fsh' not in self.__prog_src:
+ with open(os.path.join(os.path.dirname(__file__), 'shaders', 'texture.fsh'), 'r') as f:
+ self.__prog_src['texture_fsh'] = f.read()
+
+ if f'texture_{n}' not in self.__prog:
+ vsh = self.__prog_src['texture_vsh'].replace('vecN', f'vec{n}')
+ fsh = self.__prog_src['texture_fsh'].replace('vecN', f'vec{n}')
+ self.__prog[f'texture_{n}'] = self.mgl_ctx.program(vertex_shader=vsh, fragment_shader=fsh)
+ self.__prog[f'texture_{n}']['tex'] = 0
+ self.__prog[f'texture_{n}']['uv'] = 1
+
+ return self.__prog[f'texture_{n}']
+
+
+def rasterize_triangle_faces(
+ ctx: RastContext,
+ vertices: np.ndarray,
+ faces: np.ndarray,
+ attr: np.ndarray,
+ width: int,
+ height: int,
+ transform: np.ndarray = None,
+ cull_backface: bool = True,
+ return_depth: bool = False,
+ image: np.ndarray = None,
+ depth: np.ndarray = None
+) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Rasterize vertex attribute.
+
+ Args:
+ vertices (np.ndarray): [N, 3]
+ faces (np.ndarray): [T, 3]
+ attr (np.ndarray): [N, C]
+ width (int): width of rendered image
+ height (int): height of rendered image
+ transform (np.ndarray): [4, 4] model-view-projection transformation matrix.
+ cull_backface (bool): whether to cull backface
+ image: (np.ndarray): [H, W, C] background image
+ depth: (np.ndarray): [H, W] background depth
+
+ Returns:
+ image (np.ndarray): [H, W, C] rendered image
+ depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None.
+ """
+ assert vertices.ndim == 2 and vertices.shape[1] == 3
+ assert faces.ndim == 2 and faces.shape[1] == 3, f"Faces should be a 2D array with shape (T, 3), but got {faces.shape}"
+ assert attr.ndim == 2 and attr.shape[1] in [1, 2, 3, 4], f'Vertex attribute only supports channels 1, 2, 3, 4, but got {attr.shape}'
+ assert vertices.shape[0] == attr.shape[0]
+ assert vertices.dtype == np.float32
+ assert faces.dtype == np.uint32 or faces.dtype == np.int32
+ assert attr.dtype == np.float32, "Attribute should be float32"
+ assert transform is None or transform.shape == (4, 4), f"Transform should be a 4x4 matrix, but got {transform.shape}"
+ assert transform is None or transform.dtype == np.float32, f"Transform should be float32, but got {transform.dtype}"
+ if image is not None:
+ assert image.ndim == 3 and image.shape == (height, width, attr.shape[1]), f"Image should be a 3D array with shape (H, W, {attr.shape[1]}), but got {image.shape}"
+ assert image.dtype == np.float32, f"Image should be float32, but got {image.dtype}"
+ if depth is not None:
+ assert depth.ndim == 2 and depth.shape == (height, width), f"Depth should be a 2D array with shape (H, W), but got {depth.shape}"
+ assert depth.dtype == np.float32, f"Depth should be float32, but got {depth.dtype}"
+
+ C = attr.shape[1]
+ prog = ctx.program_vertex_attribute(C)
+
+ transform = np.eye(4, np.float32) if transform is None else transform
+
+ # Create buffers
+ ibo = ctx.mgl_ctx.buffer(np.ascontiguousarray(faces, dtype='i4'))
+ vbo_vertices = ctx.mgl_ctx.buffer(np.ascontiguousarray(vertices, dtype='f4'))
+ vbo_attr = ctx.mgl_ctx.buffer(np.ascontiguousarray(attr, dtype='f4'))
+ vao = ctx.mgl_ctx.vertex_array(
+ prog,
+ [
+ (vbo_vertices, '3f', 'i_position'),
+ (vbo_attr, f'{C}f', 'i_attr'),
+ ],
+ ibo,
+ mode=moderngl.TRIANGLES,
+ )
+
+ # Create framebuffer
+ image_tex = ctx.mgl_ctx.texture((width, height), C, dtype='f4', data=np.ascontiguousarray(image[::-1, :, :]) if image is not None else None)
+ depth_tex = ctx.mgl_ctx.depth_texture((width, height), data=np.ascontiguousarray(depth[::-1, :]) if depth is not None else None)
+ fbo = ctx.mgl_ctx.framebuffer(
+ color_attachments=[image_tex],
+ depth_attachment=depth_tex,
+ )
+
+ # Render
+ prog['u_mvp'].write(transform.transpose().copy().astype('f4'))
+ fbo.use()
+ fbo.viewport = (0, 0, width, height)
+ ctx.mgl_ctx.depth_func = '<'
+ if depth is None:
+ ctx.mgl_ctx.clear(depth=1.0)
+ ctx.mgl_ctx.enable(ctx.mgl_ctx.DEPTH_TEST)
+ if cull_backface:
+ ctx.mgl_ctx.enable(ctx.mgl_ctx.CULL_FACE)
+ else:
+ ctx.mgl_ctx.disable(ctx.mgl_ctx.CULL_FACE)
+ vao.render()
+ ctx.mgl_ctx.disable(ctx.mgl_ctx.DEPTH_TEST)
+
+ # Read
+ image = np.zeros((height, width, C), dtype='f4')
+ image_tex.read_into(image)
+ image = image[::-1, :, :]
+ if return_depth:
+ depth = np.zeros((height, width), dtype='f4')
+ depth_tex.read_into(depth)
+ depth = depth[::-1, :]
+ else:
+ depth = None
+
+ # Release
+ vao.release()
+ ibo.release()
+ vbo_vertices.release()
+ vbo_attr.release()
+ fbo.release()
+ image_tex.release()
+ depth_tex.release()
+
+ return image, depth
+
+
+def rasterize_edges(
+ ctx: RastContext,
+ vertices: np.ndarray,
+ edges: np.ndarray,
+ attr: np.ndarray,
+ width: int,
+ height: int,
+ transform: np.ndarray = None,
+ line_width: float = 1.0,
+ return_depth: bool = False,
+ image: np.ndarray = None,
+ depth: np.ndarray = None
+) -> Tuple[np.ndarray, ...]:
+ """
+ Rasterize vertex attribute.
+
+ Args:
+ vertices (np.ndarray): [N, 3]
+ faces (np.ndarray): [T, 3]
+ attr (np.ndarray): [N, C]
+ width (int): width of rendered image
+ height (int): height of rendered image
+ transform (np.ndarray): [4, 4] model-view-projection matrix
+ line_width (float): width of line. Defaults to 1.0. NOTE: Values other than 1.0 may not work across all platforms.
+ cull_backface (bool): whether to cull backface
+
+ Returns:
+ image (np.ndarray): [H, W, C] rendered image
+ depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None.
+ """
+ assert vertices.ndim == 2 and vertices.shape[1] == 3
+ assert edges.ndim == 2 and edges.shape[1] == 2, f"Edges should be a 2D array with shape (T, 2), but got {edges.shape}"
+ assert attr.ndim == 2 and attr.shape[1] in [1, 2, 3, 4], f'Vertex attribute only supports channels 1, 2, 3, 4, but got {attr.shape}'
+ assert vertices.shape[0] == attr.shape[0]
+ assert vertices.dtype == np.float32
+ assert edges.dtype == np.uint32 or edges.dtype == np.int32
+ assert attr.dtype == np.float32, "Attribute should be float32"
+
+ C = attr.shape[1]
+ prog = ctx.program_vertex_attribute(C)
+
+ transform = transform if transform is not None else np.eye(4, np.float32)
+
+ # Create buffers
+ ibo = ctx.mgl_ctx.buffer(np.ascontiguousarray(edges, dtype='i4'))
+ vbo_vertices = ctx.mgl_ctx.buffer(np.ascontiguousarray(vertices, dtype='f4'))
+ vbo_attr = ctx.mgl_ctx.buffer(np.ascontiguousarray(attr, dtype='f4'))
+ vao = ctx.mgl_ctx.vertex_array(
+ prog,
+ [
+ (vbo_vertices, '3f', 'i_position'),
+ (vbo_attr, f'{C}f', 'i_attr'),
+ ],
+ ibo,
+ mode=moderngl.LINES,
+ )
+
+ # Create framebuffer
+ image_tex = ctx.mgl_ctx.texture((width, height), C, dtype='f4', data=np.ascontiguousarray(image[::-1, :, :]) if image is not None else None)
+ depth_tex = ctx.mgl_ctx.depth_texture((width, height), data=np.ascontiguousarray(depth[::-1, :]) if depth is not None else None)
+ fbo = ctx.mgl_ctx.framebuffer(
+ color_attachments=[image_tex],
+ depth_attachment=depth_tex,
+ )
+
+ # Render
+ prog['u_mvp'].write(transform.transpose().copy().astype('f4'))
+ fbo.use()
+ fbo.viewport = (0, 0, width, height)
+ if depth is None:
+ ctx.mgl_ctx.clear(depth=1.0)
+ ctx.mgl_ctx.depth_func = '<'
+ ctx.mgl_ctx.enable(ctx.mgl_ctx.DEPTH_TEST)
+ ctx.mgl_ctx.line_width = line_width
+ vao.render()
+ ctx.mgl_ctx.disable(ctx.mgl_ctx.DEPTH_TEST)
+
+ # Read
+ image = np.zeros((height, width, C), dtype='f4')
+ image_tex.read_into(image)
+ image = image[::-1, :, :]
+ if return_depth:
+ depth = np.zeros((height, width), dtype='f4')
+ depth_tex.read_into(depth)
+ depth = depth[::-1, :]
+ else:
+ depth = None
+
+ # Release
+ vao.release()
+ ibo.release()
+ vbo_vertices.release()
+ vbo_attr.release()
+ fbo.release()
+ image_tex.release()
+ depth_tex.release()
+
+ return image, depth
+
+
+def texture(
+ ctx: RastContext,
+ uv: np.ndarray,
+ texture: np.ndarray,
+ interpolation: str= 'linear',
+ wrap: str = 'clamp'
+) -> np.ndarray:
+ """
+ Given an UV image, texturing from the texture map
+ """
+ assert len(texture.shape) == 3 and 1 <= texture.shape[2] <= 4
+ assert uv.shape[2] == 2
+ height, width = uv.shape[:2]
+ texture_dtype = map_np_dtype(texture.dtype)
+
+ # Create VAO
+ screen_quad_vbo = ctx.mgl_ctx.buffer(np.array([[-1, -1], [1, -1], [1, 1], [-1, 1]], dtype='f4'))
+ screen_quad_ibo = ctx.mgl_ctx.buffer(np.array([0, 1, 2, 0, 2, 3], dtype=np.int32))
+ screen_quad_vao = ctx.mgl_ctx.vertex_array(ctx.program_texture(texture.shape[2]), [(screen_quad_vbo, '2f4', 'in_vert')], index_buffer=screen_quad_ibo, index_element_size=4)
+
+ # Create texture, set filter and bind. TODO: min mag filter, mipmap
+ texture_tex = ctx.mgl_ctx.texture((texture.shape[1], texture.shape[0]), texture.shape[2], dtype=texture_dtype, data=np.ascontiguousarray(texture))
+ if interpolation == 'linear':
+ texture_tex.filter = (moderngl.LINEAR, moderngl.LINEAR)
+ elif interpolation == 'nearest':
+ texture_tex.filter = (moderngl.NEAREST, moderngl.NEAREST)
+ texture_tex.use(location=0)
+ texture_uv = ctx.mgl_ctx.texture((width, height), 2, dtype='f4', data=np.ascontiguousarray(uv.astype('f4', copy=False)))
+ texture_uv.filter = (moderngl.NEAREST, moderngl.NEAREST)
+ texture_uv.use(location=1)
+
+ # Create render buffer and frame buffer
+ rb = ctx.mgl_ctx.renderbuffer((uv.shape[1], uv.shape[0]), texture.shape[2], dtype=texture_dtype)
+ fbo = ctx.mgl_ctx.framebuffer(color_attachments=[rb])
+
+ # Render
+ fbo.use()
+ fbo.viewport = (0, 0, width, height)
+ ctx.mgl_ctx.disable(ctx.mgl_ctx.BLEND)
+ screen_quad_vao.render()
+
+ # Read buffer
+ image_buffer = np.frombuffer(fbo.read(components=texture.shape[2], attachment=0, dtype=texture_dtype), dtype=texture_dtype).reshape((height, width, texture.shape[2]))
+
+ # Release
+ texture_tex.release()
+ rb.release()
+ fbo.release()
+
+ return image_buffer
+
+
+def warp_image_by_depth(
+ ctx: RastContext,
+ src_depth: np.ndarray,
+ src_image: np.ndarray = None,
+ width: int = None,
+ height: int = None,
+ *,
+ extrinsics_src: np.ndarray = None,
+ extrinsics_tgt: np.ndarray = None,
+ intrinsics_src: np.ndarray = None,
+ intrinsics_tgt: np.ndarray = None,
+ near: float = 0.1,
+ far: float = 100.0,
+ cull_backface: bool = True,
+ ssaa: int = 1,
+ return_depth: bool = False,
+) -> Tuple[np.ndarray, ...]:
+ """
+ Warp image by depth map.
+
+ Args:
+ ctx (RastContext): rasterizer context
+ src_depth (np.ndarray): [H, W]
+ src_image (np.ndarray, optional): [H, W, C]. The image to warp. Defaults to None (use uv coordinates).
+ width (int, optional): width of the output image. None to use depth map width. Defaults to None.
+ height (int, optional): height of the output image. None to use depth map height. Defaults to None.
+ extrinsics_src (np.ndarray, optional): extrinsics matrix of the source camera. Defaults to None (identity).
+ extrinsics_tgt (np.ndarray, optional): extrinsics matrix of the target camera. Defaults to None (identity).
+ intrinsics_src (np.ndarray, optional): intrinsics matrix of the source camera. Defaults to None (use the same as intrinsics_tgt).
+ intrinsics_tgt (np.ndarray, optional): intrinsics matrix of the target camera. Defaults to None (use the same as intrinsics_src).
+ cull_backface (bool, optional): whether to cull backface. Defaults to True.
+ ssaa (int, optional): super sampling anti-aliasing. Defaults to 1.
+
+ Returns:
+ tgt_image (np.ndarray): [H, W, C] warped image (or uv coordinates if image is None).
+ tgt_depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None.
+ """
+ assert src_depth.ndim == 2
+
+ if width is None:
+ width = src_depth.shape[1]
+ if height is None:
+ height = src_depth.shape[0]
+ if src_image is not None:
+ assert src_image.shape[-2:] == src_depth.shape[-2:], f'Shape of source image {src_image.shape} does not match shape of source depth {src_depth.shape}'
+
+ # set up default camera parameters
+ extrinsics_src = np.eye(4) if extrinsics_src is None else extrinsics_src
+ extrinsics_tgt = np.eye(4) if extrinsics_tgt is None else extrinsics_tgt
+ intrinsics_src = intrinsics_tgt if intrinsics_src is None else intrinsics_src
+ intrinsics_tgt = intrinsics_src if intrinsics_tgt is None else intrinsics_tgt
+
+ assert all(x is not None for x in [extrinsics_src, extrinsics_tgt, intrinsics_src, intrinsics_tgt]), "Make sure you have provided all the necessary camera parameters."
+
+ # check shapes
+ assert extrinsics_src.shape == (4, 4) and extrinsics_tgt.shape == (4, 4)
+ assert intrinsics_src.shape == (3, 3) and intrinsics_tgt.shape == (3, 3)
+
+ # convert to view and perspective matrices
+ view_tgt = transforms.extrinsics_to_view(extrinsics_tgt)
+ perspective_tgt = transforms.intrinsics_to_perspective(intrinsics_tgt, near=near, far=far)
+
+ # unproject depth map
+ uv, faces = utils.image_mesh(*src_depth.shape[-2:])
+ pts = transforms.unproject_cv(uv, src_depth.reshape(-1), extrinsics_src, intrinsics_src)
+ faces = mesh.triangulate(faces, vertices=pts)
+
+ # rasterize attributes
+ if src_image is not None:
+ attr = src_image.reshape(-1, src_image.shape[-1])
+ else:
+ attr = uv
+
+ tgt_image, tgt_depth = rasterize_triangle_faces(
+ ctx,
+ pts,
+ faces,
+ attr,
+ width * ssaa,
+ height * ssaa,
+ transform=perspective_tgt @ view_tgt,
+ cull_backface=cull_backface,
+ return_depth=return_depth,
+ )
+
+ if ssaa > 1:
+ tgt_image = tgt_image.reshape(height, ssaa, width, ssaa, -1).mean(axis=(1, 3))
+ tgt_depth = tgt_depth.reshape(height, ssaa, width, ssaa, -1).mean(axis=(1, 3)) if return_depth else None
+
+ return tgt_image, tgt_depth
+
+def test_rasterization(ctx: RastContext):
+ """
+ Test if rasterization works. It will render a cube with random colors and save it as a CHECKME.png file.
+ """
+ vertices, faces = utils.cube(tri=True)
+ attr = np.random.rand(len(vertices), 3).astype(np.float32)
+ perspective = transforms.perspective(np.deg2rad(60), 1, 0.01, 100)
+ view = transforms.view_look_at(np.array([2, 2, 2]), np.array([0, 0, 0]), np.array([0, 1, 0]))
+ image, depth = rasterize_triangle_faces(
+ ctx,
+ vertices,
+ faces,
+ attr,
+ 512, 512,
+ transform=(perspective @ view).astype(np.float32),
+ cull_backface=False,
+ return_depth=True,
+ )
+ import cv2
+ cv2.imwrite('CHECKME.png', cv2.cvtColor((image.clip(0, 1) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
+
\ No newline at end of file
diff --git a/src/utils3d/numpy/shaders/texture.fsh b/src/utils3d/numpy/shaders/texture.fsh
new file mode 100644
index 0000000000000000000000000000000000000000..c8be72f94cbf38fb0b2a9609e8db4d50ac7753d6
--- /dev/null
+++ b/src/utils3d/numpy/shaders/texture.fsh
@@ -0,0 +1,11 @@
+#version 330
+
+uniform sampler2D tex;
+uniform sampler2D uv;
+
+in vec2 scr_coord;
+out vecN tex_color;
+
+void main() {
+ tex_color = vecN(texture(tex, texture(uv, scr_coord).xy));
+}
\ No newline at end of file
diff --git a/src/utils3d/numpy/shaders/texture.vsh b/src/utils3d/numpy/shaders/texture.vsh
new file mode 100644
index 0000000000000000000000000000000000000000..f96c6b14a8931fbcd5f4ca22ea917b9c8f80f195
--- /dev/null
+++ b/src/utils3d/numpy/shaders/texture.vsh
@@ -0,0 +1,9 @@
+ #version 330 core
+
+in vec2 in_vert;
+out vec2 scr_coord;
+
+void main() {
+ scr_coord = in_vert * 0.5 + 0.5;
+ gl_Position = vec4(in_vert, 0., 1.);
+}
\ No newline at end of file
diff --git a/src/utils3d/numpy/shaders/vertex_attribute.fsh b/src/utils3d/numpy/shaders/vertex_attribute.fsh
new file mode 100644
index 0000000000000000000000000000000000000000..54409764c5600ee190db89313b07dd91b940d6eb
--- /dev/null
+++ b/src/utils3d/numpy/shaders/vertex_attribute.fsh
@@ -0,0 +1,9 @@
+#version 330
+
+in vecN v_attr;
+
+out vecN f_attr;
+
+void main() {
+ f_attr = v_attr;
+}
diff --git a/src/utils3d/numpy/shaders/vertex_attribute.vsh b/src/utils3d/numpy/shaders/vertex_attribute.vsh
new file mode 100644
index 0000000000000000000000000000000000000000..7c94f91aaabfd714a47a194b93f8e53bf63577f5
--- /dev/null
+++ b/src/utils3d/numpy/shaders/vertex_attribute.vsh
@@ -0,0 +1,13 @@
+#version 330
+
+uniform mat4 u_mvp;
+
+in vec3 i_position;
+in vecN i_attr;
+
+out vecN v_attr;
+
+void main() {
+ gl_Position = u_mvp * vec4(i_position, 1.0);
+ v_attr = i_attr;
+}
diff --git a/src/utils3d/numpy/spline.py b/src/utils3d/numpy/spline.py
new file mode 100644
index 0000000000000000000000000000000000000000..03c664136bc3734215d37669a3446c248dffe097
--- /dev/null
+++ b/src/utils3d/numpy/spline.py
@@ -0,0 +1,82 @@
+from typing import *
+
+import numpy as np
+
+
+__all__ = ['linear_spline_interpolate']
+
+
+def linear_spline_interpolate(x: np.ndarray, t: np.ndarray, s: np.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> np.ndarray:
+ """
+ Linear spline interpolation.
+
+ ### Parameters:
+ - `x`: np.ndarray, shape (n, d): the values of data points.
+ - `t`: np.ndarray, shape (n,): the times of the data points.
+ - `s`: np.ndarray, shape (m,): the times to be interpolated.
+ - `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly.
+
+ ### Returns:
+ - `y`: np.ndarray, shape (..., m, d): the interpolated values.
+ """
+ i = np.searchsorted(t, s, side='left')
+ if extrapolation_mode == 'constant':
+ prev = np.clip(i - 1, 0, len(t) - 1)
+ suc = np.clip(i, 0, len(t) - 1)
+ elif extrapolation_mode == 'linear':
+ prev = np.clip(i - 1, 0, len(t) - 2)
+ suc = np.clip(i, 1, len(t) - 1)
+ else:
+ raise ValueError(f'Invalid extrapolation_mode: {extrapolation_mode}')
+
+ u = (s - t[prev]) / np.maximum(t[suc] - t[prev], 1e-12)
+ y = u * x[suc] + (1 - u) * x[prev]
+
+ return y
+
+
+
+def _solve_tridiagonal(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray) -> np.ndarray:
+ n = b.shape[-1]
+ cc = np.zeros_like(b)
+ dd = np.zeros_like(b)
+ cc[..., 0] = c[..., 0] / b[..., 0]
+ dd[..., 0] = d[..., 0] / b[..., 0]
+ for i in range(1, n):
+ cc[..., i] = c[..., i] / (b[..., i] - a[..., i - 1] * cc[..., i - 1])
+ dd[..., i] = (d[..., i] - a[..., i - 1] * dd[..., i - 1]) / (b[..., i] - a[..., i - 1] * cc[..., i - 1])
+ x = np.zeros_like(b)
+ x[..., -1] = dd[..., -1]
+ for i in range(n - 2, -1, -1):
+ x[..., i] = dd[..., i] - cc[..., i] * x[..., i + 1]
+ return x
+
+
+def cubic_spline_interpolate(x: np.ndarray, t: np.ndarray, s: np.ndarray, v0: np.ndarray = None, vn: np.ndarray = None) -> np.ndarray:
+ """
+ Cubic spline interpolation.
+
+ ### Parameters:
+ - `x`: np.ndarray, shape (..., n,): the x-coordinates of the data points.
+ - `t`: np.ndarray, shape (n,): the knot vector. NOTE: t must be sorted in ascending order.
+ - `s`: np.ndarray, shape (..., m,): the y-coordinates of the data points.
+ - `v0`: np.ndarray, shape (...,): the value of the derivative at the first knot, as the boundary condition. If None, it is set to zero.
+ - `vn`: np.ndarray, shape (...,): the value of the derivative at the last knot, as the boundary condition. If None, it is set to zero.
+
+ ### Returns:
+ - `y`: np.ndarray, shape (..., m): the interpolated values.
+ """
+ h = t[..., 1:] - t[..., :-1]
+ mu = h[..., :-1] / (h[..., :-1] + h[..., 1:])
+ la = 1 - mu
+ d = (x[..., 1:] - x[..., :-1]) / h
+ d = 6 * (d[..., 1:] - d[..., :-1]) / (t[..., 2:] - t[..., :-2])
+
+ mu = np.concatenate([mu, np.ones_like(mu[..., :1])], axis=-1)
+ la = np.concatenate([np.ones_like(la[..., :1]), la], axis=-1)
+ d = np.concatenate([(((x[..., 1] - x[..., 0]) / h[0] - v0) / h[0])[..., None], d, ((vn - (x[..., -1] - x[..., -2]) / h[-1]) / h[-1])[..., None]], axis=-1)
+
+ M = _solve_tridiagonal(mu, np.full_like(d, fill_value=2), la, d)
+
+ i = np.searchsorted(t, s, side='left')
+
diff --git a/src/utils3d/numpy/transforms.py b/src/utils3d/numpy/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e2418540b4a1d152a5e072900079f59f474a139
--- /dev/null
+++ b/src/utils3d/numpy/transforms.py
@@ -0,0 +1,1104 @@
+import numpy as np
+from typing import *
+from numbers import Number
+from ._helpers import batched
+from .._helpers import no_warnings
+
+
+__all__ = [
+ 'perspective',
+ 'perspective_from_fov',
+ 'perspective_from_fov_xy',
+ 'intrinsics_from_focal_center',
+ 'intrinsics_from_fov',
+ 'fov_to_focal',
+ 'focal_to_fov',
+ 'intrinsics_to_fov',
+ 'view_look_at',
+ 'extrinsics_look_at',
+ 'perspective_to_intrinsics',
+ 'perspective_to_near_far',
+ 'intrinsics_to_perspective',
+ 'extrinsics_to_view',
+ 'view_to_extrinsics',
+ 'normalize_intrinsics',
+ 'crop_intrinsics',
+ 'pixel_to_uv',
+ 'pixel_to_ndc',
+ 'uv_to_pixel',
+ 'project_depth',
+ 'depth_buffer_to_linear',
+ 'unproject_cv',
+ 'unproject_gl',
+ 'project_cv',
+ 'project_gl',
+ 'quaternion_to_matrix',
+ 'axis_angle_to_matrix',
+ 'matrix_to_quaternion',
+ 'extrinsics_to_essential',
+ 'euler_axis_angle_rotation',
+ 'euler_angles_to_matrix',
+ 'skew_symmetric',
+ 'rotation_matrix_from_vectors',
+ 'ray_intersection',
+ 'se3_matrix',
+ 'slerp_quaternion',
+ 'slerp_vector',
+ 'lerp',
+ 'lerp_se3_matrix',
+ 'piecewise_lerp',
+ 'piecewise_lerp_se3_matrix',
+ 'apply_transform'
+]
+
+
+@batched(0,0,0,0)
+def perspective(
+ fov_y: Union[float, np.ndarray],
+ aspect: Union[float, np.ndarray],
+ near: Union[float, np.ndarray],
+ far: Union[float, np.ndarray]
+) -> np.ndarray:
+ """
+ Get OpenGL perspective matrix
+
+ Args:
+ fov_y (float | np.ndarray): field of view in y axis
+ aspect (float | np.ndarray): aspect ratio
+ near (float | np.ndarray): near plane to clip
+ far (float | np.ndarray): far plane to clip
+
+ Returns:
+ (np.ndarray): [..., 4, 4] perspective matrix
+ """
+ N = fov_y.shape[0]
+ ret = np.zeros((N, 4, 4), dtype=fov_y.dtype)
+ ret[:, 0, 0] = 1. / (np.tan(fov_y / 2) * aspect)
+ ret[:, 1, 1] = 1. / (np.tan(fov_y / 2))
+ ret[:, 2, 2] = (near + far) / (near - far)
+ ret[:, 2, 3] = 2. * near * far / (near - far)
+ ret[:, 3, 2] = -1.
+ return ret
+
+
+def perspective_from_fov(
+ fov: Union[float, np.ndarray],
+ width: Union[int, np.ndarray],
+ height: Union[int, np.ndarray],
+ near: Union[float, np.ndarray],
+ far: Union[float, np.ndarray]
+) -> np.ndarray:
+ """
+ Get OpenGL perspective matrix from field of view in largest dimension
+
+ Args:
+ fov (float | np.ndarray): field of view in largest dimension
+ width (int | np.ndarray): image width
+ height (int | np.ndarray): image height
+ near (float | np.ndarray): near plane to clip
+ far (float | np.ndarray): far plane to clip
+
+ Returns:
+ (np.ndarray): [..., 4, 4] perspective matrix
+ """
+ fov_y = 2 * np.arctan(np.tan(fov / 2) * height / np.maximum(width, height))
+ aspect = width / height
+ return perspective(fov_y, aspect, near, far)
+
+
+def perspective_from_fov_xy(
+ fov_x: Union[float, np.ndarray],
+ fov_y: Union[float, np.ndarray],
+ near: Union[float, np.ndarray],
+ far: Union[float, np.ndarray]
+) -> np.ndarray:
+ """
+ Get OpenGL perspective matrix from field of view in x and y axis
+
+ Args:
+ fov_x (float | np.ndarray): field of view in x axis
+ fov_y (float | np.ndarray): field of view in y axis
+ near (float | np.ndarray): near plane to clip
+ far (float | np.ndarray): far plane to clip
+
+ Returns:
+ (np.ndarray): [..., 4, 4] perspective matrix
+ """
+ aspect = np.tan(fov_x / 2) / np.tan(fov_y / 2)
+ return perspective(fov_y, aspect, near, far)
+
+
+def intrinsics_from_focal_center(
+ fx: Union[float, np.ndarray],
+ fy: Union[float, np.ndarray],
+ cx: Union[float, np.ndarray],
+ cy: Union[float, np.ndarray],
+ dtype: Optional[np.dtype] = np.float32
+) -> np.ndarray:
+ """
+ Get OpenCV intrinsics matrix
+
+ Returns:
+ (np.ndarray): [..., 3, 3] OpenCV intrinsics matrix
+ """
+ if any(isinstance(x, np.ndarray) for x in (fx, fy, cx, cy)):
+ dtype = np.result_type(fx, fy, cx, cy)
+ fx, fy, cx, cy = np.broadcast_arrays(fx, fy, cx, cy)
+ ret = np.zeros((*fx.shape, 3, 3), dtype=dtype)
+ ret[..., 0, 0] = fx
+ ret[..., 1, 1] = fy
+ ret[..., 0, 2] = cx
+ ret[..., 1, 2] = cy
+ ret[..., 2, 2] = 1.
+ return ret
+
+
+def intrinsics_from_fov(
+ fov_max: Union[float, np.ndarray] = None,
+ fov_min: Union[float, np.ndarray] = None,
+ fov_x: Union[float, np.ndarray] = None,
+ fov_y: Union[float, np.ndarray] = None,
+ width: Union[int, np.ndarray] = None,
+ height: Union[int, np.ndarray] = None,
+) -> np.ndarray:
+ """
+ Get normalized OpenCV intrinsics matrix from given field of view.
+ You can provide either fov_max, fov_min, fov_x or fov_y
+
+ Args:
+ width (int | np.ndarray): image width
+ height (int | np.ndarray): image height
+ fov_max (float | np.ndarray): field of view in largest dimension
+ fov_min (float | np.ndarray): field of view in smallest dimension
+ fov_x (float | np.ndarray): field of view in x axis
+ fov_y (float | np.ndarray): field of view in y axis
+
+ Returns:
+ (np.ndarray): [..., 3, 3] OpenCV intrinsics matrix
+ """
+ if fov_max is not None:
+ fx = np.maximum(width, height) / width / (2 * np.tan(fov_max / 2))
+ fy = np.maximum(width, height) / height / (2 * np.tan(fov_max / 2))
+ elif fov_min is not None:
+ fx = np.minimum(width, height) / width / (2 * np.tan(fov_min / 2))
+ fy = np.minimum(width, height) / height / (2 * np.tan(fov_min / 2))
+ elif fov_x is not None and fov_y is not None:
+ fx = 1 / (2 * np.tan(fov_x / 2))
+ fy = 1 / (2 * np.tan(fov_y / 2))
+ elif fov_x is not None:
+ fx = 1 / (2 * np.tan(fov_x / 2))
+ fy = fx * width / height
+ elif fov_y is not None:
+ fy = 1 / (2 * np.tan(fov_y / 2))
+ fx = fy * height / width
+ cx = 0.5
+ cy = 0.5
+ ret = intrinsics_from_focal_center(fx, fy, cx, cy)
+ return ret
+
+
+def focal_to_fov(focal: np.ndarray):
+ return 2 * np.arctan(0.5 / focal)
+
+
+def fov_to_focal(fov: np.ndarray):
+ return 0.5 / np.tan(fov / 2)
+
+
+def intrinsics_to_fov(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ fov_x = focal_to_fov(intrinsics[..., 0, 0])
+ fov_y = focal_to_fov(intrinsics[..., 1, 1])
+ return fov_x, fov_y
+
+
+@batched(1,1,1)
+def view_look_at(
+ eye: np.ndarray,
+ look_at: np.ndarray,
+ up: np.ndarray
+ ) -> np.ndarray:
+ """
+ Get OpenGL view matrix looking at something
+
+ Args:
+ eye (np.ndarray): [..., 3] the eye position
+ look_at (np.ndarray): [..., 3] the position to look at
+ up (np.ndarray): [..., 3] head up direction (y axis in screen space). Not necessarily othogonal to view direction
+
+ Returns:
+ (np.ndarray): [..., 4, 4], view matrix
+ """
+ z = eye - look_at
+ x = np.cross(up, z)
+ y = np.cross(z, x)
+ # x = np.cross(y, z)
+ x = x / np.linalg.norm(x, axis=-1, keepdims=True)
+ y = y / np.linalg.norm(y, axis=-1, keepdims=True)
+ z = z / np.linalg.norm(z, axis=-1, keepdims=True)
+ R = np.stack([x, y, z], axis=-2)
+ t = -np.matmul(R, eye[..., None])
+ return np.concatenate([
+ np.concatenate([R, t], axis=-1),
+ np.array([[[0., 0., 0., 1.]]]).repeat(eye.shape[0], axis=0)
+ ], axis=-2)
+
+
+@batched(1,1,1)
+def extrinsics_look_at(
+ eye: np.ndarray,
+ look_at: np.ndarray,
+ up: np.ndarray
+) -> np.ndarray:
+ """
+ Get OpenCV extrinsics matrix looking at something
+
+ Args:
+ eye (np.ndarray): [..., 3] the eye position
+ look_at (np.ndarray): [..., 3] the position to look at
+ up (np.ndarray): [..., 3] head up direction (-y axis in screen space). Not necessarily othogonal to view direction
+
+ Returns:
+ (np.ndarray): [..., 4, 4], extrinsics matrix
+ """
+ z = look_at - eye
+ x = np.cross(-up, z)
+ y = np.cross(z, x)
+ # x = np.cross(y, z)
+ x = x / np.linalg.norm(x, axis=-1, keepdims=True)
+ y = y / np.linalg.norm(y, axis=-1, keepdims=True)
+ z = z / np.linalg.norm(z, axis=-1, keepdims=True)
+ R = np.stack([x, y, z], axis=-2)
+ t = -np.matmul(R, eye[..., None])
+ return np.concatenate([
+ np.concatenate([R, t], axis=-1),
+ np.array([[[0., 0., 0., 1.]]], dtype=eye.dtype).repeat(eye.shape[0], axis=0)
+ ], axis=-2)
+
+
+def perspective_to_intrinsics(
+ perspective: np.ndarray
+) -> np.ndarray:
+ """
+ OpenGL perspective matrix to OpenCV intrinsics
+
+ Args:
+ perspective (np.ndarray): [..., 4, 4] OpenGL perspective matrix
+
+ Returns:
+ (np.ndarray): shape [..., 3, 3] OpenCV intrinsics
+ """
+ ret = np.array([[0.5, 0., 0.5], [0., -0.5, 0.5], [0., 0., 1.]], dtype=perspective.dtype) \
+ @ perspective[..., [0, 1, 3], :3] \
+ @ np.diag(np.array([1, -1, -1], dtype=perspective.dtype))
+ return ret
+
+
+def perspective_to_near_far(perspective: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Get near and far planes from OpenGL perspective matrix
+
+ Args:
+ """
+ a, b = perspective[..., 2, 2], perspective[..., 2, 3]
+ near, far = b / (a - 1), b / (a + 1)
+ return near, far
+
+
+@batched(2,0,0)
+def intrinsics_to_perspective(
+ intrinsics: np.ndarray,
+ near: Union[float, np.ndarray],
+ far: Union[float, np.ndarray],
+) -> np.ndarray:
+ """
+ OpenCV intrinsics to OpenGL perspective matrix
+ NOTE: not work for tile-shifting intrinsics currently
+
+ Args:
+ intrinsics (np.ndarray): [..., 3, 3] OpenCV intrinsics matrix
+ near (float | np.ndarray): [...] near plane to clip
+ far (float | np.ndarray): [...] far plane to clip
+ Returns:
+ (np.ndarray): [..., 4, 4] OpenGL perspective matrix
+ """
+ N = intrinsics.shape[0]
+ fx, fy = intrinsics[:, 0, 0], intrinsics[:, 1, 1]
+ cx, cy = intrinsics[:, 0, 2], intrinsics[:, 1, 2]
+ ret = np.zeros((N, 4, 4), dtype=intrinsics.dtype)
+ ret[:, 0, 0] = 2 * fx
+ ret[:, 1, 1] = 2 * fy
+ ret[:, 0, 2] = -2 * cx + 1
+ ret[:, 1, 2] = 2 * cy - 1
+ ret[:, 2, 2] = (near + far) / (near - far)
+ ret[:, 2, 3] = 2. * near * far / (near - far)
+ ret[:, 3, 2] = -1.
+ return ret
+
+
+@batched(2)
+def extrinsics_to_view(
+ extrinsics: np.ndarray
+ ) -> np.ndarray:
+ """
+ OpenCV camera extrinsics to OpenGL view matrix
+
+ Args:
+ extrinsics (np.ndarray): [..., 4, 4] OpenCV camera extrinsics matrix
+
+ Returns:
+ (np.ndarray): [..., 4, 4] OpenGL view matrix
+ """
+ return extrinsics * np.array([1, -1, -1, 1], dtype=extrinsics.dtype)[:, None]
+
+
+@batched(2)
+def view_to_extrinsics(
+ view: np.ndarray
+ ) -> np.ndarray:
+ """
+ OpenGL view matrix to OpenCV camera extrinsics
+
+ Args:
+ view (np.ndarray): [..., 4, 4] OpenGL view matrix
+
+ Returns:
+ (np.ndarray): [..., 4, 4] OpenCV camera extrinsics matrix
+ """
+ return view * np.array([1, -1, -1, 1], dtype=view.dtype)[:, None]
+
+
+@batched(2, 0, 0, None)
+def normalize_intrinsics(
+ intrinsics: np.ndarray,
+ width: Union[int, np.ndarray],
+ height: Union[int, np.ndarray],
+ integer_pixel_centers: bool = True
+) -> np.ndarray:
+ """
+ Normalize intrinsics from pixel cooridnates to uv coordinates
+
+ Args:
+ intrinsics (np.ndarray): [..., 3, 3] camera intrinsics(s) to normalize
+ width (int | np.ndarray): [...] image width(s)
+ height (int | np.ndarray): [...] image height(s)
+ integer_pixel_centers (bool): whether the integer pixel coordinates are at the center of the pixel. If False, the integer coordinates are at the left-top corner of the pixel.
+
+ Returns:
+ (np.ndarray): [..., 3, 3] normalized camera intrinsics(s)
+ """
+ zeros = np.zeros_like(width)
+ ones = np.ones_like(width)
+ if integer_pixel_centers:
+ transform = np.stack([
+ 1 / width, zeros, 0.5 / width,
+ zeros, 1 / height, 0.5 / height,
+ zeros, zeros, ones
+ ]).reshape(*zeros.shape, 3, 3)
+ else:
+ transform = np.stack([
+ 1 / width, zeros, zeros,
+ zeros, 1 / height, zeros,
+ zeros, zeros, ones
+ ]).reshape(*zeros.shape, 3, 3)
+ return transform @ intrinsics
+
+
+@batched(2,0,0,0,0,0,0)
+def crop_intrinsics(
+ intrinsics: np.ndarray,
+ width: Union[int, np.ndarray],
+ height: Union[int, np.ndarray],
+ left: Union[int, np.ndarray],
+ top: Union[int, np.ndarray],
+ crop_width: Union[int, np.ndarray],
+ crop_height: Union[int, np.ndarray]
+) -> np.ndarray:
+ """
+ Evaluate the new intrinsics(s) after crop the image: cropped_img = img[top:top+crop_height, left:left+crop_width]
+
+ Args:
+ intrinsics (np.ndarray): [..., 3, 3] camera intrinsics(s) to crop
+ width (int | np.ndarray): [...] image width(s)
+ height (int | np.ndarray): [...] image height(s)
+ left (int | np.ndarray): [...] left crop boundary
+ top (int | np.ndarray): [...] top crop boundary
+ crop_width (int | np.ndarray): [...] crop width
+ crop_height (int | np.ndarray): [...] crop height
+
+ Returns:
+ (np.ndarray): [..., 3, 3] cropped camera intrinsics(s)
+ """
+ zeros = np.zeros_like(width)
+ ones = np.ones_like(width)
+ transform = np.stack([
+ width / crop_width, zeros, -left / crop_width,
+ zeros, height / crop_height, -top / crop_height,
+ zeros, zeros, ones
+ ]).reshape(*zeros.shape, 3, 3)
+ return transform @ intrinsics
+
+
+@batched(1,0,0)
+def pixel_to_uv(
+ pixel: np.ndarray,
+ width: Union[int, np.ndarray],
+ height: Union[int, np.ndarray]
+) -> np.ndarray:
+ """
+ Args:
+ pixel (np.ndarray): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1)
+ width (int | np.ndarray): [...] image width(s)
+ height (int | np.ndarray): [...] image height(s)
+
+ Returns:
+ (np.ndarray): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)
+ """
+ if not np.issubdtype(pixel.dtype, np.floating):
+ pixel = pixel.astype(np.float32)
+ dtype = pixel.dtype
+ uv = (pixel + np.array(0.5, dtype=dtype)) / np.stack([width, height], axis=-1)
+ return uv
+
+
+@batched(1,0,0)
+def uv_to_pixel(
+ uv: np.ndarray,
+ width: Union[int, np.ndarray],
+ height: Union[int, np.ndarray]
+) -> np.ndarray:
+ """
+ Args:
+ pixel (np.ndarray): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1)
+ width (int | np.ndarray): [...] image width(s)
+ height (int | np.ndarray): [...] image height(s)
+
+ Returns:
+ (np.ndarray): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)
+ """
+ pixel = uv * np.stack([width, height], axis=-1).astype(uv.dtype) - 0.5
+ return pixel
+
+
+@batched(1,0,0)
+def pixel_to_ndc(
+ pixel: np.ndarray,
+ width: Union[int, np.ndarray],
+ height: Union[int, np.ndarray]
+) -> np.ndarray:
+ """
+ Args:
+ pixel (np.ndarray): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1)
+ width (int | np.ndarray): [...] image width(s)
+ height (int | np.ndarray): [...] image height(s)
+
+ Returns:
+ (np.ndarray): [..., 2] pixel coordinrates defined in ndc space, the range is (-1, 1)
+ """
+ if not np.issubdtype(pixel.dtype, np.floating):
+ pixel = pixel.astype(np.float32)
+ dtype = pixel.dtype
+ ndc = (pixel + np.array(0.5, dtype=dtype)) / (np.stack([width, height], dim=-1) * np.array([2, -2], dtype=dtype)) \
+ + np.array([-1, 1], dtype=dtype)
+ return ndc
+
+
+@batched(0,0,0)
+def project_depth(
+ depth: np.ndarray,
+ near: Union[float, np.ndarray],
+ far: Union[float, np.ndarray]
+) -> np.ndarray:
+ """
+ Project linear depth to depth value in screen space
+
+ Args:
+ depth (np.ndarray): [...] depth value
+ near (float | np.ndarray): [...] near plane to clip
+ far (float | np.ndarray): [...] far plane to clip
+
+ Returns:
+ (np.ndarray): [..., 1] depth value in screen space, value ranging in [0, 1]
+ """
+ return (far - near * far / depth) / (far - near)
+
+
+@batched(0,0,0)
+def depth_buffer_to_linear(
+ depth_buffer: np.ndarray,
+ near: Union[float, np.ndarray],
+ far: Union[float, np.ndarray]
+) -> np.ndarray:
+ """
+ OpenGL depth buffer to linear depth
+
+ Args:
+ depth_buffer (np.ndarray): [...] depth value
+ near (float | np.ndarray): [...] near plane to clip
+ far (float | np.ndarray): [...] far plane to clip
+
+ Returns:
+ (np.ndarray): [..., 1] linear depth
+ """
+ return near * far / (far - (far - near) * depth_buffer)
+
+
+@batched(2,2,2,2)
+def project_gl(
+ points: np.ndarray,
+ model: np.ndarray = None,
+ view: np.ndarray = None,
+ perspective: np.ndarray = None
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Project 3D points to 2D following the OpenGL convention (except for row major matrice)
+
+ Args:
+ points (np.ndarray): [..., N, 3] or [..., N, 4] 3D points to project, if the last
+ dimension is 4, the points are assumed to be in homogeneous coordinates
+ model (np.ndarray): [..., 4, 4] model matrix
+ view (np.ndarray): [..., 4, 4] view matrix
+ perspective (np.ndarray): [..., 4, 4] perspective matrix
+
+ Returns:
+ scr_coord (np.ndarray): [..., N, 3] screen space coordinates, value ranging in [0, 1].
+ The origin (0., 0., 0.) is corresponding to the left & bottom & nearest
+ linear_depth (np.ndarray): [..., N] linear depth
+ """
+ assert perspective is not None, "perspective matrix is required"
+ if points.shape[-1] == 3:
+ points = np.concatenate([points, np.ones_like(points[..., :1])], axis=-1)
+ if model is not None:
+ points = points @ model.swapaxes(-1, -2)
+ if view is not None:
+ points = points @ view.swapaxes(-1, -2)
+ clip_coord = points @ perspective.swapaxes(-1, -2)
+ ndc_coord = clip_coord[..., :3] / clip_coord[..., 3:]
+ scr_coord = ndc_coord * 0.5 + 0.5
+ linear_depth = clip_coord[..., 3]
+ return scr_coord, linear_depth
+
+
+@batched(2,2,2)
+def project_cv(
+ points: np.ndarray,
+ extrinsics: np.ndarray = None,
+ intrinsics: np.ndarray = None
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Project 3D points to 2D following the OpenCV convention
+
+ Args:
+ points (np.ndarray): [..., N, 3] or [..., N, 4] 3D points to project, if the last
+ dimension is 4, the points are assumed to be in homogeneous coordinates
+ extrinsics (np.ndarray): [..., 4, 4] extrinsics matrix
+ intrinsics (np.ndarray): [..., 3, 3] intrinsics matrix
+
+ Returns:
+ uv_coord (np.ndarray): [..., N, 2] uv coordinates, value ranging in [0, 1].
+ The origin (0., 0.) is corresponding to the left & top
+ linear_depth (np.ndarray): [..., N] linear depth
+ """
+ assert intrinsics is not None, "intrinsics matrix is required"
+ if points.shape[-1] == 3:
+ points = np.concatenate([points, np.ones_like(points[..., :1])], axis=-1)
+ if extrinsics is not None:
+ points = points @ extrinsics.swapaxes(-1, -2)
+ points = points[..., :3] @ intrinsics.swapaxes(-1, -2)
+ with no_warnings():
+ uv_coord = points[..., :2] / points[..., 2:]
+ linear_depth = points[..., 2]
+ return uv_coord, linear_depth
+
+
+@batched(2,2,2,2)
+def unproject_gl(
+ screen_coord: np.ndarray,
+ model: np.ndarray = None,
+ view: np.ndarray = None,
+ perspective: np.ndarray = None
+ ) -> np.ndarray:
+ """
+ Unproject screen space coordinates to 3D view space following the OpenGL convention (except for row major matrice)
+
+ Args:
+ screen_coord (np.ndarray): [..., N, 3] screen space coordinates, value ranging in [0, 1].
+ The origin (0., 0., 0.) is corresponding to the left & bottom & nearest
+ model (np.ndarray): [..., 4, 4] model matrix
+ view (np.ndarray): [..., 4, 4] view matrix
+ perspective (np.ndarray): [..., 4, 4] perspective matrix
+
+ Returns:
+ points (np.ndarray): [..., N, 3] 3d points
+ """
+ assert perspective is not None, "perspective matrix is required"
+ ndc_xy = screen_coord * 2 - 1
+ clip_coord = np.concatenate([ndc_xy, np.ones_like(ndc_xy[..., :1])], axis=-1)
+ transform = perspective
+ if view is not None:
+ transform = transform @ view
+ if model is not None:
+ transform = transform @ model
+ transform = np.linalg.inv(transform)
+ points = clip_coord @ transform.swapaxes(-1, -2)
+ points = points[..., :3] / points[..., 3:]
+ return points
+
+
+@batched(2,1,2,2)
+def unproject_cv(
+ uv_coord: np.ndarray,
+ depth: np.ndarray = None,
+ extrinsics: np.ndarray = None,
+ intrinsics: np.ndarray = None
+) -> np.ndarray:
+ """
+ Unproject uv coordinates to 3D view space following the OpenCV convention
+
+ Args:
+ uv_coord (np.ndarray): [..., N, 2] uv coordinates, value ranging in [0, 1].
+ The origin (0., 0.) is corresponding to the left & top
+ depth (np.ndarray): [..., N] depth value
+ extrinsics (np.ndarray): [..., 4, 4] extrinsics matrix
+ intrinsics (np.ndarray): [..., 3, 3] intrinsics matrix
+
+ Returns:
+ points (np.ndarray): [..., N, 3] 3d points
+ """
+ assert intrinsics is not None, "intrinsics matrix is required"
+ points = np.concatenate([uv_coord, np.ones_like(uv_coord[..., :1])], axis=-1)
+ points = points @ np.linalg.inv(intrinsics).swapaxes(-1, -2)
+ if depth is not None:
+ points = points * depth[..., None]
+ if extrinsics is not None:
+ points = np.concatenate([points, np.ones_like(points[..., :1])], axis=-1)
+ points = (points @ np.linalg.inv(extrinsics).swapaxes(-1, -2))[..., :3]
+ return points
+
+
+def quaternion_to_matrix(quaternion: np.ndarray, eps: float = 1e-12) -> np.ndarray:
+ """Converts a batch of quaternions (w, x, y, z) to rotation matrices
+
+ Args:
+ quaternion (np.ndarray): shape (..., 4), the quaternions to convert
+
+ Returns:
+ np.ndarray: shape (..., 3, 3), the rotation matrices corresponding to the given quaternions
+ """
+ assert quaternion.shape[-1] == 4
+ quaternion = quaternion / np.linalg.norm(quaternion, axis=-1, keepdims=True).clip(min=eps)
+ w, x, y, z = quaternion[..., 0], quaternion[..., 1], quaternion[..., 2], quaternion[..., 3]
+ zeros = np.zeros_like(w)
+ I = np.eye(3, dtype=quaternion.dtype)
+ xyz = quaternion[..., 1:]
+ A = xyz[..., :, None] * xyz[..., None, :] - I * (xyz ** 2).sum(axis=-1)[..., None, None]
+ B = np.stack([
+ zeros, -z, y,
+ z, zeros, -x,
+ -y, x, zeros
+ ], axis=-1).reshape(*quaternion.shape[:-1], 3, 3)
+ rot_mat = I + 2 * (A + w[..., None, None] * B)
+ return rot_mat
+
+
+def matrix_to_quaternion(rot_mat: np.ndarray, eps: float = 1e-12) -> np.ndarray:
+ """Convert 3x3 rotation matrix to quaternion (w, x, y, z)
+
+ Args:
+ rot_mat (np.ndarray): shape (..., 3, 3), the rotation matrices to convert
+
+ Returns:
+ np.ndarray: shape (..., 4), the quaternions corresponding to the given rotation matrices
+ """
+ # Extract the diagonal and off-diagonal elements of the rotation matrix
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = [rot_mat[..., i, j] for i in range(3) for j in range(3)]
+
+ diag = np.diagonal(rot_mat, axis1=-2, axis2=-1)
+ M = np.array([
+ [1, 1, 1],
+ [1, -1, -1],
+ [-1, 1, -1],
+ [-1, -1, 1]
+ ], dtype=rot_mat.dtype)
+ wxyz = 0.5 * np.clip(1 + diag @ M.T, 0.0, None) ** 0.5
+ max_idx = np.argmax(wxyz, axis=-1)
+ xw = np.sign(m21 - m12)
+ yw = np.sign(m02 - m20)
+ zw = np.sign(m10 - m01)
+ yz = np.sign(m21 + m12)
+ xz = np.sign(m02 + m20)
+ xy = np.sign(m01 + m10)
+ ones = np.ones_like(xw)
+ sign = np.where(
+ max_idx[..., None] == 0,
+ np.stack([ones, xw, yw, zw], axis=-1),
+ np.where(
+ max_idx[..., None] == 1,
+ np.stack([xw, ones, xy, xz], axis=-1),
+ np.where(
+ max_idx[..., None] == 2,
+ np.stack([yw, xy, ones, yz], axis=-1),
+ np.stack([zw, xz, yz, ones], axis=-1)
+ )
+ )
+ )
+ quat = sign * wxyz
+ quat = quat / np.linalg.norm(quat, axis=-1, keepdims=True).clip(min=eps)
+ return quat
+
+
+def extrinsics_to_essential(extrinsics: np.ndarray):
+ """
+ extrinsics matrix `[[R, t] [0, 0, 0, 1]]` such that `x' = R (x - t)` to essential matrix such that `x' E x = 0`
+
+ Args:
+ extrinsics (np.ndaray): [..., 4, 4] extrinsics matrix
+
+ Returns:
+ (np.ndaray): [..., 3, 3] essential matrix
+ """
+ assert extrinsics.shape[-2:] == (4, 4)
+ R = extrinsics[..., :3, :3]
+ t = extrinsics[..., :3, 3]
+ zeros = np.zeros_like(t[..., 0])
+ t_x = np.stack([
+ zeros, -t[..., 2], t[..., 1],
+ t[..., 2], zeros, -t[..., 0],
+ -t[..., 1], t[..., 0], zeros
+ ]).reshape(*t.shape[:-1], 3, 3)
+ return t_x @ R
+
+
+def euler_axis_angle_rotation(axis: str, angle: np.ndarray) -> np.ndarray:
+ """
+ Return the rotation matrices for one of the rotations about an axis
+ of which Euler angles describe, for each value of the angle given.
+
+ Args:
+ axis: Axis label "X" or "Y or "Z".
+ angle: any shape tensor of Euler angles in radians
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+
+ cos = np.cos(angle)
+ sin = np.sin(angle)
+ one = np.ones_like(angle)
+ zero = np.zeros_like(angle)
+
+ if axis == "X":
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
+ elif axis == "Y":
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
+ elif axis == "Z":
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
+ else:
+ raise ValueError("letter must be either X, Y or Z.")
+
+ return np.stack(R_flat, -1).reshape(angle.shape + (3, 3))
+
+
+def euler_angles_to_matrix(euler_angles: np.ndarray, convention: str = 'XYZ') -> np.ndarray:
+ """
+ Convert rotations given as Euler angles in radians to rotation matrices.
+
+ Args:
+ euler_angles: Euler angles in radians as ndarray of shape (..., 3), XYZ
+ convention: permutation of "X", "Y" or "Z", representing the order of Euler rotations to apply.
+
+ Returns:
+ Rotation matrices as ndarray of shape (..., 3, 3).
+ """
+ if euler_angles.shape[-1] != 3:
+ raise ValueError("Invalid input euler angles.")
+ if len(convention) != 3:
+ raise ValueError("Convention must have 3 letters.")
+ if convention[1] in (convention[0], convention[2]):
+ raise ValueError(f"Invalid convention {convention}.")
+ for letter in convention:
+ if letter not in ("X", "Y", "Z"):
+ raise ValueError(f"Invalid letter {letter} in convention string.")
+ matrices = [
+ euler_axis_angle_rotation(c, euler_angles[..., 'XYZ'.index(c)])
+ for c in convention
+ ]
+ return matrices[2] @ matrices[1] @ matrices[0]
+
+
+def skew_symmetric(v: np.ndarray):
+ "Skew symmetric matrix from a 3D vector"
+ assert v.shape[-1] == 3, "v must be 3D"
+ x, y, z = v[..., 0], v[..., 1], v[..., 2]
+ zeros = np.zeros_like(x)
+ return np.stack([
+ zeros, -z, y,
+ z, zeros, -x,
+ -y, x, zeros,
+ ], axis=-1).reshape(*v.shape[:-1], 3, 3)
+
+
+def rotation_matrix_from_vectors(v1: np.ndarray, v2: np.ndarray):
+ "Rotation matrix that rotates v1 to v2"
+ I = np.eye(3, dtype=v1.dtype)
+ v1 = v1 / np.linalg.norm(v1, axis=-1)
+ v2 = v2 / np.linalg.norm(v2, axis=-1)
+ v = np.cross(v1, v2, axis=-1)
+ c = np.sum(v1 * v2, axis=-1)
+ K = skew_symmetric(v)
+ R = I + K + (1 / (1 + c)).astype(v1.dtype)[None, None] * (K @ K) # Avoid numpy's default type casting for scalars
+ return R
+
+
+def axis_angle_to_matrix(axis_angle: np.ndarray, eps: float = 1e-12) -> np.ndarray:
+ """Convert axis-angle representation (rotation vector) to rotation matrix, whose direction is the axis of rotation and length is the angle of rotation
+
+ Args:
+ axis_angle (np.ndarray): shape (..., 3), axis-angle vcetors
+
+ Returns:
+ np.ndarray: shape (..., 3, 3) The rotation matrices for the given axis-angle parameters
+ """
+ batch_shape = axis_angle.shape[:-1]
+ dtype = axis_angle.dtype
+
+ angle = np.linalg.norm(axis_angle, axis=-1, keepdims=True)
+ axis = axis_angle / (angle + eps)
+
+ cos = np.cos(angle)[..., None, :]
+ sin = np.sin(angle)[..., None, :]
+
+ rx, ry, rz = np.split(axis, 3, axis=-1)
+ zeros = np.zeros((*batch_shape, 1), dtype=dtype)
+ K = np.concatenate([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], axis=-1).reshape((*batch_shape, 3, 3))
+
+ ident = np.eye(3, dtype=dtype)
+ rot_mat = ident + sin * K + (1 - cos) * (K @ K)
+ return rot_mat
+
+
+def ray_intersection(p1: np.ndarray, d1: np.ndarray, p2: np.ndarray, d2: np.ndarray):
+ """
+ Compute the intersection/closest point of two D-dimensional rays
+ If the rays are intersecting, the closest point is the intersection point.
+
+ Args:
+ p1 (np.ndarray): (..., D) origin of ray 1
+ d1 (np.ndarray): (..., D) direction of ray 1
+ p2 (np.ndarray): (..., D) origin of ray 2
+ d2 (np.ndarray): (..., D) direction of ray 2
+
+ Returns:
+ (np.ndarray): (..., N) intersection point
+ """
+ p1, d1, p2, d2 = np.broadcast_arrays(p1, d1, p2, d2)
+ dtype = p1.dtype
+ dim = p1.shape[-1]
+ d = np.stack([d1, d2], axis=-2) # (..., 2, D)
+ p = np.stack([p1, p2], axis=-2) # (..., 2, D)
+ A = np.concatenate([
+ (np.eye(dim, dtype=dtype) * np.ones((*p.shape[:-2], 2, 1, 1))).reshape(*d.shape[:-2], 2 * dim, dim), # (..., 2 * D, D)
+ -(np.eye(2, dtype=dtype)[..., None] * d[..., None, :]).swapaxes(-2, -1).reshape(*d.shape[:-2], 2 * dim, 2) # (..., 2 * D, 2)
+ ], axis=-1) # (..., 2 * D, D + 2)
+ b = p.reshape(*p.shape[:-2], 2 * dim) # (..., 2 * D)
+ x = np.linalg.solve(A.swapaxes(-1, -2) @ A + 1e-12 * np.eye(dim + 2, dtype=dtype), (A.swapaxes(-1, -2) @ b[..., :, None])[..., 0])
+ return x[..., :dim], (x[..., dim], x[..., dim + 1])
+
+
+def se3_matrix(R: np.ndarray, t: np.ndarray) -> np.ndarray:
+ """
+ Convert rotation matrix and translation vector to 4x4 transformation matrix.
+
+ Args:
+ R (np.ndarray): [..., 3, 3] rotation matrix
+ t (np.ndarray): [..., 3] translation vector
+
+ Returns:
+ np.ndarray: [..., 4, 4] transformation matrix
+ """
+ assert R.shape[:-2] == t.shape[:-1]
+ assert R.shape[-1] == 3 and R.shape[-2] == 3
+ return np.concatenate([
+ np.concatenate([R, t[..., None]], axis=-1),
+ np.concatenate([np.zeros_like(t), np.ones_like(t[..., :1])], axis=-1)[..., None, :]
+ ], axis=-2)
+
+
+def slerp_quaternion(q1: np.ndarray, q2: np.ndarray, t: np.ndarray) -> np.ndarray:
+ """
+ Spherical linear interpolation between two unit quaternions.
+
+ Args:
+ q1 (np.ndarray): [..., d] unit vector 1
+ q2 (np.ndarray): [..., d] unit vector 2
+ t (np.ndarray): [...] interpolation parameter in [0, 1]
+
+ Returns:
+ np.ndarray: [..., 3] interpolated unit vector
+ """
+ q1 = q1 / np.linalg.norm(q1, axis=-1, keepdims=True)
+ q2 = q2 / np.linalg.norm(q2, axis=-1, keepdims=True)
+ dot = np.sum(q1 * q2, axis=-1, keepdims=True)
+
+ dot = np.where(dot < 0, -dot, dot) # handle negative dot product
+
+ dot = np.minimum(dot, 1.)
+ theta = np.arccos(dot) * t
+
+ q_ortho = q2 - q1 * dot
+ q_ortho = q_ortho / np.maximum(np.linalg.norm(q_ortho, axis=-1, keepdims=True), 1e-12)
+ q = q1 * np.cos(theta) + q_ortho * np.sin(theta)
+ return q
+
+
+def slerp_rotation_matrix(R1: np.ndarray, R2: np.ndarray, t: np.ndarray) -> np.ndarray:
+ """
+ Spherical linear interpolation between two rotation matrices.
+
+ Args:
+ R1 (np.ndarray): [..., 3, 3] rotation matrix 1
+ R2 (np.ndarray): [..., 3, 3] rotation matrix 2
+ t (np.ndarray): [...] interpolation parameter in [0, 1]
+
+ Returns:
+ np.ndarray: [..., 3, 3] interpolated rotation matrix
+ """
+ quat1 = matrix_to_quaternion(R1)
+ quat2 = matrix_to_quaternion(R2)
+ quat = slerp_quaternion(quat1, quat2, t)
+ return quaternion_to_matrix(quat)
+
+
+def slerp_vector(v1: np.ndarray, v2: np.ndarray, t: np.ndarray) -> np.ndarray:
+ """
+ Spherical linear interpolation between two unit vectors. The vectors are assumed to be normalized.
+
+ Args:
+ v1 (np.ndarray): [..., d] unit vector 1
+ v2 (np.ndarray): [..., d] unit vector 2
+ t (np.ndarray): [...] interpolation parameter in [0, 1]
+
+ Returns:
+ np.ndarray: [..., d] interpolated unit vector
+ """
+ dot = np.sum(v1 * v2, axis=-1, keepdims=True)
+
+ dot = np.minimum(dot, 1.)
+ theta = np.arccos(dot) * t
+
+ v_ortho = v2 - v1 * dot
+ v_ortho = v_ortho / np.maximum(np.linalg.norm(v_ortho, axis=-1, keepdims=True), 1e-12)
+ v = v1 * np.cos(theta) + v_ortho * np.sin(theta)
+ return v
+
+
+def lerp(x1: np.ndarray, x2: np.ndarray, t: np.ndarray) -> np.ndarray:
+ """
+ Linear interpolation between two vectors.
+
+ Args:
+ x1 (np.ndarray): [..., d] vector 1
+ x2 (np.ndarray): [..., d] vector 2
+ t (np.ndarray): [...] interpolation parameter. [0, 1] for interpolation between x1 and x2, otherwise for extrapolation.
+
+ Returns:
+ np.ndarray: [..., d] interpolated vector
+ """
+ return x1 + np.asarray(t)[..., None] * (x2 - x1)
+
+
+def lerp_se3_matrix(T1: np.ndarray, T2: np.ndarray, t: np.ndarray) -> np.ndarray:
+ """
+ Linear interpolation between two SE(3) matrices.
+
+ Args:
+ T1 (np.ndarray): [..., 4, 4] SE(3) matrix 1
+ T2 (np.ndarray): [..., 4, 4] SE(3) matrix 2
+ t (np.ndarray): [...] interpolation parameter in [0, 1]
+
+ Returns:
+ np.ndarray: [..., 4, 4] interpolated SE(3) matrix
+ """
+ R1 = T1[..., :3, :3]
+ R2 = T2[..., :3, :3]
+ trans1 = T1[..., :3, 3]
+ trans2 = T2[..., :3, 3]
+ R = slerp_rotation_matrix(R1, R2, t)
+ trans = lerp(trans1, trans2, t)
+ return se3_matrix(R, trans)
+
+
+def piecewise_lerp(x: np.ndarray, t: np.ndarray, s: np.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> np.ndarray:
+ """
+ Linear spline interpolation.
+
+ ### Parameters:
+ - `x`: np.ndarray, shape (n, d): the values of data points.
+ - `t`: np.ndarray, shape (n,): the times of the data points.
+ - `s`: np.ndarray, shape (m,): the times to be interpolated.
+ - `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly.
+
+ ### Returns:
+ - `y`: np.ndarray, shape (..., m, d): the interpolated values.
+ """
+ i = np.searchsorted(t, s, side='left')
+ if extrapolation_mode == 'constant':
+ prev = np.clip(i - 1, 0, len(t) - 1)
+ suc = np.clip(i, 0, len(t) - 1)
+ elif extrapolation_mode == 'linear':
+ prev = np.clip(i - 1, 0, len(t) - 2)
+ suc = np.clip(i, 1, len(t) - 1)
+ else:
+ raise ValueError(f'Invalid extrapolation_mode: {extrapolation_mode}')
+
+ u = (s - t[prev]) / np.maximum(t[suc] - t[prev], 1e-12)
+ y = lerp(x[prev], x[suc], u)
+
+ return y
+
+
+def piecewise_lerp_se3_matrix(T: np.ndarray, t: np.ndarray, s: np.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> np.ndarray:
+ """
+ Linear spline interpolation for SE(3) matrices.
+
+ ### Parameters:
+ - `T`: np.ndarray, shape (n, 4, 4): the SE(3) matrices.
+ - `t`: np.ndarray, shape (n,): the times of the data points.
+ - `s`: np.ndarray, shape (m,): the times to be interpolated.
+ - `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly.
+
+ ### Returns:
+ - `T_interp`: np.ndarray, shape (..., m, 4, 4): the interpolated SE(3) matrices.
+ """
+ i = np.searchsorted(t, s, side='left')
+ if extrapolation_mode == 'constant':
+ prev = np.clip(i - 1, 0, len(t) - 1)
+ suc = np.clip(i, 0, len(t) - 1)
+ elif extrapolation_mode == 'linear':
+ prev = np.clip(i - 1, 0, len(t) - 2)
+ suc = np.clip(i, 1, len(t) - 1)
+ else:
+ raise ValueError(f'Invalid extrapolation_mode: {extrapolation_mode}')
+
+ u = (s - t[prev]) / np.maximum(t[suc] - t[prev], 1e-12)
+ T = lerp_se3_matrix(T[prev], T[suc], u)
+
+ return T
+
+
+def apply_transform(T: np.ndarray, x: np.ndarray) -> np.ndarray:
+ """
+ Apply SE(3) transformation to a point or a set of points.
+
+ ### Parameters:
+ - `T`: np.ndarray, shape (..., 4, 4): the SE(3) matrix.
+ - `x`: np.ndarray, shape (..., 3): the point or a set of points to be transformed.
+
+ ### Returns:
+ - `x_transformed`: np.ndarray, shape (..., 3): the transformed point or a set of points.
+ """
+ x = np.asarray(x)
+ assert x.shape[-1] == 3
+ T = np.asarray(T)
+ assert T.shape[-2:] == (4, 4)
+ x_transformed = (T[..., :3, :3] @ x[..., :, None]) + T[..., :3, 3][..., None]
+ return x_transformed[..., 0]
\ No newline at end of file
diff --git a/src/utils3d/numpy/utils.py b/src/utils3d/numpy/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f3fbfccef9eb8224f658307c425125095f4c2a7
--- /dev/null
+++ b/src/utils3d/numpy/utils.py
@@ -0,0 +1,625 @@
+import numpy as np
+from typing import *
+from numbers import Number
+import warnings
+import functools
+
+from ._helpers import batched
+from .._helpers import no_warnings
+from . import transforms
+from . import mesh
+
+__all__ = [
+ 'sliding_window_1d',
+ 'sliding_window_nd',
+ 'sliding_window_2d',
+ 'max_pool_1d',
+ 'max_pool_2d',
+ 'max_pool_nd',
+ 'depth_edge',
+ 'normals_edge',
+ 'depth_aliasing',
+ 'interpolate',
+ 'image_scrcoord',
+ 'image_uv',
+ 'image_pixel_center',
+ 'image_pixel',
+ 'image_mesh',
+ 'image_mesh_from_depth',
+ 'points_to_normals',
+ 'points_to_normals',
+ 'chessboard',
+ 'cube',
+ 'icosahedron',
+ 'square',
+ 'camera_frustum',
+ 'to4x4'
+]
+
+
+
+def sliding_window_1d(x: np.ndarray, window_size: int, stride: int, axis: int = -1):
+ """
+ Return x view of the input array with x sliding window of the given kernel size and stride.
+ The sliding window is performed over the given axis, and the window dimension is append to the end of the output array's shape.
+
+ Args:
+ x (np.ndarray): input array with shape (..., axis_size, ...)
+ kernel_size (int): size of the sliding window
+ stride (int): stride of the sliding window
+ axis (int): axis to perform sliding window over
+
+ Returns:
+ a_sliding (np.ndarray): view of the input array with shape (..., n_windows, ..., kernel_size), where n_windows = (axis_size - kernel_size + 1) // stride
+ """
+ assert x.shape[axis] >= window_size, f"kernel_size ({window_size}) is larger than axis_size ({x.shape[axis]})"
+ axis = axis % x.ndim
+ shape = (*x.shape[:axis], (x.shape[axis] - window_size + 1) // stride, *x.shape[axis + 1:], window_size)
+ strides = (*x.strides[:axis], stride * x.strides[axis], *x.strides[axis + 1:], x.strides[axis])
+ x_sliding = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
+ return x_sliding
+
+
+def sliding_window_nd(x: np.ndarray, window_size: Tuple[int,...], stride: Tuple[int,...], axis: Tuple[int,...]) -> np.ndarray:
+ axis = [axis[i] % x.ndim for i in range(len(axis))]
+ for i in range(len(axis)):
+ x = sliding_window_1d(x, window_size[i], stride[i], axis[i])
+ return x
+
+
+def sliding_window_2d(x: np.ndarray, window_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], axis: Tuple[int, int] = (-2, -1)) -> np.ndarray:
+ if isinstance(window_size, int):
+ window_size = (window_size, window_size)
+ if isinstance(stride, int):
+ stride = (stride, stride)
+ return sliding_window_nd(x, window_size, stride, axis)
+
+
+def max_pool_1d(x: np.ndarray, kernel_size: int, stride: int, padding: int = 0, axis: int = -1):
+ axis = axis % x.ndim
+ if padding > 0:
+ fill_value = np.nan if x.dtype.kind == 'f' else np.iinfo(x.dtype).min
+ padding_arr = np.full((*x.shape[:axis], padding, *x.shape[axis + 1:]), fill_value=fill_value, dtype=x.dtype)
+ x = np.concatenate([padding_arr, x, padding_arr], axis=axis)
+ a_sliding = sliding_window_1d(x, kernel_size, stride, axis)
+ max_pool = np.nanmax(a_sliding, axis=-1)
+ return max_pool
+
+
+def max_pool_nd(x: np.ndarray, kernel_size: Tuple[int,...], stride: Tuple[int,...], padding: Tuple[int,...], axis: Tuple[int,...]) -> np.ndarray:
+ for i in range(len(axis)):
+ x = max_pool_1d(x, kernel_size[i], stride[i], padding[i], axis[i])
+ return x
+
+
+def max_pool_2d(x: np.ndarray, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], padding: Union[int, Tuple[int, int]], axis: Tuple[int, int] = (-2, -1)):
+ if isinstance(kernel_size, Number):
+ kernel_size = (kernel_size, kernel_size)
+ if isinstance(stride, Number):
+ stride = (stride, stride)
+ if isinstance(padding, Number):
+ padding = (padding, padding)
+ axis = tuple(axis)
+ return max_pool_nd(x, kernel_size, stride, padding, axis)
+
+@no_warnings(category=RuntimeWarning)
+def depth_edge(depth: np.ndarray, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: np.ndarray = None) -> np.ndarray:
+ """
+ Compute the edge mask from depth map. The edge is defined as the pixels whose neighbors have large difference in depth.
+
+ Args:
+ depth (np.ndarray): shape (..., height, width), linear depth map
+ atol (float): absolute tolerance
+ rtol (float): relative tolerance
+
+ Returns:
+ edge (np.ndarray): shape (..., height, width) of dtype torch.bool
+ """
+ if mask is None:
+ diff = (max_pool_2d(depth, kernel_size, stride=1, padding=kernel_size // 2) + max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2))
+ else:
+ diff = (max_pool_2d(np.where(mask, depth, -np.inf), kernel_size, stride=1, padding=kernel_size // 2) + max_pool_2d(np.where(mask, -depth, -np.inf), kernel_size, stride=1, padding=kernel_size // 2))
+
+ edge = np.zeros_like(depth, dtype=bool)
+ if atol is not None:
+ edge |= diff > atol
+
+ if rtol is not None:
+ edge |= diff / depth > rtol
+ return edge
+
+
+def depth_aliasing(depth: np.ndarray, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: np.ndarray = None) -> np.ndarray:
+ """
+ Compute the map that indicates the aliasing of x depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors.
+ Args:
+ depth (np.ndarray): shape (..., height, width), linear depth map
+ atol (float): absolute tolerance
+ rtol (float): relative tolerance
+
+ Returns:
+ edge (np.ndarray): shape (..., height, width) of dtype torch.bool
+ """
+ if mask is None:
+ diff_max = max_pool_2d(depth, kernel_size, stride=1, padding=kernel_size // 2) - depth
+ diff_min = max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + depth
+ else:
+ diff_max = max_pool_2d(np.where(mask, depth, -np.inf), kernel_size, stride=1, padding=kernel_size // 2) - depth
+ diff_min = max_pool_2d(np.where(mask, -depth, -np.inf), kernel_size, stride=1, padding=kernel_size // 2) + depth
+ diff = np.minimum(diff_max, diff_min)
+
+ edge = np.zeros_like(depth, dtype=bool)
+ if atol is not None:
+ edge |= diff > atol
+ if rtol is not None:
+ edge |= diff / depth > rtol
+ return edge
+
+@no_warnings(category=RuntimeWarning)
+def normals_edge(normals: np.ndarray, tol: float, kernel_size: int = 3, mask: np.ndarray = None) -> np.ndarray:
+ """
+ Compute the edge mask from normal map.
+
+ Args:
+ normal (np.ndarray): shape (..., height, width, 3), normal map
+ tol (float): tolerance in degrees
+
+ Returns:
+ edge (np.ndarray): shape (..., height, width) of dtype torch.bool
+ """
+ assert normals.ndim >= 3 and normals.shape[-1] == 3, "normal should be of shape (..., height, width, 3)"
+ normals = normals / (np.linalg.norm(normals, axis=-1, keepdims=True) + 1e-12)
+
+ padding = kernel_size // 2
+ normals_window = sliding_window_2d(
+ np.pad(normals, (*([(0, 0)] * (normals.ndim - 3)), (padding, padding), (padding, padding), (0, 0)), mode='edge'),
+ window_size=kernel_size,
+ stride=1,
+ axis=(-3, -2)
+ )
+ if mask is None:
+ angle_diff = np.arccos((normals[..., None, None] * normals_window).sum(axis=-3)).max(axis=(-2, -1))
+ else:
+ mask_window = sliding_window_2d(
+ np.pad(mask, (*([(0, 0)] * (mask.ndim - 3)), (padding, padding), (padding, padding)), mode='edge'),
+ window_size=kernel_size,
+ stride=1,
+ axis=(-3, -2)
+ )
+ angle_diff = np.where(mask_window, np.arccos((normals[..., None, None] * normals_window).sum(axis=-3)), 0).max(axis=(-2, -1))
+
+ angle_diff = max_pool_2d(angle_diff, kernel_size, stride=1, padding=kernel_size // 2)
+ edge = angle_diff > np.deg2rad(tol)
+ return edge
+
+
+@no_warnings(category=RuntimeWarning)
+def points_to_normals(point: np.ndarray, mask: np.ndarray = None) -> np.ndarray:
+ """
+ Calculate normal map from point map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system.
+
+ Args:
+ point (np.ndarray): shape (height, width, 3), point map
+ Returns:
+ normal (np.ndarray): shape (height, width, 3), normal map.
+ """
+ height, width = point.shape[-3:-1]
+ has_mask = mask is not None
+
+ if mask is None:
+ mask = np.ones_like(point[..., 0], dtype=bool)
+ mask_pad = np.zeros((height + 2, width + 2), dtype=bool)
+ mask_pad[1:-1, 1:-1] = mask
+ mask = mask_pad
+
+ pts = np.zeros((height + 2, width + 2, 3), dtype=point.dtype)
+ pts[1:-1, 1:-1, :] = point
+ up = pts[:-2, 1:-1, :] - pts[1:-1, 1:-1, :]
+ left = pts[1:-1, :-2, :] - pts[1:-1, 1:-1, :]
+ down = pts[2:, 1:-1, :] - pts[1:-1, 1:-1, :]
+ right = pts[1:-1, 2:, :] - pts[1:-1, 1:-1, :]
+ normal = np.stack([
+ np.cross(up, left, axis=-1),
+ np.cross(left, down, axis=-1),
+ np.cross(down, right, axis=-1),
+ np.cross(right, up, axis=-1),
+ ])
+ normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12)
+ valid = np.stack([
+ mask[:-2, 1:-1] & mask[1:-1, :-2],
+ mask[1:-1, :-2] & mask[2:, 1:-1],
+ mask[2:, 1:-1] & mask[1:-1, 2:],
+ mask[1:-1, 2:] & mask[:-2, 1:-1],
+ ]) & mask[None, 1:-1, 1:-1]
+ normal = (normal * valid[..., None]).sum(axis=0)
+ normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12)
+
+ if has_mask:
+ normal_mask = valid.any(axis=0)
+ normal = np.where(normal_mask[..., None], normal, 0)
+ return normal, normal_mask
+ else:
+ return normal
+
+
+def depth_to_normals(depth: np.ndarray, intrinsics: np.ndarray, mask: np.ndarray = None) -> np.ndarray:
+ """
+ Calculate normal map from depth map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system.
+
+ Args:
+ depth (np.ndarray): shape (height, width), linear depth map
+ intrinsics (np.ndarray): shape (3, 3), intrinsics matrix
+ Returns:
+ normal (np.ndarray): shape (height, width, 3), normal map.
+ """
+ has_mask = mask is not None
+
+ height, width = depth.shape[-2:]
+ if mask is None:
+ mask = np.ones_like(depth, dtype=bool)
+
+ uv = image_uv(width=width, height=height, dtype=np.float32)
+ pts = transforms.unproject_cv(uv, depth, intrinsics=intrinsics, extrinsics=None)
+
+ return points_to_normals(pts, mask)
+
+def interpolate(bary: np.ndarray, tri_id: np.ndarray, attr: np.ndarray, faces: np.ndarray) -> np.ndarray:
+ """Interpolate with given barycentric coordinates and triangle indices
+
+ Args:
+ bary (np.ndarray): shape (..., 3), barycentric coordinates
+ tri_id (np.ndarray): int array of shape (...), triangle indices
+ attr (np.ndarray): shape (N, M), vertices attributes
+ faces (np.ndarray): int array of shape (T, 3), face vertex indices
+
+ Returns:
+ np.ndarray: shape (..., M) interpolated result
+ """
+ faces_ = np.concatenate([np.zeros((1, 3), dtype=faces.dtype), faces + 1], axis=0)
+ attr_ = np.concatenate([np.zeros((1, attr.shape[1]), dtype=attr.dtype), attr], axis=0)
+ return np.sum(bary[..., None] * attr_[faces_[tri_id + 1]], axis=-2)
+
+
+def image_scrcoord(
+ width: int,
+ height: int,
+) -> np.ndarray:
+ """
+ Get OpenGL's screen space coordinates, ranging in [0, 1].
+ [0, 0] is the bottom-left corner of the image.
+
+ Args:
+ width (int): image width
+ height (int): image height
+
+ Returns:
+ (np.ndarray): shape (height, width, 2)
+ """
+ x, y = np.meshgrid(
+ np.linspace(0.5 / width, 1 - 0.5 / width, width, dtype=np.float32),
+ np.linspace(1 - 0.5 / height, 0.5 / height, height, dtype=np.float32),
+ indexing='xy'
+ )
+ return np.stack([x, y], axis=2)
+
+
+def image_uv(
+ height: int,
+ width: int,
+ left: int = None,
+ top: int = None,
+ right: int = None,
+ bottom: int = None,
+ dtype: np.dtype = np.float32
+) -> np.ndarray:
+ """
+ Get image space UV grid, ranging in [0, 1].
+
+ >>> image_uv(10, 10):
+ [[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]],
+ [[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]],
+ ... ... ...
+ [[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]]
+
+ Args:
+ width (int): image width
+ height (int): image height
+
+ Returns:
+ np.ndarray: shape (height, width, 2)
+ """
+ if left is None: left = 0
+ if top is None: top = 0
+ if right is None: right = width
+ if bottom is None: bottom = height
+ u = np.linspace((left + 0.5) / width, (right - 0.5) / width, right - left, dtype=dtype)
+ v = np.linspace((top + 0.5) / height, (bottom - 0.5) / height, bottom - top, dtype=dtype)
+ u, v = np.meshgrid(u, v, indexing='xy')
+ return np.stack([u, v], axis=2)
+
+
+def image_pixel_center(
+ height: int,
+ width: int,
+ left: int = None,
+ top: int = None,
+ right: int = None,
+ bottom: int = None,
+ dtype: np.dtype = np.float32
+) -> np.ndarray:
+ """
+ Get image pixel center coordinates, ranging in [0, width] and [0, height].
+ `image[i, j]` has pixel center coordinates `(j + 0.5, i + 0.5)`.
+
+ >>> image_pixel_center(10, 10):
+ [[[0.5, 0.5], [1.5, 0.5], ..., [9.5, 0.5]],
+ [[0.5, 1.5], [1.5, 1.5], ..., [9.5, 1.5]],
+ ... ... ...
+ [[0.5, 9.5], [1.5, 9.5], ..., [9.5, 9.5]]]
+
+ Args:
+ width (int): image width
+ height (int): image height
+
+ Returns:
+ np.ndarray: shape (height, width, 2)
+ """
+ if left is None: left = 0
+ if top is None: top = 0
+ if right is None: right = width
+ if bottom is None: bottom = height
+ u = np.linspace(left + 0.5, right - 0.5, right - left, dtype=dtype)
+ v = np.linspace(top + 0.5, bottom - 0.5, bottom - top, dtype=dtype)
+ u, v = np.meshgrid(u, v, indexing='xy')
+ return np.stack([u, v], axis=2)
+
+def image_pixel(
+ height: int,
+ width: int,
+ left: int = None,
+ top: int = None,
+ right: int = None,
+ bottom: int = None,
+ dtype: np.dtype = np.int32
+) -> np.ndarray:
+ """
+ Get image pixel coordinates grid, ranging in [0, width - 1] and [0, height - 1].
+ `image[i, j]` has pixel center coordinates `(j, i)`.
+
+ >>> image_pixel_center(10, 10):
+ [[[0, 0], [1, 0], ..., [9, 0]],
+ [[0, 1.5], [1, 1], ..., [9, 1]],
+ ... ... ...
+ [[0, 9.5], [1, 9], ..., [9, 9 ]]]
+
+ Args:
+ width (int): image width
+ height (int): image height
+
+ Returns:
+ np.ndarray: shape (height, width, 2)
+ """
+ if left is None: left = 0
+ if top is None: top = 0
+ if right is None: right = width
+ if bottom is None: bottom = height
+ u = np.arange(left, right, dtype=dtype)
+ v = np.arange(top, bottom, dtype=dtype)
+ u, v = np.meshgrid(u, v, indexing='xy')
+ return np.stack([u, v], axis=2)
+
+
+def image_mesh(
+ *image_attrs: np.ndarray,
+ mask: np.ndarray = None,
+ tri: bool = False,
+ return_indices: bool = False
+) -> Tuple[np.ndarray, ...]:
+ """
+ Get a mesh regarding image pixel uv coordinates as vertices and image grid as faces.
+
+ Args:
+ *image_attrs (np.ndarray): image attributes in shape (height, width, [channels])
+ mask (np.ndarray, optional): binary mask of shape (height, width), dtype=bool. Defaults to None.
+
+ Returns:
+ faces (np.ndarray): faces connecting neighboring pixels. shape (T, 4) if tri is False, else (T, 3)
+ *vertex_attrs (np.ndarray): vertex attributes in corresponding order with input image_attrs
+ indices (np.ndarray, optional): indices of vertices in the original mesh
+ """
+ assert (len(image_attrs) > 0) or (mask is not None), "At least one of image_attrs or mask should be provided"
+ height, width = next(image_attrs).shape[:2] if mask is None else mask.shape
+ assert all(img.shape[:2] == (height, width) for img in image_attrs), "All image_attrs should have the same shape"
+
+ row_faces = np.stack([np.arange(0, width - 1, dtype=np.int32), np.arange(width, 2 * width - 1, dtype=np.int32), np.arange(1 + width, 2 * width, dtype=np.int32), np.arange(1, width, dtype=np.int32)], axis=1)
+ faces = (np.arange(0, (height - 1) * width, width, dtype=np.int32)[:, None, None] + row_faces[None, :, :]).reshape((-1, 4))
+ if mask is None:
+ if tri:
+ faces = mesh.triangulate(faces)
+ ret = [faces, *(img.reshape(-1, *img.shape[2:]) for img in image_attrs)]
+ if return_indices:
+ ret.append(np.arange(height * width, dtype=np.int32))
+ return tuple(ret)
+ else:
+ quad_mask = (mask[:-1, :-1] & mask[1:, :-1] & mask[1:, 1:] & mask[:-1, 1:]).ravel()
+ faces = faces[quad_mask]
+ if tri:
+ faces = mesh.triangulate(faces)
+ return mesh.remove_unreferenced_vertices(
+ faces,
+ *(x.reshape(-1, *x.shape[2:]) for x in image_attrs),
+ return_indices=return_indices
+ )
+
+def image_mesh_from_depth(
+ depth: np.ndarray,
+ extrinsics: np.ndarray = None,
+ intrinsics: np.ndarray = None,
+ *vertice_attrs: np.ndarray,
+ atol: float = None,
+ rtol: float = None,
+ remove_by_depth: bool = False,
+ return_uv: bool = False,
+ return_indices: bool = False
+) -> Tuple[np.ndarray, ...]:
+ """
+ Get x triangle mesh by lifting depth map to 3D.
+
+ Args:
+ depth (np.ndarray): [H, W] depth map
+ extrinsics (np.ndarray, optional): [4, 4] extrinsics matrix. Defaults to None.
+ intrinsics (np.ndarray, optional): [3, 3] intrinsics matrix. Defaults to None.
+ *vertice_attrs (np.ndarray): [H, W, C] vertex attributes. Defaults to None.
+ atol (float, optional): absolute tolerance. Defaults to None.
+ rtol (float, optional): relative tolerance. Defaults to None.
+ triangles with vertices having depth difference larger than atol + rtol * depth will be marked.
+ remove_by_depth (bool, optional): whether to remove triangles with large depth difference. Defaults to True.
+ return_uv (bool, optional): whether to return uv coordinates. Defaults to False.
+ return_indices (bool, optional): whether to return indices of vertices in the original mesh. Defaults to False.
+
+ Returns:
+ vertices (np.ndarray): [N, 3] vertices
+ faces (np.ndarray): [T, 3] faces
+ *vertice_attrs (np.ndarray): [N, C] vertex attributes
+ image_uv (np.ndarray, optional): [N, 2] uv coordinates
+ ref_indices (np.ndarray, optional): [N] indices of vertices in the original mesh
+ """
+ height, width = depth.shape
+ image_uv, image_face = image_mesh(height, width)
+ depth = depth.reshape(-1)
+ pts = transforms.unproject_cv(image_uv, depth, extrinsics, intrinsics)
+ image_face = mesh.triangulate(image_face, vertices=pts)
+ ref_indices = None
+ ret = []
+ if atol is not None or rtol is not None:
+ atol = 0 if atol is None else atol
+ rtol = 0 if rtol is None else rtol
+ mean = depth[image_face].mean(axis=1)
+ diff = np.max(np.abs(depth[image_face] - depth[image_face[:, [1, 2, 0]]]), axis=1)
+ mask = (diff <= atol + rtol * mean)
+ image_face_ = image_face[mask]
+ image_face_, ref_indices = mesh.remove_unreferenced_vertices(image_face_, return_indices=True)
+
+ remove = remove_by_depth and ref_indices is not None
+ if remove:
+ pts = pts[ref_indices]
+ image_face = image_face_
+ ret += [pts, image_face]
+ for attr in vertice_attrs:
+ ret.append(attr.reshape(-1, attr.shape[-1]) if not remove else attr.reshape(-1, attr.shape[-1])[ref_indices])
+ if return_uv:
+ ret.append(image_uv if not remove else image_uv[ref_indices])
+ if return_indices and ref_indices is not None:
+ ret.append(ref_indices)
+ return tuple(ret)
+
+
+def chessboard(width: int, height: int, grid_size: int, color_a: np.ndarray, color_b: np.ndarray) -> np.ndarray:
+ """get x chessboard image
+
+ Args:
+ width (int): image width
+ height (int): image height
+ grid_size (int): size of chessboard grid
+ color_a (np.ndarray): color of the grid at the top-left corner
+ color_b (np.ndarray): color in complementary grid cells
+
+ Returns:
+ image (np.ndarray): shape (height, width, channels), chessboard image
+ """
+ x = np.arange(width) // grid_size
+ y = np.arange(height) // grid_size
+ mask = (x[None, :] + y[:, None]) % 2
+ image = (1 - mask[..., None]) * color_a + mask[..., None] * color_b
+ return image
+
+
+def square(tri: bool = False) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Get a square mesh of area 1 centered at origin in the xy-plane.
+
+ ### Returns
+ vertices (np.ndarray): shape (4, 3)
+ faces (np.ndarray): shape (1, 4)
+ """
+ vertices = np.array([
+ [-0.5, 0.5, 0], [0.5, 0.5, 0], [0.5, -0.5, 0], [-0.5, -0.5, 0] # v0-v1-v2-v3
+ ], dtype=np.float32)
+ if tri:
+ faces = np.array([[0, 1, 2], [0, 2, 3]], dtype=np.int32)
+ else:
+ faces = np.array([[0, 1, 2, 3]], dtype=np.int32)
+ return vertices, faces
+
+
+def cube(tri: bool = False) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Get x cube mesh of size 1 centered at origin.
+
+ ### Parameters
+ tri (bool, optional): return triangulated mesh. Defaults to False, which returns quad mesh.
+
+ ### Returns
+ vertices (np.ndarray): shape (8, 3)
+ faces (np.ndarray): shape (12, 3)
+ """
+ vertices = np.array([
+ [-0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [0.5, -0.5, 0.5], [-0.5, -0.5, 0.5], # v0-v1-v2-v3
+ [-0.5, 0.5, -0.5], [0.5, 0.5, -0.5], [0.5, -0.5, -0.5], [-0.5, -0.5, -0.5] # v4-v5-v6-v7
+ ], dtype=np.float32).reshape((-1, 3))
+
+ faces = np.array([
+ [0, 1, 2, 3], # v0-v1-v2-v3 (front)
+ [4, 5, 1, 0], # v4-v5-v1-v0 (top)
+ [3, 2, 6, 7], # v3-v2-v6-v7 (bottom)
+ [5, 4, 7, 6], # v5-v4-v7-v6 (back)
+ [1, 5, 6, 2], # v1-v5-v6-v2 (right)
+ [4, 0, 3, 7] # v4-v0-v3-v7 (left)
+ ], dtype=np.int32)
+
+ if tri:
+ faces = mesh.triangulate(faces, vertices=vertices)
+
+ return vertices, faces
+
+
+def camera_frustum(extrinsics: np.ndarray, intrinsics: np.ndarray, depth: float = 1.0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Get x triangle mesh of camera frustum.
+ """
+ assert extrinsics.shape == (4, 4) and intrinsics.shape == (3, 3)
+ vertices = transforms.unproject_cv(
+ np.array([[0, 0], [0, 0], [0, 1], [1, 1], [1, 0]], dtype=np.float32),
+ np.array([0] + [depth] * 4, dtype=np.float32),
+ extrinsics,
+ intrinsics
+ ).astype(np.float32)
+ edges = np.array([
+ [0, 1], [0, 2], [0, 3], [0, 4],
+ [1, 2], [2, 3], [3, 4], [4, 1]
+ ], dtype=np.int32)
+ faces = np.array([
+ [0, 1, 2],
+ [0, 2, 3],
+ [0, 3, 4],
+ [0, 4, 1],
+ [1, 2, 3],
+ [1, 3, 4]
+ ], dtype=np.int32)
+ return vertices, edges, faces
+
+
+def icosahedron():
+ A = (1 + 5 ** 0.5) / 2
+ vertices = np.array([
+ [0, 1, A], [0, -1, A], [0, 1, -A], [0, -1, -A],
+ [1, A, 0], [-1, A, 0], [1, -A, 0], [-1, -A, 0],
+ [A, 0, 1], [A, 0, -1], [-A, 0, 1], [-A, 0, -1]
+ ], dtype=np.float32)
+ faces = np.array([
+ [0, 1, 8], [0, 8, 4], [0, 4, 5], [0, 5, 10], [0, 10, 1],
+ [3, 2, 9], [3, 9, 6], [3, 6, 7], [3, 7, 11], [3, 11, 2],
+ [1, 6, 8], [8, 9, 4], [4, 2, 5], [5, 11, 10], [10, 7, 1],
+ [2, 4, 9], [9, 8, 6], [6, 1, 7], [7, 10, 11], [11, 5, 2]
+ ], dtype=np.int32)
+ return vertices, faces
diff --git a/src/utils3d/torch/__init__.py b/src/utils3d/torch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bffcf41b5e906c25f8c8f01fb0a1b557151103c1
--- /dev/null
+++ b/src/utils3d/torch/__init__.py
@@ -0,0 +1,139 @@
+import importlib
+import itertools
+import torch
+from typing import TYPE_CHECKING
+
+__modules_all__ = {
+ 'mesh': [
+ 'triangulate',
+ 'compute_face_normal',
+ 'compute_face_angles',
+ 'compute_vertex_normal',
+ 'compute_vertex_normal_weighted',
+ 'compute_edges',
+ 'compute_connected_components',
+ 'compute_edge_connected_components',
+ 'compute_boundarys',
+ 'compute_dual_graph',
+ 'remove_unreferenced_vertices',
+ 'remove_corrupted_faces',
+ 'remove_isolated_pieces',
+ 'merge_duplicate_vertices',
+ 'subdivide_mesh_simple',
+ 'compute_face_tbn',
+ 'compute_vertex_tbn',
+ 'laplacian',
+ 'laplacian_smooth_mesh',
+ 'taubin_smooth_mesh',
+ 'laplacian_hc_smooth_mesh',
+ ],
+ 'nerf': [
+ 'get_rays',
+ 'get_image_rays',
+ 'get_mipnerf_cones',
+ 'volume_rendering',
+ 'bin_sample',
+ 'importance_sample',
+ 'nerf_render_rays',
+ 'mipnerf_render_rays',
+ 'nerf_render_view',
+ 'mipnerf_render_view',
+ 'InstantNGP',
+ ],
+ 'utils': [
+ 'sliding_window_1d',
+ 'sliding_window_2d',
+ 'sliding_window_nd',
+ 'image_uv',
+ 'image_pixel_center',
+ 'image_mesh',
+ 'chessboard',
+ 'depth_edge',
+ 'depth_aliasing',
+ 'image_mesh_from_depth',
+ 'point_to_normal',
+ 'depth_to_normal',
+ 'masked_min',
+ 'masked_max',
+ 'bounding_rect'
+ ],
+ 'transforms': [
+ 'perspective',
+ 'perspective_from_fov',
+ 'perspective_from_fov_xy',
+ 'intrinsics_from_focal_center',
+ 'intrinsics_from_fov',
+ 'intrinsics_from_fov_xy',
+ 'view_look_at',
+ 'extrinsics_look_at',
+ 'perspective_to_intrinsics',
+ 'intrinsics_to_perspective',
+ 'extrinsics_to_view',
+ 'view_to_extrinsics',
+ 'normalize_intrinsics',
+ 'crop_intrinsics',
+ 'pixel_to_uv',
+ 'pixel_to_ndc',
+ 'uv_to_pixel',
+ 'project_depth',
+ 'depth_buffer_to_linear',
+ 'project_gl',
+ 'project_cv',
+ 'unproject_gl',
+ 'unproject_cv',
+ 'skew_symmetric',
+ 'rotation_matrix_from_vectors',
+ 'euler_axis_angle_rotation',
+ 'euler_angles_to_matrix',
+ 'matrix_to_euler_angles',
+ 'matrix_to_quaternion',
+ 'quaternion_to_matrix',
+ 'matrix_to_axis_angle',
+ 'axis_angle_to_matrix',
+ 'axis_angle_to_quaternion',
+ 'quaternion_to_axis_angle',
+ 'slerp',
+ 'interpolate_extrinsics',
+ 'interpolate_view',
+ 'extrinsics_to_essential',
+ 'to4x4',
+ 'rotation_matrix_2d',
+ 'rotate_2d',
+ 'translate_2d',
+ 'scale_2d',
+ 'apply_2d',
+ ],
+ 'rasterization': [
+ 'RastContext',
+ 'rasterize_triangle_faces',
+ 'warp_image_by_depth',
+ 'warp_image_by_forward_flow',
+ ],
+}
+
+
+__all__ = list(itertools.chain(*__modules_all__.values()))
+
+def __getattr__(name):
+ try:
+ return globals()[name]
+ except KeyError:
+ pass
+
+ try:
+ module_name = next(m for m in __modules_all__ if name in __modules_all__[m])
+ except StopIteration:
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
+ module = importlib.import_module(f'.{module_name}', __name__)
+ for key in __modules_all__[module_name]:
+ globals()[key] = getattr(module, key)
+
+ return globals()[name]
+
+
+if TYPE_CHECKING:
+ from .transforms import *
+ from .mesh import *
+ from .utils import *
+ from .nerf import *
+ from .rasterization import *
\ No newline at end of file
diff --git a/src/utils3d/torch/_helpers.py b/src/utils3d/torch/_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..442e2cb6358ba4f105d2664f9b9f44b6ec6561ca
--- /dev/null
+++ b/src/utils3d/torch/_helpers.py
@@ -0,0 +1,103 @@
+# decorator
+import torch
+from numbers import Number
+import inspect
+from functools import wraps
+from .._helpers import suppress_traceback
+
+
+def get_device(args, kwargs):
+ device = None
+ for arg in (list(args) + list(kwargs.values())):
+ if isinstance(arg, torch.Tensor):
+ if device is None:
+ device = arg.device
+ elif device != arg.device:
+ raise ValueError("All tensors must be on the same device.")
+ return device
+
+
+def get_args_order(func, args, kwargs):
+ """
+ Get the order of the arguments of a function.
+ """
+ names = inspect.getfullargspec(func).args
+ names_idx = {name: i for i, name in enumerate(names)}
+ args_order = []
+ kwargs_order = {}
+ for name, arg in kwargs.items():
+ if name in names:
+ kwargs_order[name] = names_idx[name]
+ names.remove(name)
+ for i, arg in enumerate(args):
+ if i < len(names):
+ args_order.append(names_idx[names[i]])
+ return args_order, kwargs_order
+
+
+def broadcast_args(args, kwargs, args_dim, kwargs_dim):
+ spatial = []
+ for arg, arg_dim in zip(args + list(kwargs.values()), args_dim + list(kwargs_dim.values())):
+ if isinstance(arg, torch.Tensor) and arg_dim is not None:
+ arg_spatial = arg.shape[:arg.ndim-arg_dim]
+ if len(arg_spatial) > len(spatial):
+ spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial
+ for j in range(len(arg_spatial)):
+ if spatial[-j] < arg_spatial[-j]:
+ if spatial[-j] == 1:
+ spatial[-j] = arg_spatial[-j]
+ else:
+ raise ValueError("Cannot broadcast arguments.")
+ for i, arg in enumerate(args):
+ if isinstance(arg, torch.Tensor) and args_dim[i] is not None:
+ args[i] = torch.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-args_dim[i]:]])
+ for key, arg in kwargs.items():
+ if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None:
+ kwargs[key] = torch.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-kwargs_dim[key]:]])
+ return args, kwargs, spatial
+
+@suppress_traceback
+def batched(*dims):
+ """
+ Decorator that allows a function to be called with batched arguments.
+ """
+ def decorator(func):
+ @wraps(func)
+ def wrapper(*args, device=torch.device('cpu'), **kwargs):
+ args = list(args)
+ # get arguments dimensions
+ args_order, kwargs_order = get_args_order(func, args, kwargs)
+ args_dim = [dims[i] for i in args_order]
+ kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()}
+ # convert to torch tensor
+ device = get_device(args, kwargs) or device
+ for i, arg in enumerate(args):
+ if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None:
+ args[i] = torch.tensor(arg, device=device)
+ for key, arg in kwargs.items():
+ if isinstance(arg, (Number, list, tuple)) and kwargs_dim[key] is not None:
+ kwargs[key] = torch.tensor(arg, device=device)
+ # broadcast arguments
+ args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim)
+ for i, (arg, arg_dim) in enumerate(zip(args, args_dim)):
+ if isinstance(arg, torch.Tensor) and arg_dim is not None:
+ args[i] = arg.reshape([-1, *arg.shape[arg.ndim-arg_dim:]])
+ for key, arg in kwargs.items():
+ if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None:
+ kwargs[key] = arg.reshape([-1, *arg.shape[arg.ndim-kwargs_dim[key]:]])
+ # call function
+ results = func(*args, **kwargs)
+ type_results = type(results)
+ results = list(results) if isinstance(results, (tuple, list)) else [results]
+ # restore spatial dimensions
+ for i, result in enumerate(results):
+ results[i] = result.reshape([*spatial, *result.shape[1:]])
+ if type_results == tuple:
+ results = tuple(results)
+ elif type_results == list:
+ results = list(results)
+ else:
+ results = results[0]
+ return results
+ return wrapper
+ return decorator
\ No newline at end of file
diff --git a/src/utils3d/torch/mesh.py b/src/utils3d/torch/mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b874d163e5edad3ef871a276b4edccf2e593265
--- /dev/null
+++ b/src/utils3d/torch/mesh.py
@@ -0,0 +1,688 @@
+import torch
+import torch.nn.functional as F
+from typing import *
+from ._helpers import batched
+
+
+__all__ = [
+ 'triangulate',
+ 'compute_face_normal',
+ 'compute_face_angles',
+ 'compute_vertex_normal',
+ 'compute_vertex_normal_weighted',
+ 'compute_edges',
+ 'compute_connected_components',
+ 'compute_edge_connected_components',
+ 'compute_boundarys',
+ 'compute_dual_graph',
+ 'remove_unreferenced_vertices',
+ 'remove_corrupted_faces',
+ 'remove_isolated_pieces',
+ 'merge_duplicate_vertices',
+ 'subdivide_mesh_simple',
+ 'compute_face_tbn',
+ 'compute_vertex_tbn',
+ 'laplacian',
+ 'laplacian_smooth_mesh',
+ 'taubin_smooth_mesh',
+ 'laplacian_hc_smooth_mesh',
+]
+
+
+def _group(
+ values: torch.Tensor,
+ required_group_size: Optional[int] = None,
+ return_values: bool = False
+) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]:
+ """
+ Group values into groups with identical values.
+
+ Args:
+ values (torch.Tensor): [N] values to group
+ required_group_size (int, optional): required group size. Defaults to None.
+ return_values (bool, optional): return values of groups. Defaults to False.
+
+ Returns:
+ group (Union[List[torch.Tensor], torch.Tensor]): list of groups or group indices. It will be a list of groups if required_group_size is None, otherwise a tensor of group indices.
+ group_values (Optional[torch.Tensor]): values of groups. Only returned if return_values is True.
+ """
+ sorted_values, indices = torch.sort(values)
+ nondupe = torch.cat([torch.tensor([True], dtype=torch.bool, device=values.device), sorted_values[1:] != sorted_values[:-1]])
+ nondupe_indices = torch.cumsum(nondupe, dim=0) - 1
+ counts = torch.bincount(nondupe_indices)
+ if required_group_size is None:
+ groups = torch.split(indices, counts.tolist())
+ if return_values:
+ group_values = sorted_values[nondupe]
+ return groups, group_values
+ else:
+ return groups
+ else:
+ counts = counts[nondupe_indices]
+ groups = indices[counts == required_group_size].reshape(-1, required_group_size)
+ if return_values:
+ group_values = sorted_values[nondupe][counts[nondupe] == required_group_size]
+ return groups, group_values
+ else:
+ return groups
+
+def triangulate(
+ faces: torch.Tensor,
+ vertices: torch.Tensor = None,
+ backslash: bool = None
+) -> torch.Tensor:
+ """
+ Triangulate a polygonal mesh.
+
+ Args:
+ faces (torch.Tensor): [..., L, P] polygonal faces
+ vertices (torch.Tensor, optional): [..., N, 3] 3-dimensional vertices.
+ If given, the triangulation is performed according to the distance
+ between vertices. Defaults to None.
+ backslash (torch.Tensor, optional): [..., L] boolean array indicating
+ how to triangulate the quad faces. Defaults to None.
+
+
+ Returns:
+ (torch.Tensor): [L * (P - 2), 3] triangular faces
+ """
+ if faces.shape[-1] == 3:
+ return faces
+ P = faces.shape[-1]
+ if vertices is not None:
+ assert faces.shape[-1] == 4, "now only support quad mesh"
+ if backslash is None:
+ faces_idx = faces.long()
+ backslash = torch.norm(vertices[faces_idx[..., 0]] - vertices[faces_idx[..., 2]], p=2, dim=-1) < \
+ torch.norm(vertices[faces_idx[..., 1]] - vertices[faces_idx[..., 3]], p=2, dim=-1)
+ if backslash is None:
+ loop_indice = torch.stack([
+ torch.zeros(P - 2, dtype=int),
+ torch.arange(1, P - 1, 1, dtype=int),
+ torch.arange(2, P, 1, dtype=int)
+ ], axis=1)
+ return faces[:, loop_indice].reshape(-1, 3)
+ else:
+ assert faces.shape[-1] == 4, "now only support quad mesh"
+ if isinstance(backslash, bool):
+ if backslash:
+ faces = faces[:, [0, 1, 2, 0, 2, 3]].reshape(-1, 3)
+ else:
+ faces = faces[:, [0, 1, 3, 3, 1, 2]].reshape(-1, 3)
+ else:
+ faces = torch.where(
+ backslash[:, None],
+ faces[:, [0, 1, 2, 0, 2, 3]],
+ faces[:, [0, 1, 3, 3, 1, 2]]
+ ).reshape(-1, 3)
+ return faces
+
+
+@batched(2, None)
+def compute_face_normal(
+ vertices: torch.Tensor,
+ faces: torch.Tensor
+) -> torch.Tensor:
+ """
+ Compute face normals of a triangular mesh
+
+ Args:
+ vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices
+ faces (torch.Tensor): [..., T, 3] triangular face indices
+
+ Returns:
+ normals (torch.Tensor): [..., T, 3] face normals
+ """
+ N = vertices.shape[0]
+ index = torch.arange(N)[:, None]
+ normal = torch.cross(
+ vertices[index, faces[..., 1].long()] - vertices[index, faces[..., 0].long()],
+ vertices[index, faces[..., 2].long()] - vertices[index, faces[..., 0].long()],
+ dim=-1
+ )
+ return F.normalize(normal, p=2, dim=-1)
+
+
+@batched(2, None)
+def compute_face_angles(
+ vertices: torch.Tensor,
+ faces: torch.Tensor
+) -> torch.Tensor:
+ """
+ Compute face angles of a triangular mesh
+
+ Args:
+ vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices
+ faces (torch.Tensor): [T, 3] triangular face indices
+
+ Returns:
+ angles (torch.Tensor): [..., T, 3] face angles
+ """
+ face_angles = []
+ for i in range(3):
+ edge1 = torch.index_select(vertices, dim=-2, index=faces[:, (i + 1) % 3]) - torch.index_select(vertices, dim=-2, index=faces[:, i])
+ edge2 = torch.index_select(vertices, dim=-2, index=faces[:, (i + 2) % 3]) - torch.index_select(vertices, dim=-2, index=faces[:, i])
+ face_angle = torch.arccos(torch.sum(F.normalize(edge1, p=2, dim=-1) * F.normalize(edge2, p=2, dim=-1), dim=-1))
+ face_angles.append(face_angle)
+ face_angles = torch.stack(face_angles, dim=-1)
+ return face_angles
+
+
+@batched(2, None, 2)
+def compute_vertex_normal(
+ vertices: torch.Tensor,
+ faces: torch.Tensor,
+ face_normal: torch.Tensor = None
+) -> torch.Tensor:
+ """
+ Compute vertex normals of a triangular mesh by averaging neightboring face normals
+
+ Args:
+ vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices
+ faces (torch.Tensor): [T, 3] triangular face indices
+ face_normal (torch.Tensor, optional): [..., T, 3] face normals.
+ None to compute face normals from vertices and faces. Defaults to None.
+
+ Returns:
+ normals (torch.Tensor): [..., N, 3] vertex normals
+ """
+ N = vertices.shape[0]
+ assert faces.shape[-1] == 3, "Only support triangular mesh"
+ if face_normal is None:
+ face_normal = compute_face_normal(vertices, faces)
+ face_normal = face_normal[:, :, None, :].expand(-1, -1, 3, -1).flatten(-3, -2)
+ faces = faces.flatten()
+ vertex_normal = torch.index_put(torch.zeros_like(vertices), (torch.arange(N)[:, None], faces[None, :]), face_normal, accumulate=True)
+ vertex_normal = F.normalize(vertex_normal, p=2, dim=-1)
+ return vertex_normal
+
+
+@batched(2, None, 2)
+def compute_vertex_normal_weighted(
+ vertices: torch.Tensor,
+ faces: torch.Tensor,
+ face_normal: torch.Tensor = None
+) -> torch.Tensor:
+ """
+ Compute vertex normals of a triangular mesh by weighted sum of neightboring face normals
+ according to the angles
+
+ Args:
+ vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices
+ faces (torch.Tensor): [T, 3] triangular face indices
+ face_normal (torch.Tensor, optional): [..., T, 3] face normals.
+ None to compute face normals from vertices and faces. Defaults to None.
+
+ Returns:
+ normals (torch.Tensor): [..., N, 3] vertex normals
+ """
+ N = vertices.shape[0]
+ if face_normal is None:
+ face_normal = compute_face_normal(vertices, faces)
+ face_angle = compute_face_angles(vertices, faces)
+ face_normal = face_normal[:, :, None, :].expand(-1, -1, 3, -1) * face_angle[..., None]
+ vertex_normal = torch.index_put(torch.zeros_like(vertices), (torch.arange(N)[:, None], faces.view(N, -1)), face_normal.view(N, -1, 3), accumulate=True)
+ vertex_normal = F.normalize(vertex_normal, p=2, dim=-1)
+ return vertex_normal
+
+
+def compute_edges(
+ faces: torch.Tensor
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Compute edges of a mesh.
+
+ Args:
+ faces (torch.Tensor): [T, 3] triangular face indices
+
+ Returns:
+ edges (torch.Tensor): [E, 2] edge indices
+ face2edge (torch.Tensor): [T, 3] mapping from face to edge
+ counts (torch.Tensor): [E] degree of each edge
+ """
+ T = faces.shape[0]
+ edges = torch.cat([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], dim=0) # [3T, 2]
+ edges = torch.sort(edges, dim=1).values
+ edges, inv_map, counts = torch.unique(edges, return_inverse=True, return_counts=True, dim=0)
+ face2edge = inv_map.view(3, T).T
+ return edges, face2edge, counts
+
+
+def compute_connected_components(
+ faces: torch.Tensor,
+ edges: torch.Tensor=None,
+ face2edge: torch.Tensor=None
+) -> List[torch.Tensor]:
+ """
+ Compute connected faces of a mesh.
+
+ Args:
+ faces (torch.Tensor): [T, 3] triangular face indices
+ edges (torch.Tensor, optional): [E, 2] edge indices. Defaults to None.
+ face2edge (torch.Tensor, optional): [T, 3] mapping from face to edge. Defaults to None.
+ NOTE: If edges and face2edge are not provided, they will be computed.
+
+ Returns:
+ components (List[torch.Tensor]): list of connected faces
+ """
+ T = faces.shape[0]
+ if edges is None or face2edge is None:
+ edges, face2edge, _ = compute_edges(faces)
+ E = edges.shape[0]
+
+ labels = torch.arange(T, dtype=torch.int32, device=faces.device)
+ while True:
+ edge_labels = torch.scatter_reduce(
+ torch.zeros(E, dtype=torch.int32, device=faces.device),
+ 0,
+ face2edge.flatten().long(),
+ labels.view(-1, 1).expand(-1, 3).flatten(),
+ reduce='amin',
+ include_self=False
+ )
+ new_labels = torch.min(edge_labels[face2edge], dim=-1).values
+ if torch.equal(labels, new_labels):
+ break
+ labels = new_labels
+
+ components = _group(labels)
+
+ return components
+
+
+def compute_edge_connected_components(
+ edges: torch.Tensor,
+) -> List[torch.Tensor]:
+ """
+ Compute connected edges of a mesh.
+
+ Args:
+ edges (torch.Tensor): [E, 2] edge indices
+
+ Returns:
+ components (List[torch.Tensor]): list of connected edges
+ """
+ E = edges.shape[0]
+
+ # Re-index edges
+ verts, edges = torch.unique(edges.flatten(), return_inverse=True)
+ edges = edges.view(-1, 2)
+ V = verts.shape[0]
+
+ labels = torch.arange(E, dtype=torch.int32, device=edges.device)
+ while True:
+ vertex_labels = torch.scatter_reduce(
+ torch.zeros(V, dtype=torch.int32, device=edges.device),
+ 0,
+ edges.flatten().long(),
+ labels.view(-1, 1).expand(-1, 2).flatten(),
+ reduce='amin',
+ include_self=False
+ )
+ new_labels = torch.min(vertex_labels[edges], dim=-1).values
+ if torch.equal(labels, new_labels):
+ break
+ labels = new_labels
+
+ components = _group(labels)
+
+ return components
+
+
+def compute_boundarys(
+ faces: torch.Tensor,
+ edges: torch.Tensor=None,
+ face2edge: torch.Tensor=None,
+ edge_degrees: torch.Tensor=None
+) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
+ """
+ Compute boundary edges of a mesh.
+
+ Args:
+ faces (torch.Tensor): [T, 3] triangular face indices
+ edges (torch.Tensor): [E, 2] edge indices.
+ face2edge (torch.Tensor): [T, 3] mapping from face to edge.
+ edge_degrees (torch.Tensor): [E] degree of each edge.
+
+ Returns:
+ boundary_edge_indices (List[torch.Tensor]): list of boundary edge indices
+ boundary_face_indices (List[torch.Tensor]): list of boundary face indices
+ """
+ # Map each edge to boundary edge index
+ boundary_edges = edges[edge_degrees == 1] # [BE, 2]
+ boundary_edges_idx = torch.nonzero(edge_degrees == 1, as_tuple=False).flatten() # [BE]
+ E = edges.shape[0] # Edge count
+ BE = boundary_edges.shape[0] # Boundary edge count
+ map_to_boundary_edges = torch.full((E,), -1, dtype=torch.int32, device=faces.device) # [E]
+ map_to_boundary_edges[boundary_edges_idx] = torch.arange(BE, dtype=torch.int32, device=faces.device)
+
+ # Re-index boundary vertices
+ boundary_vertices, boundary_edges = torch.unique(boundary_edges.flatten(), return_inverse=True)
+ boundary_edges = boundary_edges.view(-1, 2)
+ BV = boundary_vertices.shape[0]
+
+ boundary_edge_labels = torch.arange(BE, dtype=torch.int32, device=faces.device)
+ while True:
+ boundary_vertex_labels = torch.scatter_reduce(
+ torch.zeros(BV, dtype=torch.int32, device=faces.device),
+ 0,
+ boundary_edges.flatten().long(),
+ boundary_edge_labels.view(-1, 1).expand(-1, 2).flatten(),
+ reduce='amin',
+ include_self=False
+ )
+ new_boundary_edge_labels = torch.min(boundary_vertex_labels[boundary_edges], dim=-1).values
+ if torch.equal(boundary_edge_labels, new_boundary_edge_labels):
+ break
+ boundary_edge_labels = new_boundary_edge_labels
+
+ labels = torch.unique(boundary_edge_labels)
+ boundary_edge_indices = [boundary_edges_idx[boundary_edge_labels == label] for label in labels]
+ edge_labels = torch.full((E,), -1, dtype=torch.int32, device=faces.device)
+ edge_labels[boundary_edges_idx] = boundary_edge_labels
+ boundary_face_indices = [torch.nonzero((edge_labels[face2edge] == label).any(dim=-1), as_tuple=False).flatten() for label in labels]
+
+ return boundary_edge_indices, boundary_face_indices
+
+
+def compute_dual_graph(
+ face2edge: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Compute dual graph of a mesh.
+
+ Args:
+ face2edge (torch.Tensor): [T, 3] mapping from face to edge.
+
+ Returns:
+ dual_edges (torch.Tensor): [DE, 2] face indices of dual edges
+ dual_edge2edge (torch.Tensor): [DE] mapping from dual edge to edge
+ """
+ all_edge_indices = face2edge.flatten() # [3T]
+ dual_edges, dual_edge2edge = _group(all_edge_indices, required_group_size=2, return_values=True)
+ dual_edges = dual_edges // face2edge.shape[1]
+ return dual_edges, dual_edge2edge
+
+
+def remove_unreferenced_vertices(
+ faces: torch.Tensor,
+ *vertice_attrs,
+ return_indices: bool = False
+) -> Tuple[torch.Tensor, ...]:
+ """
+ Remove unreferenced vertices of a mesh.
+ Unreferenced vertices are removed, and the face indices are updated accordingly.
+
+ Args:
+ faces (torch.Tensor): [T, P] face indices
+ *vertice_attrs: vertex attributes
+
+ Returns:
+ faces (torch.Tensor): [T, P] face indices
+ *vertice_attrs: vertex attributes
+ indices (torch.Tensor, optional): [N] indices of vertices that are kept. Defaults to None.
+ """
+ P = faces.shape[-1]
+ fewer_indices, inv_map = torch.unique(faces, return_inverse=True)
+ faces = inv_map.to(torch.int32).reshape(-1, P)
+ ret = [faces]
+ for attr in vertice_attrs:
+ ret.append(attr[fewer_indices])
+ if return_indices:
+ ret.append(fewer_indices)
+ return tuple(ret)
+
+
+def remove_corrupted_faces(
+ faces: torch.Tensor
+) -> torch.Tensor:
+ """
+ Remove corrupted faces (faces with duplicated vertices)
+
+ Args:
+ faces (torch.Tensor): [T, 3] triangular face indices
+
+ Returns:
+ torch.Tensor: [T_, 3] triangular face indices
+ """
+ corrupted = (faces[:, 0] == faces[:, 1]) | (faces[:, 1] == faces[:, 2]) | (faces[:, 2] == faces[:, 0])
+ return faces[~corrupted]
+
+
+def merge_duplicate_vertices(
+ vertices: torch.Tensor,
+ faces: torch.Tensor,
+ tol: float = 1e-6
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Merge duplicate vertices of a triangular mesh.
+ Duplicate vertices are merged by selecte one of them, and the face indices are updated accordingly.
+
+ Args:
+ vertices (torch.Tensor): [N, 3] 3-dimensional vertices
+ faces (torch.Tensor): [T, 3] triangular face indices
+ tol (float, optional): tolerance for merging. Defaults to 1e-6.
+
+ Returns:
+ vertices (torch.Tensor): [N_, 3] 3-dimensional vertices
+ faces (torch.Tensor): [T, 3] triangular face indices
+ """
+ vertices_round = torch.round(vertices / tol)
+ uni, uni_inv = torch.unique(vertices_round, dim=0, return_inverse=True)
+ uni[uni_inv] = vertices
+ faces = uni_inv[faces]
+ return uni, faces
+
+
+def remove_isolated_pieces(
+ vertices: torch.Tensor,
+ faces: torch.Tensor,
+ connected_components: List[torch.Tensor] = None,
+ thresh_num_faces: int = None,
+ thresh_radius: float = None,
+ thresh_boundary_ratio: float = None,
+ remove_unreferenced: bool = True,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Remove isolated pieces of a mesh.
+ Isolated pieces are removed, and the face indices are updated accordingly.
+ If no face is left, will return the largest connected component.
+
+ Args:
+ vertices (torch.Tensor): [N, 3] 3-dimensional vertices
+ faces (torch.Tensor): [T, 3] triangular face indices
+ connected_components (List[torch.Tensor], optional): connected components of the mesh. If None, it will be computed. Defaults to None.
+ thresh_num_faces (int, optional): threshold of number of faces for isolated pieces. Defaults to None.
+ thresh_radius (float, optional): threshold of radius for isolated pieces. Defaults to None.
+ remove_unreferenced (bool, optional): remove unreferenced vertices after removing isolated pieces. Defaults to True.
+
+ Returns:
+ vertices (torch.Tensor): [N_, 3] 3-dimensional vertices
+ faces (torch.Tensor): [T, 3] triangular face indices
+ """
+ if connected_components is None:
+ connected_components = compute_connected_components(faces)
+ connected_components = sorted(connected_components, key=lambda x: len(x), reverse=True)
+ if thresh_num_faces is not None:
+ removed = []
+ for i in range(1, len(connected_components)):
+ if len(connected_components[i]) < thresh_num_faces:
+ removed.append(i)
+ for i in removed[::-1]:
+ connected_components.pop(i)
+ if thresh_radius is not None:
+ removed = []
+ for i in range(1, len(connected_components)):
+ comp_vertices = vertices[faces[connected_components[i]].flatten().unique()]
+ comp_center = comp_vertices.mean(dim=0)
+ comp_radius = (comp_vertices - comp_center).norm(p=2, dim=-1).max()
+ if comp_radius < thresh_radius:
+ removed.append(i)
+ for i in removed[::-1]:
+ connected_components.pop(i)
+ if thresh_boundary_ratio is not None:
+ removed = []
+ for i in range(1, len(connected_components)):
+ edges = torch.cat([faces[connected_components[i]][:, [0, 1]], faces[connected_components[i]][:, [1, 2]], faces[connected_components[i]][:, [2, 0]]], dim=0)
+ edges = torch.sort(edges, dim=1).values
+ edges, counts = torch.unique(edges, return_counts=True, dim=0)
+ num_boundary_edges = (counts == 1).sum().item()
+ num_faces = len(connected_components[i])
+ if num_boundary_edges / num_faces > thresh_boundary_ratio:
+ removed.append(i)
+ for i in removed[::-1]:
+ connected_components.pop(i)
+
+ # post-process
+ faces = torch.cat([faces[connected_components[i]] for i in range(len(connected_components))], dim=0)
+ if remove_unreferenced:
+ faces, vertices = remove_unreferenced_vertices(faces, vertices)
+ return vertices, faces
+
+
+def subdivide_mesh_simple(vertices: torch.Tensor, faces: torch.Tensor, n: int = 1) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Subdivide a triangular mesh by splitting each triangle into 4 smaller triangles.
+ NOTE: All original vertices are kept, and new vertices are appended to the end of the vertex list.
+
+ Args:
+ vertices (torch.Tensor): [N, 3] 3-dimensional vertices
+ faces (torch.Tensor): [T, 3] triangular face indices
+ n (int, optional): number of subdivisions. Defaults to 1.
+
+ Returns:
+ vertices (torch.Tensor): [N_, 3] subdivided 3-dimensional vertices
+ faces (torch.Tensor): [4 * T, 3] subdivided triangular face indices
+ """
+ for _ in range(n):
+ edges = torch.stack([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], dim=0)
+ edges = torch.sort(edges, dim=2)
+ uni_edges, uni_inv = torch.unique(edges, return_inverse=True, dim=0)
+ midpoints = (vertices[uni_edges[:, 0]] + vertices[uni_edges[:, 1]]) / 2
+
+ n_vertices = vertices.shape[0]
+ vertices = torch.cat([vertices, midpoints], dim=0)
+ faces = torch.cat([
+ torch.stack([faces[:, 0], n_vertices + uni_inv[0], n_vertices + uni_inv[2]], axis=1),
+ torch.stack([faces[:, 1], n_vertices + uni_inv[1], n_vertices + uni_inv[0]], axis=1),
+ torch.stack([faces[:, 2], n_vertices + uni_inv[2], n_vertices + uni_inv[1]], axis=1),
+ torch.stack([n_vertices + uni_inv[0], n_vertices + uni_inv[1], n_vertices + uni_inv[2]], axis=1),
+ ], dim=0)
+ return vertices, faces
+
+
+def compute_face_tbn(pos: torch.Tensor, faces_pos: torch.Tensor, uv: torch.Tensor, faces_uv: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
+ """compute TBN matrix for each face
+
+ Args:
+ pos (torch.Tensor): shape (..., N_pos, 3), positions
+ faces_pos (torch.Tensor): shape(T, 3)
+ uv (torch.Tensor): shape (..., N_uv, 3) uv coordinates,
+ faces_uv (torch.Tensor): shape(T, 3)
+
+ Returns:
+ torch.Tensor: (..., T, 3, 3) TBN matrix for each face. Note TBN vectors are normalized but not necessarily orthognal
+ """
+ e01 = torch.index_select(pos, dim=-2, index=faces_pos[:, 1]) - torch.index_select(pos, dim=-2, index=faces_pos[:, 0])
+ e02 = torch.index_select(pos, dim=-2, index=faces_pos[:, 2]) - torch.index_select(pos, dim=-2, index=faces_pos[:, 0])
+ uv01 = torch.index_select(uv, dim=-2, index=faces_uv[:, 1]) - torch.index_select(uv, dim=-2, index=faces_uv[:, 0])
+ uv02 = torch.index_select(uv, dim=-2, index=faces_uv[:, 2]) - torch.index_select(uv, dim=-2, index=faces_uv[:, 0])
+ normal = torch.cross(e01, e02)
+ tangent_bitangent = torch.stack([e01, e02], dim=-1) @ torch.inverse(torch.stack([uv01, uv02], dim=-1))
+ tbn = torch.cat([tangent_bitangent, normal.unsqueeze(-1)], dim=-1)
+ tbn = tbn / (torch.norm(tbn, p=2, dim=-2, keepdim=True) + eps)
+ return tbn
+
+
+def compute_vertex_tbn(faces_topo: torch.Tensor, pos: torch.Tensor, faces_pos: torch.Tensor, uv: torch.Tensor, faces_uv: torch.Tensor) -> torch.Tensor:
+ """compute TBN matrix for each face
+
+ Args:
+ faces_topo (torch.Tensor): (T, 3), face indice of topology
+ pos (torch.Tensor): shape (..., N_pos, 3), positions
+ faces_pos (torch.Tensor): shape(T, 3)
+ uv (torch.Tensor): shape (..., N_uv, 3) uv coordinates,
+ faces_uv (torch.Tensor): shape(T, 3)
+
+ Returns:
+ torch.Tensor: (..., V, 3, 3) TBN matrix for each face. Note TBN vectors are normalized but not necessarily orthognal
+ """
+ n_vertices = faces_topo.max().item() + 1
+ n_tri = faces_topo.shape[-2]
+ batch_shape = pos.shape[:-2]
+ face_tbn = compute_face_tbn(pos, faces_pos, uv, faces_uv) # (..., T, 3, 3)
+ face_tbn = face_tbn[..., :, None, :, :].repeat(*[1] * len(batch_shape), 1, 3, 1, 1).view(*batch_shape, n_tri * 3, 3, 3) # (..., T * 3, 3, 3)
+ vertex_tbn = torch.index_add(torch.zeros(*batch_shape, n_vertices, 3, 3).to(face_tbn), dim=-3, index=faces_topo.view(-1), source=face_tbn)
+ vertex_tbn = vertex_tbn / (torch.norm(vertex_tbn, p=2, dim=-2, keepdim=True) + 1e-7)
+ return vertex_tbn
+
+
+def laplacian(vertices: torch.Tensor, faces: torch.Tensor, weight: str = 'uniform') -> torch.Tensor:
+ """Laplacian smooth with cotangent weights
+
+ Args:
+ vertices (torch.Tensor): shape (..., N, 3)
+ faces (torch.Tensor): shape (T, 3)
+ weight (str): 'uniform' or 'cotangent'
+ """
+ sum_verts = torch.zeros_like(vertices) # (..., N, 3)
+ sum_weights = torch.zeros(*vertices.shape[:-1]).to(vertices) # (..., N)
+ face_verts = torch.index_select(vertices, -2, faces.view(-1)).view(*vertices.shape[:-2], *faces.shape, vertices.shape[-1]) # (..., T, 3)
+ if weight == 'cotangent':
+ for i in range(3):
+ e1 = face_verts[..., (i + 1) % 3, :] - face_verts[..., i, :]
+ e2 = face_verts[..., (i + 2) % 3, :] - face_verts[..., i, :]
+ cot_angle = (e1 * e2).sum(dim=-1) / torch.cross(e1, e2, dim=-1).norm(p=2, dim=-1) # (..., T, 3)
+ sum_verts = torch.index_add(sum_verts, -2, faces[:, (i + 1) % 3], face_verts[..., (i + 2) % 3, :] * cot_angle[..., None])
+ sum_weights = torch.index_add(sum_weights, -1, faces[:, (i + 1) % 3], cot_angle)
+ sum_verts = torch.index_add(sum_verts, -2, faces[:, (i + 2) % 3], face_verts[..., (i + 1) % 3, :] * cot_angle[..., None])
+ sum_weights = torch.index_add(sum_weights, -1, faces[:, (i + 2) % 3], cot_angle)
+ elif weight == 'uniform':
+ for i in range(3):
+ sum_verts = torch.index_add(sum_verts, -2, faces[:, i], face_verts[..., (i + 1) % 3, :])
+ sum_weights = torch.index_add(sum_weights, -1, faces[:, i], torch.ones_like(face_verts[..., i, 0]))
+ else:
+ raise NotImplementedError
+ return sum_verts / (sum_weights[..., None] + 1e-7)
+
+
+def laplacian_smooth_mesh(vertices: torch.Tensor, faces: torch.Tensor, weight: str = 'uniform', times: int = 5) -> torch.Tensor:
+ """Laplacian smooth with cotangent weights
+
+ Args:
+ vertices (torch.Tensor): shape (..., N, 3)
+ faces (torch.Tensor): shape (T, 3)
+ weight (str): 'uniform' or 'cotangent'
+ """
+ for _ in range(times):
+ vertices = laplacian(vertices, faces, weight)
+ return vertices
+
+
+def taubin_smooth_mesh(vertices: torch.Tensor, faces: torch.Tensor, lambda_: float = 0.5, mu_: float = -0.51) -> torch.Tensor:
+ """Taubin smooth mesh
+
+ Args:
+ vertices (torch.Tensor): _description_
+ faces (torch.Tensor): _description_
+ lambda_ (float, optional): _description_. Defaults to 0.5.
+ mu_ (float, optional): _description_. Defaults to -0.51.
+
+ Returns:
+ torch.Tensor: _description_
+ """
+ pt = vertices + lambda_ * laplacian_smooth_mesh(vertices, faces)
+ p = pt + mu_ * laplacian_smooth_mesh(pt, faces)
+ return p
+
+
+def laplacian_hc_smooth_mesh(vertices: torch.Tensor, faces: torch.Tensor, times: int = 5, alpha: float = 0.5, beta: float = 0.5, weight: str = 'uniform'):
+ """HC algorithm from Improved Laplacian Smoothing of Noisy Surface Meshes by J.Vollmer et al.
+ """
+ p = vertices
+ for i in range(times):
+ q = p
+ p = laplacian_smooth_mesh(vertices, faces, weight)
+ b = p - (alpha * vertices + (1 - alpha) * q)
+ p = p - (beta * b + (1 - beta) * laplacian_smooth_mesh(b, faces, weight)) * 0.8
+ return p
diff --git a/src/utils3d/torch/nerf.py b/src/utils3d/torch/nerf.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d20bc747255dbb1a68191f93a395a824d76e108
--- /dev/null
+++ b/src/utils3d/torch/nerf.py
@@ -0,0 +1,749 @@
+from typing import *
+from numbers import Number
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from .utils import image_uv
+
+
+__all__ = [
+ 'get_rays',
+ 'get_image_rays',
+ 'get_mipnerf_cones',
+ 'volume_rendering',
+ 'bin_sample',
+ 'importance_sample',
+ 'nerf_render_rays',
+ 'mipnerf_render_rays',
+ 'nerf_render_view',
+ 'mipnerf_render_view',
+ 'InstantNGP',
+]
+
+
+def get_rays(extrinsics: Tensor, intrinsics: Tensor, uv: Tensor) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ extrinsics: (..., 4, 4) extrinsics matrices.
+ intrinsics: (..., 3, 3) intrinsics matrices.
+ uv: (..., n_rays, 2) uv coordinates of the rays.
+
+ Returns:
+ rays_o: (..., 1, 3) ray origins
+ rays_d: (..., n_rays, 3) ray directions.
+ NOTE: ray directions are NOT normalized. They actuallys makes rays_o + rays_d * z = world coordinates, where z is the depth.
+ """
+ uvz = torch.cat([uv, torch.ones_like(uv[..., :1])], dim=-1).to(extrinsics) # (n_batch, n_views, n_rays, 3)
+
+ with torch.cuda.amp.autocast(enabled=False):
+ inv_transformation = (intrinsics @ extrinsics[..., :3, :3]).inverse()
+ inv_extrinsics = extrinsics.inverse()
+ rays_d = uvz @ inv_transformation.transpose(-1, -2)
+ rays_o = inv_extrinsics[..., None, :3, 3] # (n_batch, n_views, 1, 3)
+ return rays_o, rays_d
+
+
+def get_image_rays(extrinsics: Tensor, intrinsics: Tensor, width: int, height: int) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ extrinsics: (..., 4, 4) extrinsics matrices.
+ intrinsics: (..., 3, 3) intrinsics matrices.
+ width: width of the image.
+ height: height of the image.
+
+ Returns:
+ rays_o: (..., 1, 1, 3) ray origins
+ rays_d: (..., height, width, 3) ray directions.
+ NOTE: ray directions are NOT normalized. They actuallys makes rays_o + rays_d * z = world coordinates, where z is the depth.
+ """
+ uv = image_uv(height, width).to(extrinsics).flatten(0, 1)
+ rays_o, rays_d = get_rays(extrinsics, intrinsics, uv)
+ rays_o = rays_o.unflatten(-2, (1, 1))
+ rays_d = rays_d.unflatten(-2, (height, width))
+ return rays_o, rays_d
+
+
+def get_mipnerf_cones(rays_o: Tensor, rays_d: Tensor, z_vals: Tensor, pixel_width: Tensor) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ rays_o: (..., n_rays, 3) ray origins
+ rays_d: (..., n_rays, 3) ray directions.
+ z_vals: (..., n_rays, n_samples) z values.
+ pixel_width: (...) pixel width. = 1 / (normalized focal length * width)
+
+ Returns:
+ mu: (..., n_rays, n_samples, 3) cone mu.
+ sigma: (..., n_rays, n_samples, 3, 3) cone sigma.
+ """
+ t_mu = (z_vals[..., 1:] + z_vals[..., :-1]).mul_(0.5)
+ t_delta = (z_vals[..., 1:] - z_vals[..., :-1]).mul_(0.5)
+ t_mu_square = t_mu.square()
+ t_delta_square = t_delta.square()
+ t_delta_quad = t_delta_square.square()
+ mu_t = t_mu + 2.0 * t_mu * t_delta_square / (3.0 * t_mu_square + t_delta_square)
+ sigma_t = t_delta_square / 3.0 - (4.0 / 15.0) * t_delta_quad / (3.0 * t_mu_square + t_delta_square).square() * (12.0 * t_mu_square - t_delta_square)
+ sigma_r = (pixel_width[..., None, None].square() / 3.0) * (t_mu_square / 4.0 + (5.0 / 12.0) * t_delta_square - (4.0 / 15.0) * t_delta_quad / (3.0 * t_mu_square + t_delta_square))
+ points_mu = rays_o[:, :, :, None, :] + rays_d[:, :, :, None, :] * mu_t[..., None]
+ d_dt = rays_d[..., :, None] * rays_d[..., None, :] # (..., n_rays, 3, 3)
+ points_sigma = sigma_t[..., None, None] * d_dt[..., None, :, :] + sigma_r[..., None, None] * (torch.eye(3).to(rays_o) - d_dt[..., None, :, :])
+ return points_mu, points_sigma
+
+
+def get_pixel_width(intrinsics: Tensor, width: int, height: int) -> Tensor:
+ """
+ Args:
+ intrinsics: (..., 3, 3) intrinsics matrices.
+ width: width of the image.
+ height: height of the image.
+
+ Returns:
+ pixel_width: (...) pixel width. = 1 / (normalized focal length * width)
+ """
+ assert width == height, "Currently, only square images are supported."
+ pixel_width = torch.reciprocal((intrinsics[..., 0, 0] * intrinsics[..., 1, 1]).sqrt() * width)
+ return pixel_width
+
+
+def volume_rendering(color: Tensor, sigma: Tensor, z_vals: Tensor, ray_length: Tensor, rgb: bool = True, depth: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
+ """
+ Given color, sigma and z_vals (linear depth of the sampling points), render the volume.
+
+ NOTE: By default, color and sigma should have one less sample than z_vals, in correspondence with the average value in intervals.
+ If queried color are aligned with z_vals, we use trapezoidal rule to calculate the average values in intervals.
+
+ Args:
+ color: (..., n_samples or n_samples - 1, 3) color values.
+ sigma: (..., n_samples or n_samples - 1) density values.
+ z_vals: (..., n_samples) z values.
+ ray_length: (...) length of the ray
+
+ Returns:
+ rgb: (..., 3) rendered color values.
+ depth: (...) rendered depth values.
+ weights (..., n_samples) weights.
+ """
+ dists = (z_vals[..., 1:] - z_vals[..., :-1]) * ray_length[..., None]
+ if color.shape[-2] == z_vals.shape[-1]:
+ color = (color[..., 1:, :] + color[..., :-1, :]).mul_(0.5)
+ sigma = (sigma[..., 1:] + sigma[..., :-1]).mul_(0.5)
+ sigma_delta = sigma * dists
+ transparancy = (-torch.cat([torch.zeros_like(sigma_delta[..., :1]), sigma_delta[..., :-1]], dim=-1).cumsum(dim=-1)).exp_() # First cumsum then exp for numerical stability
+ alpha = 1.0 - (-sigma_delta).exp_()
+ weights = alpha * transparancy
+ if rgb:
+ rgb = torch.sum(weights[..., None] * color, dim=-2) if rgb else None
+ if depth:
+ z_vals = (z_vals[..., 1:] + z_vals[..., :-1]).mul_(0.5)
+ depth = torch.sum(weights * z_vals, dim=-1) / weights.sum(dim=-1).clamp_min_(1e-8) if depth else None
+ return rgb, depth, weights
+
+
+def neus_volume_rendering(color: Tensor, sdf: Tensor, s: torch.Tensor, z_vals: Tensor = None, rgb: bool = True, depth: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
+ """
+ Given color, sdf values and z_vals (linear depth of the sampling points), do volume rendering. (NeuS)
+
+ Args:
+ color: (..., n_samples or n_samples - 1, 3) color values.
+ sdf: (..., n_samples) sdf values.
+ s: (..., n_samples) S values of S-density function in NeuS. The standard deviation of such S-density distribution is 1 / s.
+ z_vals: (..., n_samples) z values.
+ ray_length: (...) length of the ray
+
+ Returns:
+ rgb: (..., 3) rendered color values.
+ depth: (...) rendered depth values.
+ weights (..., n_samples) weights.
+ """
+
+ if color.shape[-2] == z_vals.shape[-1]:
+ color = (color[..., 1:, :] + color[..., :-1, :]).mul_(0.5)
+
+ sigmoid_sdf = torch.sigmoid(s * sdf)
+ alpha = F.relu(1 - sigmoid_sdf[..., :-1] / sigmoid_sdf[..., :-1])
+ transparancy = torch.cumprod(torch.cat([torch.ones_like(alpha[..., :1]), alpha], dim=-1), dim=-1)
+ weights = alpha * transparancy
+
+ if rgb:
+ rgb = torch.sum(weights[..., None] * color, dim=-2) if rgb else None
+ if depth:
+ z_vals = (z_vals[..., 1:] + z_vals[..., :-1]).mul_(0.5)
+ depth = torch.sum(weights * z_vals, dim=-1) / weights.sum(dim=-1).clamp_min_(1e-8) if depth else None
+ return rgb, depth, weights
+
+
+def bin_sample(size: Union[torch.Size, Tuple[int, ...]], n_samples: int, min_value: Number, max_value: Number, spacing: Literal['linear', 'inverse_linear'], dtype: torch.dtype = None, device: torch.device = None) -> Tensor:
+ """
+ Uniformly (or uniformly in inverse space) sample z values in `n_samples` bins in range [min_value, max_value].
+ Args:
+ size: size of the rays
+ n_samples: number of samples to be sampled, also the number of bins
+ min_value: minimum value of the range
+ max_value: maximum value of the range
+ space: 'linear' or 'inverse_linear'. If 'inverse_linear', the sampling is uniform in inverse space.
+
+ Returns:
+ z_rand: (*size, n_samples) sampled z values, sorted in ascending order.
+ """
+ if spacing == 'linear':
+ pass
+ elif spacing == 'inverse_linear':
+ min_value = 1.0 / min_value
+ max_value = 1.0 / max_value
+ bin_length = (max_value - min_value) / n_samples
+ z_rand = (torch.rand(*size, n_samples, device=device, dtype=dtype) - 0.5) * bin_length + torch.linspace(min_value + bin_length * 0.5, max_value - bin_length * 0.5, n_samples, device=device, dtype=dtype)
+ if spacing == 'inverse_linear':
+ z_rand = 1.0 / z_rand
+ return z_rand
+
+
+def importance_sample(z_vals: Tensor, weights: Tensor, n_samples: int) -> Tuple[Tensor, Tensor]:
+ """
+ Importance sample z values.
+
+ NOTE: By default, weights should have one less sample than z_vals, in correspondence with the intervals.
+ If weights has the same number of samples as z_vals, we use trapezoidal rule to calculate the average weights in intervals.
+
+ Args:
+ z_vals: (..., n_rays, n_input_samples) z values, sorted in ascending order.
+ weights: (..., n_rays, n_input_samples or n_input_samples - 1) weights.
+ n_samples: number of output samples for importance sampling.
+
+ Returns:
+ z_importance: (..., n_rays, n_samples) importance sampled z values, unsorted.
+ """
+ if weights.shape[-1] == z_vals.shape[-1]:
+ weights = (weights[..., 1:] + weights[..., :-1]).mul_(0.5)
+ weights = weights / torch.sum(weights, dim=-1, keepdim=True) # (..., n_rays, n_input_samples - 1)
+ bins_a, bins_b = z_vals[..., :-1], z_vals[..., 1:]
+
+ pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # (..., n_rays, n_input_samples - 1)
+ cdf = torch.cumsum(pdf, dim=-1)
+ u = torch.rand(*z_vals.shape[:-1], n_samples, device=z_vals.device, dtype=z_vals.dtype)
+
+ inds = torch.searchsorted(cdf, u, right=True).clamp(0, cdf.shape[-1] - 1) # (..., n_rays, n_samples)
+
+ bins_a = torch.gather(bins_a, dim=-1, index=inds)
+ bins_b = torch.gather(bins_b, dim=-1, index=inds)
+ z_importance = bins_a + (bins_b - bins_a) * torch.rand_like(u)
+ return z_importance
+
+
+def nerf_render_rays(
+ nerf: Union[Callable[[Tensor, Tensor], Tuple[Tensor, Tensor]], Tuple[Callable[[Tensor], Tuple[Tensor, Tensor]], Callable[[Tensor], Tuple[Tensor, Tensor]]]],
+ rays_o: Tensor, rays_d: Tensor,
+ *,
+ return_dict: bool = False,
+ n_coarse: int = 64, n_fine: int = 64,
+ near: float = 0.1, far: float = 100.0,
+ z_spacing: Literal['linear', 'inverse_linear'] = 'linear',
+):
+ """
+ NeRF rendering of rays. Note that it supports arbitrary batch dimensions (denoted as `...`)
+
+ Args:
+ nerf: nerf model, which takes (points, directions) as input and returns (color, density) as output.
+ If nerf is a tuple, it should be (nerf_coarse, nerf_fine), where nerf_coarse and nerf_fine are two nerf models for coarse and fine stages respectively.
+
+ nerf args:
+ points: (..., n_rays, n_samples, 3)
+ directions: (..., n_rays, n_samples, 3)
+ nerf returns:
+ color: (..., n_rays, n_samples, 3) color values.
+ density: (..., n_rays, n_samples) density values.
+
+ rays_o: (..., n_rays, 3) ray origins
+ rays_d: (..., n_rays, 3) ray directions.
+ pixel_width: (..., n_rays) pixel width. How to compute? pixel_width = 1 / (normalized focal length * width)
+
+ Returns
+ if return_dict is False, return rendered rgb and depth for short cut. (If there are separate coarse and fine results, return fine results)
+ rgb: (..., n_rays, 3) rendered color values.
+ depth: (..., n_rays) rendered depth values.
+ else, return a dict. If `n_fine == 0` or `nerf` is a single model, the dict only contains coarse results:
+ ```
+ {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}
+ ```
+ If there are two models for coarse and fine stages, the dict contains both coarse and fine results:
+ ```
+ {
+ "coarse": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..},
+ "fine": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}
+ }
+ ```
+ """
+ if isinstance(nerf, tuple):
+ nerf_coarse, nerf_fine = nerf
+ else:
+ nerf_coarse = nerf_fine = nerf
+ # 1. Coarse: bin sampling
+ z_coarse = bin_sample(rays_d.shape[:-1], n_coarse, near, far, device=rays_o.device, dtype=rays_o.dtype, spacing=z_spacing) # (n_batch, n_views, n_rays, n_samples)
+ points_coarse = rays_o[..., None, :] + rays_d[..., None, :] * z_coarse[..., None] # (n_batch, n_views, n_rays, n_samples, 3)
+ ray_length = rays_d.norm(dim=-1)
+
+ # Query color and density
+ color_coarse, density_coarse = nerf_coarse(points_coarse, rays_d[..., None, :].expand_as(points_coarse)) # (n_batch, n_views, n_rays, n_samples, 3), (n_batch, n_views, n_rays, n_samples)
+
+ # Volume rendering
+ with torch.no_grad():
+ rgb_coarse, depth_coarse, weights = volume_rendering(color_coarse, density_coarse, z_coarse, ray_length) # (n_batch, n_views, n_rays, 3), (n_batch, n_views, n_rays, 1), (n_batch, n_views, n_rays, n_samples)
+
+ if n_fine == 0:
+ if return_dict:
+ return {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights, 'z_vals': z_coarse, 'color': color_coarse, 'density': density_coarse}
+ else:
+ return rgb_coarse, depth_coarse
+
+ # 2. Fine: Importance sampling
+ if nerf_coarse is nerf_fine:
+ # If coarse and fine stages share the same model, the points of coarse stage can be reused,
+ # and we only need to query the importance samples of fine stage.
+ z_fine = importance_sample(z_coarse, weights, n_fine)
+ points_fine = rays_o[..., None, :] + rays_d[..., None, :] * z_fine[..., None]
+ color_fine, density_fine = nerf_fine(points_fine, rays_d[..., None, :].expand_as(points_fine))
+
+ # Merge & volume rendering
+ z_vals = torch.cat([z_coarse, z_fine], dim=-1)
+ color = torch.cat([color_coarse, color_fine], dim=-2)
+ density = torch.cat([density_coarse, density_fine], dim=-1)
+ z_vals, sort_inds = torch.sort(z_vals, dim=-1)
+ color = torch.gather(color, dim=-2, index=sort_inds[..., None].expand_as(color))
+ density = torch.gather(density, dim=-1, index=sort_inds)
+ rgb, depth, weights = volume_rendering(color, density, z_vals, ray_length)
+
+ if return_dict:
+ return {'rgb': rgb, 'depth': depth, 'weights': weights, 'z_vals': z_vals, 'color': color, 'density': density}
+ else:
+ return rgb, depth
+ else:
+ # If coarse and fine stages use different models, we need to query the importance samples of both stages.
+ z_fine = importance_sample(z_coarse, weights, n_fine)
+ z_vals = torch.cat([z_coarse, z_fine], dim=-1)
+ points = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., None]
+ color, density = nerf_fine(points)
+ rgb, depth, weights = volume_rendering(color, density, z_vals, ray_length)
+
+ if return_dict:
+ return {
+ 'coarse': {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights, 'z_vals': z_coarse, 'color': color_coarse, 'density': density_coarse},
+ 'fine': {'rgb': rgb, 'depth': depth, 'weights': weights, 'z_vals': z_vals, 'color': color, 'density': density}
+ }
+ else:
+ return rgb, depth
+
+
+def mipnerf_render_rays(
+ mipnerf: Callable[[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]],
+ rays_o: Tensor, rays_d: Tensor, pixel_width: Tensor,
+ *,
+ return_dict: bool = False,
+ n_coarse: int = 64, n_fine: int = 64, uniform_ratio: float = 0.4,
+ near: float = 0.1, far: float = 100.0,
+ z_spacing: Literal['linear', 'inverse_linear'] = 'linear',
+) -> Union[Tuple[Tensor, Tensor], Dict[str, Tensor]]:
+ """
+ MipNeRF rendering.
+
+ Args:
+ mipnerf: mipnerf model, which takes (points_mu, points_sigma) as input and returns (color, density) as output.
+
+ mipnerf args:
+ points_mu: (..., n_rays, n_samples, 3) cone mu.
+ points_sigma: (..., n_rays, n_samples, 3, 3) cone sigma.
+ directions: (..., n_rays, n_samples, 3)
+ mipnerf returns:
+ color: (..., n_rays, n_samples, 3) color values.
+ density: (..., n_rays, n_samples) density values.
+
+ rays_o: (..., n_rays, 3) ray origins
+ rays_d: (..., n_rays, 3) ray directions.
+ pixel_width: (..., n_rays) pixel width. How to compute? pixel_width = 1 / (normalized focal length * width)
+
+ Returns
+ if return_dict is False, return rendered results only: (If `n_fine == 0`, return coarse results, otherwise return fine results)
+ rgb: (..., n_rays, 3) rendered color values.
+ depth: (..., n_rays) rendered depth values.
+ else, return a dict. If `n_fine == 0`, the dict only contains coarse results:
+ ```
+ {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}
+ ```
+ If n_fine > 0, the dict contains both coarse and fine results :
+ ```
+ {
+ "coarse": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..},
+ "fine": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}
+ }
+ ```
+ """
+ # 1. Coarse: bin sampling
+ z_coarse = bin_sample(rays_d.shape[:-1], n_coarse, near, far, spacing=z_spacing, device=rays_o.device, dtype=rays_o.dtype)
+ points_mu_coarse, points_sigma_coarse = get_mipnerf_cones(rays_o, rays_d, z_coarse, pixel_width)
+ ray_length = rays_d.norm(dim=-1)
+
+ # Query color and density
+ color_coarse, density_coarse = mipnerf(points_mu_coarse, points_sigma_coarse, rays_d[..., None, :].expand_as(points_mu_coarse)) # (n_batch, n_views, n_rays, n_samples, 3), (n_batch, n_views, n_rays, n_samples)
+
+ # Volume rendering
+ rgb_coarse, depth_coarse, weights_coarse = volume_rendering(color_coarse, density_coarse, z_coarse, ray_length) # (n_batch, n_views, n_rays, 3), (n_batch, n_views, n_rays, 1), (n_batch, n_views, n_rays, n_samples)
+
+ if n_fine == 0:
+ if return_dict:
+ return {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights_coarse, 'z_vals': z_coarse, 'color': color_coarse, 'density': density_coarse}
+ else:
+ return rgb_coarse, depth_coarse
+
+ # 2. Fine: Importance sampling. (NOTE: coarse stages and fine stages always share the same model, but coarse stage points can not be reused)
+ with torch.no_grad():
+ weights_coarse = (1.0 - uniform_ratio) * weights_coarse + uniform_ratio / weights_coarse.shape[-1]
+ z_fine = importance_sample(z_coarse, weights_coarse, n_fine)
+ z_fine, _ = torch.sort(z_fine, dim=-2)
+ points_mu_fine, points_sigma_fine = get_mipnerf_cones(rays_o, rays_d, z_fine, pixel_width)
+ color_fine, density_fine = mipnerf(points_mu_fine, points_sigma_fine, rays_d[..., None, :].expand_as(points_mu_fine))
+
+ # Volume rendering
+ rgb_fine, depth_fine, weights_fine = volume_rendering(color_fine, density_fine, z_fine, ray_length)
+
+ if return_dict:
+ return {
+ 'coarse': {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights_coarse, 'z_vals': z_coarse, 'color': color_coarse, 'density': density_coarse},
+ 'fine': {'rgb': rgb_fine, 'depth': depth_fine, 'weights': weights_fine, 'z_vals': z_fine, 'color': color_fine, 'density': density_fine}
+ }
+ else:
+ return rgb_fine, depth_fine
+
+
+def neus_render_rays(
+ neus: Callable[[Tensor, Tensor], Tuple[Tensor, Tensor]],
+ s: Union[Number, Tensor],
+ rays_o: Tensor, rays_d: Tensor,
+ *,
+ compute_normal: bool = True,
+ return_dict: bool = False,
+ n_coarse: int = 64, n_fine: int = 64,
+ near: float = 0.1, far: float = 100.0,
+ z_spacing: Literal['linear', 'inverse_linear'] = 'linear',
+):
+ """
+ TODO
+ NeuS rendering of rays. Note that it supports arbitrary batch dimensions (denoted as `...`)
+
+ Args:
+ neus: neus model, which takes (points, directions) as input and returns (color, density) as output.
+
+ nerf args:
+ points: (..., n_rays, n_samples, 3)
+ directions: (..., n_rays, n_samples, 3)
+ nerf returns:
+ color: (..., n_rays, n_samples, 3) color values.
+ density: (..., n_rays, n_samples) density values.
+
+ rays_o: (..., n_rays, 3) ray origins
+ rays_d: (..., n_rays, 3) ray directions.
+ pixel_width: (..., n_rays) pixel width. How to compute? pixel_width = 1 / (normalized focal length * width)
+
+ Returns
+ if return_dict is False, return rendered results only: (If `n_fine == 0`, return coarse results, otherwise return fine results)
+ rgb: (..., n_rays, 3) rendered color values.
+ depth: (..., n_rays) rendered depth values.
+ else, return a dict. If `n_fine == 0`, the dict only contains coarse results:
+ ```
+ {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'sdf': ..., 'normal': ...}
+ ```
+ If n_fine > 0, the dict contains both coarse and fine results:
+ ```
+ {
+ "coarse": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..},
+ "fine": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}
+ }
+ ```
+ """
+
+ # 1. Coarse: bin sampling
+ z_coarse = bin_sample(rays_d.shape[:-1], n_coarse, near, far, device=rays_o.device, dtype=rays_o.dtype, spacing=z_spacing) # (n_batch, n_views, n_rays, n_samples)
+ points_coarse = rays_o[..., None, :] + rays_d[..., None, :] * z_coarse[..., None] # (n_batch, n_views, n_rays, n_samples, 3)
+
+ # Query color and density
+ color_coarse, sdf_coarse = neus(points_coarse, rays_d[..., None, :].expand_as(points_coarse)) # (n_batch, n_views, n_rays, n_samples, 3), (n_batch, n_views, n_rays, n_samples)
+
+ # Volume rendering
+ with torch.no_grad():
+ rgb_coarse, depth_coarse, weights = neus_volume_rendering(color_coarse, sdf_coarse, s, z_coarse) # (n_batch, n_views, n_rays, 3), (n_batch, n_views, n_rays, 1), (n_batch, n_views, n_rays, n_samples)
+
+ if n_fine == 0:
+ if return_dict:
+ return {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights, 'z_vals': z_coarse, 'color': color_coarse, 'sdf': sdf_coarse}
+ else:
+ return rgb_coarse, depth_coarse
+
+ # If coarse and fine stages share the same model, the points of coarse stage can be reused,
+ # and we only need to query the importance samples of fine stage.
+ z_fine = importance_sample(z_coarse, weights, n_fine)
+ points_fine = rays_o[..., None, :] + rays_d[..., None, :] * z_fine[..., None]
+ color_fine, sdf_fine = neus(points_fine, rays_d[..., None, :].expand_as(points_fine))
+
+ # Merge & volume rendering
+ z_vals = torch.cat([z_coarse, z_fine], dim=-1)
+ color = torch.cat([color_coarse, color_fine], dim=-2)
+ sdf = torch.cat([sdf_coarse, sdf_fine], dim=-1)
+ z_vals, sort_inds = torch.sort(z_vals, dim=-1)
+ color = torch.gather(color, dim=-2, index=sort_inds[..., None].expand_as(color))
+ sdf = torch.gather(sdf, dim=-1, index=sort_inds)
+ rgb, depth, weights = neus_volume_rendering(color, sdf, s, z_vals)
+
+ if return_dict:
+ return {
+ 'coarse': {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights, 'z_vals': z_coarse, 'color': color_coarse, 'sdf': sdf_coarse},
+ 'fine': {'rgb': rgb, 'depth': depth, 'weights': weights, 'z_vals': z_vals, 'color': color, 'sdf': sdf}
+ }
+ else:
+ return rgb, depth
+
+
+def nerf_render_view(
+ nerf: Tensor,
+ extrinsics: Tensor,
+ intrinsics: Tensor,
+ width: int,
+ height: int,
+ *,
+ patchify: bool = False,
+ patch_size: Tuple[int, int] = (64, 64),
+ **options: Dict[str, Any]
+) -> Tuple[Tensor, Tensor]:
+ """
+ NeRF rendering of views. Note that it supports arbitrary batch dimensions (denoted as `...`)
+
+ Args:
+ extrinsics: (..., 4, 4) extrinsics matrice of the rendered views
+ intrinsics (optional): (..., 3, 3) intrinsics matrice of the rendered views.
+ width (optional): image width of the rendered views.
+ height (optional): image height of the rendered views.
+ patchify (optional): If the image is too large, render it patch by patch
+ **options: rendering options.
+
+ Returns:
+ rgb: (..., channels, height, width) rendered color values.
+ depth: (..., height, width) rendered depth values.
+ """
+ if patchify:
+ # Patchified rendering
+ max_patch_width, max_patch_height = patch_size
+ n_rows, n_columns = math.ceil(height / max_patch_height), math.ceil(width / max_patch_width)
+
+ rgb_rows, depth_rows = [], []
+ for i_row in range(n_rows):
+ rgb_row, depth_row = [], []
+ for i_column in range(n_columns):
+ patch_shape = patch_height, patch_width = min(max_patch_height, height - i_row * max_patch_height), min(max_patch_width, width - i_column * max_patch_width)
+ uv = image_uv(height, width, i_column * max_patch_width, i_row * max_patch_height, i_column * max_patch_width + patch_width, i_row * max_patch_height + patch_height).to(extrinsics)
+ uv = uv.flatten(0, 1) # (patch_height * patch_width, 2)
+ ray_o_, ray_d_ = get_rays(extrinsics, intrinsics, uv)
+ rgb_, depth_ = nerf_render_rays(nerf, ray_o_, ray_d_, **options, return_dict=False)
+ rgb_ = rgb_.transpose(-1, -2).unflatten(-1, patch_shape) # (..., 3, patch_height, patch_width)
+ depth_ = depth_.unflatten(-1, patch_shape) # (..., patch_height, patch_width)
+
+ rgb_row.append(rgb_)
+ depth_row.append(depth_)
+ rgb_rows.append(torch.cat(rgb_row, dim=-1))
+ depth_rows.append(torch.cat(depth_row, dim=-1))
+ rgb = torch.cat(rgb_rows, dim=-2)
+ depth = torch.cat(depth_rows, dim=-2)
+
+ return rgb, depth
+ else:
+ # Full rendering
+ uv = image_uv(height, width).to(extrinsics)
+ uv = uv.flatten(0, 1) # (height * width, 2)
+ ray_o_, ray_d_ = get_rays(extrinsics, intrinsics, uv)
+ rgb, depth = nerf_render_rays(nerf, ray_o_, ray_d_, **options, return_dict=False)
+ rgb = rgb.transpose(-1, -2).unflatten(-1, (height, width)) # (..., 3, height, width)
+ depth = depth.unflatten(-1, (height, width)) # (..., height, width)
+
+ return rgb, depth
+
+
+def mipnerf_render_view(
+ mipnerf: Tensor,
+ extrinsics: Tensor,
+ intrinsics: Tensor,
+ width: int,
+ height: int,
+ *,
+ patchify: bool = False,
+ patch_size: Tuple[int, int] = (64, 64),
+ **options: Dict[str, Any]
+) -> Tuple[Tensor, Tensor]:
+ """
+ MipNeRF rendering of views. Note that it supports arbitrary batch dimensions (denoted as `...`)
+
+ Args:
+ extrinsics: (..., 4, 4) extrinsics matrice of the rendered views
+ intrinsics (optional): (..., 3, 3) intrinsics matrice of the rendered views.
+ width (optional): image width of the rendered views.
+ height (optional): image height of the rendered views.
+ patchify (optional): If the image is too large, render it patch by patch
+ **options: rendering options.
+
+ Returns:
+ rgb: (..., 3, height, width) rendered color values.
+ depth: (..., height, width) rendered depth values.
+ """
+ pixel_width = get_pixel_width(intrinsics, width, height)
+
+ if patchify:
+ # Patchified rendering
+ max_patch_width, max_patch_height = patch_size
+ n_rows, n_columns = math.ceil(height / max_patch_height), math.ceil(width / max_patch_width)
+
+ rgb_rows, depth_rows = [], []
+ for i_row in range(n_rows):
+ rgb_row, depth_row = [], []
+ for i_column in range(n_columns):
+ patch_shape = patch_height, patch_width = min(max_patch_height, height - i_row * max_patch_height), min(max_patch_width, width - i_column * max_patch_width)
+ uv = image_uv(height, width, i_column * max_patch_width, i_row * max_patch_height, i_column * max_patch_width + patch_width, i_row * max_patch_height + patch_height).to(extrinsics)
+ uv = uv.flatten(0, 1) # (patch_height * patch_width, 2)
+ ray_o_, ray_d_ = get_rays(extrinsics, intrinsics, uv)
+ rgb_, depth_ = mipnerf_render_rays(mipnerf, ray_o_, ray_d_, pixel_width, **options)
+ rgb_ = rgb_.transpose(-1, -2).unflatten(-1, patch_shape) # (..., 3, patch_height, patch_width)
+ depth_ = depth_.unflatten(-1, patch_shape) # (..., patch_height, patch_width)
+
+ rgb_row.append(rgb_)
+ depth_row.append(depth_)
+ rgb_rows.append(torch.cat(rgb_row, dim=-1))
+ depth_rows.append(torch.cat(depth_row, dim=-1))
+ rgb = torch.cat(rgb_rows, dim=-2)
+ depth = torch.cat(depth_rows, dim=-2)
+
+ return rgb, depth
+ else:
+ # Full rendering
+ uv = image_uv(height, width).to(extrinsics)
+ uv = uv.flatten(0, 1) # (height * width, 2)
+ ray_o_, ray_d_ = get_rays(extrinsics, intrinsics, uv)
+ rgb, depth = mipnerf_render_rays(mipnerf, ray_o_, ray_d_, pixel_width, **options)
+ rgb = rgb.transpose(-1, -2).unflatten(-1, (height, width)) # (..., 3, height, width)
+ depth = depth.unflatten(-1, (height, width)) # (..., height, width)
+
+ return rgb, depth
+
+
+class InstantNGP(nn.Module):
+ """
+ An implementation of InstantNGP, Müller et. al., https://nvlabs.github.io/instant-ngp/.
+ Requires `tinycudann` package.
+ Install it by:
+ ```
+ pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
+ ```
+ """
+ def __init__(self,
+ view_dependent: bool = True,
+ base_resolution: int = 16,
+ finest_resolution: int = 2048,
+ n_levels: int = 16,
+ num_layers_density: int = 2,
+ hidden_dim_density: int = 64,
+ num_layers_color: int = 3,
+ hidden_dim_color: int = 64,
+ log2_hashmap_size: int = 19,
+ bound: float = 1.0,
+ color_channels: int = 3,
+ ):
+ super().__init__()
+ import tinycudann
+ N_FEATURES_PER_LEVEL = 2
+ GEO_FEAT_DIM = 15
+
+ self.bound = bound
+ self.color_channels = color_channels
+
+ # density network
+ self.num_layers_density = num_layers_density
+ self.hidden_dim_density = hidden_dim_density
+
+ per_level_scale = (finest_resolution / base_resolution) ** (1 / (n_levels - 1))
+
+ self.encoder = tinycudann.Encoding(
+ n_input_dims=3,
+ encoding_config={
+ "otype": "HashGrid",
+ "n_levels": n_levels,
+ "n_features_per_level": N_FEATURES_PER_LEVEL,
+ "log2_hashmap_size": log2_hashmap_size,
+ "base_resolution": base_resolution,
+ "per_level_scale": per_level_scale,
+ },
+ )
+
+ self.density_net = tinycudann.Network(
+ n_input_dims=N_FEATURES_PER_LEVEL * n_levels,
+ n_output_dims=1 + GEO_FEAT_DIM,
+ network_config={
+ "otype": "FullyFusedMLP",
+ "activation": "ReLU",
+ "output_activation": "None",
+ "n_neurons": hidden_dim_density,
+ "n_hidden_layers": num_layers_density - 1,
+ },
+ )
+
+ # color network
+ self.num_layers_color = num_layers_color
+ self.hidden_dim_color = hidden_dim_color
+
+ self.view_dependent = view_dependent
+ if view_dependent:
+ self.encoder_dir = tinycudann.Encoding(
+ n_input_dims=3,
+ encoding_config={
+ "otype": "SphericalHarmonics",
+ "degree": 4,
+ },
+ )
+ self.in_dim_color = self.encoder_dir.n_output_dims + GEO_FEAT_DIM
+ else:
+ self.in_dim_color = GEO_FEAT_DIM
+
+ self.color_net = tinycudann.Network(
+ n_input_dims=self.in_dim_color,
+ n_output_dims=color_channels,
+ network_config={
+ "otype": "FullyFusedMLP",
+ "activation": "ReLU",
+ "output_activation": "None",
+ "n_neurons": hidden_dim_color,
+ "n_hidden_layers": num_layers_color - 1,
+ },
+ )
+
+ def forward(self, x: torch.Tensor, d: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ x: (..., 3) points
+ d: (..., 3) directions
+ Returns:
+ color: (..., 3) color values.
+ density: (..., 1) density values.
+ """
+ batch_shape = x.shape[:-1]
+ x, d = x.reshape(-1, 3), d.reshape(-1, 3)
+
+ # density
+ x = (x + self.bound) / (2 * self.bound) # to [0, 1]
+ x = self.encoder(x)
+ density, geo_feat = self.density_net(x).split([1, 15], dim=-1)
+ density = F.softplus(density).squeeze(-1)
+
+ # color
+ if self.view_dependent:
+ d = (F.normalize(d, dim=-1) + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1]
+ d = self.encoder_dir(d)
+ h = torch.cat([d, geo_feat], dim=-1)
+ else:
+ h = geo_feat
+ color = self.color_net(h)
+
+ return color.reshape(*batch_shape, self.color_channels), density.reshape(*batch_shape)
+
diff --git a/src/utils3d/torch/rasterization.py b/src/utils3d/torch/rasterization.py
new file mode 100644
index 0000000000000000000000000000000000000000..11802737ebeae9be2e6b7bda7ee0d933b01ab909
--- /dev/null
+++ b/src/utils3d/torch/rasterization.py
@@ -0,0 +1,392 @@
+from typing import *
+
+import torch
+import nvdiffrast.torch as dr
+
+from . import utils, transforms, mesh
+from ._helpers import batched
+
+
+__all__ = [
+ 'RastContext',
+ 'rasterize_triangle_faces',
+ 'warp_image_by_depth',
+ 'warp_image_by_forward_flow',
+]
+
+
+class RastContext:
+ """
+ Create a rasterization context. Nothing but a wrapper of nvdiffrast.torch.RasterizeCudaContext or nvdiffrast.torch.RasterizeGLContext.
+ """
+ def __init__(self, nvd_ctx: Union[dr.RasterizeCudaContext, dr.RasterizeGLContext] = None, *, backend: Literal['cuda', 'gl'] = 'gl', device: Union[str, torch.device] = None):
+ if nvd_ctx is not None:
+ self.nvd_ctx = nvd_ctx
+ return
+
+ if backend == 'gl':
+ self.nvd_ctx = dr.RasterizeGLContext(device=device)
+ elif backend == 'cuda':
+ self.nvd_ctx = dr.RasterizeCudaContext(device=device)
+ else:
+ raise ValueError(f'Unknown backend: {backend}')
+
+
+def rasterize_triangle_faces(
+ ctx: RastContext,
+ vertices: torch.Tensor,
+ faces: torch.Tensor,
+ width: int,
+ height: int,
+ attr: torch.Tensor = None,
+ uv: torch.Tensor = None,
+ texture: torch.Tensor = None,
+ model: torch.Tensor = None,
+ view: torch.Tensor = None,
+ projection: torch.Tensor = None,
+ antialiasing: Union[bool, List[int]] = True,
+ diff_attrs: Union[None, List[int]] = None,
+) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ """
+ Rasterize a mesh with vertex attributes.
+
+ Args:
+ ctx (GLContext): rasterizer context
+ vertices (np.ndarray): (B, N, 2 or 3 or 4)
+ faces (torch.Tensor): (T, 3)
+ width (int): width of the output image
+ height (int): height of the output image
+ attr (torch.Tensor, optional): (B, N, C) vertex attributes. Defaults to None.
+ uv (torch.Tensor, optional): (B, N, 2) uv coordinates. Defaults to None.
+ texture (torch.Tensor, optional): (B, H, W, C) texture. Defaults to None.
+ model (torch.Tensor, optional): ([B,] 4, 4) model matrix. Defaults to None (identity).
+ view (torch.Tensor, optional): ([B,] 4, 4) view matrix. Defaults to None (identity).
+ projection (torch.Tensor, optional): ([B,] 4, 4) projection matrix. Defaults to None (identity).
+ antialiasing (Union[bool, List[int]], optional): whether to perform antialiasing. Defaults to True. If a list of indices is provided, only those channels will be antialiased.
+ diff_attrs (Union[None, List[int]], optional): indices of attributes to compute screen-space derivatives. Defaults to None.
+
+ Returns:
+ Dictionary containing:
+ - image: (torch.Tensor): (B, C, H, W)
+ - depth: (torch.Tensor): (B, H, W) screen space depth, ranging from 0 (near) to 1. (far)
+ NOTE: Empty pixels will have depth 1., i.e. far plane.
+ - mask: (torch.BoolTensor): (B, H, W) mask of valid pixels
+ - image_dr: (torch.Tensor): (B, 4, H, W) screen space derivatives of the attributes
+ - face_id: (torch.Tensor): (B, H, W) face ids
+ - uv: (torch.Tensor): (B, N, 2) uv coordinates (if uv is not None)
+ - uv_dr: (torch.Tensor): (B, N, 4) uv derivatives (if uv is not None)
+ - texture: (torch.Tensor): (B, H, W, C) texture (if uv and texture are not None)
+ """
+ assert vertices.ndim == 3
+ assert faces.ndim == 2
+
+ if vertices.shape[-1] == 2:
+ vertices = torch.cat([vertices, torch.zeros_like(vertices[..., :1]), torch.ones_like(vertices[..., :1])], dim=-1)
+ elif vertices.shape[-1] == 3:
+ vertices = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1)
+ elif vertices.shape[-1] == 4:
+ pass
+ else:
+ raise ValueError(f'Wrong shape of vertices: {vertices.shape}')
+
+ mvp = projection if projection is not None else torch.eye(4).to(vertices)
+ if view is not None:
+ mvp = mvp @ view
+ if model is not None:
+ mvp = mvp @ model
+
+ pos_clip = vertices @ mvp.transpose(-1, -2)
+ faces = faces.contiguous()
+ if attr is not None:
+ attr = attr.contiguous()
+
+ rast_out, rast_db = dr.rasterize(ctx.nvd_ctx, pos_clip, faces, resolution=[height, width], grad_db=True)
+ face_id = rast_out[..., 3].flip(1)
+ depth = rast_out[..., 2].flip(1)
+ mask = (face_id > 0).float()
+ depth = (depth * 0.5 + 0.5) * mask + (1.0 - mask)
+
+ ret = {
+ 'depth': depth,
+ 'mask': mask,
+ 'face_id': face_id,
+ }
+
+ if attr is not None:
+ image, image_dr = dr.interpolate(attr, rast_out, faces, rast_db, diff_attrs=diff_attrs)
+ if antialiasing == True:
+ image = dr.antialias(image, rast_out, pos_clip, faces)
+ elif isinstance(antialiasing, list):
+ aa_image = dr.antialias(image[..., antialiasing], rast_out, pos_clip, faces)
+ image[..., antialiasing] = aa_image
+ image = image.flip(1).permute(0, 3, 1, 2)
+ ret['image'] = image
+
+ if uv is not None:
+ uv_map, uv_map_dr = dr.interpolate(uv, rast_out, faces, rast_db, diff_attrs='all')
+ ret['uv'] = uv_map
+ ret['uv_dr'] = uv_map_dr
+ if texture is not None:
+ texture_map = dr.texture(ctx.nvd_ctx, uv_map, uv_map_dr)
+ ret['texture'] = texture_map.flip(1).permute(0, 3, 1, 2)
+
+ if diff_attrs is not None:
+ image_dr = image_dr.flip(1).permute(0, 3, 1, 2)
+ ret['image_dr'] = image_dr
+
+ return ret
+
+
+def texture(
+ ctx: RastContext,
+ uv: torch.Tensor,
+ uv_da: torch.Tensor,
+ texture: torch.Tensor,
+) -> torch.Tensor:
+ dr.texture(ctx.nvd_ctx, uv, texture)
+
+
+def warp_image_by_depth(
+ ctx: RastContext,
+ depth: torch.FloatTensor,
+ image: torch.FloatTensor = None,
+ mask: torch.BoolTensor = None,
+ width: int = None,
+ height: int = None,
+ *,
+ extrinsics_src: torch.FloatTensor = None,
+ extrinsics_tgt: torch.FloatTensor = None,
+ intrinsics_src: torch.FloatTensor = None,
+ intrinsics_tgt: torch.FloatTensor = None,
+ near: float = 0.1,
+ far: float = 100.0,
+ antialiasing: bool = True,
+ backslash: bool = False,
+ padding: int = 0,
+ return_uv: bool = False,
+ return_dr: bool = False,
+) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.BoolTensor, Optional[torch.FloatTensor], Optional[torch.FloatTensor]]:
+ """
+ Warp image by depth.
+ NOTE: if batch size is 1, image mesh will be triangulated aware of the depth, yielding less distorted results.
+ Otherwise, image mesh will be triangulated simply for batch rendering.
+
+ Args:
+ ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): rasterization context
+ depth (torch.Tensor): (B, H, W) linear depth
+ image (torch.Tensor): (B, C, H, W). None to use image space uv. Defaults to None.
+ width (int, optional): width of the output image. None to use the same as depth. Defaults to None.
+ height (int, optional): height of the output image. Defaults the same as depth..
+ extrinsics_src (torch.Tensor, optional): (B, 4, 4) extrinsics matrix for source. None to use identity. Defaults to None.
+ extrinsics_tgt (torch.Tensor, optional): (B, 4, 4) extrinsics matrix for target. None to use identity. Defaults to None.
+ intrinsics_src (torch.Tensor, optional): (B, 3, 3) intrinsics matrix for source. None to use the same as target. Defaults to None.
+ intrinsics_tgt (torch.Tensor, optional): (B, 3, 3) intrinsics matrix for target. None to use the same as source. Defaults to None.
+ near (float, optional): near plane. Defaults to 0.1.
+ far (float, optional): far plane. Defaults to 100.0.
+ antialiasing (bool, optional): whether to perform antialiasing. Defaults to True.
+ backslash (bool, optional): whether to use backslash triangulation. Defaults to False.
+ padding (int, optional): padding of the image. Defaults to 0.
+ return_uv (bool, optional): whether to return the uv. Defaults to False.
+ return_dr (bool, optional): whether to return the image-space derivatives of uv. Defaults to False.
+
+ Returns:
+ image: (torch.FloatTensor): (B, C, H, W) rendered image
+ depth: (torch.FloatTensor): (B, H, W) linear depth, ranging from 0 to inf
+ mask: (torch.BoolTensor): (B, H, W) mask of valid pixels
+ uv: (torch.FloatTensor): (B, 2, H, W) image-space uv
+ dr: (torch.FloatTensor): (B, 4, H, W) image-space derivatives of uv
+ """
+ assert depth.ndim == 3
+ batch_size = depth.shape[0]
+
+ if width is None:
+ width = depth.shape[-1]
+ if height is None:
+ height = depth.shape[-2]
+ if image is not None:
+ assert image.shape[-2:] == depth.shape[-2:], f'Shape of image {image.shape} does not match shape of depth {depth.shape}'
+
+ if extrinsics_src is None:
+ extrinsics_src = torch.eye(4).to(depth)
+ if extrinsics_tgt is None:
+ extrinsics_tgt = torch.eye(4).to(depth)
+ if intrinsics_src is None:
+ intrinsics_src = intrinsics_tgt
+ if intrinsics_tgt is None:
+ intrinsics_tgt = intrinsics_src
+
+ assert all(x is not None for x in [extrinsics_src, extrinsics_tgt, intrinsics_src, intrinsics_tgt]), "Make sure you have provided all the necessary camera parameters."
+
+ view_tgt = transforms.extrinsics_to_view(extrinsics_tgt)
+ perspective_tgt = transforms.intrinsics_to_perspective(intrinsics_tgt, near=near, far=far)
+
+ if padding > 0:
+ uv, faces = utils.image_mesh(width=width+2, height=height+2)
+ uv = (uv - 1 / (width + 2)) * ((width + 2) / width)
+ uv_ = uv.clone().reshape(height+2, width+2, 2)
+ uv_[0, :, 1] -= padding / height
+ uv_[-1, :, 1] += padding / height
+ uv_[:, 0, 0] -= padding / width
+ uv_[:, -1, 0] += padding / width
+ uv_ = uv_.reshape(-1, 2)
+ depth = torch.nn.functional.pad(depth, [1, 1, 1, 1], mode='replicate')
+ if image is not None:
+ image = torch.nn.functional.pad(image, [1, 1, 1, 1], mode='replicate')
+ uv, uv_, faces = uv.to(depth.device), uv_.to(depth.device), faces.to(depth.device)
+ pts = transforms.unproject_cv(
+ uv_,
+ depth.flatten(-2, -1),
+ extrinsics_src,
+ intrinsics_src,
+ )
+ else:
+ uv, faces = utils.image_mesh(width=depth.shape[-1], height=depth.shape[-2])
+ if mask is not None:
+ depth = torch.where(mask, depth, torch.tensor(far, dtype=depth.dtype, device=depth.device))
+ uv, faces = uv.to(depth.device), faces.to(depth.device)
+ pts = transforms.unproject_cv(
+ uv,
+ depth.flatten(-2, -1),
+ extrinsics_src,
+ intrinsics_src,
+ )
+
+ # triangulate
+ if batch_size == 1:
+ faces = mesh.triangulate(faces, vertices=pts[0])
+ else:
+ faces = mesh.triangulate(faces, backslash=backslash)
+
+ # rasterize attributes
+ diff_attrs = None
+ if image is not None:
+ attr = image.permute(0, 2, 3, 1).flatten(1, 2)
+ if return_dr or return_uv:
+ if return_dr:
+ diff_attrs = [image.shape[1], image.shape[1]+1]
+ if return_uv and antialiasing:
+ antialiasing = list(range(image.shape[1]))
+ attr = torch.cat([attr, uv.expand(batch_size, -1, -1)], dim=-1)
+ else:
+ attr = uv.expand(batch_size, -1, -1)
+ if antialiasing:
+ print("\033[93mWarning: you are performing antialiasing on uv. This may cause artifacts.\033[0m")
+ if return_uv:
+ return_uv = False
+ print("\033[93mWarning: image is None, return_uv is ignored.\033[0m")
+ if return_dr:
+ diff_attrs = [0, 1]
+
+ if mask is not None:
+ attr = torch.cat([attr, mask.float().flatten(1, 2).unsqueeze(-1)], dim=-1)
+
+ rast = rasterize_triangle_faces(
+ ctx,
+ pts,
+ faces,
+ width,
+ height,
+ attr=attr,
+ view=view_tgt,
+ perspective=perspective_tgt,
+ antialiasing=antialiasing,
+ diff_attrs=diff_attrs,
+ )
+ if return_dr:
+ output_image, screen_depth, output_dr = rast['image'], rast['depth'], rast['image_dr']
+ else:
+ output_image, screen_depth = rast['image'], rast['depth']
+ output_mask = screen_depth < 1.0
+
+ if mask is not None:
+ output_image, rast_mask = output_image[..., :-1, :, :], output_image[..., -1, :, :]
+ output_mask &= (rast_mask > 0.9999).reshape(-1, height, width)
+
+ if (return_dr or return_uv) and image is not None:
+ output_image, output_uv = output_image[..., :-2, :, :], output_image[..., -2:, :, :]
+
+ output_depth = transforms.depth_buffer_to_linear(screen_depth, near=near, far=far) * output_mask
+ output_image = output_image * output_mask.unsqueeze(1)
+
+ outs = [output_image, output_depth, output_mask]
+ if return_uv:
+ outs.append(output_uv)
+ if return_dr:
+ outs.append(output_dr)
+ return tuple(outs)
+
+
+def warp_image_by_forward_flow(
+ ctx: RastContext,
+ image: torch.FloatTensor,
+ flow: torch.FloatTensor,
+ depth: torch.FloatTensor = None,
+ *,
+ antialiasing: bool = True,
+ backslash: bool = False,
+) -> Tuple[torch.FloatTensor, torch.BoolTensor]:
+ """
+ Warp image by forward flow.
+ NOTE: if batch size is 1, image mesh will be triangulated aware of the depth, yielding less distorted results.
+ Otherwise, image mesh will be triangulated simply for batch rendering.
+
+ Args:
+ ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): rasterization context
+ image (torch.Tensor): (B, C, H, W) image
+ flow (torch.Tensor): (B, 2, H, W) forward flow
+ depth (torch.Tensor, optional): (B, H, W) linear depth. If None, will use the same for all pixels. Defaults to None.
+ antialiasing (bool, optional): whether to perform antialiasing. Defaults to True.
+ backslash (bool, optional): whether to use backslash triangulation. Defaults to False.
+
+ Returns:
+ image: (torch.FloatTensor): (B, C, H, W) rendered image
+ mask: (torch.BoolTensor): (B, H, W) mask of valid pixels
+ """
+ assert image.ndim == 4, f'Wrong shape of image: {image.shape}'
+ batch_size, _, height, width = image.shape
+
+ if depth is None:
+ depth = torch.ones_like(flow[:, 0])
+
+ extrinsics = torch.eye(4).to(image)
+ fov = torch.deg2rad(torch.tensor([45.0], device=image.device))
+ intrinsics = transforms.intrinsics_from_fov(fov, width, height, normalize=True)[0]
+
+ view = transforms.extrinsics_to_view(extrinsics)
+ perspective = transforms.intrinsics_to_perspective(intrinsics, near=0.1, far=100)
+
+ uv, faces = utils.image_mesh(width=width, height=height)
+ uv, faces = uv.to(image.device), faces.to(image.device)
+ uv = uv + flow.permute(0, 2, 3, 1).flatten(1, 2)
+ pts = transforms.unproject_cv(
+ uv,
+ depth.flatten(-2, -1),
+ extrinsics,
+ intrinsics,
+ )
+
+ # triangulate
+ if batch_size == 1:
+ faces = mesh.triangulate(faces, vertices=pts[0])
+ else:
+ faces = mesh.triangulate(faces, backslash=backslash)
+
+ # rasterize attributes
+ attr = image.permute(0, 2, 3, 1).flatten(1, 2)
+ rast = rasterize_triangle_faces(
+ ctx,
+ pts,
+ faces,
+ width,
+ height,
+ attr=attr,
+ view=view,
+ perspective=perspective,
+ antialiasing=antialiasing,
+ )
+ output_image, screen_depth = rast['image'], rast['depth']
+ output_mask = screen_depth < 1.0
+ output_image = output_image * output_mask.unsqueeze(1)
+
+ outs = [output_image, output_mask]
+ return tuple(outs)
diff --git a/src/utils3d/torch/transforms.py b/src/utils3d/torch/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..46e61d741c6ef000c80aa65201b55e31ed4246c6
--- /dev/null
+++ b/src/utils3d/torch/transforms.py
@@ -0,0 +1,1189 @@
+from typing import *
+from numbers import Number
+
+import torch
+import torch.nn.functional as F
+
+from ._helpers import batched
+
+
+__all__ = [
+ 'perspective',
+ 'perspective_from_fov',
+ 'perspective_from_fov_xy',
+ 'intrinsics_from_focal_center',
+ 'intrinsics_from_fov',
+ 'intrinsics_from_fov_xy',
+ 'view_look_at',
+ 'extrinsics_look_at',
+ 'perspective_to_intrinsics',
+ 'intrinsics_to_perspective',
+ 'extrinsics_to_view',
+ 'view_to_extrinsics',
+ 'normalize_intrinsics',
+ 'crop_intrinsics',
+ 'pixel_to_uv',
+ 'pixel_to_ndc',
+ 'uv_to_pixel',
+ 'project_depth',
+ 'depth_buffer_to_linear',
+ 'project_gl',
+ 'project_cv',
+ 'unproject_gl',
+ 'unproject_cv',
+ 'skew_symmetric',
+ 'rotation_matrix_from_vectors',
+ 'euler_axis_angle_rotation',
+ 'euler_angles_to_matrix',
+ 'matrix_to_euler_angles',
+ 'matrix_to_quaternion',
+ 'quaternion_to_matrix',
+ 'matrix_to_axis_angle',
+ 'axis_angle_to_matrix',
+ 'axis_angle_to_quaternion',
+ 'quaternion_to_axis_angle',
+ 'slerp',
+ 'interpolate_extrinsics',
+ 'interpolate_view',
+ 'extrinsics_to_essential',
+ 'to4x4',
+ 'rotation_matrix_2d',
+ 'rotate_2d',
+ 'translate_2d',
+ 'scale_2d',
+ 'apply_2d',
+]
+
+
+@batched(0,0,0,0)
+def perspective(
+ fov_y: Union[float, torch.Tensor],
+ aspect: Union[float, torch.Tensor],
+ near: Union[float, torch.Tensor],
+ far: Union[float, torch.Tensor]
+ ) -> torch.Tensor:
+ """
+ Get OpenGL perspective matrix
+
+ Args:
+ fov_y (float | torch.Tensor): field of view in y axis
+ aspect (float | torch.Tensor): aspect ratio
+ near (float | torch.Tensor): near plane to clip
+ far (float | torch.Tensor): far plane to clip
+
+ Returns:
+ (torch.Tensor): [..., 4, 4] perspective matrix
+ """
+ N = fov_y.shape[0]
+ ret = torch.zeros((N, 4, 4), dtype=fov_y.dtype, device=fov_y.device)
+ ret[:, 0, 0] = 1. / (torch.tan(fov_y / 2) * aspect)
+ ret[:, 1, 1] = 1. / (torch.tan(fov_y / 2))
+ ret[:, 2, 2] = (near + far) / (near - far)
+ ret[:, 2, 3] = 2. * near * far / (near - far)
+ ret[:, 3, 2] = -1.
+ return ret
+
+
+def perspective_from_fov(
+ fov: Union[float, torch.Tensor],
+ width: Union[int, torch.Tensor],
+ height: Union[int, torch.Tensor],
+ near: Union[float, torch.Tensor],
+ far: Union[float, torch.Tensor]
+ ) -> torch.Tensor:
+ """
+ Get OpenGL perspective matrix from field of view in largest dimension
+
+ Args:
+ fov (float | torch.Tensor): field of view in largest dimension
+ width (int | torch.Tensor): image width
+ height (int | torch.Tensor): image height
+ near (float | torch.Tensor): near plane to clip
+ far (float | torch.Tensor): far plane to clip
+
+ Returns:
+ (torch.Tensor): [..., 4, 4] perspective matrix
+ """
+ fov_y = 2 * torch.atan(torch.tan(fov / 2) * height / torch.maximum(width, height))
+ aspect = width / height
+ return perspective(fov_y, aspect, near, far)
+
+
+def perspective_from_fov_xy(
+ fov_x: Union[float, torch.Tensor],
+ fov_y: Union[float, torch.Tensor],
+ near: Union[float, torch.Tensor],
+ far: Union[float, torch.Tensor]
+ ) -> torch.Tensor:
+ """
+ Get OpenGL perspective matrix from field of view in x and y axis
+
+ Args:
+ fov_x (float | torch.Tensor): field of view in x axis
+ fov_y (float | torch.Tensor): field of view in y axis
+ near (float | torch.Tensor): near plane to clip
+ far (float | torch.Tensor): far plane to clip
+
+ Returns:
+ (torch.Tensor): [..., 4, 4] perspective matrix
+ """
+ aspect = torch.tan(fov_x / 2) / torch.tan(fov_y / 2)
+ return perspective(fov_y, aspect, near, far)
+
+
+@batched(0,0,0,0)
+def intrinsics_from_focal_center(
+ fx: Union[float, torch.Tensor],
+ fy: Union[float, torch.Tensor],
+ cx: Union[float, torch.Tensor],
+ cy: Union[float, torch.Tensor]
+) -> torch.Tensor:
+ """
+ Get OpenCV intrinsics matrix
+
+ Args:
+ focal_x (float | torch.Tensor): focal length in x axis
+ focal_y (float | torch.Tensor): focal length in y axis
+ cx (float | torch.Tensor): principal point in x axis
+ cy (float | torch.Tensor): principal point in y axis
+
+ Returns:
+ (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
+ """
+ N = fx.shape[0]
+ ret = torch.zeros((N, 3, 3), dtype=fx.dtype, device=fx.device)
+ zeros, ones = torch.zeros(N, dtype=fx.dtype, device=fx.device), torch.ones(N, dtype=fx.dtype, device=fx.device)
+ ret = torch.stack([fx, zeros, cx, zeros, fy, cy, zeros, zeros, ones], dim=-1).unflatten(-1, (3, 3))
+ return ret
+
+
+@batched(0, 0, 0, 0, 0, 0)
+def intrinsics_from_fov(
+ fov_max: Union[float, torch.Tensor] = None,
+ fov_min: Union[float, torch.Tensor] = None,
+ fov_x: Union[float, torch.Tensor] = None,
+ fov_y: Union[float, torch.Tensor] = None,
+ width: Union[int, torch.Tensor] = None,
+ height: Union[int, torch.Tensor] = None,
+) -> torch.Tensor:
+ """
+ Get normalized OpenCV intrinsics matrix from given field of view.
+ You can provide either fov_max, fov_min, fov_x or fov_y
+
+ Args:
+ width (int | torch.Tensor): image width
+ height (int | torch.Tensor): image height
+ fov_max (float | torch.Tensor): field of view in largest dimension
+ fov_min (float | torch.Tensor): field of view in smallest dimension
+ fov_x (float | torch.Tensor): field of view in x axis
+ fov_y (float | torch.Tensor): field of view in y axis
+
+ Returns:
+ (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
+ """
+ if fov_max is not None:
+ fx = torch.maximum(width, height) / width / (2 * torch.tan(fov_max / 2))
+ fy = torch.maximum(width, height) / height / (2 * torch.tan(fov_max / 2))
+ elif fov_min is not None:
+ fx = torch.minimum(width, height) / width / (2 * torch.tan(fov_min / 2))
+ fy = torch.minimum(width, height) / height / (2 * torch.tan(fov_min / 2))
+ elif fov_x is not None and fov_y is not None:
+ fx = 1 / (2 * torch.tan(fov_x / 2))
+ fy = 1 / (2 * torch.tan(fov_y / 2))
+ elif fov_x is not None:
+ fx = 1 / (2 * torch.tan(fov_x / 2))
+ fy = fx * width / height
+ elif fov_y is not None:
+ fy = 1 / (2 * torch.tan(fov_y / 2))
+ fx = fy * height / width
+ cx = 0.5
+ cy = 0.5
+ ret = intrinsics_from_focal_center(fx, fy, cx, cy)
+ return ret
+
+
+
+def intrinsics_from_fov_xy(
+ fov_x: Union[float, torch.Tensor],
+ fov_y: Union[float, torch.Tensor]
+ ) -> torch.Tensor:
+ """
+ Get OpenCV intrinsics matrix from field of view in x and y axis
+
+ Args:
+ fov_x (float | torch.Tensor): field of view in x axis
+ fov_y (float | torch.Tensor): field of view in y axis
+
+ Returns:
+ (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
+ """
+ focal_x = 0.5 / torch.tan(fov_x / 2)
+ focal_y = 0.5 / torch.tan(fov_y / 2)
+ cx = cy = 0.5
+ return intrinsics_from_focal_center(focal_x, focal_y, cx, cy)
+
+
+@batched(1,1,1)
+def view_look_at(
+ eye: torch.Tensor,
+ look_at: torch.Tensor,
+ up: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ Get OpenGL view matrix looking at something
+
+ Args:
+ eye (torch.Tensor): [..., 3] the eye position
+ look_at (torch.Tensor): [..., 3] the position to look at
+ up (torch.Tensor): [..., 3] head up direction (y axis in screen space). Not necessarily othogonal to view direction
+
+ Returns:
+ (torch.Tensor): [..., 4, 4], view matrix
+ """
+ N = eye.shape[0]
+ z = eye - look_at
+ x = torch.cross(up, z, dim=-1)
+ y = torch.cross(z, x, dim=-1)
+ # x = torch.cross(y, z, dim=-1)
+ x = x / x.norm(dim=-1, keepdim=True)
+ y = y / y.norm(dim=-1, keepdim=True)
+ z = z / z.norm(dim=-1, keepdim=True)
+ R = torch.stack([x, y, z], dim=-2)
+ t = -torch.matmul(R, eye[..., None])
+ ret = torch.zeros((N, 4, 4), dtype=eye.dtype, device=eye.device)
+ ret[:, :3, :3] = R
+ ret[:, :3, 3] = t[:, :, 0]
+ ret[:, 3, 3] = 1.
+ return ret
+
+
+@batched(1, 1, 1)
+def extrinsics_look_at(
+ eye: torch.Tensor,
+ look_at: torch.Tensor,
+ up: torch.Tensor
+) -> torch.Tensor:
+ """
+ Get OpenCV extrinsics matrix looking at something
+
+ Args:
+ eye (torch.Tensor): [..., 3] the eye position
+ look_at (torch.Tensor): [..., 3] the position to look at
+ up (torch.Tensor): [..., 3] head up direction (-y axis in screen space). Not necessarily othogonal to view direction
+
+ Returns:
+ (torch.Tensor): [..., 4, 4], extrinsics matrix
+ """
+ N = eye.shape[0]
+ z = look_at - eye
+ x = torch.cross(-up, z, dim=-1)
+ y = torch.cross(z, x, dim=-1)
+ # x = torch.cross(y, z, dim=-1)
+ x = x / x.norm(dim=-1, keepdim=True)
+ y = y / y.norm(dim=-1, keepdim=True)
+ z = z / z.norm(dim=-1, keepdim=True)
+ R = torch.stack([x, y, z], dim=-2)
+ t = -torch.matmul(R, eye[..., None])
+ ret = torch.zeros((N, 4, 4), dtype=eye.dtype, device=eye.device)
+ ret[:, :3, :3] = R
+ ret[:, :3, 3] = t[:, :, 0]
+ ret[:, 3, 3] = 1.
+ return ret
+
+
+@batched(2)
+def perspective_to_intrinsics(
+ perspective: torch.Tensor
+) -> torch.Tensor:
+ """
+ OpenGL perspective matrix to OpenCV intrinsics
+
+ Args:
+ perspective (torch.Tensor): [..., 4, 4] OpenGL perspective matrix
+
+ Returns:
+ (torch.Tensor): shape [..., 3, 3] OpenCV intrinsics
+ """
+ assert torch.allclose(perspective[:, [0, 1, 3], 3], 0), "The perspective matrix is not a projection matrix"
+ ret = torch.tensor([[0.5, 0., 0.5], [0., -0.5, 0.5], [0., 0., 1.]], dtype=perspective.dtype, device=perspective.device) \
+ @ perspective[:, [0, 1, 3], :3] \
+ @ torch.diag(torch.tensor([1, -1, -1], dtype=perspective.dtype, device=perspective.device))
+ return ret / ret[:, 2, 2, None, None]
+
+
+@batched(2,0,0)
+def intrinsics_to_perspective(
+ intrinsics: torch.Tensor,
+ near: Union[float, torch.Tensor],
+ far: Union[float, torch.Tensor],
+ ) -> torch.Tensor:
+ """
+ OpenCV intrinsics to OpenGL perspective matrix
+
+ Args:
+ intrinsics (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
+ near (float | torch.Tensor): [...] near plane to clip
+ far (float | torch.Tensor): [...] far plane to clip
+ Returns:
+ (torch.Tensor): [..., 4, 4] OpenGL perspective matrix
+ """
+ N = intrinsics.shape[0]
+ fx, fy = intrinsics[:, 0, 0], intrinsics[:, 1, 1]
+ cx, cy = intrinsics[:, 0, 2], intrinsics[:, 1, 2]
+ ret = torch.zeros((N, 4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
+ ret[:, 0, 0] = 2 * fx
+ ret[:, 1, 1] = 2 * fy
+ ret[:, 0, 2] = -2 * cx + 1
+ ret[:, 1, 2] = 2 * cy - 1
+ ret[:, 2, 2] = (near + far) / (near - far)
+ ret[:, 2, 3] = 2. * near * far / (near - far)
+ ret[:, 3, 2] = -1.
+ return ret
+
+
+@batched(2)
+def extrinsics_to_view(
+ extrinsics: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ OpenCV camera extrinsics to OpenGL view matrix
+
+ Args:
+ extrinsics (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix
+
+ Returns:
+ (torch.Tensor): [..., 4, 4] OpenGL view matrix
+ """
+ return extrinsics * torch.tensor([1, -1, -1, 1], dtype=extrinsics.dtype, device=extrinsics.device)[:, None]
+
+
+@batched(2)
+def view_to_extrinsics(
+ view: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ OpenGL view matrix to OpenCV camera extrinsics
+
+ Args:
+ view (torch.Tensor): [..., 4, 4] OpenGL view matrix
+
+ Returns:
+ (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix
+ """
+ return view * torch.tensor([1, -1, -1, 1], dtype=view.dtype, device=view.device)[:, None]
+
+
+@batched(2,0,0)
+def normalize_intrinsics(
+ intrinsics: torch.Tensor,
+ width: Union[int, torch.Tensor],
+ height: Union[int, torch.Tensor]
+ ) -> torch.Tensor:
+ """
+ Normalize camera intrinsics(s) to uv space
+
+ Args:
+ intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to normalize
+ width (int | torch.Tensor): [...] image width(s)
+ height (int | torch.Tensor): [...] image height(s)
+
+ Returns:
+ (torch.Tensor): [..., 3, 3] normalized camera intrinsics(s)
+ """
+ zeros = torch.zeros_like(width)
+ ones = torch.ones_like(width)
+ transform = torch.stack([
+ 1 / width, zeros, 0.5 / width,
+ zeros, 1 / height, 0.5 / height,
+ zeros, zeros, ones
+ ]).reshape(*zeros.shape, 3, 3).to(intrinsics)
+ return transform @ intrinsics
+
+
+
+@batched(2,0,0,0,0,0,0)
+def crop_intrinsics(
+ intrinsics: torch.Tensor,
+ width: Union[int, torch.Tensor],
+ height: Union[int, torch.Tensor],
+ left: Union[int, torch.Tensor],
+ top: Union[int, torch.Tensor],
+ crop_width: Union[int, torch.Tensor],
+ crop_height: Union[int, torch.Tensor]
+) -> torch.Tensor:
+ """
+ Evaluate the new intrinsics(s) after crop the image: cropped_img = img[top:top+crop_height, left:left+crop_width]
+
+ Args:
+ intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to crop
+ width (int | torch.Tensor): [...] image width(s)
+ height (int | torch.Tensor): [...] image height(s)
+ left (int | torch.Tensor): [...] left crop boundary
+ top (int | torch.Tensor): [...] top crop boundary
+ crop_width (int | torch.Tensor): [...] crop width
+ crop_height (int | torch.Tensor): [...] crop height
+
+ Returns:
+ (torch.Tensor): [..., 3, 3] cropped camera intrinsics(s)
+ """
+ zeros = torch.zeros_like(width)
+ ones = torch.ones_like(width)
+ transform = torch.stack([
+ width / crop_width, zeros, -left / crop_width,
+ zeros, height / crop_height, -top / crop_height,
+ zeros, zeros, ones
+ ]).reshape(*zeros.shape, 3, 3).to(intrinsics)
+ return transform @ intrinsics
+
+
+@batched(1,0,0)
+def pixel_to_uv(
+ pixel: torch.Tensor,
+ width: Union[int, torch.Tensor],
+ height: Union[int, torch.Tensor]
+) -> torch.Tensor:
+ """
+ Args:
+ pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1)
+ width (int | torch.Tensor): [...] image width(s)
+ height (int | torch.Tensor): [...] image height(s)
+
+ Returns:
+ (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)
+ """
+ if not torch.is_floating_point(pixel):
+ pixel = pixel.float()
+ uv = (pixel + 0.5) / torch.stack([width, height], dim=-1).to(pixel)
+ return uv
+
+
+@batched(1,0,0)
+def uv_to_pixel(
+ uv: torch.Tensor,
+ width: Union[int, torch.Tensor],
+ height: Union[int, torch.Tensor]
+) -> torch.Tensor:
+ """
+ Args:
+ uv (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)
+ width (int | torch.Tensor): [...] image width(s)
+ height (int | torch.Tensor): [...] image height(s)
+
+ Returns:
+ (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)
+ """
+ pixel = uv * torch.stack([width, height], dim=-1).to(uv) - 0.5
+ return pixel
+
+
+@batched(1,0,0)
+def pixel_to_ndc(
+ pixel: torch.Tensor,
+ width: Union[int, torch.Tensor],
+ height: Union[int, torch.Tensor]
+) -> torch.Tensor:
+ """
+ Args:
+ pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1)
+ width (int | torch.Tensor): [...] image width(s)
+ height (int | torch.Tensor): [...] image height(s)
+
+ Returns:
+ (torch.Tensor): [..., 2] pixel coordinrates defined in ndc space, the range is (-1, 1)
+ """
+ if not torch.is_floating_point(pixel):
+ pixel = pixel.float()
+ ndc = (pixel + 0.5) / (torch.stack([width, height], dim=-1).to(pixel) * torch.tensor([2, -2], dtype=pixel.dtype, device=pixel.device)) \
+ + torch.tensor([-1, 1], dtype=pixel.dtype, device=pixel.device)
+ return ndc
+
+
+@batched(0,0,0)
+def project_depth(
+ depth: torch.Tensor,
+ near: Union[float, torch.Tensor],
+ far: Union[float, torch.Tensor]
+ ) -> torch.Tensor:
+ """
+ Project linear depth to depth value in screen space
+
+ Args:
+ depth (torch.Tensor): [...] depth value
+ near (float | torch.Tensor): [...] near plane to clip
+ far (float | torch.Tensor): [...] far plane to clip
+
+ Returns:
+ (torch.Tensor): [..., 1] depth value in screen space, value ranging in [0, 1]
+ """
+ return (far - near * far / depth) / (far - near)
+
+
+@batched(0,0,0)
+def depth_buffer_to_linear(
+ depth: torch.Tensor,
+ near: Union[float, torch.Tensor],
+ far: Union[float, torch.Tensor]
+ ) -> torch.Tensor:
+ """
+ Linearize depth value to linear depth
+
+ Args:
+ depth (torch.Tensor): [...] screen depth value, ranging in [0, 1]
+ near (float | torch.Tensor): [...] near plane to clip
+ far (float | torch.Tensor): [...] far plane to clip
+
+ Returns:
+ (torch.Tensor): [...] linear depth
+ """
+ return near * far / (far - (far - near) * depth)
+
+
+@batched(2, 2, 2, 2)
+def project_gl(
+ points: torch.Tensor,
+ model: torch.Tensor = None,
+ view: torch.Tensor = None,
+ perspective: torch.Tensor = None
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Project 3D points to 2D following the OpenGL convention (except for row major matrice)
+
+ Args:
+ points (torch.Tensor): [..., N, 3 or 4] 3D points to project, if the last
+ dimension is 4, the points are assumed to be in homogeneous coordinates
+ model (torch.Tensor): [..., 4, 4] model matrix
+ view (torch.Tensor): [..., 4, 4] view matrix
+ perspective (torch.Tensor): [..., 4, 4] perspective matrix
+
+ Returns:
+ scr_coord (torch.Tensor): [..., N, 3] screen space coordinates, value ranging in [0, 1].
+ The origin (0., 0., 0.) is corresponding to the left & bottom & nearest
+ linear_depth (torch.Tensor): [..., N] linear depth
+ """
+ assert perspective is not None, "perspective matrix is required"
+
+ if points.shape[-1] == 3:
+ points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
+ mvp = perspective if perspective is not None else torch.eye(4).to(points)
+ if view is not None:
+ mvp = mvp @ view
+ if model is not None:
+ mvp = mvp @ model
+ clip_coord = points @ mvp.transpose(-1, -2)
+ ndc_coord = clip_coord[..., :3] / clip_coord[..., 3:]
+ scr_coord = ndc_coord * 0.5 + 0.5
+ linear_depth = clip_coord[..., 3]
+ return scr_coord, linear_depth
+
+
+@batched(2, 2, 2)
+def project_cv(
+ points: torch.Tensor,
+ extrinsics: torch.Tensor = None,
+ intrinsics: torch.Tensor = None
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Project 3D points to 2D following the OpenCV convention
+
+ Args:
+ points (torch.Tensor): [..., N, 3] or [..., N, 4] 3D points to project, if the last
+ dimension is 4, the points are assumed to be in homogeneous coordinates
+ extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix
+ intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix
+
+ Returns:
+ uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1].
+ The origin (0., 0.) is corresponding to the left & top
+ linear_depth (torch.Tensor): [..., N] linear depth
+ """
+ assert intrinsics is not None, "intrinsics matrix is required"
+ if points.shape[-1] == 3:
+ points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
+ if extrinsics is not None:
+ points = points @ extrinsics.transpose(-1, -2)
+ points = points[..., :3] @ intrinsics.transpose(-2, -1)
+ uv_coord = points[..., :2] / points[..., 2:]
+ linear_depth = points[..., 2]
+ return uv_coord, linear_depth
+
+
+@batched(2, 2, 2, 2)
+def unproject_gl(
+ screen_coord: torch.Tensor,
+ model: torch.Tensor = None,
+ view: torch.Tensor = None,
+ perspective: torch.Tensor = None
+ ) -> torch.Tensor:
+ """
+ Unproject screen space coordinates to 3D view space following the OpenGL convention (except for row major matrice)
+
+ Args:
+ screen_coord (torch.Tensor): [... N, 3] screen space coordinates, value ranging in [0, 1].
+ The origin (0., 0., 0.) is corresponding to the left & bottom & nearest
+ model (torch.Tensor): [..., 4, 4] model matrix
+ view (torch.Tensor): [..., 4, 4] view matrix
+ perspective (torch.Tensor): [..., 4, 4] perspective matrix
+
+ Returns:
+ points (torch.Tensor): [..., N, 3] 3d points
+ """
+ assert perspective is not None, "perspective matrix is required"
+ ndc_xy = screen_coord * 2 - 1
+ clip_coord = torch.cat([ndc_xy, torch.ones_like(ndc_xy[..., :1])], dim=-1)
+ transform = perspective
+ if view is not None:
+ transform = transform @ view
+ if model is not None:
+ transform = transform @ model
+ transform = torch.inverse(transform)
+ points = clip_coord @ transform.transpose(-1, -2)
+ points = points[..., :3] / points[..., 3:]
+ return points
+
+
+@batched(2, 1, 2, 2)
+def unproject_cv(
+ uv_coord: torch.Tensor,
+ depth: torch.Tensor,
+ extrinsics: torch.Tensor = None,
+ intrinsics: torch.Tensor = None
+) -> torch.Tensor:
+ """
+ Unproject uv coordinates to 3D view space following the OpenCV convention
+
+ Args:
+ uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1].
+ The origin (0., 0.) is corresponding to the left & top
+ depth (torch.Tensor): [..., N] depth value
+ extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix
+ intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix
+
+ Returns:
+ points (torch.Tensor): [..., N, 3] 3d points
+ """
+ assert intrinsics is not None, "intrinsics matrix is required"
+ points = torch.cat([uv_coord, torch.ones_like(uv_coord[..., :1])], dim=-1)
+ points = points @ torch.inverse(intrinsics).transpose(-2, -1)
+ points = points * depth[..., None]
+ if extrinsics is not None:
+ points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
+ points = (points @ torch.inverse(extrinsics).transpose(-2, -1))[..., :3]
+ return points
+
+
+def euler_axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
+ """
+ Return the rotation matrices for one of the rotations about an axis
+ of which Euler angles describe, for each value of the angle given.
+
+ Args:
+ axis: Axis label "X" or "Y or "Z".
+ angle: any shape tensor of Euler angles in radians
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+
+ cos = torch.cos(angle)
+ sin = torch.sin(angle)
+ one = torch.ones_like(angle)
+ zero = torch.zeros_like(angle)
+
+ if axis == "X":
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
+ elif axis == "Y":
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
+ elif axis == "Z":
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
+ else:
+ raise ValueError("letter must be either X, Y or Z.")
+
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
+
+
+def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str = 'XYZ') -> torch.Tensor:
+ """
+ Convert rotations given as Euler angles in radians to rotation matrices.
+
+ Args:
+ euler_angles: Euler angles in radians as tensor of shape (..., 3), XYZ
+ convention: permutation of "X", "Y" or "Z", representing the order of Euler rotations to apply.
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
+ raise ValueError("Invalid input euler angles.")
+ if len(convention) != 3:
+ raise ValueError("Convention must have 3 letters.")
+ if convention[1] in (convention[0], convention[2]):
+ raise ValueError(f"Invalid convention {convention}.")
+ for letter in convention:
+ if letter not in ("X", "Y", "Z"):
+ raise ValueError(f"Invalid letter {letter} in convention string.")
+ matrices = [
+ euler_axis_angle_rotation(c, euler_angles[..., 'XYZ'.index(c)])
+ for c in convention
+ ]
+ # return functools.reduce(torch.matmul, matrices)
+ return matrices[2] @ matrices[1] @ matrices[0]
+
+
+def skew_symmetric(v: torch.Tensor):
+ "Skew symmetric matrix from a 3D vector"
+ assert v.shape[-1] == 3, "v must be 3D"
+ x, y, z = v.unbind(dim=-1)
+ zeros = torch.zeros_like(x)
+ return torch.stack([
+ zeros, -z, y,
+ z, zeros, -x,
+ -y, x, zeros,
+ ], dim=-1).reshape(*v.shape[:-1], 3, 3)
+
+
+def rotation_matrix_from_vectors(v1: torch.Tensor, v2: torch.Tensor):
+ "Rotation matrix that rotates v1 to v2"
+ I = torch.eye(3).to(v1)
+ v1 = F.normalize(v1, dim=-1)
+ v2 = F.normalize(v2, dim=-1)
+ v = torch.cross(v1, v2, dim=-1)
+ c = torch.sum(v1 * v2, dim=-1)
+ K = skew_symmetric(v)
+ R = I + K + (1 / (1 + c))[None, None] * (K @ K)
+ return R
+
+
+def _angle_from_tan(
+ axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
+) -> torch.Tensor:
+ """
+ Extract the first or third Euler angle from the two members of
+ the matrix which are positive constant times its sine and cosine.
+
+ Args:
+ axis: Axis label "X" or "Y or "Z" for the angle we are finding.
+ other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
+ convention.
+ data: Rotation matrices as tensor of shape (..., 3, 3).
+ horizontal: Whether we are looking for the angle for the third axis,
+ which means the relevant entries are in the same row of the
+ rotation matrix. If not, they are in the same column.
+ tait_bryan: Whether the first and third axes in the convention differ.
+
+ Returns:
+ Euler Angles in radians for each matrix in data as a tensor
+ of shape (...).
+ """
+
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
+ if horizontal:
+ i2, i1 = i1, i2
+ even = (axis + other_axis) in ["XY", "YZ", "ZX"]
+ if horizontal == even:
+ return torch.atan2(data[..., i1], data[..., i2])
+ if tait_bryan:
+ return torch.atan2(-data[..., i2], data[..., i1])
+ return torch.atan2(data[..., i2], -data[..., i1])
+
+
+def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to Euler angles in radians.
+ NOTE: The composition order eg. `XYZ` means `Rz * Ry * Rx` (like blender), instead of `Rx * Ry * Rz` (like pytorch3d)
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+ convention: Convention string of three uppercase letters.
+
+ Returns:
+ Euler angles in radians as tensor of shape (..., 3), in the order of XYZ (like blender), instead of convention (like pytorch3d)
+ """
+ if not all(c in 'XYZ' for c in convention) or not all(c in convention for c in 'XYZ'):
+ raise ValueError(f"Invalid convention {convention}.")
+ if not matrix.shape[-2:] == (3, 3):
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ i0 = 'XYZ'.index(convention[0])
+ i2 = 'XYZ'.index(convention[2])
+ tait_bryan = i0 != i2
+ if tait_bryan:
+ central_angle = torch.asin(matrix[..., i2, i0] * (-1.0 if i2 - i0 in [-1, 2] else 1.0))
+ else:
+ central_angle = torch.acos(matrix[..., i2, i2])
+
+ # Angles in composition order
+ o = [
+ _angle_from_tan(
+ convention[0], convention[1], matrix[..., i2, :], True, tait_bryan
+ ),
+ central_angle,
+ _angle_from_tan(
+ convention[2], convention[1], matrix[..., i0], False, tait_bryan
+ ),
+ ]
+ return torch.stack([o[convention.index(c)] for c in 'XYZ'], -1)
+
+
+def axis_angle_to_matrix(axis_angle: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
+ """Convert axis-angle representation (rotation vector) to rotation matrix, whose direction is the axis of rotation and length is the angle of rotation
+
+ Args:
+ axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors
+
+ Returns:
+ torch.Tensor: shape (..., 3, 3) The rotation matrices for the given axis-angle parameters
+ """
+ batch_shape = axis_angle.shape[:-1]
+ device, dtype = axis_angle.device, axis_angle.dtype
+
+ angle = torch.norm(axis_angle + eps, dim=-1, keepdim=True)
+ axis = axis_angle / angle
+
+ cos = torch.cos(angle)[..., None, :]
+ sin = torch.sin(angle)[..., None, :]
+
+ rx, ry, rz = torch.split(axis, 3, dim=-1)
+ zeros = torch.zeros((*batch_shape, 1), dtype=dtype, device=device)
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1).view((*batch_shape, 3, 3))
+
+ ident = torch.eye(3, dtype=dtype, device=device)
+ rot_mat = ident + sin * K + (1 - cos) * torch.matmul(K, K)
+ return rot_mat
+
+
+def matrix_to_axis_angle(rot_mat: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
+ """Convert a batch of 3x3 rotation matrices to axis-angle representation (rotation vector)
+
+ Args:
+ rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert
+
+ Returns:
+ torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given rotation matrices
+ """
+ quat = matrix_to_quaternion(rot_mat)
+ axis_angle = quaternion_to_axis_angle(quat, eps=eps)
+ return axis_angle
+
+
+def quaternion_to_axis_angle(quaternion: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
+ """Convert a batch of quaternions (w, x, y, z) to axis-angle representation (rotation vector)
+
+ Args:
+ quaternion (torch.Tensor): shape (..., 4), the quaternions to convert
+
+ Returns:
+ torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given quaternions
+ """
+ assert quaternion.shape[-1] == 4
+ norm = torch.norm(quaternion[..., 1:], dim=-1, keepdim=True)
+ axis = quaternion[..., 1:] / norm.clamp(min=eps)
+ angle = 2 * torch.atan2(norm, quaternion[..., 0:1])
+ return angle * axis
+
+
+def axis_angle_to_quaternion(axis_angle: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
+ """Convert axis-angle representation (rotation vector) to quaternion (w, x, y, z)
+
+ Args:
+ axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors
+
+ Returns:
+ torch.Tensor: shape (..., 4) The quaternions for the given axis-angle parameters
+ """
+ axis = F.normalize(axis_angle, dim=-1, eps=eps)
+ angle = torch.norm(axis_angle, dim=-1, keepdim=True)
+ quat = torch.cat([torch.cos(angle / 2), torch.sin(angle / 2) * axis], dim=-1)
+ return quat
+
+
+def matrix_to_quaternion(rot_mat: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
+ """Convert 3x3 rotation matrix to quaternion (w, x, y, z)
+
+ Args:
+ rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert
+
+ Returns:
+ torch.Tensor: shape (..., 4), the quaternions corresponding to the given rotation matrices
+ """
+ # Extract the diagonal and off-diagonal elements of the rotation matrix
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = rot_mat.flatten(-2).unbind(dim=-1)
+
+ diag = torch.diagonal(rot_mat, dim1=-2, dim2=-1)
+ M = torch.tensor([
+ [1, 1, 1],
+ [1, -1, -1],
+ [-1, 1, -1],
+ [-1, -1, 1]
+ ], dtype=rot_mat.dtype, device=rot_mat.device)
+ wxyz = (1 + diag @ M.transpose(-1, -2)).clamp_(0).sqrt().mul(0.5)
+ _, max_idx = wxyz.max(dim=-1)
+ xw = torch.sign(m21 - m12)
+ yw = torch.sign(m02 - m20)
+ zw = torch.sign(m10 - m01)
+ yz = torch.sign(m21 + m12)
+ xz = torch.sign(m02 + m20)
+ xy = torch.sign(m01 + m10)
+ ones = torch.ones_like(xw)
+ sign = torch.where(
+ max_idx[..., None] == 0,
+ torch.stack([ones, xw, yw, zw], dim=-1),
+ torch.where(
+ max_idx[..., None] == 1,
+ torch.stack([xw, ones, xy, xz], dim=-1),
+ torch.where(
+ max_idx[..., None] == 2,
+ torch.stack([yw, xy, ones, yz], dim=-1),
+ torch.stack([zw, xz, yz, ones], dim=-1)
+ )
+ )
+ )
+ quat = sign * wxyz
+ quat = F.normalize(quat, dim=-1, eps=eps)
+ return quat
+
+
+def quaternion_to_matrix(quaternion: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
+ """Converts a batch of quaternions (w, x, y, z) to rotation matrices
+
+ Args:
+ quaternion (torch.Tensor): shape (..., 4), the quaternions to convert
+
+ Returns:
+ torch.Tensor: shape (..., 3, 3), the rotation matrices corresponding to the given quaternions
+ """
+ assert quaternion.shape[-1] == 4
+ quaternion = F.normalize(quaternion, dim=-1, eps=eps)
+ w, x, y, z = quaternion.unbind(dim=-1)
+ zeros = torch.zeros_like(w)
+ I = torch.eye(3, dtype=quaternion.dtype, device=quaternion.device)
+ xyz = quaternion[..., 1:]
+ A = xyz[..., :, None] * xyz[..., None, :] - I * (xyz ** 2).sum(dim=-1)[..., None, None]
+ B = torch.stack([
+ zeros, -z, y,
+ z, zeros, -x,
+ -y, x, zeros
+ ], dim=-1).unflatten(-1, (3, 3))
+ rot_mat = I + 2 * (A + w[..., None, None] * B)
+ return rot_mat
+
+
+def slerp(rot_mat_1: torch.Tensor, rot_mat_2: torch.Tensor, t: Union[Number, torch.Tensor]) -> torch.Tensor:
+ """Spherical linear interpolation between two rotation matrices
+
+ Args:
+ rot_mat_1 (torch.Tensor): shape (..., 3, 3), the first rotation matrix
+ rot_mat_2 (torch.Tensor): shape (..., 3, 3), the second rotation matrix
+ t (torch.Tensor): scalar or shape (...,), the interpolation factor
+
+ Returns:
+ torch.Tensor: shape (..., 3, 3), the interpolated rotation matrix
+ """
+ assert rot_mat_1.shape[-2:] == (3, 3)
+ rot_vec_1 = matrix_to_axis_angle(rot_mat_1)
+ rot_vec_2 = matrix_to_axis_angle(rot_mat_2)
+ if isinstance(t, Number):
+ t = torch.tensor(t, dtype=rot_mat_1.dtype, device=rot_mat_1.device)
+ rot_vec = (1 - t[..., None]) * rot_vec_1 + t[..., None] * rot_vec_2
+ rot_mat = axis_angle_to_matrix(rot_vec)
+ return rot_mat
+
+
+def interpolate_extrinsics(ext1: torch.Tensor, ext2: torch.Tensor, t: Union[Number, torch.Tensor]) -> torch.Tensor:
+ """Interpolate extrinsics between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation.
+
+ Args:
+ ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose
+ ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose
+ t (torch.Tensor): scalar or shape (...,), the interpolation factor
+
+ Returns:
+ torch.Tensor: shape (..., 4, 4), the interpolated camera pose
+ """
+ return torch.inverse(interpolate_transform(torch.inverse(ext1), torch.inverse(ext2), t))
+
+
+def interpolate_view(view1: torch.Tensor, view2: torch.Tensor, t: Union[Number, torch.Tensor]):
+ """Interpolate view matrices between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation.
+
+ Args:
+ ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose
+ ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose
+ t (torch.Tensor): scalar or shape (...,), the interpolation factor
+
+ Returns:
+ torch.Tensor: shape (..., 4, 4), the interpolated camera pose
+ """
+ return interpolate_extrinsics(view1, view2, t)
+
+
+def interpolate_transform(transform1: torch.Tensor, transform2: torch.Tensor, t: Union[Number, torch.Tensor]):
+ assert transform1.shape[-2:] == (4, 4) and transform2.shape[-2:] == (4, 4)
+ if isinstance(t, Number):
+ t = torch.tensor(t, dtype=transform1.dtype, device=transform1.device)
+ pos = (1 - t[..., None]) * transform1[..., :3, 3] + t[..., None] * transform2[..., :3, 3]
+ rot = slerp(transform1[..., :3, :3], transform2[..., :3, :3], t)
+ transform = torch.cat([rot, pos[..., None]], dim=-1)
+ transform = torch.cat([ext, torch.tensor([0, 0, 0, 1], dtype=transform.dtype, device=transform.device).expand_as(transform[..., :1, :])], dim=-2)
+ return transform
+
+
+def extrinsics_to_essential(extrinsics: torch.Tensor):
+ """
+ extrinsics matrix `[[R, t] [0, 0, 0, 1]]` such that `x' = R (x - t)` to essential matrix such that `x' E x = 0`
+
+ Args:
+ extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix
+
+ Returns:
+ (torch.Tensor): [..., 3, 3] essential matrix
+ """
+ assert extrinsics.shape[-2:] == (4, 4)
+ R = extrinsics[..., :3, :3]
+ t = extrinsics[..., :3, 3]
+ zeros = torch.zeros_like(t)
+ t_x = torch.stack([
+ zeros, -t[..., 2], t[..., 1],
+ t[..., 2], zeros, -t[..., 0],
+ -t[..., 1], t[..., 0], zeros
+ ]).reshape(*t.shape[:-1], 3, 3)
+ return R @ t_x
+
+
+def to4x4(R: torch.Tensor, t: torch.Tensor):
+ """
+ Compose rotation matrix and translation vector to 4x4 transformation matrix
+
+ Args:
+ R (torch.Tensor): [..., 3, 3] rotation matrix
+ t (torch.Tensor): [..., 3] translation vector
+
+ Returns:
+ (torch.Tensor): [..., 4, 4] transformation matrix
+ """
+ assert R.shape[-2:] == (3, 3)
+ assert t.shape[-1] == 3
+ assert R.shape[:-2] == t.shape[:-1]
+ return torch.cat([
+ torch.cat([R, t[..., None]], dim=-1),
+ torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device).expand(*R.shape[:-2], 1, 4)
+ ], dim=-2)
+
+
+def rotation_matrix_2d(theta: Union[float, torch.Tensor]):
+ """
+ 2x2 matrix for 2D rotation
+
+ Args:
+ theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,)
+
+ Returns:
+ (torch.Tensor): (..., 2, 2) rotation matrix
+ """
+ if isinstance(theta, float):
+ theta = torch.tensor(theta)
+ return torch.stack([
+ torch.cos(theta), -torch.sin(theta),
+ torch.sin(theta), torch.cos(theta),
+ ], dim=-1).unflatten(-1, (2, 2))
+
+
+def rotate_2d(theta: Union[float, torch.Tensor], center: torch.Tensor = None):
+ """
+ 3x3 matrix for 2D rotation around a center
+ ```
+ [[Rxx, Rxy, tx],
+ [Ryx, Ryy, ty],
+ [0, 0, 1]]
+ ```
+ Args:
+ theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,)
+ center (torch.Tensor): rotation center, arbitrary shape (..., 2). Default to (0, 0)
+
+ Returns:
+ (torch.Tensor): (..., 3, 3) transformation matrix
+ """
+ if isinstance(theta, float):
+ theta = torch.tensor(theta)
+ if center is not None:
+ theta = theta.to(center)
+ if center is None:
+ center = torch.zeros(2).to(theta).expand(*theta.shape, -1)
+ R = rotation_matrix_2d(theta)
+ return torch.cat([
+ torch.cat([
+ R,
+ center[..., :, None] - R @ center[..., :, None],
+ ], dim=-1),
+ torch.tensor([[0, 0, 1]], dtype=center.dtype, device=center.device).expand(*center.shape[:-1], -1, -1),
+ ], dim=-2)
+
+
+def translate_2d(translation: torch.Tensor):
+ """
+ Translation matrix for 2D translation
+ ```
+ [[1, 0, tx],
+ [0, 1, ty],
+ [0, 0, 1]]
+ ```
+ Args:
+ translation (torch.Tensor): translation vector, arbitrary shape (..., 2)
+
+ Returns:
+ (torch.Tensor): (..., 3, 3) transformation matrix
+ """
+ return torch.cat([
+ torch.cat([
+ torch.eye(2, dtype=translation.dtype, device=translation.device).expand(*translation.shape[:-1], -1, -1),
+ translation[..., None],
+ ], dim=-1),
+ torch.tensor([[0, 0, 1]], dtype=translation.dtype, device=translation.device).expand(*translation.shape[:-1], -1, -1),
+ ], dim=-2)
+
+
+def scale_2d(scale: Union[float, torch.Tensor], center: torch.Tensor = None):
+ """
+ Scale matrix for 2D scaling
+ ```
+ [[s, 0, tx],
+ [0, s, ty],
+ [0, 0, 1]]
+ ```
+ Args:
+ scale (float | torch.Tensor): scale factor, arbitrary shape (...,)
+ center (torch.Tensor): scale center, arbitrary shape (..., 2). Default to (0, 0)
+
+ Returns:
+ (torch.Tensor): (..., 3, 3) transformation matrix
+ """
+ if isinstance(scale, float):
+ scale = torch.tensor(scale)
+ if center is not None:
+ scale = scale.to(center)
+ if center is None:
+ center = torch.zeros(2, dtype=scale.dtype, device=scale.device).expand(*scale.shape, -1)
+ return torch.cat([
+ torch.cat([
+ scale * torch.eye(2, dtype=scale.dtype, device=scale.device).expand(*scale.shape[:-1], -1, -1),
+ center[..., :, None] - center[..., :, None] * scale[..., None, None],
+ ], dim=-1),
+ torch.tensor([[0, 0, 1]], dtype=scale.dtype, device=scale.device).expand(*center.shape[:-1], -1, -1),
+ ], dim=-2)
+
+
+def apply_2d(transform: torch.Tensor, points: torch.Tensor):
+ """
+ Apply (3x3 or 2x3) 2D affine transformation to points
+ ```
+ p = R @ p + t
+ ```
+ Args:
+ transform (torch.Tensor): (..., 2 or 3, 3) transformation matrix
+ points (torch.Tensor): (..., N, 2) points to transform
+
+ Returns:
+ (torch.Tensor): (..., N, 2) transformed points
+ """
+ assert transform.shape[-2:] == (3, 3) or transform.shape[-2:] == (2, 3), "transform must be 3x3 or 2x3"
+ assert points.shape[-1] == 2, "points must be 2D"
+ return points @ transform[..., :2, :2].mT + transform[..., :2, None, 2]
\ No newline at end of file
diff --git a/src/utils3d/torch/utils.py b/src/utils3d/torch/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..877ffb8a60a7f5206fbeb5a9e4a584758b875da4
--- /dev/null
+++ b/src/utils3d/torch/utils.py
@@ -0,0 +1,351 @@
+from typing import *
+
+import torch
+import torch.nn.functional as F
+
+from . import transforms
+from . import mesh
+from ._helpers import batched
+
+
+__all__ = [
+ 'sliding_window_1d',
+ 'sliding_window_2d',
+ 'sliding_window_nd',
+ 'image_uv',
+ 'image_pixel_center',
+ 'image_mesh',
+ 'chessboard',
+ 'depth_edge',
+ 'depth_aliasing',
+ 'image_mesh_from_depth',
+ 'point_to_normal',
+ 'depth_to_normal',
+ 'masked_min',
+ 'masked_max',
+ 'bounding_rect'
+]
+
+
+def sliding_window_1d(x: torch.Tensor, window_size: int, stride: int = 1, dim: int = -1) -> torch.Tensor:
+ """
+ Sliding window view of the input tensor. The dimension of the sliding window is appended to the end of the input tensor's shape.
+ NOTE: Since Pytorch has `unfold` function, 1D sliding window view is just a wrapper of it.
+ """
+ return x.unfold(dim, window_size, stride)
+
+
+def sliding_window_nd(x: torch.Tensor, window_size: Tuple[int, ...], stride: Tuple[int, ...], dim: Tuple[int, ...]) -> torch.Tensor:
+ dim = [dim[i] % x.ndim for i in range(len(dim))]
+ assert len(window_size) == len(stride) == len(dim)
+ for i in range(len(window_size)):
+ x = sliding_window_1d(x, window_size[i], stride[i], dim[i])
+ return x
+
+
+def sliding_window_2d(x: torch.Tensor, window_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], dim: Union[int, Tuple[int, int]] = (-2, -1)) -> torch.Tensor:
+ if isinstance(window_size, int):
+ window_size = (window_size, window_size)
+ if isinstance(stride, int):
+ stride = (stride, stride)
+ return sliding_window_nd(x, window_size, stride, dim)
+
+
+def image_uv(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, device: torch.device = None, dtype: torch.dtype = None) -> torch.Tensor:
+ """
+ Get image space UV grid, ranging in [0, 1].
+
+ >>> image_uv(10, 10):
+ [[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]],
+ [[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]],
+ ... ... ...
+ [[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]]
+
+ Args:
+ width (int): image width
+ height (int): image height
+
+ Returns:
+ np.ndarray: shape (height, width, 2)
+ """
+ if left is None: left = 0
+ if top is None: top = 0
+ if right is None: right = width
+ if bottom is None: bottom = height
+ u = torch.linspace((left + 0.5) / width, (right - 0.5) / width, right - left, device=device, dtype=dtype)
+ v = torch.linspace((top + 0.5) / height, (bottom - 0.5) / height, bottom - top, device=device, dtype=dtype)
+ u, v = torch.meshgrid(u, v, indexing='xy')
+ uv = torch.stack([u, v], dim=-1)
+ return uv
+
+
+def image_pixel_center(
+ height: int,
+ width: int,
+ left: int = None,
+ top: int = None,
+ right: int = None,
+ bottom: int = None,
+ dtype: torch.dtype = None,
+ device: torch.device = None
+) -> torch.Tensor:
+ """
+ Get image pixel center coordinates, ranging in [0, width] and [0, height].
+ `image[i, j]` has pixel center coordinates `(j + 0.5, i + 0.5)`.
+
+ >>> image_pixel_center(10, 10):
+ [[[0.5, 0.5], [1.5, 0.5], ..., [9.5, 0.5]],
+ [[0.5, 1.5], [1.5, 1.5], ..., [9.5, 1.5]],
+ ... ... ...
+ [[0.5, 9.5], [1.5, 9.5], ..., [9.5, 9.5]]]
+
+ Args:
+ width (int): image width
+ height (int): image height
+
+ Returns:
+ np.ndarray: shape (height, width, 2)
+ """
+ if left is None: left = 0
+ if top is None: top = 0
+ if right is None: right = width
+ if bottom is None: bottom = height
+ u = torch.linspace(left + 0.5, right - 0.5, right - left, dtype=dtype, device=device)
+ v = torch.linspace(top + 0.5, bottom - 0.5, bottom - top, dtype=dtype, device=device)
+ u, v = torch.meshgrid(u, v, indexing='xy')
+ return torch.stack([u, v], dim=2)
+
+
+def image_mesh(height: int, width: int, mask: torch.Tensor = None, device: torch.device = None, dtype: torch.dtype = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Get a quad mesh regarding image pixel uv coordinates as vertices and image grid as faces.
+
+ Args:
+ width (int): image width
+ height (int): image height
+ mask (np.ndarray, optional): binary mask of shape (height, width), dtype=bool. Defaults to None.
+
+ Returns:
+ uv (np.ndarray): uv corresponding to pixels as described in image_uv()
+ faces (np.ndarray): quad faces connecting neighboring pixels
+ indices (np.ndarray, optional): indices of vertices in the original mesh
+ """
+ if device is None and mask is not None:
+ device = mask.device
+ if mask is not None:
+ assert mask.shape[0] == height and mask.shape[1] == width
+ assert mask.dtype == torch.bool
+ uv = image_uv(height, width, device=device, dtype=dtype).reshape((-1, 2))
+ row_faces = torch.stack([
+ torch.arange(0, width - 1, dtype=torch.int32, device=device),
+ torch.arange(width, 2 * width - 1, dtype=torch.int32, device=device),
+ torch.arange(1 + width, 2 * width, dtype=torch.int32, device=device),
+ torch.arange(1, width, dtype=torch.int32, device=device)
+ ], dim=1)
+ faces = (torch.arange(0, (height - 1) * width, width, device=device, dtype=torch.int32)[:, None, None] + row_faces[None, :, :]).reshape((-1, 4))
+ if mask is not None:
+ quad_mask = (mask[:-1, :-1] & mask[1:, :-1] & mask[1:, 1:] & mask[:-1, 1:]).ravel()
+ faces = faces[quad_mask]
+ faces, uv, indices = mesh.remove_unreferenced_vertices(faces, uv, return_indices=True)
+ return uv, faces, indices
+ return uv, faces
+
+
+def depth_edge(depth: torch.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch.Tensor = None) -> torch.BoolTensor:
+ """
+ Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth.
+
+ Args:
+ depth (torch.Tensor): shape (..., height, width), linear depth map
+ atol (float): absolute tolerance
+ rtol (float): relative tolerance
+
+ Returns:
+ edge (torch.Tensor): shape (..., height, width) of dtype torch.bool
+ """
+ shape = depth.shape
+ depth = depth.reshape(-1, 1, *shape[-2:])
+ if mask is not None:
+ mask = mask.reshape(-1, 1, *shape[-2:])
+
+ if mask is None:
+ diff = (F.max_pool2d(depth, kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2))
+ else:
+ diff = (F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(torch.where(mask, -depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2))
+
+ edge = torch.zeros_like(depth, dtype=torch.bool)
+ if atol is not None:
+ edge |= diff > atol
+ if rtol is not None:
+ edge |= (diff / depth).nan_to_num_() > rtol
+ edge = edge.reshape(*shape)
+ return edge
+
+
+def depth_aliasing(depth: torch.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch.Tensor = None) -> torch.BoolTensor:
+ """
+ Compute the map that indicates the aliasing of a depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors.
+ Args:
+ depth (torch.Tensor): shape (..., height, width), linear depth map
+ atol (float): absolute tolerance
+ rtol (float): relative tolerance
+
+ Returns:
+ edge (torch.Tensor): shape (..., height, width) of dtype torch.bool
+ """
+ shape = depth.shape
+ depth = depth.reshape(-1, 1, *shape[-2:])
+ if mask is not None:
+ mask = mask.reshape(-1, 1, *shape[-2:])
+
+ if mask is None:
+ diff_max = F.max_pool2d(depth, kernel_size, stride=1, padding=kernel_size // 2) - depth
+ diff_min = F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + depth
+ else:
+ diff_max = F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) - depth
+ diff_min = F.max_pool2d(torch.where(mask, -depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + depth
+ diff = torch.minimum(diff_max, diff_min)
+
+ edge = torch.zeros_like(depth, dtype=torch.bool)
+ if atol is not None:
+ edge |= diff > atol
+ if rtol is not None:
+ edge |= (diff / depth).nan_to_num_() > rtol
+ edge = edge.reshape(*shape)
+ return edge
+
+
+def image_mesh_from_depth(
+ depth: torch.Tensor,
+ extrinsics: torch.Tensor = None,
+ intrinsics: torch.Tensor = None
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ height, width = depth.shape
+ uv, faces = image_mesh(height, width)
+ faces = faces.reshape(-1, 4)
+ depth = depth.reshape(-1)
+ pts = transforms.unproject_cv(image_uv, depth, extrinsics, intrinsics)
+ faces = mesh.triangulate(faces, vertices=pts)
+ return pts, faces
+
+
+@batched(3, 2, 2)
+def point_to_normal(point: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
+ """
+ Calculate normal map from point map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system.
+
+ Args:
+ point (torch.Tensor): shape (..., height, width, 3), point map
+ Returns:
+ normal (torch.Tensor): shape (..., height, width, 3), normal map.
+ """
+ has_mask = mask is not None
+
+ if mask is None:
+ mask = torch.ones_like(point[..., 0], dtype=torch.bool)
+ mask = F.pad(mask, (1, 1, 1, 1), mode='constant', value=0)
+
+ pts = F.pad(point.permute(0, 3, 1, 2), (1, 1, 1, 1), mode='constant', value=1).permute(0, 2, 3, 1)
+ up = pts[:, :-2, 1:-1, :] - pts[:, 1:-1, 1:-1, :]
+ left = pts[:, 1:-1, :-2, :] - pts[:, 1:-1, 1:-1, :]
+ down = pts[:, 2:, 1:-1, :] - pts[:, 1:-1, 1:-1, :]
+ right = pts[:, 1:-1, 2:, :] - pts[:, 1:-1, 1:-1, :]
+ normal = torch.stack([
+ torch.cross(up, left, dim=-1),
+ torch.cross(left, down, dim=-1),
+ torch.cross(down, right, dim=-1),
+ torch.cross(right, up, dim=-1),
+ ])
+ normal = F.normalize(normal, dim=-1)
+ valid = torch.stack([
+ mask[:, :-2, 1:-1] & mask[:, 1:-1, :-2],
+ mask[:, 1:-1, :-2] & mask[:, 2:, 1:-1],
+ mask[:, 2:, 1:-1] & mask[:, 1:-1, 2:],
+ mask[:, 1:-1, 2:] & mask[:, :-2, 1:-1],
+ ]) & mask[None, :, 1:-1, 1:-1]
+ normal = (normal * valid[..., None]).sum(dim=0)
+ normal = F.normalize(normal, dim=-1)
+
+ if has_mask:
+ return normal, valid.any(dim=0)
+ else:
+ return normal
+
+
+@batched(2, 2, 2)
+def depth_to_normal(depth: torch.Tensor, intrinsics: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
+ """
+ Calculate normal map from depth map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system.
+
+ Args:
+ depth (torch.Tensor): shape (..., height, width), linear depth map
+ intrinsics (torch.Tensor): shape (..., 3, 3), intrinsics matrix
+ Returns:
+ normal (torch.Tensor): shape (..., 3, height, width), normal map.
+ """
+ has_mask = mask is not None
+
+ height, width = depth.shape[-2:]
+ if mask is None:
+ mask = torch.ones_like(depth, dtype=torch.bool)
+ mask = F.pad(mask, (1, 1, 1, 1), mode='constant', value=0)
+
+ uv = image_uv(*depth.shape[-2:]).unsqueeze(0).to(depth)
+ pts = transforms.unproject_cv(uv.reshape(-1, 2), depth.flatten(-2), intrinsics=intrinsics, extrinsics=None).unflatten(-2, (height, width))
+
+ return point_to_normal(pts, mask)
+
+
+def masked_min(input: torch.Tensor, mask: torch.BoolTensor, dim: int = None, keepdim: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """Similar to torch.min, but with mask
+ """
+ if dim is None:
+ return torch.where(mask, input, torch.tensor(torch.inf, dtype=input.dtype, device=input.device)).min()
+ else:
+ return torch.where(mask, input, torch.tensor(torch.inf, dtype=input.dtype, device=input.device)).min(dim=dim, keepdim=keepdim)
+
+
+def masked_max(input: torch.Tensor, mask: torch.BoolTensor, dim: int = None, keepdim: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """Similar to torch.max, but with mask
+ """
+ if dim is None:
+ return torch.where(mask, input, torch.tensor(-torch.inf, dtype=input.dtype, device=input.device)).max()
+ else:
+ return torch.where(mask, input, torch.tensor(-torch.inf, dtype=input.dtype, device=input.device)).max(dim=dim, keepdim=keepdim)
+
+
+def bounding_rect(mask: torch.BoolTensor):
+ """get bounding rectangle of a mask
+
+ Args:
+ mask (torch.Tensor): shape (..., height, width), mask
+
+ Returns:
+ rect (torch.Tensor): shape (..., 4), bounding rectangle (left, top, right, bottom)
+ """
+ height, width = mask.shape[-2:]
+ mask = mask.flatten(-2).unsqueeze(-1)
+ uv = image_uv(height, width).to(mask.device).reshape(-1, 2)
+ left_top = masked_min(uv, mask, dim=-2)[0]
+ right_bottom = masked_max(uv, mask, dim=-2)[0]
+ return torch.cat([left_top, right_bottom], dim=-1)
+
+
+def chessboard(width: int, height: int, grid_size: int, color_a: torch.Tensor, color_b: torch.Tensor) -> torch.Tensor:
+ """get a chessboard image
+
+ Args:
+ width (int): image width
+ height (int): image height
+ grid_size (int): size of chessboard grid
+ color_a (torch.Tensor): shape (chanenls,), color of the grid at the top-left corner
+ color_b (torch.Tensor): shape (chanenls,), color in complementary grids
+
+ Returns:
+ image (torch.Tensor): shape (height, width, channels), chessboard image
+ """
+ x = torch.div(torch.arange(width), grid_size, rounding_mode='floor')
+ y = torch.div(torch.arange(height), grid_size, rounding_mode='floor')
+ mask = ((x[None, :] + y[:, None]) % 2).to(color_a)
+ image = (1 - mask[..., None]) * color_a + mask[..., None] * color_b
+ return image
\ No newline at end of file
diff --git a/tools.py b/tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3a6977dc826ef3b658a7a7af207afadb6ca597f
--- /dev/null
+++ b/tools.py
@@ -0,0 +1,329 @@
+import argparse
+import os
+import torch
+import numpy as np
+import trimesh
+from scipy.spatial.transform import Rotation
+import torch.backends.cudnn as cudnn
+import torch.nn.functional as F
+from PIL import Image
+from src.utils.vis import (
+ prob_to_mask,
+ colorize,
+ denormalize,
+)
+import numpy as np
+from src.lari.model import LaRIModel, DinoSegModel
+from rembg import remove
+from plyfile import PlyData, PlyElement
+import torchvision.transforms as transforms
+
+
+
+LAYER_COLOR = [
+ [255, 190, 11], # FFFF0B
+ [251, 86, 7], # FB5607
+ [241, 91, 181], # F15BB5
+ [131, 56, 236], # 8338EC
+ [58, 134, 255], # 3A86FF
+]
+
+
+OPENGL = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
+
+
+
+
+def save_point_cloud(pcd, rgb, filename, binary=True):
+ """Save an RGB point cloud as a PLY file.
+ :paras
+ @pcd: Nx3 matrix, the XYZ coordinates
+ @rgb: Nx3 matrix, the rgb colors for each 3D point
+ """
+
+ if rgb is None:
+ gray_concat = np.tile(np.array([128], dtype=np.uint8),
+ (pcd.shape[0], 3))
+ points_3d = np.hstack((pcd, gray_concat))
+ else:
+ assert pcd.shape[0] == rgb.shape[0]
+ points_3d = np.hstack((pcd, rgb))
+ python_types = (float, float, float, int, int, int)
+ npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'),
+ ('green', 'u1'), ('blue', 'u1')]
+ if binary is True:
+ # Format into Numpy structured array
+ vertices = []
+ for row_idx in range(points_3d.shape[0]):
+ cur_point = points_3d[row_idx]
+ vertices.append(
+ tuple(
+ dtype(point)
+ for dtype, point in zip(python_types, cur_point)))
+ vertices_array = np.array(vertices, dtype=npy_types)
+ el = PlyElement.describe(vertices_array, 'vertex')
+
+ # write
+ PlyData([el]).write(filename)
+ else:
+ x = np.squeeze(points_3d[:, 0])
+ y = np.squeeze(points_3d[:, 1])
+ z = np.squeeze(points_3d[:, 2])
+ r = np.squeeze(points_3d[:, 3])
+ g = np.squeeze(points_3d[:, 4])
+ b = np.squeeze(points_3d[:, 5])
+
+ ply_head = 'ply\n' \
+ 'format ascii 1.0\n' \
+ 'element vertex %d\n' \
+ 'property float x\n' \
+ 'property float y\n' \
+ 'property float z\n' \
+ 'property uchar red\n' \
+ 'property uchar green\n' \
+ 'property uchar blue\n' \
+ 'end_header' % r.shape[0]
+ # ---- Save ply data to disk
+ np.savetxt(filename, np.column_stack[x, y, z, r, g, b], fmt='%f %f %f %d %d %d', header=ply_head, comments='')
+
+
+def load_model(model_info, ckpt_path, device):
+ model = eval(model_info)
+ model.to(device)
+ model.eval()
+
+ # Load pretrained weights
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
+ if "model" in ckpt:
+ model.load_state_dict(ckpt["model"], strict=False)
+ else:
+ model.load_state_dict(ckpt, strict=False)
+ return model
+
+
+def process_image_custom(pil_image, resolution=512):
+ """
+ Read an image, resize the long side to `resolution` and pad the short side with gray,
+ so that the final image is (resolution x resolution).
+
+ Returns:
+ padded_img (PIL.Image): The processed image.
+ crop_coords (tuple): (top, left, bottom, right) coordinates of the valid region.
+ original_size (tuple): (width, height) of the original image.
+ """
+ pil_image = pil_image.convert("RGB")
+ original_width, original_height = pil_image.size
+
+ # If already at fixed resolution, no processing is needed.
+ if original_width == resolution and original_height == resolution:
+ crop_coords = (0, 0, resolution, resolution)
+ return pil_image, crop_coords, (original_width, original_height), pil_image
+
+ # Compute scaling factor based on the long side.
+ if original_width >= original_height:
+ # Width is the long side.
+ scale = resolution / float(original_width)
+ new_width = resolution
+ new_height = int(round(original_height * scale))
+ resized_img = pil_image.resize((new_width, new_height), Image.BILINEAR)
+ # Compute vertical padding.
+ pad_top = (resolution - new_height) // 2
+ pad_bottom = resolution - new_height - pad_top
+ pad_left, pad_right = 0, 0
+ else:
+ # Height is the long side.
+ scale = resolution / float(original_height)
+ new_height = resolution
+ new_width = int(round(original_width * scale))
+ resized_img = pil_image.resize((new_width, new_height), Image.BILINEAR)
+ # Compute horizontal padding.
+ pad_left = (resolution - new_width) // 2
+ pad_right = resolution - new_width - pad_left
+ pad_top, pad_bottom = 0, 0
+
+ # Create new image filled with black
+ padded_img = Image.new("RGB", (resolution, resolution), (0, 0, 0))
+ padded_img.paste(resized_img, (pad_left, pad_top))
+
+ # The valid region (crop) is where the resized image was pasted.
+ crop_coords = (pad_top, pad_left, pad_top + new_height, pad_left + new_width)
+ return padded_img, crop_coords, (original_width, original_height), pil_image
+
+
+def process_image(pil_image, resolution=512):
+ """
+ Process the image: apply custom resize/pad then convert to normalized tensor.
+
+ Returns:
+ img_tensor (torch.Tensor): Tensor of shape (1, 3, resolution, resolution).
+ crop_coords (tuple): (top, left, bottom, right) coordinates of the valid region.
+ original_size (tuple): (width, height) of the original image.
+ """
+ padded_img, crop_coords, original_size, ori_img = process_image_custom(
+ pil_image, resolution
+ )
+ transform = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
+ ]
+ )
+
+ img_tensor = transform(padded_img).unsqueeze(0)
+ ori_img_tensor = transform(ori_img).unsqueeze(0)
+ return img_tensor, ori_img_tensor, crop_coords, original_size
+
+
+def post_process_output(input_tensor, crop_coords, original_size):
+ """
+ Crop the input tensor using the crop_coords and then resize to the original image size.
+
+ Args:
+ input_tensor (torch.Tensor): Input with shape (H, W, L, C) where C is 1 or 3.
+ crop_coords (tuple): (top, left, bottom, right) coordinates for cropping.
+ original_size (tuple): (width, height) of the original image.
+
+ Returns:
+ processed_output (torch.Tensor): Output with shape (original_height, original_width, L, C).
+ """
+ top, left, bottom, right = crop_coords
+ # Crop the input spatially: resulting shape (crop_h, crop_w, L, C)
+ cropped = input_tensor[top:bottom, left:right, ...]
+ crop_h, crop_w, L, C = cropped.shape
+
+ # New shape becomes (1, L * C, crop_h, crop_w)
+ reshaped = cropped.permute(2, 3, 0, 1).reshape(1, L * C, crop_h, crop_w)
+
+ # Unpack the original size (width, height) and use bilinear interpolation.
+ new_width, new_height = original_size
+ mode = "nearest" if L == 1 else "bilinear"
+
+ resized = F.interpolate(
+ reshaped, size=(new_height, new_width), mode=mode, align_corners=False
+ )
+ resized = resized.reshape(L, C, new_height, new_width)
+
+ # Permute to the output shape: (new_height, new_width, L, C)
+ processed_output = resized.permute(2, 3, 0, 1)
+
+ return processed_output
+
+
+def get_masked_depth(lari_map, valid_mask, layer_id):
+
+ layer_id = max(0, layer_id)
+
+ lari_depth = lari_map[:, :, layer_id, 2].cpu().numpy() # H W
+ valid_mask = valid_mask[:, :, layer_id, 0].cpu().numpy() # H W
+ valid_values = lari_depth[valid_mask]
+
+ # Handle empty valid values
+ if valid_values.size == 0:
+ vis_depth_range = [0, 1]
+ else:
+ vis_depth_range = [valid_values.min(), valid_values.max()]
+
+ depth_image = Image.fromarray(
+ colorize(
+ lari_depth,
+ vis_depth_range[0],
+ vis_depth_range[1],
+ invalid_mask=~valid_mask,
+ cmap="Spectral",
+ )
+ ).convert("RGB")
+ return depth_image
+
+
+def save_to_glb(pts3d, color3d, path):
+ scene = trimesh.Scene()
+ pct = trimesh.PointCloud(pts3d, colors=color3d)
+ scene.add_geometry(pct)
+ rot_y = np.eye(4)
+ rot_y[:3, :3] = Rotation.from_euler("y", np.deg2rad(180)).as_matrix()
+ scene.apply_transform(np.linalg.inv(OPENGL @ rot_y))
+ outfile = os.path.join(path, "res.glb")
+ scene.export(file_obj=outfile)
+ return outfile
+
+
+def get_point_cloud(pred, img, mask, first_layer_color="image", target_folder=None):
+ """
+ pred h w l 3 - the point cloud
+ img: 3 h w - the colored image
+ mask: h w l - indicating the valid layers
+ n_samples: int - n of pts to sample and save
+ """
+
+ ori_shape = pred.shape
+ pred = pred.cpu().numpy()
+ pred = pred.reshape(-1, 3) # M 3
+
+ color_palette = LAYER_COLOR[: min(len(LAYER_COLOR), ori_shape[-2])]
+ assert first_layer_color in ["image", "pseudo"]
+
+ # assign color to point clouds: [M,3] -> [M, 6]
+ img = torch.clip(denormalize(img).squeeze(0), 0.0, 1.0)
+ img = img.permute(1, 2, 0).unsqueeze(2).cpu().numpy() # H W 1 3
+ img = (img * 255.0).astype(np.uint8)
+
+ layered_color = np.array([[color_palette]]).astype(np.uint8) # 1 1 n_layer 3
+ layered_color = np.broadcast_to(
+ layered_color, (img.shape[0], img.shape[1], ori_shape[2], 3)
+ ) # H W n_layer 3
+
+ if first_layer_color == "image":
+ layered_color[:, :, :1, :] = img
+ layered_color = layered_color.reshape(-1, 3)
+
+ valid_mask_arr = mask.squeeze().reshape(-1).cpu().numpy() # [H,W,layers] -> [M]
+ pred = pred[valid_mask_arr.astype(bool)]
+ layered_color = layered_color[valid_mask_arr.astype(bool)] # V,3
+
+ save_folder = target_folder if target_folder is not None else os.path.dirname(__file__)
+
+ ply_path = os.path.join(save_folder, "res.ply")
+ save_point_cloud(pred, layered_color, filename=ply_path)
+
+ glb_path = save_to_glb(pred, layered_color, save_folder)
+
+ return glb_path, ply_path
+
+
+def removebg_crop(pil_input):
+
+ pil_input = remove(pil_input.convert("RGB"))
+
+ pil_np = np.array(pil_input)
+ alpha = pil_np[:, :, 3]
+ is_crop = (
+ False
+ if np.sum(alpha > 0.8 * 255) > 0.1 * (alpha.shape[0] * alpha.shape[1])
+ else True
+ )
+
+ # adjust object size to fit the image resolution
+ if is_crop:
+ width, height = pil_input.size
+ # adjust object size
+ output_np = np.array(pil_input)
+ alpha = output_np[:, :, 3]
+ bbox = np.argwhere(alpha > 0.8 * 255)
+ bbox = (
+ np.min(bbox[:, 1]),
+ np.min(bbox[:, 0]),
+ np.max(bbox[:, 1]),
+ np.max(bbox[:, 0]),
+ )
+ center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
+ size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
+ size = int(size * 1.5)
+ bbox = (
+ max(center[0] - size // 2, 0),
+ max(center[1] - size // 2, 0),
+ min(center[0] + size // 2, width),
+ min(center[1] + size // 2, height),
+ )
+ pil_input = pil_input.crop(bbox) # type: ignore
+
+ return pil_input
\ No newline at end of file