diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..b75c5c5c8b54dd7eb396efd67261396b063f16b6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,195 @@ +scripts/rendering/blender-* + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class +.vscode/ + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ +*.sif +blender-4.2.5-linux-x64*/ +*.zip +*.png +*.jpg +*.log +intermediate/ +__pycache__ +*.ply +*.npy +*.npz +*.obj +*.mtl +*.json.gz +dcgm/ +wandb/ +# *.json + + + +# Exception to add training list +!lgm_leq20Kpts_simtopo25_train.json.gz +!lgm_leq20Kpts_simtopo25_test.json.gz +!lgm_leq20Kpts_train.json.gz +!lgm_leq20Kpts_train_same_size_wrt_simtopo.json.gz +*.mp4 +*.gif +*.glb +!lgm_leq20Kpts_plus_3Kremain_train.json.gz +!lgm_leq20Kpts_test_cleaned.json.gz +!lgm_leq20Kpts_train_cleaned.json.gz +test_metrics.json \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca3d21fd05e09ce8861f4882e7a37a625167d73 --- /dev/null +++ b/app.py @@ -0,0 +1,282 @@ +import argparse +import gradio +import torch +import torch.backends.cudnn as cudnn +from src.utils.vis import prob_to_mask +from src.lari.model import LaRIModel, DinoSegModel +from tools import load_model, process_image, post_process_output, get_masked_depth, save_to_glb, get_point_cloud, removebg_crop +from huggingface_hub import hf_hub_download + +parser = argparse.ArgumentParser("Arguments for deploying a LaRI Demo") + +parser.add_argument( + "--model_info_pm", + type=str, + default="LaRIModel(use_pretrained = 'moge_full', num_output_layer = 5, head_type = 'point')", + help="Network parameters to load the model", +) + +parser.add_argument( + "--model_info_mask", + type=str, + default="DinoSegModel(use_pretrained = 'dinov2', dim_proj = 256, pretrained_path = '', num_output_layer = 4, output_type = 'ray_stop')", + help="Network parameters to load the model", +) + +parser.add_argument( + "--ckpt_path_pm", + type=str, + default="lari_obj_16k_pointmap.pth", + help="Path to pre-trained weights", +) + +parser.add_argument( + "--ckpt_path_mask", + type=str, + default="lari_obj_16k_seg.pth", + help="Path to pre-trained weights", +) + +parser.add_argument( + "--resolution", type=int, default=512, help="Default model resolution" +) +args = parser.parse_args() + + + +def model_forward(pil_input, layered_id, rembg_checkbox): + """ + Perform LaRI estimation by: + 1. image processing + 2. network forward + 3. save masked layered depth image + 4. save point cloud + """ + if pil_input is None: + return (None, None, None, None, None, None) + + if rembg_checkbox: + pil_input = removebg_crop(pil_input) + + # Process the input image. + input_tensor, ori_img_tensor, crop_coords, original_size = process_image( + pil_input, resolution=512 + ) + input_tensor = input_tensor.to(device) + + # Run inference. + with torch.no_grad(): + # lari map + pred_dict = model_pm(input_tensor) + lari_map = -pred_dict["pts3d"].squeeze( + 0 + ) # Expected output shape: (H_reso, W_reso, L, 3) + # mask + if model_mask: + pred_dict = model_mask(input_tensor) + assert "seg_prob" in pred_dict + valid_mask = prob_to_mask(pred_dict["seg_prob"].squeeze(0)) # H W L 1 + else: + h, w, l, _ = lari_map.shape + valid_mask = torch.new_ones((h, w, l, 1), device=lari_map.device) + + # crop & resize the output to the original resolution. + if original_size[0] != args.resolution or original_size[1] != args.resolution: + lari_map = post_process_output(lari_map, crop_coords, original_size) # H W L 3 + valid_mask = post_process_output( + valid_mask.float(), crop_coords, original_size + ).bool() # H W L 1 + + max_n_layer = min(valid_mask.shape[-2], lari_map.shape[-2]) + valid_mask = valid_mask[:, :, :max_n_layer, :] + lari_map = lari_map[:, :, :max_n_layer, :] + + curr_layer_id = min(max_n_layer - 1, layered_id - 1) + + # masked depth list + depth_image = get_masked_depth( + lari_map=lari_map, valid_mask=valid_mask, layer_id=curr_layer_id + ) + # point cloud + glb_path, ply_path = get_point_cloud( + lari_map, ori_img_tensor, valid_mask, first_layer_color="pseudo" + ) + + return ( + depth_image, + glb_path, + lari_map, + valid_mask, + 0, + max_n_layer - 1, + glb_path, + ply_path, + pil_input, + ) + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +cudnn.benchmark = True + + +# Download the file +model_path_pm = hf_hub_download(repo_id="ruili3/LaRI", filename=args.ckpt_path_pm, repo_type="model") +model_path_mask = hf_hub_download(repo_id="ruili3/LaRI", filename=args.ckpt_path_mask, repo_type="model") + + +# Load the model with pretrained weights. +model_pm = load_model(args.model_info_pm, model_path_pm, device) +model_mask = ( + load_model(args.model_info_mask, model_path_mask, device) + if args.model_info_mask is not None + else None +) + + +def change_layer(slider_layer_id, lari_map, valid_mask, min_layer_id, max_layer_id): + + if lari_map is None: + return + + slider_layer_id = slider_layer_id - 1 + curr_layer_id = min(slider_layer_id, max_layer_id) + curr_layer_id = max(curr_layer_id, min_layer_id) + + # masked depth list + depth_image = get_masked_depth( + lari_map=lari_map, valid_mask=valid_mask, layer_id=curr_layer_id + ) + + return depth_image + + +def clear_everything(): + return ( + gradio.update(value=None), + gradio.update(value=None), + gradio.update(value=None), + gradio.update(value=None), + gradio.update(value=None), + gradio.update(value=None), + gradio.update(value=None), + ) + + +with gradio.Blocks( + css=""".gradio-container {margin: 0 !important; min-width: 100%};""", + title="LaRI Demo", +) as demo: + + gradio.Markdown( + "

LaRI: Layered Ray Intersections for Single-view 3D Geometric Reasoning

", + elem_id="title", + ) + + gradio.Markdown( + """ + This is the official demo of Layered Ray Intersection (LaRI). For a quick start, click the images in 'Examples' and then click the 'Process' Button. + + You can try with your own images with following steps: + - Load an image; + - Click the 'Process' button; + - Browse layered depth maps (z-channel of the resulting LaRI point map) by tunning 'Layer ID'; + + Note that in '3D Point Cloud', different color denotes diffrent intersection layers, i.e., layer 1, layer 2, layer 3, layer 4. + """ + ) + + # , layer 5. + lari_map = gradio.State(None) + valid_mask = gradio.State(None) + min_layer_id = gradio.State(None) + max_layer_id = gradio.State(None) + + with gradio.Column(): + with gradio.Row(equal_height=True): + with gradio.Column(scale=1): + image_input = gradio.Image( + label="Upload an Image", type="pil", height=350 + ) + with gradio.Row(): + rembg_checkbox = gradio.Checkbox(label="Remove background") + clear_button = gradio.Button("Clear") + submit_btn = gradio.Button("Process") + with gradio.Column(scale=1): + depth_output = gradio.Image( + label="LaRI Map at Z-axis (depth)", + type="pil", + interactive=False, + height=300, + ) + slider_layer_id = gradio.Slider( + minimum=1, + maximum=4, + step=1, + value=1, + label="Layer ID", + interactive=True, + ) + + with gradio.Row(scale=1): + outmodel = gradio.Model3D( + label="3D Point Cloud (Color denotes different layers)", + interactive=False, + zoom_speed=0.5, + pan_speed=0.5, + height=450, + ) + + with gradio.Row(): + ply_file_output = gradio.File(label="ply output", elem_classes="small-file") + glb_file_output = gradio.File(label="glb output", elem_classes="small-file") + + submit_btn.click( + fn=model_forward, + inputs=[image_input, slider_layer_id, rembg_checkbox], + outputs=[ + depth_output, + outmodel, + lari_map, + valid_mask, + min_layer_id, + max_layer_id, + glb_file_output, + ply_file_output, + image_input, + ], + ) + + clear_button.click( + fn=clear_everything, + outputs=[ + lari_map, + valid_mask, + min_layer_id, + max_layer_id, + image_input, + depth_output, + outmodel, + ], + ) + + slider_layer_id.change( + fn=change_layer, + inputs=[slider_layer_id, lari_map, valid_mask, min_layer_id, max_layer_id], + outputs=depth_output, + ) + + gradio.Examples(examples=["assets/cole_hardware.png", + "assets/3m_tape.png", + "assets/horse.png", + "assets/rhino.png", + "assets/alphabet.png", + "assets/martin_wedge.png", + "assets/d_rose.png", + "assets/ace.png", + "assets/bifidus.png", + "assets/fem.png", + ], + inputs=image_input) + + +demo.launch(share=False) diff --git a/demo.py b/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..389fd4ddda46554037b6a6138124ec77c07ac2d5 --- /dev/null +++ b/demo.py @@ -0,0 +1,136 @@ +import argparse +import os +import torch +import torch.backends.cudnn as cudnn +from PIL import Image +from src.utils.vis import prob_to_mask +from huggingface_hub import hf_hub_download +from tools import load_model, process_image, post_process_output, get_masked_depth, get_point_cloud, removebg_crop + +parser = argparse.ArgumentParser("Arguments for deploying a LaRI Demo") +parser.add_argument( + "--image_path", + type=str, + default="assets/cole_hardware.png", + help="input image name", +) + +parser.add_argument( + "--output_path", + type=str, + default="./results", + help="path to save the image", +) + +parser.add_argument( + "--model_info_pm", + type=str, + default="LaRIModel(use_pretrained = 'moge_full', num_output_layer = 5, head_type = 'point')", + help="Network parameters to load the model", +) + +parser.add_argument( + "--model_info_mask", + type=str, + default="DinoSegModel(use_pretrained = 'dinov2', dim_proj = 256, pretrained_path = '', num_output_layer = 4, output_type = 'ray_stop')", + help="Network parameters to load the model", +) + +parser.add_argument( + "--ckpt_path_pm", + type=str, + default="lari_obj_16k_pointmap.pth", + help="Path to pre-trained weights", +) + +parser.add_argument( + "--ckpt_path_mask", + type=str, + default="lari_obj_16k_seg.pth", + help="Path to pre-trained weights", +) + +parser.add_argument( + "--resolution", type=int, default=512, help="Default model resolution" +) + +parser.add_argument( + "--is_remove_background", action="store_true", help="Automatically remove the background." +) + +args = parser.parse_args() + + + + + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +cudnn.benchmark = True + +# === Load the model + +model_path_pm = hf_hub_download(repo_id="ruili3/LaRI", filename=args.ckpt_path_pm, repo_type="model") +model_path_mask = hf_hub_download(repo_id="ruili3/LaRI", filename=args.ckpt_path_mask, repo_type="model") +# Load the model with pretrained weights. +model_pm = load_model(args.model_info_pm, model_path_pm, device) +model_mask = ( + load_model(args.model_info_mask, model_path_mask, device) + if args.model_info_mask is not None + else None +) + +# === Image pre-processing +pil_input = Image.open(args.image_path) +if args.is_remove_background: + pil_input = removebg_crop(pil_input) # remove background +input_tensor, ori_img_tensor, crop_coords, original_size = process_image( + pil_input, resolution=512) # crop & resize to fit the model input size +input_tensor = input_tensor.to(device) + + +# === Run inference +with torch.no_grad(): + # lari map + pred_dict = model_pm(input_tensor) + lari_map = -pred_dict["pts3d"].squeeze( + 0 + ) + # mask + if model_mask: + pred_dict = model_mask(input_tensor) + assert "seg_prob" in pred_dict + valid_mask = prob_to_mask(pred_dict["seg_prob"].squeeze(0)) # H W L 1 + else: + h, w, l, _ = lari_map.shape + valid_mask = torch.new_ones((h, w, l, 1), device=lari_map.device) + +# === crop & resize back to the original resolution +if original_size[0] != args.resolution or original_size[1] != args.resolution: + lari_map = post_process_output(lari_map, crop_coords, original_size) # H W L 3 + valid_mask = post_process_output( + valid_mask.float(), crop_coords, original_size + ).bool() # H W L 1 + +max_n_layer = min(valid_mask.shape[-2], lari_map.shape[-2]) +valid_mask = valid_mask[:, :, :max_n_layer, :] +lari_map = lari_map[:, :, :max_n_layer, :] + + +# === save output +os.makedirs(args.output_path, exist_ok=True) + +for layer_id in range(max_n_layer): + depth_pil = get_masked_depth( + lari_map=lari_map, valid_mask=valid_mask, layer_id=layer_id + ) + depth_pil.save(os.path.join(args.output_path, f"layered_depth_{layer_id}.jpg")) + + +# point cloud +glb_path, ply_path = get_point_cloud( + lari_map, ori_img_tensor, valid_mask, first_layer_color="pseudo", + target_folder=args.output_path +) + +print("All results saved to `{}`.".format(args.output_path)) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..26d85c62b9f1d7294e4c0738272cedc83f2cc6ce --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +gradio==5.23.3 +huggingface_hub==0.30.1 +imageio==2.37.0 +matplotlib==3.10.1 +moderngl==5.12.0 +omegaconf==2.3.0 +opencv_python==4.11.0.86 +opencv_python_headless==4.11.0.86 +Pillow==11.1.0 +piqp==0.5.0 +plyfile==1.1 +rembg==2.0.65 +scipy==1.15.2 +torchvision==0.21.0 +trimesh==4.6.4 +xformers==0.0.29.post3 +numpy==1.26.4 +torch==2.6.0 +opencv-python==4.11.0 \ No newline at end of file diff --git a/src/lari/model/__init__.py b/src/lari/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d64ed1ea63d542c9852a5ef5cd933a8dd018d94 --- /dev/null +++ b/src/lari/model/__init__.py @@ -0,0 +1,2 @@ +from .lari_model import LaRIModel +from .dinoseg_model import DinoSegModel \ No newline at end of file diff --git a/src/lari/model/blocks.py b/src/lari/model/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..51d061093ef4fb1f25532ae3b0bf3cbcbcf709f1 --- /dev/null +++ b/src/lari/model/blocks.py @@ -0,0 +1,209 @@ +from typing import * +import torch.nn as nn + +class ResidualConvBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int = None, hidden_channels: int = None, padding_mode: str = 'replicate', activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', norm: Literal['group_norm', 'layer_norm'] = 'group_norm'): + super(ResidualConvBlock, self).__init__() + if out_channels is None: + out_channels = in_channels + if hidden_channels is None: + hidden_channels = in_channels + + if activation =='relu': + activation_cls = lambda: nn.ReLU(inplace=True) + elif activation == 'leaky_relu': + activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True) + elif activation =='silu': + activation_cls = lambda: nn.SiLU(inplace=True) + elif activation == 'elu': + activation_cls = lambda: nn.ELU(inplace=True) + else: + raise ValueError(f'Unsupported activation function: {activation}') + + self.layers = nn.Sequential( + nn.GroupNorm(1, in_channels), + activation_cls(), + nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode=padding_mode), + nn.GroupNorm(hidden_channels // 32 if norm == 'group_norm' else 1, hidden_channels), + activation_cls(), + nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode) + ) + + self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity() + + def forward(self, x): + skip = self.skip_connection(x) + x = self.layers(x) + x = x + skip + return x + + + + +def make_upsampler(in_channels: int, out_channels: int): + upsampler = nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate') + ) + upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1] + return upsampler + +def make_output_block(dim_in: int, dim_out: int, dim_times_res_block_hidden: int, last_res_blocks: int, last_conv_channels: int, last_conv_size: int, res_block_norm: Literal['group_norm', 'layer_norm']): + return nn.Sequential( + nn.Conv2d(dim_in, last_conv_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'), + *(ResidualConvBlock(last_conv_channels, last_conv_channels, dim_times_res_block_hidden * last_conv_channels, activation='relu', norm=res_block_norm) for _ in range(last_res_blocks)), + nn.ReLU(inplace=True), + nn.Conv2d(last_conv_channels, dim_out, kernel_size=last_conv_size, stride=1, padding=last_conv_size // 2, padding_mode='replicate'), + ) + + + +# ---- the following are from Depth Anything ---- +import torch.nn as nn + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) + + output = self.out_conv(output) + + return output diff --git a/src/lari/model/dinoseg_model.py b/src/lari/model/dinoseg_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5e6e218ac769a564ee297e1a466e030273e33374 --- /dev/null +++ b/src/lari/model/dinoseg_model.py @@ -0,0 +1,153 @@ +from typing import * +from numbers import Number +from functools import partial +from pathlib import Path +import importlib +import warnings +import json + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils +import torch.utils.checkpoint +import torch.version +from huggingface_hub import hf_hub_download + + +from src.lari.model.utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing +from src.lari.model.dpt_seg_head import DPTSegHead + + + +class DinoSegModel(nn.Module): + + def __init__(self, + encoder: str = 'dinov2_vitl14', + intermediate_layers: Union[int, List[int]] = 4, + dim_proj: int = 512, + use_pretrained: Literal["dinov2", "moge_full", "moge_backbone", None] = None, + pretrained_path: str = None, + num_output_layer: str = None, + output_type: str = "ray_stop", # "seg_sep" + **deprecated_kwargs + ): + super(DinoSegModel, self).__init__() + if deprecated_kwargs: + warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}") + + self.encoder = encoder + self.intermediate_layers = intermediate_layers + self.use_pretrained = use_pretrained + self.pretrained_path = pretrained_path + self.num_output_layer = num_output_layer + self.output_type = output_type + assert self.output_type in ["seg_sep", "ray_stop"] + + hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder) + + self.backbone = hub_loader(pretrained=True if self.use_pretrained == "dinov2" else False) + dim_feature = self.backbone.blocks[0].attn.qkv.in_features + + + + + self.head = DPTSegHead(in_channels=dim_feature, + features=dim_proj, + use_bn=True, + out_channels=[256, 512, 1024, 1024], + use_clstoken=False, + num_classes = num_output_layer, + output_type = self.output_type + ) + + + if torch.__version__ >= '2.0': + self.enable_pytorch_native_sdpa() + + self._load_pretrained() + + + def _load_pretrained(self): + ''' + Load data from MoGe model + ''' + return + + + + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'DinoSegModel': + """ + Load a model from a checkpoint file. + + ### Parameters: + - `pretrained_model_name_or_path`: path to the checkpoint file or repo id. + - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint. + - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path. + + ### Returns: + - A new instance of `MoGe` with the parameters loaded from the checkpoint. + """ + if Path(pretrained_model_name_or_path).exists(): + checkpoint = torch.load(pretrained_model_name_or_path, map_location='cpu', weights_only=True) + else: + cached_checkpoint_path = hf_hub_download( + repo_id=pretrained_model_name_or_path, + repo_type="model", + filename="model.pt", + **hf_kwargs + ) + checkpoint = torch.load(cached_checkpoint_path, map_location='cpu', weights_only=True) + model_config = checkpoint['model_config'] + if model_kwargs is not None: + model_config.update(model_kwargs) + model = cls(**model_config) + model.load_state_dict(checkpoint['model']) + return model + + @staticmethod + def cache_pretrained_backbone(encoder: str, pretrained: bool): + _ = torch.hub.load('facebookresearch/dinov2', encoder, pretrained=pretrained) + + def load_pretrained_backbone(self): + "Load the backbone with pretrained dinov2 weights from torch hub" + state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict() + self.backbone.load_state_dict(state_dict) + + def enable_backbone_gradient_checkpointing(self): + for i in range(len(self.backbone.blocks)): + self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i]) + + def enable_pytorch_native_sdpa(self): + for i in range(len(self.backbone.blocks)): + self.backbone.blocks[i].attn = wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn) + + + + def forward(self, image: torch.Tensor, mixed_precision: bool = False) -> Dict[str, torch.Tensor]: + raw_img_h, raw_img_w = image.shape[-2:] + patch_h, patch_w = raw_img_h // 14, raw_img_w // 14 + # Apply image transformation for DINOv2 + image_14 = F.interpolate(image, (patch_h * 14, patch_w * 14), mode="bilinear", align_corners=False, antialias=True) + + # Get intermediate layers from the backbone + with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=mixed_precision): + features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True) + + # Predict points and mask (mask scores) + mask = self.head(features, patch_h, patch_w) + + # b c h w + mask = F.interpolate(mask, (raw_img_h, raw_img_w), mode="bilinear", align_corners=False) + + out_dict = {} + + if self.output_type == "seg_sep": + # mask = torch.nn.functional.sigmoid(mask) # for binary segmentation + out_dict["mask"] = mask.permute(0, 2, 3, 1).unsqueeze(-1) # B H W L 1 + elif self.output_type == "ray_stop": + out_dict["seg_prob"] = mask # B L+1 H W + + return out_dict \ No newline at end of file diff --git a/src/lari/model/dinov2/__init__.py b/src/lari/model/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae847e46898077fe3d8701b8a181d7b4e3d41cd9 --- /dev/null +++ b/src/lari/model/dinov2/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +__version__ = "0.0.1" diff --git a/src/lari/model/dinov2/hub/__init__.py b/src/lari/model/dinov2/hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/src/lari/model/dinov2/hub/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/src/lari/model/dinov2/hub/backbones.py b/src/lari/model/dinov2/hub/backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..53fe83719d5107eb77a8f25ef1814c3d73446002 --- /dev/null +++ b/src/lari/model/dinov2/hub/backbones.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch + +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + LVD142M = "LVD142M" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD142M, + **kwargs, +): + from ..models import vision_transformer as vits + + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + ) + vit_kwargs.update(**kwargs) + model = vits.__dict__[arch_name](**vit_kwargs) + + if pretrained: + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + + return model + + +def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + +def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/src/lari/model/dinov2/hub/utils.py b/src/lari/model/dinov2/hub/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c6641404093652d5a2f19b4cf283d976ec39e64 --- /dev/null +++ b/src/lari/model/dinov2/hub/utils.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" + + +class CenterPadding(nn.Module): + def __init__(self, multiple): + super().__init__() + self.multiple = multiple + + def _get_pad(self, size): + new_size = math.ceil(size / self.multiple) * self.multiple + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + @torch.inference_mode() + def forward(self, x): + pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) + output = F.pad(x, pads) + return output diff --git a/src/lari/model/dinov2/layers/__init__.py b/src/lari/model/dinov2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05a0b61868e43abb821ca05a813bab2b8b43629e --- /dev/null +++ b/src/lari/model/dinov2/layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/src/lari/model/dinov2/layers/attention.py b/src/lari/model/dinov2/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..3fed573116d5c837be46a7525d8acf77422c2400 --- /dev/null +++ b/src/lari/model/dinov2/layers/attention.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Attention)") + else: + # warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/src/lari/model/dinov2/layers/block.py b/src/lari/model/dinov2/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5b8a7bb8527b74186af7c1e060e37bdb52c73d --- /dev/null +++ b/src/lari/model/dinov2/layers/block.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Block)") + else: + # warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/src/lari/model/dinov2/layers/dino_head.py b/src/lari/model/dinov2/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565 --- /dev/null +++ b/src/lari/model/dinov2/layers/dino_head.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/src/lari/model/dinov2/layers/drop_path.py b/src/lari/model/dinov2/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5 --- /dev/null +++ b/src/lari/model/dinov2/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/src/lari/model/dinov2/layers/layer_scale.py b/src/lari/model/dinov2/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386 --- /dev/null +++ b/src/lari/model/dinov2/layers/layer_scale.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/src/lari/model/dinov2/layers/mlp.py b/src/lari/model/dinov2/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e --- /dev/null +++ b/src/lari/model/dinov2/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/src/lari/model/dinov2/layers/patch_embed.py b/src/lari/model/dinov2/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339 --- /dev/null +++ b/src/lari/model/dinov2/layers/patch_embed.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/src/lari/model/dinov2/layers/swiglu_ffn.py b/src/lari/model/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..5ce211515774d42e04c8b51003bae53b88f14b35 --- /dev/null +++ b/src/lari/model/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (SwiGLU)") + else: + # warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + # warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/src/lari/model/dinov2/models/__init__.py b/src/lari/model/dinov2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3fdff20badbd5244bf79f16bf18dd2cb73982265 --- /dev/null +++ b/src/lari/model/dinov2/models/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +from . import vision_transformer as vits + + +logger = logging.getLogger("dinov2") + + +def build_model(args, only_teacher=False, img_size=224): + args.arch = args.arch.removesuffix("_memeff") + if "vit" in args.arch: + vit_kwargs = dict( + img_size=img_size, + patch_size=args.patch_size, + init_values=args.layerscale, + ffn_layer=args.ffn_layer, + block_chunks=args.block_chunks, + qkv_bias=args.qkv_bias, + proj_bias=args.proj_bias, + ffn_bias=args.ffn_bias, + num_register_tokens=args.num_register_tokens, + interpolate_offset=args.interpolate_offset, + interpolate_antialias=args.interpolate_antialias, + ) + teacher = vits.__dict__[args.arch](**vit_kwargs) + if only_teacher: + return teacher, teacher.embed_dim + student = vits.__dict__[args.arch]( + **vit_kwargs, + drop_path_rate=args.drop_path_rate, + drop_path_uniform=args.drop_path_uniform, + ) + embed_dim = student.embed_dim + return student, teacher, embed_dim + + +def build_model_from_cfg(cfg, only_teacher=False): + return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) diff --git a/src/lari/model/dinov2/models/vision_transformer.py b/src/lari/model/dinov2/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..1007ba57ddb35109c91716f1f5bf203db346e7be --- /dev/null +++ b/src/lari/model/dinov2/models/vision_transformer.py @@ -0,0 +1,396 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/src/lari/model/dinov2/utils/__init__.py b/src/lari/model/dinov2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/src/lari/model/dinov2/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/src/lari/model/dinov2/utils/cluster.py b/src/lari/model/dinov2/utils/cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..3df87dc3e1eb4f0f8a280dc3137cfef031886314 --- /dev/null +++ b/src/lari/model/dinov2/utils/cluster.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +import os +from pathlib import Path +from typing import Any, Dict, Optional + + +class ClusterType(Enum): + AWS = "aws" + FAIR = "fair" + RSC = "rsc" + + +def _guess_cluster_type() -> ClusterType: + uname = os.uname() + if uname.sysname == "Linux": + if uname.release.endswith("-aws"): + # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" + return ClusterType.AWS + elif uname.nodename.startswith("rsc"): + # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" + return ClusterType.RSC + + return ClusterType.FAIR + + +def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: + if cluster_type is None: + return _guess_cluster_type() + + return cluster_type + + +def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + CHECKPOINT_DIRNAMES = { + ClusterType.AWS: "checkpoints", + ClusterType.FAIR: "checkpoint", + ClusterType.RSC: "checkpoint/dino", + } + return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] + + +def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + checkpoint_path = get_checkpoint_path(cluster_type) + if checkpoint_path is None: + return None + + username = os.environ.get("USER") + assert username is not None + return checkpoint_path / username + + +def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + SLURM_PARTITIONS = { + ClusterType.AWS: "learnlab", + ClusterType.FAIR: "learnlab", + ClusterType.RSC: "learn", + } + return SLURM_PARTITIONS[cluster_type] + + +def get_slurm_executor_parameters( + nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs +) -> Dict[str, Any]: + # create default parameters + params = { + "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html + "gpus_per_node": num_gpus_per_node, + "tasks_per_node": num_gpus_per_node, # one task per GPU + "cpus_per_task": 10, + "nodes": nodes, + "slurm_partition": get_slurm_partition(cluster_type), + } + # apply cluster-specific adjustments + cluster_type = get_cluster_type(cluster_type) + if cluster_type == ClusterType.AWS: + params["cpus_per_task"] = 12 + del params["mem_gb"] + elif cluster_type == ClusterType.RSC: + params["cpus_per_task"] = 12 + # set additional parameters / apply overrides + params.update(kwargs) + return params diff --git a/src/lari/model/dinov2/utils/config.py b/src/lari/model/dinov2/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c9de578787bbcb376f8bd5a782206d0eb7ec1f52 --- /dev/null +++ b/src/lari/model/dinov2/utils/config.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math +import logging +import os + +from omegaconf import OmegaConf + +import dinov2.distributed as distributed +from dinov2.logging import setup_logging +from dinov2.utils import utils +from dinov2.configs import dinov2_default_config + + +logger = logging.getLogger("dinov2") + + +def apply_scaling_rules_to_cfg(cfg): # to fix + if cfg.optim.scaling_rule == "sqrt_wrt_1024": + base_lr = cfg.optim.base_lr + cfg.optim.lr = base_lr + cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0) + logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") + else: + raise NotImplementedError + return cfg + + +def write_config(cfg, output_dir, name="config.yaml"): + logger.info(OmegaConf.to_yaml(cfg)) + saved_cfg_path = os.path.join(output_dir, name) + with open(saved_cfg_path, "w") as f: + OmegaConf.save(config=cfg, f=f) + return saved_cfg_path + + +def get_cfg_from_args(args): + args.output_dir = os.path.abspath(args.output_dir) + args.opts += [f"train.output_dir={args.output_dir}"] + default_cfg = OmegaConf.create(dinov2_default_config) + cfg = OmegaConf.load(args.config_file) + cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) + return cfg + + +def default_setup(args): + distributed.enable(overwrite=True) + seed = getattr(args, "seed", 0) + rank = distributed.get_global_rank() + + global logger + setup_logging(output=args.output_dir, level=logging.INFO) + logger = logging.getLogger("dinov2") + + utils.fix_random_seeds(seed + rank) + logger.info("git:\n {}\n".format(utils.get_sha())) + logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg_from_args(args) + os.makedirs(args.output_dir, exist_ok=True) + default_setup(args) + apply_scaling_rules_to_cfg(cfg) + write_config(cfg, args.output_dir) + return cfg diff --git a/src/lari/model/dinov2/utils/dtype.py b/src/lari/model/dinov2/utils/dtype.py new file mode 100644 index 0000000000000000000000000000000000000000..80f4cd74d99faa2731dbe9f8d3a13d71b3f8e3a8 --- /dev/null +++ b/src/lari/model/dinov2/utils/dtype.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +from typing import Dict, Union + +import numpy as np +import torch + + +TypeSpec = Union[str, np.dtype, torch.dtype] + + +_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { + np.dtype("bool"): torch.bool, + np.dtype("uint8"): torch.uint8, + np.dtype("int8"): torch.int8, + np.dtype("int16"): torch.int16, + np.dtype("int32"): torch.int32, + np.dtype("int64"): torch.int64, + np.dtype("float16"): torch.float16, + np.dtype("float32"): torch.float32, + np.dtype("float64"): torch.float64, + np.dtype("complex64"): torch.complex64, + np.dtype("complex128"): torch.complex128, +} + + +def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, str): + dtype = np.dtype(dtype) + assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" + return _NUMPY_TO_TORCH_DTYPE[dtype] diff --git a/src/lari/model/dinov2/utils/param_groups.py b/src/lari/model/dinov2/utils/param_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5d2ff627cddadc222e5f836864ee39c865208f --- /dev/null +++ b/src/lari/model/dinov2/utils/param_groups.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from collections import defaultdict +import logging + + +logger = logging.getLogger("dinov2") + + +def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone") or force_is_backbone: + if ( + ".pos_embed" in name + or ".patch_embed" in name + or ".mask_token" in name + or ".cls_token" in name + or ".register_tokens" in name + ): + layer_id = 0 + elif force_is_backbone and ( + "pos_embed" in name + or "patch_embed" in name + or "mask_token" in name + or "cls_token" in name + or "register_tokens" in name + ): + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + elif chunked_blocks and "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 + elif "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 + + return lr_decay_rate ** (num_layers + 1 - layer_id) + + +def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): + chunked_blocks = False + if hasattr(model, "n_blocks"): + logger.info("chunked fsdp") + n_blocks = model.n_blocks + chunked_blocks = model.chunked_blocks + elif hasattr(model, "blocks"): + logger.info("first code branch") + n_blocks = len(model.blocks) + elif hasattr(model, "backbone"): + logger.info("second code branch") + n_blocks = len(model.backbone.blocks) + else: + logger.info("else code branch") + n_blocks = 0 + all_param_groups = [] + + for name, param in model.named_parameters(): + name = name.replace("_fsdp_wrapped_module.", "") + if not param.requires_grad: + continue + decay_rate = get_vit_lr_decay_rate( + name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks + ) + d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name} + + if "last_layer" in name: + d.update({"is_last_layer": True}) + + if name.endswith(".bias") or "norm" in name or "gamma" in name: + d.update({"wd_multiplier": 0.0}) + + if "patch_embed" in name: + d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) + + all_param_groups.append(d) + logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""") + + return all_param_groups + + +def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")): + fused_params_groups = defaultdict(lambda: {"params": []}) + for d in all_params_groups: + identifier = "" + for k in keys: + identifier += k + str(d[k]) + "_" + + for k in keys: + fused_params_groups[identifier][k] = d[k] + fused_params_groups[identifier]["params"].append(d["params"]) + + return fused_params_groups.values() diff --git a/src/lari/model/dinov2/utils/utils.py b/src/lari/model/dinov2/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..68f8e2c3be5f780bbb7e00359b5ac4fd0ba0785f --- /dev/null +++ b/src/lari/model/dinov2/utils/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import random +import subprocess +from urllib.parse import urlparse + +import numpy as np +import torch +from torch import nn + + +logger = logging.getLogger("dinov2") + + +def load_pretrained_weights(model, pretrained_weights, checkpoint_key): + if urlparse(pretrained_weights).scheme: # If it looks like an URL + state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") + else: + state_dict = torch.load(pretrained_weights, map_location="cpu") + if checkpoint_key is not None and checkpoint_key in state_dict: + logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") + state_dict = state_dict[checkpoint_key] + # remove `module.` prefix + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + # remove `backbone.` prefix induced by multicrop wrapper + state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + msg = model.load_state_dict(state_dict, strict=False) + logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) + + +def fix_random_seeds(seed=31): + """ + Fix random seeds. + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" + diff = "clean" + branch = "N/A" + try: + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) + diff = "has uncommitted changes" if diff else "clean" + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +class CosineScheduler(object): + def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): + super().__init__() + self.final_value = final_value + self.total_iters = total_iters + + freeze_schedule = np.zeros((freeze_iters)) + + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(total_iters - warmup_iters - freeze_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) + + assert len(self.schedule) == self.total_iters + + def __getitem__(self, it): + if it >= self.total_iters: + return self.final_value + else: + return self.schedule[it] + + +def has_batchnorms(model): + bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) + for name, module in model.named_modules(): + if isinstance(module, bn_types): + return True + return False diff --git a/src/lari/model/dpt_seg_head.py b/src/lari/model/dpt_seg_head.py new file mode 100644 index 0000000000000000000000000000000000000000..45a1f82d2bf6e01dad32346f5e925be5e580c3b4 --- /dev/null +++ b/src/lari/model/dpt_seg_head.py @@ -0,0 +1,158 @@ +''' +The code is modified based on Depth Anything and DPT +''' +from src.lari.model.blocks import FeatureFusionBlock, _make_scratch +import cv2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Compose + + + + +def _make_fusion_block(features, use_bn, size=None): + return FeatureFusionBlock( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class ConvBlock(nn.Module): + def __init__(self, in_feature, out_feature): + super().__init__() + + self.conv_block = nn.Sequential( + nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(out_feature), + nn.ReLU(True) + ) + + def forward(self, x): + return self.conv_block(x) + + +class DPTSegHead(nn.Module): + def __init__( + self, + in_channels, + features=256, + use_bn=False, + out_channels=[256, 512, 1024, 1024], + use_clstoken=False, + num_classes = 5, + output_type = "ray_stop" # "seg_sep" + ): + super(DPTSegHead, self).__init__() + + self.use_clstoken = use_clstoken + self.output_type = output_type + + # output one more layer to indicate the invalid ray-stopping point using index 0 + self.num_classes = num_classes + 1 if self.output_type == "ray_stop" else num_classes + + + self.projects = nn.ModuleList([ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + stride=1, + padding=0, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + + if use_clstoken: + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + nn.Linear(2 * in_channels, in_channels), + nn.GELU())) + + self.scratch = _make_scratch( + out_channels, + features, + groups=1, + expand=False, + ) + + self.scratch.stem_transpose = None + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv1 = nn.Sequential( + nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(features), + nn.ReLU(True), + nn.Dropout(0.1, False), + nn.Conv2d(features, self.num_classes, kernel_size=1), + ) + + + + def forward(self, out_features, patch_h, patch_w): + out = [] + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + else: + x = x[0] + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[i](x) + x = self.resize_layers[i](x) + + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv1(path_1) + + # B C H W - segmentaton logits + out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) + + return out \ No newline at end of file diff --git a/src/lari/model/heads.py b/src/lari/model/heads.py new file mode 100644 index 0000000000000000000000000000000000000000..c9559d07992b1c54e5a12b0efc7238c2fabaa410 --- /dev/null +++ b/src/lari/model/heads.py @@ -0,0 +1,104 @@ + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils +import torch.utils.checkpoint +import torch.version +from typing import * +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))) +from src.lari.model.blocks import ResidualConvBlock, make_upsampler, make_output_block +from src.lari.utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, gaussian_blur_2d + + +class PointHead(nn.Module): + def __init__( + self, + num_features: int, + dim_in: int, + dim_out: int, + dim_proj: int = 512, + dim_upsample: List[int] = [256, 128, 128], + dim_times_res_block_hidden: int = 1, + num_res_blocks: int = 1, + res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm', + last_res_blocks: int = 0, + last_conv_channels: int = 32, + last_conv_size: int = 1, + num_output_layer: int = 5 + ): + super().__init__() + + self.num_output_layer = num_output_layer + + self.projects = nn.ModuleList([ + nn.Conv2d(in_channels=dim_in, out_channels=dim_proj, kernel_size=1, stride=1, padding=0,) for _ in range(num_features) + ]) + + self.upsample_blocks = nn.ModuleList([ + nn.Sequential( + make_upsampler(in_ch + 2, out_ch), + *(ResidualConvBlock(out_ch, out_ch, dim_times_res_block_hidden * out_ch, activation="relu", norm=res_block_norm) for _ in range(num_res_blocks)) + ) for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample) + ]) + + # layer iterations + self.first_layer_block = make_output_block(dim_upsample[-1] + 2, dim_out, + dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm,) + + self.remaining_layer_block = nn.ModuleList([make_output_block(dim_upsample[-1] + 2, dim_out, + dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm,) + for _ in range(self.num_output_layer - 1)]) + + + + def forward(self, hidden_states: torch.Tensor, image: torch.Tensor): + img_h, img_w = image.shape[-2:] + patch_h, patch_w = img_h // 14, img_w // 14 + + # Process the hidden states + x = torch.stack([ + proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous()) + for proj, (feat, clstoken) in zip(self.projects, hidden_states) + ], dim=1).sum(dim=1) + + # Upsample stage + # (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8) + for i, block in enumerate(self.upsample_blocks): + # UV coordinates is for awareness of image aspect ratio + uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) + x = torch.cat([x, uv], dim=1) + for layer in block: + x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False) + + # (patch_h * 8, patch_w * 8) -> (img_h, img_w) + x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False) + uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) + x = torch.cat([x, uv], dim=1) + + + pts_list = [] + for layer_id in range(self.num_output_layer): + if layer_id == 0: + blocks = self.first_layer_block + else: + blocks = self.remaining_layer_block[layer_id-1] + + # for each block + if isinstance(blocks, nn.ModuleList): + raise NotImplementedError() + else: + res = torch.utils.checkpoint.checkpoint(blocks, x, use_reentrant=False)[:,:3, :,:] + pts_list.append(res[:, :3, :,:]) + + pts = torch.stack(pts_list, dim=-1) + seg = pts.new_zeros(pts.shape)[:, :1, ...] + + # , + output = [pts, seg] + + return output \ No newline at end of file diff --git a/src/lari/model/lari_model.py b/src/lari/model/lari_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4351c6b021e3a15ecb424a7f404ca08b39173305 --- /dev/null +++ b/src/lari/model/lari_model.py @@ -0,0 +1,177 @@ +from typing import * +from numbers import Number +from functools import partial +from pathlib import Path +import importlib +import warnings +import json + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils +import torch.utils.checkpoint +import torch.version +from huggingface_hub import hf_hub_download +from src.lari.model.utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing +from src.lari.model.heads import PointHead + + +class LaRIModel(nn.Module): + image_mean: torch.Tensor + image_std: torch.Tensor + + def __init__(self, + encoder: str = 'dinov2_vitl14', + intermediate_layers: Union[int, List[int]] = 4, + dim_proj: int = 512, + dim_upsample: List[int] = [256, 128, 64], + dim_times_res_block_hidden: int = 2, + num_res_blocks: int = 2, + output_mask: bool = True, + split_head: bool = True, + remap_output: Literal[False, True, 'linear', 'sinh', 'exp', 'sinh_exp'] = 'exp', + res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm', + last_res_blocks: int = 0, + last_conv_channels: int = 32, + last_conv_size: int = 1, + use_pretrained: Literal["dinov2", "moge_full", "moge_backbone", None] = None, + pretrained_path: str = "", + num_output_layer: str = None, + head_type = None, + **deprecated_kwargs + ): + super(LaRIModel, self).__init__() + if deprecated_kwargs: + warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}") + + self.encoder = encoder + self.remap_output = remap_output + self.intermediate_layers = intermediate_layers + self.head_type = head_type + self.output_mask = output_mask + self.split_head = split_head + self.use_pretrained = use_pretrained + self.pretrained_path = pretrained_path + self.num_output_layer = num_output_layer + + hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder) + # hub_loader = getattr(importlib.import_module("dinov2.hub.backbones", __package__), encoder) + + self.backbone = hub_loader(pretrained=True if self.use_pretrained == "dinov2" else False) + dim_feature = self.backbone.blocks[0].attn.qkv.in_features + + if self.head_type == "point": + self.head = PointHead( + num_features=intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers), + dim_in=dim_feature, + dim_out=3, + dim_proj=dim_proj, + dim_upsample=dim_upsample, + dim_times_res_block_hidden=dim_times_res_block_hidden, + num_res_blocks=num_res_blocks, + res_block_norm=res_block_norm, + last_res_blocks=last_res_blocks, + last_conv_channels=last_conv_channels, + last_conv_size=last_conv_size, + num_output_layer = num_output_layer + ) + else: + raise NotImplementedError() + + + if torch.__version__ >= '2.0': + self.enable_pytorch_native_sdpa() + + self._load_pretrained() + + + def _load_pretrained(self): + ''' + Load pre-trained weights + ''' + if self.use_pretrained == "dinov2" or self.use_pretrained is None: return + + if self.use_pretrained == "moge_full" and self.pretrained_path != "": + checkpoint = torch.load(self.pretrained_path, map_location='cpu', weights_only=True) + if self.head_type == "point": + key_transition_map = {"output_block": "first_layer_block"} + model_state_dict = {} + + # change the key name of the dict + for key, val in checkpoint['model'].items(): + for trans_src, trans_target in key_transition_map.items(): + if trans_src in key: + model_state_dict[key.replace(trans_src, trans_target)] = val + else: + model_state_dict[key] = val + + self.load_state_dict(model_state_dict, strict=False) + del model_state_dict + + + else: + return + + + @staticmethod + def cache_pretrained_backbone(encoder: str, pretrained: bool): + _ = torch.hub.load('facebookresearch/dinov2', encoder, pretrained=pretrained) + + def load_pretrained_backbone(self): + "Load the backbone with pretrained dinov2 weights from torch hub" + state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict() + self.backbone.load_state_dict(state_dict) + + def enable_backbone_gradient_checkpointing(self): + for i in range(len(self.backbone.blocks)): + self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i]) + + def enable_pytorch_native_sdpa(self): + for i in range(len(self.backbone.blocks)): + self.backbone.blocks[i].attn = wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn) + + def forward(self, image: torch.Tensor, mixed_precision: bool = False) -> Dict[str, torch.Tensor]: + raw_img_h, raw_img_w = image.shape[-2:] + patch_h, patch_w = raw_img_h // 14, raw_img_w // 14 + + # Apply image transformation for DINOv2 + image_14 = F.interpolate(image, (patch_h * 14, patch_w * 14), mode="bilinear", align_corners=False, antialias=True) + + # Get intermediate layers from the backbone + with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=mixed_precision): + features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True) + + # Predict points and mask (mask scores) + points, mask = self.head(features, image) + + is_output_prob = False + if mask.ndim == 5: + # , + points, mask = points.permute(0, 2, 3, 4, 1), mask.permute(0,2,3,4,1) + elif mask.ndim == 4: # , + points = points.permute(0, 2, 3, 4, 1) + is_output_prob = True + + if self.remap_output == 'linear' or self.remap_output == False: + pass + elif self.remap_output =='sinh' or self.remap_output == True: + points = torch.sinh(points) + elif self.remap_output == 'exp': + xy, z = points.split([2, 1], dim=-1) + z = torch.exp(z) + points = torch.cat([xy * z, z], dim=-1) + elif self.remap_output =='sinh_exp': + xy, z = points.split([2, 1], dim=-1) + points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1) + else: + raise ValueError(f"Invalid remap output type: {self.remap_output}") + + return_dict = {'pts3d': points} + + if not is_output_prob: + return_dict['mask'] = mask + else: + return_dict["seg_prob"] = mask + + return return_dict \ No newline at end of file diff --git a/src/lari/model/utils.py b/src/lari/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..af0d0a042209ed87cb60f340529940359fdfa900 --- /dev/null +++ b/src/lari/model/utils.py @@ -0,0 +1,38 @@ +from typing import * + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def wrap_module_with_gradient_checkpointing(module: nn.Module): + from torch.utils.checkpoint import checkpoint + class _CheckpointingWrapper(module.__class__): + _restore_cls = module.__class__ + def forward(self, *args, **kwargs): + return checkpoint(super().forward, *args, use_reentrant=False, **kwargs) + + module.__class__ = _CheckpointingWrapper + return module + + +def unwrap_module_with_gradient_checkpointing(module: nn.Module): + module.__class__ = module.__class__._restore_cls + + +def wrap_dinov2_attention_with_sdpa(module: nn.Module): + assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later" + class _AttentionWrapper(module.__class__): + def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H) + + q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H) + + x = F.scaled_dot_product_attention(q, k, v, attn_bias) + x = x.permute(0, 2, 1, 3).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + module.__class__ = _AttentionWrapper + return module \ No newline at end of file diff --git a/src/lari/utils/__init__.py b/src/lari/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/lari/utils/geometry_numpy.py b/src/lari/utils/geometry_numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..b553af48e1d866f03089a90bb9d3d6d3e5a5337d --- /dev/null +++ b/src/lari/utils/geometry_numpy.py @@ -0,0 +1,187 @@ +from typing import * +from functools import partial +import math + +import numpy as np +import utils3d + +def weighted_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray: + if w is None: + return np.mean(x, axis=axis) + else: + w = w.astype(x.dtype) + return (x * w).mean(axis=axis) / np.clip(w.mean(axis=axis), eps, None) + + +def harmonic_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray: + if w is None: + return 1 / (1 / np.clip(x, eps, None)).mean(axis=axis) + else: + w = w.astype(x.dtype) + return 1 / (weighted_mean_numpy(1 / (x + eps), w, axis=axis, keepdims=keepdims, eps=eps) + eps) + + +def normalized_view_plane_uv_numpy(width: int, height: int, aspect_ratio: float = None, dtype: np.dtype = np.float32) -> np.ndarray: + "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)" + if aspect_ratio is None: + aspect_ratio = width / height + + span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 + span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5 + + u = np.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype) + v = np.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype) + u, v = np.meshgrid(u, v, indexing='xy') + uv = np.stack([u, v], axis=-1) + return uv + + +def focal_to_fov_numpy(focal: np.ndarray): + return 2 * np.arctan(0.5 / focal) + + +def fov_to_focal_numpy(fov: np.ndarray): + return 0.5 / np.tan(fov / 2) + + +def intrinsics_to_fov_numpy(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + fov_x = focal_to_fov_numpy(intrinsics[..., 0, 0]) + fov_y = focal_to_fov_numpy(intrinsics[..., 1, 1]) + return fov_x, fov_y + + +def point_map_to_depth_legacy_numpy(points: np.ndarray): + height, width = points.shape[-3:-1] + diagonal = (height ** 2 + width ** 2) ** 0.5 + uv = normalized_view_plane_uv_numpy(width, height, dtype=points.dtype) # (H, W, 2) + _, uv = np.broadcast_arrays(points[..., :2], uv) + + # Solve least squares problem + b = (uv * points[..., 2:]).reshape(*points.shape[:-3], -1) # (..., H * W * 2) + A = np.stack([points[..., :2], -uv], axis=-1).reshape(*points.shape[:-3], -1, 2) # (..., H * W * 2, 2) + + M = A.swapaxes(-2, -1) @ A + solution = (np.linalg.inv(M + 1e-6 * np.eye(2)) @ (A.swapaxes(-2, -1) @ b[..., None])).squeeze(-1) + focal, shift = solution + + depth = points[..., 2] + shift[..., None, None] + fov_x = np.arctan(width / diagonal / focal) * 2 + fov_y = np.arctan(height / diagonal / focal) * 2 + return depth, fov_x, fov_y, shift + + +def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray): + "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal" + from scipy.optimize import least_squares + uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1) + + def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray): + xy_proj = xy / (z + shift)[: , None] + f = (xy_proj * uv).sum() / np.square(xy_proj).sum() + err = (f * xy_proj - uv).ravel() + return err + + solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm') + optim_shift = solution['x'].squeeze().astype(np.float32) + + xy_proj = xy / (z + optim_shift)[: , None] + optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum() + + return optim_shift, optim_focal + + +def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float): + "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift" + from scipy.optimize import least_squares + uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1) + + def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray): + xy_proj = xy/ (z + shift)[: , None] + err = (focal * xy_proj - uv).ravel() + return err + + solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm') + optim_shift = solution['x'].squeeze().astype(np.float32) + + return optim_shift + + +def recover_focal_shift_numpy(points: np.ndarray, mask: np.ndarray = None, focal: float = None, downsample_size: Tuple[int, int] = (64, 64)): + import cv2 + assert points.shape[-1] == 3, "Points should (H, W, 3)" + + height, width = points.shape[-3], points.shape[-2] + diagonal = (height ** 2 + width ** 2) ** 0.5 + + uv = normalized_view_plane_uv_numpy(width=width, height=height) + + if mask is None: + points_lr = cv2.resize(points, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 3) + uv_lr = cv2.resize(uv, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 2) + else: + index, mask_lr = mask_aware_nearest_resize_numpy(mask, *downsample_size) + points_lr, uv_lr = points[index][mask_lr], uv[index][mask_lr] + + if points_lr.size == 0: + return np.zeros((height, width)), 0, 0, 0 + + if focal is None: + focal, shift = solve_optimal_focal_shift(uv_lr, points_lr) + else: + shift = solve_optimal_shift(uv_lr, points_lr, focal) + + return focal, shift + + +def mask_aware_nearest_resize_numpy(mask: np.ndarray, target_width: int, target_height: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map. + + ### Parameters + - `mask`: Input 2D mask of shape (..., H, W) + - `target_width`: target width of the resized map + - `target_height`: target height of the resized map + + ### Returns + - `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width). Indices are like j + i * W, where j is the row index and i is the column index. + - `target_mask`: Mask of the resized map of shape (..., target_height, target_width) + """ + height, width = mask.shape[-2:] + filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width) + filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f) + filter_size = filter_h_i * filter_w_i + padding_h, padding_w = round(filter_h_f / 2), round(filter_w_f / 2) + + # Window the original mask and uv + uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32) + indices = np.arange(height * width, dtype=np.int32).reshape(height, width) + padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32) + padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv + padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool) + padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask + padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32) + padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices + windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1)) + windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1)) + windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1)) + + # Gather the target pixels's local window + target_uv = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32) + target_corner = target_uv - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32) + target_corner = np.round(target_corner - 0.5).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32) + + target_window_uv = windowed_uv[target_corner[..., 1], target_corner[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size) + target_window_mask = windowed_mask[..., target_corner[..., 1], target_corner[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size) + target_window_indices = windowed_indices[target_corner[..., 1], target_corner[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size) + + # Compute nearest neighbor in the local window for each pixel + dist = np.square(target_window_uv - target_uv[..., None]) + dist = dist[..., 0, :] + dist[..., 1, :] + dist = np.where(target_window_mask, dist, np.inf) # (..., target_height, tgt_width, filter_size) + nearest_in_window = np.argmin(dist, axis=-1, keepdims=True) # (..., target_height, tgt_width, 1) + nearest_idx = np.take_along_axis(target_window_indices, nearest_in_window, axis=-1).squeeze(-1) # (..., target_height, tgt_width) + nearest_i, nearest_j = nearest_idx // width, nearest_idx % width + target_mask = np.any(target_window_mask, axis=-1) + batch_indices = [np.arange(n).reshape([1] * i + [n] + [1] * (mask.ndim - i - 1)) for i, n in enumerate(mask.shape[:-2])] + + return (*batch_indices, nearest_i, nearest_j), target_mask \ No newline at end of file diff --git a/src/lari/utils/geometry_torch.py b/src/lari/utils/geometry_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..c28cd4de1b37fb61afdf12640c306a22d7796ce9 --- /dev/null +++ b/src/lari/utils/geometry_torch.py @@ -0,0 +1,221 @@ +from typing import * +import math +from collections import namedtuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.types + +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) +import utils3d +from .geometry_numpy import solve_optimal_focal_shift, solve_optimal_shift + + +def weighted_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor: + if w is None: + return x.mean(dim=dim, keepdim=keepdim) + else: + w = w.to(x.dtype) + return (x * w).mean(dim=dim, keepdim=keepdim) / w.mean(dim=dim, keepdim=keepdim).add(eps) + + +def harmonic_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor: + if w is None: + return x.add(eps).reciprocal().mean(dim=dim, keepdim=keepdim).reciprocal() + else: + w = w.to(x.dtype) + return weighted_mean(x.add(eps).reciprocal(), w, dim=dim, keepdim=keepdim, eps=eps).add(eps).reciprocal() + + +def geometric_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor: + if w is None: + return x.add(eps).log().mean(dim=dim).exp() + else: + w = w.to(x.dtype) + return weighted_mean(x.add(eps).log(), w, dim=dim, keepdim=keepdim, eps=eps).exp() + + +def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor: + "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)" + if aspect_ratio is None: + aspect_ratio = width / height + + span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 + span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5 + + u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device) + v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device) + u, v = torch.meshgrid(u, v, indexing='xy') + uv = torch.stack([u, v], dim=-1) + return uv + + +def gaussian_blur_2d(input: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor: + kernel = torch.exp(-(torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=input.dtype, device=input.device) ** 2) / (2 * sigma ** 2)) + kernel = kernel / kernel.sum() + kernel = (kernel[:, None] * kernel[None, :]).reshape(1, 1, kernel_size, kernel_size) + input = F.pad(input, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), mode='replicate') + input = F.conv2d(input, kernel, groups=input.shape[1]) + return input + + +def focal_to_fov(focal: torch.Tensor): + return 2 * torch.atan(0.5 / focal) + + +def fov_to_focal(fov: torch.Tensor): + return 0.5 / torch.tan(fov / 2) + + +def intrinsics_to_fov(intrinsics: torch.Tensor): + """ + Returns field of view in radians from normalized intrinsics matrix. + ### Parameters: + - intrinsics: torch.Tensor of shape (..., 3, 3) + + ### Returns: + - fov_x: torch.Tensor of shape (...) + - fov_y: torch.Tensor of shape (...) + """ + focal_x = intrinsics[..., 0, 0] + focal_y = intrinsics[..., 1, 1] + return 2 * torch.atan(0.5 / focal_x), 2 * torch.atan(0.5 / focal_y) + + +def point_map_to_depth_legacy(points: torch.Tensor): + height, width = points.shape[-3:-1] + diagonal = (height ** 2 + width ** 2) ** 0.5 + uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2) + + # Solve least squares problem + b = (uv * points[..., 2:]).flatten(-3, -1) # (..., H * W * 2) + A = torch.stack([points[..., :2], -uv.expand_as(points[..., :2])], dim=-1).flatten(-4, -2) # (..., H * W * 2, 2) + + M = A.transpose(-2, -1) @ A + solution = (torch.inverse(M + 1e-6 * torch.eye(2).to(A)) @ (A.transpose(-2, -1) @ b[..., None])).squeeze(-1) + focal, shift = solution.unbind(-1) + + depth = points[..., 2] + shift[..., None, None] + fov_x = torch.atan(width / diagonal / focal) * 2 + fov_y = torch.atan(height / diagonal / focal) * 2 + return depth, fov_x, fov_y, shift + + +def view_plane_uv_to_focal(uv: torch.Tensor): + normed_uv = normalized_view_plane_uv(width=uv.shape[-2], height=uv.shape[-3], device=uv.device, dtype=uv.dtype) + focal = (uv * normed_uv).sum() / uv.square().sum().add(1e-12) + return focal + + +def recover_focal_shift(points: torch.Tensor, mask: torch.Tensor = None, focal: torch.Tensor = None, downsample_size: Tuple[int, int] = (64, 64)): + """ + Recover the depth map and FoV from a point map with unknown z shift and focal. + + Note that it assumes: + - the optical center is at the center of the map + - the map is undistorted + - the map is isometric in the x and y directions + + ### Parameters: + - `points: torch.Tensor` of shape (..., H, W, 3) + - `mask: torch.Tensor` of shape (..., H, W). Optional. + - `focal: torch.Tensor` of shape (...). Optional. + - `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps. + + ### Returns: + - `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map + - `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space + """ + shape = points.shape + height, width = points.shape[-3], points.shape[-2] + diagonal = (height ** 2 + width ** 2) ** 0.5 + + points = points.reshape(-1, *shape[-3:]) + mask = None if mask is None else mask.reshape(-1, *shape[-3:-1]) + focal = focal.reshape(-1) if focal is not None else None + uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2) + + points_lr = F.interpolate(points.permute(0, 3, 1, 2), downsample_size, mode='nearest').permute(0, 2, 3, 1) + uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode='nearest').squeeze(0).permute(1, 2, 0) + mask_lr = None if mask is None else F.interpolate(mask.to(torch.float32).unsqueeze(1), downsample_size, mode='nearest').squeeze(1) > 0 + + uv_lr_np = uv_lr.cpu().numpy() + points_lr_np = points_lr.detach().cpu().numpy() + focal_np = focal.cpu().numpy() if focal is not None else None + mask_lr_np = None if mask is None else mask_lr.cpu().numpy() + optim_shift, optim_focal = [], [] + for i in range(points.shape[0]): + points_lr_i_np = points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]] + uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]] + if focal is None: + optim_shift_i, optim_focal_i = solve_optimal_focal_shift(uv_lr_i_np, points_lr_i_np) + optim_focal.append(float(optim_focal_i)) + else: + optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i]) + optim_shift.append(float(optim_shift_i)) + optim_shift = torch.tensor(optim_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3]) + + if focal is None: + optim_focal = torch.tensor(optim_focal, device=points.device, dtype=points.dtype).reshape(shape[:-3]) + else: + optim_focal = focal.reshape(shape[:-3]) + + return optim_focal, optim_shift + + +def mask_aware_nearest_resize(mask: torch.BoolTensor, target_width: int, target_height: int) -> Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]: + """ + Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map. + + ### Parameters + - `mask`: Input 2D mask of shape (..., H, W) + - `target_width`: target width of the resized map + - `target_height`: target height of the resized map + + ### Returns + - `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension + - `target_mask`: Mask of the resized map of shape (..., target_height, target_width) + """ + height, width = mask.shape[-2:] + device = mask.device + filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width) + filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f) + filter_size = filter_h_i * filter_w_i + padding_h, padding_w = round(filter_h_f / 2), round(filter_w_f / 2) + + # Window the original mask and uv + uv = utils3d.torch.image_pixel_center(width=width, height=height, dtype=torch.float32, device=device) + indices = torch.arange(height * width, dtype=torch.long, device=device).reshape(height, width) + padded_uv = torch.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=torch.float32, device=device) + padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv + padded_mask = torch.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=torch.bool, device=device) + padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask + padded_indices = torch.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=torch.long, device=device) + padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices + windowed_uv = utils3d.torch.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, dim=(0, 1)) + windowed_mask = utils3d.torch.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, dim=(-2, -1)) + windowed_indices = utils3d.torch.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, dim=(0, 1)) + + # Gather the target pixels's local window + target_uv = utils3d.torch.image_uv(width=target_width, height=target_height, dtype=torch.float32, device=device) * torch.tensor([width, height], dtype=torch.float32, device=device) + target_corner = target_uv - torch.tensor((filter_w_f / 2, filter_h_f / 2), dtype=torch.float32, device=device) + target_corner = torch.round(target_corner - 0.5).long() + torch.tensor((padding_w, padding_h), dtype=torch.long, device=device) + + target_window_uv = windowed_uv[target_corner[..., 1], target_corner[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size) + target_window_mask = windowed_mask[..., target_corner[..., 1], target_corner[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size) + target_window_indices = windowed_indices[target_corner[..., 1], target_corner[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size) + target_window_indices = target_window_indices.expand_as(target_window_mask) + + # Compute nearest neighbor in the local window for each pixel + dist = torch.where(target_window_mask, torch.norm(target_window_uv - target_uv[..., None], dim=-2), torch.inf) # (..., target_height, tgt_width, filter_size) + nearest = torch.argmin(dist, dim=-1, keepdim=True) # (..., target_height, tgt_width, 1) + nearest_idx = torch.gather(target_window_indices, index=nearest, dim=-1).squeeze(-1) # (..., target_height, tgt_width) + target_mask = torch.any(target_window_mask, dim=-1) + nearest_i, nearest_j = nearest_idx // width, nearest_idx % width + batch_indices = [torch.arange(n, device=device).reshape([1] * i + [n] + [1] * (mask.dim() - i - 1)) for i, n in enumerate(mask.shape[:-2])] + + return (*batch_indices, nearest_i, nearest_j), target_mask diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024-present Naver Corporation. All rights reserved. +# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). diff --git a/src/utils/vis.py b/src/utils/vis.py new file mode 100644 index 0000000000000000000000000000000000000000..5c833afb5ab0b98cf11fb202d3bd8d93debe5737 --- /dev/null +++ b/src/utils/vis.py @@ -0,0 +1,105 @@ +# import torchvision.transforms as transforms +# import torch.nn.functional as F +# import cv2 +# import os +# import logging +# from pathlib import Path +import numpy as np +# import os +import torch +import matplotlib +# import cv2 +# import random +# from PIL import Image +# import imageio + +def prob_to_mask(prob): + """ + Transforms a probability map of stopping points (shape: (n_layer+1, H, W)) + into a binary mask (shape: (H, W, n_layer, 1)) where for each pixel, layers + with index ≤ stopping index (as given by argmax) are marked valid. + """ + num_layer_plus1, H, W = prob.shape + # Get stopping index for each pixel; values are in {0, 1, ..., n_layer} + stopping_indices = torch.argmax(prob, dim=0) # (H, W) + + # Create a tensor with layer indices [1, 2, ..., n_layer] + layer_indices = torch.arange(1, num_layer_plus1, device=prob.device).view(-1, 1, 1) + + # Compare: a layer is valid if its index is <= the stopping index. + pred_mask = (layer_indices <= stopping_indices.unsqueeze(0)) + + # Permute and unsqueeze to get shape (H, W, n_layer, 1) + pred_mask = pred_mask.permute(1, 2, 0).unsqueeze(-1) + return pred_mask + + + + +def colorize(value, vmin=None, vmax=None, cmap='rainbow', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None): + """Converts a depth map to a color image. + + Args: + value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed + vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None. + vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None. + cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'. + invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99. + invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None. + background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255). + gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False. + value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None. + + Returns: + numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4) + """ + if isinstance(value, torch.Tensor): + value = value.detach().cpu().numpy() + + value = value.squeeze() + if invalid_mask is None: + invalid_mask = value == invalid_val + mask = np.logical_not(invalid_mask) + + # normalize + vmin = np.percentile(value[mask],2) if vmin is None else vmin + vmax = np.percentile(value[mask],85) if vmax is None else vmax + if vmin != vmax: + value = (value - vmin) / (vmax - vmin) # vmin..vmax + else: + # Avoid 0-division + value = value * 0. + + value[invalid_mask] = np.nan + cmapper = matplotlib.cm.get_cmap(cmap) + if value_transform: + value = value_transform(value) + # value = value / value.max() + value = cmapper(value, bytes=True) # (nxmx4) + + # img = value[:, :, :] + img = value[...] + img[invalid_mask] = background_color + + if gamma_corrected: + # gamma correction + img = img / 255 + img = np.power(img, 2.2) + img = img * 255 + img = img.astype(np.uint8) + return img + + + +def denormalize(x): + """Reverses the imagenet normalization applied to the input. + + Args: + x (torch.Tensor - shape(N,3,H,W)): input tensor + + Returns: + torch.Tensor - shape(N,3,H,W): Denormalized input + """ + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device) + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device) + return x * std + mean \ No newline at end of file diff --git a/src/utils3d/README.md b/src/utils3d/README.md new file mode 100644 index 0000000000000000000000000000000000000000..25b9c737af1398819f077988b4dfca878f839205 --- /dev/null +++ b/src/utils3d/README.md @@ -0,0 +1,3 @@ +# utils3d + +This is a collection of utility functions for 3D computer vision tasks copied from https://github.com/EasternJournalist/utils3d. diff --git a/src/utils3d/__init__.py b/src/utils3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3291ba3a79a7f263c6208a546c20ca458a5058d9 --- /dev/null +++ b/src/utils3d/__init__.py @@ -0,0 +1,20 @@ +""" +A package for common utility functions in 3D computer graphics and vision. Providing NumPy utilities in `utils3d.numpy`, PyTorch utilities in `utils3d.torch`, and IO utilities in `utils3d.io`. +""" +import importlib +from typing import TYPE_CHECKING + +try: + from ._unified import * +except ImportError: + pass + +__all__ = ['numpy', 'torch', 'io'] + +def __getattr__(name: str): + return globals().get(name, importlib.import_module(f'.{name}', __package__)) + +if TYPE_CHECKING: + from . import torch + from . import numpy + from . import io \ No newline at end of file diff --git a/src/utils3d/_helpers.py b/src/utils3d/_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..9d1d5b086dec88b8e2283e2546520d6f2a3d8505 --- /dev/null +++ b/src/utils3d/_helpers.py @@ -0,0 +1,35 @@ +from functools import wraps +import warnings + + +def suppress_traceback(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception as e: + e.__traceback__ = e.__traceback__.tb_next.tb_next + raise + return wrapper + + +class no_warnings: + def __init__(self, action: str = 'ignore', **kwargs): + self.action = action + self.filter_kwargs = kwargs + + def __call__(self, fn): + @wraps(fn) + def wrapper(*args, **kwargs): + with warnings.catch_warnings(): + warnings.simplefilter(self.action, **self.filter_kwargs) + return fn(*args, **kwargs) + return wrapper + + def __enter__(self): + self.warnings_manager = warnings.catch_warnings() + self.warnings_manager.__enter__() + warnings.simplefilter(self.action, **self.filter_kwargs) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.warnings_manager.__exit__(exc_type, exc_val, exc_tb) diff --git a/src/utils3d/_unified/__init__.py b/src/utils3d/_unified/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..84a675766935a51cbae7f3bcbb7378ca25227bd5 --- /dev/null +++ b/src/utils3d/_unified/__init__.py @@ -0,0 +1,934 @@ +# Auto-generated implementation redirecting to numpy/torch implementations +import sys +from typing import TYPE_CHECKING +import utils3d +from .._helpers import suppress_traceback + +__all__ = ["triangulate", +"compute_face_normal", +"compute_face_angle", +"compute_vertex_normal", +"compute_vertex_normal_weighted", +"remove_corrupted_faces", +"merge_duplicate_vertices", +"remove_unreferenced_vertices", +"subdivide_mesh_simple", +"mesh_relations", +"flatten_mesh_indices", +"calc_quad_candidates", +"calc_quad_distortion", +"calc_quad_direction", +"calc_quad_smoothness", +"sovle_quad", +"sovle_quad_qp", +"tri_to_quad", +"sliding_window_1d", +"sliding_window_nd", +"sliding_window_2d", +"max_pool_1d", +"max_pool_2d", +"max_pool_nd", +"depth_edge", +"normals_edge", +"depth_aliasing", +"interpolate", +"image_scrcoord", +"image_uv", +"image_pixel_center", +"image_pixel", +"image_mesh", +"image_mesh_from_depth", +"depth_to_normals", +"points_to_normals", +"chessboard", +"cube", +"icosahedron", +"square", +"camera_frustum", +"perspective", +"perspective_from_fov", +"perspective_from_fov_xy", +"intrinsics_from_focal_center", +"intrinsics_from_fov", +"fov_to_focal", +"focal_to_fov", +"intrinsics_to_fov", +"view_look_at", +"extrinsics_look_at", +"perspective_to_intrinsics", +"perspective_to_near_far", +"intrinsics_to_perspective", +"extrinsics_to_view", +"view_to_extrinsics", +"normalize_intrinsics", +"crop_intrinsics", +"pixel_to_uv", +"pixel_to_ndc", +"uv_to_pixel", +"project_depth", +"depth_buffer_to_linear", +"unproject_cv", +"unproject_gl", +"project_cv", +"project_gl", +"quaternion_to_matrix", +"axis_angle_to_matrix", +"matrix_to_quaternion", +"extrinsics_to_essential", +"euler_axis_angle_rotation", +"euler_angles_to_matrix", +"skew_symmetric", +"rotation_matrix_from_vectors", +"ray_intersection", +"se3_matrix", +"slerp_quaternion", +"slerp_vector", +"lerp", +"lerp_se3_matrix", +"piecewise_lerp", +"piecewise_lerp_se3_matrix", +"apply_transform", +"linear_spline_interpolate", +"RastContext", +"rasterize_triangle_faces", +"rasterize_edges", +"texture", +"warp_image_by_depth", +"test_rasterization", +"compute_face_angles", +"compute_face_tbn", +"compute_vertex_tbn", +"laplacian", +"laplacian_smooth_mesh", +"taubin_smooth_mesh", +"laplacian_hc_smooth_mesh", +"get_rays", +"get_image_rays", +"get_mipnerf_cones", +"volume_rendering", +"bin_sample", +"importance_sample", +"nerf_render_rays", +"mipnerf_render_rays", +"nerf_render_view", +"mipnerf_render_view", +"InstantNGP", +"point_to_normal", +"depth_to_normal", +"masked_min", +"masked_max", +"bounding_rect", +"intrinsics_from_fov_xy", +"matrix_to_euler_angles", +"matrix_to_axis_angle", +"axis_angle_to_quaternion", +"quaternion_to_axis_angle", +"slerp", +"interpolate_extrinsics", +"interpolate_view", +"to4x4", +"rotation_matrix_2d", +"rotate_2d", +"translate_2d", +"scale_2d", +"apply_2d", +"warp_image_by_forward_flow"] + +def _contains_tensor(obj): + if isinstance(obj, (list, tuple)): + return any(_contains_tensor(item) for item in obj) + elif isinstance(obj, dict): + return any(_contains_tensor(value) for value in obj.values()) + else: + import torch + return isinstance(obj, torch.Tensor) + + +@suppress_traceback +def _call_based_on_args(fname, args, kwargs): + if 'torch' in sys.modules: + if any(_contains_tensor(arg) for arg in args) or any(_contains_tensor(v) for v in kwargs.values()): + fn = getattr(utils3d.torch, fname, None) + if fn is None: + raise NotImplementedError(f"Function {fname} has no torch implementation.") + return fn(*args, **kwargs) + fn = getattr(utils3d.numpy, fname, None) + if fn is None: + raise NotImplementedError(f"Function {fname} has no numpy implementation.") + return fn(*args, **kwargs) + + +@suppress_traceback +def triangulate(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.triangulate, utils3d.torch.triangulate + return _call_based_on_args('triangulate', args, kwargs) + +@suppress_traceback +def compute_face_normal(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.compute_face_normal, utils3d.torch.compute_face_normal + return _call_based_on_args('compute_face_normal', args, kwargs) + +@suppress_traceback +def compute_face_angle(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.compute_face_angle, None + return _call_based_on_args('compute_face_angle', args, kwargs) + +@suppress_traceback +def compute_vertex_normal(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.compute_vertex_normal, utils3d.torch.compute_vertex_normal + return _call_based_on_args('compute_vertex_normal', args, kwargs) + +@suppress_traceback +def compute_vertex_normal_weighted(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.compute_vertex_normal_weighted, utils3d.torch.compute_vertex_normal_weighted + return _call_based_on_args('compute_vertex_normal_weighted', args, kwargs) + +@suppress_traceback +def remove_corrupted_faces(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.remove_corrupted_faces, utils3d.torch.remove_corrupted_faces + return _call_based_on_args('remove_corrupted_faces', args, kwargs) + +@suppress_traceback +def merge_duplicate_vertices(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.merge_duplicate_vertices, utils3d.torch.merge_duplicate_vertices + return _call_based_on_args('merge_duplicate_vertices', args, kwargs) + +@suppress_traceback +def remove_unreferenced_vertices(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.remove_unreferenced_vertices, utils3d.torch.remove_unreferenced_vertices + return _call_based_on_args('remove_unreferenced_vertices', args, kwargs) + +@suppress_traceback +def subdivide_mesh_simple(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.subdivide_mesh_simple, utils3d.torch.subdivide_mesh_simple + return _call_based_on_args('subdivide_mesh_simple', args, kwargs) + +@suppress_traceback +def mesh_relations(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.mesh_relations, None + return _call_based_on_args('mesh_relations', args, kwargs) + +@suppress_traceback +def flatten_mesh_indices(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.flatten_mesh_indices, None + return _call_based_on_args('flatten_mesh_indices', args, kwargs) + +@suppress_traceback +def calc_quad_candidates(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.calc_quad_candidates, None + return _call_based_on_args('calc_quad_candidates', args, kwargs) + +@suppress_traceback +def calc_quad_distortion(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.calc_quad_distortion, None + return _call_based_on_args('calc_quad_distortion', args, kwargs) + +@suppress_traceback +def calc_quad_direction(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.calc_quad_direction, None + return _call_based_on_args('calc_quad_direction', args, kwargs) + +@suppress_traceback +def calc_quad_smoothness(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.calc_quad_smoothness, None + return _call_based_on_args('calc_quad_smoothness', args, kwargs) + +@suppress_traceback +def sovle_quad(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.sovle_quad, None + return _call_based_on_args('sovle_quad', args, kwargs) + +@suppress_traceback +def sovle_quad_qp(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.sovle_quad_qp, None + return _call_based_on_args('sovle_quad_qp', args, kwargs) + +@suppress_traceback +def tri_to_quad(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.tri_to_quad, None + return _call_based_on_args('tri_to_quad', args, kwargs) + +@suppress_traceback +def sliding_window_1d(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.sliding_window_1d, utils3d.torch.sliding_window_1d + return _call_based_on_args('sliding_window_1d', args, kwargs) + +@suppress_traceback +def sliding_window_nd(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.sliding_window_nd, utils3d.torch.sliding_window_nd + return _call_based_on_args('sliding_window_nd', args, kwargs) + +@suppress_traceback +def sliding_window_2d(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.sliding_window_2d, utils3d.torch.sliding_window_2d + return _call_based_on_args('sliding_window_2d', args, kwargs) + +@suppress_traceback +def max_pool_1d(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.max_pool_1d, None + return _call_based_on_args('max_pool_1d', args, kwargs) + +@suppress_traceback +def max_pool_2d(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.max_pool_2d, None + return _call_based_on_args('max_pool_2d', args, kwargs) + +@suppress_traceback +def max_pool_nd(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.max_pool_nd, None + return _call_based_on_args('max_pool_nd', args, kwargs) + +@suppress_traceback +def depth_edge(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.depth_edge, utils3d.torch.depth_edge + return _call_based_on_args('depth_edge', args, kwargs) + +@suppress_traceback +def normals_edge(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.normals_edge, None + return _call_based_on_args('normals_edge', args, kwargs) + +@suppress_traceback +def depth_aliasing(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.depth_aliasing, utils3d.torch.depth_aliasing + return _call_based_on_args('depth_aliasing', args, kwargs) + +@suppress_traceback +def interpolate(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.interpolate, None + return _call_based_on_args('interpolate', args, kwargs) + +@suppress_traceback +def image_scrcoord(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.image_scrcoord, None + return _call_based_on_args('image_scrcoord', args, kwargs) + +@suppress_traceback +def image_uv(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.image_uv, utils3d.torch.image_uv + return _call_based_on_args('image_uv', args, kwargs) + +@suppress_traceback +def image_pixel_center(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.image_pixel_center, utils3d.torch.image_pixel_center + return _call_based_on_args('image_pixel_center', args, kwargs) + +@suppress_traceback +def image_pixel(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.image_pixel, None + return _call_based_on_args('image_pixel', args, kwargs) + +@suppress_traceback +def image_mesh(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.image_mesh, utils3d.torch.image_mesh + return _call_based_on_args('image_mesh', args, kwargs) + +@suppress_traceback +def image_mesh_from_depth(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.image_mesh_from_depth, utils3d.torch.image_mesh_from_depth + return _call_based_on_args('image_mesh_from_depth', args, kwargs) + +@suppress_traceback +def depth_to_normals(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.depth_to_normals, None + return _call_based_on_args('depth_to_normals', args, kwargs) + +@suppress_traceback +def points_to_normals(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.points_to_normals, None + return _call_based_on_args('points_to_normals', args, kwargs) + +@suppress_traceback +def chessboard(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.chessboard, utils3d.torch.chessboard + return _call_based_on_args('chessboard', args, kwargs) + +@suppress_traceback +def cube(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.cube, None + return _call_based_on_args('cube', args, kwargs) + +@suppress_traceback +def icosahedron(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.icosahedron, None + return _call_based_on_args('icosahedron', args, kwargs) + +@suppress_traceback +def square(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.square, None + return _call_based_on_args('square', args, kwargs) + +@suppress_traceback +def camera_frustum(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.camera_frustum, None + return _call_based_on_args('camera_frustum', args, kwargs) + +@suppress_traceback +def perspective(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.perspective, utils3d.torch.perspective + return _call_based_on_args('perspective', args, kwargs) + +@suppress_traceback +def perspective_from_fov(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.perspective_from_fov, utils3d.torch.perspective_from_fov + return _call_based_on_args('perspective_from_fov', args, kwargs) + +@suppress_traceback +def perspective_from_fov_xy(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.perspective_from_fov_xy, utils3d.torch.perspective_from_fov_xy + return _call_based_on_args('perspective_from_fov_xy', args, kwargs) + +@suppress_traceback +def intrinsics_from_focal_center(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.intrinsics_from_focal_center, utils3d.torch.intrinsics_from_focal_center + return _call_based_on_args('intrinsics_from_focal_center', args, kwargs) + +@suppress_traceback +def intrinsics_from_fov(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.intrinsics_from_fov, utils3d.torch.intrinsics_from_fov + return _call_based_on_args('intrinsics_from_fov', args, kwargs) + +@suppress_traceback +def fov_to_focal(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.fov_to_focal, None + return _call_based_on_args('fov_to_focal', args, kwargs) + +@suppress_traceback +def focal_to_fov(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.focal_to_fov, None + return _call_based_on_args('focal_to_fov', args, kwargs) + +@suppress_traceback +def intrinsics_to_fov(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.intrinsics_to_fov, None + return _call_based_on_args('intrinsics_to_fov', args, kwargs) + +@suppress_traceback +def view_look_at(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.view_look_at, utils3d.torch.view_look_at + return _call_based_on_args('view_look_at', args, kwargs) + +@suppress_traceback +def extrinsics_look_at(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.extrinsics_look_at, utils3d.torch.extrinsics_look_at + return _call_based_on_args('extrinsics_look_at', args, kwargs) + +@suppress_traceback +def perspective_to_intrinsics(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.perspective_to_intrinsics, utils3d.torch.perspective_to_intrinsics + return _call_based_on_args('perspective_to_intrinsics', args, kwargs) + +@suppress_traceback +def perspective_to_near_far(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.perspective_to_near_far, None + return _call_based_on_args('perspective_to_near_far', args, kwargs) + +@suppress_traceback +def intrinsics_to_perspective(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.intrinsics_to_perspective, utils3d.torch.intrinsics_to_perspective + return _call_based_on_args('intrinsics_to_perspective', args, kwargs) + +@suppress_traceback +def extrinsics_to_view(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.extrinsics_to_view, utils3d.torch.extrinsics_to_view + return _call_based_on_args('extrinsics_to_view', args, kwargs) + +@suppress_traceback +def view_to_extrinsics(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.view_to_extrinsics, utils3d.torch.view_to_extrinsics + return _call_based_on_args('view_to_extrinsics', args, kwargs) + +@suppress_traceback +def normalize_intrinsics(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.normalize_intrinsics, utils3d.torch.normalize_intrinsics + return _call_based_on_args('normalize_intrinsics', args, kwargs) + +@suppress_traceback +def crop_intrinsics(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.crop_intrinsics, utils3d.torch.crop_intrinsics + return _call_based_on_args('crop_intrinsics', args, kwargs) + +@suppress_traceback +def pixel_to_uv(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.pixel_to_uv, utils3d.torch.pixel_to_uv + return _call_based_on_args('pixel_to_uv', args, kwargs) + +@suppress_traceback +def pixel_to_ndc(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.pixel_to_ndc, utils3d.torch.pixel_to_ndc + return _call_based_on_args('pixel_to_ndc', args, kwargs) + +@suppress_traceback +def uv_to_pixel(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.uv_to_pixel, utils3d.torch.uv_to_pixel + return _call_based_on_args('uv_to_pixel', args, kwargs) + +@suppress_traceback +def project_depth(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.project_depth, utils3d.torch.project_depth + return _call_based_on_args('project_depth', args, kwargs) + +@suppress_traceback +def depth_buffer_to_linear(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.depth_buffer_to_linear, utils3d.torch.depth_buffer_to_linear + return _call_based_on_args('depth_buffer_to_linear', args, kwargs) + +@suppress_traceback +def unproject_cv(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.unproject_cv, utils3d.torch.unproject_cv + return _call_based_on_args('unproject_cv', args, kwargs) + +@suppress_traceback +def unproject_gl(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.unproject_gl, utils3d.torch.unproject_gl + return _call_based_on_args('unproject_gl', args, kwargs) + +@suppress_traceback +def project_cv(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.project_cv, utils3d.torch.project_cv + return _call_based_on_args('project_cv', args, kwargs) + +@suppress_traceback +def project_gl(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.project_gl, utils3d.torch.project_gl + return _call_based_on_args('project_gl', args, kwargs) + +@suppress_traceback +def quaternion_to_matrix(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.quaternion_to_matrix, utils3d.torch.quaternion_to_matrix + return _call_based_on_args('quaternion_to_matrix', args, kwargs) + +@suppress_traceback +def axis_angle_to_matrix(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.axis_angle_to_matrix, utils3d.torch.axis_angle_to_matrix + return _call_based_on_args('axis_angle_to_matrix', args, kwargs) + +@suppress_traceback +def matrix_to_quaternion(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.matrix_to_quaternion, utils3d.torch.matrix_to_quaternion + return _call_based_on_args('matrix_to_quaternion', args, kwargs) + +@suppress_traceback +def extrinsics_to_essential(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.extrinsics_to_essential, utils3d.torch.extrinsics_to_essential + return _call_based_on_args('extrinsics_to_essential', args, kwargs) + +@suppress_traceback +def euler_axis_angle_rotation(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.euler_axis_angle_rotation, utils3d.torch.euler_axis_angle_rotation + return _call_based_on_args('euler_axis_angle_rotation', args, kwargs) + +@suppress_traceback +def euler_angles_to_matrix(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.euler_angles_to_matrix, utils3d.torch.euler_angles_to_matrix + return _call_based_on_args('euler_angles_to_matrix', args, kwargs) + +@suppress_traceback +def skew_symmetric(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.skew_symmetric, utils3d.torch.skew_symmetric + return _call_based_on_args('skew_symmetric', args, kwargs) + +@suppress_traceback +def rotation_matrix_from_vectors(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.rotation_matrix_from_vectors, utils3d.torch.rotation_matrix_from_vectors + return _call_based_on_args('rotation_matrix_from_vectors', args, kwargs) + +@suppress_traceback +def ray_intersection(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.ray_intersection, None + return _call_based_on_args('ray_intersection', args, kwargs) + +@suppress_traceback +def se3_matrix(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.se3_matrix, None + return _call_based_on_args('se3_matrix', args, kwargs) + +@suppress_traceback +def slerp_quaternion(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.slerp_quaternion, None + return _call_based_on_args('slerp_quaternion', args, kwargs) + +@suppress_traceback +def slerp_vector(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.slerp_vector, None + return _call_based_on_args('slerp_vector', args, kwargs) + +@suppress_traceback +def lerp(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.lerp, None + return _call_based_on_args('lerp', args, kwargs) + +@suppress_traceback +def lerp_se3_matrix(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.lerp_se3_matrix, None + return _call_based_on_args('lerp_se3_matrix', args, kwargs) + +@suppress_traceback +def piecewise_lerp(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.piecewise_lerp, None + return _call_based_on_args('piecewise_lerp', args, kwargs) + +@suppress_traceback +def piecewise_lerp_se3_matrix(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.piecewise_lerp_se3_matrix, None + return _call_based_on_args('piecewise_lerp_se3_matrix', args, kwargs) + +@suppress_traceback +def apply_transform(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.apply_transform, None + return _call_based_on_args('apply_transform', args, kwargs) + +@suppress_traceback +def linear_spline_interpolate(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.linear_spline_interpolate, None + return _call_based_on_args('linear_spline_interpolate', args, kwargs) + +@suppress_traceback +def RastContext(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.RastContext, utils3d.torch.RastContext + return _call_based_on_args('RastContext', args, kwargs) + +@suppress_traceback +def rasterize_triangle_faces(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.rasterize_triangle_faces, utils3d.torch.rasterize_triangle_faces + return _call_based_on_args('rasterize_triangle_faces', args, kwargs) + +@suppress_traceback +def rasterize_edges(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.rasterize_edges, None + return _call_based_on_args('rasterize_edges', args, kwargs) + +@suppress_traceback +def texture(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.texture, None + return _call_based_on_args('texture', args, kwargs) + +@suppress_traceback +def warp_image_by_depth(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.warp_image_by_depth, utils3d.torch.warp_image_by_depth + return _call_based_on_args('warp_image_by_depth', args, kwargs) + +@suppress_traceback +def test_rasterization(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + utils3d.numpy.test_rasterization, None + return _call_based_on_args('test_rasterization', args, kwargs) + +@suppress_traceback +def compute_face_angles(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.compute_face_angles + return _call_based_on_args('compute_face_angles', args, kwargs) + +@suppress_traceback +def compute_face_tbn(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.compute_face_tbn + return _call_based_on_args('compute_face_tbn', args, kwargs) + +@suppress_traceback +def compute_vertex_tbn(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.compute_vertex_tbn + return _call_based_on_args('compute_vertex_tbn', args, kwargs) + +@suppress_traceback +def laplacian(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.laplacian + return _call_based_on_args('laplacian', args, kwargs) + +@suppress_traceback +def laplacian_smooth_mesh(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.laplacian_smooth_mesh + return _call_based_on_args('laplacian_smooth_mesh', args, kwargs) + +@suppress_traceback +def taubin_smooth_mesh(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.taubin_smooth_mesh + return _call_based_on_args('taubin_smooth_mesh', args, kwargs) + +@suppress_traceback +def laplacian_hc_smooth_mesh(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.laplacian_hc_smooth_mesh + return _call_based_on_args('laplacian_hc_smooth_mesh', args, kwargs) + +@suppress_traceback +def get_rays(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.get_rays + return _call_based_on_args('get_rays', args, kwargs) + +@suppress_traceback +def get_image_rays(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.get_image_rays + return _call_based_on_args('get_image_rays', args, kwargs) + +@suppress_traceback +def get_mipnerf_cones(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.get_mipnerf_cones + return _call_based_on_args('get_mipnerf_cones', args, kwargs) + +@suppress_traceback +def volume_rendering(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.volume_rendering + return _call_based_on_args('volume_rendering', args, kwargs) + +@suppress_traceback +def bin_sample(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.bin_sample + return _call_based_on_args('bin_sample', args, kwargs) + +@suppress_traceback +def importance_sample(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.importance_sample + return _call_based_on_args('importance_sample', args, kwargs) + +@suppress_traceback +def nerf_render_rays(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.nerf_render_rays + return _call_based_on_args('nerf_render_rays', args, kwargs) + +@suppress_traceback +def mipnerf_render_rays(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.mipnerf_render_rays + return _call_based_on_args('mipnerf_render_rays', args, kwargs) + +@suppress_traceback +def nerf_render_view(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.nerf_render_view + return _call_based_on_args('nerf_render_view', args, kwargs) + +@suppress_traceback +def mipnerf_render_view(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.mipnerf_render_view + return _call_based_on_args('mipnerf_render_view', args, kwargs) + +@suppress_traceback +def InstantNGP(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.InstantNGP + return _call_based_on_args('InstantNGP', args, kwargs) + +@suppress_traceback +def point_to_normal(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.point_to_normal + return _call_based_on_args('point_to_normal', args, kwargs) + +@suppress_traceback +def depth_to_normal(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.depth_to_normal + return _call_based_on_args('depth_to_normal', args, kwargs) + +@suppress_traceback +def masked_min(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.masked_min + return _call_based_on_args('masked_min', args, kwargs) + +@suppress_traceback +def masked_max(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.masked_max + return _call_based_on_args('masked_max', args, kwargs) + +@suppress_traceback +def bounding_rect(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.bounding_rect + return _call_based_on_args('bounding_rect', args, kwargs) + +@suppress_traceback +def intrinsics_from_fov_xy(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.intrinsics_from_fov_xy + return _call_based_on_args('intrinsics_from_fov_xy', args, kwargs) + +@suppress_traceback +def matrix_to_euler_angles(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.matrix_to_euler_angles + return _call_based_on_args('matrix_to_euler_angles', args, kwargs) + +@suppress_traceback +def matrix_to_axis_angle(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.matrix_to_axis_angle + return _call_based_on_args('matrix_to_axis_angle', args, kwargs) + +@suppress_traceback +def axis_angle_to_quaternion(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.axis_angle_to_quaternion + return _call_based_on_args('axis_angle_to_quaternion', args, kwargs) + +@suppress_traceback +def quaternion_to_axis_angle(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.quaternion_to_axis_angle + return _call_based_on_args('quaternion_to_axis_angle', args, kwargs) + +@suppress_traceback +def slerp(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.slerp + return _call_based_on_args('slerp', args, kwargs) + +@suppress_traceback +def interpolate_extrinsics(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.interpolate_extrinsics + return _call_based_on_args('interpolate_extrinsics', args, kwargs) + +@suppress_traceback +def interpolate_view(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.interpolate_view + return _call_based_on_args('interpolate_view', args, kwargs) + +@suppress_traceback +def to4x4(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.to4x4 + return _call_based_on_args('to4x4', args, kwargs) + +@suppress_traceback +def rotation_matrix_2d(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.rotation_matrix_2d + return _call_based_on_args('rotation_matrix_2d', args, kwargs) + +@suppress_traceback +def rotate_2d(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.rotate_2d + return _call_based_on_args('rotate_2d', args, kwargs) + +@suppress_traceback +def translate_2d(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.translate_2d + return _call_based_on_args('translate_2d', args, kwargs) + +@suppress_traceback +def scale_2d(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.scale_2d + return _call_based_on_args('scale_2d', args, kwargs) + +@suppress_traceback +def apply_2d(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.apply_2d + return _call_based_on_args('apply_2d', args, kwargs) + +@suppress_traceback +def warp_image_by_forward_flow(*args, **kwargs): + if TYPE_CHECKING: # redirected to: + None, utils3d.torch.warp_image_by_forward_flow + return _call_based_on_args('warp_image_by_forward_flow', args, kwargs) + diff --git a/src/utils3d/_unified/__init__.pyi b/src/utils3d/_unified/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..28f662efdf95ed2894c9af693674e74b307109dc --- /dev/null +++ b/src/utils3d/_unified/__init__.pyi @@ -0,0 +1,2431 @@ +# Auto-generated interface file +from typing import List, Tuple, Dict, Union, Optional, Any, overload, Literal, Callable +import numpy as numpy_ +import torch as torch_ +import nvdiffrast.torch +import numbers +from . import numpy, torch +import utils3d.numpy, utils3d.torch + +__all__ = ["triangulate", +"compute_face_normal", +"compute_face_angle", +"compute_vertex_normal", +"compute_vertex_normal_weighted", +"remove_corrupted_faces", +"merge_duplicate_vertices", +"remove_unreferenced_vertices", +"subdivide_mesh_simple", +"mesh_relations", +"flatten_mesh_indices", +"calc_quad_candidates", +"calc_quad_distortion", +"calc_quad_direction", +"calc_quad_smoothness", +"sovle_quad", +"sovle_quad_qp", +"tri_to_quad", +"sliding_window_1d", +"sliding_window_nd", +"sliding_window_2d", +"max_pool_1d", +"max_pool_2d", +"max_pool_nd", +"depth_edge", +"normals_edge", +"depth_aliasing", +"interpolate", +"image_scrcoord", +"image_uv", +"image_pixel_center", +"image_pixel", +"image_mesh", +"image_mesh_from_depth", +"depth_to_normals", +"points_to_normals", +"chessboard", +"cube", +"icosahedron", +"square", +"camera_frustum", +"perspective", +"perspective_from_fov", +"perspective_from_fov_xy", +"intrinsics_from_focal_center", +"intrinsics_from_fov", +"fov_to_focal", +"focal_to_fov", +"intrinsics_to_fov", +"view_look_at", +"extrinsics_look_at", +"perspective_to_intrinsics", +"perspective_to_near_far", +"intrinsics_to_perspective", +"extrinsics_to_view", +"view_to_extrinsics", +"normalize_intrinsics", +"crop_intrinsics", +"pixel_to_uv", +"pixel_to_ndc", +"uv_to_pixel", +"project_depth", +"depth_buffer_to_linear", +"unproject_cv", +"unproject_gl", +"project_cv", +"project_gl", +"quaternion_to_matrix", +"axis_angle_to_matrix", +"matrix_to_quaternion", +"extrinsics_to_essential", +"euler_axis_angle_rotation", +"euler_angles_to_matrix", +"skew_symmetric", +"rotation_matrix_from_vectors", +"ray_intersection", +"se3_matrix", +"slerp_quaternion", +"slerp_vector", +"lerp", +"lerp_se3_matrix", +"piecewise_lerp", +"piecewise_lerp_se3_matrix", +"apply_transform", +"linear_spline_interpolate", +"RastContext", +"rasterize_triangle_faces", +"rasterize_edges", +"texture", +"warp_image_by_depth", +"test_rasterization", +"compute_face_angles", +"compute_face_tbn", +"compute_vertex_tbn", +"laplacian", +"laplacian_smooth_mesh", +"taubin_smooth_mesh", +"laplacian_hc_smooth_mesh", +"get_rays", +"get_image_rays", +"get_mipnerf_cones", +"volume_rendering", +"bin_sample", +"importance_sample", +"nerf_render_rays", +"mipnerf_render_rays", +"nerf_render_view", +"mipnerf_render_view", +"InstantNGP", +"point_to_normal", +"depth_to_normal", +"masked_min", +"masked_max", +"bounding_rect", +"intrinsics_from_fov_xy", +"matrix_to_euler_angles", +"matrix_to_axis_angle", +"axis_angle_to_quaternion", +"quaternion_to_axis_angle", +"slerp", +"interpolate_extrinsics", +"interpolate_view", +"to4x4", +"rotation_matrix_2d", +"rotate_2d", +"translate_2d", +"scale_2d", +"apply_2d", +"warp_image_by_forward_flow"] + +@overload +def triangulate(faces: numpy_.ndarray, vertices: numpy_.ndarray = None, backslash: numpy_.ndarray = None) -> numpy_.ndarray: + """Triangulate a polygonal mesh. + +Args: + faces (np.ndarray): [L, P] polygonal faces + vertices (np.ndarray, optional): [N, 3] 3-dimensional vertices. + If given, the triangulation is performed according to the distance + between vertices. Defaults to None. + backslash (np.ndarray, optional): [L] boolean array indicating + how to triangulate the quad faces. Defaults to None. + +Returns: + (np.ndarray): [L * (P - 2), 3] triangular faces""" + utils3d.numpy.mesh.triangulate + +@overload +def compute_face_normal(vertices: numpy_.ndarray, faces: numpy_.ndarray) -> numpy_.ndarray: + """Compute face normals of a triangular mesh + +Args: + vertices (np.ndarray): [..., N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + +Returns: + normals (np.ndarray): [..., T, 3] face normals""" + utils3d.numpy.mesh.compute_face_normal + +@overload +def compute_face_angle(vertices: numpy_.ndarray, faces: numpy_.ndarray, eps: float = 1e-12) -> numpy_.ndarray: + """Compute face angles of a triangular mesh + +Args: + vertices (np.ndarray): [..., N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + +Returns: + angles (np.ndarray): [..., T, 3] face angles""" + utils3d.numpy.mesh.compute_face_angle + +@overload +def compute_vertex_normal(vertices: numpy_.ndarray, faces: numpy_.ndarray, face_normal: numpy_.ndarray = None) -> numpy_.ndarray: + """Compute vertex normals of a triangular mesh by averaging neightboring face normals +TODO: can be improved. + +Args: + vertices (np.ndarray): [..., N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + face_normal (np.ndarray, optional): [..., T, 3] face normals. + None to compute face normals from vertices and faces. Defaults to None. + +Returns: + normals (np.ndarray): [..., N, 3] vertex normals""" + utils3d.numpy.mesh.compute_vertex_normal + +@overload +def compute_vertex_normal_weighted(vertices: numpy_.ndarray, faces: numpy_.ndarray, face_normal: numpy_.ndarray = None) -> numpy_.ndarray: + """Compute vertex normals of a triangular mesh by weighted sum of neightboring face normals +according to the angles + +Args: + vertices (np.ndarray): [..., N, 3] 3-dimensional vertices + faces (np.ndarray): [..., T, 3] triangular face indices + face_normal (np.ndarray, optional): [..., T, 3] face normals. + None to compute face normals from vertices and faces. Defaults to None. + +Returns: + normals (np.ndarray): [..., N, 3] vertex normals""" + utils3d.numpy.mesh.compute_vertex_normal_weighted + +@overload +def remove_corrupted_faces(faces: numpy_.ndarray) -> numpy_.ndarray: + """Remove corrupted faces (faces with duplicated vertices) + +Args: + faces (np.ndarray): [T, 3] triangular face indices + +Returns: + np.ndarray: [T_, 3] triangular face indices""" + utils3d.numpy.mesh.remove_corrupted_faces + +@overload +def merge_duplicate_vertices(vertices: numpy_.ndarray, faces: numpy_.ndarray, tol: float = 1e-06) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + """Merge duplicate vertices of a triangular mesh. +Duplicate vertices are merged by selecte one of them, and the face indices are updated accordingly. + +Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + tol (float, optional): tolerance for merging. Defaults to 1e-6. + +Returns: + vertices (np.ndarray): [N_, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices""" + utils3d.numpy.mesh.merge_duplicate_vertices + +@overload +def remove_unreferenced_vertices(faces: numpy_.ndarray, *vertice_attrs, return_indices: bool = False) -> Tuple[numpy_.ndarray, ...]: + """Remove unreferenced vertices of a mesh. +Unreferenced vertices are removed, and the face indices are updated accordingly. + +Args: + faces (np.ndarray): [T, P] face indices + *vertice_attrs: vertex attributes + +Returns: + faces (np.ndarray): [T, P] face indices + *vertice_attrs: vertex attributes + indices (np.ndarray, optional): [N] indices of vertices that are kept. Defaults to None.""" + utils3d.numpy.mesh.remove_unreferenced_vertices + +@overload +def subdivide_mesh_simple(vertices: numpy_.ndarray, faces: numpy_.ndarray, n: int = 1) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + """Subdivide a triangular mesh by splitting each triangle into 4 smaller triangles. +NOTE: All original vertices are kept, and new vertices are appended to the end of the vertex list. + +Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + n (int, optional): number of subdivisions. Defaults to 1. + +Returns: + vertices (np.ndarray): [N_, 3] subdivided 3-dimensional vertices + faces (np.ndarray): [4 * T, 3] subdivided triangular face indices""" + utils3d.numpy.mesh.subdivide_mesh_simple + +@overload +def mesh_relations(faces: numpy_.ndarray) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + """Calculate the relation between vertices and faces. +NOTE: The input mesh must be a manifold triangle mesh. + +Args: + faces (np.ndarray): [T, 3] triangular face indices + +Returns: + edges (np.ndarray): [E, 2] edge indices + edge2face (np.ndarray): [E, 2] edge to face relation. The second column is -1 if the edge is boundary. + face2edge (np.ndarray): [T, 3] face to edge relation + face2face (np.ndarray): [T, 3] face to face relation""" + utils3d.numpy.mesh.mesh_relations + +@overload +def flatten_mesh_indices(*args: numpy_.ndarray) -> Tuple[numpy_.ndarray, ...]: + utils3d.numpy.mesh.flatten_mesh_indices + +@overload +def calc_quad_candidates(edges: numpy_.ndarray, face2edge: numpy_.ndarray, edge2face: numpy_.ndarray): + """Calculate the candidate quad faces. + +Args: + edges (np.ndarray): [E, 2] edge indices + face2edge (np.ndarray): [T, 3] face to edge relation + edge2face (np.ndarray): [E, 2] edge to face relation + +Returns: + quads (np.ndarray): [Q, 4] quad candidate indices + quad2edge (np.ndarray): [Q, 4] edge to quad candidate relation + quad2adj (np.ndarray): [Q, 8] adjacent quad candidates of each quad candidate + quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid""" + utils3d.numpy.quadmesh.calc_quad_candidates + +@overload +def calc_quad_distortion(vertices: numpy_.ndarray, quads: numpy_.ndarray): + """Calculate the distortion of each candidate quad face. + +Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + quads (np.ndarray): [Q, 4] quad face indices + +Returns: + distortion (np.ndarray): [Q] distortion of each quad face""" + utils3d.numpy.quadmesh.calc_quad_distortion + +@overload +def calc_quad_direction(vertices: numpy_.ndarray, quads: numpy_.ndarray): + """Calculate the direction of each candidate quad face. + +Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + quads (np.ndarray): [Q, 4] quad face indices + +Returns: + direction (np.ndarray): [Q, 4] direction of each quad face. + Represented by the angle between the crossing and each edge.""" + utils3d.numpy.quadmesh.calc_quad_direction + +@overload +def calc_quad_smoothness(quad2edge: numpy_.ndarray, quad2adj: numpy_.ndarray, quads_direction: numpy_.ndarray): + """Calculate the smoothness of each candidate quad face connection. + +Args: + quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face + quads_direction (np.ndarray): [Q, 4] direction of each quad face + +Returns: + smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection""" + utils3d.numpy.quadmesh.calc_quad_smoothness + +@overload +def sovle_quad(face2edge: numpy_.ndarray, edge2face: numpy_.ndarray, quad2adj: numpy_.ndarray, quads_distortion: numpy_.ndarray, quads_smoothness: numpy_.ndarray, quads_valid: numpy_.ndarray): + """Solve the quad mesh from the candidate quad faces. + +Args: + face2edge (np.ndarray): [T, 3] face to edge relation + edge2face (np.ndarray): [E, 2] edge to face relation + quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face + quads_distortion (np.ndarray): [Q] distortion of each quad face + quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection + quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid + +Returns: + weights (np.ndarray): [Q] weight of each valid quad face""" + utils3d.numpy.quadmesh.sovle_quad + +@overload +def sovle_quad_qp(face2edge: numpy_.ndarray, edge2face: numpy_.ndarray, quad2adj: numpy_.ndarray, quads_distortion: numpy_.ndarray, quads_smoothness: numpy_.ndarray, quads_valid: numpy_.ndarray): + """Solve the quad mesh from the candidate quad faces. + +Args: + face2edge (np.ndarray): [T, 3] face to edge relation + edge2face (np.ndarray): [E, 2] edge to face relation + quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face + quads_distortion (np.ndarray): [Q] distortion of each quad face + quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection + quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid + +Returns: + weights (np.ndarray): [Q] weight of each valid quad face""" + utils3d.numpy.quadmesh.sovle_quad_qp + +@overload +def tri_to_quad(vertices: numpy_.ndarray, faces: numpy_.ndarray) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + """Convert a triangle mesh to a quad mesh. +NOTE: The input mesh must be a manifold mesh. + +Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + +Returns: + vertices (np.ndarray): [N_, 3] 3-dimensional vertices + faces (np.ndarray): [Q, 4] quad face indices""" + utils3d.numpy.quadmesh.tri_to_quad + +@overload +def sliding_window_1d(x: numpy_.ndarray, window_size: int, stride: int, axis: int = -1): + """Return x view of the input array with x sliding window of the given kernel size and stride. +The sliding window is performed over the given axis, and the window dimension is append to the end of the output array's shape. + +Args: + x (np.ndarray): input array with shape (..., axis_size, ...) + kernel_size (int): size of the sliding window + stride (int): stride of the sliding window + axis (int): axis to perform sliding window over + +Returns: + a_sliding (np.ndarray): view of the input array with shape (..., n_windows, ..., kernel_size), where n_windows = (axis_size - kernel_size + 1) // stride""" + utils3d.numpy.utils.sliding_window_1d + +@overload +def sliding_window_nd(x: numpy_.ndarray, window_size: Tuple[int, ...], stride: Tuple[int, ...], axis: Tuple[int, ...]) -> numpy_.ndarray: + utils3d.numpy.utils.sliding_window_nd + +@overload +def sliding_window_2d(x: numpy_.ndarray, window_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], axis: Tuple[int, int] = (-2, -1)) -> numpy_.ndarray: + utils3d.numpy.utils.sliding_window_2d + +@overload +def max_pool_1d(x: numpy_.ndarray, kernel_size: int, stride: int, padding: int = 0, axis: int = -1): + utils3d.numpy.utils.max_pool_1d + +@overload +def max_pool_2d(x: numpy_.ndarray, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], padding: Union[int, Tuple[int, int]], axis: Tuple[int, int] = (-2, -1)): + utils3d.numpy.utils.max_pool_2d + +@overload +def max_pool_nd(x: numpy_.ndarray, kernel_size: Tuple[int, ...], stride: Tuple[int, ...], padding: Tuple[int, ...], axis: Tuple[int, ...]) -> numpy_.ndarray: + utils3d.numpy.utils.max_pool_nd + +@overload +def depth_edge(depth: numpy_.ndarray, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: numpy_.ndarray = None) -> numpy_.ndarray: + """Compute the edge mask from depth map. The edge is defined as the pixels whose neighbors have large difference in depth. + +Args: + depth (np.ndarray): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + +Returns: + edge (np.ndarray): shape (..., height, width) of dtype torch.bool""" + utils3d.numpy.utils.depth_edge + +@overload +def normals_edge(normals: numpy_.ndarray, tol: float, kernel_size: int = 3, mask: numpy_.ndarray = None) -> numpy_.ndarray: + """Compute the edge mask from normal map. + +Args: + normal (np.ndarray): shape (..., height, width, 3), normal map + tol (float): tolerance in degrees + +Returns: + edge (np.ndarray): shape (..., height, width) of dtype torch.bool""" + utils3d.numpy.utils.normals_edge + +@overload +def depth_aliasing(depth: numpy_.ndarray, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: numpy_.ndarray = None) -> numpy_.ndarray: + """Compute the map that indicates the aliasing of x depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors. +Args: + depth (np.ndarray): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + +Returns: + edge (np.ndarray): shape (..., height, width) of dtype torch.bool""" + utils3d.numpy.utils.depth_aliasing + +@overload +def interpolate(bary: numpy_.ndarray, tri_id: numpy_.ndarray, attr: numpy_.ndarray, faces: numpy_.ndarray) -> numpy_.ndarray: + """Interpolate with given barycentric coordinates and triangle indices + +Args: + bary (np.ndarray): shape (..., 3), barycentric coordinates + tri_id (np.ndarray): int array of shape (...), triangle indices + attr (np.ndarray): shape (N, M), vertices attributes + faces (np.ndarray): int array of shape (T, 3), face vertex indices + +Returns: + np.ndarray: shape (..., M) interpolated result""" + utils3d.numpy.utils.interpolate + +@overload +def image_scrcoord(width: int, height: int) -> numpy_.ndarray: + """Get OpenGL's screen space coordinates, ranging in [0, 1]. +[0, 0] is the bottom-left corner of the image. + +Args: + width (int): image width + height (int): image height + +Returns: + (np.ndarray): shape (height, width, 2)""" + utils3d.numpy.utils.image_scrcoord + +@overload +def image_uv(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, dtype: numpy_.dtype = numpy_.float32) -> numpy_.ndarray: + """Get image space UV grid, ranging in [0, 1]. + +>>> image_uv(10, 10): +[[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]], + [[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]], + ... ... ... + [[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]] + +Args: + width (int): image width + height (int): image height + +Returns: + np.ndarray: shape (height, width, 2)""" + utils3d.numpy.utils.image_uv + +@overload +def image_pixel_center(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, dtype: numpy_.dtype = numpy_.float32) -> numpy_.ndarray: + """Get image pixel center coordinates, ranging in [0, width] and [0, height]. +`image[i, j]` has pixel center coordinates `(j + 0.5, i + 0.5)`. + +>>> image_pixel_center(10, 10): +[[[0.5, 0.5], [1.5, 0.5], ..., [9.5, 0.5]], + [[0.5, 1.5], [1.5, 1.5], ..., [9.5, 1.5]], + ... ... ... +[[0.5, 9.5], [1.5, 9.5], ..., [9.5, 9.5]]] + +Args: + width (int): image width + height (int): image height + +Returns: + np.ndarray: shape (height, width, 2)""" + utils3d.numpy.utils.image_pixel_center + +@overload +def image_pixel(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, dtype: numpy_.dtype = numpy_.int32) -> numpy_.ndarray: + """Get image pixel coordinates grid, ranging in [0, width - 1] and [0, height - 1]. +`image[i, j]` has pixel center coordinates `(j, i)`. + +>>> image_pixel_center(10, 10): +[[[0, 0], [1, 0], ..., [9, 0]], + [[0, 1.5], [1, 1], ..., [9, 1]], + ... ... ... +[[0, 9.5], [1, 9], ..., [9, 9 ]]] + +Args: + width (int): image width + height (int): image height + +Returns: + np.ndarray: shape (height, width, 2)""" + utils3d.numpy.utils.image_pixel + +@overload +def image_mesh(*image_attrs: numpy_.ndarray, mask: numpy_.ndarray = None, tri: bool = False, return_indices: bool = False) -> Tuple[numpy_.ndarray, ...]: + """Get a mesh regarding image pixel uv coordinates as vertices and image grid as faces. + +Args: + *image_attrs (np.ndarray): image attributes in shape (height, width, [channels]) + mask (np.ndarray, optional): binary mask of shape (height, width), dtype=bool. Defaults to None. + +Returns: + faces (np.ndarray): faces connecting neighboring pixels. shape (T, 4) if tri is False, else (T, 3) + *vertex_attrs (np.ndarray): vertex attributes in corresponding order with input image_attrs + indices (np.ndarray, optional): indices of vertices in the original mesh""" + utils3d.numpy.utils.image_mesh + +@overload +def image_mesh_from_depth(depth: numpy_.ndarray, extrinsics: numpy_.ndarray = None, intrinsics: numpy_.ndarray = None, *vertice_attrs: numpy_.ndarray, atol: float = None, rtol: float = None, remove_by_depth: bool = False, return_uv: bool = False, return_indices: bool = False) -> Tuple[numpy_.ndarray, ...]: + """Get x triangle mesh by lifting depth map to 3D. + +Args: + depth (np.ndarray): [H, W] depth map + extrinsics (np.ndarray, optional): [4, 4] extrinsics matrix. Defaults to None. + intrinsics (np.ndarray, optional): [3, 3] intrinsics matrix. Defaults to None. + *vertice_attrs (np.ndarray): [H, W, C] vertex attributes. Defaults to None. + atol (float, optional): absolute tolerance. Defaults to None. + rtol (float, optional): relative tolerance. Defaults to None. + triangles with vertices having depth difference larger than atol + rtol * depth will be marked. + remove_by_depth (bool, optional): whether to remove triangles with large depth difference. Defaults to True. + return_uv (bool, optional): whether to return uv coordinates. Defaults to False. + return_indices (bool, optional): whether to return indices of vertices in the original mesh. Defaults to False. + +Returns: + vertices (np.ndarray): [N, 3] vertices + faces (np.ndarray): [T, 3] faces + *vertice_attrs (np.ndarray): [N, C] vertex attributes + image_uv (np.ndarray, optional): [N, 2] uv coordinates + ref_indices (np.ndarray, optional): [N] indices of vertices in the original mesh""" + utils3d.numpy.utils.image_mesh_from_depth + +@overload +def depth_to_normals(depth: numpy_.ndarray, intrinsics: numpy_.ndarray, mask: numpy_.ndarray = None) -> numpy_.ndarray: + """Calculate normal map from depth map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. + +Args: + depth (np.ndarray): shape (height, width), linear depth map + intrinsics (np.ndarray): shape (3, 3), intrinsics matrix +Returns: + normal (np.ndarray): shape (height, width, 3), normal map. """ + utils3d.numpy.utils.depth_to_normals + +@overload +def points_to_normals(point: numpy_.ndarray, mask: numpy_.ndarray = None) -> numpy_.ndarray: + """Calculate normal map from point map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. + +Args: + point (np.ndarray): shape (height, width, 3), point map +Returns: + normal (np.ndarray): shape (height, width, 3), normal map. """ + utils3d.numpy.utils.points_to_normals + +@overload +def chessboard(width: int, height: int, grid_size: int, color_a: numpy_.ndarray, color_b: numpy_.ndarray) -> numpy_.ndarray: + """get x chessboard image + +Args: + width (int): image width + height (int): image height + grid_size (int): size of chessboard grid + color_a (np.ndarray): color of the grid at the top-left corner + color_b (np.ndarray): color in complementary grid cells + +Returns: + image (np.ndarray): shape (height, width, channels), chessboard image""" + utils3d.numpy.utils.chessboard + +@overload +def cube(tri: bool = False) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + """Get x cube mesh of size 1 centered at origin. + +### Parameters + tri (bool, optional): return triangulated mesh. Defaults to False, which returns quad mesh. + +### Returns + vertices (np.ndarray): shape (8, 3) + faces (np.ndarray): shape (12, 3)""" + utils3d.numpy.utils.cube + +@overload +def icosahedron(): + utils3d.numpy.utils.icosahedron + +@overload +def square(tri: bool = False) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + """Get a square mesh of area 1 centered at origin in the xy-plane. + +### Returns + vertices (np.ndarray): shape (4, 3) + faces (np.ndarray): shape (1, 4)""" + utils3d.numpy.utils.square + +@overload +def camera_frustum(extrinsics: numpy_.ndarray, intrinsics: numpy_.ndarray, depth: float = 1.0) -> Tuple[numpy_.ndarray, numpy_.ndarray, numpy_.ndarray]: + """Get x triangle mesh of camera frustum.""" + utils3d.numpy.utils.camera_frustum + +@overload +def perspective(fov_y: Union[float, numpy_.ndarray], aspect: Union[float, numpy_.ndarray], near: Union[float, numpy_.ndarray], far: Union[float, numpy_.ndarray]) -> numpy_.ndarray: + """Get OpenGL perspective matrix + +Args: + fov_y (float | np.ndarray): field of view in y axis + aspect (float | np.ndarray): aspect ratio + near (float | np.ndarray): near plane to clip + far (float | np.ndarray): far plane to clip + +Returns: + (np.ndarray): [..., 4, 4] perspective matrix""" + utils3d.numpy.transforms.perspective + +@overload +def perspective_from_fov(fov: Union[float, numpy_.ndarray], width: Union[int, numpy_.ndarray], height: Union[int, numpy_.ndarray], near: Union[float, numpy_.ndarray], far: Union[float, numpy_.ndarray]) -> numpy_.ndarray: + """Get OpenGL perspective matrix from field of view in largest dimension + +Args: + fov (float | np.ndarray): field of view in largest dimension + width (int | np.ndarray): image width + height (int | np.ndarray): image height + near (float | np.ndarray): near plane to clip + far (float | np.ndarray): far plane to clip + +Returns: + (np.ndarray): [..., 4, 4] perspective matrix""" + utils3d.numpy.transforms.perspective_from_fov + +@overload +def perspective_from_fov_xy(fov_x: Union[float, numpy_.ndarray], fov_y: Union[float, numpy_.ndarray], near: Union[float, numpy_.ndarray], far: Union[float, numpy_.ndarray]) -> numpy_.ndarray: + """Get OpenGL perspective matrix from field of view in x and y axis + +Args: + fov_x (float | np.ndarray): field of view in x axis + fov_y (float | np.ndarray): field of view in y axis + near (float | np.ndarray): near plane to clip + far (float | np.ndarray): far plane to clip + +Returns: + (np.ndarray): [..., 4, 4] perspective matrix""" + utils3d.numpy.transforms.perspective_from_fov_xy + +@overload +def intrinsics_from_focal_center(fx: Union[float, numpy_.ndarray], fy: Union[float, numpy_.ndarray], cx: Union[float, numpy_.ndarray], cy: Union[float, numpy_.ndarray], dtype: Optional[numpy_.dtype] = numpy_.float32) -> numpy_.ndarray: + """Get OpenCV intrinsics matrix + +Returns: + (np.ndarray): [..., 3, 3] OpenCV intrinsics matrix""" + utils3d.numpy.transforms.intrinsics_from_focal_center + +@overload +def intrinsics_from_fov(fov_max: Union[float, numpy_.ndarray] = None, fov_min: Union[float, numpy_.ndarray] = None, fov_x: Union[float, numpy_.ndarray] = None, fov_y: Union[float, numpy_.ndarray] = None, width: Union[int, numpy_.ndarray] = None, height: Union[int, numpy_.ndarray] = None) -> numpy_.ndarray: + """Get normalized OpenCV intrinsics matrix from given field of view. +You can provide either fov_max, fov_min, fov_x or fov_y + +Args: + width (int | np.ndarray): image width + height (int | np.ndarray): image height + fov_max (float | np.ndarray): field of view in largest dimension + fov_min (float | np.ndarray): field of view in smallest dimension + fov_x (float | np.ndarray): field of view in x axis + fov_y (float | np.ndarray): field of view in y axis + +Returns: + (np.ndarray): [..., 3, 3] OpenCV intrinsics matrix""" + utils3d.numpy.transforms.intrinsics_from_fov + +@overload +def fov_to_focal(fov: numpy_.ndarray): + utils3d.numpy.transforms.fov_to_focal + +@overload +def focal_to_fov(focal: numpy_.ndarray): + utils3d.numpy.transforms.focal_to_fov + +@overload +def intrinsics_to_fov(intrinsics: numpy_.ndarray) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + utils3d.numpy.transforms.intrinsics_to_fov + +@overload +def view_look_at(eye: numpy_.ndarray, look_at: numpy_.ndarray, up: numpy_.ndarray) -> numpy_.ndarray: + """Get OpenGL view matrix looking at something + +Args: + eye (np.ndarray): [..., 3] the eye position + look_at (np.ndarray): [..., 3] the position to look at + up (np.ndarray): [..., 3] head up direction (y axis in screen space). Not necessarily othogonal to view direction + +Returns: + (np.ndarray): [..., 4, 4], view matrix""" + utils3d.numpy.transforms.view_look_at + +@overload +def extrinsics_look_at(eye: numpy_.ndarray, look_at: numpy_.ndarray, up: numpy_.ndarray) -> numpy_.ndarray: + """Get OpenCV extrinsics matrix looking at something + +Args: + eye (np.ndarray): [..., 3] the eye position + look_at (np.ndarray): [..., 3] the position to look at + up (np.ndarray): [..., 3] head up direction (-y axis in screen space). Not necessarily othogonal to view direction + +Returns: + (np.ndarray): [..., 4, 4], extrinsics matrix""" + utils3d.numpy.transforms.extrinsics_look_at + +@overload +def perspective_to_intrinsics(perspective: numpy_.ndarray) -> numpy_.ndarray: + """OpenGL perspective matrix to OpenCV intrinsics + +Args: + perspective (np.ndarray): [..., 4, 4] OpenGL perspective matrix + +Returns: + (np.ndarray): shape [..., 3, 3] OpenCV intrinsics""" + utils3d.numpy.transforms.perspective_to_intrinsics + +@overload +def perspective_to_near_far(perspective: numpy_.ndarray) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + """Get near and far planes from OpenGL perspective matrix + +Args:""" + utils3d.numpy.transforms.perspective_to_near_far + +@overload +def intrinsics_to_perspective(intrinsics: numpy_.ndarray, near: Union[float, numpy_.ndarray], far: Union[float, numpy_.ndarray]) -> numpy_.ndarray: + """OpenCV intrinsics to OpenGL perspective matrix +NOTE: not work for tile-shifting intrinsics currently + +Args: + intrinsics (np.ndarray): [..., 3, 3] OpenCV intrinsics matrix + near (float | np.ndarray): [...] near plane to clip + far (float | np.ndarray): [...] far plane to clip +Returns: + (np.ndarray): [..., 4, 4] OpenGL perspective matrix""" + utils3d.numpy.transforms.intrinsics_to_perspective + +@overload +def extrinsics_to_view(extrinsics: numpy_.ndarray) -> numpy_.ndarray: + """OpenCV camera extrinsics to OpenGL view matrix + +Args: + extrinsics (np.ndarray): [..., 4, 4] OpenCV camera extrinsics matrix + +Returns: + (np.ndarray): [..., 4, 4] OpenGL view matrix""" + utils3d.numpy.transforms.extrinsics_to_view + +@overload +def view_to_extrinsics(view: numpy_.ndarray) -> numpy_.ndarray: + """OpenGL view matrix to OpenCV camera extrinsics + +Args: + view (np.ndarray): [..., 4, 4] OpenGL view matrix + +Returns: + (np.ndarray): [..., 4, 4] OpenCV camera extrinsics matrix""" + utils3d.numpy.transforms.view_to_extrinsics + +@overload +def normalize_intrinsics(intrinsics: numpy_.ndarray, width: Union[int, numpy_.ndarray], height: Union[int, numpy_.ndarray], integer_pixel_centers: bool = True) -> numpy_.ndarray: + """Normalize intrinsics from pixel cooridnates to uv coordinates + +Args: + intrinsics (np.ndarray): [..., 3, 3] camera intrinsics(s) to normalize + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + integer_pixel_centers (bool): whether the integer pixel coordinates are at the center of the pixel. If False, the integer coordinates are at the left-top corner of the pixel. + +Returns: + (np.ndarray): [..., 3, 3] normalized camera intrinsics(s)""" + utils3d.numpy.transforms.normalize_intrinsics + +@overload +def crop_intrinsics(intrinsics: numpy_.ndarray, width: Union[int, numpy_.ndarray], height: Union[int, numpy_.ndarray], left: Union[int, numpy_.ndarray], top: Union[int, numpy_.ndarray], crop_width: Union[int, numpy_.ndarray], crop_height: Union[int, numpy_.ndarray]) -> numpy_.ndarray: + """Evaluate the new intrinsics(s) after crop the image: cropped_img = img[top:top+crop_height, left:left+crop_width] + +Args: + intrinsics (np.ndarray): [..., 3, 3] camera intrinsics(s) to crop + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + left (int | np.ndarray): [...] left crop boundary + top (int | np.ndarray): [...] top crop boundary + crop_width (int | np.ndarray): [...] crop width + crop_height (int | np.ndarray): [...] crop height + +Returns: + (np.ndarray): [..., 3, 3] cropped camera intrinsics(s)""" + utils3d.numpy.transforms.crop_intrinsics + +@overload +def pixel_to_uv(pixel: numpy_.ndarray, width: Union[int, numpy_.ndarray], height: Union[int, numpy_.ndarray]) -> numpy_.ndarray: + """Args: + pixel (np.ndarray): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + +Returns: + (np.ndarray): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)""" + utils3d.numpy.transforms.pixel_to_uv + +@overload +def pixel_to_ndc(pixel: numpy_.ndarray, width: Union[int, numpy_.ndarray], height: Union[int, numpy_.ndarray]) -> numpy_.ndarray: + """Args: + pixel (np.ndarray): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + +Returns: + (np.ndarray): [..., 2] pixel coordinrates defined in ndc space, the range is (-1, 1)""" + utils3d.numpy.transforms.pixel_to_ndc + +@overload +def uv_to_pixel(uv: numpy_.ndarray, width: Union[int, numpy_.ndarray], height: Union[int, numpy_.ndarray]) -> numpy_.ndarray: + """Args: + pixel (np.ndarray): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + +Returns: + (np.ndarray): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)""" + utils3d.numpy.transforms.uv_to_pixel + +@overload +def project_depth(depth: numpy_.ndarray, near: Union[float, numpy_.ndarray], far: Union[float, numpy_.ndarray]) -> numpy_.ndarray: + """Project linear depth to depth value in screen space + +Args: + depth (np.ndarray): [...] depth value + near (float | np.ndarray): [...] near plane to clip + far (float | np.ndarray): [...] far plane to clip + +Returns: + (np.ndarray): [..., 1] depth value in screen space, value ranging in [0, 1]""" + utils3d.numpy.transforms.project_depth + +@overload +def depth_buffer_to_linear(depth_buffer: numpy_.ndarray, near: Union[float, numpy_.ndarray], far: Union[float, numpy_.ndarray]) -> numpy_.ndarray: + """OpenGL depth buffer to linear depth + +Args: + depth_buffer (np.ndarray): [...] depth value + near (float | np.ndarray): [...] near plane to clip + far (float | np.ndarray): [...] far plane to clip + +Returns: + (np.ndarray): [..., 1] linear depth""" + utils3d.numpy.transforms.depth_buffer_to_linear + +@overload +def unproject_cv(uv_coord: numpy_.ndarray, depth: numpy_.ndarray = None, extrinsics: numpy_.ndarray = None, intrinsics: numpy_.ndarray = None) -> numpy_.ndarray: + """Unproject uv coordinates to 3D view space following the OpenCV convention + +Args: + uv_coord (np.ndarray): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + depth (np.ndarray): [..., N] depth value + extrinsics (np.ndarray): [..., 4, 4] extrinsics matrix + intrinsics (np.ndarray): [..., 3, 3] intrinsics matrix + +Returns: + points (np.ndarray): [..., N, 3] 3d points""" + utils3d.numpy.transforms.unproject_cv + +@overload +def unproject_gl(screen_coord: numpy_.ndarray, model: numpy_.ndarray = None, view: numpy_.ndarray = None, perspective: numpy_.ndarray = None) -> numpy_.ndarray: + """Unproject screen space coordinates to 3D view space following the OpenGL convention (except for row major matrice) + +Args: + screen_coord (np.ndarray): [..., N, 3] screen space coordinates, value ranging in [0, 1]. + The origin (0., 0., 0.) is corresponding to the left & bottom & nearest + model (np.ndarray): [..., 4, 4] model matrix + view (np.ndarray): [..., 4, 4] view matrix + perspective (np.ndarray): [..., 4, 4] perspective matrix + +Returns: + points (np.ndarray): [..., N, 3] 3d points""" + utils3d.numpy.transforms.unproject_gl + +@overload +def project_cv(points: numpy_.ndarray, extrinsics: numpy_.ndarray = None, intrinsics: numpy_.ndarray = None) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + """Project 3D points to 2D following the OpenCV convention + +Args: + points (np.ndarray): [..., N, 3] or [..., N, 4] 3D points to project, if the last + dimension is 4, the points are assumed to be in homogeneous coordinates + extrinsics (np.ndarray): [..., 4, 4] extrinsics matrix + intrinsics (np.ndarray): [..., 3, 3] intrinsics matrix + +Returns: + uv_coord (np.ndarray): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + linear_depth (np.ndarray): [..., N] linear depth""" + utils3d.numpy.transforms.project_cv + +@overload +def project_gl(points: numpy_.ndarray, model: numpy_.ndarray = None, view: numpy_.ndarray = None, perspective: numpy_.ndarray = None) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + """Project 3D points to 2D following the OpenGL convention (except for row major matrice) + +Args: + points (np.ndarray): [..., N, 3] or [..., N, 4] 3D points to project, if the last + dimension is 4, the points are assumed to be in homogeneous coordinates + model (np.ndarray): [..., 4, 4] model matrix + view (np.ndarray): [..., 4, 4] view matrix + perspective (np.ndarray): [..., 4, 4] perspective matrix + +Returns: + scr_coord (np.ndarray): [..., N, 3] screen space coordinates, value ranging in [0, 1]. + The origin (0., 0., 0.) is corresponding to the left & bottom & nearest + linear_depth (np.ndarray): [..., N] linear depth""" + utils3d.numpy.transforms.project_gl + +@overload +def quaternion_to_matrix(quaternion: numpy_.ndarray, eps: float = 1e-12) -> numpy_.ndarray: + """Converts a batch of quaternions (w, x, y, z) to rotation matrices + +Args: + quaternion (np.ndarray): shape (..., 4), the quaternions to convert + +Returns: + np.ndarray: shape (..., 3, 3), the rotation matrices corresponding to the given quaternions""" + utils3d.numpy.transforms.quaternion_to_matrix + +@overload +def axis_angle_to_matrix(axis_angle: numpy_.ndarray, eps: float = 1e-12) -> numpy_.ndarray: + """Convert axis-angle representation (rotation vector) to rotation matrix, whose direction is the axis of rotation and length is the angle of rotation + +Args: + axis_angle (np.ndarray): shape (..., 3), axis-angle vcetors + +Returns: + np.ndarray: shape (..., 3, 3) The rotation matrices for the given axis-angle parameters""" + utils3d.numpy.transforms.axis_angle_to_matrix + +@overload +def matrix_to_quaternion(rot_mat: numpy_.ndarray, eps: float = 1e-12) -> numpy_.ndarray: + """Convert 3x3 rotation matrix to quaternion (w, x, y, z) + +Args: + rot_mat (np.ndarray): shape (..., 3, 3), the rotation matrices to convert + +Returns: + np.ndarray: shape (..., 4), the quaternions corresponding to the given rotation matrices""" + utils3d.numpy.transforms.matrix_to_quaternion + +@overload +def extrinsics_to_essential(extrinsics: numpy_.ndarray): + """extrinsics matrix `[[R, t] [0, 0, 0, 1]]` such that `x' = R (x - t)` to essential matrix such that `x' E x = 0` + +Args: + extrinsics (np.ndaray): [..., 4, 4] extrinsics matrix + +Returns: + (np.ndaray): [..., 3, 3] essential matrix""" + utils3d.numpy.transforms.extrinsics_to_essential + +@overload +def euler_axis_angle_rotation(axis: str, angle: numpy_.ndarray) -> numpy_.ndarray: + """Return the rotation matrices for one of the rotations about an axis +of which Euler angles describe, for each value of the angle given. + +Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + +Returns: + Rotation matrices as tensor of shape (..., 3, 3).""" + utils3d.numpy.transforms.euler_axis_angle_rotation + +@overload +def euler_angles_to_matrix(euler_angles: numpy_.ndarray, convention: str = 'XYZ') -> numpy_.ndarray: + """Convert rotations given as Euler angles in radians to rotation matrices. + +Args: + euler_angles: Euler angles in radians as ndarray of shape (..., 3), XYZ + convention: permutation of "X", "Y" or "Z", representing the order of Euler rotations to apply. + +Returns: + Rotation matrices as ndarray of shape (..., 3, 3).""" + utils3d.numpy.transforms.euler_angles_to_matrix + +@overload +def skew_symmetric(v: numpy_.ndarray): + """Skew symmetric matrix from a 3D vector""" + utils3d.numpy.transforms.skew_symmetric + +@overload +def rotation_matrix_from_vectors(v1: numpy_.ndarray, v2: numpy_.ndarray): + """Rotation matrix that rotates v1 to v2""" + utils3d.numpy.transforms.rotation_matrix_from_vectors + +@overload +def ray_intersection(p1: numpy_.ndarray, d1: numpy_.ndarray, p2: numpy_.ndarray, d2: numpy_.ndarray): + """Compute the intersection/closest point of two D-dimensional rays +If the rays are intersecting, the closest point is the intersection point. + +Args: + p1 (np.ndarray): (..., D) origin of ray 1 + d1 (np.ndarray): (..., D) direction of ray 1 + p2 (np.ndarray): (..., D) origin of ray 2 + d2 (np.ndarray): (..., D) direction of ray 2 + +Returns: + (np.ndarray): (..., N) intersection point""" + utils3d.numpy.transforms.ray_intersection + +@overload +def se3_matrix(R: numpy_.ndarray, t: numpy_.ndarray) -> numpy_.ndarray: + """Convert rotation matrix and translation vector to 4x4 transformation matrix. + +Args: + R (np.ndarray): [..., 3, 3] rotation matrix + t (np.ndarray): [..., 3] translation vector + +Returns: + np.ndarray: [..., 4, 4] transformation matrix""" + utils3d.numpy.transforms.se3_matrix + +@overload +def slerp_quaternion(q1: numpy_.ndarray, q2: numpy_.ndarray, t: numpy_.ndarray) -> numpy_.ndarray: + """Spherical linear interpolation between two unit quaternions. + +Args: + q1 (np.ndarray): [..., d] unit vector 1 + q2 (np.ndarray): [..., d] unit vector 2 + t (np.ndarray): [...] interpolation parameter in [0, 1] + +Returns: + np.ndarray: [..., 3] interpolated unit vector""" + utils3d.numpy.transforms.slerp_quaternion + +@overload +def slerp_vector(v1: numpy_.ndarray, v2: numpy_.ndarray, t: numpy_.ndarray) -> numpy_.ndarray: + """Spherical linear interpolation between two unit vectors. The vectors are assumed to be normalized. + +Args: + v1 (np.ndarray): [..., d] unit vector 1 + v2 (np.ndarray): [..., d] unit vector 2 + t (np.ndarray): [...] interpolation parameter in [0, 1] + +Returns: + np.ndarray: [..., d] interpolated unit vector""" + utils3d.numpy.transforms.slerp_vector + +@overload +def lerp(x1: numpy_.ndarray, x2: numpy_.ndarray, t: numpy_.ndarray) -> numpy_.ndarray: + """Linear interpolation between two vectors. + +Args: + x1 (np.ndarray): [..., d] vector 1 + x2 (np.ndarray): [..., d] vector 2 + t (np.ndarray): [...] interpolation parameter. [0, 1] for interpolation between x1 and x2, otherwise for extrapolation. + +Returns: + np.ndarray: [..., d] interpolated vector""" + utils3d.numpy.transforms.lerp + +@overload +def lerp_se3_matrix(T1: numpy_.ndarray, T2: numpy_.ndarray, t: numpy_.ndarray) -> numpy_.ndarray: + """Linear interpolation between two SE(3) matrices. + +Args: + T1 (np.ndarray): [..., 4, 4] SE(3) matrix 1 + T2 (np.ndarray): [..., 4, 4] SE(3) matrix 2 + t (np.ndarray): [...] interpolation parameter in [0, 1] + +Returns: + np.ndarray: [..., 4, 4] interpolated SE(3) matrix""" + utils3d.numpy.transforms.lerp_se3_matrix + +@overload +def piecewise_lerp(x: numpy_.ndarray, t: numpy_.ndarray, s: numpy_.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> numpy_.ndarray: + """Linear spline interpolation. + +### Parameters: +- `x`: np.ndarray, shape (n, d): the values of data points. +- `t`: np.ndarray, shape (n,): the times of the data points. +- `s`: np.ndarray, shape (m,): the times to be interpolated. +- `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly. + +### Returns: +- `y`: np.ndarray, shape (..., m, d): the interpolated values.""" + utils3d.numpy.transforms.piecewise_lerp + +@overload +def piecewise_lerp_se3_matrix(T: numpy_.ndarray, t: numpy_.ndarray, s: numpy_.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> numpy_.ndarray: + """Linear spline interpolation for SE(3) matrices. + +### Parameters: +- `T`: np.ndarray, shape (n, 4, 4): the SE(3) matrices. +- `t`: np.ndarray, shape (n,): the times of the data points. +- `s`: np.ndarray, shape (m,): the times to be interpolated. +- `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly. + +### Returns: +- `T_interp`: np.ndarray, shape (..., m, 4, 4): the interpolated SE(3) matrices.""" + utils3d.numpy.transforms.piecewise_lerp_se3_matrix + +@overload +def apply_transform(T: numpy_.ndarray, x: numpy_.ndarray) -> numpy_.ndarray: + """Apply SE(3) transformation to a point or a set of points. + +### Parameters: +- `T`: np.ndarray, shape (..., 4, 4): the SE(3) matrix. +- `x`: np.ndarray, shape (..., 3): the point or a set of points to be transformed. + +### Returns: +- `x_transformed`: np.ndarray, shape (..., 3): the transformed point or a set of points.""" + utils3d.numpy.transforms.apply_transform + +@overload +def linear_spline_interpolate(x: numpy_.ndarray, t: numpy_.ndarray, s: numpy_.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> numpy_.ndarray: + """Linear spline interpolation. + +### Parameters: +- `x`: np.ndarray, shape (n, d): the values of data points. +- `t`: np.ndarray, shape (n,): the times of the data points. +- `s`: np.ndarray, shape (m,): the times to be interpolated. +- `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly. + +### Returns: +- `y`: np.ndarray, shape (..., m, d): the interpolated values.""" + utils3d.numpy.spline.linear_spline_interpolate + +@overload +def RastContext(*args, **kwargs): + utils3d.numpy.rasterization.RastContext + +@overload +def rasterize_triangle_faces(ctx: utils3d.numpy.rasterization.RastContext, vertices: numpy_.ndarray, faces: numpy_.ndarray, attr: numpy_.ndarray, width: int, height: int, transform: numpy_.ndarray = None, cull_backface: bool = True, return_depth: bool = False, image: numpy_.ndarray = None, depth: numpy_.ndarray = None) -> Tuple[numpy_.ndarray, numpy_.ndarray]: + """Rasterize vertex attribute. + +Args: + vertices (np.ndarray): [N, 3] + faces (np.ndarray): [T, 3] + attr (np.ndarray): [N, C] + width (int): width of rendered image + height (int): height of rendered image + transform (np.ndarray): [4, 4] model-view-projection transformation matrix. + cull_backface (bool): whether to cull backface + image: (np.ndarray): [H, W, C] background image + depth: (np.ndarray): [H, W] background depth + +Returns: + image (np.ndarray): [H, W, C] rendered image + depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None.""" + utils3d.numpy.rasterization.rasterize_triangle_faces + +@overload +def rasterize_edges(ctx: utils3d.numpy.rasterization.RastContext, vertices: numpy_.ndarray, edges: numpy_.ndarray, attr: numpy_.ndarray, width: int, height: int, transform: numpy_.ndarray = None, line_width: float = 1.0, return_depth: bool = False, image: numpy_.ndarray = None, depth: numpy_.ndarray = None) -> Tuple[numpy_.ndarray, ...]: + """Rasterize vertex attribute. + +Args: + vertices (np.ndarray): [N, 3] + faces (np.ndarray): [T, 3] + attr (np.ndarray): [N, C] + width (int): width of rendered image + height (int): height of rendered image + transform (np.ndarray): [4, 4] model-view-projection matrix + line_width (float): width of line. Defaults to 1.0. NOTE: Values other than 1.0 may not work across all platforms. + cull_backface (bool): whether to cull backface + +Returns: + image (np.ndarray): [H, W, C] rendered image + depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None.""" + utils3d.numpy.rasterization.rasterize_edges + +@overload +def texture(ctx: utils3d.numpy.rasterization.RastContext, uv: numpy_.ndarray, texture: numpy_.ndarray, interpolation: str = 'linear', wrap: str = 'clamp') -> numpy_.ndarray: + """Given an UV image, texturing from the texture map""" + utils3d.numpy.rasterization.texture + +@overload +def warp_image_by_depth(ctx: utils3d.numpy.rasterization.RastContext, src_depth: numpy_.ndarray, src_image: numpy_.ndarray = None, width: int = None, height: int = None, *, extrinsics_src: numpy_.ndarray = None, extrinsics_tgt: numpy_.ndarray = None, intrinsics_src: numpy_.ndarray = None, intrinsics_tgt: numpy_.ndarray = None, near: float = 0.1, far: float = 100.0, cull_backface: bool = True, ssaa: int = 1, return_depth: bool = False) -> Tuple[numpy_.ndarray, ...]: + """Warp image by depth map. + +Args: + ctx (RastContext): rasterizer context + src_depth (np.ndarray): [H, W] + src_image (np.ndarray, optional): [H, W, C]. The image to warp. Defaults to None (use uv coordinates). + width (int, optional): width of the output image. None to use depth map width. Defaults to None. + height (int, optional): height of the output image. None to use depth map height. Defaults to None. + extrinsics_src (np.ndarray, optional): extrinsics matrix of the source camera. Defaults to None (identity). + extrinsics_tgt (np.ndarray, optional): extrinsics matrix of the target camera. Defaults to None (identity). + intrinsics_src (np.ndarray, optional): intrinsics matrix of the source camera. Defaults to None (use the same as intrinsics_tgt). + intrinsics_tgt (np.ndarray, optional): intrinsics matrix of the target camera. Defaults to None (use the same as intrinsics_src). + cull_backface (bool, optional): whether to cull backface. Defaults to True. + ssaa (int, optional): super sampling anti-aliasing. Defaults to 1. + +Returns: + tgt_image (np.ndarray): [H, W, C] warped image (or uv coordinates if image is None). + tgt_depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None.""" + utils3d.numpy.rasterization.warp_image_by_depth + +@overload +def test_rasterization(ctx: utils3d.numpy.rasterization.RastContext): + """Test if rasterization works. It will render a cube with random colors and save it as a CHECKME.png file.""" + utils3d.numpy.rasterization.test_rasterization + +@overload +def triangulate(faces: torch_.Tensor, vertices: torch_.Tensor = None, backslash: bool = None) -> torch_.Tensor: + """Triangulate a polygonal mesh. + +Args: + faces (torch.Tensor): [..., L, P] polygonal faces + vertices (torch.Tensor, optional): [..., N, 3] 3-dimensional vertices. + If given, the triangulation is performed according to the distance + between vertices. Defaults to None. + backslash (torch.Tensor, optional): [..., L] boolean array indicating + how to triangulate the quad faces. Defaults to None. + + +Returns: + (torch.Tensor): [L * (P - 2), 3] triangular faces""" + utils3d.torch.mesh.triangulate + +@overload +def compute_face_normal(vertices: torch_.Tensor, faces: torch_.Tensor) -> torch_.Tensor: + """Compute face normals of a triangular mesh + +Args: + vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices + faces (torch.Tensor): [..., T, 3] triangular face indices + +Returns: + normals (torch.Tensor): [..., T, 3] face normals""" + utils3d.torch.mesh.compute_face_normal + +@overload +def compute_face_angles(vertices: torch_.Tensor, faces: torch_.Tensor) -> torch_.Tensor: + """Compute face angles of a triangular mesh + +Args: + vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + +Returns: + angles (torch.Tensor): [..., T, 3] face angles""" + utils3d.torch.mesh.compute_face_angles + +@overload +def compute_vertex_normal(vertices: torch_.Tensor, faces: torch_.Tensor, face_normal: torch_.Tensor = None) -> torch_.Tensor: + """Compute vertex normals of a triangular mesh by averaging neightboring face normals + +Args: + vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + face_normal (torch.Tensor, optional): [..., T, 3] face normals. + None to compute face normals from vertices and faces. Defaults to None. + +Returns: + normals (torch.Tensor): [..., N, 3] vertex normals""" + utils3d.torch.mesh.compute_vertex_normal + +@overload +def compute_vertex_normal_weighted(vertices: torch_.Tensor, faces: torch_.Tensor, face_normal: torch_.Tensor = None) -> torch_.Tensor: + """Compute vertex normals of a triangular mesh by weighted sum of neightboring face normals +according to the angles + +Args: + vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + face_normal (torch.Tensor, optional): [..., T, 3] face normals. + None to compute face normals from vertices and faces. Defaults to None. + +Returns: + normals (torch.Tensor): [..., N, 3] vertex normals""" + utils3d.torch.mesh.compute_vertex_normal_weighted + +@overload +def remove_unreferenced_vertices(faces: torch_.Tensor, *vertice_attrs, return_indices: bool = False) -> Tuple[torch_.Tensor, ...]: + """Remove unreferenced vertices of a mesh. +Unreferenced vertices are removed, and the face indices are updated accordingly. + +Args: + faces (torch.Tensor): [T, P] face indices + *vertice_attrs: vertex attributes + +Returns: + faces (torch.Tensor): [T, P] face indices + *vertice_attrs: vertex attributes + indices (torch.Tensor, optional): [N] indices of vertices that are kept. Defaults to None.""" + utils3d.torch.mesh.remove_unreferenced_vertices + +@overload +def remove_corrupted_faces(faces: torch_.Tensor) -> torch_.Tensor: + """Remove corrupted faces (faces with duplicated vertices) + +Args: + faces (torch.Tensor): [T, 3] triangular face indices + +Returns: + torch.Tensor: [T_, 3] triangular face indices""" + utils3d.torch.mesh.remove_corrupted_faces + +@overload +def merge_duplicate_vertices(vertices: torch_.Tensor, faces: torch_.Tensor, tol: float = 1e-06) -> Tuple[torch_.Tensor, torch_.Tensor]: + """Merge duplicate vertices of a triangular mesh. +Duplicate vertices are merged by selecte one of them, and the face indices are updated accordingly. + +Args: + vertices (torch.Tensor): [N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + tol (float, optional): tolerance for merging. Defaults to 1e-6. + +Returns: + vertices (torch.Tensor): [N_, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices""" + utils3d.torch.mesh.merge_duplicate_vertices + +@overload +def subdivide_mesh_simple(vertices: torch_.Tensor, faces: torch_.Tensor, n: int = 1) -> Tuple[torch_.Tensor, torch_.Tensor]: + """Subdivide a triangular mesh by splitting each triangle into 4 smaller triangles. +NOTE: All original vertices are kept, and new vertices are appended to the end of the vertex list. + +Args: + vertices (torch.Tensor): [N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + n (int, optional): number of subdivisions. Defaults to 1. + +Returns: + vertices (torch.Tensor): [N_, 3] subdivided 3-dimensional vertices + faces (torch.Tensor): [4 * T, 3] subdivided triangular face indices""" + utils3d.torch.mesh.subdivide_mesh_simple + +@overload +def compute_face_tbn(pos: torch_.Tensor, faces_pos: torch_.Tensor, uv: torch_.Tensor, faces_uv: torch_.Tensor, eps: float = 1e-07) -> torch_.Tensor: + """compute TBN matrix for each face + +Args: + pos (torch.Tensor): shape (..., N_pos, 3), positions + faces_pos (torch.Tensor): shape(T, 3) + uv (torch.Tensor): shape (..., N_uv, 3) uv coordinates, + faces_uv (torch.Tensor): shape(T, 3) + +Returns: + torch.Tensor: (..., T, 3, 3) TBN matrix for each face. Note TBN vectors are normalized but not necessarily orthognal""" + utils3d.torch.mesh.compute_face_tbn + +@overload +def compute_vertex_tbn(faces_topo: torch_.Tensor, pos: torch_.Tensor, faces_pos: torch_.Tensor, uv: torch_.Tensor, faces_uv: torch_.Tensor) -> torch_.Tensor: + """compute TBN matrix for each face + +Args: + faces_topo (torch.Tensor): (T, 3), face indice of topology + pos (torch.Tensor): shape (..., N_pos, 3), positions + faces_pos (torch.Tensor): shape(T, 3) + uv (torch.Tensor): shape (..., N_uv, 3) uv coordinates, + faces_uv (torch.Tensor): shape(T, 3) + +Returns: + torch.Tensor: (..., V, 3, 3) TBN matrix for each face. Note TBN vectors are normalized but not necessarily orthognal""" + utils3d.torch.mesh.compute_vertex_tbn + +@overload +def laplacian(vertices: torch_.Tensor, faces: torch_.Tensor, weight: str = 'uniform') -> torch_.Tensor: + """Laplacian smooth with cotangent weights + +Args: + vertices (torch.Tensor): shape (..., N, 3) + faces (torch.Tensor): shape (T, 3) + weight (str): 'uniform' or 'cotangent'""" + utils3d.torch.mesh.laplacian + +@overload +def laplacian_smooth_mesh(vertices: torch_.Tensor, faces: torch_.Tensor, weight: str = 'uniform', times: int = 5) -> torch_.Tensor: + """Laplacian smooth with cotangent weights + +Args: + vertices (torch.Tensor): shape (..., N, 3) + faces (torch.Tensor): shape (T, 3) + weight (str): 'uniform' or 'cotangent'""" + utils3d.torch.mesh.laplacian_smooth_mesh + +@overload +def taubin_smooth_mesh(vertices: torch_.Tensor, faces: torch_.Tensor, lambda_: float = 0.5, mu_: float = -0.51) -> torch_.Tensor: + """Taubin smooth mesh + +Args: + vertices (torch.Tensor): _description_ + faces (torch.Tensor): _description_ + lambda_ (float, optional): _description_. Defaults to 0.5. + mu_ (float, optional): _description_. Defaults to -0.51. + +Returns: + torch.Tensor: _description_""" + utils3d.torch.mesh.taubin_smooth_mesh + +@overload +def laplacian_hc_smooth_mesh(vertices: torch_.Tensor, faces: torch_.Tensor, times: int = 5, alpha: float = 0.5, beta: float = 0.5, weight: str = 'uniform'): + """HC algorithm from Improved Laplacian Smoothing of Noisy Surface Meshes by J.Vollmer et al. + """ + utils3d.torch.mesh.laplacian_hc_smooth_mesh + +@overload +def get_rays(extrinsics: torch_.Tensor, intrinsics: torch_.Tensor, uv: torch_.Tensor) -> Tuple[torch_.Tensor, torch_.Tensor]: + """Args: + extrinsics: (..., 4, 4) extrinsics matrices. + intrinsics: (..., 3, 3) intrinsics matrices. + uv: (..., n_rays, 2) uv coordinates of the rays. + +Returns: + rays_o: (..., 1, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + NOTE: ray directions are NOT normalized. They actuallys makes rays_o + rays_d * z = world coordinates, where z is the depth.""" + utils3d.torch.nerf.get_rays + +@overload +def get_image_rays(extrinsics: torch_.Tensor, intrinsics: torch_.Tensor, width: int, height: int) -> Tuple[torch_.Tensor, torch_.Tensor]: + """Args: + extrinsics: (..., 4, 4) extrinsics matrices. + intrinsics: (..., 3, 3) intrinsics matrices. + width: width of the image. + height: height of the image. + +Returns: + rays_o: (..., 1, 1, 3) ray origins + rays_d: (..., height, width, 3) ray directions. + NOTE: ray directions are NOT normalized. They actuallys makes rays_o + rays_d * z = world coordinates, where z is the depth.""" + utils3d.torch.nerf.get_image_rays + +@overload +def get_mipnerf_cones(rays_o: torch_.Tensor, rays_d: torch_.Tensor, z_vals: torch_.Tensor, pixel_width: torch_.Tensor) -> Tuple[torch_.Tensor, torch_.Tensor]: + """Args: + rays_o: (..., n_rays, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + z_vals: (..., n_rays, n_samples) z values. + pixel_width: (...) pixel width. = 1 / (normalized focal length * width) + +Returns: + mu: (..., n_rays, n_samples, 3) cone mu. + sigma: (..., n_rays, n_samples, 3, 3) cone sigma.""" + utils3d.torch.nerf.get_mipnerf_cones + +@overload +def volume_rendering(color: torch_.Tensor, sigma: torch_.Tensor, z_vals: torch_.Tensor, ray_length: torch_.Tensor, rgb: bool = True, depth: bool = True) -> Tuple[torch_.Tensor, torch_.Tensor, torch_.Tensor]: + """Given color, sigma and z_vals (linear depth of the sampling points), render the volume. + +NOTE: By default, color and sigma should have one less sample than z_vals, in correspondence with the average value in intervals. +If queried color are aligned with z_vals, we use trapezoidal rule to calculate the average values in intervals. + +Args: + color: (..., n_samples or n_samples - 1, 3) color values. + sigma: (..., n_samples or n_samples - 1) density values. + z_vals: (..., n_samples) z values. + ray_length: (...) length of the ray + +Returns: + rgb: (..., 3) rendered color values. + depth: (...) rendered depth values. + weights (..., n_samples) weights.""" + utils3d.torch.nerf.volume_rendering + +@overload +def bin_sample(size: Union[torch_.Size, Tuple[int, ...]], n_samples: int, min_value: numbers.Number, max_value: numbers.Number, spacing: Literal['linear', 'inverse_linear'], dtype: torch_.dtype = None, device: torch_.device = None) -> torch_.Tensor: + """Uniformly (or uniformly in inverse space) sample z values in `n_samples` bins in range [min_value, max_value]. +Args: + size: size of the rays + n_samples: number of samples to be sampled, also the number of bins + min_value: minimum value of the range + max_value: maximum value of the range + space: 'linear' or 'inverse_linear'. If 'inverse_linear', the sampling is uniform in inverse space. + +Returns: + z_rand: (*size, n_samples) sampled z values, sorted in ascending order.""" + utils3d.torch.nerf.bin_sample + +@overload +def importance_sample(z_vals: torch_.Tensor, weights: torch_.Tensor, n_samples: int) -> Tuple[torch_.Tensor, torch_.Tensor]: + """Importance sample z values. + +NOTE: By default, weights should have one less sample than z_vals, in correspondence with the intervals. +If weights has the same number of samples as z_vals, we use trapezoidal rule to calculate the average weights in intervals. + +Args: + z_vals: (..., n_rays, n_input_samples) z values, sorted in ascending order. + weights: (..., n_rays, n_input_samples or n_input_samples - 1) weights. + n_samples: number of output samples for importance sampling. + +Returns: + z_importance: (..., n_rays, n_samples) importance sampled z values, unsorted.""" + utils3d.torch.nerf.importance_sample + +@overload +def nerf_render_rays(nerf: Union[Callable[[torch_.Tensor, torch_.Tensor], Tuple[torch_.Tensor, torch_.Tensor]], Tuple[Callable[[torch_.Tensor], Tuple[torch_.Tensor, torch_.Tensor]], Callable[[torch_.Tensor], Tuple[torch_.Tensor, torch_.Tensor]]]], rays_o: torch_.Tensor, rays_d: torch_.Tensor, *, return_dict: bool = False, n_coarse: int = 64, n_fine: int = 64, near: float = 0.1, far: float = 100.0, z_spacing: Literal['linear', 'inverse_linear'] = 'linear'): + """NeRF rendering of rays. Note that it supports arbitrary batch dimensions (denoted as `...`) + +Args: + nerf: nerf model, which takes (points, directions) as input and returns (color, density) as output. + If nerf is a tuple, it should be (nerf_coarse, nerf_fine), where nerf_coarse and nerf_fine are two nerf models for coarse and fine stages respectively. + + nerf args: + points: (..., n_rays, n_samples, 3) + directions: (..., n_rays, n_samples, 3) + nerf returns: + color: (..., n_rays, n_samples, 3) color values. + density: (..., n_rays, n_samples) density values. + + rays_o: (..., n_rays, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + pixel_width: (..., n_rays) pixel width. How to compute? pixel_width = 1 / (normalized focal length * width) + +Returns + if return_dict is False, return rendered rgb and depth for short cut. (If there are separate coarse and fine results, return fine results) + rgb: (..., n_rays, 3) rendered color values. + depth: (..., n_rays) rendered depth values. + else, return a dict. If `n_fine == 0` or `nerf` is a single model, the dict only contains coarse results: + ``` + {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + ``` + If there are two models for coarse and fine stages, the dict contains both coarse and fine results: + ``` + { + "coarse": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}, + "fine": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + } + ```""" + utils3d.torch.nerf.nerf_render_rays + +@overload +def mipnerf_render_rays(mipnerf: Callable[[torch_.Tensor, torch_.Tensor, torch_.Tensor], Tuple[torch_.Tensor, torch_.Tensor]], rays_o: torch_.Tensor, rays_d: torch_.Tensor, pixel_width: torch_.Tensor, *, return_dict: bool = False, n_coarse: int = 64, n_fine: int = 64, uniform_ratio: float = 0.4, near: float = 0.1, far: float = 100.0, z_spacing: Literal['linear', 'inverse_linear'] = 'linear') -> Union[Tuple[torch_.Tensor, torch_.Tensor], Dict[str, torch_.Tensor]]: + """MipNeRF rendering. + +Args: + mipnerf: mipnerf model, which takes (points_mu, points_sigma) as input and returns (color, density) as output. + + mipnerf args: + points_mu: (..., n_rays, n_samples, 3) cone mu. + points_sigma: (..., n_rays, n_samples, 3, 3) cone sigma. + directions: (..., n_rays, n_samples, 3) + mipnerf returns: + color: (..., n_rays, n_samples, 3) color values. + density: (..., n_rays, n_samples) density values. + + rays_o: (..., n_rays, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + pixel_width: (..., n_rays) pixel width. How to compute? pixel_width = 1 / (normalized focal length * width) + +Returns + if return_dict is False, return rendered results only: (If `n_fine == 0`, return coarse results, otherwise return fine results) + rgb: (..., n_rays, 3) rendered color values. + depth: (..., n_rays) rendered depth values. + else, return a dict. If `n_fine == 0`, the dict only contains coarse results: + ``` + {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + ``` + If n_fine > 0, the dict contains both coarse and fine results : + ``` + { + "coarse": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}, + "fine": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + } + ```""" + utils3d.torch.nerf.mipnerf_render_rays + +@overload +def nerf_render_view(nerf: torch_.Tensor, extrinsics: torch_.Tensor, intrinsics: torch_.Tensor, width: int, height: int, *, patchify: bool = False, patch_size: Tuple[int, int] = (64, 64), **options: Dict[str, Any]) -> Tuple[torch_.Tensor, torch_.Tensor]: + """NeRF rendering of views. Note that it supports arbitrary batch dimensions (denoted as `...`) + +Args: + extrinsics: (..., 4, 4) extrinsics matrice of the rendered views + intrinsics (optional): (..., 3, 3) intrinsics matrice of the rendered views. + width (optional): image width of the rendered views. + height (optional): image height of the rendered views. + patchify (optional): If the image is too large, render it patch by patch + **options: rendering options. + +Returns: + rgb: (..., channels, height, width) rendered color values. + depth: (..., height, width) rendered depth values.""" + utils3d.torch.nerf.nerf_render_view + +@overload +def mipnerf_render_view(mipnerf: torch_.Tensor, extrinsics: torch_.Tensor, intrinsics: torch_.Tensor, width: int, height: int, *, patchify: bool = False, patch_size: Tuple[int, int] = (64, 64), **options: Dict[str, Any]) -> Tuple[torch_.Tensor, torch_.Tensor]: + """MipNeRF rendering of views. Note that it supports arbitrary batch dimensions (denoted as `...`) + +Args: + extrinsics: (..., 4, 4) extrinsics matrice of the rendered views + intrinsics (optional): (..., 3, 3) intrinsics matrice of the rendered views. + width (optional): image width of the rendered views. + height (optional): image height of the rendered views. + patchify (optional): If the image is too large, render it patch by patch + **options: rendering options. + +Returns: + rgb: (..., 3, height, width) rendered color values. + depth: (..., height, width) rendered depth values.""" + utils3d.torch.nerf.mipnerf_render_view + +@overload +def InstantNGP(view_dependent: bool = True, base_resolution: int = 16, finest_resolution: int = 2048, n_levels: int = 16, num_layers_density: int = 2, hidden_dim_density: int = 64, num_layers_color: int = 3, hidden_dim_color: int = 64, log2_hashmap_size: int = 19, bound: float = 1.0, color_channels: int = 3): + """An implementation of InstantNGP, Müller et. al., https://nvlabs.github.io/instant-ngp/. +Requires `tinycudann` package. +Install it by: +``` +pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch +```""" + utils3d.torch.nerf.InstantNGP + +@overload +def sliding_window_1d(x: torch_.Tensor, window_size: int, stride: int = 1, dim: int = -1) -> torch_.Tensor: + """Sliding window view of the input tensor. The dimension of the sliding window is appended to the end of the input tensor's shape. +NOTE: Since Pytorch has `unfold` function, 1D sliding window view is just a wrapper of it.""" + utils3d.torch.utils.sliding_window_1d + +@overload +def sliding_window_2d(x: torch_.Tensor, window_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], dim: Union[int, Tuple[int, int]] = (-2, -1)) -> torch_.Tensor: + utils3d.torch.utils.sliding_window_2d + +@overload +def sliding_window_nd(x: torch_.Tensor, window_size: Tuple[int, ...], stride: Tuple[int, ...], dim: Tuple[int, ...]) -> torch_.Tensor: + utils3d.torch.utils.sliding_window_nd + +@overload +def image_uv(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, device: torch_.device = None, dtype: torch_.dtype = None) -> torch_.Tensor: + """Get image space UV grid, ranging in [0, 1]. + +>>> image_uv(10, 10): +[[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]], + [[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]], + ... ... ... + [[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]] + +Args: + width (int): image width + height (int): image height + +Returns: + np.ndarray: shape (height, width, 2)""" + utils3d.torch.utils.image_uv + +@overload +def image_pixel_center(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, dtype: torch_.dtype = None, device: torch_.device = None) -> torch_.Tensor: + """Get image pixel center coordinates, ranging in [0, width] and [0, height]. +`image[i, j]` has pixel center coordinates `(j + 0.5, i + 0.5)`. + +>>> image_pixel_center(10, 10): +[[[0.5, 0.5], [1.5, 0.5], ..., [9.5, 0.5]], + [[0.5, 1.5], [1.5, 1.5], ..., [9.5, 1.5]], + ... ... ... +[[0.5, 9.5], [1.5, 9.5], ..., [9.5, 9.5]]] + +Args: + width (int): image width + height (int): image height + +Returns: + np.ndarray: shape (height, width, 2)""" + utils3d.torch.utils.image_pixel_center + +@overload +def image_mesh(height: int, width: int, mask: torch_.Tensor = None, device: torch_.device = None, dtype: torch_.dtype = None) -> Tuple[torch_.Tensor, torch_.Tensor]: + """Get a quad mesh regarding image pixel uv coordinates as vertices and image grid as faces. + +Args: + width (int): image width + height (int): image height + mask (np.ndarray, optional): binary mask of shape (height, width), dtype=bool. Defaults to None. + +Returns: + uv (np.ndarray): uv corresponding to pixels as described in image_uv() + faces (np.ndarray): quad faces connecting neighboring pixels + indices (np.ndarray, optional): indices of vertices in the original mesh""" + utils3d.torch.utils.image_mesh + +@overload +def chessboard(width: int, height: int, grid_size: int, color_a: torch_.Tensor, color_b: torch_.Tensor) -> torch_.Tensor: + """get a chessboard image + +Args: + width (int): image width + height (int): image height + grid_size (int): size of chessboard grid + color_a (torch.Tensor): shape (chanenls,), color of the grid at the top-left corner + color_b (torch.Tensor): shape (chanenls,), color in complementary grids + +Returns: + image (torch.Tensor): shape (height, width, channels), chessboard image""" + utils3d.torch.utils.chessboard + +@overload +def depth_edge(depth: torch_.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch_.Tensor = None) -> torch_.BoolTensor: + """Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth. + +Args: + depth (torch.Tensor): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + +Returns: + edge (torch.Tensor): shape (..., height, width) of dtype torch.bool""" + utils3d.torch.utils.depth_edge + +@overload +def depth_aliasing(depth: torch_.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch_.Tensor = None) -> torch_.BoolTensor: + """Compute the map that indicates the aliasing of a depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors. +Args: + depth (torch.Tensor): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + +Returns: + edge (torch.Tensor): shape (..., height, width) of dtype torch.bool""" + utils3d.torch.utils.depth_aliasing + +@overload +def image_mesh_from_depth(depth: torch_.Tensor, extrinsics: torch_.Tensor = None, intrinsics: torch_.Tensor = None) -> Tuple[torch_.Tensor, torch_.Tensor]: + utils3d.torch.utils.image_mesh_from_depth + +@overload +def point_to_normal(point: torch_.Tensor, mask: torch_.Tensor = None) -> torch_.Tensor: + """Calculate normal map from point map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. + +Args: + point (torch.Tensor): shape (..., height, width, 3), point map +Returns: + normal (torch.Tensor): shape (..., height, width, 3), normal map. """ + utils3d.torch.utils.point_to_normal + +@overload +def depth_to_normal(depth: torch_.Tensor, intrinsics: torch_.Tensor, mask: torch_.Tensor = None) -> torch_.Tensor: + """Calculate normal map from depth map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. + +Args: + depth (torch.Tensor): shape (..., height, width), linear depth map + intrinsics (torch.Tensor): shape (..., 3, 3), intrinsics matrix +Returns: + normal (torch.Tensor): shape (..., 3, height, width), normal map. """ + utils3d.torch.utils.depth_to_normal + +@overload +def masked_min(input: torch_.Tensor, mask: torch_.BoolTensor, dim: int = None, keepdim: bool = False) -> Union[torch_.Tensor, Tuple[torch_.Tensor, torch_.Tensor]]: + """Similar to torch.min, but with mask + """ + utils3d.torch.utils.masked_min + +@overload +def masked_max(input: torch_.Tensor, mask: torch_.BoolTensor, dim: int = None, keepdim: bool = False) -> Union[torch_.Tensor, Tuple[torch_.Tensor, torch_.Tensor]]: + """Similar to torch.max, but with mask + """ + utils3d.torch.utils.masked_max + +@overload +def bounding_rect(mask: torch_.BoolTensor): + """get bounding rectangle of a mask + +Args: + mask (torch.Tensor): shape (..., height, width), mask + +Returns: + rect (torch.Tensor): shape (..., 4), bounding rectangle (left, top, right, bottom)""" + utils3d.torch.utils.bounding_rect + +@overload +def perspective(fov_y: Union[float, torch_.Tensor], aspect: Union[float, torch_.Tensor], near: Union[float, torch_.Tensor], far: Union[float, torch_.Tensor]) -> torch_.Tensor: + """Get OpenGL perspective matrix + +Args: + fov_y (float | torch.Tensor): field of view in y axis + aspect (float | torch.Tensor): aspect ratio + near (float | torch.Tensor): near plane to clip + far (float | torch.Tensor): far plane to clip + +Returns: + (torch.Tensor): [..., 4, 4] perspective matrix""" + utils3d.torch.transforms.perspective + +@overload +def perspective_from_fov(fov: Union[float, torch_.Tensor], width: Union[int, torch_.Tensor], height: Union[int, torch_.Tensor], near: Union[float, torch_.Tensor], far: Union[float, torch_.Tensor]) -> torch_.Tensor: + """Get OpenGL perspective matrix from field of view in largest dimension + +Args: + fov (float | torch.Tensor): field of view in largest dimension + width (int | torch.Tensor): image width + height (int | torch.Tensor): image height + near (float | torch.Tensor): near plane to clip + far (float | torch.Tensor): far plane to clip + +Returns: + (torch.Tensor): [..., 4, 4] perspective matrix""" + utils3d.torch.transforms.perspective_from_fov + +@overload +def perspective_from_fov_xy(fov_x: Union[float, torch_.Tensor], fov_y: Union[float, torch_.Tensor], near: Union[float, torch_.Tensor], far: Union[float, torch_.Tensor]) -> torch_.Tensor: + """Get OpenGL perspective matrix from field of view in x and y axis + +Args: + fov_x (float | torch.Tensor): field of view in x axis + fov_y (float | torch.Tensor): field of view in y axis + near (float | torch.Tensor): near plane to clip + far (float | torch.Tensor): far plane to clip + +Returns: + (torch.Tensor): [..., 4, 4] perspective matrix""" + utils3d.torch.transforms.perspective_from_fov_xy + +@overload +def intrinsics_from_focal_center(fx: Union[float, torch_.Tensor], fy: Union[float, torch_.Tensor], cx: Union[float, torch_.Tensor], cy: Union[float, torch_.Tensor]) -> torch_.Tensor: + """Get OpenCV intrinsics matrix + +Args: + focal_x (float | torch.Tensor): focal length in x axis + focal_y (float | torch.Tensor): focal length in y axis + cx (float | torch.Tensor): principal point in x axis + cy (float | torch.Tensor): principal point in y axis + +Returns: + (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix""" + utils3d.torch.transforms.intrinsics_from_focal_center + +@overload +def intrinsics_from_fov(fov_max: Union[float, torch_.Tensor] = None, fov_min: Union[float, torch_.Tensor] = None, fov_x: Union[float, torch_.Tensor] = None, fov_y: Union[float, torch_.Tensor] = None, width: Union[int, torch_.Tensor] = None, height: Union[int, torch_.Tensor] = None) -> torch_.Tensor: + """Get normalized OpenCV intrinsics matrix from given field of view. +You can provide either fov_max, fov_min, fov_x or fov_y + +Args: + width (int | torch.Tensor): image width + height (int | torch.Tensor): image height + fov_max (float | torch.Tensor): field of view in largest dimension + fov_min (float | torch.Tensor): field of view in smallest dimension + fov_x (float | torch.Tensor): field of view in x axis + fov_y (float | torch.Tensor): field of view in y axis + +Returns: + (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix""" + utils3d.torch.transforms.intrinsics_from_fov + +@overload +def intrinsics_from_fov_xy(fov_x: Union[float, torch_.Tensor], fov_y: Union[float, torch_.Tensor]) -> torch_.Tensor: + """Get OpenCV intrinsics matrix from field of view in x and y axis + +Args: + fov_x (float | torch.Tensor): field of view in x axis + fov_y (float | torch.Tensor): field of view in y axis + +Returns: + (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix""" + utils3d.torch.transforms.intrinsics_from_fov_xy + +@overload +def view_look_at(eye: torch_.Tensor, look_at: torch_.Tensor, up: torch_.Tensor) -> torch_.Tensor: + """Get OpenGL view matrix looking at something + +Args: + eye (torch.Tensor): [..., 3] the eye position + look_at (torch.Tensor): [..., 3] the position to look at + up (torch.Tensor): [..., 3] head up direction (y axis in screen space). Not necessarily othogonal to view direction + +Returns: + (torch.Tensor): [..., 4, 4], view matrix""" + utils3d.torch.transforms.view_look_at + +@overload +def extrinsics_look_at(eye: torch_.Tensor, look_at: torch_.Tensor, up: torch_.Tensor) -> torch_.Tensor: + """Get OpenCV extrinsics matrix looking at something + +Args: + eye (torch.Tensor): [..., 3] the eye position + look_at (torch.Tensor): [..., 3] the position to look at + up (torch.Tensor): [..., 3] head up direction (-y axis in screen space). Not necessarily othogonal to view direction + +Returns: + (torch.Tensor): [..., 4, 4], extrinsics matrix""" + utils3d.torch.transforms.extrinsics_look_at + +@overload +def perspective_to_intrinsics(perspective: torch_.Tensor) -> torch_.Tensor: + """OpenGL perspective matrix to OpenCV intrinsics + +Args: + perspective (torch.Tensor): [..., 4, 4] OpenGL perspective matrix + +Returns: + (torch.Tensor): shape [..., 3, 3] OpenCV intrinsics""" + utils3d.torch.transforms.perspective_to_intrinsics + +@overload +def intrinsics_to_perspective(intrinsics: torch_.Tensor, near: Union[float, torch_.Tensor], far: Union[float, torch_.Tensor]) -> torch_.Tensor: + """OpenCV intrinsics to OpenGL perspective matrix + +Args: + intrinsics (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix + near (float | torch.Tensor): [...] near plane to clip + far (float | torch.Tensor): [...] far plane to clip +Returns: + (torch.Tensor): [..., 4, 4] OpenGL perspective matrix""" + utils3d.torch.transforms.intrinsics_to_perspective + +@overload +def extrinsics_to_view(extrinsics: torch_.Tensor) -> torch_.Tensor: + """OpenCV camera extrinsics to OpenGL view matrix + +Args: + extrinsics (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix + +Returns: + (torch.Tensor): [..., 4, 4] OpenGL view matrix""" + utils3d.torch.transforms.extrinsics_to_view + +@overload +def view_to_extrinsics(view: torch_.Tensor) -> torch_.Tensor: + """OpenGL view matrix to OpenCV camera extrinsics + +Args: + view (torch.Tensor): [..., 4, 4] OpenGL view matrix + +Returns: + (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix""" + utils3d.torch.transforms.view_to_extrinsics + +@overload +def normalize_intrinsics(intrinsics: torch_.Tensor, width: Union[int, torch_.Tensor], height: Union[int, torch_.Tensor]) -> torch_.Tensor: + """Normalize camera intrinsics(s) to uv space + +Args: + intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to normalize + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + +Returns: + (torch.Tensor): [..., 3, 3] normalized camera intrinsics(s)""" + utils3d.torch.transforms.normalize_intrinsics + +@overload +def crop_intrinsics(intrinsics: torch_.Tensor, width: Union[int, torch_.Tensor], height: Union[int, torch_.Tensor], left: Union[int, torch_.Tensor], top: Union[int, torch_.Tensor], crop_width: Union[int, torch_.Tensor], crop_height: Union[int, torch_.Tensor]) -> torch_.Tensor: + """Evaluate the new intrinsics(s) after crop the image: cropped_img = img[top:top+crop_height, left:left+crop_width] + +Args: + intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to crop + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + left (int | torch.Tensor): [...] left crop boundary + top (int | torch.Tensor): [...] top crop boundary + crop_width (int | torch.Tensor): [...] crop width + crop_height (int | torch.Tensor): [...] crop height + +Returns: + (torch.Tensor): [..., 3, 3] cropped camera intrinsics(s)""" + utils3d.torch.transforms.crop_intrinsics + +@overload +def pixel_to_uv(pixel: torch_.Tensor, width: Union[int, torch_.Tensor], height: Union[int, torch_.Tensor]) -> torch_.Tensor: + """Args: + pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + +Returns: + (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)""" + utils3d.torch.transforms.pixel_to_uv + +@overload +def pixel_to_ndc(pixel: torch_.Tensor, width: Union[int, torch_.Tensor], height: Union[int, torch_.Tensor]) -> torch_.Tensor: + """Args: + pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + +Returns: + (torch.Tensor): [..., 2] pixel coordinrates defined in ndc space, the range is (-1, 1)""" + utils3d.torch.transforms.pixel_to_ndc + +@overload +def uv_to_pixel(uv: torch_.Tensor, width: Union[int, torch_.Tensor], height: Union[int, torch_.Tensor]) -> torch_.Tensor: + """Args: + uv (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1) + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + +Returns: + (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)""" + utils3d.torch.transforms.uv_to_pixel + +@overload +def project_depth(depth: torch_.Tensor, near: Union[float, torch_.Tensor], far: Union[float, torch_.Tensor]) -> torch_.Tensor: + """Project linear depth to depth value in screen space + +Args: + depth (torch.Tensor): [...] depth value + near (float | torch.Tensor): [...] near plane to clip + far (float | torch.Tensor): [...] far plane to clip + +Returns: + (torch.Tensor): [..., 1] depth value in screen space, value ranging in [0, 1]""" + utils3d.torch.transforms.project_depth + +@overload +def depth_buffer_to_linear(depth: torch_.Tensor, near: Union[float, torch_.Tensor], far: Union[float, torch_.Tensor]) -> torch_.Tensor: + """Linearize depth value to linear depth + +Args: + depth (torch.Tensor): [...] screen depth value, ranging in [0, 1] + near (float | torch.Tensor): [...] near plane to clip + far (float | torch.Tensor): [...] far plane to clip + +Returns: + (torch.Tensor): [...] linear depth""" + utils3d.torch.transforms.depth_buffer_to_linear + +@overload +def project_gl(points: torch_.Tensor, model: torch_.Tensor = None, view: torch_.Tensor = None, perspective: torch_.Tensor = None) -> Tuple[torch_.Tensor, torch_.Tensor]: + """Project 3D points to 2D following the OpenGL convention (except for row major matrice) + +Args: + points (torch.Tensor): [..., N, 3 or 4] 3D points to project, if the last + dimension is 4, the points are assumed to be in homogeneous coordinates + model (torch.Tensor): [..., 4, 4] model matrix + view (torch.Tensor): [..., 4, 4] view matrix + perspective (torch.Tensor): [..., 4, 4] perspective matrix + +Returns: + scr_coord (torch.Tensor): [..., N, 3] screen space coordinates, value ranging in [0, 1]. + The origin (0., 0., 0.) is corresponding to the left & bottom & nearest + linear_depth (torch.Tensor): [..., N] linear depth""" + utils3d.torch.transforms.project_gl + +@overload +def project_cv(points: torch_.Tensor, extrinsics: torch_.Tensor = None, intrinsics: torch_.Tensor = None) -> Tuple[torch_.Tensor, torch_.Tensor]: + """Project 3D points to 2D following the OpenCV convention + +Args: + points (torch.Tensor): [..., N, 3] or [..., N, 4] 3D points to project, if the last + dimension is 4, the points are assumed to be in homogeneous coordinates + extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix + intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix + +Returns: + uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + linear_depth (torch.Tensor): [..., N] linear depth""" + utils3d.torch.transforms.project_cv + +@overload +def unproject_gl(screen_coord: torch_.Tensor, model: torch_.Tensor = None, view: torch_.Tensor = None, perspective: torch_.Tensor = None) -> torch_.Tensor: + """Unproject screen space coordinates to 3D view space following the OpenGL convention (except for row major matrice) + +Args: + screen_coord (torch.Tensor): [... N, 3] screen space coordinates, value ranging in [0, 1]. + The origin (0., 0., 0.) is corresponding to the left & bottom & nearest + model (torch.Tensor): [..., 4, 4] model matrix + view (torch.Tensor): [..., 4, 4] view matrix + perspective (torch.Tensor): [..., 4, 4] perspective matrix + +Returns: + points (torch.Tensor): [..., N, 3] 3d points""" + utils3d.torch.transforms.unproject_gl + +@overload +def unproject_cv(uv_coord: torch_.Tensor, depth: torch_.Tensor, extrinsics: torch_.Tensor = None, intrinsics: torch_.Tensor = None) -> torch_.Tensor: + """Unproject uv coordinates to 3D view space following the OpenCV convention + +Args: + uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + depth (torch.Tensor): [..., N] depth value + extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix + intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix + +Returns: + points (torch.Tensor): [..., N, 3] 3d points""" + utils3d.torch.transforms.unproject_cv + +@overload +def skew_symmetric(v: torch_.Tensor): + """Skew symmetric matrix from a 3D vector""" + utils3d.torch.transforms.skew_symmetric + +@overload +def rotation_matrix_from_vectors(v1: torch_.Tensor, v2: torch_.Tensor): + """Rotation matrix that rotates v1 to v2""" + utils3d.torch.transforms.rotation_matrix_from_vectors + +@overload +def euler_axis_angle_rotation(axis: str, angle: torch_.Tensor) -> torch_.Tensor: + """Return the rotation matrices for one of the rotations about an axis +of which Euler angles describe, for each value of the angle given. + +Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + +Returns: + Rotation matrices as tensor of shape (..., 3, 3).""" + utils3d.torch.transforms.euler_axis_angle_rotation + +@overload +def euler_angles_to_matrix(euler_angles: torch_.Tensor, convention: str = 'XYZ') -> torch_.Tensor: + """Convert rotations given as Euler angles in radians to rotation matrices. + +Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3), XYZ + convention: permutation of "X", "Y" or "Z", representing the order of Euler rotations to apply. + +Returns: + Rotation matrices as tensor of shape (..., 3, 3).""" + utils3d.torch.transforms.euler_angles_to_matrix + +@overload +def matrix_to_euler_angles(matrix: torch_.Tensor, convention: str) -> torch_.Tensor: + """Convert rotations given as rotation matrices to Euler angles in radians. +NOTE: The composition order eg. `XYZ` means `Rz * Ry * Rx` (like blender), instead of `Rx * Ry * Rz` (like pytorch3d) + +Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + +Returns: + Euler angles in radians as tensor of shape (..., 3), in the order of XYZ (like blender), instead of convention (like pytorch3d)""" + utils3d.torch.transforms.matrix_to_euler_angles + +@overload +def matrix_to_quaternion(rot_mat: torch_.Tensor, eps: float = 1e-12) -> torch_.Tensor: + """Convert 3x3 rotation matrix to quaternion (w, x, y, z) + +Args: + rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert + +Returns: + torch.Tensor: shape (..., 4), the quaternions corresponding to the given rotation matrices""" + utils3d.torch.transforms.matrix_to_quaternion + +@overload +def quaternion_to_matrix(quaternion: torch_.Tensor, eps: float = 1e-12) -> torch_.Tensor: + """Converts a batch of quaternions (w, x, y, z) to rotation matrices + +Args: + quaternion (torch.Tensor): shape (..., 4), the quaternions to convert + +Returns: + torch.Tensor: shape (..., 3, 3), the rotation matrices corresponding to the given quaternions""" + utils3d.torch.transforms.quaternion_to_matrix + +@overload +def matrix_to_axis_angle(rot_mat: torch_.Tensor, eps: float = 1e-12) -> torch_.Tensor: + """Convert a batch of 3x3 rotation matrices to axis-angle representation (rotation vector) + +Args: + rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert + +Returns: + torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given rotation matrices""" + utils3d.torch.transforms.matrix_to_axis_angle + +@overload +def axis_angle_to_matrix(axis_angle: torch_.Tensor, eps: float = 1e-12) -> torch_.Tensor: + """Convert axis-angle representation (rotation vector) to rotation matrix, whose direction is the axis of rotation and length is the angle of rotation + +Args: + axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors + +Returns: + torch.Tensor: shape (..., 3, 3) The rotation matrices for the given axis-angle parameters""" + utils3d.torch.transforms.axis_angle_to_matrix + +@overload +def axis_angle_to_quaternion(axis_angle: torch_.Tensor, eps: float = 1e-12) -> torch_.Tensor: + """Convert axis-angle representation (rotation vector) to quaternion (w, x, y, z) + +Args: + axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors + +Returns: + torch.Tensor: shape (..., 4) The quaternions for the given axis-angle parameters""" + utils3d.torch.transforms.axis_angle_to_quaternion + +@overload +def quaternion_to_axis_angle(quaternion: torch_.Tensor, eps: float = 1e-12) -> torch_.Tensor: + """Convert a batch of quaternions (w, x, y, z) to axis-angle representation (rotation vector) + +Args: + quaternion (torch.Tensor): shape (..., 4), the quaternions to convert + +Returns: + torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given quaternions""" + utils3d.torch.transforms.quaternion_to_axis_angle + +@overload +def slerp(rot_mat_1: torch_.Tensor, rot_mat_2: torch_.Tensor, t: Union[numbers.Number, torch_.Tensor]) -> torch_.Tensor: + """Spherical linear interpolation between two rotation matrices + +Args: + rot_mat_1 (torch.Tensor): shape (..., 3, 3), the first rotation matrix + rot_mat_2 (torch.Tensor): shape (..., 3, 3), the second rotation matrix + t (torch.Tensor): scalar or shape (...,), the interpolation factor + +Returns: + torch.Tensor: shape (..., 3, 3), the interpolated rotation matrix""" + utils3d.torch.transforms.slerp + +@overload +def interpolate_extrinsics(ext1: torch_.Tensor, ext2: torch_.Tensor, t: Union[numbers.Number, torch_.Tensor]) -> torch_.Tensor: + """Interpolate extrinsics between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation. + +Args: + ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose + ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose + t (torch.Tensor): scalar or shape (...,), the interpolation factor + +Returns: + torch.Tensor: shape (..., 4, 4), the interpolated camera pose""" + utils3d.torch.transforms.interpolate_extrinsics + +@overload +def interpolate_view(view1: torch_.Tensor, view2: torch_.Tensor, t: Union[numbers.Number, torch_.Tensor]): + """Interpolate view matrices between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation. + +Args: + ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose + ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose + t (torch.Tensor): scalar or shape (...,), the interpolation factor + +Returns: + torch.Tensor: shape (..., 4, 4), the interpolated camera pose""" + utils3d.torch.transforms.interpolate_view + +@overload +def extrinsics_to_essential(extrinsics: torch_.Tensor): + """extrinsics matrix `[[R, t] [0, 0, 0, 1]]` such that `x' = R (x - t)` to essential matrix such that `x' E x = 0` + +Args: + extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix + +Returns: + (torch.Tensor): [..., 3, 3] essential matrix""" + utils3d.torch.transforms.extrinsics_to_essential + +@overload +def to4x4(R: torch_.Tensor, t: torch_.Tensor): + """Compose rotation matrix and translation vector to 4x4 transformation matrix + +Args: + R (torch.Tensor): [..., 3, 3] rotation matrix + t (torch.Tensor): [..., 3] translation vector + +Returns: + (torch.Tensor): [..., 4, 4] transformation matrix""" + utils3d.torch.transforms.to4x4 + +@overload +def rotation_matrix_2d(theta: Union[float, torch_.Tensor]): + """2x2 matrix for 2D rotation + +Args: + theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,) + +Returns: + (torch.Tensor): (..., 2, 2) rotation matrix""" + utils3d.torch.transforms.rotation_matrix_2d + +@overload +def rotate_2d(theta: Union[float, torch_.Tensor], center: torch_.Tensor = None): + """3x3 matrix for 2D rotation around a center +``` + [[Rxx, Rxy, tx], + [Ryx, Ryy, ty], + [0, 0, 1]] +``` +Args: + theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,) + center (torch.Tensor): rotation center, arbitrary shape (..., 2). Default to (0, 0) + +Returns: + (torch.Tensor): (..., 3, 3) transformation matrix""" + utils3d.torch.transforms.rotate_2d + +@overload +def translate_2d(translation: torch_.Tensor): + """Translation matrix for 2D translation +``` + [[1, 0, tx], + [0, 1, ty], + [0, 0, 1]] +``` +Args: + translation (torch.Tensor): translation vector, arbitrary shape (..., 2) + +Returns: + (torch.Tensor): (..., 3, 3) transformation matrix""" + utils3d.torch.transforms.translate_2d + +@overload +def scale_2d(scale: Union[float, torch_.Tensor], center: torch_.Tensor = None): + """Scale matrix for 2D scaling +``` + [[s, 0, tx], + [0, s, ty], + [0, 0, 1]] +``` +Args: + scale (float | torch.Tensor): scale factor, arbitrary shape (...,) + center (torch.Tensor): scale center, arbitrary shape (..., 2). Default to (0, 0) + +Returns: + (torch.Tensor): (..., 3, 3) transformation matrix""" + utils3d.torch.transforms.scale_2d + +@overload +def apply_2d(transform: torch_.Tensor, points: torch_.Tensor): + """Apply (3x3 or 2x3) 2D affine transformation to points +``` + p = R @ p + t +``` +Args: + transform (torch.Tensor): (..., 2 or 3, 3) transformation matrix + points (torch.Tensor): (..., N, 2) points to transform + +Returns: + (torch.Tensor): (..., N, 2) transformed points""" + utils3d.torch.transforms.apply_2d + +@overload +def RastContext(nvd_ctx: Union[nvdiffrast.torch.ops.RasterizeCudaContext, nvdiffrast.torch.ops.RasterizeGLContext] = None, *, backend: Literal['cuda', 'gl'] = 'gl', device: Union[str, torch_.device] = None): + """Create a rasterization context. Nothing but a wrapper of nvdiffrast.torch.RasterizeCudaContext or nvdiffrast.torch.RasterizeGLContext.""" + utils3d.torch.rasterization.RastContext + +@overload +def rasterize_triangle_faces(ctx: utils3d.torch.rasterization.RastContext, vertices: torch_.Tensor, faces: torch_.Tensor, attr: torch_.Tensor, width: int, height: int, model: torch_.Tensor = None, view: torch_.Tensor = None, projection: torch_.Tensor = None, antialiasing: Union[bool, List[int]] = True, diff_attrs: Optional[List[int]] = None) -> Tuple[torch_.Tensor, torch_.Tensor, Optional[torch_.Tensor]]: + """Rasterize a mesh with vertex attributes. + +Args: + ctx (GLContext): rasterizer context + vertices (np.ndarray): (B, N, 2 or 3 or 4) + faces (torch.Tensor): (T, 3) + attr (torch.Tensor): (B, N, C) + width (int): width of the output image + height (int): height of the output image + model (torch.Tensor, optional): ([B,] 4, 4) model matrix. Defaults to None (identity). + view (torch.Tensor, optional): ([B,] 4, 4) view matrix. Defaults to None (identity). + projection (torch.Tensor, optional): ([B,] 4, 4) projection matrix. Defaults to None (identity). + antialiasing (Union[bool, List[int]], optional): whether to perform antialiasing. Defaults to True. If a list of indices is provided, only those channels will be antialiased. + diff_attrs (Union[None, List[int]], optional): indices of attributes to compute screen-space derivatives. Defaults to None. + +Returns: + image: (torch.Tensor): (B, C, H, W) + depth: (torch.Tensor): (B, H, W) screen space depth, ranging from 0 (near) to 1. (far) + NOTE: Empty pixels will have depth 1., i.e. far plane.""" + utils3d.torch.rasterization.rasterize_triangle_faces + +@overload +def warp_image_by_depth(ctx: utils3d.torch.rasterization.RastContext, depth: torch_.FloatTensor, image: torch_.FloatTensor = None, mask: torch_.BoolTensor = None, width: int = None, height: int = None, *, extrinsics_src: torch_.FloatTensor = None, extrinsics_tgt: torch_.FloatTensor = None, intrinsics_src: torch_.FloatTensor = None, intrinsics_tgt: torch_.FloatTensor = None, near: float = 0.1, far: float = 100.0, antialiasing: bool = True, backslash: bool = False, padding: int = 0, return_uv: bool = False, return_dr: bool = False) -> Tuple[torch_.FloatTensor, torch_.FloatTensor, torch_.BoolTensor, Optional[torch_.FloatTensor], Optional[torch_.FloatTensor]]: + """Warp image by depth. +NOTE: if batch size is 1, image mesh will be triangulated aware of the depth, yielding less distorted results. +Otherwise, image mesh will be triangulated simply for batch rendering. + +Args: + ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): rasterization context + depth (torch.Tensor): (B, H, W) linear depth + image (torch.Tensor): (B, C, H, W). None to use image space uv. Defaults to None. + width (int, optional): width of the output image. None to use the same as depth. Defaults to None. + height (int, optional): height of the output image. Defaults the same as depth.. + extrinsics_src (torch.Tensor, optional): (B, 4, 4) extrinsics matrix for source. None to use identity. Defaults to None. + extrinsics_tgt (torch.Tensor, optional): (B, 4, 4) extrinsics matrix for target. None to use identity. Defaults to None. + intrinsics_src (torch.Tensor, optional): (B, 3, 3) intrinsics matrix for source. None to use the same as target. Defaults to None. + intrinsics_tgt (torch.Tensor, optional): (B, 3, 3) intrinsics matrix for target. None to use the same as source. Defaults to None. + near (float, optional): near plane. Defaults to 0.1. + far (float, optional): far plane. Defaults to 100.0. + antialiasing (bool, optional): whether to perform antialiasing. Defaults to True. + backslash (bool, optional): whether to use backslash triangulation. Defaults to False. + padding (int, optional): padding of the image. Defaults to 0. + return_uv (bool, optional): whether to return the uv. Defaults to False. + return_dr (bool, optional): whether to return the image-space derivatives of uv. Defaults to False. + +Returns: + image: (torch.FloatTensor): (B, C, H, W) rendered image + depth: (torch.FloatTensor): (B, H, W) linear depth, ranging from 0 to inf + mask: (torch.BoolTensor): (B, H, W) mask of valid pixels + uv: (torch.FloatTensor): (B, 2, H, W) image-space uv + dr: (torch.FloatTensor): (B, 4, H, W) image-space derivatives of uv""" + utils3d.torch.rasterization.warp_image_by_depth + +@overload +def warp_image_by_forward_flow(ctx: utils3d.torch.rasterization.RastContext, image: torch_.FloatTensor, flow: torch_.FloatTensor, depth: torch_.FloatTensor = None, *, antialiasing: bool = True, backslash: bool = False) -> Tuple[torch_.FloatTensor, torch_.BoolTensor]: + """Warp image by forward flow. +NOTE: if batch size is 1, image mesh will be triangulated aware of the depth, yielding less distorted results. +Otherwise, image mesh will be triangulated simply for batch rendering. + +Args: + ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): rasterization context + image (torch.Tensor): (B, C, H, W) image + flow (torch.Tensor): (B, 2, H, W) forward flow + depth (torch.Tensor, optional): (B, H, W) linear depth. If None, will use the same for all pixels. Defaults to None. + antialiasing (bool, optional): whether to perform antialiasing. Defaults to True. + backslash (bool, optional): whether to use backslash triangulation. Defaults to False. + +Returns: + image: (torch.FloatTensor): (B, C, H, W) rendered image + mask: (torch.BoolTensor): (B, H, W) mask of valid pixels""" + utils3d.torch.rasterization.warp_image_by_forward_flow + diff --git a/src/utils3d/io/__init__.py b/src/utils3d/io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e67b88d2d06b04520fec1cb21b70bdda521eafed --- /dev/null +++ b/src/utils3d/io/__init__.py @@ -0,0 +1,3 @@ +from .obj import * +from .colmap import * +from .ply import * diff --git a/src/utils3d/io/colmap.py b/src/utils3d/io/colmap.py new file mode 100644 index 0000000000000000000000000000000000000000..d00ccbea8b3974e45fc91678000e9083e4ce378b --- /dev/null +++ b/src/utils3d/io/colmap.py @@ -0,0 +1,139 @@ +from typing import * +from pathlib import Path + +import numpy as np +from scipy.spatial.transform import Rotation + + +__all__ = ['read_extrinsics_from_colmap', 'read_intrinsics_from_colmap', 'write_extrinsics_as_colmap', 'write_intrinsics_as_colmap'] + + +def write_extrinsics_as_colmap(file: Union[str, Path], extrinsics: np.ndarray, image_names: Union[str, List[str]] = 'image_{i:04d}.png', camera_ids: List[int] = None): + """ + Write extrinsics to colmap `images.txt` file. + Args: + file: Path to `images.txt` file. + extrinsics: (N, 4, 4) array of extrinsics. + image_names: str or List of str, image names. Length is N. + If str, it should be a format string with `i` as the index. (i starts from 1, in correspondence with IMAGE_ID in colmap) + camera_ids: List of int, camera ids. Length is N. + If None, it will be set to [1, 2, ..., N]. + """ + assert extrinsics.shape[1:] == (4, 4) and extrinsics.ndim == 3 or extrinsics.shape == (4, 4) + if extrinsics.ndim == 2: + extrinsics = extrinsics[np.newaxis, ...] + quats = Rotation.from_matrix(extrinsics[:, :3, :3]).as_quat() + trans = extrinsics[:, :3, 3] + if camera_ids is None: + camera_ids = list(range(1, len(extrinsics) + 1)) + if isinstance(image_names, str): + image_names = [image_names.format(i=i) for i in range(1, len(extrinsics) + 1)] + assert len(extrinsics) == len(image_names) == len(camera_ids), \ + f'Number of extrinsics ({len(extrinsics)}), image_names ({len(image_names)}), and camera_ids ({len(camera_ids)}) must be the same' + with open(file, 'w') as fp: + print("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME", file=fp) + for i, (quat, t, name, camera_id) in enumerate(zip(quats.tolist(), trans.tolist(), image_names, camera_ids)): + # Colmap has wxyz order while scipy.spatial.transform.Rotation has xyzw order. + qx, qy, qz, qw = quat + tx, ty, tz = t + print(f'{i + 1} {qw:f} {qx:f} {qy:f} {qz:f} {tx:f} {ty:f} {tz:f} {camera_id:d} {name}', file=fp) + print() + + +def write_intrinsics_as_colmap(file: Union[str, Path], intrinsics: np.ndarray, width: int, height: int, normalized: bool = False): + """ + Write intrinsics to colmap `cameras.txt` file. Currently only support PINHOLE model (no distortion) + Args: + file: Path to `cameras.txt` file. + intrinsics: (N, 3, 3) array of intrinsics. + width: Image width. + height: Image height. + normalized: Whether the intrinsics are normalized. If True, the intrinsics will unnormalized for writing. + """ + assert intrinsics.shape[1:] == (3, 3) and intrinsics.ndim == 3 or intrinsics.shape == (3, 3) + if intrinsics.ndim == 2: + intrinsics = intrinsics[np.newaxis, ...] + if normalized: + intrinsics = intrinsics * np.array([width, height, 1])[:, None] + with open(file, 'w') as fp: + print("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]", file=fp) + for i, intr in enumerate(intrinsics): + fx, fy, cx, cy = intr[0, 0], intr[1, 1], intr[0, 2], intr[1, 2] + print(f'{i + 1} PINHOLE {width:d} {height:d} {fx:f} {fy:f} {cx:f} {cy:f}', file=fp) + + +def read_extrinsics_from_colmap(file: Union[str, Path]) -> Union[np.ndarray, List[int], List[str]]: + """ + Read extrinsics from colmap `images.txt` file. + Args: + file: Path to `images.txt` file. + Returns: + extrinsics: (N, 4, 4) array of extrinsics. + camera_ids: List of int, camera ids. Length is N. Note that camera ids in colmap typically starts from 1. + image_names: List of str, image names. Length is N. + """ + with open(file) as fp: + lines = fp.readlines() + image_names, quats, trans, camera_ids = [], [], [], [] + i_line = 0 + for line in lines: + line = line.strip() + if line.startswith('#'): + continue + i_line += 1 + if i_line % 2 == 0: + continue + image_id, qw, qx, qy, qz, tx, ty, tz, camera_id, name = line.split() + quats.append([float(qx), float(qy), float(qz), float(qw)]) + trans.append([float(tx), float(ty), float(tz)]) + camera_ids.append(int(camera_id)) + image_names.append(name) + + quats = np.array(quats, dtype=np.float32) + trans = np.array(trans, dtype=np.float32) + rotation = Rotation.from_quat(quats).as_matrix() + extrinsics = np.concatenate([ + np.concatenate([rotation, trans[..., None]], axis=-1), + np.array([0, 0, 0, 1], dtype=np.float32)[None, None, :].repeat(len(quats), axis=0) + ], axis=-2) + + return extrinsics, camera_ids, image_names + + +def read_intrinsics_from_colmap(file: Union[str, Path], normalize: bool = False) -> Tuple[List[int], np.ndarray, np.ndarray]: + """ + Read intrinsics from colmap `cameras.txt` file. + Args: + file: Path to `cameras.txt` file. + normalize: Whether to normalize the intrinsics. If True, the intrinsics will be normalized. (mapping coordinates to [0, 1] range) + Returns: + camera_ids: List of int, camera ids. Length is N. Note that camera ids in colmap typically starts from 1. + intrinsics: (N, 3, 3) array of intrinsics. + distortions: (N, 5) array of distortions. + """ + with open(file) as fp: + lines = fp.readlines() + intrinsics, distortions, camera_ids = [], [], [] + for line in lines: + line = line.strip() + if not line or line.startswith('#'): + continue + camera_id, model, width, height, *params = line.split() + camera_id, width, height = int(camera_id), int(width), int(height) + if model == 'PINHOLE': + fx, fy, cx, cy = map(float, params[:4]) + k1 = k2 = k3 = p1 = p2 = 0.0 + elif model == 'OPENCV': + fx, fy, cx, cy, k1, k2, p1, p2, k3 = *map(float, params[:8]), 0.0 + elif model == 'SIMPLE_RADIAL': + f, cx, cy, k = map(float, params[:4]) + fx = fy = f + k1, k2, p1, p2, k3 = k, 0.0, 0.0, 0.0, 0.0 + camera_ids.append(camera_id) + if normalize: + fx, fy, cx, cy = fx / width, fy / height, cx / width, cy / height + intrinsics.append([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) + distortions.append([k1, k2, p1, p2, k3]) + intrinsics = np.array(intrinsics, dtype=np.float32) + distortions = np.array(distortions, dtype=np.float32) + return camera_ids, intrinsics, distortions diff --git a/src/utils3d/io/obj.py b/src/utils3d/io/obj.py new file mode 100644 index 0000000000000000000000000000000000000000..3471e490bd758fbf58173cfb7297ec747f46f173 --- /dev/null +++ b/src/utils3d/io/obj.py @@ -0,0 +1,146 @@ +from io import TextIOWrapper +from typing import Dict, Any, Union, Iterable +import numpy as np +from pathlib import Path + +__all__ = [ + 'read_obj', + 'write_obj', + 'simple_write_obj' +] + +def read_obj( + file : Union[str, Path, TextIOWrapper], + encoding: Union[str, None] = None, + ignore_unknown: bool = False +): + """ + Read wavefront .obj file, without preprocessing. + + Why bothering having this read_obj() while we already have other libraries like `trimesh`? + This function read the raw format from .obj file and keeps the order of vertices and faces, + while trimesh which involves modification like merge/split vertices, which could break the orders of vertices and faces, + Those libraries are commonly aiming at geometry processing and rendering supporting various formats. + If you want mesh geometry processing, you may turn to `trimesh` for more features. + + ### Parameters + `file` (str, Path, TextIOWrapper): filepath or file object + encoding (str, optional): + + ### Returns + obj (dict): A dict containing .obj components + { + 'mtllib': [], + 'v': [[0,1, 0.2, 1.0], [1.2, 0.0, 0.0], ...], + 'vt': [[0.5, 0.5], ...], + 'vn': [[0., 0.7, 0.7], [0., -0.7, 0.7], ...], + 'f': [[0, 1, 2], [2, 3, 4],...], + 'usemtl': [{'name': 'mtl1', 'f': 7}] + } + """ + if hasattr(file,'read'): + lines = file.read().splitlines() + else: + with open(file, 'r', encoding=encoding) as fp: + lines = fp.read().splitlines() + mtllib = [] + v, vt, vn, vp = [], [], [], [] # Vertex coordinates, Vertex texture coordinate, Vertex normal, Vertex parameter + f, ft, fn = [], [], [] # Face indices, Face texture indices, Face normal indices + o = [] + s = [] + usemtl = [] + + def pad(l: list, n: Any): + return l + [n] * (3 - len(l)) + + for i, line in enumerate(lines): + sq = line.strip().split() + if len(sq) == 0: + continue + if sq[0] == 'v': + assert 4 <= len(sq) <= 5, f'Invalid format of line {i}: {line}' + v.append([float(e) for e in sq[1:]][:3]) + elif sq[0] == 'vt': + assert 3 <= len(sq) <= 4, f'Invalid format of line {i}: {line}' + vt.append([float(e) for e in sq[1:]][:2]) + elif sq[0] == 'vn': + assert len(sq) == 4, f'Invalid format of line {i}: {line}' + vn.append([float(e) for e in sq[1:]]) + elif sq[0] == 'vp': + assert 2 <= len(sq) <= 4, f'Invalid format of line {i}: {line}' + vp.append(pad([float(e) for e in sq[1:]], 0)) + elif sq[0] == 'f': + spliting = [pad([int(j) - 1 for j in e.split('/')], -1) for e in sq[1:]] + f.append([e[0] for e in spliting]) + ft.append([e[1] for e in spliting]) + fn.append([e[2] for e in spliting]) + elif sq[0] == 'usemtl': + assert len(sq) == 2 + usemtl.append((sq[1], len(f))) + elif sq[0] == 'o': + assert len(sq) == 2 + o.append((sq[1], len(f))) + elif sq[0] == 's': + s.append((sq[1], len(f))) + elif sq[0] == 'mtllib': + assert len(sq) == 2 + mtllib.append(sq[1]) + elif sq[0][0] == '#': + continue + else: + if not ignore_unknown: + raise Exception(f'Unknown keyword {sq[0]}') + + min_poly_vertices = min(len(f) for f in f) + max_poly_vertices = max(len(f) for f in f) + + return { + 'mtllib': mtllib, + 'v': np.array(v, dtype=np.float32), + 'vt': np.array(vt, dtype=np.float32), + 'vn': np.array(vn, dtype=np.float32), + 'vp': np.array(vp, dtype=np.float32), + 'f': np.array(f, dtype=np.int32) if min_poly_vertices == max_poly_vertices else f, + 'ft': np.array(ft, dtype=np.int32) if min_poly_vertices == max_poly_vertices else ft, + 'fn': np.array(fn, dtype=np.int32) if min_poly_vertices == max_poly_vertices else fn, + 'o': o, + 's': s, + 'usemtl': usemtl, + } + + +def write_obj( + file: Union[str, Path], + obj: Dict[str, Any], + encoding: Union[str, None] = None + ): + with open(file, 'w', encoding=encoding) as fp: + for k in ['v', 'vt', 'vn', 'vp']: + if k not in obj: + continue + for v in obj[k]: + print(k, *map(float, v), file=fp) + for f in obj['f']: + print('f', *((str('/').join(map(int, i)) if isinstance(int(i), Iterable) else i) for i in f), file=fp) + + +def simple_write_obj( + file: Union[str, Path], + vertices: np.ndarray, + faces: np.ndarray, + encoding: Union[str, None] = None + ): + """ + Write wavefront .obj file, without preprocessing. + + Args: + vertices (np.ndarray): [N, 3] + faces (np.ndarray): [T, 3] + file (Any): filepath + encoding (str, optional): + """ + with open(file, 'w', encoding=encoding) as fp: + for v in vertices: + print('v', *map(float, v), file=fp) + for f in faces: + print('f', *map(int, f + 1), file=fp) diff --git a/src/utils3d/io/ply.py b/src/utils3d/io/ply.py new file mode 100644 index 0000000000000000000000000000000000000000..39fa41728a7be76d25743788c85dacb384d6d83e --- /dev/null +++ b/src/utils3d/io/ply.py @@ -0,0 +1,104 @@ +import numpy as np + +from typing import * +from pathlib import Path + + +def read_ply( + file: Union[str, Path], + encoding: Union[str, None] = None, + ignore_unknown: bool = False +) -> Tuple[np.ndarray, np.ndarray]: + """ + Read .ply file, without preprocessing. + + Args: + file (Any): filepath + encoding (str, optional): + + Returns: + Tuple[np.ndarray, np.ndarray]: vertices, faces + """ + import plyfile + plydata = plyfile.PlyData.read(file) + vertices = np.stack([plydata['vertex'][k] for k in ['x', 'y', 'z']], axis=-1) + if 'face' in plydata: + faces = np.array(plydata['face']['vertex_indices'].tolist()) + else: + faces = None + return vertices, faces + + +def write_ply( + file: Union[str, Path], + vertices: np.ndarray, + faces: np.ndarray = None, + edges: np.ndarray = None, + vertex_colors: np.ndarray = None, + edge_colors: np.ndarray = None, + text: bool = False +): + """ + Write .ply file, without preprocessing. + + Args: + file (Any): filepath + vertices (np.ndarray): [N, 3] + faces (np.ndarray): [T, E] + edges (np.ndarray): [E, 2] + vertex_colors (np.ndarray, optional): [N, 3]. Defaults to None. + edge_colors (np.ndarray, optional): [E, 3]. Defaults to None. + text (bool, optional): save data in text format. Defaults to False. + """ + import plyfile + assert vertices.ndim == 2 and vertices.shape[1] == 3 + vertices = vertices.astype(np.float32) + if faces is not None: + assert faces.ndim == 2 + faces = faces.astype(np.int32) + if edges is not None: + assert edges.ndim == 2 and edges.shape[1] == 2 + edges = edges.astype(np.int32) + + if vertex_colors is not None: + assert vertex_colors.ndim == 2 and vertex_colors.shape[1] == 3 + if vertex_colors.dtype in [np.float32, np.float64]: + vertex_colors = vertex_colors * 255 + vertex_colors = np.clip(vertex_colors, 0, 255).astype(np.uint8) + vertices_data = np.zeros(len(vertices), dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) + vertices_data['x'] = vertices[:, 0] + vertices_data['y'] = vertices[:, 1] + vertices_data['z'] = vertices[:, 2] + vertices_data['red'] = vertex_colors[:, 0] + vertices_data['green'] = vertex_colors[:, 1] + vertices_data['blue'] = vertex_colors[:, 2] + else: + vertices_data = np.array([tuple(v) for v in vertices], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) + + if faces is not None: + faces_data = np.zeros(len(faces), dtype=[('vertex_indices', 'i4', (faces.shape[1],))]) + faces_data['vertex_indices'] = faces + + if edges is not None: + if edge_colors is not None: + assert edge_colors.ndim == 2 and edge_colors.shape[1] == 3 + if edge_colors.dtype in [np.float32, np.float64]: + edge_colors = edge_colors * 255 + edge_colors = np.clip(edge_colors, 0, 255).astype(np.uint8) + edges_data = np.zeros(len(edges), dtype=[('vertex1', 'i4'), ('vertex2', 'i4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) + edges_data['vertex1'] = edges[:, 0] + edges_data['vertex2'] = edges[:, 1] + edges_data['red'] = edge_colors[:, 0] + edges_data['green'] = edge_colors[:, 1] + edges_data['blue'] = edge_colors[:, 2] + else: + edges_data = np.array([tuple(e) for e in edges], dtype=[('vertex1', 'i4'), ('vertex2', 'i4')]) + + ply_data = [plyfile.PlyElement.describe(vertices_data, 'vertex')] + if faces is not None: + ply_data.append(plyfile.PlyElement.describe(faces_data, 'face')) + if edges is not None: + ply_data.append(plyfile.PlyElement.describe(edges_data, 'edge')) + + plyfile.PlyData(ply_data, text=text).write(file) + \ No newline at end of file diff --git a/src/utils3d/numpy/__init__.py b/src/utils3d/numpy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa06c53abe3b4abd39f1d7f8372851d7cdc58260 --- /dev/null +++ b/src/utils3d/numpy/__init__.py @@ -0,0 +1,142 @@ +""" +3D utility functions workings with NumPy. +""" +import importlib +import itertools +import numpy +from typing import TYPE_CHECKING + + +__modules_all__ = { + 'mesh':[ + 'triangulate', + 'compute_face_normal', + 'compute_face_angle', + 'compute_vertex_normal', + 'compute_vertex_normal_weighted', + 'remove_corrupted_faces', + 'merge_duplicate_vertices', + 'remove_unreferenced_vertices', + 'subdivide_mesh_simple', + 'mesh_relations', + 'flatten_mesh_indices' + ], + 'quadmesh': [ + 'calc_quad_candidates', + 'calc_quad_distortion', + 'calc_quad_direction', + 'calc_quad_smoothness', + 'sovle_quad', + 'sovle_quad_qp', + 'tri_to_quad' + ], + 'utils': [ + 'sliding_window_1d', + 'sliding_window_nd', + 'sliding_window_2d', + 'max_pool_1d', + 'max_pool_2d', + 'max_pool_nd', + 'depth_edge', + 'normals_edge', + 'depth_aliasing', + 'interpolate', + 'image_scrcoord', + 'image_uv', + 'image_pixel_center', + 'image_pixel', + 'image_mesh', + 'image_mesh_from_depth', + 'depth_to_normals', + 'points_to_normals', + 'chessboard', + 'cube', + 'icosahedron', + 'square', + 'camera_frustum', + ], + 'transforms': [ + 'perspective', + 'perspective_from_fov', + 'perspective_from_fov_xy', + 'intrinsics_from_focal_center', + 'intrinsics_from_fov', + 'fov_to_focal', + 'focal_to_fov', + 'intrinsics_to_fov', + 'view_look_at', + 'extrinsics_look_at', + 'perspective_to_intrinsics', + 'perspective_to_near_far', + 'intrinsics_to_perspective', + 'extrinsics_to_view', + 'view_to_extrinsics', + 'normalize_intrinsics', + 'crop_intrinsics', + 'pixel_to_uv', + 'pixel_to_ndc', + 'uv_to_pixel', + 'project_depth', + 'depth_buffer_to_linear', + 'unproject_cv', + 'unproject_gl', + 'project_cv', + 'project_gl', + 'quaternion_to_matrix', + 'axis_angle_to_matrix', + 'matrix_to_quaternion', + 'extrinsics_to_essential', + 'euler_axis_angle_rotation', + 'euler_angles_to_matrix', + 'skew_symmetric', + 'rotation_matrix_from_vectors', + 'ray_intersection', + 'se3_matrix', + 'slerp_quaternion', + 'slerp_vector', + 'lerp', + 'lerp_se3_matrix', + 'piecewise_lerp', + 'piecewise_lerp_se3_matrix', + 'apply_transform' + ], + 'spline': [ + 'linear_spline_interpolate', + ], + 'rasterization': [ + 'RastContext', + 'rasterize_triangle_faces', + 'rasterize_edges', + 'texture', + 'warp_image_by_depth', + 'test_rasterization' + ], +} + + +__all__ = list(itertools.chain(*__modules_all__.values())) + +def __getattr__(name): + try: + return globals()[name] + except KeyError: + pass + + try: + module_name = next(m for m in __modules_all__ if name in __modules_all__[m]) + except StopIteration: + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + module = importlib.import_module(f'.{module_name}', __name__) + for key in __modules_all__[module_name]: + globals()[key] = getattr(module, key) + + return globals()[name] + + +if TYPE_CHECKING: + from .quadmesh import * + from .transforms import * + from .mesh import * + from .utils import * + from .rasterization import * + from .spline import * \ No newline at end of file diff --git a/src/utils3d/numpy/_helpers.py b/src/utils3d/numpy/_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..7c397df338e3e04e0f228341f68171d8e067eb4e --- /dev/null +++ b/src/utils3d/numpy/_helpers.py @@ -0,0 +1,93 @@ +# decorator +import numpy as np +from numbers import Number +import inspect +from functools import wraps +from typing import * +from .._helpers import suppress_traceback + + +def get_args_order(func, args, kwargs): + """ + Get the order of the arguments of a function. + """ + names = inspect.getfullargspec(func).args + names_idx = {name: i for i, name in enumerate(names)} + args_order = [] + kwargs_order = {} + for name, arg in kwargs.items(): + if name in names: + kwargs_order[name] = names_idx[name] + names.remove(name) + for i, arg in enumerate(args): + if i < len(names): + args_order.append(names_idx[names[i]]) + return args_order, kwargs_order + + +def broadcast_args(args, kwargs, args_dim, kwargs_dim): + spatial = [] + for arg, arg_dim in zip(args + list(kwargs.values()), args_dim + list(kwargs_dim.values())): + if isinstance(arg, np.ndarray) and arg_dim is not None: + arg_spatial = arg.shape[:arg.ndim-arg_dim] + if len(arg_spatial) > len(spatial): + spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial + for j in range(len(arg_spatial)): + if spatial[-j] < arg_spatial[-j]: + if spatial[-j] == 1: + spatial[-j] = arg_spatial[-j] + else: + raise ValueError("Cannot broadcast arguments.") + for i, arg in enumerate(args): + if isinstance(arg, np.ndarray) and args_dim[i] is not None: + args[i] = np.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-args_dim[i]:]]) + for key, arg in kwargs.items(): + if isinstance(arg, np.ndarray) and kwargs_dim[key] is not None: + kwargs[key] = np.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-kwargs_dim[key]:]]) + return args, kwargs, spatial + + +def batched(*dims): + """ + Decorator that allows a function to be called with batched arguments. + """ + def decorator(func): + @wraps(func) + @suppress_traceback + def wrapper(*args, **kwargs): + args = list(args) + # get arguments dimensions + args_order, kwargs_order = get_args_order(func, args, kwargs) + args_dim = [dims[i] for i in args_order] + kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()} + # convert to numpy array + for i, arg in enumerate(args): + if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None: + args[i] = np.array(arg) + for key, arg in kwargs.items(): + if isinstance(arg, (Number, list, tuple)) and kwargs_dim[key] is not None: + kwargs[key] = np.array(arg) + # broadcast arguments + args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim) + for i, (arg, arg_dim) in enumerate(zip(args, args_dim)): + if isinstance(arg, np.ndarray) and arg_dim is not None: + args[i] = arg.reshape([-1, *arg.shape[arg.ndim-arg_dim:]]) + for key, arg in kwargs.items(): + if isinstance(arg, np.ndarray) and kwargs_dim[key] is not None: + kwargs[key] = arg.reshape([-1, *arg.shape[arg.ndim-kwargs_dim[key]:]]) + # call function + results = func(*args, **kwargs) + type_results = type(results) + results = list(results) if isinstance(results, (tuple, list)) else [results] + # restore spatial dimensions + for i, result in enumerate(results): + results[i] = result.reshape([*spatial, *result.shape[1:]]) + if type_results == tuple: + results = tuple(results) + elif type_results == list: + results = list(results) + else: + results = results[0] + return results + return wrapper + return decorator diff --git a/src/utils3d/numpy/mesh.py b/src/utils3d/numpy/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..afadb5f2510b58a1c5acbabff2ff798c041744d6 --- /dev/null +++ b/src/utils3d/numpy/mesh.py @@ -0,0 +1,355 @@ +import numpy as np +from typing import * +from ._helpers import batched + + +__all__ = [ + 'triangulate', + 'compute_face_normal', + 'compute_face_angle', + 'compute_vertex_normal', + 'compute_vertex_normal_weighted', + 'remove_corrupted_faces', + 'merge_duplicate_vertices', + 'remove_unreferenced_vertices', + 'subdivide_mesh_simple', + 'mesh_relations', + 'flatten_mesh_indices' +] + + +def triangulate( + faces: np.ndarray, + vertices: np.ndarray = None, + backslash: np.ndarray = None +) -> np.ndarray: + """ + Triangulate a polygonal mesh. + + Args: + faces (np.ndarray): [L, P] polygonal faces + vertices (np.ndarray, optional): [N, 3] 3-dimensional vertices. + If given, the triangulation is performed according to the distance + between vertices. Defaults to None. + backslash (np.ndarray, optional): [L] boolean array indicating + how to triangulate the quad faces. Defaults to None. + + Returns: + (np.ndarray): [L * (P - 2), 3] triangular faces + """ + if faces.shape[-1] == 3: + return faces + P = faces.shape[-1] + if vertices is not None: + assert faces.shape[-1] == 4, "now only support quad mesh" + if backslash is None: + backslash = np.linalg.norm(vertices[faces[:, 0]] - vertices[faces[:, 2]], axis=-1) < \ + np.linalg.norm(vertices[faces[:, 1]] - vertices[faces[:, 3]], axis=-1) + if backslash is None: + loop_indice = np.stack([ + np.zeros(P - 2, dtype=int), + np.arange(1, P - 1, 1, dtype=int), + np.arange(2, P, 1, dtype=int) + ], axis=1) + return faces[:, loop_indice].reshape((-1, 3)) + else: + assert faces.shape[-1] == 4, "now only support quad mesh" + faces = np.where( + backslash[:, None], + faces[:, [0, 1, 2, 0, 2, 3]], + faces[:, [0, 1, 3, 3, 1, 2]] + ).reshape((-1, 3)) + return faces + + +@batched(2, None) +def compute_face_normal( + vertices: np.ndarray, + faces: np.ndarray +) -> np.ndarray: + """ + Compute face normals of a triangular mesh + + Args: + vertices (np.ndarray): [..., N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + + Returns: + normals (np.ndarray): [..., T, 3] face normals + """ + normal = np.cross( + vertices[..., faces[:, 1], :] - vertices[..., faces[:, 0], :], + vertices[..., faces[:, 2], :] - vertices[..., faces[:, 0], :] + ) + normal_norm = np.linalg.norm(normal, axis=-1, keepdims=True) + normal_norm[normal_norm == 0] = 1 + normal /= normal_norm + return normal + + +@batched(2, None) +def compute_face_angle( + vertices: np.ndarray, + faces: np.ndarray, + eps: float = 1e-12 + ) -> np.ndarray: + """ + Compute face angles of a triangular mesh + + Args: + vertices (np.ndarray): [..., N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + + Returns: + angles (np.ndarray): [..., T, 3] face angles + """ + face_angle = np.zeros_like(faces, dtype=vertices.dtype) + for i in range(3): + edge1 = vertices[..., faces[:, (i + 1) % 3], :] - vertices[..., faces[:, i], :] + edge2 = vertices[..., faces[:, (i + 2) % 3], :] - vertices[..., faces[:, i], :] + face_angle[..., i] = np.arccos(np.sum( + edge1 / np.clip(np.linalg.norm(edge1, axis=-1, keepdims=True), eps, None) * + edge2 / np.clip(np.linalg.norm(edge2, axis=-1, keepdims=True), eps, None), + axis=-1 + )) + return face_angle + + +@batched(2, None, 2) +def compute_vertex_normal( + vertices: np.ndarray, + faces: np.ndarray, + face_normal: np.ndarray = None +) -> np.ndarray: + """ + Compute vertex normals of a triangular mesh by averaging neightboring face normals + TODO: can be improved. + + Args: + vertices (np.ndarray): [..., N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + face_normal (np.ndarray, optional): [..., T, 3] face normals. + None to compute face normals from vertices and faces. Defaults to None. + + Returns: + normals (np.ndarray): [..., N, 3] vertex normals + """ + if face_normal is None: + face_normal = compute_face_normal(vertices, faces) + vertex_normal = np.zeros_like(vertices, dtype=vertices.dtype) + for n in range(vertices.shape[0]): + for i in range(3): + vertex_normal[n, :, 0] += np.bincount(faces[:, i], weights=face_normal[n, :, 0], minlength=vertices.shape[1]) + vertex_normal[n, :, 1] += np.bincount(faces[:, i], weights=face_normal[n, :, 1], minlength=vertices.shape[1]) + vertex_normal[n, :, 2] += np.bincount(faces[:, i], weights=face_normal[n, :, 2], minlength=vertices.shape[1]) + vertex_normal_norm = np.linalg.norm(vertex_normal, axis=-1, keepdims=True) + vertex_normal_norm[vertex_normal_norm == 0] = 1 + vertex_normal /= vertex_normal_norm + return vertex_normal + + +@batched(2, None, 2) +def compute_vertex_normal_weighted( + vertices: np.ndarray, + faces: np.ndarray, + face_normal: np.ndarray = None +) -> np.ndarray: + """ + Compute vertex normals of a triangular mesh by weighted sum of neightboring face normals + according to the angles + + Args: + vertices (np.ndarray): [..., N, 3] 3-dimensional vertices + faces (np.ndarray): [..., T, 3] triangular face indices + face_normal (np.ndarray, optional): [..., T, 3] face normals. + None to compute face normals from vertices and faces. Defaults to None. + + Returns: + normals (np.ndarray): [..., N, 3] vertex normals + """ + if face_normal is None: + face_normal = compute_face_normal(vertices, faces) + face_angle = compute_face_angle(vertices, faces) + vertex_normal = np.zeros_like(vertices) + for n in range(vertices.shape[0]): + for i in range(3): + vertex_normal[n, :, 0] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 0] * face_angle[n, :, i], minlength=vertices.shape[1]) + vertex_normal[n, :, 1] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 1] * face_angle[n, :, i], minlength=vertices.shape[1]) + vertex_normal[n, :, 2] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 2] * face_angle[n, :, i], minlength=vertices.shape[1]) + vertex_normal_norm = np.linalg.norm(vertex_normal, axis=-1, keepdims=True) + vertex_normal_norm[vertex_normal_norm == 0] = 1 + vertex_normal /= vertex_normal_norm + return vertex_normal + + +def remove_corrupted_faces( + faces: np.ndarray + ) -> np.ndarray: + """ + Remove corrupted faces (faces with duplicated vertices) + + Args: + faces (np.ndarray): [T, 3] triangular face indices + + Returns: + np.ndarray: [T_, 3] triangular face indices + """ + corrupted = (faces[:, 0] == faces[:, 1]) | (faces[:, 1] == faces[:, 2]) | (faces[:, 2] == faces[:, 0]) + return faces[~corrupted] + + +def merge_duplicate_vertices( + vertices: np.ndarray, + faces: np.ndarray, + tol: float = 1e-6 + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Merge duplicate vertices of a triangular mesh. + Duplicate vertices are merged by selecte one of them, and the face indices are updated accordingly. + + Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + tol (float, optional): tolerance for merging. Defaults to 1e-6. + + Returns: + vertices (np.ndarray): [N_, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + """ + vertices_round = np.round(vertices / tol) + _, uni_i, uni_inv = np.unique(vertices_round, return_index=True, return_inverse=True, axis=0) + vertices = vertices[uni_i] + faces = uni_inv[faces] + return vertices, faces + + +def remove_unreferenced_vertices( + faces: np.ndarray, + *vertice_attrs, + return_indices: bool = False +) -> Tuple[np.ndarray, ...]: + """ + Remove unreferenced vertices of a mesh. + Unreferenced vertices are removed, and the face indices are updated accordingly. + + Args: + faces (np.ndarray): [T, P] face indices + *vertice_attrs: vertex attributes + + Returns: + faces (np.ndarray): [T, P] face indices + *vertice_attrs: vertex attributes + indices (np.ndarray, optional): [N] indices of vertices that are kept. Defaults to None. + """ + P = faces.shape[-1] + fewer_indices, inv_map = np.unique(faces, return_inverse=True) + faces = inv_map.astype(np.int32).reshape(-1, P) + ret = [faces] + for attr in vertice_attrs: + ret.append(attr[fewer_indices]) + if return_indices: + ret.append(fewer_indices) + return tuple(ret) + + +def subdivide_mesh_simple( + vertices: np.ndarray, + faces: np.ndarray, + n: int = 1 +) -> Tuple[np.ndarray, np.ndarray]: + """ + Subdivide a triangular mesh by splitting each triangle into 4 smaller triangles. + NOTE: All original vertices are kept, and new vertices are appended to the end of the vertex list. + + Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + n (int, optional): number of subdivisions. Defaults to 1. + + Returns: + vertices (np.ndarray): [N_, 3] subdivided 3-dimensional vertices + faces (np.ndarray): [4 * T, 3] subdivided triangular face indices + """ + for _ in range(n): + edges = np.stack([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], axis=0) + edges = np.sort(edges, axis=2) + uni_edges, uni_inv = np.unique(edges.reshape(-1, 2), return_inverse=True, axis=0) + uni_inv = uni_inv.reshape(3, -1) + midpoints = (vertices[uni_edges[:, 0]] + vertices[uni_edges[:, 1]]) / 2 + + n_vertices = vertices.shape[0] + vertices = np.concatenate([vertices, midpoints], axis=0) + faces = np.concatenate([ + np.stack([faces[:, 0], n_vertices + uni_inv[0], n_vertices + uni_inv[2]], axis=1), + np.stack([faces[:, 1], n_vertices + uni_inv[1], n_vertices + uni_inv[0]], axis=1), + np.stack([faces[:, 2], n_vertices + uni_inv[2], n_vertices + uni_inv[1]], axis=1), + np.stack([n_vertices + uni_inv[0], n_vertices + uni_inv[1], n_vertices + uni_inv[2]], axis=1), + ], axis=0) + return vertices, faces + + +def mesh_relations( + faces: np.ndarray, +) -> Tuple[np.ndarray, np.ndarray]: + """ + Calculate the relation between vertices and faces. + NOTE: The input mesh must be a manifold triangle mesh. + + Args: + faces (np.ndarray): [T, 3] triangular face indices + + Returns: + edges (np.ndarray): [E, 2] edge indices + edge2face (np.ndarray): [E, 2] edge to face relation. The second column is -1 if the edge is boundary. + face2edge (np.ndarray): [T, 3] face to edge relation + face2face (np.ndarray): [T, 3] face to face relation + """ + T = faces.shape[0] + edges = np.stack([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], axis=1).reshape(-1, 2) # [3T, 2] + edges = np.sort(edges, axis=1) # [3T, 2] + edges, face2edge, occurence = np.unique(edges, axis=0, return_inverse=True, return_counts=True) # [E, 2], [3T], [E] + E = edges.shape[0] + assert np.all(occurence <= 2), "The input mesh is not a manifold mesh." + + # Edge to face relation + padding = np.arange(E, dtype=np.int32)[occurence == 1] + padded_face2edge = np.concatenate([face2edge, padding], axis=0) # [2E] + edge2face = np.argsort(padded_face2edge, kind='stable').reshape(-1, 2) // 3 # [E, 2] + edge2face_valid = edge2face[:, 1] < T # [E] + edge2face[~edge2face_valid, 1] = -1 + + # Face to edge relation + face2edge = face2edge.reshape(-1, 3) # [T, 3] + + # Face to face relation + face2face = edge2face[face2edge] # [T, 3, 2] + face2face = face2face[face2face != np.arange(T)[:, None, None]].reshape(T, 3) # [T, 3] + + return edges, edge2face, face2edge, face2face + + +@overload +def flatten_mesh_indices(faces1: np.ndarray, attr1: np.ndarray, *other_faces_attrs_pairs: np.ndarray) -> Tuple[np.ndarray, ...]: + """ + Rearrange the indices of a mesh to a flattened version. Vertices will be no longer shared. + + ### Parameters: + - `faces1`: [T, P] face indices of the first attribute + - `attr1`: [N1, ...] attributes of the first mesh + - ... + + ### Returns: + - `faces`: [T, P] flattened face indices, contigous from 0 to T * P - 1 + - `attr1`: [T * P, ...] attributes of the first mesh, where every P values correspond to a face + _ ... + """ +def flatten_mesh_indices(*args: np.ndarray) -> Tuple[np.ndarray, ...]: + assert len(args) % 2 == 0, "The number of arguments must be even." + T, P = args[0].shape + assert all(arg.shape[0] == T and arg.shape[1] == P for arg in args[::2]), "The faces must have the same shape." + attr_flat = [] + for faces_, attr_ in zip(args[::2], args[1::2]): + attr_flat_ = attr_[faces_].reshape(-1, *attr_.shape[1:]) + attr_flat.append(attr_flat_) + faces_flat = np.arange(T * P, dtype=np.int32).reshape(T, P) + return faces_flat, *attr_flat \ No newline at end of file diff --git a/src/utils3d/numpy/quadmesh.py b/src/utils3d/numpy/quadmesh.py new file mode 100644 index 0000000000000000000000000000000000000000..6728d91124020767cc9b3c1fdd6b21d50dc55828 --- /dev/null +++ b/src/utils3d/numpy/quadmesh.py @@ -0,0 +1,472 @@ +import numpy as np +import scipy as sp +import scipy.optimize as spopt +from typing import * + + +__all__ = [ + 'calc_quad_candidates', + 'calc_quad_distortion', + 'calc_quad_direction', + 'calc_quad_smoothness', + 'sovle_quad', + 'sovle_quad_qp', + 'tri_to_quad' +] + + +def calc_quad_candidates( + edges: np.ndarray, + face2edge: np.ndarray, + edge2face: np.ndarray, +): + """ + Calculate the candidate quad faces. + + Args: + edges (np.ndarray): [E, 2] edge indices + face2edge (np.ndarray): [T, 3] face to edge relation + edge2face (np.ndarray): [E, 2] edge to face relation + + Returns: + quads (np.ndarray): [Q, 4] quad candidate indices + quad2edge (np.ndarray): [Q, 4] edge to quad candidate relation + quad2adj (np.ndarray): [Q, 8] adjacent quad candidates of each quad candidate + quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid + """ + E = edges.shape[0] + T = face2edge.shape[0] + + quads_valid = edge2face[:, 1] != -1 + Q = quads_valid.sum() + quad2face = edge2face[quads_valid] # [Q, 2] + quad2edge = face2edge[quad2face] # [Q, 2, 3] + flag = quad2edge == np.arange(E)[quads_valid][:, None, None] # [Q, 2, 3] + flag = flag.argmax(axis=-1) # [Q, 2] + quad2edge = np.stack([ + quad2edge[np.arange(Q)[:, None], np.arange(2)[None, :], (flag + 1) % 3], + quad2edge[np.arange(Q)[:, None], np.arange(2)[None, :], (flag + 2) % 3], + ], axis=-1).reshape(Q, 4) # [Q, 4] + + quads = np.concatenate([ + np.where( + (edges[quad2edge[:, 0:1], 1:] == edges[quad2edge[:, 1:2], :]).any(axis=-1), + edges[quad2edge[:, 0:1], [[0, 1]]], + edges[quad2edge[:, 0:1], [[1, 0]]], + ), + np.where( + (edges[quad2edge[:, 2:3], 1:] == edges[quad2edge[:, 3:4], :]).any(axis=-1), + edges[quad2edge[:, 2:3], [[0, 1]]], + edges[quad2edge[:, 2:3], [[1, 0]]], + ), + ], axis=1) # [Q, 4] + + quad2adj = edge2face[quad2edge] # [Q, 4, 2] + quad2adj = quad2adj[quad2adj != quad2face[:, [0,0,1,1], None]].reshape(Q, 4) # [Q, 4] + quad2adj_valid = quad2adj != -1 + quad2adj = face2edge[quad2adj] # [Q, 4, 3] + quad2adj[~quad2adj_valid, 0] = quad2edge[~quad2adj_valid] + quad2adj[~quad2adj_valid, 1:] = -1 + quad2adj = quad2adj[quad2adj != quad2edge[..., None]].reshape(Q, 8) # [Q, 8] + edge_valid = -np.ones(E, dtype=np.int32) + edge_valid[quads_valid] = np.arange(Q) + quad2adj_valid = quad2adj != -1 + quad2adj[quad2adj_valid] = edge_valid[quad2adj[quad2adj_valid]] # [Q, 8] + + return quads, quad2edge, quad2adj, quads_valid + + +def calc_quad_distortion( + vertices: np.ndarray, + quads: np.ndarray, +): + """ + Calculate the distortion of each candidate quad face. + + Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + quads (np.ndarray): [Q, 4] quad face indices + + Returns: + distortion (np.ndarray): [Q] distortion of each quad face + """ + edge0 = vertices[quads[:, 1]] - vertices[quads[:, 0]] # [Q, 3] + edge1 = vertices[quads[:, 2]] - vertices[quads[:, 1]] # [Q, 3] + edge2 = vertices[quads[:, 3]] - vertices[quads[:, 2]] # [Q, 3] + edge3 = vertices[quads[:, 0]] - vertices[quads[:, 3]] # [Q, 3] + cross = vertices[quads[:, 0]] - vertices[quads[:, 2]] # [Q, 3] + + len0 = np.maximum(np.linalg.norm(edge0, axis=-1), 1e-10) # [Q] + len1 = np.maximum(np.linalg.norm(edge1, axis=-1), 1e-10) # [Q] + len2 = np.maximum(np.linalg.norm(edge2, axis=-1), 1e-10) # [Q] + len3 = np.maximum(np.linalg.norm(edge3, axis=-1), 1e-10) # [Q] + len_cross = np.maximum(np.linalg.norm(cross, axis=-1), 1e-10) # [Q] + + angle0 = np.arccos(np.clip(np.sum(-edge0 * edge1, axis=-1) / (len0 * len1), -1, 1)) # [Q] + angle1 = np.arccos(np.clip(np.sum(-edge1 * cross, axis=-1) / (len1 * len_cross), -1, 1)) \ + + np.arccos(np.clip(np.sum(cross * edge2, axis=-1) / (len_cross * len2), -1, 1)) # [Q] + angle2 = np.arccos(np.clip(np.sum(-edge2 * edge3, axis=-1) / (len2 * len3), -1, 1)) # [Q] + angle3 = np.arccos(np.clip(np.sum(-edge3 * -cross, axis=-1) / (len3 * len_cross), -1, 1)) \ + + np.arccos(np.clip(np.sum(-cross * edge0, axis=-1) / (len_cross * len0), -1, 1)) # [Q] + + normal0 = np.cross(edge0, edge1) # [Q, 3] + normal1 = np.cross(edge2, edge3) # [Q, 3] + normal0 = normal0 / np.maximum(np.linalg.norm(normal0, axis=-1, keepdims=True), 1e-10) # [Q, 3] + normal1 = normal1 / np.maximum(np.linalg.norm(normal1, axis=-1, keepdims=True), 1e-10) # [Q, 3] + angle_normal = np.arccos(np.clip(np.sum(normal0 * normal1, axis=-1), -1, 1)) # [Q] + + D90 = np.pi / 2 + D180 = np.pi + D360 = np.pi * 2 + ang_eng = (np.abs(angle0 - D90)**2 + np.abs(angle1 - D90)**2 + np.abs(angle2 - D90)**2 + np.abs(angle3 - D90)**2) / 4 # [Q] + dist_eng = np.abs(angle0 - angle2)**2 / np.minimum(np.maximum(np.minimum(angle0, angle2), 1e-10), np.maximum(D180 - np.maximum(angle0, angle2), 1e-10)) \ + + np.abs(angle1 - angle3)**2 / np.minimum(np.maximum(np.minimum(angle1, angle3), 1e-10), np.maximum(D180 - np.maximum(angle1, angle3), 1e-10)) # [Q] + plane_eng = np.where(angle_normal < D90/2, np.abs(angle_normal)**2, 1e10) # [Q] + eng = ang_eng + 2 * dist_eng + 2 * plane_eng # [Q] + + return eng + + +def calc_quad_direction( + vertices: np.ndarray, + quads: np.ndarray, + ): + """ + Calculate the direction of each candidate quad face. + + Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + quads (np.ndarray): [Q, 4] quad face indices + + Returns: + direction (np.ndarray): [Q, 4] direction of each quad face. + Represented by the angle between the crossing and each edge. + """ + mid0 = (vertices[quads[:, 0]] + vertices[quads[:, 1]]) / 2 # [Q, 3] + mid1 = (vertices[quads[:, 1]] + vertices[quads[:, 2]]) / 2 # [Q, 3] + mid2 = (vertices[quads[:, 2]] + vertices[quads[:, 3]]) / 2 # [Q, 3] + mid3 = (vertices[quads[:, 3]] + vertices[quads[:, 0]]) / 2 # [Q, 3] + + cross0 = mid2 - mid0 # [Q, 3] + cross1 = mid3 - mid1 # [Q, 3] + cross0 = cross0 / np.maximum(np.linalg.norm(cross0, axis=-1, keepdims=True), 1e-10) # [Q, 3] + cross1 = cross1 / np.maximum(np.linalg.norm(cross1, axis=-1, keepdims=True), 1e-10) # [Q, 3] + + edge0 = vertices[quads[:, 1]] - vertices[quads[:, 0]] # [Q, 3] + edge1 = vertices[quads[:, 2]] - vertices[quads[:, 1]] # [Q, 3] + edge2 = vertices[quads[:, 3]] - vertices[quads[:, 2]] # [Q, 3] + edge3 = vertices[quads[:, 0]] - vertices[quads[:, 3]] # [Q, 3] + edge0 = edge0 / np.maximum(np.linalg.norm(edge0, axis=-1, keepdims=True), 1e-10) # [Q, 3] + edge1 = edge1 / np.maximum(np.linalg.norm(edge1, axis=-1, keepdims=True), 1e-10) # [Q, 3] + edge2 = edge2 / np.maximum(np.linalg.norm(edge2, axis=-1, keepdims=True), 1e-10) # [Q, 3] + edge3 = edge3 / np.maximum(np.linalg.norm(edge3, axis=-1, keepdims=True), 1e-10) # [Q, 3] + + direction = np.stack([ + np.arccos(np.clip(np.sum(cross0 * edge0, axis=-1), -1, 1)), + np.arccos(np.clip(np.sum(cross1 * edge1, axis=-1), -1, 1)), + np.arccos(np.clip(np.sum(-cross0 * edge2, axis=-1), -1, 1)), + np.arccos(np.clip(np.sum(-cross1 * edge3, axis=-1), -1, 1)), + ], axis=-1) # [Q, 4] + + return direction + + +def calc_quad_smoothness( + quad2edge: np.ndarray, + quad2adj: np.ndarray, + quads_direction: np.ndarray, + ): + """ + Calculate the smoothness of each candidate quad face connection. + + Args: + quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face + quads_direction (np.ndarray): [Q, 4] direction of each quad face + + Returns: + smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection + """ + Q = quad2adj.shape[0] + quad2adj_valid = quad2adj != -1 + connections = np.stack([ + np.arange(Q)[:, None].repeat(8, axis=1), + quad2adj, + ], axis=-1)[quad2adj_valid] # [C, 2] + shared_edge_idx_0 = np.array([[0, 0, 1, 1, 2, 2, 3, 3]]).repeat(Q, axis=0)[quad2adj_valid] # [C] + shared_edge_idx_1 = np.argmax(quad2edge[quad2adj][quad2adj_valid] == quad2edge[connections[:, 0], shared_edge_idx_0][:, None], axis=-1) # [C] + valid_smoothness = np.abs(quads_direction[connections[:, 0], shared_edge_idx_0] - quads_direction[connections[:, 1], shared_edge_idx_1])**2 # [C] + smoothness = np.zeros([Q, 8], dtype=np.float32) + smoothness[quad2adj_valid] = valid_smoothness + return smoothness + + +def sovle_quad( + face2edge: np.ndarray, + edge2face: np.ndarray, + quad2adj: np.ndarray, + quads_distortion: np.ndarray, + quads_smoothness: np.ndarray, + quads_valid: np.ndarray, + ): + """ + Solve the quad mesh from the candidate quad faces. + + Args: + face2edge (np.ndarray): [T, 3] face to edge relation + edge2face (np.ndarray): [E, 2] edge to face relation + quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face + quads_distortion (np.ndarray): [Q] distortion of each quad face + quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection + quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid + + Returns: + weights (np.ndarray): [Q] weight of each valid quad face + """ + T = face2edge.shape[0] + E = edge2face.shape[0] + Q = quads_distortion.shape[0] + edge_valid = -np.ones(E, dtype=np.int32) + edge_valid[quads_valid] = np.arange(Q) + + quads_connection = np.stack([ + np.arange(Q)[:, None].repeat(8, axis=1), + quad2adj, + ], axis=-1)[quad2adj != -1] # [C, 2] + quads_connection = np.sort(quads_connection, axis=-1) # [C, 2] + quads_connection, quads_connection_idx = np.unique(quads_connection, axis=0, return_index=True) # [C, 2], [C] + quads_smoothness = quads_smoothness[quad2adj != -1] # [C] + quads_smoothness = quads_smoothness[quads_connection_idx] # [C] + C = quads_connection.shape[0] + + # Construct the linear programming problem + + # Variables: + # quads_weight: [Q] weight of each quad face + # tri_min_weight: [T] minimum weight of each triangle face + # conn_min_weight: [C] minimum weight of each quad face connection + # conn_max_weight: [C] maximum weight of each quad face connection + # Objective: + # mimi + + c = np.concatenate([ + quads_distortion - 3, + quads_smoothness*4 - 2, + quads_smoothness*4, + ], axis=0) # [Q+C] + + A_ub_triplet = np.concatenate([ + np.stack([np.arange(T), edge_valid[face2edge[:, 0]], np.ones(T)], axis=1), # [T, 3] + np.stack([np.arange(T), edge_valid[face2edge[:, 1]], np.ones(T)], axis=1), # [T, 3] + np.stack([np.arange(T), edge_valid[face2edge[:, 2]], np.ones(T)], axis=1), # [T, 3] + np.stack([np.arange(T, T+C), np.arange(Q, Q+C), np.ones(C)], axis=1), # [C, 3] + np.stack([np.arange(T, T+C), quads_connection[:, 0], -np.ones(C)], axis=1), # [C, 3] + np.stack([np.arange(T, T+C), quads_connection[:, 1], -np.ones(C)], axis=1), # [C, 3] + np.stack([np.arange(T+C, T+2*C), np.arange(Q+C, Q+2*C), -np.ones(C)], axis=1), # [C, 3] + np.stack([np.arange(T+C, T+2*C), quads_connection[:, 0], np.ones(C)], axis=1), # [C, 3] + np.stack([np.arange(T+C, T+2*C), quads_connection[:, 1], np.ones(C)], axis=1), # [C, 3] + ], axis=0) # [3T+6C, 3] + A_ub_triplet = A_ub_triplet[A_ub_triplet[:, 1] != -1] # [3T', 3] + A_ub = sp.sparse.coo_matrix((A_ub_triplet[:, 2], (A_ub_triplet[:, 0], A_ub_triplet[:, 1])), shape=[T+2*C, Q+2*C]) # [T, + b_ub = np.concatenate([np.ones(T), -np.ones(C), np.ones(C)], axis=0) # [T+2C] + bound = np.stack([ + np.concatenate([np.zeros(Q), -np.ones(C), np.zeros(C)], axis=0), + np.concatenate([np.ones(Q), np.ones(C), np.ones(C)], axis=0), + ], axis=1) # [Q+2C, 2] + A_eq = None + b_eq = None + + print('Solver statistics:') + print(f' #T = {T}') + print(f' #Q = {Q}') + print(f' #C = {C}') + + # Solve the linear programming problem + last_num_valid = 0 + for i in range(100): + res_ = spopt.linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=bound) + if not res_.success: + print(f' Iter {i} | Failed with {res_.message}') + break + res = res_ + weights = res.x[:Q] + valid = (weights > 0.5) + num_valid = valid.sum() + print(f' Iter {i} | #Q_valid = {num_valid}') + if num_valid == last_num_valid: + break + last_num_valid = num_valid + A_eq_triplet = np.stack([ + np.arange(num_valid), + np.arange(Q)[valid], + np.ones(num_valid), + ], axis=1) # [num_valid, 3] + A_eq = sp.sparse.coo_matrix((A_eq_triplet[:, 2], (A_eq_triplet[:, 0], A_eq_triplet[:, 1])), shape=[num_valid, Q+2*C]) # [num_valid, Q+C] + b_eq = np.where(weights[valid] > 0.5, 1, 0) # [num_valid] + + # Return the result + quads_weight = res.x[:Q] + conn_min_weight = res.x[Q:Q+C] + conn_max_weight = res.x[Q+C:Q+2*C] + return quads_weight, conn_min_weight, conn_max_weight + + +def sovle_quad_qp( + face2edge: np.ndarray, + edge2face: np.ndarray, + quad2adj: np.ndarray, + quads_distortion: np.ndarray, + quads_smoothness: np.ndarray, + quads_valid: np.ndarray, + ): + """ + Solve the quad mesh from the candidate quad faces. + + Args: + face2edge (np.ndarray): [T, 3] face to edge relation + edge2face (np.ndarray): [E, 2] edge to face relation + quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face + quads_distortion (np.ndarray): [Q] distortion of each quad face + quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection + quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid + + Returns: + weights (np.ndarray): [Q] weight of each valid quad face + """ + T = face2edge.shape[0] + E = edge2face.shape[0] + Q = quads_distortion.shape[0] + edge_valid = -np.ones(E, dtype=np.int32) + edge_valid[quads_valid] = np.arange(Q) + + # Construct the quadratic programming problem + C_smoothness_triplet = np.stack([ + np.arange(Q)[:, None].repeat(8, axis=1)[quad2adj != -1], + quad2adj[quad2adj != -1], + 5 * quads_smoothness[quad2adj != -1], + ], axis=-1) # [C, 3] + # C_smoothness_triplet = np.concatenate([ + # C_smoothness_triplet, + # np.stack([np.arange(Q), np.arange(Q), 20*np.ones(Q)], axis=1), + # ], axis=0) # [C+Q, 3] + C_smoothness = sp.sparse.coo_matrix((C_smoothness_triplet[:, 2], (C_smoothness_triplet[:, 0], C_smoothness_triplet[:, 1])), shape=[Q, Q]) # [Q, Q] + C_smoothness = C_smoothness.tocsc() + C_dist = quads_distortion - 20 # [Q] + + A_eq = sp.sparse.coo_matrix((np.zeros(Q), (np.zeros(Q), np.arange(Q))), shape=[1, Q]) # [1, Q]\ + A_eq = A_eq.tocsc() + b_eq = np.array([0]) + + A_ub_triplet = np.concatenate([ + np.stack([np.arange(T), edge_valid[face2edge[:, 0]], np.ones(T)], axis=1), # [T, 3] + np.stack([np.arange(T), edge_valid[face2edge[:, 1]], np.ones(T)], axis=1), # [T, 3] + np.stack([np.arange(T), edge_valid[face2edge[:, 2]], np.ones(T)], axis=1), # [T, 3] + ], axis=0) # [3T, 3] + A_ub_triplet = A_ub_triplet[A_ub_triplet[:, 1] != -1] # [3T', 3] + A_ub = sp.sparse.coo_matrix((A_ub_triplet[:, 2], (A_ub_triplet[:, 0], A_ub_triplet[:, 1])), shape=[T, Q]) # [T, Q] + A_ub = A_ub.tocsc() + b_ub = np.ones(T) + + lb = np.zeros(Q) + ub = np.ones(Q) + + import piqp + solver = piqp.SparseSolver() + solver.settings.verbose = True + solver.settings.compute_timings = True + solver.setup(C_smoothness, C_dist, A_eq, b_eq, A_ub, b_ub, lb, ub) + + status = solver.solve() + + # x = cp.Variable(Q) + # prob = cp.Problem( + # cp.Minimize(cp.quad_form(x, C_smoothness) + C_dist.T @ x), + # [ + # A_ub @ x <= b_ub, + # x >= 0, x <= 1, + # ] + # ) + + # # Solve the quadratic programming problem + # prob.solve(solver=cp.PIQP, verbose=True) + + # Return the result + weights = solver.result.x + return weights + + +def tri_to_quad( + vertices: np.ndarray, + faces: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Convert a triangle mesh to a quad mesh. + NOTE: The input mesh must be a manifold mesh. + + Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + + Returns: + vertices (np.ndarray): [N_, 3] 3-dimensional vertices + faces (np.ndarray): [Q, 4] quad face indices + """ + raise NotImplementedError + + +if __name__ == '__main__': + import os + import sys + sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))) + import utils3d + import numpy as np + import cv2 + from vis import vis_edge_color + + file = 'miku' + + vertices, faces = utils3d.io.read_ply(f'test/assets/{file}.ply') + edges, edge2face, face2edge, face2face = calc_relations(faces) + quad_cands, quad2edge, quad2adj, quad_valid = calc_quad_candidates(edges, face2edge, edge2face) + distortion = calc_quad_distortion(vertices, quad_cands) + direction = calc_quad_direction(vertices, quad_cands) + smoothness = calc_quad_smoothness(quad2edge, quad2adj, direction) + boundary_edges = edges[edge2face[:, 1] == -1] + quads_weight, conn_min_weight, conn_max_weight = sovle_quad(face2edge, edge2face, quad2adj, distortion, smoothness, quad_valid) + quads = quad_cands[quads_weight > 0.5] + print('Mesh statistics') + print(f' #V = {vertices.shape[0]}') + print(f' #F = {faces.shape[0]}') + print(f' #E = {edges.shape[0]}') + print(f' #B = {boundary_edges.shape[0]}') + print(f' #Q_cand = {quad_cands.shape[0]}') + print(f' #Q = {quads.shape[0]}') + + utils3d.io.write_ply(f'test/assets/{file}_boundary_edges.ply', vertices=vertices, edges=boundary_edges) + utils3d.io.write_ply(f'test/assets/{file}_quad_candidates.ply', vertices=vertices, faces=quads) + + edge_colors = np.zeros([edges.shape[0], 3], dtype=np.uint8) + distortion = (distortion - distortion.min()) / (distortion.max() - distortion.min()) + distortion = (distortion * 255).astype(np.uint8) + edge_colors[quad_valid] = cv2.cvtColor(cv2.applyColorMap(distortion, cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3) + utils3d.io.write_ply(f'test/assets/{file}_quad_candidates_distortion.ply', **vis_edge_color(vertices, edges, edge_colors)) + + edge_colors = np.zeros([edges.shape[0], 3], dtype=np.uint8) + edge_colors[quad_valid] = cv2.cvtColor(cv2.applyColorMap((quads_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3) + utils3d.io.write_ply(f'test/assets/{file}_quad_candidates_weights.ply', **vis_edge_color(vertices, edges, edge_colors)) + utils3d.io.write_ply(f'test/assets/{file}_quad.ply', vertices=vertices, faces=quads) + + quad_centers = vertices[quad_cands].mean(axis=1) + conns = np.stack([ + np.arange(quad_cands.shape[0])[:, None].repeat(8, axis=1), + quad2adj, + ], axis=-1)[quad2adj != -1] # [C, 2] + conns, conns_idx = np.unique(np.sort(conns, axis=-1), axis=0, return_index=True) # [C, 2], [C] + smoothness = smoothness[quad2adj != -1][conns_idx] # [C] + conns_color = cv2.cvtColor(cv2.applyColorMap((smoothness * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3) + utils3d.io.write_ply(f'test/assets/{file}_quad_conn_smoothness.ply', **vis_edge_color(quad_centers, conns, conns_color)) + conns_color = cv2.cvtColor(cv2.applyColorMap((conn_min_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3) + utils3d.io.write_ply(f'test/assets/{file}_quad_conn_min.ply', **vis_edge_color(quad_centers, conns, conns_color)) + conns_color = cv2.cvtColor(cv2.applyColorMap((conn_max_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3) + utils3d.io.write_ply(f'test/assets/{file}_quad_conn_max.ply', **vis_edge_color(quad_centers, conns, conns_color)) + + \ No newline at end of file diff --git a/src/utils3d/numpy/rasterization.py b/src/utils3d/numpy/rasterization.py new file mode 100644 index 0000000000000000000000000000000000000000..8f0f0db55d87f37f108a778dac29ae6320418f3a --- /dev/null +++ b/src/utils3d/numpy/rasterization.py @@ -0,0 +1,469 @@ +import os +from typing import * + +import numpy as np +import moderngl + +from . import transforms, utils, mesh + + +__all__ = [ + 'RastContext', + 'rasterize_triangle_faces', + 'rasterize_edges', + 'texture', + 'test_rasterization', + 'warp_image_by_depth', +] + + +def map_np_dtype(dtype) -> str: + if dtype == int: + return 'i4' + elif dtype == np.uint8: + return 'u1' + elif dtype == np.uint32: + return 'u2' + elif dtype == np.float16: + return 'f2' + elif dtype == np.float32: + return 'f4' + + +def one_value(dtype): + if dtype == 'u1': + return 255 + elif dtype == 'u2': + return 65535 + else: + return 1 + + +class RastContext: + def __init__(self, *args, **kwargs): + """ + Create a moderngl context. + + Args: + See moderngl.create_context + """ + if len(args) == 1 and isinstance(args[0], moderngl.Context): + self.mgl_ctx = args[0] + else: + self.mgl_ctx = moderngl.create_context(*args, **kwargs) + self.__prog_src = {} + self.__prog = {} + + def program_vertex_attribute(self, n: int) -> moderngl.Program: + assert n in [1, 2, 3, 4], 'vertex attribute only supports channels 1, 2, 3, 4' + + if 'vertex_attribute_vsh' not in self.__prog_src: + with open(os.path.join(os.path.dirname(__file__), 'shaders', 'vertex_attribute.vsh'), 'r') as f: + self.__prog_src['vertex_attribute_vsh'] = f.read() + if 'vertex_attribute_fsh' not in self.__prog_src: + with open(os.path.join(os.path.dirname(__file__), 'shaders', 'vertex_attribute.fsh'), 'r') as f: + self.__prog_src['vertex_attribute_fsh'] = f.read() + + if f'vertex_attribute_{n}' not in self.__prog: + vsh = self.__prog_src['vertex_attribute_vsh'].replace('vecN', f'vec{n}') + fsh = self.__prog_src['vertex_attribute_fsh'].replace('vecN', f'vec{n}') + self.__prog[f'vertex_attribute_{n}'] = self.mgl_ctx.program(vertex_shader=vsh, fragment_shader=fsh) + + return self.__prog[f'vertex_attribute_{n}'] + + def program_texture(self, n: int) -> moderngl.Program: + assert n in [1, 2, 3, 4], 'texture only supports channels 1, 2, 3, 4' + + if 'texture_vsh' not in self.__prog_src: + with open(os.path.join(os.path.dirname(__file__), 'shaders', 'texture.vsh'), 'r') as f: + self.__prog_src['texture_vsh'] = f.read() + if 'texture_fsh' not in self.__prog_src: + with open(os.path.join(os.path.dirname(__file__), 'shaders', 'texture.fsh'), 'r') as f: + self.__prog_src['texture_fsh'] = f.read() + + if f'texture_{n}' not in self.__prog: + vsh = self.__prog_src['texture_vsh'].replace('vecN', f'vec{n}') + fsh = self.__prog_src['texture_fsh'].replace('vecN', f'vec{n}') + self.__prog[f'texture_{n}'] = self.mgl_ctx.program(vertex_shader=vsh, fragment_shader=fsh) + self.__prog[f'texture_{n}']['tex'] = 0 + self.__prog[f'texture_{n}']['uv'] = 1 + + return self.__prog[f'texture_{n}'] + + +def rasterize_triangle_faces( + ctx: RastContext, + vertices: np.ndarray, + faces: np.ndarray, + attr: np.ndarray, + width: int, + height: int, + transform: np.ndarray = None, + cull_backface: bool = True, + return_depth: bool = False, + image: np.ndarray = None, + depth: np.ndarray = None +) -> Tuple[np.ndarray, np.ndarray]: + """ + Rasterize vertex attribute. + + Args: + vertices (np.ndarray): [N, 3] + faces (np.ndarray): [T, 3] + attr (np.ndarray): [N, C] + width (int): width of rendered image + height (int): height of rendered image + transform (np.ndarray): [4, 4] model-view-projection transformation matrix. + cull_backface (bool): whether to cull backface + image: (np.ndarray): [H, W, C] background image + depth: (np.ndarray): [H, W] background depth + + Returns: + image (np.ndarray): [H, W, C] rendered image + depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None. + """ + assert vertices.ndim == 2 and vertices.shape[1] == 3 + assert faces.ndim == 2 and faces.shape[1] == 3, f"Faces should be a 2D array with shape (T, 3), but got {faces.shape}" + assert attr.ndim == 2 and attr.shape[1] in [1, 2, 3, 4], f'Vertex attribute only supports channels 1, 2, 3, 4, but got {attr.shape}' + assert vertices.shape[0] == attr.shape[0] + assert vertices.dtype == np.float32 + assert faces.dtype == np.uint32 or faces.dtype == np.int32 + assert attr.dtype == np.float32, "Attribute should be float32" + assert transform is None or transform.shape == (4, 4), f"Transform should be a 4x4 matrix, but got {transform.shape}" + assert transform is None or transform.dtype == np.float32, f"Transform should be float32, but got {transform.dtype}" + if image is not None: + assert image.ndim == 3 and image.shape == (height, width, attr.shape[1]), f"Image should be a 3D array with shape (H, W, {attr.shape[1]}), but got {image.shape}" + assert image.dtype == np.float32, f"Image should be float32, but got {image.dtype}" + if depth is not None: + assert depth.ndim == 2 and depth.shape == (height, width), f"Depth should be a 2D array with shape (H, W), but got {depth.shape}" + assert depth.dtype == np.float32, f"Depth should be float32, but got {depth.dtype}" + + C = attr.shape[1] + prog = ctx.program_vertex_attribute(C) + + transform = np.eye(4, np.float32) if transform is None else transform + + # Create buffers + ibo = ctx.mgl_ctx.buffer(np.ascontiguousarray(faces, dtype='i4')) + vbo_vertices = ctx.mgl_ctx.buffer(np.ascontiguousarray(vertices, dtype='f4')) + vbo_attr = ctx.mgl_ctx.buffer(np.ascontiguousarray(attr, dtype='f4')) + vao = ctx.mgl_ctx.vertex_array( + prog, + [ + (vbo_vertices, '3f', 'i_position'), + (vbo_attr, f'{C}f', 'i_attr'), + ], + ibo, + mode=moderngl.TRIANGLES, + ) + + # Create framebuffer + image_tex = ctx.mgl_ctx.texture((width, height), C, dtype='f4', data=np.ascontiguousarray(image[::-1, :, :]) if image is not None else None) + depth_tex = ctx.mgl_ctx.depth_texture((width, height), data=np.ascontiguousarray(depth[::-1, :]) if depth is not None else None) + fbo = ctx.mgl_ctx.framebuffer( + color_attachments=[image_tex], + depth_attachment=depth_tex, + ) + + # Render + prog['u_mvp'].write(transform.transpose().copy().astype('f4')) + fbo.use() + fbo.viewport = (0, 0, width, height) + ctx.mgl_ctx.depth_func = '<' + if depth is None: + ctx.mgl_ctx.clear(depth=1.0) + ctx.mgl_ctx.enable(ctx.mgl_ctx.DEPTH_TEST) + if cull_backface: + ctx.mgl_ctx.enable(ctx.mgl_ctx.CULL_FACE) + else: + ctx.mgl_ctx.disable(ctx.mgl_ctx.CULL_FACE) + vao.render() + ctx.mgl_ctx.disable(ctx.mgl_ctx.DEPTH_TEST) + + # Read + image = np.zeros((height, width, C), dtype='f4') + image_tex.read_into(image) + image = image[::-1, :, :] + if return_depth: + depth = np.zeros((height, width), dtype='f4') + depth_tex.read_into(depth) + depth = depth[::-1, :] + else: + depth = None + + # Release + vao.release() + ibo.release() + vbo_vertices.release() + vbo_attr.release() + fbo.release() + image_tex.release() + depth_tex.release() + + return image, depth + + +def rasterize_edges( + ctx: RastContext, + vertices: np.ndarray, + edges: np.ndarray, + attr: np.ndarray, + width: int, + height: int, + transform: np.ndarray = None, + line_width: float = 1.0, + return_depth: bool = False, + image: np.ndarray = None, + depth: np.ndarray = None +) -> Tuple[np.ndarray, ...]: + """ + Rasterize vertex attribute. + + Args: + vertices (np.ndarray): [N, 3] + faces (np.ndarray): [T, 3] + attr (np.ndarray): [N, C] + width (int): width of rendered image + height (int): height of rendered image + transform (np.ndarray): [4, 4] model-view-projection matrix + line_width (float): width of line. Defaults to 1.0. NOTE: Values other than 1.0 may not work across all platforms. + cull_backface (bool): whether to cull backface + + Returns: + image (np.ndarray): [H, W, C] rendered image + depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None. + """ + assert vertices.ndim == 2 and vertices.shape[1] == 3 + assert edges.ndim == 2 and edges.shape[1] == 2, f"Edges should be a 2D array with shape (T, 2), but got {edges.shape}" + assert attr.ndim == 2 and attr.shape[1] in [1, 2, 3, 4], f'Vertex attribute only supports channels 1, 2, 3, 4, but got {attr.shape}' + assert vertices.shape[0] == attr.shape[0] + assert vertices.dtype == np.float32 + assert edges.dtype == np.uint32 or edges.dtype == np.int32 + assert attr.dtype == np.float32, "Attribute should be float32" + + C = attr.shape[1] + prog = ctx.program_vertex_attribute(C) + + transform = transform if transform is not None else np.eye(4, np.float32) + + # Create buffers + ibo = ctx.mgl_ctx.buffer(np.ascontiguousarray(edges, dtype='i4')) + vbo_vertices = ctx.mgl_ctx.buffer(np.ascontiguousarray(vertices, dtype='f4')) + vbo_attr = ctx.mgl_ctx.buffer(np.ascontiguousarray(attr, dtype='f4')) + vao = ctx.mgl_ctx.vertex_array( + prog, + [ + (vbo_vertices, '3f', 'i_position'), + (vbo_attr, f'{C}f', 'i_attr'), + ], + ibo, + mode=moderngl.LINES, + ) + + # Create framebuffer + image_tex = ctx.mgl_ctx.texture((width, height), C, dtype='f4', data=np.ascontiguousarray(image[::-1, :, :]) if image is not None else None) + depth_tex = ctx.mgl_ctx.depth_texture((width, height), data=np.ascontiguousarray(depth[::-1, :]) if depth is not None else None) + fbo = ctx.mgl_ctx.framebuffer( + color_attachments=[image_tex], + depth_attachment=depth_tex, + ) + + # Render + prog['u_mvp'].write(transform.transpose().copy().astype('f4')) + fbo.use() + fbo.viewport = (0, 0, width, height) + if depth is None: + ctx.mgl_ctx.clear(depth=1.0) + ctx.mgl_ctx.depth_func = '<' + ctx.mgl_ctx.enable(ctx.mgl_ctx.DEPTH_TEST) + ctx.mgl_ctx.line_width = line_width + vao.render() + ctx.mgl_ctx.disable(ctx.mgl_ctx.DEPTH_TEST) + + # Read + image = np.zeros((height, width, C), dtype='f4') + image_tex.read_into(image) + image = image[::-1, :, :] + if return_depth: + depth = np.zeros((height, width), dtype='f4') + depth_tex.read_into(depth) + depth = depth[::-1, :] + else: + depth = None + + # Release + vao.release() + ibo.release() + vbo_vertices.release() + vbo_attr.release() + fbo.release() + image_tex.release() + depth_tex.release() + + return image, depth + + +def texture( + ctx: RastContext, + uv: np.ndarray, + texture: np.ndarray, + interpolation: str= 'linear', + wrap: str = 'clamp' +) -> np.ndarray: + """ + Given an UV image, texturing from the texture map + """ + assert len(texture.shape) == 3 and 1 <= texture.shape[2] <= 4 + assert uv.shape[2] == 2 + height, width = uv.shape[:2] + texture_dtype = map_np_dtype(texture.dtype) + + # Create VAO + screen_quad_vbo = ctx.mgl_ctx.buffer(np.array([[-1, -1], [1, -1], [1, 1], [-1, 1]], dtype='f4')) + screen_quad_ibo = ctx.mgl_ctx.buffer(np.array([0, 1, 2, 0, 2, 3], dtype=np.int32)) + screen_quad_vao = ctx.mgl_ctx.vertex_array(ctx.program_texture(texture.shape[2]), [(screen_quad_vbo, '2f4', 'in_vert')], index_buffer=screen_quad_ibo, index_element_size=4) + + # Create texture, set filter and bind. TODO: min mag filter, mipmap + texture_tex = ctx.mgl_ctx.texture((texture.shape[1], texture.shape[0]), texture.shape[2], dtype=texture_dtype, data=np.ascontiguousarray(texture)) + if interpolation == 'linear': + texture_tex.filter = (moderngl.LINEAR, moderngl.LINEAR) + elif interpolation == 'nearest': + texture_tex.filter = (moderngl.NEAREST, moderngl.NEAREST) + texture_tex.use(location=0) + texture_uv = ctx.mgl_ctx.texture((width, height), 2, dtype='f4', data=np.ascontiguousarray(uv.astype('f4', copy=False))) + texture_uv.filter = (moderngl.NEAREST, moderngl.NEAREST) + texture_uv.use(location=1) + + # Create render buffer and frame buffer + rb = ctx.mgl_ctx.renderbuffer((uv.shape[1], uv.shape[0]), texture.shape[2], dtype=texture_dtype) + fbo = ctx.mgl_ctx.framebuffer(color_attachments=[rb]) + + # Render + fbo.use() + fbo.viewport = (0, 0, width, height) + ctx.mgl_ctx.disable(ctx.mgl_ctx.BLEND) + screen_quad_vao.render() + + # Read buffer + image_buffer = np.frombuffer(fbo.read(components=texture.shape[2], attachment=0, dtype=texture_dtype), dtype=texture_dtype).reshape((height, width, texture.shape[2])) + + # Release + texture_tex.release() + rb.release() + fbo.release() + + return image_buffer + + +def warp_image_by_depth( + ctx: RastContext, + src_depth: np.ndarray, + src_image: np.ndarray = None, + width: int = None, + height: int = None, + *, + extrinsics_src: np.ndarray = None, + extrinsics_tgt: np.ndarray = None, + intrinsics_src: np.ndarray = None, + intrinsics_tgt: np.ndarray = None, + near: float = 0.1, + far: float = 100.0, + cull_backface: bool = True, + ssaa: int = 1, + return_depth: bool = False, +) -> Tuple[np.ndarray, ...]: + """ + Warp image by depth map. + + Args: + ctx (RastContext): rasterizer context + src_depth (np.ndarray): [H, W] + src_image (np.ndarray, optional): [H, W, C]. The image to warp. Defaults to None (use uv coordinates). + width (int, optional): width of the output image. None to use depth map width. Defaults to None. + height (int, optional): height of the output image. None to use depth map height. Defaults to None. + extrinsics_src (np.ndarray, optional): extrinsics matrix of the source camera. Defaults to None (identity). + extrinsics_tgt (np.ndarray, optional): extrinsics matrix of the target camera. Defaults to None (identity). + intrinsics_src (np.ndarray, optional): intrinsics matrix of the source camera. Defaults to None (use the same as intrinsics_tgt). + intrinsics_tgt (np.ndarray, optional): intrinsics matrix of the target camera. Defaults to None (use the same as intrinsics_src). + cull_backface (bool, optional): whether to cull backface. Defaults to True. + ssaa (int, optional): super sampling anti-aliasing. Defaults to 1. + + Returns: + tgt_image (np.ndarray): [H, W, C] warped image (or uv coordinates if image is None). + tgt_depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None. + """ + assert src_depth.ndim == 2 + + if width is None: + width = src_depth.shape[1] + if height is None: + height = src_depth.shape[0] + if src_image is not None: + assert src_image.shape[-2:] == src_depth.shape[-2:], f'Shape of source image {src_image.shape} does not match shape of source depth {src_depth.shape}' + + # set up default camera parameters + extrinsics_src = np.eye(4) if extrinsics_src is None else extrinsics_src + extrinsics_tgt = np.eye(4) if extrinsics_tgt is None else extrinsics_tgt + intrinsics_src = intrinsics_tgt if intrinsics_src is None else intrinsics_src + intrinsics_tgt = intrinsics_src if intrinsics_tgt is None else intrinsics_tgt + + assert all(x is not None for x in [extrinsics_src, extrinsics_tgt, intrinsics_src, intrinsics_tgt]), "Make sure you have provided all the necessary camera parameters." + + # check shapes + assert extrinsics_src.shape == (4, 4) and extrinsics_tgt.shape == (4, 4) + assert intrinsics_src.shape == (3, 3) and intrinsics_tgt.shape == (3, 3) + + # convert to view and perspective matrices + view_tgt = transforms.extrinsics_to_view(extrinsics_tgt) + perspective_tgt = transforms.intrinsics_to_perspective(intrinsics_tgt, near=near, far=far) + + # unproject depth map + uv, faces = utils.image_mesh(*src_depth.shape[-2:]) + pts = transforms.unproject_cv(uv, src_depth.reshape(-1), extrinsics_src, intrinsics_src) + faces = mesh.triangulate(faces, vertices=pts) + + # rasterize attributes + if src_image is not None: + attr = src_image.reshape(-1, src_image.shape[-1]) + else: + attr = uv + + tgt_image, tgt_depth = rasterize_triangle_faces( + ctx, + pts, + faces, + attr, + width * ssaa, + height * ssaa, + transform=perspective_tgt @ view_tgt, + cull_backface=cull_backface, + return_depth=return_depth, + ) + + if ssaa > 1: + tgt_image = tgt_image.reshape(height, ssaa, width, ssaa, -1).mean(axis=(1, 3)) + tgt_depth = tgt_depth.reshape(height, ssaa, width, ssaa, -1).mean(axis=(1, 3)) if return_depth else None + + return tgt_image, tgt_depth + +def test_rasterization(ctx: RastContext): + """ + Test if rasterization works. It will render a cube with random colors and save it as a CHECKME.png file. + """ + vertices, faces = utils.cube(tri=True) + attr = np.random.rand(len(vertices), 3).astype(np.float32) + perspective = transforms.perspective(np.deg2rad(60), 1, 0.01, 100) + view = transforms.view_look_at(np.array([2, 2, 2]), np.array([0, 0, 0]), np.array([0, 1, 0])) + image, depth = rasterize_triangle_faces( + ctx, + vertices, + faces, + attr, + 512, 512, + transform=(perspective @ view).astype(np.float32), + cull_backface=False, + return_depth=True, + ) + import cv2 + cv2.imwrite('CHECKME.png', cv2.cvtColor((image.clip(0, 1) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)) + \ No newline at end of file diff --git a/src/utils3d/numpy/shaders/texture.fsh b/src/utils3d/numpy/shaders/texture.fsh new file mode 100644 index 0000000000000000000000000000000000000000..c8be72f94cbf38fb0b2a9609e8db4d50ac7753d6 --- /dev/null +++ b/src/utils3d/numpy/shaders/texture.fsh @@ -0,0 +1,11 @@ +#version 330 + +uniform sampler2D tex; +uniform sampler2D uv; + +in vec2 scr_coord; +out vecN tex_color; + +void main() { + tex_color = vecN(texture(tex, texture(uv, scr_coord).xy)); +} \ No newline at end of file diff --git a/src/utils3d/numpy/shaders/texture.vsh b/src/utils3d/numpy/shaders/texture.vsh new file mode 100644 index 0000000000000000000000000000000000000000..f96c6b14a8931fbcd5f4ca22ea917b9c8f80f195 --- /dev/null +++ b/src/utils3d/numpy/shaders/texture.vsh @@ -0,0 +1,9 @@ + #version 330 core + +in vec2 in_vert; +out vec2 scr_coord; + +void main() { + scr_coord = in_vert * 0.5 + 0.5; + gl_Position = vec4(in_vert, 0., 1.); +} \ No newline at end of file diff --git a/src/utils3d/numpy/shaders/vertex_attribute.fsh b/src/utils3d/numpy/shaders/vertex_attribute.fsh new file mode 100644 index 0000000000000000000000000000000000000000..54409764c5600ee190db89313b07dd91b940d6eb --- /dev/null +++ b/src/utils3d/numpy/shaders/vertex_attribute.fsh @@ -0,0 +1,9 @@ +#version 330 + +in vecN v_attr; + +out vecN f_attr; + +void main() { + f_attr = v_attr; +} diff --git a/src/utils3d/numpy/shaders/vertex_attribute.vsh b/src/utils3d/numpy/shaders/vertex_attribute.vsh new file mode 100644 index 0000000000000000000000000000000000000000..7c94f91aaabfd714a47a194b93f8e53bf63577f5 --- /dev/null +++ b/src/utils3d/numpy/shaders/vertex_attribute.vsh @@ -0,0 +1,13 @@ +#version 330 + +uniform mat4 u_mvp; + +in vec3 i_position; +in vecN i_attr; + +out vecN v_attr; + +void main() { + gl_Position = u_mvp * vec4(i_position, 1.0); + v_attr = i_attr; +} diff --git a/src/utils3d/numpy/spline.py b/src/utils3d/numpy/spline.py new file mode 100644 index 0000000000000000000000000000000000000000..03c664136bc3734215d37669a3446c248dffe097 --- /dev/null +++ b/src/utils3d/numpy/spline.py @@ -0,0 +1,82 @@ +from typing import * + +import numpy as np + + +__all__ = ['linear_spline_interpolate'] + + +def linear_spline_interpolate(x: np.ndarray, t: np.ndarray, s: np.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> np.ndarray: + """ + Linear spline interpolation. + + ### Parameters: + - `x`: np.ndarray, shape (n, d): the values of data points. + - `t`: np.ndarray, shape (n,): the times of the data points. + - `s`: np.ndarray, shape (m,): the times to be interpolated. + - `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly. + + ### Returns: + - `y`: np.ndarray, shape (..., m, d): the interpolated values. + """ + i = np.searchsorted(t, s, side='left') + if extrapolation_mode == 'constant': + prev = np.clip(i - 1, 0, len(t) - 1) + suc = np.clip(i, 0, len(t) - 1) + elif extrapolation_mode == 'linear': + prev = np.clip(i - 1, 0, len(t) - 2) + suc = np.clip(i, 1, len(t) - 1) + else: + raise ValueError(f'Invalid extrapolation_mode: {extrapolation_mode}') + + u = (s - t[prev]) / np.maximum(t[suc] - t[prev], 1e-12) + y = u * x[suc] + (1 - u) * x[prev] + + return y + + + +def _solve_tridiagonal(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray) -> np.ndarray: + n = b.shape[-1] + cc = np.zeros_like(b) + dd = np.zeros_like(b) + cc[..., 0] = c[..., 0] / b[..., 0] + dd[..., 0] = d[..., 0] / b[..., 0] + for i in range(1, n): + cc[..., i] = c[..., i] / (b[..., i] - a[..., i - 1] * cc[..., i - 1]) + dd[..., i] = (d[..., i] - a[..., i - 1] * dd[..., i - 1]) / (b[..., i] - a[..., i - 1] * cc[..., i - 1]) + x = np.zeros_like(b) + x[..., -1] = dd[..., -1] + for i in range(n - 2, -1, -1): + x[..., i] = dd[..., i] - cc[..., i] * x[..., i + 1] + return x + + +def cubic_spline_interpolate(x: np.ndarray, t: np.ndarray, s: np.ndarray, v0: np.ndarray = None, vn: np.ndarray = None) -> np.ndarray: + """ + Cubic spline interpolation. + + ### Parameters: + - `x`: np.ndarray, shape (..., n,): the x-coordinates of the data points. + - `t`: np.ndarray, shape (n,): the knot vector. NOTE: t must be sorted in ascending order. + - `s`: np.ndarray, shape (..., m,): the y-coordinates of the data points. + - `v0`: np.ndarray, shape (...,): the value of the derivative at the first knot, as the boundary condition. If None, it is set to zero. + - `vn`: np.ndarray, shape (...,): the value of the derivative at the last knot, as the boundary condition. If None, it is set to zero. + + ### Returns: + - `y`: np.ndarray, shape (..., m): the interpolated values. + """ + h = t[..., 1:] - t[..., :-1] + mu = h[..., :-1] / (h[..., :-1] + h[..., 1:]) + la = 1 - mu + d = (x[..., 1:] - x[..., :-1]) / h + d = 6 * (d[..., 1:] - d[..., :-1]) / (t[..., 2:] - t[..., :-2]) + + mu = np.concatenate([mu, np.ones_like(mu[..., :1])], axis=-1) + la = np.concatenate([np.ones_like(la[..., :1]), la], axis=-1) + d = np.concatenate([(((x[..., 1] - x[..., 0]) / h[0] - v0) / h[0])[..., None], d, ((vn - (x[..., -1] - x[..., -2]) / h[-1]) / h[-1])[..., None]], axis=-1) + + M = _solve_tridiagonal(mu, np.full_like(d, fill_value=2), la, d) + + i = np.searchsorted(t, s, side='left') + diff --git a/src/utils3d/numpy/transforms.py b/src/utils3d/numpy/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..2e2418540b4a1d152a5e072900079f59f474a139 --- /dev/null +++ b/src/utils3d/numpy/transforms.py @@ -0,0 +1,1104 @@ +import numpy as np +from typing import * +from numbers import Number +from ._helpers import batched +from .._helpers import no_warnings + + +__all__ = [ + 'perspective', + 'perspective_from_fov', + 'perspective_from_fov_xy', + 'intrinsics_from_focal_center', + 'intrinsics_from_fov', + 'fov_to_focal', + 'focal_to_fov', + 'intrinsics_to_fov', + 'view_look_at', + 'extrinsics_look_at', + 'perspective_to_intrinsics', + 'perspective_to_near_far', + 'intrinsics_to_perspective', + 'extrinsics_to_view', + 'view_to_extrinsics', + 'normalize_intrinsics', + 'crop_intrinsics', + 'pixel_to_uv', + 'pixel_to_ndc', + 'uv_to_pixel', + 'project_depth', + 'depth_buffer_to_linear', + 'unproject_cv', + 'unproject_gl', + 'project_cv', + 'project_gl', + 'quaternion_to_matrix', + 'axis_angle_to_matrix', + 'matrix_to_quaternion', + 'extrinsics_to_essential', + 'euler_axis_angle_rotation', + 'euler_angles_to_matrix', + 'skew_symmetric', + 'rotation_matrix_from_vectors', + 'ray_intersection', + 'se3_matrix', + 'slerp_quaternion', + 'slerp_vector', + 'lerp', + 'lerp_se3_matrix', + 'piecewise_lerp', + 'piecewise_lerp_se3_matrix', + 'apply_transform' +] + + +@batched(0,0,0,0) +def perspective( + fov_y: Union[float, np.ndarray], + aspect: Union[float, np.ndarray], + near: Union[float, np.ndarray], + far: Union[float, np.ndarray] +) -> np.ndarray: + """ + Get OpenGL perspective matrix + + Args: + fov_y (float | np.ndarray): field of view in y axis + aspect (float | np.ndarray): aspect ratio + near (float | np.ndarray): near plane to clip + far (float | np.ndarray): far plane to clip + + Returns: + (np.ndarray): [..., 4, 4] perspective matrix + """ + N = fov_y.shape[0] + ret = np.zeros((N, 4, 4), dtype=fov_y.dtype) + ret[:, 0, 0] = 1. / (np.tan(fov_y / 2) * aspect) + ret[:, 1, 1] = 1. / (np.tan(fov_y / 2)) + ret[:, 2, 2] = (near + far) / (near - far) + ret[:, 2, 3] = 2. * near * far / (near - far) + ret[:, 3, 2] = -1. + return ret + + +def perspective_from_fov( + fov: Union[float, np.ndarray], + width: Union[int, np.ndarray], + height: Union[int, np.ndarray], + near: Union[float, np.ndarray], + far: Union[float, np.ndarray] +) -> np.ndarray: + """ + Get OpenGL perspective matrix from field of view in largest dimension + + Args: + fov (float | np.ndarray): field of view in largest dimension + width (int | np.ndarray): image width + height (int | np.ndarray): image height + near (float | np.ndarray): near plane to clip + far (float | np.ndarray): far plane to clip + + Returns: + (np.ndarray): [..., 4, 4] perspective matrix + """ + fov_y = 2 * np.arctan(np.tan(fov / 2) * height / np.maximum(width, height)) + aspect = width / height + return perspective(fov_y, aspect, near, far) + + +def perspective_from_fov_xy( + fov_x: Union[float, np.ndarray], + fov_y: Union[float, np.ndarray], + near: Union[float, np.ndarray], + far: Union[float, np.ndarray] +) -> np.ndarray: + """ + Get OpenGL perspective matrix from field of view in x and y axis + + Args: + fov_x (float | np.ndarray): field of view in x axis + fov_y (float | np.ndarray): field of view in y axis + near (float | np.ndarray): near plane to clip + far (float | np.ndarray): far plane to clip + + Returns: + (np.ndarray): [..., 4, 4] perspective matrix + """ + aspect = np.tan(fov_x / 2) / np.tan(fov_y / 2) + return perspective(fov_y, aspect, near, far) + + +def intrinsics_from_focal_center( + fx: Union[float, np.ndarray], + fy: Union[float, np.ndarray], + cx: Union[float, np.ndarray], + cy: Union[float, np.ndarray], + dtype: Optional[np.dtype] = np.float32 +) -> np.ndarray: + """ + Get OpenCV intrinsics matrix + + Returns: + (np.ndarray): [..., 3, 3] OpenCV intrinsics matrix + """ + if any(isinstance(x, np.ndarray) for x in (fx, fy, cx, cy)): + dtype = np.result_type(fx, fy, cx, cy) + fx, fy, cx, cy = np.broadcast_arrays(fx, fy, cx, cy) + ret = np.zeros((*fx.shape, 3, 3), dtype=dtype) + ret[..., 0, 0] = fx + ret[..., 1, 1] = fy + ret[..., 0, 2] = cx + ret[..., 1, 2] = cy + ret[..., 2, 2] = 1. + return ret + + +def intrinsics_from_fov( + fov_max: Union[float, np.ndarray] = None, + fov_min: Union[float, np.ndarray] = None, + fov_x: Union[float, np.ndarray] = None, + fov_y: Union[float, np.ndarray] = None, + width: Union[int, np.ndarray] = None, + height: Union[int, np.ndarray] = None, +) -> np.ndarray: + """ + Get normalized OpenCV intrinsics matrix from given field of view. + You can provide either fov_max, fov_min, fov_x or fov_y + + Args: + width (int | np.ndarray): image width + height (int | np.ndarray): image height + fov_max (float | np.ndarray): field of view in largest dimension + fov_min (float | np.ndarray): field of view in smallest dimension + fov_x (float | np.ndarray): field of view in x axis + fov_y (float | np.ndarray): field of view in y axis + + Returns: + (np.ndarray): [..., 3, 3] OpenCV intrinsics matrix + """ + if fov_max is not None: + fx = np.maximum(width, height) / width / (2 * np.tan(fov_max / 2)) + fy = np.maximum(width, height) / height / (2 * np.tan(fov_max / 2)) + elif fov_min is not None: + fx = np.minimum(width, height) / width / (2 * np.tan(fov_min / 2)) + fy = np.minimum(width, height) / height / (2 * np.tan(fov_min / 2)) + elif fov_x is not None and fov_y is not None: + fx = 1 / (2 * np.tan(fov_x / 2)) + fy = 1 / (2 * np.tan(fov_y / 2)) + elif fov_x is not None: + fx = 1 / (2 * np.tan(fov_x / 2)) + fy = fx * width / height + elif fov_y is not None: + fy = 1 / (2 * np.tan(fov_y / 2)) + fx = fy * height / width + cx = 0.5 + cy = 0.5 + ret = intrinsics_from_focal_center(fx, fy, cx, cy) + return ret + + +def focal_to_fov(focal: np.ndarray): + return 2 * np.arctan(0.5 / focal) + + +def fov_to_focal(fov: np.ndarray): + return 0.5 / np.tan(fov / 2) + + +def intrinsics_to_fov(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + fov_x = focal_to_fov(intrinsics[..., 0, 0]) + fov_y = focal_to_fov(intrinsics[..., 1, 1]) + return fov_x, fov_y + + +@batched(1,1,1) +def view_look_at( + eye: np.ndarray, + look_at: np.ndarray, + up: np.ndarray + ) -> np.ndarray: + """ + Get OpenGL view matrix looking at something + + Args: + eye (np.ndarray): [..., 3] the eye position + look_at (np.ndarray): [..., 3] the position to look at + up (np.ndarray): [..., 3] head up direction (y axis in screen space). Not necessarily othogonal to view direction + + Returns: + (np.ndarray): [..., 4, 4], view matrix + """ + z = eye - look_at + x = np.cross(up, z) + y = np.cross(z, x) + # x = np.cross(y, z) + x = x / np.linalg.norm(x, axis=-1, keepdims=True) + y = y / np.linalg.norm(y, axis=-1, keepdims=True) + z = z / np.linalg.norm(z, axis=-1, keepdims=True) + R = np.stack([x, y, z], axis=-2) + t = -np.matmul(R, eye[..., None]) + return np.concatenate([ + np.concatenate([R, t], axis=-1), + np.array([[[0., 0., 0., 1.]]]).repeat(eye.shape[0], axis=0) + ], axis=-2) + + +@batched(1,1,1) +def extrinsics_look_at( + eye: np.ndarray, + look_at: np.ndarray, + up: np.ndarray +) -> np.ndarray: + """ + Get OpenCV extrinsics matrix looking at something + + Args: + eye (np.ndarray): [..., 3] the eye position + look_at (np.ndarray): [..., 3] the position to look at + up (np.ndarray): [..., 3] head up direction (-y axis in screen space). Not necessarily othogonal to view direction + + Returns: + (np.ndarray): [..., 4, 4], extrinsics matrix + """ + z = look_at - eye + x = np.cross(-up, z) + y = np.cross(z, x) + # x = np.cross(y, z) + x = x / np.linalg.norm(x, axis=-1, keepdims=True) + y = y / np.linalg.norm(y, axis=-1, keepdims=True) + z = z / np.linalg.norm(z, axis=-1, keepdims=True) + R = np.stack([x, y, z], axis=-2) + t = -np.matmul(R, eye[..., None]) + return np.concatenate([ + np.concatenate([R, t], axis=-1), + np.array([[[0., 0., 0., 1.]]], dtype=eye.dtype).repeat(eye.shape[0], axis=0) + ], axis=-2) + + +def perspective_to_intrinsics( + perspective: np.ndarray +) -> np.ndarray: + """ + OpenGL perspective matrix to OpenCV intrinsics + + Args: + perspective (np.ndarray): [..., 4, 4] OpenGL perspective matrix + + Returns: + (np.ndarray): shape [..., 3, 3] OpenCV intrinsics + """ + ret = np.array([[0.5, 0., 0.5], [0., -0.5, 0.5], [0., 0., 1.]], dtype=perspective.dtype) \ + @ perspective[..., [0, 1, 3], :3] \ + @ np.diag(np.array([1, -1, -1], dtype=perspective.dtype)) + return ret + + +def perspective_to_near_far(perspective: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Get near and far planes from OpenGL perspective matrix + + Args: + """ + a, b = perspective[..., 2, 2], perspective[..., 2, 3] + near, far = b / (a - 1), b / (a + 1) + return near, far + + +@batched(2,0,0) +def intrinsics_to_perspective( + intrinsics: np.ndarray, + near: Union[float, np.ndarray], + far: Union[float, np.ndarray], +) -> np.ndarray: + """ + OpenCV intrinsics to OpenGL perspective matrix + NOTE: not work for tile-shifting intrinsics currently + + Args: + intrinsics (np.ndarray): [..., 3, 3] OpenCV intrinsics matrix + near (float | np.ndarray): [...] near plane to clip + far (float | np.ndarray): [...] far plane to clip + Returns: + (np.ndarray): [..., 4, 4] OpenGL perspective matrix + """ + N = intrinsics.shape[0] + fx, fy = intrinsics[:, 0, 0], intrinsics[:, 1, 1] + cx, cy = intrinsics[:, 0, 2], intrinsics[:, 1, 2] + ret = np.zeros((N, 4, 4), dtype=intrinsics.dtype) + ret[:, 0, 0] = 2 * fx + ret[:, 1, 1] = 2 * fy + ret[:, 0, 2] = -2 * cx + 1 + ret[:, 1, 2] = 2 * cy - 1 + ret[:, 2, 2] = (near + far) / (near - far) + ret[:, 2, 3] = 2. * near * far / (near - far) + ret[:, 3, 2] = -1. + return ret + + +@batched(2) +def extrinsics_to_view( + extrinsics: np.ndarray + ) -> np.ndarray: + """ + OpenCV camera extrinsics to OpenGL view matrix + + Args: + extrinsics (np.ndarray): [..., 4, 4] OpenCV camera extrinsics matrix + + Returns: + (np.ndarray): [..., 4, 4] OpenGL view matrix + """ + return extrinsics * np.array([1, -1, -1, 1], dtype=extrinsics.dtype)[:, None] + + +@batched(2) +def view_to_extrinsics( + view: np.ndarray + ) -> np.ndarray: + """ + OpenGL view matrix to OpenCV camera extrinsics + + Args: + view (np.ndarray): [..., 4, 4] OpenGL view matrix + + Returns: + (np.ndarray): [..., 4, 4] OpenCV camera extrinsics matrix + """ + return view * np.array([1, -1, -1, 1], dtype=view.dtype)[:, None] + + +@batched(2, 0, 0, None) +def normalize_intrinsics( + intrinsics: np.ndarray, + width: Union[int, np.ndarray], + height: Union[int, np.ndarray], + integer_pixel_centers: bool = True +) -> np.ndarray: + """ + Normalize intrinsics from pixel cooridnates to uv coordinates + + Args: + intrinsics (np.ndarray): [..., 3, 3] camera intrinsics(s) to normalize + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + integer_pixel_centers (bool): whether the integer pixel coordinates are at the center of the pixel. If False, the integer coordinates are at the left-top corner of the pixel. + + Returns: + (np.ndarray): [..., 3, 3] normalized camera intrinsics(s) + """ + zeros = np.zeros_like(width) + ones = np.ones_like(width) + if integer_pixel_centers: + transform = np.stack([ + 1 / width, zeros, 0.5 / width, + zeros, 1 / height, 0.5 / height, + zeros, zeros, ones + ]).reshape(*zeros.shape, 3, 3) + else: + transform = np.stack([ + 1 / width, zeros, zeros, + zeros, 1 / height, zeros, + zeros, zeros, ones + ]).reshape(*zeros.shape, 3, 3) + return transform @ intrinsics + + +@batched(2,0,0,0,0,0,0) +def crop_intrinsics( + intrinsics: np.ndarray, + width: Union[int, np.ndarray], + height: Union[int, np.ndarray], + left: Union[int, np.ndarray], + top: Union[int, np.ndarray], + crop_width: Union[int, np.ndarray], + crop_height: Union[int, np.ndarray] +) -> np.ndarray: + """ + Evaluate the new intrinsics(s) after crop the image: cropped_img = img[top:top+crop_height, left:left+crop_width] + + Args: + intrinsics (np.ndarray): [..., 3, 3] camera intrinsics(s) to crop + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + left (int | np.ndarray): [...] left crop boundary + top (int | np.ndarray): [...] top crop boundary + crop_width (int | np.ndarray): [...] crop width + crop_height (int | np.ndarray): [...] crop height + + Returns: + (np.ndarray): [..., 3, 3] cropped camera intrinsics(s) + """ + zeros = np.zeros_like(width) + ones = np.ones_like(width) + transform = np.stack([ + width / crop_width, zeros, -left / crop_width, + zeros, height / crop_height, -top / crop_height, + zeros, zeros, ones + ]).reshape(*zeros.shape, 3, 3) + return transform @ intrinsics + + +@batched(1,0,0) +def pixel_to_uv( + pixel: np.ndarray, + width: Union[int, np.ndarray], + height: Union[int, np.ndarray] +) -> np.ndarray: + """ + Args: + pixel (np.ndarray): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + + Returns: + (np.ndarray): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1) + """ + if not np.issubdtype(pixel.dtype, np.floating): + pixel = pixel.astype(np.float32) + dtype = pixel.dtype + uv = (pixel + np.array(0.5, dtype=dtype)) / np.stack([width, height], axis=-1) + return uv + + +@batched(1,0,0) +def uv_to_pixel( + uv: np.ndarray, + width: Union[int, np.ndarray], + height: Union[int, np.ndarray] +) -> np.ndarray: + """ + Args: + pixel (np.ndarray): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + + Returns: + (np.ndarray): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1) + """ + pixel = uv * np.stack([width, height], axis=-1).astype(uv.dtype) - 0.5 + return pixel + + +@batched(1,0,0) +def pixel_to_ndc( + pixel: np.ndarray, + width: Union[int, np.ndarray], + height: Union[int, np.ndarray] +) -> np.ndarray: + """ + Args: + pixel (np.ndarray): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + + Returns: + (np.ndarray): [..., 2] pixel coordinrates defined in ndc space, the range is (-1, 1) + """ + if not np.issubdtype(pixel.dtype, np.floating): + pixel = pixel.astype(np.float32) + dtype = pixel.dtype + ndc = (pixel + np.array(0.5, dtype=dtype)) / (np.stack([width, height], dim=-1) * np.array([2, -2], dtype=dtype)) \ + + np.array([-1, 1], dtype=dtype) + return ndc + + +@batched(0,0,0) +def project_depth( + depth: np.ndarray, + near: Union[float, np.ndarray], + far: Union[float, np.ndarray] +) -> np.ndarray: + """ + Project linear depth to depth value in screen space + + Args: + depth (np.ndarray): [...] depth value + near (float | np.ndarray): [...] near plane to clip + far (float | np.ndarray): [...] far plane to clip + + Returns: + (np.ndarray): [..., 1] depth value in screen space, value ranging in [0, 1] + """ + return (far - near * far / depth) / (far - near) + + +@batched(0,0,0) +def depth_buffer_to_linear( + depth_buffer: np.ndarray, + near: Union[float, np.ndarray], + far: Union[float, np.ndarray] +) -> np.ndarray: + """ + OpenGL depth buffer to linear depth + + Args: + depth_buffer (np.ndarray): [...] depth value + near (float | np.ndarray): [...] near plane to clip + far (float | np.ndarray): [...] far plane to clip + + Returns: + (np.ndarray): [..., 1] linear depth + """ + return near * far / (far - (far - near) * depth_buffer) + + +@batched(2,2,2,2) +def project_gl( + points: np.ndarray, + model: np.ndarray = None, + view: np.ndarray = None, + perspective: np.ndarray = None + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Project 3D points to 2D following the OpenGL convention (except for row major matrice) + + Args: + points (np.ndarray): [..., N, 3] or [..., N, 4] 3D points to project, if the last + dimension is 4, the points are assumed to be in homogeneous coordinates + model (np.ndarray): [..., 4, 4] model matrix + view (np.ndarray): [..., 4, 4] view matrix + perspective (np.ndarray): [..., 4, 4] perspective matrix + + Returns: + scr_coord (np.ndarray): [..., N, 3] screen space coordinates, value ranging in [0, 1]. + The origin (0., 0., 0.) is corresponding to the left & bottom & nearest + linear_depth (np.ndarray): [..., N] linear depth + """ + assert perspective is not None, "perspective matrix is required" + if points.shape[-1] == 3: + points = np.concatenate([points, np.ones_like(points[..., :1])], axis=-1) + if model is not None: + points = points @ model.swapaxes(-1, -2) + if view is not None: + points = points @ view.swapaxes(-1, -2) + clip_coord = points @ perspective.swapaxes(-1, -2) + ndc_coord = clip_coord[..., :3] / clip_coord[..., 3:] + scr_coord = ndc_coord * 0.5 + 0.5 + linear_depth = clip_coord[..., 3] + return scr_coord, linear_depth + + +@batched(2,2,2) +def project_cv( + points: np.ndarray, + extrinsics: np.ndarray = None, + intrinsics: np.ndarray = None + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Project 3D points to 2D following the OpenCV convention + + Args: + points (np.ndarray): [..., N, 3] or [..., N, 4] 3D points to project, if the last + dimension is 4, the points are assumed to be in homogeneous coordinates + extrinsics (np.ndarray): [..., 4, 4] extrinsics matrix + intrinsics (np.ndarray): [..., 3, 3] intrinsics matrix + + Returns: + uv_coord (np.ndarray): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + linear_depth (np.ndarray): [..., N] linear depth + """ + assert intrinsics is not None, "intrinsics matrix is required" + if points.shape[-1] == 3: + points = np.concatenate([points, np.ones_like(points[..., :1])], axis=-1) + if extrinsics is not None: + points = points @ extrinsics.swapaxes(-1, -2) + points = points[..., :3] @ intrinsics.swapaxes(-1, -2) + with no_warnings(): + uv_coord = points[..., :2] / points[..., 2:] + linear_depth = points[..., 2] + return uv_coord, linear_depth + + +@batched(2,2,2,2) +def unproject_gl( + screen_coord: np.ndarray, + model: np.ndarray = None, + view: np.ndarray = None, + perspective: np.ndarray = None + ) -> np.ndarray: + """ + Unproject screen space coordinates to 3D view space following the OpenGL convention (except for row major matrice) + + Args: + screen_coord (np.ndarray): [..., N, 3] screen space coordinates, value ranging in [0, 1]. + The origin (0., 0., 0.) is corresponding to the left & bottom & nearest + model (np.ndarray): [..., 4, 4] model matrix + view (np.ndarray): [..., 4, 4] view matrix + perspective (np.ndarray): [..., 4, 4] perspective matrix + + Returns: + points (np.ndarray): [..., N, 3] 3d points + """ + assert perspective is not None, "perspective matrix is required" + ndc_xy = screen_coord * 2 - 1 + clip_coord = np.concatenate([ndc_xy, np.ones_like(ndc_xy[..., :1])], axis=-1) + transform = perspective + if view is not None: + transform = transform @ view + if model is not None: + transform = transform @ model + transform = np.linalg.inv(transform) + points = clip_coord @ transform.swapaxes(-1, -2) + points = points[..., :3] / points[..., 3:] + return points + + +@batched(2,1,2,2) +def unproject_cv( + uv_coord: np.ndarray, + depth: np.ndarray = None, + extrinsics: np.ndarray = None, + intrinsics: np.ndarray = None +) -> np.ndarray: + """ + Unproject uv coordinates to 3D view space following the OpenCV convention + + Args: + uv_coord (np.ndarray): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + depth (np.ndarray): [..., N] depth value + extrinsics (np.ndarray): [..., 4, 4] extrinsics matrix + intrinsics (np.ndarray): [..., 3, 3] intrinsics matrix + + Returns: + points (np.ndarray): [..., N, 3] 3d points + """ + assert intrinsics is not None, "intrinsics matrix is required" + points = np.concatenate([uv_coord, np.ones_like(uv_coord[..., :1])], axis=-1) + points = points @ np.linalg.inv(intrinsics).swapaxes(-1, -2) + if depth is not None: + points = points * depth[..., None] + if extrinsics is not None: + points = np.concatenate([points, np.ones_like(points[..., :1])], axis=-1) + points = (points @ np.linalg.inv(extrinsics).swapaxes(-1, -2))[..., :3] + return points + + +def quaternion_to_matrix(quaternion: np.ndarray, eps: float = 1e-12) -> np.ndarray: + """Converts a batch of quaternions (w, x, y, z) to rotation matrices + + Args: + quaternion (np.ndarray): shape (..., 4), the quaternions to convert + + Returns: + np.ndarray: shape (..., 3, 3), the rotation matrices corresponding to the given quaternions + """ + assert quaternion.shape[-1] == 4 + quaternion = quaternion / np.linalg.norm(quaternion, axis=-1, keepdims=True).clip(min=eps) + w, x, y, z = quaternion[..., 0], quaternion[..., 1], quaternion[..., 2], quaternion[..., 3] + zeros = np.zeros_like(w) + I = np.eye(3, dtype=quaternion.dtype) + xyz = quaternion[..., 1:] + A = xyz[..., :, None] * xyz[..., None, :] - I * (xyz ** 2).sum(axis=-1)[..., None, None] + B = np.stack([ + zeros, -z, y, + z, zeros, -x, + -y, x, zeros + ], axis=-1).reshape(*quaternion.shape[:-1], 3, 3) + rot_mat = I + 2 * (A + w[..., None, None] * B) + return rot_mat + + +def matrix_to_quaternion(rot_mat: np.ndarray, eps: float = 1e-12) -> np.ndarray: + """Convert 3x3 rotation matrix to quaternion (w, x, y, z) + + Args: + rot_mat (np.ndarray): shape (..., 3, 3), the rotation matrices to convert + + Returns: + np.ndarray: shape (..., 4), the quaternions corresponding to the given rotation matrices + """ + # Extract the diagonal and off-diagonal elements of the rotation matrix + m00, m01, m02, m10, m11, m12, m20, m21, m22 = [rot_mat[..., i, j] for i in range(3) for j in range(3)] + + diag = np.diagonal(rot_mat, axis1=-2, axis2=-1) + M = np.array([ + [1, 1, 1], + [1, -1, -1], + [-1, 1, -1], + [-1, -1, 1] + ], dtype=rot_mat.dtype) + wxyz = 0.5 * np.clip(1 + diag @ M.T, 0.0, None) ** 0.5 + max_idx = np.argmax(wxyz, axis=-1) + xw = np.sign(m21 - m12) + yw = np.sign(m02 - m20) + zw = np.sign(m10 - m01) + yz = np.sign(m21 + m12) + xz = np.sign(m02 + m20) + xy = np.sign(m01 + m10) + ones = np.ones_like(xw) + sign = np.where( + max_idx[..., None] == 0, + np.stack([ones, xw, yw, zw], axis=-1), + np.where( + max_idx[..., None] == 1, + np.stack([xw, ones, xy, xz], axis=-1), + np.where( + max_idx[..., None] == 2, + np.stack([yw, xy, ones, yz], axis=-1), + np.stack([zw, xz, yz, ones], axis=-1) + ) + ) + ) + quat = sign * wxyz + quat = quat / np.linalg.norm(quat, axis=-1, keepdims=True).clip(min=eps) + return quat + + +def extrinsics_to_essential(extrinsics: np.ndarray): + """ + extrinsics matrix `[[R, t] [0, 0, 0, 1]]` such that `x' = R (x - t)` to essential matrix such that `x' E x = 0` + + Args: + extrinsics (np.ndaray): [..., 4, 4] extrinsics matrix + + Returns: + (np.ndaray): [..., 3, 3] essential matrix + """ + assert extrinsics.shape[-2:] == (4, 4) + R = extrinsics[..., :3, :3] + t = extrinsics[..., :3, 3] + zeros = np.zeros_like(t[..., 0]) + t_x = np.stack([ + zeros, -t[..., 2], t[..., 1], + t[..., 2], zeros, -t[..., 0], + -t[..., 1], t[..., 0], zeros + ]).reshape(*t.shape[:-1], 3, 3) + return t_x @ R + + +def euler_axis_angle_rotation(axis: str, angle: np.ndarray) -> np.ndarray: + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = np.cos(angle) + sin = np.sin(angle) + one = np.ones_like(angle) + zero = np.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return np.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles: np.ndarray, convention: str = 'XYZ') -> np.ndarray: + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as ndarray of shape (..., 3), XYZ + convention: permutation of "X", "Y" or "Z", representing the order of Euler rotations to apply. + + Returns: + Rotation matrices as ndarray of shape (..., 3, 3). + """ + if euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [ + euler_axis_angle_rotation(c, euler_angles[..., 'XYZ'.index(c)]) + for c in convention + ] + return matrices[2] @ matrices[1] @ matrices[0] + + +def skew_symmetric(v: np.ndarray): + "Skew symmetric matrix from a 3D vector" + assert v.shape[-1] == 3, "v must be 3D" + x, y, z = v[..., 0], v[..., 1], v[..., 2] + zeros = np.zeros_like(x) + return np.stack([ + zeros, -z, y, + z, zeros, -x, + -y, x, zeros, + ], axis=-1).reshape(*v.shape[:-1], 3, 3) + + +def rotation_matrix_from_vectors(v1: np.ndarray, v2: np.ndarray): + "Rotation matrix that rotates v1 to v2" + I = np.eye(3, dtype=v1.dtype) + v1 = v1 / np.linalg.norm(v1, axis=-1) + v2 = v2 / np.linalg.norm(v2, axis=-1) + v = np.cross(v1, v2, axis=-1) + c = np.sum(v1 * v2, axis=-1) + K = skew_symmetric(v) + R = I + K + (1 / (1 + c)).astype(v1.dtype)[None, None] * (K @ K) # Avoid numpy's default type casting for scalars + return R + + +def axis_angle_to_matrix(axis_angle: np.ndarray, eps: float = 1e-12) -> np.ndarray: + """Convert axis-angle representation (rotation vector) to rotation matrix, whose direction is the axis of rotation and length is the angle of rotation + + Args: + axis_angle (np.ndarray): shape (..., 3), axis-angle vcetors + + Returns: + np.ndarray: shape (..., 3, 3) The rotation matrices for the given axis-angle parameters + """ + batch_shape = axis_angle.shape[:-1] + dtype = axis_angle.dtype + + angle = np.linalg.norm(axis_angle, axis=-1, keepdims=True) + axis = axis_angle / (angle + eps) + + cos = np.cos(angle)[..., None, :] + sin = np.sin(angle)[..., None, :] + + rx, ry, rz = np.split(axis, 3, axis=-1) + zeros = np.zeros((*batch_shape, 1), dtype=dtype) + K = np.concatenate([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], axis=-1).reshape((*batch_shape, 3, 3)) + + ident = np.eye(3, dtype=dtype) + rot_mat = ident + sin * K + (1 - cos) * (K @ K) + return rot_mat + + +def ray_intersection(p1: np.ndarray, d1: np.ndarray, p2: np.ndarray, d2: np.ndarray): + """ + Compute the intersection/closest point of two D-dimensional rays + If the rays are intersecting, the closest point is the intersection point. + + Args: + p1 (np.ndarray): (..., D) origin of ray 1 + d1 (np.ndarray): (..., D) direction of ray 1 + p2 (np.ndarray): (..., D) origin of ray 2 + d2 (np.ndarray): (..., D) direction of ray 2 + + Returns: + (np.ndarray): (..., N) intersection point + """ + p1, d1, p2, d2 = np.broadcast_arrays(p1, d1, p2, d2) + dtype = p1.dtype + dim = p1.shape[-1] + d = np.stack([d1, d2], axis=-2) # (..., 2, D) + p = np.stack([p1, p2], axis=-2) # (..., 2, D) + A = np.concatenate([ + (np.eye(dim, dtype=dtype) * np.ones((*p.shape[:-2], 2, 1, 1))).reshape(*d.shape[:-2], 2 * dim, dim), # (..., 2 * D, D) + -(np.eye(2, dtype=dtype)[..., None] * d[..., None, :]).swapaxes(-2, -1).reshape(*d.shape[:-2], 2 * dim, 2) # (..., 2 * D, 2) + ], axis=-1) # (..., 2 * D, D + 2) + b = p.reshape(*p.shape[:-2], 2 * dim) # (..., 2 * D) + x = np.linalg.solve(A.swapaxes(-1, -2) @ A + 1e-12 * np.eye(dim + 2, dtype=dtype), (A.swapaxes(-1, -2) @ b[..., :, None])[..., 0]) + return x[..., :dim], (x[..., dim], x[..., dim + 1]) + + +def se3_matrix(R: np.ndarray, t: np.ndarray) -> np.ndarray: + """ + Convert rotation matrix and translation vector to 4x4 transformation matrix. + + Args: + R (np.ndarray): [..., 3, 3] rotation matrix + t (np.ndarray): [..., 3] translation vector + + Returns: + np.ndarray: [..., 4, 4] transformation matrix + """ + assert R.shape[:-2] == t.shape[:-1] + assert R.shape[-1] == 3 and R.shape[-2] == 3 + return np.concatenate([ + np.concatenate([R, t[..., None]], axis=-1), + np.concatenate([np.zeros_like(t), np.ones_like(t[..., :1])], axis=-1)[..., None, :] + ], axis=-2) + + +def slerp_quaternion(q1: np.ndarray, q2: np.ndarray, t: np.ndarray) -> np.ndarray: + """ + Spherical linear interpolation between two unit quaternions. + + Args: + q1 (np.ndarray): [..., d] unit vector 1 + q2 (np.ndarray): [..., d] unit vector 2 + t (np.ndarray): [...] interpolation parameter in [0, 1] + + Returns: + np.ndarray: [..., 3] interpolated unit vector + """ + q1 = q1 / np.linalg.norm(q1, axis=-1, keepdims=True) + q2 = q2 / np.linalg.norm(q2, axis=-1, keepdims=True) + dot = np.sum(q1 * q2, axis=-1, keepdims=True) + + dot = np.where(dot < 0, -dot, dot) # handle negative dot product + + dot = np.minimum(dot, 1.) + theta = np.arccos(dot) * t + + q_ortho = q2 - q1 * dot + q_ortho = q_ortho / np.maximum(np.linalg.norm(q_ortho, axis=-1, keepdims=True), 1e-12) + q = q1 * np.cos(theta) + q_ortho * np.sin(theta) + return q + + +def slerp_rotation_matrix(R1: np.ndarray, R2: np.ndarray, t: np.ndarray) -> np.ndarray: + """ + Spherical linear interpolation between two rotation matrices. + + Args: + R1 (np.ndarray): [..., 3, 3] rotation matrix 1 + R2 (np.ndarray): [..., 3, 3] rotation matrix 2 + t (np.ndarray): [...] interpolation parameter in [0, 1] + + Returns: + np.ndarray: [..., 3, 3] interpolated rotation matrix + """ + quat1 = matrix_to_quaternion(R1) + quat2 = matrix_to_quaternion(R2) + quat = slerp_quaternion(quat1, quat2, t) + return quaternion_to_matrix(quat) + + +def slerp_vector(v1: np.ndarray, v2: np.ndarray, t: np.ndarray) -> np.ndarray: + """ + Spherical linear interpolation between two unit vectors. The vectors are assumed to be normalized. + + Args: + v1 (np.ndarray): [..., d] unit vector 1 + v2 (np.ndarray): [..., d] unit vector 2 + t (np.ndarray): [...] interpolation parameter in [0, 1] + + Returns: + np.ndarray: [..., d] interpolated unit vector + """ + dot = np.sum(v1 * v2, axis=-1, keepdims=True) + + dot = np.minimum(dot, 1.) + theta = np.arccos(dot) * t + + v_ortho = v2 - v1 * dot + v_ortho = v_ortho / np.maximum(np.linalg.norm(v_ortho, axis=-1, keepdims=True), 1e-12) + v = v1 * np.cos(theta) + v_ortho * np.sin(theta) + return v + + +def lerp(x1: np.ndarray, x2: np.ndarray, t: np.ndarray) -> np.ndarray: + """ + Linear interpolation between two vectors. + + Args: + x1 (np.ndarray): [..., d] vector 1 + x2 (np.ndarray): [..., d] vector 2 + t (np.ndarray): [...] interpolation parameter. [0, 1] for interpolation between x1 and x2, otherwise for extrapolation. + + Returns: + np.ndarray: [..., d] interpolated vector + """ + return x1 + np.asarray(t)[..., None] * (x2 - x1) + + +def lerp_se3_matrix(T1: np.ndarray, T2: np.ndarray, t: np.ndarray) -> np.ndarray: + """ + Linear interpolation between two SE(3) matrices. + + Args: + T1 (np.ndarray): [..., 4, 4] SE(3) matrix 1 + T2 (np.ndarray): [..., 4, 4] SE(3) matrix 2 + t (np.ndarray): [...] interpolation parameter in [0, 1] + + Returns: + np.ndarray: [..., 4, 4] interpolated SE(3) matrix + """ + R1 = T1[..., :3, :3] + R2 = T2[..., :3, :3] + trans1 = T1[..., :3, 3] + trans2 = T2[..., :3, 3] + R = slerp_rotation_matrix(R1, R2, t) + trans = lerp(trans1, trans2, t) + return se3_matrix(R, trans) + + +def piecewise_lerp(x: np.ndarray, t: np.ndarray, s: np.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> np.ndarray: + """ + Linear spline interpolation. + + ### Parameters: + - `x`: np.ndarray, shape (n, d): the values of data points. + - `t`: np.ndarray, shape (n,): the times of the data points. + - `s`: np.ndarray, shape (m,): the times to be interpolated. + - `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly. + + ### Returns: + - `y`: np.ndarray, shape (..., m, d): the interpolated values. + """ + i = np.searchsorted(t, s, side='left') + if extrapolation_mode == 'constant': + prev = np.clip(i - 1, 0, len(t) - 1) + suc = np.clip(i, 0, len(t) - 1) + elif extrapolation_mode == 'linear': + prev = np.clip(i - 1, 0, len(t) - 2) + suc = np.clip(i, 1, len(t) - 1) + else: + raise ValueError(f'Invalid extrapolation_mode: {extrapolation_mode}') + + u = (s - t[prev]) / np.maximum(t[suc] - t[prev], 1e-12) + y = lerp(x[prev], x[suc], u) + + return y + + +def piecewise_lerp_se3_matrix(T: np.ndarray, t: np.ndarray, s: np.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> np.ndarray: + """ + Linear spline interpolation for SE(3) matrices. + + ### Parameters: + - `T`: np.ndarray, shape (n, 4, 4): the SE(3) matrices. + - `t`: np.ndarray, shape (n,): the times of the data points. + - `s`: np.ndarray, shape (m,): the times to be interpolated. + - `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly. + + ### Returns: + - `T_interp`: np.ndarray, shape (..., m, 4, 4): the interpolated SE(3) matrices. + """ + i = np.searchsorted(t, s, side='left') + if extrapolation_mode == 'constant': + prev = np.clip(i - 1, 0, len(t) - 1) + suc = np.clip(i, 0, len(t) - 1) + elif extrapolation_mode == 'linear': + prev = np.clip(i - 1, 0, len(t) - 2) + suc = np.clip(i, 1, len(t) - 1) + else: + raise ValueError(f'Invalid extrapolation_mode: {extrapolation_mode}') + + u = (s - t[prev]) / np.maximum(t[suc] - t[prev], 1e-12) + T = lerp_se3_matrix(T[prev], T[suc], u) + + return T + + +def apply_transform(T: np.ndarray, x: np.ndarray) -> np.ndarray: + """ + Apply SE(3) transformation to a point or a set of points. + + ### Parameters: + - `T`: np.ndarray, shape (..., 4, 4): the SE(3) matrix. + - `x`: np.ndarray, shape (..., 3): the point or a set of points to be transformed. + + ### Returns: + - `x_transformed`: np.ndarray, shape (..., 3): the transformed point or a set of points. + """ + x = np.asarray(x) + assert x.shape[-1] == 3 + T = np.asarray(T) + assert T.shape[-2:] == (4, 4) + x_transformed = (T[..., :3, :3] @ x[..., :, None]) + T[..., :3, 3][..., None] + return x_transformed[..., 0] \ No newline at end of file diff --git a/src/utils3d/numpy/utils.py b/src/utils3d/numpy/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8f3fbfccef9eb8224f658307c425125095f4c2a7 --- /dev/null +++ b/src/utils3d/numpy/utils.py @@ -0,0 +1,625 @@ +import numpy as np +from typing import * +from numbers import Number +import warnings +import functools + +from ._helpers import batched +from .._helpers import no_warnings +from . import transforms +from . import mesh + +__all__ = [ + 'sliding_window_1d', + 'sliding_window_nd', + 'sliding_window_2d', + 'max_pool_1d', + 'max_pool_2d', + 'max_pool_nd', + 'depth_edge', + 'normals_edge', + 'depth_aliasing', + 'interpolate', + 'image_scrcoord', + 'image_uv', + 'image_pixel_center', + 'image_pixel', + 'image_mesh', + 'image_mesh_from_depth', + 'points_to_normals', + 'points_to_normals', + 'chessboard', + 'cube', + 'icosahedron', + 'square', + 'camera_frustum', + 'to4x4' +] + + + +def sliding_window_1d(x: np.ndarray, window_size: int, stride: int, axis: int = -1): + """ + Return x view of the input array with x sliding window of the given kernel size and stride. + The sliding window is performed over the given axis, and the window dimension is append to the end of the output array's shape. + + Args: + x (np.ndarray): input array with shape (..., axis_size, ...) + kernel_size (int): size of the sliding window + stride (int): stride of the sliding window + axis (int): axis to perform sliding window over + + Returns: + a_sliding (np.ndarray): view of the input array with shape (..., n_windows, ..., kernel_size), where n_windows = (axis_size - kernel_size + 1) // stride + """ + assert x.shape[axis] >= window_size, f"kernel_size ({window_size}) is larger than axis_size ({x.shape[axis]})" + axis = axis % x.ndim + shape = (*x.shape[:axis], (x.shape[axis] - window_size + 1) // stride, *x.shape[axis + 1:], window_size) + strides = (*x.strides[:axis], stride * x.strides[axis], *x.strides[axis + 1:], x.strides[axis]) + x_sliding = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides) + return x_sliding + + +def sliding_window_nd(x: np.ndarray, window_size: Tuple[int,...], stride: Tuple[int,...], axis: Tuple[int,...]) -> np.ndarray: + axis = [axis[i] % x.ndim for i in range(len(axis))] + for i in range(len(axis)): + x = sliding_window_1d(x, window_size[i], stride[i], axis[i]) + return x + + +def sliding_window_2d(x: np.ndarray, window_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], axis: Tuple[int, int] = (-2, -1)) -> np.ndarray: + if isinstance(window_size, int): + window_size = (window_size, window_size) + if isinstance(stride, int): + stride = (stride, stride) + return sliding_window_nd(x, window_size, stride, axis) + + +def max_pool_1d(x: np.ndarray, kernel_size: int, stride: int, padding: int = 0, axis: int = -1): + axis = axis % x.ndim + if padding > 0: + fill_value = np.nan if x.dtype.kind == 'f' else np.iinfo(x.dtype).min + padding_arr = np.full((*x.shape[:axis], padding, *x.shape[axis + 1:]), fill_value=fill_value, dtype=x.dtype) + x = np.concatenate([padding_arr, x, padding_arr], axis=axis) + a_sliding = sliding_window_1d(x, kernel_size, stride, axis) + max_pool = np.nanmax(a_sliding, axis=-1) + return max_pool + + +def max_pool_nd(x: np.ndarray, kernel_size: Tuple[int,...], stride: Tuple[int,...], padding: Tuple[int,...], axis: Tuple[int,...]) -> np.ndarray: + for i in range(len(axis)): + x = max_pool_1d(x, kernel_size[i], stride[i], padding[i], axis[i]) + return x + + +def max_pool_2d(x: np.ndarray, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], padding: Union[int, Tuple[int, int]], axis: Tuple[int, int] = (-2, -1)): + if isinstance(kernel_size, Number): + kernel_size = (kernel_size, kernel_size) + if isinstance(stride, Number): + stride = (stride, stride) + if isinstance(padding, Number): + padding = (padding, padding) + axis = tuple(axis) + return max_pool_nd(x, kernel_size, stride, padding, axis) + +@no_warnings(category=RuntimeWarning) +def depth_edge(depth: np.ndarray, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: np.ndarray = None) -> np.ndarray: + """ + Compute the edge mask from depth map. The edge is defined as the pixels whose neighbors have large difference in depth. + + Args: + depth (np.ndarray): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + + Returns: + edge (np.ndarray): shape (..., height, width) of dtype torch.bool + """ + if mask is None: + diff = (max_pool_2d(depth, kernel_size, stride=1, padding=kernel_size // 2) + max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2)) + else: + diff = (max_pool_2d(np.where(mask, depth, -np.inf), kernel_size, stride=1, padding=kernel_size // 2) + max_pool_2d(np.where(mask, -depth, -np.inf), kernel_size, stride=1, padding=kernel_size // 2)) + + edge = np.zeros_like(depth, dtype=bool) + if atol is not None: + edge |= diff > atol + + if rtol is not None: + edge |= diff / depth > rtol + return edge + + +def depth_aliasing(depth: np.ndarray, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: np.ndarray = None) -> np.ndarray: + """ + Compute the map that indicates the aliasing of x depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors. + Args: + depth (np.ndarray): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + + Returns: + edge (np.ndarray): shape (..., height, width) of dtype torch.bool + """ + if mask is None: + diff_max = max_pool_2d(depth, kernel_size, stride=1, padding=kernel_size // 2) - depth + diff_min = max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + depth + else: + diff_max = max_pool_2d(np.where(mask, depth, -np.inf), kernel_size, stride=1, padding=kernel_size // 2) - depth + diff_min = max_pool_2d(np.where(mask, -depth, -np.inf), kernel_size, stride=1, padding=kernel_size // 2) + depth + diff = np.minimum(diff_max, diff_min) + + edge = np.zeros_like(depth, dtype=bool) + if atol is not None: + edge |= diff > atol + if rtol is not None: + edge |= diff / depth > rtol + return edge + +@no_warnings(category=RuntimeWarning) +def normals_edge(normals: np.ndarray, tol: float, kernel_size: int = 3, mask: np.ndarray = None) -> np.ndarray: + """ + Compute the edge mask from normal map. + + Args: + normal (np.ndarray): shape (..., height, width, 3), normal map + tol (float): tolerance in degrees + + Returns: + edge (np.ndarray): shape (..., height, width) of dtype torch.bool + """ + assert normals.ndim >= 3 and normals.shape[-1] == 3, "normal should be of shape (..., height, width, 3)" + normals = normals / (np.linalg.norm(normals, axis=-1, keepdims=True) + 1e-12) + + padding = kernel_size // 2 + normals_window = sliding_window_2d( + np.pad(normals, (*([(0, 0)] * (normals.ndim - 3)), (padding, padding), (padding, padding), (0, 0)), mode='edge'), + window_size=kernel_size, + stride=1, + axis=(-3, -2) + ) + if mask is None: + angle_diff = np.arccos((normals[..., None, None] * normals_window).sum(axis=-3)).max(axis=(-2, -1)) + else: + mask_window = sliding_window_2d( + np.pad(mask, (*([(0, 0)] * (mask.ndim - 3)), (padding, padding), (padding, padding)), mode='edge'), + window_size=kernel_size, + stride=1, + axis=(-3, -2) + ) + angle_diff = np.where(mask_window, np.arccos((normals[..., None, None] * normals_window).sum(axis=-3)), 0).max(axis=(-2, -1)) + + angle_diff = max_pool_2d(angle_diff, kernel_size, stride=1, padding=kernel_size // 2) + edge = angle_diff > np.deg2rad(tol) + return edge + + +@no_warnings(category=RuntimeWarning) +def points_to_normals(point: np.ndarray, mask: np.ndarray = None) -> np.ndarray: + """ + Calculate normal map from point map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. + + Args: + point (np.ndarray): shape (height, width, 3), point map + Returns: + normal (np.ndarray): shape (height, width, 3), normal map. + """ + height, width = point.shape[-3:-1] + has_mask = mask is not None + + if mask is None: + mask = np.ones_like(point[..., 0], dtype=bool) + mask_pad = np.zeros((height + 2, width + 2), dtype=bool) + mask_pad[1:-1, 1:-1] = mask + mask = mask_pad + + pts = np.zeros((height + 2, width + 2, 3), dtype=point.dtype) + pts[1:-1, 1:-1, :] = point + up = pts[:-2, 1:-1, :] - pts[1:-1, 1:-1, :] + left = pts[1:-1, :-2, :] - pts[1:-1, 1:-1, :] + down = pts[2:, 1:-1, :] - pts[1:-1, 1:-1, :] + right = pts[1:-1, 2:, :] - pts[1:-1, 1:-1, :] + normal = np.stack([ + np.cross(up, left, axis=-1), + np.cross(left, down, axis=-1), + np.cross(down, right, axis=-1), + np.cross(right, up, axis=-1), + ]) + normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12) + valid = np.stack([ + mask[:-2, 1:-1] & mask[1:-1, :-2], + mask[1:-1, :-2] & mask[2:, 1:-1], + mask[2:, 1:-1] & mask[1:-1, 2:], + mask[1:-1, 2:] & mask[:-2, 1:-1], + ]) & mask[None, 1:-1, 1:-1] + normal = (normal * valid[..., None]).sum(axis=0) + normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12) + + if has_mask: + normal_mask = valid.any(axis=0) + normal = np.where(normal_mask[..., None], normal, 0) + return normal, normal_mask + else: + return normal + + +def depth_to_normals(depth: np.ndarray, intrinsics: np.ndarray, mask: np.ndarray = None) -> np.ndarray: + """ + Calculate normal map from depth map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. + + Args: + depth (np.ndarray): shape (height, width), linear depth map + intrinsics (np.ndarray): shape (3, 3), intrinsics matrix + Returns: + normal (np.ndarray): shape (height, width, 3), normal map. + """ + has_mask = mask is not None + + height, width = depth.shape[-2:] + if mask is None: + mask = np.ones_like(depth, dtype=bool) + + uv = image_uv(width=width, height=height, dtype=np.float32) + pts = transforms.unproject_cv(uv, depth, intrinsics=intrinsics, extrinsics=None) + + return points_to_normals(pts, mask) + +def interpolate(bary: np.ndarray, tri_id: np.ndarray, attr: np.ndarray, faces: np.ndarray) -> np.ndarray: + """Interpolate with given barycentric coordinates and triangle indices + + Args: + bary (np.ndarray): shape (..., 3), barycentric coordinates + tri_id (np.ndarray): int array of shape (...), triangle indices + attr (np.ndarray): shape (N, M), vertices attributes + faces (np.ndarray): int array of shape (T, 3), face vertex indices + + Returns: + np.ndarray: shape (..., M) interpolated result + """ + faces_ = np.concatenate([np.zeros((1, 3), dtype=faces.dtype), faces + 1], axis=0) + attr_ = np.concatenate([np.zeros((1, attr.shape[1]), dtype=attr.dtype), attr], axis=0) + return np.sum(bary[..., None] * attr_[faces_[tri_id + 1]], axis=-2) + + +def image_scrcoord( + width: int, + height: int, +) -> np.ndarray: + """ + Get OpenGL's screen space coordinates, ranging in [0, 1]. + [0, 0] is the bottom-left corner of the image. + + Args: + width (int): image width + height (int): image height + + Returns: + (np.ndarray): shape (height, width, 2) + """ + x, y = np.meshgrid( + np.linspace(0.5 / width, 1 - 0.5 / width, width, dtype=np.float32), + np.linspace(1 - 0.5 / height, 0.5 / height, height, dtype=np.float32), + indexing='xy' + ) + return np.stack([x, y], axis=2) + + +def image_uv( + height: int, + width: int, + left: int = None, + top: int = None, + right: int = None, + bottom: int = None, + dtype: np.dtype = np.float32 +) -> np.ndarray: + """ + Get image space UV grid, ranging in [0, 1]. + + >>> image_uv(10, 10): + [[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]], + [[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]], + ... ... ... + [[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]] + + Args: + width (int): image width + height (int): image height + + Returns: + np.ndarray: shape (height, width, 2) + """ + if left is None: left = 0 + if top is None: top = 0 + if right is None: right = width + if bottom is None: bottom = height + u = np.linspace((left + 0.5) / width, (right - 0.5) / width, right - left, dtype=dtype) + v = np.linspace((top + 0.5) / height, (bottom - 0.5) / height, bottom - top, dtype=dtype) + u, v = np.meshgrid(u, v, indexing='xy') + return np.stack([u, v], axis=2) + + +def image_pixel_center( + height: int, + width: int, + left: int = None, + top: int = None, + right: int = None, + bottom: int = None, + dtype: np.dtype = np.float32 +) -> np.ndarray: + """ + Get image pixel center coordinates, ranging in [0, width] and [0, height]. + `image[i, j]` has pixel center coordinates `(j + 0.5, i + 0.5)`. + + >>> image_pixel_center(10, 10): + [[[0.5, 0.5], [1.5, 0.5], ..., [9.5, 0.5]], + [[0.5, 1.5], [1.5, 1.5], ..., [9.5, 1.5]], + ... ... ... + [[0.5, 9.5], [1.5, 9.5], ..., [9.5, 9.5]]] + + Args: + width (int): image width + height (int): image height + + Returns: + np.ndarray: shape (height, width, 2) + """ + if left is None: left = 0 + if top is None: top = 0 + if right is None: right = width + if bottom is None: bottom = height + u = np.linspace(left + 0.5, right - 0.5, right - left, dtype=dtype) + v = np.linspace(top + 0.5, bottom - 0.5, bottom - top, dtype=dtype) + u, v = np.meshgrid(u, v, indexing='xy') + return np.stack([u, v], axis=2) + +def image_pixel( + height: int, + width: int, + left: int = None, + top: int = None, + right: int = None, + bottom: int = None, + dtype: np.dtype = np.int32 +) -> np.ndarray: + """ + Get image pixel coordinates grid, ranging in [0, width - 1] and [0, height - 1]. + `image[i, j]` has pixel center coordinates `(j, i)`. + + >>> image_pixel_center(10, 10): + [[[0, 0], [1, 0], ..., [9, 0]], + [[0, 1.5], [1, 1], ..., [9, 1]], + ... ... ... + [[0, 9.5], [1, 9], ..., [9, 9 ]]] + + Args: + width (int): image width + height (int): image height + + Returns: + np.ndarray: shape (height, width, 2) + """ + if left is None: left = 0 + if top is None: top = 0 + if right is None: right = width + if bottom is None: bottom = height + u = np.arange(left, right, dtype=dtype) + v = np.arange(top, bottom, dtype=dtype) + u, v = np.meshgrid(u, v, indexing='xy') + return np.stack([u, v], axis=2) + + +def image_mesh( + *image_attrs: np.ndarray, + mask: np.ndarray = None, + tri: bool = False, + return_indices: bool = False +) -> Tuple[np.ndarray, ...]: + """ + Get a mesh regarding image pixel uv coordinates as vertices and image grid as faces. + + Args: + *image_attrs (np.ndarray): image attributes in shape (height, width, [channels]) + mask (np.ndarray, optional): binary mask of shape (height, width), dtype=bool. Defaults to None. + + Returns: + faces (np.ndarray): faces connecting neighboring pixels. shape (T, 4) if tri is False, else (T, 3) + *vertex_attrs (np.ndarray): vertex attributes in corresponding order with input image_attrs + indices (np.ndarray, optional): indices of vertices in the original mesh + """ + assert (len(image_attrs) > 0) or (mask is not None), "At least one of image_attrs or mask should be provided" + height, width = next(image_attrs).shape[:2] if mask is None else mask.shape + assert all(img.shape[:2] == (height, width) for img in image_attrs), "All image_attrs should have the same shape" + + row_faces = np.stack([np.arange(0, width - 1, dtype=np.int32), np.arange(width, 2 * width - 1, dtype=np.int32), np.arange(1 + width, 2 * width, dtype=np.int32), np.arange(1, width, dtype=np.int32)], axis=1) + faces = (np.arange(0, (height - 1) * width, width, dtype=np.int32)[:, None, None] + row_faces[None, :, :]).reshape((-1, 4)) + if mask is None: + if tri: + faces = mesh.triangulate(faces) + ret = [faces, *(img.reshape(-1, *img.shape[2:]) for img in image_attrs)] + if return_indices: + ret.append(np.arange(height * width, dtype=np.int32)) + return tuple(ret) + else: + quad_mask = (mask[:-1, :-1] & mask[1:, :-1] & mask[1:, 1:] & mask[:-1, 1:]).ravel() + faces = faces[quad_mask] + if tri: + faces = mesh.triangulate(faces) + return mesh.remove_unreferenced_vertices( + faces, + *(x.reshape(-1, *x.shape[2:]) for x in image_attrs), + return_indices=return_indices + ) + +def image_mesh_from_depth( + depth: np.ndarray, + extrinsics: np.ndarray = None, + intrinsics: np.ndarray = None, + *vertice_attrs: np.ndarray, + atol: float = None, + rtol: float = None, + remove_by_depth: bool = False, + return_uv: bool = False, + return_indices: bool = False +) -> Tuple[np.ndarray, ...]: + """ + Get x triangle mesh by lifting depth map to 3D. + + Args: + depth (np.ndarray): [H, W] depth map + extrinsics (np.ndarray, optional): [4, 4] extrinsics matrix. Defaults to None. + intrinsics (np.ndarray, optional): [3, 3] intrinsics matrix. Defaults to None. + *vertice_attrs (np.ndarray): [H, W, C] vertex attributes. Defaults to None. + atol (float, optional): absolute tolerance. Defaults to None. + rtol (float, optional): relative tolerance. Defaults to None. + triangles with vertices having depth difference larger than atol + rtol * depth will be marked. + remove_by_depth (bool, optional): whether to remove triangles with large depth difference. Defaults to True. + return_uv (bool, optional): whether to return uv coordinates. Defaults to False. + return_indices (bool, optional): whether to return indices of vertices in the original mesh. Defaults to False. + + Returns: + vertices (np.ndarray): [N, 3] vertices + faces (np.ndarray): [T, 3] faces + *vertice_attrs (np.ndarray): [N, C] vertex attributes + image_uv (np.ndarray, optional): [N, 2] uv coordinates + ref_indices (np.ndarray, optional): [N] indices of vertices in the original mesh + """ + height, width = depth.shape + image_uv, image_face = image_mesh(height, width) + depth = depth.reshape(-1) + pts = transforms.unproject_cv(image_uv, depth, extrinsics, intrinsics) + image_face = mesh.triangulate(image_face, vertices=pts) + ref_indices = None + ret = [] + if atol is not None or rtol is not None: + atol = 0 if atol is None else atol + rtol = 0 if rtol is None else rtol + mean = depth[image_face].mean(axis=1) + diff = np.max(np.abs(depth[image_face] - depth[image_face[:, [1, 2, 0]]]), axis=1) + mask = (diff <= atol + rtol * mean) + image_face_ = image_face[mask] + image_face_, ref_indices = mesh.remove_unreferenced_vertices(image_face_, return_indices=True) + + remove = remove_by_depth and ref_indices is not None + if remove: + pts = pts[ref_indices] + image_face = image_face_ + ret += [pts, image_face] + for attr in vertice_attrs: + ret.append(attr.reshape(-1, attr.shape[-1]) if not remove else attr.reshape(-1, attr.shape[-1])[ref_indices]) + if return_uv: + ret.append(image_uv if not remove else image_uv[ref_indices]) + if return_indices and ref_indices is not None: + ret.append(ref_indices) + return tuple(ret) + + +def chessboard(width: int, height: int, grid_size: int, color_a: np.ndarray, color_b: np.ndarray) -> np.ndarray: + """get x chessboard image + + Args: + width (int): image width + height (int): image height + grid_size (int): size of chessboard grid + color_a (np.ndarray): color of the grid at the top-left corner + color_b (np.ndarray): color in complementary grid cells + + Returns: + image (np.ndarray): shape (height, width, channels), chessboard image + """ + x = np.arange(width) // grid_size + y = np.arange(height) // grid_size + mask = (x[None, :] + y[:, None]) % 2 + image = (1 - mask[..., None]) * color_a + mask[..., None] * color_b + return image + + +def square(tri: bool = False) -> Tuple[np.ndarray, np.ndarray]: + """ + Get a square mesh of area 1 centered at origin in the xy-plane. + + ### Returns + vertices (np.ndarray): shape (4, 3) + faces (np.ndarray): shape (1, 4) + """ + vertices = np.array([ + [-0.5, 0.5, 0], [0.5, 0.5, 0], [0.5, -0.5, 0], [-0.5, -0.5, 0] # v0-v1-v2-v3 + ], dtype=np.float32) + if tri: + faces = np.array([[0, 1, 2], [0, 2, 3]], dtype=np.int32) + else: + faces = np.array([[0, 1, 2, 3]], dtype=np.int32) + return vertices, faces + + +def cube(tri: bool = False) -> Tuple[np.ndarray, np.ndarray]: + """ + Get x cube mesh of size 1 centered at origin. + + ### Parameters + tri (bool, optional): return triangulated mesh. Defaults to False, which returns quad mesh. + + ### Returns + vertices (np.ndarray): shape (8, 3) + faces (np.ndarray): shape (12, 3) + """ + vertices = np.array([ + [-0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [0.5, -0.5, 0.5], [-0.5, -0.5, 0.5], # v0-v1-v2-v3 + [-0.5, 0.5, -0.5], [0.5, 0.5, -0.5], [0.5, -0.5, -0.5], [-0.5, -0.5, -0.5] # v4-v5-v6-v7 + ], dtype=np.float32).reshape((-1, 3)) + + faces = np.array([ + [0, 1, 2, 3], # v0-v1-v2-v3 (front) + [4, 5, 1, 0], # v4-v5-v1-v0 (top) + [3, 2, 6, 7], # v3-v2-v6-v7 (bottom) + [5, 4, 7, 6], # v5-v4-v7-v6 (back) + [1, 5, 6, 2], # v1-v5-v6-v2 (right) + [4, 0, 3, 7] # v4-v0-v3-v7 (left) + ], dtype=np.int32) + + if tri: + faces = mesh.triangulate(faces, vertices=vertices) + + return vertices, faces + + +def camera_frustum(extrinsics: np.ndarray, intrinsics: np.ndarray, depth: float = 1.0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Get x triangle mesh of camera frustum. + """ + assert extrinsics.shape == (4, 4) and intrinsics.shape == (3, 3) + vertices = transforms.unproject_cv( + np.array([[0, 0], [0, 0], [0, 1], [1, 1], [1, 0]], dtype=np.float32), + np.array([0] + [depth] * 4, dtype=np.float32), + extrinsics, + intrinsics + ).astype(np.float32) + edges = np.array([ + [0, 1], [0, 2], [0, 3], [0, 4], + [1, 2], [2, 3], [3, 4], [4, 1] + ], dtype=np.int32) + faces = np.array([ + [0, 1, 2], + [0, 2, 3], + [0, 3, 4], + [0, 4, 1], + [1, 2, 3], + [1, 3, 4] + ], dtype=np.int32) + return vertices, edges, faces + + +def icosahedron(): + A = (1 + 5 ** 0.5) / 2 + vertices = np.array([ + [0, 1, A], [0, -1, A], [0, 1, -A], [0, -1, -A], + [1, A, 0], [-1, A, 0], [1, -A, 0], [-1, -A, 0], + [A, 0, 1], [A, 0, -1], [-A, 0, 1], [-A, 0, -1] + ], dtype=np.float32) + faces = np.array([ + [0, 1, 8], [0, 8, 4], [0, 4, 5], [0, 5, 10], [0, 10, 1], + [3, 2, 9], [3, 9, 6], [3, 6, 7], [3, 7, 11], [3, 11, 2], + [1, 6, 8], [8, 9, 4], [4, 2, 5], [5, 11, 10], [10, 7, 1], + [2, 4, 9], [9, 8, 6], [6, 1, 7], [7, 10, 11], [11, 5, 2] + ], dtype=np.int32) + return vertices, faces diff --git a/src/utils3d/torch/__init__.py b/src/utils3d/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bffcf41b5e906c25f8c8f01fb0a1b557151103c1 --- /dev/null +++ b/src/utils3d/torch/__init__.py @@ -0,0 +1,139 @@ +import importlib +import itertools +import torch +from typing import TYPE_CHECKING + +__modules_all__ = { + 'mesh': [ + 'triangulate', + 'compute_face_normal', + 'compute_face_angles', + 'compute_vertex_normal', + 'compute_vertex_normal_weighted', + 'compute_edges', + 'compute_connected_components', + 'compute_edge_connected_components', + 'compute_boundarys', + 'compute_dual_graph', + 'remove_unreferenced_vertices', + 'remove_corrupted_faces', + 'remove_isolated_pieces', + 'merge_duplicate_vertices', + 'subdivide_mesh_simple', + 'compute_face_tbn', + 'compute_vertex_tbn', + 'laplacian', + 'laplacian_smooth_mesh', + 'taubin_smooth_mesh', + 'laplacian_hc_smooth_mesh', + ], + 'nerf': [ + 'get_rays', + 'get_image_rays', + 'get_mipnerf_cones', + 'volume_rendering', + 'bin_sample', + 'importance_sample', + 'nerf_render_rays', + 'mipnerf_render_rays', + 'nerf_render_view', + 'mipnerf_render_view', + 'InstantNGP', + ], + 'utils': [ + 'sliding_window_1d', + 'sliding_window_2d', + 'sliding_window_nd', + 'image_uv', + 'image_pixel_center', + 'image_mesh', + 'chessboard', + 'depth_edge', + 'depth_aliasing', + 'image_mesh_from_depth', + 'point_to_normal', + 'depth_to_normal', + 'masked_min', + 'masked_max', + 'bounding_rect' + ], + 'transforms': [ + 'perspective', + 'perspective_from_fov', + 'perspective_from_fov_xy', + 'intrinsics_from_focal_center', + 'intrinsics_from_fov', + 'intrinsics_from_fov_xy', + 'view_look_at', + 'extrinsics_look_at', + 'perspective_to_intrinsics', + 'intrinsics_to_perspective', + 'extrinsics_to_view', + 'view_to_extrinsics', + 'normalize_intrinsics', + 'crop_intrinsics', + 'pixel_to_uv', + 'pixel_to_ndc', + 'uv_to_pixel', + 'project_depth', + 'depth_buffer_to_linear', + 'project_gl', + 'project_cv', + 'unproject_gl', + 'unproject_cv', + 'skew_symmetric', + 'rotation_matrix_from_vectors', + 'euler_axis_angle_rotation', + 'euler_angles_to_matrix', + 'matrix_to_euler_angles', + 'matrix_to_quaternion', + 'quaternion_to_matrix', + 'matrix_to_axis_angle', + 'axis_angle_to_matrix', + 'axis_angle_to_quaternion', + 'quaternion_to_axis_angle', + 'slerp', + 'interpolate_extrinsics', + 'interpolate_view', + 'extrinsics_to_essential', + 'to4x4', + 'rotation_matrix_2d', + 'rotate_2d', + 'translate_2d', + 'scale_2d', + 'apply_2d', + ], + 'rasterization': [ + 'RastContext', + 'rasterize_triangle_faces', + 'warp_image_by_depth', + 'warp_image_by_forward_flow', + ], +} + + +__all__ = list(itertools.chain(*__modules_all__.values())) + +def __getattr__(name): + try: + return globals()[name] + except KeyError: + pass + + try: + module_name = next(m for m in __modules_all__ if name in __modules_all__[m]) + except StopIteration: + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + module = importlib.import_module(f'.{module_name}', __name__) + for key in __modules_all__[module_name]: + globals()[key] = getattr(module, key) + + return globals()[name] + + +if TYPE_CHECKING: + from .transforms import * + from .mesh import * + from .utils import * + from .nerf import * + from .rasterization import * \ No newline at end of file diff --git a/src/utils3d/torch/_helpers.py b/src/utils3d/torch/_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..442e2cb6358ba4f105d2664f9b9f44b6ec6561ca --- /dev/null +++ b/src/utils3d/torch/_helpers.py @@ -0,0 +1,103 @@ +# decorator +import torch +from numbers import Number +import inspect +from functools import wraps +from .._helpers import suppress_traceback + + +def get_device(args, kwargs): + device = None + for arg in (list(args) + list(kwargs.values())): + if isinstance(arg, torch.Tensor): + if device is None: + device = arg.device + elif device != arg.device: + raise ValueError("All tensors must be on the same device.") + return device + + +def get_args_order(func, args, kwargs): + """ + Get the order of the arguments of a function. + """ + names = inspect.getfullargspec(func).args + names_idx = {name: i for i, name in enumerate(names)} + args_order = [] + kwargs_order = {} + for name, arg in kwargs.items(): + if name in names: + kwargs_order[name] = names_idx[name] + names.remove(name) + for i, arg in enumerate(args): + if i < len(names): + args_order.append(names_idx[names[i]]) + return args_order, kwargs_order + + +def broadcast_args(args, kwargs, args_dim, kwargs_dim): + spatial = [] + for arg, arg_dim in zip(args + list(kwargs.values()), args_dim + list(kwargs_dim.values())): + if isinstance(arg, torch.Tensor) and arg_dim is not None: + arg_spatial = arg.shape[:arg.ndim-arg_dim] + if len(arg_spatial) > len(spatial): + spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial + for j in range(len(arg_spatial)): + if spatial[-j] < arg_spatial[-j]: + if spatial[-j] == 1: + spatial[-j] = arg_spatial[-j] + else: + raise ValueError("Cannot broadcast arguments.") + for i, arg in enumerate(args): + if isinstance(arg, torch.Tensor) and args_dim[i] is not None: + args[i] = torch.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-args_dim[i]:]]) + for key, arg in kwargs.items(): + if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None: + kwargs[key] = torch.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-kwargs_dim[key]:]]) + return args, kwargs, spatial + +@suppress_traceback +def batched(*dims): + """ + Decorator that allows a function to be called with batched arguments. + """ + def decorator(func): + @wraps(func) + def wrapper(*args, device=torch.device('cpu'), **kwargs): + args = list(args) + # get arguments dimensions + args_order, kwargs_order = get_args_order(func, args, kwargs) + args_dim = [dims[i] for i in args_order] + kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()} + # convert to torch tensor + device = get_device(args, kwargs) or device + for i, arg in enumerate(args): + if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None: + args[i] = torch.tensor(arg, device=device) + for key, arg in kwargs.items(): + if isinstance(arg, (Number, list, tuple)) and kwargs_dim[key] is not None: + kwargs[key] = torch.tensor(arg, device=device) + # broadcast arguments + args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim) + for i, (arg, arg_dim) in enumerate(zip(args, args_dim)): + if isinstance(arg, torch.Tensor) and arg_dim is not None: + args[i] = arg.reshape([-1, *arg.shape[arg.ndim-arg_dim:]]) + for key, arg in kwargs.items(): + if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None: + kwargs[key] = arg.reshape([-1, *arg.shape[arg.ndim-kwargs_dim[key]:]]) + # call function + results = func(*args, **kwargs) + type_results = type(results) + results = list(results) if isinstance(results, (tuple, list)) else [results] + # restore spatial dimensions + for i, result in enumerate(results): + results[i] = result.reshape([*spatial, *result.shape[1:]]) + if type_results == tuple: + results = tuple(results) + elif type_results == list: + results = list(results) + else: + results = results[0] + return results + return wrapper + return decorator \ No newline at end of file diff --git a/src/utils3d/torch/mesh.py b/src/utils3d/torch/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..5b874d163e5edad3ef871a276b4edccf2e593265 --- /dev/null +++ b/src/utils3d/torch/mesh.py @@ -0,0 +1,688 @@ +import torch +import torch.nn.functional as F +from typing import * +from ._helpers import batched + + +__all__ = [ + 'triangulate', + 'compute_face_normal', + 'compute_face_angles', + 'compute_vertex_normal', + 'compute_vertex_normal_weighted', + 'compute_edges', + 'compute_connected_components', + 'compute_edge_connected_components', + 'compute_boundarys', + 'compute_dual_graph', + 'remove_unreferenced_vertices', + 'remove_corrupted_faces', + 'remove_isolated_pieces', + 'merge_duplicate_vertices', + 'subdivide_mesh_simple', + 'compute_face_tbn', + 'compute_vertex_tbn', + 'laplacian', + 'laplacian_smooth_mesh', + 'taubin_smooth_mesh', + 'laplacian_hc_smooth_mesh', +] + + +def _group( + values: torch.Tensor, + required_group_size: Optional[int] = None, + return_values: bool = False +) -> Tuple[Union[List[torch.Tensor], torch.Tensor], Optional[torch.Tensor]]: + """ + Group values into groups with identical values. + + Args: + values (torch.Tensor): [N] values to group + required_group_size (int, optional): required group size. Defaults to None. + return_values (bool, optional): return values of groups. Defaults to False. + + Returns: + group (Union[List[torch.Tensor], torch.Tensor]): list of groups or group indices. It will be a list of groups if required_group_size is None, otherwise a tensor of group indices. + group_values (Optional[torch.Tensor]): values of groups. Only returned if return_values is True. + """ + sorted_values, indices = torch.sort(values) + nondupe = torch.cat([torch.tensor([True], dtype=torch.bool, device=values.device), sorted_values[1:] != sorted_values[:-1]]) + nondupe_indices = torch.cumsum(nondupe, dim=0) - 1 + counts = torch.bincount(nondupe_indices) + if required_group_size is None: + groups = torch.split(indices, counts.tolist()) + if return_values: + group_values = sorted_values[nondupe] + return groups, group_values + else: + return groups + else: + counts = counts[nondupe_indices] + groups = indices[counts == required_group_size].reshape(-1, required_group_size) + if return_values: + group_values = sorted_values[nondupe][counts[nondupe] == required_group_size] + return groups, group_values + else: + return groups + +def triangulate( + faces: torch.Tensor, + vertices: torch.Tensor = None, + backslash: bool = None +) -> torch.Tensor: + """ + Triangulate a polygonal mesh. + + Args: + faces (torch.Tensor): [..., L, P] polygonal faces + vertices (torch.Tensor, optional): [..., N, 3] 3-dimensional vertices. + If given, the triangulation is performed according to the distance + between vertices. Defaults to None. + backslash (torch.Tensor, optional): [..., L] boolean array indicating + how to triangulate the quad faces. Defaults to None. + + + Returns: + (torch.Tensor): [L * (P - 2), 3] triangular faces + """ + if faces.shape[-1] == 3: + return faces + P = faces.shape[-1] + if vertices is not None: + assert faces.shape[-1] == 4, "now only support quad mesh" + if backslash is None: + faces_idx = faces.long() + backslash = torch.norm(vertices[faces_idx[..., 0]] - vertices[faces_idx[..., 2]], p=2, dim=-1) < \ + torch.norm(vertices[faces_idx[..., 1]] - vertices[faces_idx[..., 3]], p=2, dim=-1) + if backslash is None: + loop_indice = torch.stack([ + torch.zeros(P - 2, dtype=int), + torch.arange(1, P - 1, 1, dtype=int), + torch.arange(2, P, 1, dtype=int) + ], axis=1) + return faces[:, loop_indice].reshape(-1, 3) + else: + assert faces.shape[-1] == 4, "now only support quad mesh" + if isinstance(backslash, bool): + if backslash: + faces = faces[:, [0, 1, 2, 0, 2, 3]].reshape(-1, 3) + else: + faces = faces[:, [0, 1, 3, 3, 1, 2]].reshape(-1, 3) + else: + faces = torch.where( + backslash[:, None], + faces[:, [0, 1, 2, 0, 2, 3]], + faces[:, [0, 1, 3, 3, 1, 2]] + ).reshape(-1, 3) + return faces + + +@batched(2, None) +def compute_face_normal( + vertices: torch.Tensor, + faces: torch.Tensor +) -> torch.Tensor: + """ + Compute face normals of a triangular mesh + + Args: + vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices + faces (torch.Tensor): [..., T, 3] triangular face indices + + Returns: + normals (torch.Tensor): [..., T, 3] face normals + """ + N = vertices.shape[0] + index = torch.arange(N)[:, None] + normal = torch.cross( + vertices[index, faces[..., 1].long()] - vertices[index, faces[..., 0].long()], + vertices[index, faces[..., 2].long()] - vertices[index, faces[..., 0].long()], + dim=-1 + ) + return F.normalize(normal, p=2, dim=-1) + + +@batched(2, None) +def compute_face_angles( + vertices: torch.Tensor, + faces: torch.Tensor +) -> torch.Tensor: + """ + Compute face angles of a triangular mesh + + Args: + vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + + Returns: + angles (torch.Tensor): [..., T, 3] face angles + """ + face_angles = [] + for i in range(3): + edge1 = torch.index_select(vertices, dim=-2, index=faces[:, (i + 1) % 3]) - torch.index_select(vertices, dim=-2, index=faces[:, i]) + edge2 = torch.index_select(vertices, dim=-2, index=faces[:, (i + 2) % 3]) - torch.index_select(vertices, dim=-2, index=faces[:, i]) + face_angle = torch.arccos(torch.sum(F.normalize(edge1, p=2, dim=-1) * F.normalize(edge2, p=2, dim=-1), dim=-1)) + face_angles.append(face_angle) + face_angles = torch.stack(face_angles, dim=-1) + return face_angles + + +@batched(2, None, 2) +def compute_vertex_normal( + vertices: torch.Tensor, + faces: torch.Tensor, + face_normal: torch.Tensor = None +) -> torch.Tensor: + """ + Compute vertex normals of a triangular mesh by averaging neightboring face normals + + Args: + vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + face_normal (torch.Tensor, optional): [..., T, 3] face normals. + None to compute face normals from vertices and faces. Defaults to None. + + Returns: + normals (torch.Tensor): [..., N, 3] vertex normals + """ + N = vertices.shape[0] + assert faces.shape[-1] == 3, "Only support triangular mesh" + if face_normal is None: + face_normal = compute_face_normal(vertices, faces) + face_normal = face_normal[:, :, None, :].expand(-1, -1, 3, -1).flatten(-3, -2) + faces = faces.flatten() + vertex_normal = torch.index_put(torch.zeros_like(vertices), (torch.arange(N)[:, None], faces[None, :]), face_normal, accumulate=True) + vertex_normal = F.normalize(vertex_normal, p=2, dim=-1) + return vertex_normal + + +@batched(2, None, 2) +def compute_vertex_normal_weighted( + vertices: torch.Tensor, + faces: torch.Tensor, + face_normal: torch.Tensor = None +) -> torch.Tensor: + """ + Compute vertex normals of a triangular mesh by weighted sum of neightboring face normals + according to the angles + + Args: + vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + face_normal (torch.Tensor, optional): [..., T, 3] face normals. + None to compute face normals from vertices and faces. Defaults to None. + + Returns: + normals (torch.Tensor): [..., N, 3] vertex normals + """ + N = vertices.shape[0] + if face_normal is None: + face_normal = compute_face_normal(vertices, faces) + face_angle = compute_face_angles(vertices, faces) + face_normal = face_normal[:, :, None, :].expand(-1, -1, 3, -1) * face_angle[..., None] + vertex_normal = torch.index_put(torch.zeros_like(vertices), (torch.arange(N)[:, None], faces.view(N, -1)), face_normal.view(N, -1, 3), accumulate=True) + vertex_normal = F.normalize(vertex_normal, p=2, dim=-1) + return vertex_normal + + +def compute_edges( + faces: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute edges of a mesh. + + Args: + faces (torch.Tensor): [T, 3] triangular face indices + + Returns: + edges (torch.Tensor): [E, 2] edge indices + face2edge (torch.Tensor): [T, 3] mapping from face to edge + counts (torch.Tensor): [E] degree of each edge + """ + T = faces.shape[0] + edges = torch.cat([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], dim=0) # [3T, 2] + edges = torch.sort(edges, dim=1).values + edges, inv_map, counts = torch.unique(edges, return_inverse=True, return_counts=True, dim=0) + face2edge = inv_map.view(3, T).T + return edges, face2edge, counts + + +def compute_connected_components( + faces: torch.Tensor, + edges: torch.Tensor=None, + face2edge: torch.Tensor=None +) -> List[torch.Tensor]: + """ + Compute connected faces of a mesh. + + Args: + faces (torch.Tensor): [T, 3] triangular face indices + edges (torch.Tensor, optional): [E, 2] edge indices. Defaults to None. + face2edge (torch.Tensor, optional): [T, 3] mapping from face to edge. Defaults to None. + NOTE: If edges and face2edge are not provided, they will be computed. + + Returns: + components (List[torch.Tensor]): list of connected faces + """ + T = faces.shape[0] + if edges is None or face2edge is None: + edges, face2edge, _ = compute_edges(faces) + E = edges.shape[0] + + labels = torch.arange(T, dtype=torch.int32, device=faces.device) + while True: + edge_labels = torch.scatter_reduce( + torch.zeros(E, dtype=torch.int32, device=faces.device), + 0, + face2edge.flatten().long(), + labels.view(-1, 1).expand(-1, 3).flatten(), + reduce='amin', + include_self=False + ) + new_labels = torch.min(edge_labels[face2edge], dim=-1).values + if torch.equal(labels, new_labels): + break + labels = new_labels + + components = _group(labels) + + return components + + +def compute_edge_connected_components( + edges: torch.Tensor, +) -> List[torch.Tensor]: + """ + Compute connected edges of a mesh. + + Args: + edges (torch.Tensor): [E, 2] edge indices + + Returns: + components (List[torch.Tensor]): list of connected edges + """ + E = edges.shape[0] + + # Re-index edges + verts, edges = torch.unique(edges.flatten(), return_inverse=True) + edges = edges.view(-1, 2) + V = verts.shape[0] + + labels = torch.arange(E, dtype=torch.int32, device=edges.device) + while True: + vertex_labels = torch.scatter_reduce( + torch.zeros(V, dtype=torch.int32, device=edges.device), + 0, + edges.flatten().long(), + labels.view(-1, 1).expand(-1, 2).flatten(), + reduce='amin', + include_self=False + ) + new_labels = torch.min(vertex_labels[edges], dim=-1).values + if torch.equal(labels, new_labels): + break + labels = new_labels + + components = _group(labels) + + return components + + +def compute_boundarys( + faces: torch.Tensor, + edges: torch.Tensor=None, + face2edge: torch.Tensor=None, + edge_degrees: torch.Tensor=None +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Compute boundary edges of a mesh. + + Args: + faces (torch.Tensor): [T, 3] triangular face indices + edges (torch.Tensor): [E, 2] edge indices. + face2edge (torch.Tensor): [T, 3] mapping from face to edge. + edge_degrees (torch.Tensor): [E] degree of each edge. + + Returns: + boundary_edge_indices (List[torch.Tensor]): list of boundary edge indices + boundary_face_indices (List[torch.Tensor]): list of boundary face indices + """ + # Map each edge to boundary edge index + boundary_edges = edges[edge_degrees == 1] # [BE, 2] + boundary_edges_idx = torch.nonzero(edge_degrees == 1, as_tuple=False).flatten() # [BE] + E = edges.shape[0] # Edge count + BE = boundary_edges.shape[0] # Boundary edge count + map_to_boundary_edges = torch.full((E,), -1, dtype=torch.int32, device=faces.device) # [E] + map_to_boundary_edges[boundary_edges_idx] = torch.arange(BE, dtype=torch.int32, device=faces.device) + + # Re-index boundary vertices + boundary_vertices, boundary_edges = torch.unique(boundary_edges.flatten(), return_inverse=True) + boundary_edges = boundary_edges.view(-1, 2) + BV = boundary_vertices.shape[0] + + boundary_edge_labels = torch.arange(BE, dtype=torch.int32, device=faces.device) + while True: + boundary_vertex_labels = torch.scatter_reduce( + torch.zeros(BV, dtype=torch.int32, device=faces.device), + 0, + boundary_edges.flatten().long(), + boundary_edge_labels.view(-1, 1).expand(-1, 2).flatten(), + reduce='amin', + include_self=False + ) + new_boundary_edge_labels = torch.min(boundary_vertex_labels[boundary_edges], dim=-1).values + if torch.equal(boundary_edge_labels, new_boundary_edge_labels): + break + boundary_edge_labels = new_boundary_edge_labels + + labels = torch.unique(boundary_edge_labels) + boundary_edge_indices = [boundary_edges_idx[boundary_edge_labels == label] for label in labels] + edge_labels = torch.full((E,), -1, dtype=torch.int32, device=faces.device) + edge_labels[boundary_edges_idx] = boundary_edge_labels + boundary_face_indices = [torch.nonzero((edge_labels[face2edge] == label).any(dim=-1), as_tuple=False).flatten() for label in labels] + + return boundary_edge_indices, boundary_face_indices + + +def compute_dual_graph( + face2edge: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute dual graph of a mesh. + + Args: + face2edge (torch.Tensor): [T, 3] mapping from face to edge. + + Returns: + dual_edges (torch.Tensor): [DE, 2] face indices of dual edges + dual_edge2edge (torch.Tensor): [DE] mapping from dual edge to edge + """ + all_edge_indices = face2edge.flatten() # [3T] + dual_edges, dual_edge2edge = _group(all_edge_indices, required_group_size=2, return_values=True) + dual_edges = dual_edges // face2edge.shape[1] + return dual_edges, dual_edge2edge + + +def remove_unreferenced_vertices( + faces: torch.Tensor, + *vertice_attrs, + return_indices: bool = False +) -> Tuple[torch.Tensor, ...]: + """ + Remove unreferenced vertices of a mesh. + Unreferenced vertices are removed, and the face indices are updated accordingly. + + Args: + faces (torch.Tensor): [T, P] face indices + *vertice_attrs: vertex attributes + + Returns: + faces (torch.Tensor): [T, P] face indices + *vertice_attrs: vertex attributes + indices (torch.Tensor, optional): [N] indices of vertices that are kept. Defaults to None. + """ + P = faces.shape[-1] + fewer_indices, inv_map = torch.unique(faces, return_inverse=True) + faces = inv_map.to(torch.int32).reshape(-1, P) + ret = [faces] + for attr in vertice_attrs: + ret.append(attr[fewer_indices]) + if return_indices: + ret.append(fewer_indices) + return tuple(ret) + + +def remove_corrupted_faces( + faces: torch.Tensor +) -> torch.Tensor: + """ + Remove corrupted faces (faces with duplicated vertices) + + Args: + faces (torch.Tensor): [T, 3] triangular face indices + + Returns: + torch.Tensor: [T_, 3] triangular face indices + """ + corrupted = (faces[:, 0] == faces[:, 1]) | (faces[:, 1] == faces[:, 2]) | (faces[:, 2] == faces[:, 0]) + return faces[~corrupted] + + +def merge_duplicate_vertices( + vertices: torch.Tensor, + faces: torch.Tensor, + tol: float = 1e-6 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Merge duplicate vertices of a triangular mesh. + Duplicate vertices are merged by selecte one of them, and the face indices are updated accordingly. + + Args: + vertices (torch.Tensor): [N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + tol (float, optional): tolerance for merging. Defaults to 1e-6. + + Returns: + vertices (torch.Tensor): [N_, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + """ + vertices_round = torch.round(vertices / tol) + uni, uni_inv = torch.unique(vertices_round, dim=0, return_inverse=True) + uni[uni_inv] = vertices + faces = uni_inv[faces] + return uni, faces + + +def remove_isolated_pieces( + vertices: torch.Tensor, + faces: torch.Tensor, + connected_components: List[torch.Tensor] = None, + thresh_num_faces: int = None, + thresh_radius: float = None, + thresh_boundary_ratio: float = None, + remove_unreferenced: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Remove isolated pieces of a mesh. + Isolated pieces are removed, and the face indices are updated accordingly. + If no face is left, will return the largest connected component. + + Args: + vertices (torch.Tensor): [N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + connected_components (List[torch.Tensor], optional): connected components of the mesh. If None, it will be computed. Defaults to None. + thresh_num_faces (int, optional): threshold of number of faces for isolated pieces. Defaults to None. + thresh_radius (float, optional): threshold of radius for isolated pieces. Defaults to None. + remove_unreferenced (bool, optional): remove unreferenced vertices after removing isolated pieces. Defaults to True. + + Returns: + vertices (torch.Tensor): [N_, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + """ + if connected_components is None: + connected_components = compute_connected_components(faces) + connected_components = sorted(connected_components, key=lambda x: len(x), reverse=True) + if thresh_num_faces is not None: + removed = [] + for i in range(1, len(connected_components)): + if len(connected_components[i]) < thresh_num_faces: + removed.append(i) + for i in removed[::-1]: + connected_components.pop(i) + if thresh_radius is not None: + removed = [] + for i in range(1, len(connected_components)): + comp_vertices = vertices[faces[connected_components[i]].flatten().unique()] + comp_center = comp_vertices.mean(dim=0) + comp_radius = (comp_vertices - comp_center).norm(p=2, dim=-1).max() + if comp_radius < thresh_radius: + removed.append(i) + for i in removed[::-1]: + connected_components.pop(i) + if thresh_boundary_ratio is not None: + removed = [] + for i in range(1, len(connected_components)): + edges = torch.cat([faces[connected_components[i]][:, [0, 1]], faces[connected_components[i]][:, [1, 2]], faces[connected_components[i]][:, [2, 0]]], dim=0) + edges = torch.sort(edges, dim=1).values + edges, counts = torch.unique(edges, return_counts=True, dim=0) + num_boundary_edges = (counts == 1).sum().item() + num_faces = len(connected_components[i]) + if num_boundary_edges / num_faces > thresh_boundary_ratio: + removed.append(i) + for i in removed[::-1]: + connected_components.pop(i) + + # post-process + faces = torch.cat([faces[connected_components[i]] for i in range(len(connected_components))], dim=0) + if remove_unreferenced: + faces, vertices = remove_unreferenced_vertices(faces, vertices) + return vertices, faces + + +def subdivide_mesh_simple(vertices: torch.Tensor, faces: torch.Tensor, n: int = 1) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Subdivide a triangular mesh by splitting each triangle into 4 smaller triangles. + NOTE: All original vertices are kept, and new vertices are appended to the end of the vertex list. + + Args: + vertices (torch.Tensor): [N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + n (int, optional): number of subdivisions. Defaults to 1. + + Returns: + vertices (torch.Tensor): [N_, 3] subdivided 3-dimensional vertices + faces (torch.Tensor): [4 * T, 3] subdivided triangular face indices + """ + for _ in range(n): + edges = torch.stack([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], dim=0) + edges = torch.sort(edges, dim=2) + uni_edges, uni_inv = torch.unique(edges, return_inverse=True, dim=0) + midpoints = (vertices[uni_edges[:, 0]] + vertices[uni_edges[:, 1]]) / 2 + + n_vertices = vertices.shape[0] + vertices = torch.cat([vertices, midpoints], dim=0) + faces = torch.cat([ + torch.stack([faces[:, 0], n_vertices + uni_inv[0], n_vertices + uni_inv[2]], axis=1), + torch.stack([faces[:, 1], n_vertices + uni_inv[1], n_vertices + uni_inv[0]], axis=1), + torch.stack([faces[:, 2], n_vertices + uni_inv[2], n_vertices + uni_inv[1]], axis=1), + torch.stack([n_vertices + uni_inv[0], n_vertices + uni_inv[1], n_vertices + uni_inv[2]], axis=1), + ], dim=0) + return vertices, faces + + +def compute_face_tbn(pos: torch.Tensor, faces_pos: torch.Tensor, uv: torch.Tensor, faces_uv: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: + """compute TBN matrix for each face + + Args: + pos (torch.Tensor): shape (..., N_pos, 3), positions + faces_pos (torch.Tensor): shape(T, 3) + uv (torch.Tensor): shape (..., N_uv, 3) uv coordinates, + faces_uv (torch.Tensor): shape(T, 3) + + Returns: + torch.Tensor: (..., T, 3, 3) TBN matrix for each face. Note TBN vectors are normalized but not necessarily orthognal + """ + e01 = torch.index_select(pos, dim=-2, index=faces_pos[:, 1]) - torch.index_select(pos, dim=-2, index=faces_pos[:, 0]) + e02 = torch.index_select(pos, dim=-2, index=faces_pos[:, 2]) - torch.index_select(pos, dim=-2, index=faces_pos[:, 0]) + uv01 = torch.index_select(uv, dim=-2, index=faces_uv[:, 1]) - torch.index_select(uv, dim=-2, index=faces_uv[:, 0]) + uv02 = torch.index_select(uv, dim=-2, index=faces_uv[:, 2]) - torch.index_select(uv, dim=-2, index=faces_uv[:, 0]) + normal = torch.cross(e01, e02) + tangent_bitangent = torch.stack([e01, e02], dim=-1) @ torch.inverse(torch.stack([uv01, uv02], dim=-1)) + tbn = torch.cat([tangent_bitangent, normal.unsqueeze(-1)], dim=-1) + tbn = tbn / (torch.norm(tbn, p=2, dim=-2, keepdim=True) + eps) + return tbn + + +def compute_vertex_tbn(faces_topo: torch.Tensor, pos: torch.Tensor, faces_pos: torch.Tensor, uv: torch.Tensor, faces_uv: torch.Tensor) -> torch.Tensor: + """compute TBN matrix for each face + + Args: + faces_topo (torch.Tensor): (T, 3), face indice of topology + pos (torch.Tensor): shape (..., N_pos, 3), positions + faces_pos (torch.Tensor): shape(T, 3) + uv (torch.Tensor): shape (..., N_uv, 3) uv coordinates, + faces_uv (torch.Tensor): shape(T, 3) + + Returns: + torch.Tensor: (..., V, 3, 3) TBN matrix for each face. Note TBN vectors are normalized but not necessarily orthognal + """ + n_vertices = faces_topo.max().item() + 1 + n_tri = faces_topo.shape[-2] + batch_shape = pos.shape[:-2] + face_tbn = compute_face_tbn(pos, faces_pos, uv, faces_uv) # (..., T, 3, 3) + face_tbn = face_tbn[..., :, None, :, :].repeat(*[1] * len(batch_shape), 1, 3, 1, 1).view(*batch_shape, n_tri * 3, 3, 3) # (..., T * 3, 3, 3) + vertex_tbn = torch.index_add(torch.zeros(*batch_shape, n_vertices, 3, 3).to(face_tbn), dim=-3, index=faces_topo.view(-1), source=face_tbn) + vertex_tbn = vertex_tbn / (torch.norm(vertex_tbn, p=2, dim=-2, keepdim=True) + 1e-7) + return vertex_tbn + + +def laplacian(vertices: torch.Tensor, faces: torch.Tensor, weight: str = 'uniform') -> torch.Tensor: + """Laplacian smooth with cotangent weights + + Args: + vertices (torch.Tensor): shape (..., N, 3) + faces (torch.Tensor): shape (T, 3) + weight (str): 'uniform' or 'cotangent' + """ + sum_verts = torch.zeros_like(vertices) # (..., N, 3) + sum_weights = torch.zeros(*vertices.shape[:-1]).to(vertices) # (..., N) + face_verts = torch.index_select(vertices, -2, faces.view(-1)).view(*vertices.shape[:-2], *faces.shape, vertices.shape[-1]) # (..., T, 3) + if weight == 'cotangent': + for i in range(3): + e1 = face_verts[..., (i + 1) % 3, :] - face_verts[..., i, :] + e2 = face_verts[..., (i + 2) % 3, :] - face_verts[..., i, :] + cot_angle = (e1 * e2).sum(dim=-1) / torch.cross(e1, e2, dim=-1).norm(p=2, dim=-1) # (..., T, 3) + sum_verts = torch.index_add(sum_verts, -2, faces[:, (i + 1) % 3], face_verts[..., (i + 2) % 3, :] * cot_angle[..., None]) + sum_weights = torch.index_add(sum_weights, -1, faces[:, (i + 1) % 3], cot_angle) + sum_verts = torch.index_add(sum_verts, -2, faces[:, (i + 2) % 3], face_verts[..., (i + 1) % 3, :] * cot_angle[..., None]) + sum_weights = torch.index_add(sum_weights, -1, faces[:, (i + 2) % 3], cot_angle) + elif weight == 'uniform': + for i in range(3): + sum_verts = torch.index_add(sum_verts, -2, faces[:, i], face_verts[..., (i + 1) % 3, :]) + sum_weights = torch.index_add(sum_weights, -1, faces[:, i], torch.ones_like(face_verts[..., i, 0])) + else: + raise NotImplementedError + return sum_verts / (sum_weights[..., None] + 1e-7) + + +def laplacian_smooth_mesh(vertices: torch.Tensor, faces: torch.Tensor, weight: str = 'uniform', times: int = 5) -> torch.Tensor: + """Laplacian smooth with cotangent weights + + Args: + vertices (torch.Tensor): shape (..., N, 3) + faces (torch.Tensor): shape (T, 3) + weight (str): 'uniform' or 'cotangent' + """ + for _ in range(times): + vertices = laplacian(vertices, faces, weight) + return vertices + + +def taubin_smooth_mesh(vertices: torch.Tensor, faces: torch.Tensor, lambda_: float = 0.5, mu_: float = -0.51) -> torch.Tensor: + """Taubin smooth mesh + + Args: + vertices (torch.Tensor): _description_ + faces (torch.Tensor): _description_ + lambda_ (float, optional): _description_. Defaults to 0.5. + mu_ (float, optional): _description_. Defaults to -0.51. + + Returns: + torch.Tensor: _description_ + """ + pt = vertices + lambda_ * laplacian_smooth_mesh(vertices, faces) + p = pt + mu_ * laplacian_smooth_mesh(pt, faces) + return p + + +def laplacian_hc_smooth_mesh(vertices: torch.Tensor, faces: torch.Tensor, times: int = 5, alpha: float = 0.5, beta: float = 0.5, weight: str = 'uniform'): + """HC algorithm from Improved Laplacian Smoothing of Noisy Surface Meshes by J.Vollmer et al. + """ + p = vertices + for i in range(times): + q = p + p = laplacian_smooth_mesh(vertices, faces, weight) + b = p - (alpha * vertices + (1 - alpha) * q) + p = p - (beta * b + (1 - beta) * laplacian_smooth_mesh(b, faces, weight)) * 0.8 + return p diff --git a/src/utils3d/torch/nerf.py b/src/utils3d/torch/nerf.py new file mode 100644 index 0000000000000000000000000000000000000000..7d20bc747255dbb1a68191f93a395a824d76e108 --- /dev/null +++ b/src/utils3d/torch/nerf.py @@ -0,0 +1,749 @@ +from typing import * +from numbers import Number +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from .utils import image_uv + + +__all__ = [ + 'get_rays', + 'get_image_rays', + 'get_mipnerf_cones', + 'volume_rendering', + 'bin_sample', + 'importance_sample', + 'nerf_render_rays', + 'mipnerf_render_rays', + 'nerf_render_view', + 'mipnerf_render_view', + 'InstantNGP', +] + + +def get_rays(extrinsics: Tensor, intrinsics: Tensor, uv: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + extrinsics: (..., 4, 4) extrinsics matrices. + intrinsics: (..., 3, 3) intrinsics matrices. + uv: (..., n_rays, 2) uv coordinates of the rays. + + Returns: + rays_o: (..., 1, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + NOTE: ray directions are NOT normalized. They actuallys makes rays_o + rays_d * z = world coordinates, where z is the depth. + """ + uvz = torch.cat([uv, torch.ones_like(uv[..., :1])], dim=-1).to(extrinsics) # (n_batch, n_views, n_rays, 3) + + with torch.cuda.amp.autocast(enabled=False): + inv_transformation = (intrinsics @ extrinsics[..., :3, :3]).inverse() + inv_extrinsics = extrinsics.inverse() + rays_d = uvz @ inv_transformation.transpose(-1, -2) + rays_o = inv_extrinsics[..., None, :3, 3] # (n_batch, n_views, 1, 3) + return rays_o, rays_d + + +def get_image_rays(extrinsics: Tensor, intrinsics: Tensor, width: int, height: int) -> Tuple[Tensor, Tensor]: + """ + Args: + extrinsics: (..., 4, 4) extrinsics matrices. + intrinsics: (..., 3, 3) intrinsics matrices. + width: width of the image. + height: height of the image. + + Returns: + rays_o: (..., 1, 1, 3) ray origins + rays_d: (..., height, width, 3) ray directions. + NOTE: ray directions are NOT normalized. They actuallys makes rays_o + rays_d * z = world coordinates, where z is the depth. + """ + uv = image_uv(height, width).to(extrinsics).flatten(0, 1) + rays_o, rays_d = get_rays(extrinsics, intrinsics, uv) + rays_o = rays_o.unflatten(-2, (1, 1)) + rays_d = rays_d.unflatten(-2, (height, width)) + return rays_o, rays_d + + +def get_mipnerf_cones(rays_o: Tensor, rays_d: Tensor, z_vals: Tensor, pixel_width: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + rays_o: (..., n_rays, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + z_vals: (..., n_rays, n_samples) z values. + pixel_width: (...) pixel width. = 1 / (normalized focal length * width) + + Returns: + mu: (..., n_rays, n_samples, 3) cone mu. + sigma: (..., n_rays, n_samples, 3, 3) cone sigma. + """ + t_mu = (z_vals[..., 1:] + z_vals[..., :-1]).mul_(0.5) + t_delta = (z_vals[..., 1:] - z_vals[..., :-1]).mul_(0.5) + t_mu_square = t_mu.square() + t_delta_square = t_delta.square() + t_delta_quad = t_delta_square.square() + mu_t = t_mu + 2.0 * t_mu * t_delta_square / (3.0 * t_mu_square + t_delta_square) + sigma_t = t_delta_square / 3.0 - (4.0 / 15.0) * t_delta_quad / (3.0 * t_mu_square + t_delta_square).square() * (12.0 * t_mu_square - t_delta_square) + sigma_r = (pixel_width[..., None, None].square() / 3.0) * (t_mu_square / 4.0 + (5.0 / 12.0) * t_delta_square - (4.0 / 15.0) * t_delta_quad / (3.0 * t_mu_square + t_delta_square)) + points_mu = rays_o[:, :, :, None, :] + rays_d[:, :, :, None, :] * mu_t[..., None] + d_dt = rays_d[..., :, None] * rays_d[..., None, :] # (..., n_rays, 3, 3) + points_sigma = sigma_t[..., None, None] * d_dt[..., None, :, :] + sigma_r[..., None, None] * (torch.eye(3).to(rays_o) - d_dt[..., None, :, :]) + return points_mu, points_sigma + + +def get_pixel_width(intrinsics: Tensor, width: int, height: int) -> Tensor: + """ + Args: + intrinsics: (..., 3, 3) intrinsics matrices. + width: width of the image. + height: height of the image. + + Returns: + pixel_width: (...) pixel width. = 1 / (normalized focal length * width) + """ + assert width == height, "Currently, only square images are supported." + pixel_width = torch.reciprocal((intrinsics[..., 0, 0] * intrinsics[..., 1, 1]).sqrt() * width) + return pixel_width + + +def volume_rendering(color: Tensor, sigma: Tensor, z_vals: Tensor, ray_length: Tensor, rgb: bool = True, depth: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + """ + Given color, sigma and z_vals (linear depth of the sampling points), render the volume. + + NOTE: By default, color and sigma should have one less sample than z_vals, in correspondence with the average value in intervals. + If queried color are aligned with z_vals, we use trapezoidal rule to calculate the average values in intervals. + + Args: + color: (..., n_samples or n_samples - 1, 3) color values. + sigma: (..., n_samples or n_samples - 1) density values. + z_vals: (..., n_samples) z values. + ray_length: (...) length of the ray + + Returns: + rgb: (..., 3) rendered color values. + depth: (...) rendered depth values. + weights (..., n_samples) weights. + """ + dists = (z_vals[..., 1:] - z_vals[..., :-1]) * ray_length[..., None] + if color.shape[-2] == z_vals.shape[-1]: + color = (color[..., 1:, :] + color[..., :-1, :]).mul_(0.5) + sigma = (sigma[..., 1:] + sigma[..., :-1]).mul_(0.5) + sigma_delta = sigma * dists + transparancy = (-torch.cat([torch.zeros_like(sigma_delta[..., :1]), sigma_delta[..., :-1]], dim=-1).cumsum(dim=-1)).exp_() # First cumsum then exp for numerical stability + alpha = 1.0 - (-sigma_delta).exp_() + weights = alpha * transparancy + if rgb: + rgb = torch.sum(weights[..., None] * color, dim=-2) if rgb else None + if depth: + z_vals = (z_vals[..., 1:] + z_vals[..., :-1]).mul_(0.5) + depth = torch.sum(weights * z_vals, dim=-1) / weights.sum(dim=-1).clamp_min_(1e-8) if depth else None + return rgb, depth, weights + + +def neus_volume_rendering(color: Tensor, sdf: Tensor, s: torch.Tensor, z_vals: Tensor = None, rgb: bool = True, depth: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + """ + Given color, sdf values and z_vals (linear depth of the sampling points), do volume rendering. (NeuS) + + Args: + color: (..., n_samples or n_samples - 1, 3) color values. + sdf: (..., n_samples) sdf values. + s: (..., n_samples) S values of S-density function in NeuS. The standard deviation of such S-density distribution is 1 / s. + z_vals: (..., n_samples) z values. + ray_length: (...) length of the ray + + Returns: + rgb: (..., 3) rendered color values. + depth: (...) rendered depth values. + weights (..., n_samples) weights. + """ + + if color.shape[-2] == z_vals.shape[-1]: + color = (color[..., 1:, :] + color[..., :-1, :]).mul_(0.5) + + sigmoid_sdf = torch.sigmoid(s * sdf) + alpha = F.relu(1 - sigmoid_sdf[..., :-1] / sigmoid_sdf[..., :-1]) + transparancy = torch.cumprod(torch.cat([torch.ones_like(alpha[..., :1]), alpha], dim=-1), dim=-1) + weights = alpha * transparancy + + if rgb: + rgb = torch.sum(weights[..., None] * color, dim=-2) if rgb else None + if depth: + z_vals = (z_vals[..., 1:] + z_vals[..., :-1]).mul_(0.5) + depth = torch.sum(weights * z_vals, dim=-1) / weights.sum(dim=-1).clamp_min_(1e-8) if depth else None + return rgb, depth, weights + + +def bin_sample(size: Union[torch.Size, Tuple[int, ...]], n_samples: int, min_value: Number, max_value: Number, spacing: Literal['linear', 'inverse_linear'], dtype: torch.dtype = None, device: torch.device = None) -> Tensor: + """ + Uniformly (or uniformly in inverse space) sample z values in `n_samples` bins in range [min_value, max_value]. + Args: + size: size of the rays + n_samples: number of samples to be sampled, also the number of bins + min_value: minimum value of the range + max_value: maximum value of the range + space: 'linear' or 'inverse_linear'. If 'inverse_linear', the sampling is uniform in inverse space. + + Returns: + z_rand: (*size, n_samples) sampled z values, sorted in ascending order. + """ + if spacing == 'linear': + pass + elif spacing == 'inverse_linear': + min_value = 1.0 / min_value + max_value = 1.0 / max_value + bin_length = (max_value - min_value) / n_samples + z_rand = (torch.rand(*size, n_samples, device=device, dtype=dtype) - 0.5) * bin_length + torch.linspace(min_value + bin_length * 0.5, max_value - bin_length * 0.5, n_samples, device=device, dtype=dtype) + if spacing == 'inverse_linear': + z_rand = 1.0 / z_rand + return z_rand + + +def importance_sample(z_vals: Tensor, weights: Tensor, n_samples: int) -> Tuple[Tensor, Tensor]: + """ + Importance sample z values. + + NOTE: By default, weights should have one less sample than z_vals, in correspondence with the intervals. + If weights has the same number of samples as z_vals, we use trapezoidal rule to calculate the average weights in intervals. + + Args: + z_vals: (..., n_rays, n_input_samples) z values, sorted in ascending order. + weights: (..., n_rays, n_input_samples or n_input_samples - 1) weights. + n_samples: number of output samples for importance sampling. + + Returns: + z_importance: (..., n_rays, n_samples) importance sampled z values, unsorted. + """ + if weights.shape[-1] == z_vals.shape[-1]: + weights = (weights[..., 1:] + weights[..., :-1]).mul_(0.5) + weights = weights / torch.sum(weights, dim=-1, keepdim=True) # (..., n_rays, n_input_samples - 1) + bins_a, bins_b = z_vals[..., :-1], z_vals[..., 1:] + + pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # (..., n_rays, n_input_samples - 1) + cdf = torch.cumsum(pdf, dim=-1) + u = torch.rand(*z_vals.shape[:-1], n_samples, device=z_vals.device, dtype=z_vals.dtype) + + inds = torch.searchsorted(cdf, u, right=True).clamp(0, cdf.shape[-1] - 1) # (..., n_rays, n_samples) + + bins_a = torch.gather(bins_a, dim=-1, index=inds) + bins_b = torch.gather(bins_b, dim=-1, index=inds) + z_importance = bins_a + (bins_b - bins_a) * torch.rand_like(u) + return z_importance + + +def nerf_render_rays( + nerf: Union[Callable[[Tensor, Tensor], Tuple[Tensor, Tensor]], Tuple[Callable[[Tensor], Tuple[Tensor, Tensor]], Callable[[Tensor], Tuple[Tensor, Tensor]]]], + rays_o: Tensor, rays_d: Tensor, + *, + return_dict: bool = False, + n_coarse: int = 64, n_fine: int = 64, + near: float = 0.1, far: float = 100.0, + z_spacing: Literal['linear', 'inverse_linear'] = 'linear', +): + """ + NeRF rendering of rays. Note that it supports arbitrary batch dimensions (denoted as `...`) + + Args: + nerf: nerf model, which takes (points, directions) as input and returns (color, density) as output. + If nerf is a tuple, it should be (nerf_coarse, nerf_fine), where nerf_coarse and nerf_fine are two nerf models for coarse and fine stages respectively. + + nerf args: + points: (..., n_rays, n_samples, 3) + directions: (..., n_rays, n_samples, 3) + nerf returns: + color: (..., n_rays, n_samples, 3) color values. + density: (..., n_rays, n_samples) density values. + + rays_o: (..., n_rays, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + pixel_width: (..., n_rays) pixel width. How to compute? pixel_width = 1 / (normalized focal length * width) + + Returns + if return_dict is False, return rendered rgb and depth for short cut. (If there are separate coarse and fine results, return fine results) + rgb: (..., n_rays, 3) rendered color values. + depth: (..., n_rays) rendered depth values. + else, return a dict. If `n_fine == 0` or `nerf` is a single model, the dict only contains coarse results: + ``` + {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + ``` + If there are two models for coarse and fine stages, the dict contains both coarse and fine results: + ``` + { + "coarse": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}, + "fine": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + } + ``` + """ + if isinstance(nerf, tuple): + nerf_coarse, nerf_fine = nerf + else: + nerf_coarse = nerf_fine = nerf + # 1. Coarse: bin sampling + z_coarse = bin_sample(rays_d.shape[:-1], n_coarse, near, far, device=rays_o.device, dtype=rays_o.dtype, spacing=z_spacing) # (n_batch, n_views, n_rays, n_samples) + points_coarse = rays_o[..., None, :] + rays_d[..., None, :] * z_coarse[..., None] # (n_batch, n_views, n_rays, n_samples, 3) + ray_length = rays_d.norm(dim=-1) + + # Query color and density + color_coarse, density_coarse = nerf_coarse(points_coarse, rays_d[..., None, :].expand_as(points_coarse)) # (n_batch, n_views, n_rays, n_samples, 3), (n_batch, n_views, n_rays, n_samples) + + # Volume rendering + with torch.no_grad(): + rgb_coarse, depth_coarse, weights = volume_rendering(color_coarse, density_coarse, z_coarse, ray_length) # (n_batch, n_views, n_rays, 3), (n_batch, n_views, n_rays, 1), (n_batch, n_views, n_rays, n_samples) + + if n_fine == 0: + if return_dict: + return {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights, 'z_vals': z_coarse, 'color': color_coarse, 'density': density_coarse} + else: + return rgb_coarse, depth_coarse + + # 2. Fine: Importance sampling + if nerf_coarse is nerf_fine: + # If coarse and fine stages share the same model, the points of coarse stage can be reused, + # and we only need to query the importance samples of fine stage. + z_fine = importance_sample(z_coarse, weights, n_fine) + points_fine = rays_o[..., None, :] + rays_d[..., None, :] * z_fine[..., None] + color_fine, density_fine = nerf_fine(points_fine, rays_d[..., None, :].expand_as(points_fine)) + + # Merge & volume rendering + z_vals = torch.cat([z_coarse, z_fine], dim=-1) + color = torch.cat([color_coarse, color_fine], dim=-2) + density = torch.cat([density_coarse, density_fine], dim=-1) + z_vals, sort_inds = torch.sort(z_vals, dim=-1) + color = torch.gather(color, dim=-2, index=sort_inds[..., None].expand_as(color)) + density = torch.gather(density, dim=-1, index=sort_inds) + rgb, depth, weights = volume_rendering(color, density, z_vals, ray_length) + + if return_dict: + return {'rgb': rgb, 'depth': depth, 'weights': weights, 'z_vals': z_vals, 'color': color, 'density': density} + else: + return rgb, depth + else: + # If coarse and fine stages use different models, we need to query the importance samples of both stages. + z_fine = importance_sample(z_coarse, weights, n_fine) + z_vals = torch.cat([z_coarse, z_fine], dim=-1) + points = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., None] + color, density = nerf_fine(points) + rgb, depth, weights = volume_rendering(color, density, z_vals, ray_length) + + if return_dict: + return { + 'coarse': {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights, 'z_vals': z_coarse, 'color': color_coarse, 'density': density_coarse}, + 'fine': {'rgb': rgb, 'depth': depth, 'weights': weights, 'z_vals': z_vals, 'color': color, 'density': density} + } + else: + return rgb, depth + + +def mipnerf_render_rays( + mipnerf: Callable[[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]], + rays_o: Tensor, rays_d: Tensor, pixel_width: Tensor, + *, + return_dict: bool = False, + n_coarse: int = 64, n_fine: int = 64, uniform_ratio: float = 0.4, + near: float = 0.1, far: float = 100.0, + z_spacing: Literal['linear', 'inverse_linear'] = 'linear', +) -> Union[Tuple[Tensor, Tensor], Dict[str, Tensor]]: + """ + MipNeRF rendering. + + Args: + mipnerf: mipnerf model, which takes (points_mu, points_sigma) as input and returns (color, density) as output. + + mipnerf args: + points_mu: (..., n_rays, n_samples, 3) cone mu. + points_sigma: (..., n_rays, n_samples, 3, 3) cone sigma. + directions: (..., n_rays, n_samples, 3) + mipnerf returns: + color: (..., n_rays, n_samples, 3) color values. + density: (..., n_rays, n_samples) density values. + + rays_o: (..., n_rays, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + pixel_width: (..., n_rays) pixel width. How to compute? pixel_width = 1 / (normalized focal length * width) + + Returns + if return_dict is False, return rendered results only: (If `n_fine == 0`, return coarse results, otherwise return fine results) + rgb: (..., n_rays, 3) rendered color values. + depth: (..., n_rays) rendered depth values. + else, return a dict. If `n_fine == 0`, the dict only contains coarse results: + ``` + {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + ``` + If n_fine > 0, the dict contains both coarse and fine results : + ``` + { + "coarse": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}, + "fine": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + } + ``` + """ + # 1. Coarse: bin sampling + z_coarse = bin_sample(rays_d.shape[:-1], n_coarse, near, far, spacing=z_spacing, device=rays_o.device, dtype=rays_o.dtype) + points_mu_coarse, points_sigma_coarse = get_mipnerf_cones(rays_o, rays_d, z_coarse, pixel_width) + ray_length = rays_d.norm(dim=-1) + + # Query color and density + color_coarse, density_coarse = mipnerf(points_mu_coarse, points_sigma_coarse, rays_d[..., None, :].expand_as(points_mu_coarse)) # (n_batch, n_views, n_rays, n_samples, 3), (n_batch, n_views, n_rays, n_samples) + + # Volume rendering + rgb_coarse, depth_coarse, weights_coarse = volume_rendering(color_coarse, density_coarse, z_coarse, ray_length) # (n_batch, n_views, n_rays, 3), (n_batch, n_views, n_rays, 1), (n_batch, n_views, n_rays, n_samples) + + if n_fine == 0: + if return_dict: + return {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights_coarse, 'z_vals': z_coarse, 'color': color_coarse, 'density': density_coarse} + else: + return rgb_coarse, depth_coarse + + # 2. Fine: Importance sampling. (NOTE: coarse stages and fine stages always share the same model, but coarse stage points can not be reused) + with torch.no_grad(): + weights_coarse = (1.0 - uniform_ratio) * weights_coarse + uniform_ratio / weights_coarse.shape[-1] + z_fine = importance_sample(z_coarse, weights_coarse, n_fine) + z_fine, _ = torch.sort(z_fine, dim=-2) + points_mu_fine, points_sigma_fine = get_mipnerf_cones(rays_o, rays_d, z_fine, pixel_width) + color_fine, density_fine = mipnerf(points_mu_fine, points_sigma_fine, rays_d[..., None, :].expand_as(points_mu_fine)) + + # Volume rendering + rgb_fine, depth_fine, weights_fine = volume_rendering(color_fine, density_fine, z_fine, ray_length) + + if return_dict: + return { + 'coarse': {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights_coarse, 'z_vals': z_coarse, 'color': color_coarse, 'density': density_coarse}, + 'fine': {'rgb': rgb_fine, 'depth': depth_fine, 'weights': weights_fine, 'z_vals': z_fine, 'color': color_fine, 'density': density_fine} + } + else: + return rgb_fine, depth_fine + + +def neus_render_rays( + neus: Callable[[Tensor, Tensor], Tuple[Tensor, Tensor]], + s: Union[Number, Tensor], + rays_o: Tensor, rays_d: Tensor, + *, + compute_normal: bool = True, + return_dict: bool = False, + n_coarse: int = 64, n_fine: int = 64, + near: float = 0.1, far: float = 100.0, + z_spacing: Literal['linear', 'inverse_linear'] = 'linear', +): + """ + TODO + NeuS rendering of rays. Note that it supports arbitrary batch dimensions (denoted as `...`) + + Args: + neus: neus model, which takes (points, directions) as input and returns (color, density) as output. + + nerf args: + points: (..., n_rays, n_samples, 3) + directions: (..., n_rays, n_samples, 3) + nerf returns: + color: (..., n_rays, n_samples, 3) color values. + density: (..., n_rays, n_samples) density values. + + rays_o: (..., n_rays, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + pixel_width: (..., n_rays) pixel width. How to compute? pixel_width = 1 / (normalized focal length * width) + + Returns + if return_dict is False, return rendered results only: (If `n_fine == 0`, return coarse results, otherwise return fine results) + rgb: (..., n_rays, 3) rendered color values. + depth: (..., n_rays) rendered depth values. + else, return a dict. If `n_fine == 0`, the dict only contains coarse results: + ``` + {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'sdf': ..., 'normal': ...} + ``` + If n_fine > 0, the dict contains both coarse and fine results: + ``` + { + "coarse": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}, + "fine": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + } + ``` + """ + + # 1. Coarse: bin sampling + z_coarse = bin_sample(rays_d.shape[:-1], n_coarse, near, far, device=rays_o.device, dtype=rays_o.dtype, spacing=z_spacing) # (n_batch, n_views, n_rays, n_samples) + points_coarse = rays_o[..., None, :] + rays_d[..., None, :] * z_coarse[..., None] # (n_batch, n_views, n_rays, n_samples, 3) + + # Query color and density + color_coarse, sdf_coarse = neus(points_coarse, rays_d[..., None, :].expand_as(points_coarse)) # (n_batch, n_views, n_rays, n_samples, 3), (n_batch, n_views, n_rays, n_samples) + + # Volume rendering + with torch.no_grad(): + rgb_coarse, depth_coarse, weights = neus_volume_rendering(color_coarse, sdf_coarse, s, z_coarse) # (n_batch, n_views, n_rays, 3), (n_batch, n_views, n_rays, 1), (n_batch, n_views, n_rays, n_samples) + + if n_fine == 0: + if return_dict: + return {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights, 'z_vals': z_coarse, 'color': color_coarse, 'sdf': sdf_coarse} + else: + return rgb_coarse, depth_coarse + + # If coarse and fine stages share the same model, the points of coarse stage can be reused, + # and we only need to query the importance samples of fine stage. + z_fine = importance_sample(z_coarse, weights, n_fine) + points_fine = rays_o[..., None, :] + rays_d[..., None, :] * z_fine[..., None] + color_fine, sdf_fine = neus(points_fine, rays_d[..., None, :].expand_as(points_fine)) + + # Merge & volume rendering + z_vals = torch.cat([z_coarse, z_fine], dim=-1) + color = torch.cat([color_coarse, color_fine], dim=-2) + sdf = torch.cat([sdf_coarse, sdf_fine], dim=-1) + z_vals, sort_inds = torch.sort(z_vals, dim=-1) + color = torch.gather(color, dim=-2, index=sort_inds[..., None].expand_as(color)) + sdf = torch.gather(sdf, dim=-1, index=sort_inds) + rgb, depth, weights = neus_volume_rendering(color, sdf, s, z_vals) + + if return_dict: + return { + 'coarse': {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights, 'z_vals': z_coarse, 'color': color_coarse, 'sdf': sdf_coarse}, + 'fine': {'rgb': rgb, 'depth': depth, 'weights': weights, 'z_vals': z_vals, 'color': color, 'sdf': sdf} + } + else: + return rgb, depth + + +def nerf_render_view( + nerf: Tensor, + extrinsics: Tensor, + intrinsics: Tensor, + width: int, + height: int, + *, + patchify: bool = False, + patch_size: Tuple[int, int] = (64, 64), + **options: Dict[str, Any] +) -> Tuple[Tensor, Tensor]: + """ + NeRF rendering of views. Note that it supports arbitrary batch dimensions (denoted as `...`) + + Args: + extrinsics: (..., 4, 4) extrinsics matrice of the rendered views + intrinsics (optional): (..., 3, 3) intrinsics matrice of the rendered views. + width (optional): image width of the rendered views. + height (optional): image height of the rendered views. + patchify (optional): If the image is too large, render it patch by patch + **options: rendering options. + + Returns: + rgb: (..., channels, height, width) rendered color values. + depth: (..., height, width) rendered depth values. + """ + if patchify: + # Patchified rendering + max_patch_width, max_patch_height = patch_size + n_rows, n_columns = math.ceil(height / max_patch_height), math.ceil(width / max_patch_width) + + rgb_rows, depth_rows = [], [] + for i_row in range(n_rows): + rgb_row, depth_row = [], [] + for i_column in range(n_columns): + patch_shape = patch_height, patch_width = min(max_patch_height, height - i_row * max_patch_height), min(max_patch_width, width - i_column * max_patch_width) + uv = image_uv(height, width, i_column * max_patch_width, i_row * max_patch_height, i_column * max_patch_width + patch_width, i_row * max_patch_height + patch_height).to(extrinsics) + uv = uv.flatten(0, 1) # (patch_height * patch_width, 2) + ray_o_, ray_d_ = get_rays(extrinsics, intrinsics, uv) + rgb_, depth_ = nerf_render_rays(nerf, ray_o_, ray_d_, **options, return_dict=False) + rgb_ = rgb_.transpose(-1, -2).unflatten(-1, patch_shape) # (..., 3, patch_height, patch_width) + depth_ = depth_.unflatten(-1, patch_shape) # (..., patch_height, patch_width) + + rgb_row.append(rgb_) + depth_row.append(depth_) + rgb_rows.append(torch.cat(rgb_row, dim=-1)) + depth_rows.append(torch.cat(depth_row, dim=-1)) + rgb = torch.cat(rgb_rows, dim=-2) + depth = torch.cat(depth_rows, dim=-2) + + return rgb, depth + else: + # Full rendering + uv = image_uv(height, width).to(extrinsics) + uv = uv.flatten(0, 1) # (height * width, 2) + ray_o_, ray_d_ = get_rays(extrinsics, intrinsics, uv) + rgb, depth = nerf_render_rays(nerf, ray_o_, ray_d_, **options, return_dict=False) + rgb = rgb.transpose(-1, -2).unflatten(-1, (height, width)) # (..., 3, height, width) + depth = depth.unflatten(-1, (height, width)) # (..., height, width) + + return rgb, depth + + +def mipnerf_render_view( + mipnerf: Tensor, + extrinsics: Tensor, + intrinsics: Tensor, + width: int, + height: int, + *, + patchify: bool = False, + patch_size: Tuple[int, int] = (64, 64), + **options: Dict[str, Any] +) -> Tuple[Tensor, Tensor]: + """ + MipNeRF rendering of views. Note that it supports arbitrary batch dimensions (denoted as `...`) + + Args: + extrinsics: (..., 4, 4) extrinsics matrice of the rendered views + intrinsics (optional): (..., 3, 3) intrinsics matrice of the rendered views. + width (optional): image width of the rendered views. + height (optional): image height of the rendered views. + patchify (optional): If the image is too large, render it patch by patch + **options: rendering options. + + Returns: + rgb: (..., 3, height, width) rendered color values. + depth: (..., height, width) rendered depth values. + """ + pixel_width = get_pixel_width(intrinsics, width, height) + + if patchify: + # Patchified rendering + max_patch_width, max_patch_height = patch_size + n_rows, n_columns = math.ceil(height / max_patch_height), math.ceil(width / max_patch_width) + + rgb_rows, depth_rows = [], [] + for i_row in range(n_rows): + rgb_row, depth_row = [], [] + for i_column in range(n_columns): + patch_shape = patch_height, patch_width = min(max_patch_height, height - i_row * max_patch_height), min(max_patch_width, width - i_column * max_patch_width) + uv = image_uv(height, width, i_column * max_patch_width, i_row * max_patch_height, i_column * max_patch_width + patch_width, i_row * max_patch_height + patch_height).to(extrinsics) + uv = uv.flatten(0, 1) # (patch_height * patch_width, 2) + ray_o_, ray_d_ = get_rays(extrinsics, intrinsics, uv) + rgb_, depth_ = mipnerf_render_rays(mipnerf, ray_o_, ray_d_, pixel_width, **options) + rgb_ = rgb_.transpose(-1, -2).unflatten(-1, patch_shape) # (..., 3, patch_height, patch_width) + depth_ = depth_.unflatten(-1, patch_shape) # (..., patch_height, patch_width) + + rgb_row.append(rgb_) + depth_row.append(depth_) + rgb_rows.append(torch.cat(rgb_row, dim=-1)) + depth_rows.append(torch.cat(depth_row, dim=-1)) + rgb = torch.cat(rgb_rows, dim=-2) + depth = torch.cat(depth_rows, dim=-2) + + return rgb, depth + else: + # Full rendering + uv = image_uv(height, width).to(extrinsics) + uv = uv.flatten(0, 1) # (height * width, 2) + ray_o_, ray_d_ = get_rays(extrinsics, intrinsics, uv) + rgb, depth = mipnerf_render_rays(mipnerf, ray_o_, ray_d_, pixel_width, **options) + rgb = rgb.transpose(-1, -2).unflatten(-1, (height, width)) # (..., 3, height, width) + depth = depth.unflatten(-1, (height, width)) # (..., height, width) + + return rgb, depth + + +class InstantNGP(nn.Module): + """ + An implementation of InstantNGP, Müller et. al., https://nvlabs.github.io/instant-ngp/. + Requires `tinycudann` package. + Install it by: + ``` + pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch + ``` + """ + def __init__(self, + view_dependent: bool = True, + base_resolution: int = 16, + finest_resolution: int = 2048, + n_levels: int = 16, + num_layers_density: int = 2, + hidden_dim_density: int = 64, + num_layers_color: int = 3, + hidden_dim_color: int = 64, + log2_hashmap_size: int = 19, + bound: float = 1.0, + color_channels: int = 3, + ): + super().__init__() + import tinycudann + N_FEATURES_PER_LEVEL = 2 + GEO_FEAT_DIM = 15 + + self.bound = bound + self.color_channels = color_channels + + # density network + self.num_layers_density = num_layers_density + self.hidden_dim_density = hidden_dim_density + + per_level_scale = (finest_resolution / base_resolution) ** (1 / (n_levels - 1)) + + self.encoder = tinycudann.Encoding( + n_input_dims=3, + encoding_config={ + "otype": "HashGrid", + "n_levels": n_levels, + "n_features_per_level": N_FEATURES_PER_LEVEL, + "log2_hashmap_size": log2_hashmap_size, + "base_resolution": base_resolution, + "per_level_scale": per_level_scale, + }, + ) + + self.density_net = tinycudann.Network( + n_input_dims=N_FEATURES_PER_LEVEL * n_levels, + n_output_dims=1 + GEO_FEAT_DIM, + network_config={ + "otype": "FullyFusedMLP", + "activation": "ReLU", + "output_activation": "None", + "n_neurons": hidden_dim_density, + "n_hidden_layers": num_layers_density - 1, + }, + ) + + # color network + self.num_layers_color = num_layers_color + self.hidden_dim_color = hidden_dim_color + + self.view_dependent = view_dependent + if view_dependent: + self.encoder_dir = tinycudann.Encoding( + n_input_dims=3, + encoding_config={ + "otype": "SphericalHarmonics", + "degree": 4, + }, + ) + self.in_dim_color = self.encoder_dir.n_output_dims + GEO_FEAT_DIM + else: + self.in_dim_color = GEO_FEAT_DIM + + self.color_net = tinycudann.Network( + n_input_dims=self.in_dim_color, + n_output_dims=color_channels, + network_config={ + "otype": "FullyFusedMLP", + "activation": "ReLU", + "output_activation": "None", + "n_neurons": hidden_dim_color, + "n_hidden_layers": num_layers_color - 1, + }, + ) + + def forward(self, x: torch.Tensor, d: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: (..., 3) points + d: (..., 3) directions + Returns: + color: (..., 3) color values. + density: (..., 1) density values. + """ + batch_shape = x.shape[:-1] + x, d = x.reshape(-1, 3), d.reshape(-1, 3) + + # density + x = (x + self.bound) / (2 * self.bound) # to [0, 1] + x = self.encoder(x) + density, geo_feat = self.density_net(x).split([1, 15], dim=-1) + density = F.softplus(density).squeeze(-1) + + # color + if self.view_dependent: + d = (F.normalize(d, dim=-1) + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1] + d = self.encoder_dir(d) + h = torch.cat([d, geo_feat], dim=-1) + else: + h = geo_feat + color = self.color_net(h) + + return color.reshape(*batch_shape, self.color_channels), density.reshape(*batch_shape) + diff --git a/src/utils3d/torch/rasterization.py b/src/utils3d/torch/rasterization.py new file mode 100644 index 0000000000000000000000000000000000000000..11802737ebeae9be2e6b7bda7ee0d933b01ab909 --- /dev/null +++ b/src/utils3d/torch/rasterization.py @@ -0,0 +1,392 @@ +from typing import * + +import torch +import nvdiffrast.torch as dr + +from . import utils, transforms, mesh +from ._helpers import batched + + +__all__ = [ + 'RastContext', + 'rasterize_triangle_faces', + 'warp_image_by_depth', + 'warp_image_by_forward_flow', +] + + +class RastContext: + """ + Create a rasterization context. Nothing but a wrapper of nvdiffrast.torch.RasterizeCudaContext or nvdiffrast.torch.RasterizeGLContext. + """ + def __init__(self, nvd_ctx: Union[dr.RasterizeCudaContext, dr.RasterizeGLContext] = None, *, backend: Literal['cuda', 'gl'] = 'gl', device: Union[str, torch.device] = None): + if nvd_ctx is not None: + self.nvd_ctx = nvd_ctx + return + + if backend == 'gl': + self.nvd_ctx = dr.RasterizeGLContext(device=device) + elif backend == 'cuda': + self.nvd_ctx = dr.RasterizeCudaContext(device=device) + else: + raise ValueError(f'Unknown backend: {backend}') + + +def rasterize_triangle_faces( + ctx: RastContext, + vertices: torch.Tensor, + faces: torch.Tensor, + width: int, + height: int, + attr: torch.Tensor = None, + uv: torch.Tensor = None, + texture: torch.Tensor = None, + model: torch.Tensor = None, + view: torch.Tensor = None, + projection: torch.Tensor = None, + antialiasing: Union[bool, List[int]] = True, + diff_attrs: Union[None, List[int]] = None, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Rasterize a mesh with vertex attributes. + + Args: + ctx (GLContext): rasterizer context + vertices (np.ndarray): (B, N, 2 or 3 or 4) + faces (torch.Tensor): (T, 3) + width (int): width of the output image + height (int): height of the output image + attr (torch.Tensor, optional): (B, N, C) vertex attributes. Defaults to None. + uv (torch.Tensor, optional): (B, N, 2) uv coordinates. Defaults to None. + texture (torch.Tensor, optional): (B, H, W, C) texture. Defaults to None. + model (torch.Tensor, optional): ([B,] 4, 4) model matrix. Defaults to None (identity). + view (torch.Tensor, optional): ([B,] 4, 4) view matrix. Defaults to None (identity). + projection (torch.Tensor, optional): ([B,] 4, 4) projection matrix. Defaults to None (identity). + antialiasing (Union[bool, List[int]], optional): whether to perform antialiasing. Defaults to True. If a list of indices is provided, only those channels will be antialiased. + diff_attrs (Union[None, List[int]], optional): indices of attributes to compute screen-space derivatives. Defaults to None. + + Returns: + Dictionary containing: + - image: (torch.Tensor): (B, C, H, W) + - depth: (torch.Tensor): (B, H, W) screen space depth, ranging from 0 (near) to 1. (far) + NOTE: Empty pixels will have depth 1., i.e. far plane. + - mask: (torch.BoolTensor): (B, H, W) mask of valid pixels + - image_dr: (torch.Tensor): (B, 4, H, W) screen space derivatives of the attributes + - face_id: (torch.Tensor): (B, H, W) face ids + - uv: (torch.Tensor): (B, N, 2) uv coordinates (if uv is not None) + - uv_dr: (torch.Tensor): (B, N, 4) uv derivatives (if uv is not None) + - texture: (torch.Tensor): (B, H, W, C) texture (if uv and texture are not None) + """ + assert vertices.ndim == 3 + assert faces.ndim == 2 + + if vertices.shape[-1] == 2: + vertices = torch.cat([vertices, torch.zeros_like(vertices[..., :1]), torch.ones_like(vertices[..., :1])], dim=-1) + elif vertices.shape[-1] == 3: + vertices = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1) + elif vertices.shape[-1] == 4: + pass + else: + raise ValueError(f'Wrong shape of vertices: {vertices.shape}') + + mvp = projection if projection is not None else torch.eye(4).to(vertices) + if view is not None: + mvp = mvp @ view + if model is not None: + mvp = mvp @ model + + pos_clip = vertices @ mvp.transpose(-1, -2) + faces = faces.contiguous() + if attr is not None: + attr = attr.contiguous() + + rast_out, rast_db = dr.rasterize(ctx.nvd_ctx, pos_clip, faces, resolution=[height, width], grad_db=True) + face_id = rast_out[..., 3].flip(1) + depth = rast_out[..., 2].flip(1) + mask = (face_id > 0).float() + depth = (depth * 0.5 + 0.5) * mask + (1.0 - mask) + + ret = { + 'depth': depth, + 'mask': mask, + 'face_id': face_id, + } + + if attr is not None: + image, image_dr = dr.interpolate(attr, rast_out, faces, rast_db, diff_attrs=diff_attrs) + if antialiasing == True: + image = dr.antialias(image, rast_out, pos_clip, faces) + elif isinstance(antialiasing, list): + aa_image = dr.antialias(image[..., antialiasing], rast_out, pos_clip, faces) + image[..., antialiasing] = aa_image + image = image.flip(1).permute(0, 3, 1, 2) + ret['image'] = image + + if uv is not None: + uv_map, uv_map_dr = dr.interpolate(uv, rast_out, faces, rast_db, diff_attrs='all') + ret['uv'] = uv_map + ret['uv_dr'] = uv_map_dr + if texture is not None: + texture_map = dr.texture(ctx.nvd_ctx, uv_map, uv_map_dr) + ret['texture'] = texture_map.flip(1).permute(0, 3, 1, 2) + + if diff_attrs is not None: + image_dr = image_dr.flip(1).permute(0, 3, 1, 2) + ret['image_dr'] = image_dr + + return ret + + +def texture( + ctx: RastContext, + uv: torch.Tensor, + uv_da: torch.Tensor, + texture: torch.Tensor, +) -> torch.Tensor: + dr.texture(ctx.nvd_ctx, uv, texture) + + +def warp_image_by_depth( + ctx: RastContext, + depth: torch.FloatTensor, + image: torch.FloatTensor = None, + mask: torch.BoolTensor = None, + width: int = None, + height: int = None, + *, + extrinsics_src: torch.FloatTensor = None, + extrinsics_tgt: torch.FloatTensor = None, + intrinsics_src: torch.FloatTensor = None, + intrinsics_tgt: torch.FloatTensor = None, + near: float = 0.1, + far: float = 100.0, + antialiasing: bool = True, + backslash: bool = False, + padding: int = 0, + return_uv: bool = False, + return_dr: bool = False, +) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.BoolTensor, Optional[torch.FloatTensor], Optional[torch.FloatTensor]]: + """ + Warp image by depth. + NOTE: if batch size is 1, image mesh will be triangulated aware of the depth, yielding less distorted results. + Otherwise, image mesh will be triangulated simply for batch rendering. + + Args: + ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): rasterization context + depth (torch.Tensor): (B, H, W) linear depth + image (torch.Tensor): (B, C, H, W). None to use image space uv. Defaults to None. + width (int, optional): width of the output image. None to use the same as depth. Defaults to None. + height (int, optional): height of the output image. Defaults the same as depth.. + extrinsics_src (torch.Tensor, optional): (B, 4, 4) extrinsics matrix for source. None to use identity. Defaults to None. + extrinsics_tgt (torch.Tensor, optional): (B, 4, 4) extrinsics matrix for target. None to use identity. Defaults to None. + intrinsics_src (torch.Tensor, optional): (B, 3, 3) intrinsics matrix for source. None to use the same as target. Defaults to None. + intrinsics_tgt (torch.Tensor, optional): (B, 3, 3) intrinsics matrix for target. None to use the same as source. Defaults to None. + near (float, optional): near plane. Defaults to 0.1. + far (float, optional): far plane. Defaults to 100.0. + antialiasing (bool, optional): whether to perform antialiasing. Defaults to True. + backslash (bool, optional): whether to use backslash triangulation. Defaults to False. + padding (int, optional): padding of the image. Defaults to 0. + return_uv (bool, optional): whether to return the uv. Defaults to False. + return_dr (bool, optional): whether to return the image-space derivatives of uv. Defaults to False. + + Returns: + image: (torch.FloatTensor): (B, C, H, W) rendered image + depth: (torch.FloatTensor): (B, H, W) linear depth, ranging from 0 to inf + mask: (torch.BoolTensor): (B, H, W) mask of valid pixels + uv: (torch.FloatTensor): (B, 2, H, W) image-space uv + dr: (torch.FloatTensor): (B, 4, H, W) image-space derivatives of uv + """ + assert depth.ndim == 3 + batch_size = depth.shape[0] + + if width is None: + width = depth.shape[-1] + if height is None: + height = depth.shape[-2] + if image is not None: + assert image.shape[-2:] == depth.shape[-2:], f'Shape of image {image.shape} does not match shape of depth {depth.shape}' + + if extrinsics_src is None: + extrinsics_src = torch.eye(4).to(depth) + if extrinsics_tgt is None: + extrinsics_tgt = torch.eye(4).to(depth) + if intrinsics_src is None: + intrinsics_src = intrinsics_tgt + if intrinsics_tgt is None: + intrinsics_tgt = intrinsics_src + + assert all(x is not None for x in [extrinsics_src, extrinsics_tgt, intrinsics_src, intrinsics_tgt]), "Make sure you have provided all the necessary camera parameters." + + view_tgt = transforms.extrinsics_to_view(extrinsics_tgt) + perspective_tgt = transforms.intrinsics_to_perspective(intrinsics_tgt, near=near, far=far) + + if padding > 0: + uv, faces = utils.image_mesh(width=width+2, height=height+2) + uv = (uv - 1 / (width + 2)) * ((width + 2) / width) + uv_ = uv.clone().reshape(height+2, width+2, 2) + uv_[0, :, 1] -= padding / height + uv_[-1, :, 1] += padding / height + uv_[:, 0, 0] -= padding / width + uv_[:, -1, 0] += padding / width + uv_ = uv_.reshape(-1, 2) + depth = torch.nn.functional.pad(depth, [1, 1, 1, 1], mode='replicate') + if image is not None: + image = torch.nn.functional.pad(image, [1, 1, 1, 1], mode='replicate') + uv, uv_, faces = uv.to(depth.device), uv_.to(depth.device), faces.to(depth.device) + pts = transforms.unproject_cv( + uv_, + depth.flatten(-2, -1), + extrinsics_src, + intrinsics_src, + ) + else: + uv, faces = utils.image_mesh(width=depth.shape[-1], height=depth.shape[-2]) + if mask is not None: + depth = torch.where(mask, depth, torch.tensor(far, dtype=depth.dtype, device=depth.device)) + uv, faces = uv.to(depth.device), faces.to(depth.device) + pts = transforms.unproject_cv( + uv, + depth.flatten(-2, -1), + extrinsics_src, + intrinsics_src, + ) + + # triangulate + if batch_size == 1: + faces = mesh.triangulate(faces, vertices=pts[0]) + else: + faces = mesh.triangulate(faces, backslash=backslash) + + # rasterize attributes + diff_attrs = None + if image is not None: + attr = image.permute(0, 2, 3, 1).flatten(1, 2) + if return_dr or return_uv: + if return_dr: + diff_attrs = [image.shape[1], image.shape[1]+1] + if return_uv and antialiasing: + antialiasing = list(range(image.shape[1])) + attr = torch.cat([attr, uv.expand(batch_size, -1, -1)], dim=-1) + else: + attr = uv.expand(batch_size, -1, -1) + if antialiasing: + print("\033[93mWarning: you are performing antialiasing on uv. This may cause artifacts.\033[0m") + if return_uv: + return_uv = False + print("\033[93mWarning: image is None, return_uv is ignored.\033[0m") + if return_dr: + diff_attrs = [0, 1] + + if mask is not None: + attr = torch.cat([attr, mask.float().flatten(1, 2).unsqueeze(-1)], dim=-1) + + rast = rasterize_triangle_faces( + ctx, + pts, + faces, + width, + height, + attr=attr, + view=view_tgt, + perspective=perspective_tgt, + antialiasing=antialiasing, + diff_attrs=diff_attrs, + ) + if return_dr: + output_image, screen_depth, output_dr = rast['image'], rast['depth'], rast['image_dr'] + else: + output_image, screen_depth = rast['image'], rast['depth'] + output_mask = screen_depth < 1.0 + + if mask is not None: + output_image, rast_mask = output_image[..., :-1, :, :], output_image[..., -1, :, :] + output_mask &= (rast_mask > 0.9999).reshape(-1, height, width) + + if (return_dr or return_uv) and image is not None: + output_image, output_uv = output_image[..., :-2, :, :], output_image[..., -2:, :, :] + + output_depth = transforms.depth_buffer_to_linear(screen_depth, near=near, far=far) * output_mask + output_image = output_image * output_mask.unsqueeze(1) + + outs = [output_image, output_depth, output_mask] + if return_uv: + outs.append(output_uv) + if return_dr: + outs.append(output_dr) + return tuple(outs) + + +def warp_image_by_forward_flow( + ctx: RastContext, + image: torch.FloatTensor, + flow: torch.FloatTensor, + depth: torch.FloatTensor = None, + *, + antialiasing: bool = True, + backslash: bool = False, +) -> Tuple[torch.FloatTensor, torch.BoolTensor]: + """ + Warp image by forward flow. + NOTE: if batch size is 1, image mesh will be triangulated aware of the depth, yielding less distorted results. + Otherwise, image mesh will be triangulated simply for batch rendering. + + Args: + ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): rasterization context + image (torch.Tensor): (B, C, H, W) image + flow (torch.Tensor): (B, 2, H, W) forward flow + depth (torch.Tensor, optional): (B, H, W) linear depth. If None, will use the same for all pixels. Defaults to None. + antialiasing (bool, optional): whether to perform antialiasing. Defaults to True. + backslash (bool, optional): whether to use backslash triangulation. Defaults to False. + + Returns: + image: (torch.FloatTensor): (B, C, H, W) rendered image + mask: (torch.BoolTensor): (B, H, W) mask of valid pixels + """ + assert image.ndim == 4, f'Wrong shape of image: {image.shape}' + batch_size, _, height, width = image.shape + + if depth is None: + depth = torch.ones_like(flow[:, 0]) + + extrinsics = torch.eye(4).to(image) + fov = torch.deg2rad(torch.tensor([45.0], device=image.device)) + intrinsics = transforms.intrinsics_from_fov(fov, width, height, normalize=True)[0] + + view = transforms.extrinsics_to_view(extrinsics) + perspective = transforms.intrinsics_to_perspective(intrinsics, near=0.1, far=100) + + uv, faces = utils.image_mesh(width=width, height=height) + uv, faces = uv.to(image.device), faces.to(image.device) + uv = uv + flow.permute(0, 2, 3, 1).flatten(1, 2) + pts = transforms.unproject_cv( + uv, + depth.flatten(-2, -1), + extrinsics, + intrinsics, + ) + + # triangulate + if batch_size == 1: + faces = mesh.triangulate(faces, vertices=pts[0]) + else: + faces = mesh.triangulate(faces, backslash=backslash) + + # rasterize attributes + attr = image.permute(0, 2, 3, 1).flatten(1, 2) + rast = rasterize_triangle_faces( + ctx, + pts, + faces, + width, + height, + attr=attr, + view=view, + perspective=perspective, + antialiasing=antialiasing, + ) + output_image, screen_depth = rast['image'], rast['depth'] + output_mask = screen_depth < 1.0 + output_image = output_image * output_mask.unsqueeze(1) + + outs = [output_image, output_mask] + return tuple(outs) diff --git a/src/utils3d/torch/transforms.py b/src/utils3d/torch/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..46e61d741c6ef000c80aa65201b55e31ed4246c6 --- /dev/null +++ b/src/utils3d/torch/transforms.py @@ -0,0 +1,1189 @@ +from typing import * +from numbers import Number + +import torch +import torch.nn.functional as F + +from ._helpers import batched + + +__all__ = [ + 'perspective', + 'perspective_from_fov', + 'perspective_from_fov_xy', + 'intrinsics_from_focal_center', + 'intrinsics_from_fov', + 'intrinsics_from_fov_xy', + 'view_look_at', + 'extrinsics_look_at', + 'perspective_to_intrinsics', + 'intrinsics_to_perspective', + 'extrinsics_to_view', + 'view_to_extrinsics', + 'normalize_intrinsics', + 'crop_intrinsics', + 'pixel_to_uv', + 'pixel_to_ndc', + 'uv_to_pixel', + 'project_depth', + 'depth_buffer_to_linear', + 'project_gl', + 'project_cv', + 'unproject_gl', + 'unproject_cv', + 'skew_symmetric', + 'rotation_matrix_from_vectors', + 'euler_axis_angle_rotation', + 'euler_angles_to_matrix', + 'matrix_to_euler_angles', + 'matrix_to_quaternion', + 'quaternion_to_matrix', + 'matrix_to_axis_angle', + 'axis_angle_to_matrix', + 'axis_angle_to_quaternion', + 'quaternion_to_axis_angle', + 'slerp', + 'interpolate_extrinsics', + 'interpolate_view', + 'extrinsics_to_essential', + 'to4x4', + 'rotation_matrix_2d', + 'rotate_2d', + 'translate_2d', + 'scale_2d', + 'apply_2d', +] + + +@batched(0,0,0,0) +def perspective( + fov_y: Union[float, torch.Tensor], + aspect: Union[float, torch.Tensor], + near: Union[float, torch.Tensor], + far: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Get OpenGL perspective matrix + + Args: + fov_y (float | torch.Tensor): field of view in y axis + aspect (float | torch.Tensor): aspect ratio + near (float | torch.Tensor): near plane to clip + far (float | torch.Tensor): far plane to clip + + Returns: + (torch.Tensor): [..., 4, 4] perspective matrix + """ + N = fov_y.shape[0] + ret = torch.zeros((N, 4, 4), dtype=fov_y.dtype, device=fov_y.device) + ret[:, 0, 0] = 1. / (torch.tan(fov_y / 2) * aspect) + ret[:, 1, 1] = 1. / (torch.tan(fov_y / 2)) + ret[:, 2, 2] = (near + far) / (near - far) + ret[:, 2, 3] = 2. * near * far / (near - far) + ret[:, 3, 2] = -1. + return ret + + +def perspective_from_fov( + fov: Union[float, torch.Tensor], + width: Union[int, torch.Tensor], + height: Union[int, torch.Tensor], + near: Union[float, torch.Tensor], + far: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Get OpenGL perspective matrix from field of view in largest dimension + + Args: + fov (float | torch.Tensor): field of view in largest dimension + width (int | torch.Tensor): image width + height (int | torch.Tensor): image height + near (float | torch.Tensor): near plane to clip + far (float | torch.Tensor): far plane to clip + + Returns: + (torch.Tensor): [..., 4, 4] perspective matrix + """ + fov_y = 2 * torch.atan(torch.tan(fov / 2) * height / torch.maximum(width, height)) + aspect = width / height + return perspective(fov_y, aspect, near, far) + + +def perspective_from_fov_xy( + fov_x: Union[float, torch.Tensor], + fov_y: Union[float, torch.Tensor], + near: Union[float, torch.Tensor], + far: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Get OpenGL perspective matrix from field of view in x and y axis + + Args: + fov_x (float | torch.Tensor): field of view in x axis + fov_y (float | torch.Tensor): field of view in y axis + near (float | torch.Tensor): near plane to clip + far (float | torch.Tensor): far plane to clip + + Returns: + (torch.Tensor): [..., 4, 4] perspective matrix + """ + aspect = torch.tan(fov_x / 2) / torch.tan(fov_y / 2) + return perspective(fov_y, aspect, near, far) + + +@batched(0,0,0,0) +def intrinsics_from_focal_center( + fx: Union[float, torch.Tensor], + fy: Union[float, torch.Tensor], + cx: Union[float, torch.Tensor], + cy: Union[float, torch.Tensor] +) -> torch.Tensor: + """ + Get OpenCV intrinsics matrix + + Args: + focal_x (float | torch.Tensor): focal length in x axis + focal_y (float | torch.Tensor): focal length in y axis + cx (float | torch.Tensor): principal point in x axis + cy (float | torch.Tensor): principal point in y axis + + Returns: + (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix + """ + N = fx.shape[0] + ret = torch.zeros((N, 3, 3), dtype=fx.dtype, device=fx.device) + zeros, ones = torch.zeros(N, dtype=fx.dtype, device=fx.device), torch.ones(N, dtype=fx.dtype, device=fx.device) + ret = torch.stack([fx, zeros, cx, zeros, fy, cy, zeros, zeros, ones], dim=-1).unflatten(-1, (3, 3)) + return ret + + +@batched(0, 0, 0, 0, 0, 0) +def intrinsics_from_fov( + fov_max: Union[float, torch.Tensor] = None, + fov_min: Union[float, torch.Tensor] = None, + fov_x: Union[float, torch.Tensor] = None, + fov_y: Union[float, torch.Tensor] = None, + width: Union[int, torch.Tensor] = None, + height: Union[int, torch.Tensor] = None, +) -> torch.Tensor: + """ + Get normalized OpenCV intrinsics matrix from given field of view. + You can provide either fov_max, fov_min, fov_x or fov_y + + Args: + width (int | torch.Tensor): image width + height (int | torch.Tensor): image height + fov_max (float | torch.Tensor): field of view in largest dimension + fov_min (float | torch.Tensor): field of view in smallest dimension + fov_x (float | torch.Tensor): field of view in x axis + fov_y (float | torch.Tensor): field of view in y axis + + Returns: + (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix + """ + if fov_max is not None: + fx = torch.maximum(width, height) / width / (2 * torch.tan(fov_max / 2)) + fy = torch.maximum(width, height) / height / (2 * torch.tan(fov_max / 2)) + elif fov_min is not None: + fx = torch.minimum(width, height) / width / (2 * torch.tan(fov_min / 2)) + fy = torch.minimum(width, height) / height / (2 * torch.tan(fov_min / 2)) + elif fov_x is not None and fov_y is not None: + fx = 1 / (2 * torch.tan(fov_x / 2)) + fy = 1 / (2 * torch.tan(fov_y / 2)) + elif fov_x is not None: + fx = 1 / (2 * torch.tan(fov_x / 2)) + fy = fx * width / height + elif fov_y is not None: + fy = 1 / (2 * torch.tan(fov_y / 2)) + fx = fy * height / width + cx = 0.5 + cy = 0.5 + ret = intrinsics_from_focal_center(fx, fy, cx, cy) + return ret + + + +def intrinsics_from_fov_xy( + fov_x: Union[float, torch.Tensor], + fov_y: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Get OpenCV intrinsics matrix from field of view in x and y axis + + Args: + fov_x (float | torch.Tensor): field of view in x axis + fov_y (float | torch.Tensor): field of view in y axis + + Returns: + (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix + """ + focal_x = 0.5 / torch.tan(fov_x / 2) + focal_y = 0.5 / torch.tan(fov_y / 2) + cx = cy = 0.5 + return intrinsics_from_focal_center(focal_x, focal_y, cx, cy) + + +@batched(1,1,1) +def view_look_at( + eye: torch.Tensor, + look_at: torch.Tensor, + up: torch.Tensor + ) -> torch.Tensor: + """ + Get OpenGL view matrix looking at something + + Args: + eye (torch.Tensor): [..., 3] the eye position + look_at (torch.Tensor): [..., 3] the position to look at + up (torch.Tensor): [..., 3] head up direction (y axis in screen space). Not necessarily othogonal to view direction + + Returns: + (torch.Tensor): [..., 4, 4], view matrix + """ + N = eye.shape[0] + z = eye - look_at + x = torch.cross(up, z, dim=-1) + y = torch.cross(z, x, dim=-1) + # x = torch.cross(y, z, dim=-1) + x = x / x.norm(dim=-1, keepdim=True) + y = y / y.norm(dim=-1, keepdim=True) + z = z / z.norm(dim=-1, keepdim=True) + R = torch.stack([x, y, z], dim=-2) + t = -torch.matmul(R, eye[..., None]) + ret = torch.zeros((N, 4, 4), dtype=eye.dtype, device=eye.device) + ret[:, :3, :3] = R + ret[:, :3, 3] = t[:, :, 0] + ret[:, 3, 3] = 1. + return ret + + +@batched(1, 1, 1) +def extrinsics_look_at( + eye: torch.Tensor, + look_at: torch.Tensor, + up: torch.Tensor +) -> torch.Tensor: + """ + Get OpenCV extrinsics matrix looking at something + + Args: + eye (torch.Tensor): [..., 3] the eye position + look_at (torch.Tensor): [..., 3] the position to look at + up (torch.Tensor): [..., 3] head up direction (-y axis in screen space). Not necessarily othogonal to view direction + + Returns: + (torch.Tensor): [..., 4, 4], extrinsics matrix + """ + N = eye.shape[0] + z = look_at - eye + x = torch.cross(-up, z, dim=-1) + y = torch.cross(z, x, dim=-1) + # x = torch.cross(y, z, dim=-1) + x = x / x.norm(dim=-1, keepdim=True) + y = y / y.norm(dim=-1, keepdim=True) + z = z / z.norm(dim=-1, keepdim=True) + R = torch.stack([x, y, z], dim=-2) + t = -torch.matmul(R, eye[..., None]) + ret = torch.zeros((N, 4, 4), dtype=eye.dtype, device=eye.device) + ret[:, :3, :3] = R + ret[:, :3, 3] = t[:, :, 0] + ret[:, 3, 3] = 1. + return ret + + +@batched(2) +def perspective_to_intrinsics( + perspective: torch.Tensor +) -> torch.Tensor: + """ + OpenGL perspective matrix to OpenCV intrinsics + + Args: + perspective (torch.Tensor): [..., 4, 4] OpenGL perspective matrix + + Returns: + (torch.Tensor): shape [..., 3, 3] OpenCV intrinsics + """ + assert torch.allclose(perspective[:, [0, 1, 3], 3], 0), "The perspective matrix is not a projection matrix" + ret = torch.tensor([[0.5, 0., 0.5], [0., -0.5, 0.5], [0., 0., 1.]], dtype=perspective.dtype, device=perspective.device) \ + @ perspective[:, [0, 1, 3], :3] \ + @ torch.diag(torch.tensor([1, -1, -1], dtype=perspective.dtype, device=perspective.device)) + return ret / ret[:, 2, 2, None, None] + + +@batched(2,0,0) +def intrinsics_to_perspective( + intrinsics: torch.Tensor, + near: Union[float, torch.Tensor], + far: Union[float, torch.Tensor], + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix + near (float | torch.Tensor): [...] near plane to clip + far (float | torch.Tensor): [...] far plane to clip + Returns: + (torch.Tensor): [..., 4, 4] OpenGL perspective matrix + """ + N = intrinsics.shape[0] + fx, fy = intrinsics[:, 0, 0], intrinsics[:, 1, 1] + cx, cy = intrinsics[:, 0, 2], intrinsics[:, 1, 2] + ret = torch.zeros((N, 4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[:, 0, 0] = 2 * fx + ret[:, 1, 1] = 2 * fy + ret[:, 0, 2] = -2 * cx + 1 + ret[:, 1, 2] = 2 * cy - 1 + ret[:, 2, 2] = (near + far) / (near - far) + ret[:, 2, 3] = 2. * near * far / (near - far) + ret[:, 3, 2] = -1. + return ret + + +@batched(2) +def extrinsics_to_view( + extrinsics: torch.Tensor + ) -> torch.Tensor: + """ + OpenCV camera extrinsics to OpenGL view matrix + + Args: + extrinsics (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix + + Returns: + (torch.Tensor): [..., 4, 4] OpenGL view matrix + """ + return extrinsics * torch.tensor([1, -1, -1, 1], dtype=extrinsics.dtype, device=extrinsics.device)[:, None] + + +@batched(2) +def view_to_extrinsics( + view: torch.Tensor + ) -> torch.Tensor: + """ + OpenGL view matrix to OpenCV camera extrinsics + + Args: + view (torch.Tensor): [..., 4, 4] OpenGL view matrix + + Returns: + (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix + """ + return view * torch.tensor([1, -1, -1, 1], dtype=view.dtype, device=view.device)[:, None] + + +@batched(2,0,0) +def normalize_intrinsics( + intrinsics: torch.Tensor, + width: Union[int, torch.Tensor], + height: Union[int, torch.Tensor] + ) -> torch.Tensor: + """ + Normalize camera intrinsics(s) to uv space + + Args: + intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to normalize + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + + Returns: + (torch.Tensor): [..., 3, 3] normalized camera intrinsics(s) + """ + zeros = torch.zeros_like(width) + ones = torch.ones_like(width) + transform = torch.stack([ + 1 / width, zeros, 0.5 / width, + zeros, 1 / height, 0.5 / height, + zeros, zeros, ones + ]).reshape(*zeros.shape, 3, 3).to(intrinsics) + return transform @ intrinsics + + + +@batched(2,0,0,0,0,0,0) +def crop_intrinsics( + intrinsics: torch.Tensor, + width: Union[int, torch.Tensor], + height: Union[int, torch.Tensor], + left: Union[int, torch.Tensor], + top: Union[int, torch.Tensor], + crop_width: Union[int, torch.Tensor], + crop_height: Union[int, torch.Tensor] +) -> torch.Tensor: + """ + Evaluate the new intrinsics(s) after crop the image: cropped_img = img[top:top+crop_height, left:left+crop_width] + + Args: + intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to crop + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + left (int | torch.Tensor): [...] left crop boundary + top (int | torch.Tensor): [...] top crop boundary + crop_width (int | torch.Tensor): [...] crop width + crop_height (int | torch.Tensor): [...] crop height + + Returns: + (torch.Tensor): [..., 3, 3] cropped camera intrinsics(s) + """ + zeros = torch.zeros_like(width) + ones = torch.ones_like(width) + transform = torch.stack([ + width / crop_width, zeros, -left / crop_width, + zeros, height / crop_height, -top / crop_height, + zeros, zeros, ones + ]).reshape(*zeros.shape, 3, 3).to(intrinsics) + return transform @ intrinsics + + +@batched(1,0,0) +def pixel_to_uv( + pixel: torch.Tensor, + width: Union[int, torch.Tensor], + height: Union[int, torch.Tensor] +) -> torch.Tensor: + """ + Args: + pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + + Returns: + (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1) + """ + if not torch.is_floating_point(pixel): + pixel = pixel.float() + uv = (pixel + 0.5) / torch.stack([width, height], dim=-1).to(pixel) + return uv + + +@batched(1,0,0) +def uv_to_pixel( + uv: torch.Tensor, + width: Union[int, torch.Tensor], + height: Union[int, torch.Tensor] +) -> torch.Tensor: + """ + Args: + uv (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1) + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + + Returns: + (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1) + """ + pixel = uv * torch.stack([width, height], dim=-1).to(uv) - 0.5 + return pixel + + +@batched(1,0,0) +def pixel_to_ndc( + pixel: torch.Tensor, + width: Union[int, torch.Tensor], + height: Union[int, torch.Tensor] +) -> torch.Tensor: + """ + Args: + pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + + Returns: + (torch.Tensor): [..., 2] pixel coordinrates defined in ndc space, the range is (-1, 1) + """ + if not torch.is_floating_point(pixel): + pixel = pixel.float() + ndc = (pixel + 0.5) / (torch.stack([width, height], dim=-1).to(pixel) * torch.tensor([2, -2], dtype=pixel.dtype, device=pixel.device)) \ + + torch.tensor([-1, 1], dtype=pixel.dtype, device=pixel.device) + return ndc + + +@batched(0,0,0) +def project_depth( + depth: torch.Tensor, + near: Union[float, torch.Tensor], + far: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Project linear depth to depth value in screen space + + Args: + depth (torch.Tensor): [...] depth value + near (float | torch.Tensor): [...] near plane to clip + far (float | torch.Tensor): [...] far plane to clip + + Returns: + (torch.Tensor): [..., 1] depth value in screen space, value ranging in [0, 1] + """ + return (far - near * far / depth) / (far - near) + + +@batched(0,0,0) +def depth_buffer_to_linear( + depth: torch.Tensor, + near: Union[float, torch.Tensor], + far: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Linearize depth value to linear depth + + Args: + depth (torch.Tensor): [...] screen depth value, ranging in [0, 1] + near (float | torch.Tensor): [...] near plane to clip + far (float | torch.Tensor): [...] far plane to clip + + Returns: + (torch.Tensor): [...] linear depth + """ + return near * far / (far - (far - near) * depth) + + +@batched(2, 2, 2, 2) +def project_gl( + points: torch.Tensor, + model: torch.Tensor = None, + view: torch.Tensor = None, + perspective: torch.Tensor = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Project 3D points to 2D following the OpenGL convention (except for row major matrice) + + Args: + points (torch.Tensor): [..., N, 3 or 4] 3D points to project, if the last + dimension is 4, the points are assumed to be in homogeneous coordinates + model (torch.Tensor): [..., 4, 4] model matrix + view (torch.Tensor): [..., 4, 4] view matrix + perspective (torch.Tensor): [..., 4, 4] perspective matrix + + Returns: + scr_coord (torch.Tensor): [..., N, 3] screen space coordinates, value ranging in [0, 1]. + The origin (0., 0., 0.) is corresponding to the left & bottom & nearest + linear_depth (torch.Tensor): [..., N] linear depth + """ + assert perspective is not None, "perspective matrix is required" + + if points.shape[-1] == 3: + points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + mvp = perspective if perspective is not None else torch.eye(4).to(points) + if view is not None: + mvp = mvp @ view + if model is not None: + mvp = mvp @ model + clip_coord = points @ mvp.transpose(-1, -2) + ndc_coord = clip_coord[..., :3] / clip_coord[..., 3:] + scr_coord = ndc_coord * 0.5 + 0.5 + linear_depth = clip_coord[..., 3] + return scr_coord, linear_depth + + +@batched(2, 2, 2) +def project_cv( + points: torch.Tensor, + extrinsics: torch.Tensor = None, + intrinsics: torch.Tensor = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Project 3D points to 2D following the OpenCV convention + + Args: + points (torch.Tensor): [..., N, 3] or [..., N, 4] 3D points to project, if the last + dimension is 4, the points are assumed to be in homogeneous coordinates + extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix + intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix + + Returns: + uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + linear_depth (torch.Tensor): [..., N] linear depth + """ + assert intrinsics is not None, "intrinsics matrix is required" + if points.shape[-1] == 3: + points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + if extrinsics is not None: + points = points @ extrinsics.transpose(-1, -2) + points = points[..., :3] @ intrinsics.transpose(-2, -1) + uv_coord = points[..., :2] / points[..., 2:] + linear_depth = points[..., 2] + return uv_coord, linear_depth + + +@batched(2, 2, 2, 2) +def unproject_gl( + screen_coord: torch.Tensor, + model: torch.Tensor = None, + view: torch.Tensor = None, + perspective: torch.Tensor = None + ) -> torch.Tensor: + """ + Unproject screen space coordinates to 3D view space following the OpenGL convention (except for row major matrice) + + Args: + screen_coord (torch.Tensor): [... N, 3] screen space coordinates, value ranging in [0, 1]. + The origin (0., 0., 0.) is corresponding to the left & bottom & nearest + model (torch.Tensor): [..., 4, 4] model matrix + view (torch.Tensor): [..., 4, 4] view matrix + perspective (torch.Tensor): [..., 4, 4] perspective matrix + + Returns: + points (torch.Tensor): [..., N, 3] 3d points + """ + assert perspective is not None, "perspective matrix is required" + ndc_xy = screen_coord * 2 - 1 + clip_coord = torch.cat([ndc_xy, torch.ones_like(ndc_xy[..., :1])], dim=-1) + transform = perspective + if view is not None: + transform = transform @ view + if model is not None: + transform = transform @ model + transform = torch.inverse(transform) + points = clip_coord @ transform.transpose(-1, -2) + points = points[..., :3] / points[..., 3:] + return points + + +@batched(2, 1, 2, 2) +def unproject_cv( + uv_coord: torch.Tensor, + depth: torch.Tensor, + extrinsics: torch.Tensor = None, + intrinsics: torch.Tensor = None +) -> torch.Tensor: + """ + Unproject uv coordinates to 3D view space following the OpenCV convention + + Args: + uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + depth (torch.Tensor): [..., N] depth value + extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix + intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix + + Returns: + points (torch.Tensor): [..., N, 3] 3d points + """ + assert intrinsics is not None, "intrinsics matrix is required" + points = torch.cat([uv_coord, torch.ones_like(uv_coord[..., :1])], dim=-1) + points = points @ torch.inverse(intrinsics).transpose(-2, -1) + points = points * depth[..., None] + if extrinsics is not None: + points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + points = (points @ torch.inverse(extrinsics).transpose(-2, -1))[..., :3] + return points + + +def euler_axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str = 'XYZ') -> torch.Tensor: + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3), XYZ + convention: permutation of "X", "Y" or "Z", representing the order of Euler rotations to apply. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [ + euler_axis_angle_rotation(c, euler_angles[..., 'XYZ'.index(c)]) + for c in convention + ] + # return functools.reduce(torch.matmul, matrices) + return matrices[2] @ matrices[1] @ matrices[0] + + +def skew_symmetric(v: torch.Tensor): + "Skew symmetric matrix from a 3D vector" + assert v.shape[-1] == 3, "v must be 3D" + x, y, z = v.unbind(dim=-1) + zeros = torch.zeros_like(x) + return torch.stack([ + zeros, -z, y, + z, zeros, -x, + -y, x, zeros, + ], dim=-1).reshape(*v.shape[:-1], 3, 3) + + +def rotation_matrix_from_vectors(v1: torch.Tensor, v2: torch.Tensor): + "Rotation matrix that rotates v1 to v2" + I = torch.eye(3).to(v1) + v1 = F.normalize(v1, dim=-1) + v2 = F.normalize(v2, dim=-1) + v = torch.cross(v1, v2, dim=-1) + c = torch.sum(v1 * v2, dim=-1) + K = skew_symmetric(v) + R = I + K + (1 / (1 + c))[None, None] * (K @ K) + return R + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +) -> torch.Tensor: + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to Euler angles in radians. + NOTE: The composition order eg. `XYZ` means `Rz * Ry * Rx` (like blender), instead of `Rx * Ry * Rz` (like pytorch3d) + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3), in the order of XYZ (like blender), instead of convention (like pytorch3d) + """ + if not all(c in 'XYZ' for c in convention) or not all(c in convention for c in 'XYZ'): + raise ValueError(f"Invalid convention {convention}.") + if not matrix.shape[-2:] == (3, 3): + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + i0 = 'XYZ'.index(convention[0]) + i2 = 'XYZ'.index(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin(matrix[..., i2, i0] * (-1.0 if i2 - i0 in [-1, 2] else 1.0)) + else: + central_angle = torch.acos(matrix[..., i2, i2]) + + # Angles in composition order + o = [ + _angle_from_tan( + convention[0], convention[1], matrix[..., i2, :], True, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0], False, tait_bryan + ), + ] + return torch.stack([o[convention.index(c)] for c in 'XYZ'], -1) + + +def axis_angle_to_matrix(axis_angle: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """Convert axis-angle representation (rotation vector) to rotation matrix, whose direction is the axis of rotation and length is the angle of rotation + + Args: + axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors + + Returns: + torch.Tensor: shape (..., 3, 3) The rotation matrices for the given axis-angle parameters + """ + batch_shape = axis_angle.shape[:-1] + device, dtype = axis_angle.device, axis_angle.dtype + + angle = torch.norm(axis_angle + eps, dim=-1, keepdim=True) + axis = axis_angle / angle + + cos = torch.cos(angle)[..., None, :] + sin = torch.sin(angle)[..., None, :] + + rx, ry, rz = torch.split(axis, 3, dim=-1) + zeros = torch.zeros((*batch_shape, 1), dtype=dtype, device=device) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1).view((*batch_shape, 3, 3)) + + ident = torch.eye(3, dtype=dtype, device=device) + rot_mat = ident + sin * K + (1 - cos) * torch.matmul(K, K) + return rot_mat + + +def matrix_to_axis_angle(rot_mat: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """Convert a batch of 3x3 rotation matrices to axis-angle representation (rotation vector) + + Args: + rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert + + Returns: + torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given rotation matrices + """ + quat = matrix_to_quaternion(rot_mat) + axis_angle = quaternion_to_axis_angle(quat, eps=eps) + return axis_angle + + +def quaternion_to_axis_angle(quaternion: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """Convert a batch of quaternions (w, x, y, z) to axis-angle representation (rotation vector) + + Args: + quaternion (torch.Tensor): shape (..., 4), the quaternions to convert + + Returns: + torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given quaternions + """ + assert quaternion.shape[-1] == 4 + norm = torch.norm(quaternion[..., 1:], dim=-1, keepdim=True) + axis = quaternion[..., 1:] / norm.clamp(min=eps) + angle = 2 * torch.atan2(norm, quaternion[..., 0:1]) + return angle * axis + + +def axis_angle_to_quaternion(axis_angle: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """Convert axis-angle representation (rotation vector) to quaternion (w, x, y, z) + + Args: + axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors + + Returns: + torch.Tensor: shape (..., 4) The quaternions for the given axis-angle parameters + """ + axis = F.normalize(axis_angle, dim=-1, eps=eps) + angle = torch.norm(axis_angle, dim=-1, keepdim=True) + quat = torch.cat([torch.cos(angle / 2), torch.sin(angle / 2) * axis], dim=-1) + return quat + + +def matrix_to_quaternion(rot_mat: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """Convert 3x3 rotation matrix to quaternion (w, x, y, z) + + Args: + rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert + + Returns: + torch.Tensor: shape (..., 4), the quaternions corresponding to the given rotation matrices + """ + # Extract the diagonal and off-diagonal elements of the rotation matrix + m00, m01, m02, m10, m11, m12, m20, m21, m22 = rot_mat.flatten(-2).unbind(dim=-1) + + diag = torch.diagonal(rot_mat, dim1=-2, dim2=-1) + M = torch.tensor([ + [1, 1, 1], + [1, -1, -1], + [-1, 1, -1], + [-1, -1, 1] + ], dtype=rot_mat.dtype, device=rot_mat.device) + wxyz = (1 + diag @ M.transpose(-1, -2)).clamp_(0).sqrt().mul(0.5) + _, max_idx = wxyz.max(dim=-1) + xw = torch.sign(m21 - m12) + yw = torch.sign(m02 - m20) + zw = torch.sign(m10 - m01) + yz = torch.sign(m21 + m12) + xz = torch.sign(m02 + m20) + xy = torch.sign(m01 + m10) + ones = torch.ones_like(xw) + sign = torch.where( + max_idx[..., None] == 0, + torch.stack([ones, xw, yw, zw], dim=-1), + torch.where( + max_idx[..., None] == 1, + torch.stack([xw, ones, xy, xz], dim=-1), + torch.where( + max_idx[..., None] == 2, + torch.stack([yw, xy, ones, yz], dim=-1), + torch.stack([zw, xz, yz, ones], dim=-1) + ) + ) + ) + quat = sign * wxyz + quat = F.normalize(quat, dim=-1, eps=eps) + return quat + + +def quaternion_to_matrix(quaternion: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """Converts a batch of quaternions (w, x, y, z) to rotation matrices + + Args: + quaternion (torch.Tensor): shape (..., 4), the quaternions to convert + + Returns: + torch.Tensor: shape (..., 3, 3), the rotation matrices corresponding to the given quaternions + """ + assert quaternion.shape[-1] == 4 + quaternion = F.normalize(quaternion, dim=-1, eps=eps) + w, x, y, z = quaternion.unbind(dim=-1) + zeros = torch.zeros_like(w) + I = torch.eye(3, dtype=quaternion.dtype, device=quaternion.device) + xyz = quaternion[..., 1:] + A = xyz[..., :, None] * xyz[..., None, :] - I * (xyz ** 2).sum(dim=-1)[..., None, None] + B = torch.stack([ + zeros, -z, y, + z, zeros, -x, + -y, x, zeros + ], dim=-1).unflatten(-1, (3, 3)) + rot_mat = I + 2 * (A + w[..., None, None] * B) + return rot_mat + + +def slerp(rot_mat_1: torch.Tensor, rot_mat_2: torch.Tensor, t: Union[Number, torch.Tensor]) -> torch.Tensor: + """Spherical linear interpolation between two rotation matrices + + Args: + rot_mat_1 (torch.Tensor): shape (..., 3, 3), the first rotation matrix + rot_mat_2 (torch.Tensor): shape (..., 3, 3), the second rotation matrix + t (torch.Tensor): scalar or shape (...,), the interpolation factor + + Returns: + torch.Tensor: shape (..., 3, 3), the interpolated rotation matrix + """ + assert rot_mat_1.shape[-2:] == (3, 3) + rot_vec_1 = matrix_to_axis_angle(rot_mat_1) + rot_vec_2 = matrix_to_axis_angle(rot_mat_2) + if isinstance(t, Number): + t = torch.tensor(t, dtype=rot_mat_1.dtype, device=rot_mat_1.device) + rot_vec = (1 - t[..., None]) * rot_vec_1 + t[..., None] * rot_vec_2 + rot_mat = axis_angle_to_matrix(rot_vec) + return rot_mat + + +def interpolate_extrinsics(ext1: torch.Tensor, ext2: torch.Tensor, t: Union[Number, torch.Tensor]) -> torch.Tensor: + """Interpolate extrinsics between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation. + + Args: + ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose + ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose + t (torch.Tensor): scalar or shape (...,), the interpolation factor + + Returns: + torch.Tensor: shape (..., 4, 4), the interpolated camera pose + """ + return torch.inverse(interpolate_transform(torch.inverse(ext1), torch.inverse(ext2), t)) + + +def interpolate_view(view1: torch.Tensor, view2: torch.Tensor, t: Union[Number, torch.Tensor]): + """Interpolate view matrices between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation. + + Args: + ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose + ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose + t (torch.Tensor): scalar or shape (...,), the interpolation factor + + Returns: + torch.Tensor: shape (..., 4, 4), the interpolated camera pose + """ + return interpolate_extrinsics(view1, view2, t) + + +def interpolate_transform(transform1: torch.Tensor, transform2: torch.Tensor, t: Union[Number, torch.Tensor]): + assert transform1.shape[-2:] == (4, 4) and transform2.shape[-2:] == (4, 4) + if isinstance(t, Number): + t = torch.tensor(t, dtype=transform1.dtype, device=transform1.device) + pos = (1 - t[..., None]) * transform1[..., :3, 3] + t[..., None] * transform2[..., :3, 3] + rot = slerp(transform1[..., :3, :3], transform2[..., :3, :3], t) + transform = torch.cat([rot, pos[..., None]], dim=-1) + transform = torch.cat([ext, torch.tensor([0, 0, 0, 1], dtype=transform.dtype, device=transform.device).expand_as(transform[..., :1, :])], dim=-2) + return transform + + +def extrinsics_to_essential(extrinsics: torch.Tensor): + """ + extrinsics matrix `[[R, t] [0, 0, 0, 1]]` such that `x' = R (x - t)` to essential matrix such that `x' E x = 0` + + Args: + extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix + + Returns: + (torch.Tensor): [..., 3, 3] essential matrix + """ + assert extrinsics.shape[-2:] == (4, 4) + R = extrinsics[..., :3, :3] + t = extrinsics[..., :3, 3] + zeros = torch.zeros_like(t) + t_x = torch.stack([ + zeros, -t[..., 2], t[..., 1], + t[..., 2], zeros, -t[..., 0], + -t[..., 1], t[..., 0], zeros + ]).reshape(*t.shape[:-1], 3, 3) + return R @ t_x + + +def to4x4(R: torch.Tensor, t: torch.Tensor): + """ + Compose rotation matrix and translation vector to 4x4 transformation matrix + + Args: + R (torch.Tensor): [..., 3, 3] rotation matrix + t (torch.Tensor): [..., 3] translation vector + + Returns: + (torch.Tensor): [..., 4, 4] transformation matrix + """ + assert R.shape[-2:] == (3, 3) + assert t.shape[-1] == 3 + assert R.shape[:-2] == t.shape[:-1] + return torch.cat([ + torch.cat([R, t[..., None]], dim=-1), + torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device).expand(*R.shape[:-2], 1, 4) + ], dim=-2) + + +def rotation_matrix_2d(theta: Union[float, torch.Tensor]): + """ + 2x2 matrix for 2D rotation + + Args: + theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,) + + Returns: + (torch.Tensor): (..., 2, 2) rotation matrix + """ + if isinstance(theta, float): + theta = torch.tensor(theta) + return torch.stack([ + torch.cos(theta), -torch.sin(theta), + torch.sin(theta), torch.cos(theta), + ], dim=-1).unflatten(-1, (2, 2)) + + +def rotate_2d(theta: Union[float, torch.Tensor], center: torch.Tensor = None): + """ + 3x3 matrix for 2D rotation around a center + ``` + [[Rxx, Rxy, tx], + [Ryx, Ryy, ty], + [0, 0, 1]] + ``` + Args: + theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,) + center (torch.Tensor): rotation center, arbitrary shape (..., 2). Default to (0, 0) + + Returns: + (torch.Tensor): (..., 3, 3) transformation matrix + """ + if isinstance(theta, float): + theta = torch.tensor(theta) + if center is not None: + theta = theta.to(center) + if center is None: + center = torch.zeros(2).to(theta).expand(*theta.shape, -1) + R = rotation_matrix_2d(theta) + return torch.cat([ + torch.cat([ + R, + center[..., :, None] - R @ center[..., :, None], + ], dim=-1), + torch.tensor([[0, 0, 1]], dtype=center.dtype, device=center.device).expand(*center.shape[:-1], -1, -1), + ], dim=-2) + + +def translate_2d(translation: torch.Tensor): + """ + Translation matrix for 2D translation + ``` + [[1, 0, tx], + [0, 1, ty], + [0, 0, 1]] + ``` + Args: + translation (torch.Tensor): translation vector, arbitrary shape (..., 2) + + Returns: + (torch.Tensor): (..., 3, 3) transformation matrix + """ + return torch.cat([ + torch.cat([ + torch.eye(2, dtype=translation.dtype, device=translation.device).expand(*translation.shape[:-1], -1, -1), + translation[..., None], + ], dim=-1), + torch.tensor([[0, 0, 1]], dtype=translation.dtype, device=translation.device).expand(*translation.shape[:-1], -1, -1), + ], dim=-2) + + +def scale_2d(scale: Union[float, torch.Tensor], center: torch.Tensor = None): + """ + Scale matrix for 2D scaling + ``` + [[s, 0, tx], + [0, s, ty], + [0, 0, 1]] + ``` + Args: + scale (float | torch.Tensor): scale factor, arbitrary shape (...,) + center (torch.Tensor): scale center, arbitrary shape (..., 2). Default to (0, 0) + + Returns: + (torch.Tensor): (..., 3, 3) transformation matrix + """ + if isinstance(scale, float): + scale = torch.tensor(scale) + if center is not None: + scale = scale.to(center) + if center is None: + center = torch.zeros(2, dtype=scale.dtype, device=scale.device).expand(*scale.shape, -1) + return torch.cat([ + torch.cat([ + scale * torch.eye(2, dtype=scale.dtype, device=scale.device).expand(*scale.shape[:-1], -1, -1), + center[..., :, None] - center[..., :, None] * scale[..., None, None], + ], dim=-1), + torch.tensor([[0, 0, 1]], dtype=scale.dtype, device=scale.device).expand(*center.shape[:-1], -1, -1), + ], dim=-2) + + +def apply_2d(transform: torch.Tensor, points: torch.Tensor): + """ + Apply (3x3 or 2x3) 2D affine transformation to points + ``` + p = R @ p + t + ``` + Args: + transform (torch.Tensor): (..., 2 or 3, 3) transformation matrix + points (torch.Tensor): (..., N, 2) points to transform + + Returns: + (torch.Tensor): (..., N, 2) transformed points + """ + assert transform.shape[-2:] == (3, 3) or transform.shape[-2:] == (2, 3), "transform must be 3x3 or 2x3" + assert points.shape[-1] == 2, "points must be 2D" + return points @ transform[..., :2, :2].mT + transform[..., :2, None, 2] \ No newline at end of file diff --git a/src/utils3d/torch/utils.py b/src/utils3d/torch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..877ffb8a60a7f5206fbeb5a9e4a584758b875da4 --- /dev/null +++ b/src/utils3d/torch/utils.py @@ -0,0 +1,351 @@ +from typing import * + +import torch +import torch.nn.functional as F + +from . import transforms +from . import mesh +from ._helpers import batched + + +__all__ = [ + 'sliding_window_1d', + 'sliding_window_2d', + 'sliding_window_nd', + 'image_uv', + 'image_pixel_center', + 'image_mesh', + 'chessboard', + 'depth_edge', + 'depth_aliasing', + 'image_mesh_from_depth', + 'point_to_normal', + 'depth_to_normal', + 'masked_min', + 'masked_max', + 'bounding_rect' +] + + +def sliding_window_1d(x: torch.Tensor, window_size: int, stride: int = 1, dim: int = -1) -> torch.Tensor: + """ + Sliding window view of the input tensor. The dimension of the sliding window is appended to the end of the input tensor's shape. + NOTE: Since Pytorch has `unfold` function, 1D sliding window view is just a wrapper of it. + """ + return x.unfold(dim, window_size, stride) + + +def sliding_window_nd(x: torch.Tensor, window_size: Tuple[int, ...], stride: Tuple[int, ...], dim: Tuple[int, ...]) -> torch.Tensor: + dim = [dim[i] % x.ndim for i in range(len(dim))] + assert len(window_size) == len(stride) == len(dim) + for i in range(len(window_size)): + x = sliding_window_1d(x, window_size[i], stride[i], dim[i]) + return x + + +def sliding_window_2d(x: torch.Tensor, window_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], dim: Union[int, Tuple[int, int]] = (-2, -1)) -> torch.Tensor: + if isinstance(window_size, int): + window_size = (window_size, window_size) + if isinstance(stride, int): + stride = (stride, stride) + return sliding_window_nd(x, window_size, stride, dim) + + +def image_uv(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, device: torch.device = None, dtype: torch.dtype = None) -> torch.Tensor: + """ + Get image space UV grid, ranging in [0, 1]. + + >>> image_uv(10, 10): + [[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]], + [[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]], + ... ... ... + [[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]] + + Args: + width (int): image width + height (int): image height + + Returns: + np.ndarray: shape (height, width, 2) + """ + if left is None: left = 0 + if top is None: top = 0 + if right is None: right = width + if bottom is None: bottom = height + u = torch.linspace((left + 0.5) / width, (right - 0.5) / width, right - left, device=device, dtype=dtype) + v = torch.linspace((top + 0.5) / height, (bottom - 0.5) / height, bottom - top, device=device, dtype=dtype) + u, v = torch.meshgrid(u, v, indexing='xy') + uv = torch.stack([u, v], dim=-1) + return uv + + +def image_pixel_center( + height: int, + width: int, + left: int = None, + top: int = None, + right: int = None, + bottom: int = None, + dtype: torch.dtype = None, + device: torch.device = None +) -> torch.Tensor: + """ + Get image pixel center coordinates, ranging in [0, width] and [0, height]. + `image[i, j]` has pixel center coordinates `(j + 0.5, i + 0.5)`. + + >>> image_pixel_center(10, 10): + [[[0.5, 0.5], [1.5, 0.5], ..., [9.5, 0.5]], + [[0.5, 1.5], [1.5, 1.5], ..., [9.5, 1.5]], + ... ... ... + [[0.5, 9.5], [1.5, 9.5], ..., [9.5, 9.5]]] + + Args: + width (int): image width + height (int): image height + + Returns: + np.ndarray: shape (height, width, 2) + """ + if left is None: left = 0 + if top is None: top = 0 + if right is None: right = width + if bottom is None: bottom = height + u = torch.linspace(left + 0.5, right - 0.5, right - left, dtype=dtype, device=device) + v = torch.linspace(top + 0.5, bottom - 0.5, bottom - top, dtype=dtype, device=device) + u, v = torch.meshgrid(u, v, indexing='xy') + return torch.stack([u, v], dim=2) + + +def image_mesh(height: int, width: int, mask: torch.Tensor = None, device: torch.device = None, dtype: torch.dtype = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get a quad mesh regarding image pixel uv coordinates as vertices and image grid as faces. + + Args: + width (int): image width + height (int): image height + mask (np.ndarray, optional): binary mask of shape (height, width), dtype=bool. Defaults to None. + + Returns: + uv (np.ndarray): uv corresponding to pixels as described in image_uv() + faces (np.ndarray): quad faces connecting neighboring pixels + indices (np.ndarray, optional): indices of vertices in the original mesh + """ + if device is None and mask is not None: + device = mask.device + if mask is not None: + assert mask.shape[0] == height and mask.shape[1] == width + assert mask.dtype == torch.bool + uv = image_uv(height, width, device=device, dtype=dtype).reshape((-1, 2)) + row_faces = torch.stack([ + torch.arange(0, width - 1, dtype=torch.int32, device=device), + torch.arange(width, 2 * width - 1, dtype=torch.int32, device=device), + torch.arange(1 + width, 2 * width, dtype=torch.int32, device=device), + torch.arange(1, width, dtype=torch.int32, device=device) + ], dim=1) + faces = (torch.arange(0, (height - 1) * width, width, device=device, dtype=torch.int32)[:, None, None] + row_faces[None, :, :]).reshape((-1, 4)) + if mask is not None: + quad_mask = (mask[:-1, :-1] & mask[1:, :-1] & mask[1:, 1:] & mask[:-1, 1:]).ravel() + faces = faces[quad_mask] + faces, uv, indices = mesh.remove_unreferenced_vertices(faces, uv, return_indices=True) + return uv, faces, indices + return uv, faces + + +def depth_edge(depth: torch.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch.Tensor = None) -> torch.BoolTensor: + """ + Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth. + + Args: + depth (torch.Tensor): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + + Returns: + edge (torch.Tensor): shape (..., height, width) of dtype torch.bool + """ + shape = depth.shape + depth = depth.reshape(-1, 1, *shape[-2:]) + if mask is not None: + mask = mask.reshape(-1, 1, *shape[-2:]) + + if mask is None: + diff = (F.max_pool2d(depth, kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2)) + else: + diff = (F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(torch.where(mask, -depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2)) + + edge = torch.zeros_like(depth, dtype=torch.bool) + if atol is not None: + edge |= diff > atol + if rtol is not None: + edge |= (diff / depth).nan_to_num_() > rtol + edge = edge.reshape(*shape) + return edge + + +def depth_aliasing(depth: torch.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch.Tensor = None) -> torch.BoolTensor: + """ + Compute the map that indicates the aliasing of a depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors. + Args: + depth (torch.Tensor): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + + Returns: + edge (torch.Tensor): shape (..., height, width) of dtype torch.bool + """ + shape = depth.shape + depth = depth.reshape(-1, 1, *shape[-2:]) + if mask is not None: + mask = mask.reshape(-1, 1, *shape[-2:]) + + if mask is None: + diff_max = F.max_pool2d(depth, kernel_size, stride=1, padding=kernel_size // 2) - depth + diff_min = F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + depth + else: + diff_max = F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) - depth + diff_min = F.max_pool2d(torch.where(mask, -depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + depth + diff = torch.minimum(diff_max, diff_min) + + edge = torch.zeros_like(depth, dtype=torch.bool) + if atol is not None: + edge |= diff > atol + if rtol is not None: + edge |= (diff / depth).nan_to_num_() > rtol + edge = edge.reshape(*shape) + return edge + + +def image_mesh_from_depth( + depth: torch.Tensor, + extrinsics: torch.Tensor = None, + intrinsics: torch.Tensor = None +) -> Tuple[torch.Tensor, torch.Tensor]: + height, width = depth.shape + uv, faces = image_mesh(height, width) + faces = faces.reshape(-1, 4) + depth = depth.reshape(-1) + pts = transforms.unproject_cv(image_uv, depth, extrinsics, intrinsics) + faces = mesh.triangulate(faces, vertices=pts) + return pts, faces + + +@batched(3, 2, 2) +def point_to_normal(point: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: + """ + Calculate normal map from point map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. + + Args: + point (torch.Tensor): shape (..., height, width, 3), point map + Returns: + normal (torch.Tensor): shape (..., height, width, 3), normal map. + """ + has_mask = mask is not None + + if mask is None: + mask = torch.ones_like(point[..., 0], dtype=torch.bool) + mask = F.pad(mask, (1, 1, 1, 1), mode='constant', value=0) + + pts = F.pad(point.permute(0, 3, 1, 2), (1, 1, 1, 1), mode='constant', value=1).permute(0, 2, 3, 1) + up = pts[:, :-2, 1:-1, :] - pts[:, 1:-1, 1:-1, :] + left = pts[:, 1:-1, :-2, :] - pts[:, 1:-1, 1:-1, :] + down = pts[:, 2:, 1:-1, :] - pts[:, 1:-1, 1:-1, :] + right = pts[:, 1:-1, 2:, :] - pts[:, 1:-1, 1:-1, :] + normal = torch.stack([ + torch.cross(up, left, dim=-1), + torch.cross(left, down, dim=-1), + torch.cross(down, right, dim=-1), + torch.cross(right, up, dim=-1), + ]) + normal = F.normalize(normal, dim=-1) + valid = torch.stack([ + mask[:, :-2, 1:-1] & mask[:, 1:-1, :-2], + mask[:, 1:-1, :-2] & mask[:, 2:, 1:-1], + mask[:, 2:, 1:-1] & mask[:, 1:-1, 2:], + mask[:, 1:-1, 2:] & mask[:, :-2, 1:-1], + ]) & mask[None, :, 1:-1, 1:-1] + normal = (normal * valid[..., None]).sum(dim=0) + normal = F.normalize(normal, dim=-1) + + if has_mask: + return normal, valid.any(dim=0) + else: + return normal + + +@batched(2, 2, 2) +def depth_to_normal(depth: torch.Tensor, intrinsics: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: + """ + Calculate normal map from depth map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. + + Args: + depth (torch.Tensor): shape (..., height, width), linear depth map + intrinsics (torch.Tensor): shape (..., 3, 3), intrinsics matrix + Returns: + normal (torch.Tensor): shape (..., 3, height, width), normal map. + """ + has_mask = mask is not None + + height, width = depth.shape[-2:] + if mask is None: + mask = torch.ones_like(depth, dtype=torch.bool) + mask = F.pad(mask, (1, 1, 1, 1), mode='constant', value=0) + + uv = image_uv(*depth.shape[-2:]).unsqueeze(0).to(depth) + pts = transforms.unproject_cv(uv.reshape(-1, 2), depth.flatten(-2), intrinsics=intrinsics, extrinsics=None).unflatten(-2, (height, width)) + + return point_to_normal(pts, mask) + + +def masked_min(input: torch.Tensor, mask: torch.BoolTensor, dim: int = None, keepdim: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Similar to torch.min, but with mask + """ + if dim is None: + return torch.where(mask, input, torch.tensor(torch.inf, dtype=input.dtype, device=input.device)).min() + else: + return torch.where(mask, input, torch.tensor(torch.inf, dtype=input.dtype, device=input.device)).min(dim=dim, keepdim=keepdim) + + +def masked_max(input: torch.Tensor, mask: torch.BoolTensor, dim: int = None, keepdim: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Similar to torch.max, but with mask + """ + if dim is None: + return torch.where(mask, input, torch.tensor(-torch.inf, dtype=input.dtype, device=input.device)).max() + else: + return torch.where(mask, input, torch.tensor(-torch.inf, dtype=input.dtype, device=input.device)).max(dim=dim, keepdim=keepdim) + + +def bounding_rect(mask: torch.BoolTensor): + """get bounding rectangle of a mask + + Args: + mask (torch.Tensor): shape (..., height, width), mask + + Returns: + rect (torch.Tensor): shape (..., 4), bounding rectangle (left, top, right, bottom) + """ + height, width = mask.shape[-2:] + mask = mask.flatten(-2).unsqueeze(-1) + uv = image_uv(height, width).to(mask.device).reshape(-1, 2) + left_top = masked_min(uv, mask, dim=-2)[0] + right_bottom = masked_max(uv, mask, dim=-2)[0] + return torch.cat([left_top, right_bottom], dim=-1) + + +def chessboard(width: int, height: int, grid_size: int, color_a: torch.Tensor, color_b: torch.Tensor) -> torch.Tensor: + """get a chessboard image + + Args: + width (int): image width + height (int): image height + grid_size (int): size of chessboard grid + color_a (torch.Tensor): shape (chanenls,), color of the grid at the top-left corner + color_b (torch.Tensor): shape (chanenls,), color in complementary grids + + Returns: + image (torch.Tensor): shape (height, width, channels), chessboard image + """ + x = torch.div(torch.arange(width), grid_size, rounding_mode='floor') + y = torch.div(torch.arange(height), grid_size, rounding_mode='floor') + mask = ((x[None, :] + y[:, None]) % 2).to(color_a) + image = (1 - mask[..., None]) * color_a + mask[..., None] * color_b + return image \ No newline at end of file diff --git a/tools.py b/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..a3a6977dc826ef3b658a7a7af207afadb6ca597f --- /dev/null +++ b/tools.py @@ -0,0 +1,329 @@ +import argparse +import os +import torch +import numpy as np +import trimesh +from scipy.spatial.transform import Rotation +import torch.backends.cudnn as cudnn +import torch.nn.functional as F +from PIL import Image +from src.utils.vis import ( + prob_to_mask, + colorize, + denormalize, +) +import numpy as np +from src.lari.model import LaRIModel, DinoSegModel +from rembg import remove +from plyfile import PlyData, PlyElement +import torchvision.transforms as transforms + + + +LAYER_COLOR = [ + [255, 190, 11], # FFFF0B + [251, 86, 7], # FB5607 + [241, 91, 181], # F15BB5 + [131, 56, 236], # 8338EC + [58, 134, 255], # 3A86FF +] + + +OPENGL = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) + + + + +def save_point_cloud(pcd, rgb, filename, binary=True): + """Save an RGB point cloud as a PLY file. + :paras + @pcd: Nx3 matrix, the XYZ coordinates + @rgb: Nx3 matrix, the rgb colors for each 3D point + """ + + if rgb is None: + gray_concat = np.tile(np.array([128], dtype=np.uint8), + (pcd.shape[0], 3)) + points_3d = np.hstack((pcd, gray_concat)) + else: + assert pcd.shape[0] == rgb.shape[0] + points_3d = np.hstack((pcd, rgb)) + python_types = (float, float, float, int, int, int) + npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), + ('green', 'u1'), ('blue', 'u1')] + if binary is True: + # Format into Numpy structured array + vertices = [] + for row_idx in range(points_3d.shape[0]): + cur_point = points_3d[row_idx] + vertices.append( + tuple( + dtype(point) + for dtype, point in zip(python_types, cur_point))) + vertices_array = np.array(vertices, dtype=npy_types) + el = PlyElement.describe(vertices_array, 'vertex') + + # write + PlyData([el]).write(filename) + else: + x = np.squeeze(points_3d[:, 0]) + y = np.squeeze(points_3d[:, 1]) + z = np.squeeze(points_3d[:, 2]) + r = np.squeeze(points_3d[:, 3]) + g = np.squeeze(points_3d[:, 4]) + b = np.squeeze(points_3d[:, 5]) + + ply_head = 'ply\n' \ + 'format ascii 1.0\n' \ + 'element vertex %d\n' \ + 'property float x\n' \ + 'property float y\n' \ + 'property float z\n' \ + 'property uchar red\n' \ + 'property uchar green\n' \ + 'property uchar blue\n' \ + 'end_header' % r.shape[0] + # ---- Save ply data to disk + np.savetxt(filename, np.column_stack[x, y, z, r, g, b], fmt='%f %f %f %d %d %d', header=ply_head, comments='') + + +def load_model(model_info, ckpt_path, device): + model = eval(model_info) + model.to(device) + model.eval() + + # Load pretrained weights + ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) + if "model" in ckpt: + model.load_state_dict(ckpt["model"], strict=False) + else: + model.load_state_dict(ckpt, strict=False) + return model + + +def process_image_custom(pil_image, resolution=512): + """ + Read an image, resize the long side to `resolution` and pad the short side with gray, + so that the final image is (resolution x resolution). + + Returns: + padded_img (PIL.Image): The processed image. + crop_coords (tuple): (top, left, bottom, right) coordinates of the valid region. + original_size (tuple): (width, height) of the original image. + """ + pil_image = pil_image.convert("RGB") + original_width, original_height = pil_image.size + + # If already at fixed resolution, no processing is needed. + if original_width == resolution and original_height == resolution: + crop_coords = (0, 0, resolution, resolution) + return pil_image, crop_coords, (original_width, original_height), pil_image + + # Compute scaling factor based on the long side. + if original_width >= original_height: + # Width is the long side. + scale = resolution / float(original_width) + new_width = resolution + new_height = int(round(original_height * scale)) + resized_img = pil_image.resize((new_width, new_height), Image.BILINEAR) + # Compute vertical padding. + pad_top = (resolution - new_height) // 2 + pad_bottom = resolution - new_height - pad_top + pad_left, pad_right = 0, 0 + else: + # Height is the long side. + scale = resolution / float(original_height) + new_height = resolution + new_width = int(round(original_width * scale)) + resized_img = pil_image.resize((new_width, new_height), Image.BILINEAR) + # Compute horizontal padding. + pad_left = (resolution - new_width) // 2 + pad_right = resolution - new_width - pad_left + pad_top, pad_bottom = 0, 0 + + # Create new image filled with black + padded_img = Image.new("RGB", (resolution, resolution), (0, 0, 0)) + padded_img.paste(resized_img, (pad_left, pad_top)) + + # The valid region (crop) is where the resized image was pasted. + crop_coords = (pad_top, pad_left, pad_top + new_height, pad_left + new_width) + return padded_img, crop_coords, (original_width, original_height), pil_image + + +def process_image(pil_image, resolution=512): + """ + Process the image: apply custom resize/pad then convert to normalized tensor. + + Returns: + img_tensor (torch.Tensor): Tensor of shape (1, 3, resolution, resolution). + crop_coords (tuple): (top, left, bottom, right) coordinates of the valid region. + original_size (tuple): (width, height) of the original image. + """ + padded_img, crop_coords, original_size, ori_img = process_image_custom( + pil_image, resolution + ) + transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + img_tensor = transform(padded_img).unsqueeze(0) + ori_img_tensor = transform(ori_img).unsqueeze(0) + return img_tensor, ori_img_tensor, crop_coords, original_size + + +def post_process_output(input_tensor, crop_coords, original_size): + """ + Crop the input tensor using the crop_coords and then resize to the original image size. + + Args: + input_tensor (torch.Tensor): Input with shape (H, W, L, C) where C is 1 or 3. + crop_coords (tuple): (top, left, bottom, right) coordinates for cropping. + original_size (tuple): (width, height) of the original image. + + Returns: + processed_output (torch.Tensor): Output with shape (original_height, original_width, L, C). + """ + top, left, bottom, right = crop_coords + # Crop the input spatially: resulting shape (crop_h, crop_w, L, C) + cropped = input_tensor[top:bottom, left:right, ...] + crop_h, crop_w, L, C = cropped.shape + + # New shape becomes (1, L * C, crop_h, crop_w) + reshaped = cropped.permute(2, 3, 0, 1).reshape(1, L * C, crop_h, crop_w) + + # Unpack the original size (width, height) and use bilinear interpolation. + new_width, new_height = original_size + mode = "nearest" if L == 1 else "bilinear" + + resized = F.interpolate( + reshaped, size=(new_height, new_width), mode=mode, align_corners=False + ) + resized = resized.reshape(L, C, new_height, new_width) + + # Permute to the output shape: (new_height, new_width, L, C) + processed_output = resized.permute(2, 3, 0, 1) + + return processed_output + + +def get_masked_depth(lari_map, valid_mask, layer_id): + + layer_id = max(0, layer_id) + + lari_depth = lari_map[:, :, layer_id, 2].cpu().numpy() # H W + valid_mask = valid_mask[:, :, layer_id, 0].cpu().numpy() # H W + valid_values = lari_depth[valid_mask] + + # Handle empty valid values + if valid_values.size == 0: + vis_depth_range = [0, 1] + else: + vis_depth_range = [valid_values.min(), valid_values.max()] + + depth_image = Image.fromarray( + colorize( + lari_depth, + vis_depth_range[0], + vis_depth_range[1], + invalid_mask=~valid_mask, + cmap="Spectral", + ) + ).convert("RGB") + return depth_image + + +def save_to_glb(pts3d, color3d, path): + scene = trimesh.Scene() + pct = trimesh.PointCloud(pts3d, colors=color3d) + scene.add_geometry(pct) + rot_y = np.eye(4) + rot_y[:3, :3] = Rotation.from_euler("y", np.deg2rad(180)).as_matrix() + scene.apply_transform(np.linalg.inv(OPENGL @ rot_y)) + outfile = os.path.join(path, "res.glb") + scene.export(file_obj=outfile) + return outfile + + +def get_point_cloud(pred, img, mask, first_layer_color="image", target_folder=None): + """ + pred h w l 3 - the point cloud + img: 3 h w - the colored image + mask: h w l - indicating the valid layers + n_samples: int - n of pts to sample and save + """ + + ori_shape = pred.shape + pred = pred.cpu().numpy() + pred = pred.reshape(-1, 3) # M 3 + + color_palette = LAYER_COLOR[: min(len(LAYER_COLOR), ori_shape[-2])] + assert first_layer_color in ["image", "pseudo"] + + # assign color to point clouds: [M,3] -> [M, 6] + img = torch.clip(denormalize(img).squeeze(0), 0.0, 1.0) + img = img.permute(1, 2, 0).unsqueeze(2).cpu().numpy() # H W 1 3 + img = (img * 255.0).astype(np.uint8) + + layered_color = np.array([[color_palette]]).astype(np.uint8) # 1 1 n_layer 3 + layered_color = np.broadcast_to( + layered_color, (img.shape[0], img.shape[1], ori_shape[2], 3) + ) # H W n_layer 3 + + if first_layer_color == "image": + layered_color[:, :, :1, :] = img + layered_color = layered_color.reshape(-1, 3) + + valid_mask_arr = mask.squeeze().reshape(-1).cpu().numpy() # [H,W,layers] -> [M] + pred = pred[valid_mask_arr.astype(bool)] + layered_color = layered_color[valid_mask_arr.astype(bool)] # V,3 + + save_folder = target_folder if target_folder is not None else os.path.dirname(__file__) + + ply_path = os.path.join(save_folder, "res.ply") + save_point_cloud(pred, layered_color, filename=ply_path) + + glb_path = save_to_glb(pred, layered_color, save_folder) + + return glb_path, ply_path + + +def removebg_crop(pil_input): + + pil_input = remove(pil_input.convert("RGB")) + + pil_np = np.array(pil_input) + alpha = pil_np[:, :, 3] + is_crop = ( + False + if np.sum(alpha > 0.8 * 255) > 0.1 * (alpha.shape[0] * alpha.shape[1]) + else True + ) + + # adjust object size to fit the image resolution + if is_crop: + width, height = pil_input.size + # adjust object size + output_np = np.array(pil_input) + alpha = output_np[:, :, 3] + bbox = np.argwhere(alpha > 0.8 * 255) + bbox = ( + np.min(bbox[:, 1]), + np.min(bbox[:, 0]), + np.max(bbox[:, 1]), + np.max(bbox[:, 0]), + ) + center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 + size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) + size = int(size * 1.5) + bbox = ( + max(center[0] - size // 2, 0), + max(center[1] - size // 2, 0), + min(center[0] + size // 2, width), + min(center[1] + size // 2, height), + ) + pil_input = pil_input.crop(bbox) # type: ignore + + return pil_input \ No newline at end of file