|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import numpy as np |
|
import torch |
|
from typing import Dict, Union |
|
import math |
|
|
|
from diffusers import ( |
|
AutoencoderKL, |
|
DDIMScheduler, |
|
DiffusionPipeline, |
|
LCMScheduler, |
|
UNet2DConditionModel, |
|
AutoencoderTiny, |
|
) |
|
from diffusers.utils import BaseOutput |
|
from PIL import Image |
|
from torch.utils.data import DataLoader, TensorDataset |
|
from torchvision.transforms.functional import resize, pil_to_tensor |
|
from torchvision.transforms import InterpolationMode |
|
from tqdm.auto import tqdm |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
from functools import partial |
|
from typing import Optional, Tuple |
|
|
|
|
|
class MarigoldDepthOutput(BaseOutput): |
|
""" |
|
Output class for Marigold Monocular Depth Estimation pipeline. |
|
|
|
Args: |
|
depth_np (`np.ndarray`): |
|
Predicted depth map, with depth values in the range of [0, 1]. |
|
base_depth_np (`np.ndarray`): |
|
Upsampled base depth map, with depth values in the range of [0, 1]. |
|
This is the depth map used as a global guidance for the boosted inference. |
|
It is upsampled to the same resolution as the final depth map. |
|
This is useful for visualization and debugging purposes. |
|
""" |
|
|
|
depth_np: np.ndarray |
|
base_depth_np: np.ndarray |
|
|
|
|
|
class MarigoldDepthHRPipeline(DiffusionPipeline): |
|
""" |
|
Pipeline for high resolution monocular depth estimation using Marigold: https://marigoldcomputervision.github.io. |
|
|
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the |
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) |
|
|
|
Args: |
|
unet (`UNet2DConditionModel`): |
|
Conditional U-Net to denoise the prediction latent, conditioned on image latent. |
|
vae (`AutoencoderKL`): |
|
Variational Auto-Encoder (VAE) Model to encode and decode images and predictions |
|
to and from latent representations. |
|
scheduler (`DDIMScheduler`): |
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. |
|
text_encoder (`CLIPTextModel`): |
|
Text-encoder, for empty text embedding. |
|
tokenizer (`CLIPTokenizer`): |
|
CLIP tokenizer. |
|
boosting_unet (`UNet2DConditionModel`): |
|
Conditional U-Net to denoise the depth latent, conditioned on image latent and a global depth map. |
|
scale_invariant (`bool`, *optional*): |
|
A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in |
|
the model config. When used together with the `shift_invariant=True` flag, the model is also called |
|
"affine-invariant". NB: overriding this value is not supported. |
|
shift_invariant (`bool`, *optional*): |
|
A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in |
|
the model config. When used together with the `scale_invariant=True` flag, the model is also called |
|
"affine-invariant". NB: overriding this value is not supported. |
|
default_denoising_steps (`int`, *optional*): |
|
The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable |
|
quality with the given model. This value must be set in the model config. When the pipeline is called |
|
without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure |
|
reasonable results with various model flavors compatible with the pipeline, such as those relying on very |
|
short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). |
|
default_boosting_denoising_steps (`int`, *optional*): |
|
Same as `default_denoising_steps` but for `boosting_unet`. |
|
default_processing_resolution (`int`, *optional*): |
|
The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in |
|
the model config. When the pipeline is called without explicitly setting `processing_resolution`, the |
|
default value is used. This is required to ensure reasonable results with various model flavors trained |
|
with varying optimal processing resolution values. |
|
""" |
|
|
|
latent_scale_factor = 0.18215 |
|
|
|
def __init__( |
|
self, |
|
unet: UNet2DConditionModel, |
|
vae: AutoencoderKL, |
|
scheduler: Union[DDIMScheduler, LCMScheduler], |
|
text_encoder: CLIPTextModel, |
|
tokenizer: CLIPTokenizer, |
|
boosting_unet: Optional[UNet2DConditionModel], |
|
scale_invariant: Optional[bool] = True, |
|
shift_invariant: Optional[bool] = True, |
|
default_denoising_steps: Optional[int] = None, |
|
default_boosting_denoising_steps: Optional[int] = None, |
|
default_processing_resolution: Optional[int] = None, |
|
base_depth_model_uri: Optional[str] = None, |
|
variant: Optional[str] = None, |
|
): |
|
super().__init__() |
|
|
|
if boosting_unet is None: |
|
logging.warning( |
|
"Boosting U-Net is not provided. If this message appears during training, it is expected." |
|
) |
|
|
|
self.register_modules( |
|
unet=unet, |
|
vae=vae, |
|
scheduler=scheduler, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
boosting_unet=boosting_unet, |
|
) |
|
self.register_to_config( |
|
scale_invariant=scale_invariant, |
|
shift_invariant=shift_invariant, |
|
default_denoising_steps=default_denoising_steps, |
|
default_boosting_denoising_steps=default_denoising_steps, |
|
default_processing_resolution=default_processing_resolution, |
|
) |
|
|
|
self.register_to_config(base_depth_model_uri=base_depth_model_uri) |
|
|
|
self.scale_invariant = scale_invariant |
|
self.shift_invariant = shift_invariant |
|
self.default_denoising_steps = default_denoising_steps |
|
self.default_boosting_denoising_steps = default_boosting_denoising_steps |
|
self.default_processing_resolution = default_processing_resolution |
|
|
|
if base_depth_model_uri is not None: |
|
|
|
self.base_pipe = DiffusionPipeline.from_pretrained( |
|
base_depth_model_uri, |
|
variant=variant, |
|
torch_dtype=self.dtype, |
|
trust_remote_code=True, |
|
) |
|
self.base_pipe.to(self.device) |
|
else: |
|
self.base_pipe = None |
|
|
|
self.empty_text_embed = None |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
input_image: Union[Image.Image, torch.Tensor], |
|
*, |
|
base_depth: Optional[Union[Image.Image, np.ndarray, torch.Tensor]] = None, |
|
denoising_steps: Optional[int] = None, |
|
boosted_denoising_steps: Optional[int] = None, |
|
ensemble_size: int = 10, |
|
boosted_ensemble_size: int = 5, |
|
processing_res: Optional[int] = None, |
|
match_input_res: bool = True, |
|
resample_method: str = "bilinear", |
|
batch_size: int = 0, |
|
show_progress_bar: bool = True, |
|
ensemble_kwargs: Dict = None, |
|
upscale_factor: int = 2, |
|
) -> MarigoldDepthOutput: |
|
""" |
|
Function invoked when calling the pipeline. |
|
|
|
Args: |
|
input_image (`Image` or `torch.Tensor`): |
|
Input RGB (or gray-scale) image. |
|
base_depth (`Image`, `np.ndarray`, `torch.Tensor` or `MarigoldDepthOutput`, *optional*): |
|
Base depth map to be used as a global guidance for the boosted inference. |
|
denoising_steps (`int`, *optional*, defaults to `10`): |
|
Number of diffusion denoising steps (DDIM) during inference. |
|
ensemble_size (`int`, *optional*, defaults to `10`): |
|
Number of predictions to be ensembled. |
|
boosted_ensemble_size (`int`, *optional*, defaults to `5`): |
|
Number of predictions to be ensembled in the boosted inference. |
|
processing_res (`int`, *optional*, defaults to `768`): |
|
Maximum resolution of processing. |
|
If set to 0: will not resize at all. |
|
match_input_res (`bool`, *optional*, defaults to `True`): |
|
Resize depth prediction to match input resolution. |
|
Only valid if `processing_res` > 0. |
|
resample_method: (`str`, *optional*, defaults to `bilinear`): |
|
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`. |
|
batch_size (`int`, *optional*, defaults to `0`): |
|
Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. |
|
show_progress_bar (`bool`, *optional*, defaults to `True`): |
|
Display a progress bar of diffusion denoising. |
|
scale_invariant (`str`, *optional*, defaults to `True`): |
|
Flag of scale-invariant prediction, if True, scale will be adjusted from the raw prediction. |
|
shift_invariant (`str`, *optional*, defaults to `True`): |
|
Flag of shift-invariant prediction, if True, shift will be adjusted from the raw prediction, if False, |
|
near plane will be fixed at 0m. |
|
ensemble_kwargs (`dict`, *optional*, defaults to `None`): |
|
Arguments for detailed ensembling settings. |
|
Returns: |
|
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including: |
|
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1] |
|
- **base_depth_np** (`np.ndarray`) Upsampled base depth map, with depth values in the range of [0, 1]. |
|
This is the depth map used as a global guidance for the boosted inference. |
|
""" |
|
|
|
|
|
if denoising_steps is None: |
|
denoising_steps = self.default_denoising_steps |
|
if boosted_denoising_steps is None: |
|
boosted_denoising_steps = self.default_boosting_denoising_steps |
|
if processing_res is None: |
|
processing_res = self.default_processing_resolution |
|
|
|
|
|
assert processing_res >= 0, "Processing resolution must be non-negative." |
|
assert ensemble_size >= 1, "Ensemble size must be at least 1." |
|
assert boosted_ensemble_size >= 1, "Boosted ensemble size must be at least 1." |
|
assert math.log2( |
|
upscale_factor |
|
).is_integer(), "Upscale factor must be a power of 2." |
|
assert upscale_factor >= 1, "Upscale factor must be at least 2." |
|
assert batch_size >= 0, "Batch size must be non-negative." |
|
|
|
|
|
if upscale_factor >= 8 and self.dtype == torch.float16: |
|
logging.warning( |
|
"Warning: Upscaling factors of 8 (and more) with half precision or more may lead to artifacts in the final prediction." |
|
) |
|
if upscale_factor >= 4 and isinstance(self.vae, AutoencoderTiny): |
|
logging.warning( |
|
"Warning: Upscaling factors of 4 (and more) with the Tiny VAE may lead to instabilities." |
|
) |
|
|
|
|
|
input_width, input_height = ( |
|
input_image.size |
|
if isinstance(input_image, Image.Image) |
|
else input_image.shape[-2:] |
|
) |
|
|
|
|
|
if base_depth is not None: |
|
|
|
if isinstance(base_depth, Image.Image): |
|
lowres = np.asarray(base_depth.convert("L"), dtype=np.float32) |
|
elif isinstance(base_depth, torch.Tensor): |
|
lowres = base_depth.squeeze().cpu().numpy().astype(np.float32) |
|
elif isinstance(base_depth, np.ndarray): |
|
lowres = base_depth.squeeze().astype(np.float32) |
|
elif isinstance(base_depth, MarigoldDepthOutput): |
|
lowres = base_depth.depth_np.astype(np.float32) |
|
else: |
|
raise TypeError(f"Unsupported base_depth type: {type(base_depth)}") |
|
|
|
|
|
min_v, max_v = lowres.min(), lowres.max() |
|
eps = 1e-8 |
|
if max_v - min_v > eps: |
|
lowres = (lowres - min_v) / (max_v - min_v + eps) |
|
else: |
|
|
|
lowres = np.zeros_like(lowres, dtype=np.float32) |
|
else: |
|
assert self.base_pipe is not None |
|
if self.base_pipe.device != self.device: |
|
|
|
self.base_pipe.to(self.device) |
|
|
|
base_out = self.base_pipe( |
|
input_image, |
|
num_inference_steps=denoising_steps, |
|
ensemble_size=ensemble_size, |
|
processing_resolution=processing_res, |
|
match_input_resolution=False, |
|
batch_size=1, |
|
|
|
resample_method_input=resample_method, |
|
resample_method_output=resample_method, |
|
ensembling_kwargs=ensemble_kwargs, |
|
) |
|
lowres = base_out.prediction[0,:,:,0] |
|
base_out.depth_np = lowres |
|
|
|
|
|
|
|
t = torch.from_numpy(lowres[None,None]) |
|
up = resize(t, (input_height, input_width), |
|
interpolation=InterpolationMode.NEAREST_EXACT, |
|
antialias=True) |
|
base_depth_np_upsampled = up.squeeze().cpu().numpy() |
|
|
|
|
|
if upscale_factor == 1: |
|
|
|
return MarigoldDepthOutput( |
|
depth_np=lowres, |
|
base_depth_np=base_depth_np_upsampled |
|
) |
|
|
|
|
|
global_pred = torch.from_numpy(lowres).to(self.device) |
|
global_pred = (global_pred - global_pred.min()) / (global_pred.max() - global_pred.min()) |
|
|
|
|
|
current_pred = global_pred |
|
current_factor = 2 |
|
|
|
|
|
upscale_factors = [2**i for i in range(1, int(math.log2(upscale_factor)) + 1)] |
|
|
|
|
|
if processing_res == 0: |
|
processing_res = 768 |
|
|
|
df = min( |
|
processing_res / input_image.width, processing_res / input_image.height |
|
) |
|
patch_height = int(input_image.height * df) |
|
patch_width = int(input_image.width * df) |
|
|
|
|
|
for factor in upscale_factors: |
|
tw = patch_width * factor |
|
th = patch_height * factor |
|
if tw > input_width * 1.1 or th > input_height * 1.1: |
|
logging.warning( |
|
f"Warning: Attempting to upsample to {tw}×{th}, " |
|
f"which exceeds the original input of {input_width}×{input_height}. " |
|
"This technically works, but may lead to suboptimal results." |
|
) |
|
|
|
|
|
with tqdm( |
|
total=len(upscale_factors), |
|
desc=" Upscaling Progress", |
|
unit="step", |
|
leave=False, |
|
) as pbar: |
|
for current_factor in upscale_factors: |
|
|
|
|
|
pbar.set_description(f" Upscaling x{current_factor}") |
|
|
|
|
|
is_final_step = current_factor == upscale_factor |
|
|
|
|
|
boosted_output = self.boosted_inference( |
|
input_image=input_image, |
|
denoising_steps=boosted_denoising_steps, |
|
ensemble_size=( |
|
boosted_ensemble_size |
|
if current_factor < upscale_factors[-1] |
|
else boosted_ensemble_size |
|
), |
|
processing_res=processing_res, |
|
match_input_res=match_input_res and is_final_step, |
|
batch_size=batch_size, |
|
resample_method=resample_method, |
|
show_progress_bar=show_progress_bar, |
|
ensemble_kwargs=ensemble_kwargs, |
|
global_pred=current_pred, |
|
upscale_factor=current_factor, |
|
) |
|
|
|
|
|
current_pred = torch.from_numpy(boosted_output.depth_np) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
current_factor *= 2 |
|
|
|
|
|
pbar.update(1) |
|
|
|
|
|
out = boosted_output |
|
out.base_depth_np = base_depth_np_upsampled |
|
|
|
return out |
|
|
|
def boosted_inference( |
|
self, |
|
input_image: Union[torch.Tensor], |
|
denoising_steps: int = 10, |
|
ensemble_size: int = 10, |
|
processing_res: int = 768, |
|
match_input_res: bool = True, |
|
batch_size: int = 0, |
|
resample_method: str = "bilinear", |
|
seed: Union[int, None] = None, |
|
show_progress_bar: bool = True, |
|
ensemble_kwargs: Dict = None, |
|
global_pred: torch.Tensor = None, |
|
upscale_factor: int = 2, |
|
) -> MarigoldDepthOutput: |
|
""" |
|
Function invoked when calling the pipeline with boosted inference. |
|
|
|
Args: |
|
input_image (`torch.Tensor`): |
|
Input RGB image. |
|
denoising_steps (`int`, *optional*, defaults to `10`): |
|
Number of diffusion denoising steps (DDIM) during inference. |
|
ensemble_size (`int`, *optional*, defaults to `10`): |
|
Number of predictions to be ensembled. |
|
processing_res (`int`, *optional*, defaults to `768`): |
|
Maximum resolution of processing. |
|
If set to 0: will not resize at all. |
|
match_input_res (`bool`, *optional*, defaults to `True`): |
|
Resize depth prediction to match input resolution. |
|
Only valid if `processing_res` > 0. |
|
resample_method: (`str`, *optional*, defaults to `bilinear`): |
|
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`. |
|
batch_size (`int`, *optional*, defaults to `0`): |
|
Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. |
|
seed (`int`, *optional*, defaults to `None`): |
|
Random seed for the diffusion process. |
|
show_progress_bar (`bool`, *optional*, defaults to `True`): |
|
Display a progress bar of diffusion denoising. |
|
ensemble_kwargs (`dict`, *optional*, defaults to `None`): |
|
Arguments for detailed ensembling settings. |
|
global_pred (`torch.Tensor`): |
|
Global depth map to be used as guidance. |
|
upscale_factor (`int`, *optional*, defaults to `2`): |
|
Upscale factor of the global depth map. |
|
|
|
Returns: |
|
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including: |
|
- **depth_np** (`np.ndarray`) Predicted depth map with depth values in the range of [0, 1] |
|
- **base_depth_np** (`np.ndarray`) Upsampled base depth map with depth values in the range of [0, 1]. |
|
""" |
|
device = self.device |
|
|
|
self._check_inference_step(denoising_steps) |
|
resample_method: InterpolationMode = get_tv_resample_method(resample_method) |
|
|
|
|
|
if isinstance(input_image, Image.Image): |
|
input_image = input_image.convert("RGB") |
|
|
|
input_image = pil_to_tensor(input_image) |
|
elif isinstance(input_image, torch.Tensor): |
|
input_image = input_image.squeeze() |
|
|
|
else: |
|
raise TypeError(f"Unknown input type: {type(input_image) = }") |
|
input_size = input_image.shape |
|
assert ( |
|
3 == input_image.dim() and 3 == input_size[0] |
|
), f"Wrong input shape {input_size}, expected [rgb, H, W]" |
|
|
|
if isinstance(global_pred, torch.Tensor): |
|
global_pred = global_pred.squeeze().unsqueeze(0).to(device) |
|
else: |
|
raise TypeError(f"Unknown global_pred type: {type(global_pred) = }") |
|
|
|
if processing_res == 0: |
|
|
|
processing_res = 768 |
|
|
|
df = min( |
|
processing_res / input_image.shape[2], processing_res / input_image.shape[1] |
|
) |
|
patch_height = int(input_image.shape[1] * df) |
|
patch_width = int(input_image.shape[2] * df) |
|
|
|
|
|
patch_height = round(patch_height / 16) * 16 |
|
patch_width = round(patch_width / 16) * 16 |
|
|
|
patch_size = patch_height, patch_width |
|
global_size = (patch_size[0] * upscale_factor, patch_size[1] * upscale_factor) |
|
|
|
if global_pred.shape[1:] != global_size: |
|
global_pred = resize( |
|
global_pred, |
|
global_size, |
|
interpolation=resample_method, |
|
antialias=True, |
|
) |
|
|
|
if input_image.shape[1:] != patch_size: |
|
input_image = resize( |
|
input_image, |
|
global_size, |
|
interpolation=resample_method, |
|
antialias=True, |
|
).squeeze() |
|
|
|
input_image = ( |
|
input_image.unsqueeze(0) / 255.0 * 2.0 - 1.0 |
|
) |
|
input_image = input_image.to(self.dtype).to(device) |
|
assert input_image.min() >= -1.0 and input_image.max() <= 1.0 |
|
global_pred = global_pred.to(self.dtype).to(device) |
|
|
|
if batch_size > 0: |
|
_bs = batch_size |
|
else: |
|
_bs = find_batch_size( |
|
ensemble_size=ensemble_size |
|
* (2 * upscale_factor - 1) |
|
* (2 * upscale_factor - 1), |
|
input_res=max(patch_size), |
|
dtype=self.dtype, |
|
) |
|
|
|
|
|
global_pred = 0.9 * (global_pred * 2 - 1) |
|
global_pred = (global_pred + 1) / 2 |
|
|
|
depth_pred, pred_uncert = self.multidiffusion_inference( |
|
rgb_norm=input_image, |
|
global_pred=global_pred, |
|
num_inference_steps=denoising_steps, |
|
patch_size=patch_size, |
|
seed=seed, |
|
show_pbar=show_progress_bar, |
|
ensemble_size=ensemble_size, |
|
batch_size=_bs, |
|
ensemble_kwargs=ensemble_kwargs, |
|
) |
|
depth_pred = depth_pred.squeeze(0) |
|
|
|
|
|
min_d = torch.min(depth_pred) |
|
max_d = torch.max(depth_pred) |
|
depth_pred = (depth_pred - min_d) / (max_d - min_d) |
|
|
|
if depth_pred.shape[1:] != input_size[1:] and match_input_res: |
|
depth_pred = resize( |
|
depth_pred.unsqueeze(0), |
|
input_size[1:], |
|
interpolation=resample_method, |
|
antialias=True, |
|
).squeeze() |
|
depth_pred = depth_pred.squeeze().cpu().numpy() |
|
depth_pred = depth_pred.clip(0, 1) |
|
|
|
return MarigoldDepthOutput( |
|
depth_np=depth_pred, |
|
) |
|
|
|
def _check_inference_step(self, n_step: int) -> None: |
|
""" |
|
Check if denoising step is reasonable |
|
Args: |
|
n_step (`int`): denoising steps |
|
""" |
|
assert n_step >= 1 |
|
|
|
def encode_empty_text(self): |
|
""" |
|
Encode text embedding for empty prompt |
|
""" |
|
prompt = "" |
|
text_inputs = self.tokenizer( |
|
prompt, |
|
padding="do_not_pad", |
|
max_length=self.tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) |
|
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) |
|
|
|
@torch.no_grad() |
|
def multidiffusion_inference( |
|
self, |
|
rgb_norm: torch.Tensor, |
|
global_pred: torch.Tensor, |
|
num_inference_steps: int, |
|
seed: Union[int, None], |
|
show_pbar: bool, |
|
patch_size=(512, 768), |
|
encoder_patch_size=None, |
|
ensemble_size=1, |
|
batch_size=1, |
|
ensemble_kwargs: Dict = None, |
|
) -> torch.Tensor: |
|
""" |
|
Perform an individual depth prediction without ensembling. |
|
|
|
Args: |
|
rgb_norm (`torch.Tensor`): |
|
Input RGB image. |
|
num_inference_steps (`int`): |
|
Number of diffusion denoisign steps (DDIM) during inference. |
|
num_patches_vert (`int`): |
|
Number of vertical patches. |
|
num_patches_horz (`int`): |
|
Number of horizontal patches. |
|
step_height (`int`): |
|
Height of the patch. |
|
step_width (`int`): |
|
Width of the patch. |
|
Returns: |
|
`torch.Tensor`: Predicted depth map. |
|
`torch.Tensor`: Uncertainty map. |
|
""" |
|
device = self.device |
|
rgb_norm = rgb_norm.to(device) |
|
global_pred = global_pred.to(device) |
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
timesteps = self.scheduler.timesteps |
|
|
|
latent_patch_size = (patch_size[0] // 8, patch_size[1] // 8) |
|
latent_full_size = (rgb_norm.shape[-2] // 8, rgb_norm.shape[-1] // 8) |
|
|
|
|
|
global_pred = global_pred * 2 - 1 |
|
|
|
|
|
if encoder_patch_size is None: |
|
encoder_patch_size = patch_size |
|
rgb_latent = self.encode_rgb_patched( |
|
rgb_norm, patch_size=encoder_patch_size, show_pbar=show_pbar |
|
) |
|
global_pred_latent = self.encode_depth_patched( |
|
global_pred, patch_size=encoder_patch_size, show_pbar=show_pbar |
|
) |
|
|
|
|
|
rgb_norm = rgb_norm.cpu() |
|
global_pred = global_pred.cpu() |
|
|
|
|
|
if seed is None: |
|
rand_num_generator = None |
|
else: |
|
rand_num_generator = torch.Generator(device=device) |
|
rand_num_generator.manual_seed(seed) |
|
|
|
|
|
( |
|
rgb_latent_patched, |
|
num_patches_vert, |
|
num_patches_horz, |
|
step_height, |
|
step_width, |
|
) = self.extract_patches(rgb_latent[0], patch_size=latent_patch_size) |
|
|
|
|
|
global_pred_latent_patched, _, _, _, _ = self.extract_patches( |
|
global_pred_latent[0], patch_size=latent_patch_size |
|
) |
|
|
|
|
|
if self.empty_text_embed is None: |
|
self.encode_empty_text() |
|
batch_empty_text_embed = self.empty_text_embed.repeat((batch_size, 1, 1)).to( |
|
device |
|
) |
|
if hasattr(self, "pooled_empty_text_embeds"): |
|
batch_pooled_empty_text_embed = self.pooled_empty_text_embeds.repeat( |
|
(batch_size, 1, 1) |
|
).to(device) |
|
|
|
|
|
if ensemble_size > 1: |
|
rgb_latent_patched = rgb_latent_patched.repeat(ensemble_size, 1, 1, 1) |
|
global_pred_latent_patched = global_pred_latent_patched.repeat( |
|
ensemble_size, 1, 1, 1 |
|
) |
|
|
|
batch_empty_text_embed = batch_empty_text_embed.repeat(ensemble_size, 1, 1) |
|
if hasattr(self, "pooled_empty_text_embeds"): |
|
batch_pooled_empty_text_embed = batch_pooled_empty_text_embed.repeat( |
|
ensemble_size, 1, 1 |
|
) |
|
|
|
|
|
depth_latent = torch.randn( |
|
(ensemble_size, 4, latent_full_size[0], latent_full_size[1]), |
|
device=device, |
|
dtype=self.dtype, |
|
generator=rand_num_generator, |
|
) |
|
( |
|
depth_latent_patched, |
|
num_patches_vert, |
|
num_patches_horz, |
|
step_height, |
|
step_width, |
|
) = self.extract_patches(depth_latent, patch_size=latent_patch_size) |
|
|
|
|
|
if show_pbar: |
|
iterable = tqdm( |
|
enumerate(timesteps), |
|
total=len(timesteps), |
|
leave=False, |
|
desc=" " * 4 + "Diffusion denoising", |
|
) |
|
else: |
|
iterable = enumerate(timesteps) |
|
|
|
for _, t in iterable: |
|
|
|
assert ( |
|
self.boosting_unet.conv_in.in_channels == 12 |
|
), "The input channels of the boosting unet must be 12." |
|
unet_input = torch.cat( |
|
[rgb_latent_patched, global_pred_latent_patched, depth_latent_patched], |
|
dim=1, |
|
) |
|
|
|
|
|
dataset = TensorDataset(unet_input) |
|
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False) |
|
noise_preds = [] |
|
for batch in tqdm( |
|
loader, |
|
leave=False, |
|
disable=not show_pbar, |
|
desc=" " * 6 + "UNet patch inference", |
|
): |
|
(unet_input,) = batch |
|
noise_pred = self.boosting_unet( |
|
unet_input, |
|
t, |
|
encoder_hidden_states=batch_empty_text_embed[: unet_input.shape[0]], |
|
).sample |
|
noise_preds.append(noise_pred) |
|
noise_preds = torch.concat(noise_preds, dim=0) |
|
|
|
|
|
scheduler_out = self.scheduler.step( |
|
noise_preds, t, depth_latent_patched, generator=rand_num_generator |
|
) |
|
|
|
|
|
depth_latent = scheduler_out.prev_sample.reshape( |
|
ensemble_size, |
|
num_patches_vert, |
|
num_patches_horz, |
|
4, |
|
latent_patch_size[0], |
|
latent_patch_size[1], |
|
) |
|
|
|
|
|
depth_latent_full = self.blend_patches( |
|
depth_latent, |
|
canvas_size=global_pred_latent.shape[2:], |
|
num_patches_vert=num_patches_vert, |
|
num_patches_horz=num_patches_horz, |
|
step_height=step_height, |
|
step_width=step_width, |
|
) |
|
|
|
|
|
if t != timesteps[-1]: |
|
depth_latent_patched, _, _, _, _ = self.extract_patches( |
|
depth_latent_full, patch_size=latent_patch_size |
|
) |
|
|
|
|
|
depth = self.decode_depth_patched( |
|
depth_latent_full=depth_latent_full, |
|
canvas_size=global_pred.shape[1:], |
|
latent_decoder_patch_size=( |
|
encoder_patch_size[0] // 8, |
|
encoder_patch_size[1] // 8, |
|
), |
|
show_pbar=show_pbar, |
|
ensemble_size=ensemble_size, |
|
) |
|
|
|
if ensemble_size > 1: |
|
depth, pred_uncert = ensemble_depth( |
|
depth, |
|
scale_invariant=self.scale_invariant, |
|
shift_invariant=self.shift_invariant, |
|
**(ensemble_kwargs or {}), |
|
) |
|
else: |
|
depth = (depth + 1.0) / 2.0 |
|
pred_uncert = None |
|
|
|
return depth, pred_uncert |
|
|
|
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Encode RGB image into latent. |
|
|
|
Args: |
|
rgb_in (`torch.Tensor`): |
|
Input RGB image to be encoded. |
|
|
|
Returns: |
|
`torch.Tensor`: Image latent. |
|
""" |
|
if isinstance(self.vae, AutoencoderTiny): |
|
rgb_latent = self.vae.encoder(rgb_in) |
|
else: |
|
h = self.vae.encoder(rgb_in) |
|
moments = self.vae.quant_conv(h) |
|
mean, _ = torch.chunk(moments, 2, dim=1) |
|
rgb_latent = mean * self.latent_scale_factor |
|
|
|
return rgb_latent |
|
|
|
def encode_rgb_patched( |
|
self, rgb_in: torch.Tensor, patch_size: tuple, show_pbar: bool = False |
|
) -> torch.Tensor: |
|
""" |
|
Encode depth map into latent. |
|
|
|
Args: |
|
depth_in (`torch.Tensor`): |
|
Input depth map to be encoded. |
|
patch_size (`Tuple[int, int]`): |
|
Size of the patch. |
|
|
|
Returns: |
|
`torch.Tensor`: Depth latent. |
|
""" |
|
device = self.device |
|
|
|
( |
|
rgb_patched, |
|
num_patches_vert_pix, |
|
num_patches_horz_pix, |
|
step_height_pix, |
|
step_width_pix, |
|
) = self.extract_patches(rgb_in.squeeze(0), patch_size=patch_size) |
|
rgb_in = rgb_in.cpu() |
|
rgb_patched = rgb_patched.cpu() |
|
|
|
|
|
rgb_latent_patched = [] |
|
for rgb in tqdm( |
|
rgb_patched, |
|
leave=False, |
|
disable=not show_pbar, |
|
desc=" " * 4 + "Encoding RGB", |
|
): |
|
patch = rgb.unsqueeze(0).to(device) |
|
rgb_latent_patched.append(self.encode_rgb(patch).cpu()) |
|
|
|
|
|
rgb_latent_patched = torch.concat(rgb_latent_patched, dim=0) |
|
rgb_latent_patched = rgb_latent_patched.reshape( |
|
num_patches_vert_pix, |
|
num_patches_horz_pix, |
|
4, |
|
rgb_latent_patched.shape[-2], |
|
rgb_latent_patched.shape[-1], |
|
).to(device) |
|
|
|
|
|
rgb_latent = self.blend_patches( |
|
rgb_latent_patched, |
|
canvas_size=(rgb_in.shape[-2] // 8, rgb_in.shape[-1] // 8), |
|
num_patches_vert=num_patches_vert_pix, |
|
num_patches_horz=num_patches_horz_pix, |
|
step_height=step_height_pix // 8, |
|
step_width=step_width_pix // 8, |
|
) |
|
|
|
if len(rgb_latent.shape) == 3: |
|
rgb_latent = rgb_latent.unsqueeze(0) |
|
|
|
return rgb_latent |
|
|
|
def encode_depth_patched( |
|
self, |
|
depth_in: torch.Tensor, |
|
patch_size, |
|
show_pbar: bool = False, |
|
) -> torch.Tensor: |
|
""" |
|
Encode depth map into latent, but in a patched and scalable way |
|
|
|
Args: |
|
depth_in (`torch.Tensor`): |
|
Input depth map to be encoded. |
|
patch_size (`Tuple[int, int]`): |
|
Size of the patch. |
|
show_pbar (`bool`): |
|
Display a progress bar. |
|
|
|
Returns: |
|
`torch.Tensor`: Depth latent. |
|
""" |
|
device = self.device |
|
ensemble_size = depth_in.shape[0] |
|
|
|
( |
|
depth_in_patched, |
|
num_patches_vert_pix, |
|
num_patches_horz_pix, |
|
step_height_pix, |
|
step_width_pix, |
|
) = self.extract_patches(depth_in, patch_size=patch_size) |
|
depth_in = depth_in.cpu() |
|
depth_in_patched = depth_in_patched.cpu() |
|
|
|
|
|
depth_latent_patched = [] |
|
for gpred in tqdm( |
|
depth_in_patched, |
|
leave=False, |
|
disable=not show_pbar, |
|
desc=" " * 4 + "Encoding context depth", |
|
): |
|
patch = gpred.unsqueeze(0).to(device) |
|
depth_latent_patched.append(self.encode_depth(patch).cpu()) |
|
|
|
|
|
depth_latent_patched = torch.concat(depth_latent_patched, dim=0) |
|
depth_latent_patched = depth_latent_patched.reshape( |
|
ensemble_size, |
|
num_patches_vert_pix, |
|
num_patches_horz_pix, |
|
4, |
|
depth_latent_patched.shape[-2], |
|
depth_latent_patched.shape[-1], |
|
).to(device) |
|
|
|
|
|
depth_latent = self.blend_patches( |
|
depth_latent_patched, |
|
canvas_size=(depth_in.shape[-2] // 8, depth_in.shape[-1] // 8), |
|
num_patches_vert=num_patches_vert_pix, |
|
num_patches_horz=num_patches_horz_pix, |
|
step_height=step_height_pix // 8, |
|
step_width=step_width_pix // 8, |
|
) |
|
|
|
if len(depth_latent.shape) == 3: |
|
depth_latent = depth_latent.unsqueeze(0) |
|
|
|
return depth_latent |
|
|
|
def decode_depth_patched( |
|
self, |
|
depth_latent_full: torch.Tensor, |
|
canvas_size: tuple, |
|
latent_decoder_patch_size: tuple, |
|
show_pbar: bool = True, |
|
ensemble_size: int = 1, |
|
) -> torch.Tensor: |
|
""" |
|
Decode depth map from latent in a patched and scalable way. |
|
|
|
Args: |
|
depth_latent_full (`torch.Tensor`): |
|
Depth latent to be decoded. |
|
canvas_size (`tuple`): |
|
Size of the canvas. |
|
latent_decoder_patch_size (`tuple`): |
|
Size of the patch. |
|
show_pbar (`bool`): |
|
Display a progress bar. |
|
ensemble_size (`int`): |
|
Ensemble size. |
|
|
|
Returns: |
|
`torch.Tensor`: Decoded depth map. |
|
""" |
|
|
|
encoder_patch_size = ( |
|
latent_decoder_patch_size[0] * 8, |
|
latent_decoder_patch_size[1] * 8, |
|
) |
|
|
|
|
|
( |
|
depth_latent_patched, |
|
num_patches_vert, |
|
num_patches_horz, |
|
step_height, |
|
step_width, |
|
) = self.extract_patches( |
|
depth_latent_full, patch_size=latent_decoder_patch_size, overlap=0.5 |
|
) |
|
|
|
|
|
depthp = [] |
|
for patch in tqdm( |
|
depth_latent_patched, |
|
leave=False, |
|
desc=" " * 6 + "Decoding Depth", |
|
disable=not show_pbar, |
|
): |
|
depthp.append(self.decode_depth(patch.unsqueeze(0))) |
|
depthp = torch.concat(depthp, dim=0) |
|
depthp = depthp.reshape( |
|
ensemble_size, |
|
num_patches_vert, |
|
num_patches_horz, |
|
1, |
|
encoder_patch_size[0], |
|
encoder_patch_size[1], |
|
) |
|
|
|
|
|
depth = self.blend_patches( |
|
depthp, |
|
canvas_size=canvas_size, |
|
num_patches_vert=num_patches_vert, |
|
num_patches_horz=num_patches_horz, |
|
step_height=step_height * 8, |
|
step_width=step_width * 8, |
|
) |
|
|
|
return depth |
|
|
|
def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Decode depth latent into depth map. |
|
|
|
Args: |
|
depth_latent (`torch.Tensor`): |
|
Depth latent to be decoded. |
|
|
|
Returns: |
|
`torch.Tensor`: Decoded depth map of the input depth latent. |
|
""" |
|
|
|
|
|
if isinstance(self.vae, AutoencoderTiny): |
|
stacked = self.vae.decoder(depth_latent) |
|
else: |
|
depth_latent = depth_latent / self.latent_scale_factor |
|
z = self.vae.post_quant_conv(depth_latent) |
|
stacked = self.vae.decoder(z) |
|
|
|
|
|
depth_mean = stacked.mean(dim=1, keepdim=True) |
|
return depth_mean |
|
|
|
def encode_depth( |
|
self, |
|
depth_in: torch.Tensor, |
|
) -> torch.Tensor: |
|
""" |
|
Encode depth map into latent. |
|
|
|
Args: |
|
depth_in (`torch.Tensor`): |
|
Input depth map to be encoded. |
|
|
|
Returns: |
|
`torch.Tensor`: Depth latent of the input depth map. |
|
""" |
|
|
|
|
|
stacked = self.stack_depth_images(depth_in) |
|
|
|
|
|
depth_latent = self.encode_rgb(stacked) |
|
return depth_latent |
|
|
|
@staticmethod |
|
def stack_depth_images(depth_in): |
|
if 4 == len(depth_in.shape): |
|
stacked = depth_in.repeat(1, 3, 1, 1) |
|
elif 3 == len(depth_in.shape): |
|
stacked = depth_in.unsqueeze(1) |
|
stacked = depth_in.repeat(1, 3, 1, 1) |
|
return stacked |
|
|
|
def extract_patches(self, canvas_input_image, patch_size, overlap=0.5): |
|
""" |
|
Extract patches from an image |
|
Args: |
|
image (`np.ndarray`): Input image of shape [channels, height, width] |
|
patch_size (`tuple`): Size of the patch |
|
step_size (`int`): Step size |
|
Returns: |
|
input_image_patched (`torch.Tensor`): Extracted patches |
|
num_patches_vert (`int`): Number of patches in the vertical direction |
|
num_patches_horz (`int`): Number of patches in the horizontal direction |
|
step_height (`int`): Step size in the vertical direction |
|
step_width (`int`): Step size in the horizontal direction |
|
""" |
|
|
|
if len(canvas_input_image.shape) == 4: |
|
ensemble_size = canvas_input_image.shape[0] |
|
else: |
|
canvas_input_image = canvas_input_image.unsqueeze(0) |
|
ensemble_size = 1 |
|
|
|
|
|
h, w = patch_size |
|
step_height, step_width = int(h * (1 - overlap)), int(w * (1 - overlap)) |
|
|
|
|
|
num_patches_vert = (canvas_input_image.shape[-2] - h) // step_height + 1 |
|
num_patches_horz = (canvas_input_image.shape[-1] - w) // step_width + 1 |
|
|
|
|
|
patches = [] |
|
|
|
for e in range(ensemble_size): |
|
for i in range(num_patches_vert): |
|
for j in range(num_patches_horz): |
|
|
|
start_y = i * step_height |
|
start_x = j * step_width |
|
|
|
|
|
patch = canvas_input_image[ |
|
e, :, start_y : start_y + h, start_x : start_x + w |
|
] |
|
patches.append(patch) |
|
|
|
|
|
input_image_patched = torch.stack(patches) |
|
|
|
return ( |
|
input_image_patched, |
|
num_patches_vert, |
|
num_patches_horz, |
|
step_height, |
|
step_width, |
|
) |
|
|
|
def blend_patches( |
|
self, |
|
depth_preds, |
|
num_patches_vert, |
|
num_patches_horz, |
|
step_height, |
|
step_width, |
|
global_depth_pred=None, |
|
canvas_size=None, |
|
alpha_center=1.0, |
|
alpha_edge=1e-4, |
|
noise_blend=False, |
|
eps=1e-8, |
|
): |
|
""" |
|
Blend patches of depth maps and apply the transformation to the global depth map |
|
|
|
Args: |
|
global_depth_pred (`torch.Tensor`): Global depth map |
|
depth_preds (`torch.Tensor`): Local depth maps |
|
num_patches_vert (`int`): Number of patches in the vertical direction |
|
num_patches_horz (`int`): Number of patches in the horizontal direction |
|
step_height (`int`): Step size in the vertical direction |
|
step_width (`int`): Step size in the horizontal direction |
|
overlap (`float`): Overlap between patches |
|
alpha_center (`float`): Weight at the center of the patch |
|
alpha_edge (`float`): Weight at the edge of the patch, should not be exactly 0 to avoid division by zero |
|
adjust_LSQ (`bool`): Adjust the transformation using least squares optimization |
|
|
|
Returns: |
|
`torch.Tensor`: Blended depth map |
|
""" |
|
|
|
eps = 1e-8 |
|
|
|
channels, h, w = ( |
|
depth_preds.shape[-3], |
|
depth_preds.shape[-2], |
|
depth_preds.shape[-1], |
|
) |
|
|
|
if len(depth_preds.shape) == 6: |
|
ensemble_size = depth_preds.shape[0] if len(depth_preds.shape) == 6 else 1 |
|
elif len(depth_preds.shape) == 5: |
|
ensemble_size = 1 |
|
depth_preds = depth_preds.unsqueeze(0) |
|
else: |
|
raise ValueError("depth_preds should have 5 or 6 dimensions") |
|
|
|
|
|
blended_depth_map = torch.zeros( |
|
(ensemble_size, channels, canvas_size[0], canvas_size[1]), |
|
device=depth_preds.device, |
|
dtype=depth_preds.dtype, |
|
) |
|
denominator_map = torch.zeros( |
|
(1, canvas_size[0], canvas_size[1]), |
|
device=depth_preds.device, |
|
dtype=depth_preds.dtype, |
|
) |
|
|
|
for i in range(num_patches_vert): |
|
for j in range(num_patches_horz): |
|
|
|
start_y = i * step_height |
|
start_x = j * step_width |
|
|
|
|
|
local_patch = depth_preds[:, i, j] |
|
|
|
|
|
weights = self.get_linear_weight_map( |
|
h, |
|
w, |
|
device=local_patch.device, |
|
alpha_center=alpha_center, |
|
alpha_edge=alpha_edge, |
|
cosine_blending=True, |
|
margin=0.0, |
|
) |
|
|
|
|
|
blended_depth_map[ |
|
:, :, start_y : start_y + h, start_x : start_x + w |
|
] += (local_patch * weights) |
|
|
|
if noise_blend: |
|
denominator_weights = weights**2 |
|
else: |
|
denominator_weights = weights |
|
denominator_map[ |
|
:, start_y : start_y + h, start_x : start_x + w |
|
] += denominator_weights |
|
|
|
if noise_blend: |
|
denominator_map = torch.sqrt(denominator_map) |
|
|
|
blended_depth_map /= denominator_map + eps |
|
|
|
|
|
|
|
if global_depth_pred is not None: |
|
blended_depth_map[torch.isnan(blended_depth_map)] = ( |
|
global_depth_pred.repeat(ensemble_size, 1, 1, 1)[ |
|
torch.isnan(blended_depth_map) |
|
] |
|
) |
|
|
|
return blended_depth_map |
|
|
|
def get_linear_weight_map( |
|
self, |
|
h, |
|
w, |
|
device, |
|
alpha_center=1.0, |
|
alpha_edge=1e-4, |
|
margin=0.0, |
|
cosine_blending=False, |
|
): |
|
""" |
|
Generate a linear weight map for blending patches |
|
Args: |
|
h (`int`): Height of the weight map |
|
w (`int`): Width of the weight map |
|
device (`torch.device`): Device to use |
|
alpha_center (`float`): Weight at the center of the patch |
|
alpha_edge (`float`): Weight at the edge of the patch |
|
margin (`int`): Perceptage of image dimensions. This margin at the edges to be filled with near 0 values. |
|
Returns: |
|
`torch.Tensor`: Linear distance weight map (looks like a pyramid) |
|
""" |
|
|
|
x = torch.linspace(-1, 1, h, device=device) |
|
y = torch.linspace(-1, 1, w, device=device) |
|
xx, yy = torch.meshgrid(x, y, indexing="ij") |
|
dist = torch.stack([xx.abs(), yy.abs()]).max(dim=0).values |
|
norm_dist = dist / torch.max(dist) |
|
|
|
|
|
norm_dist = torch.clamp(norm_dist + margin, 0, 1) |
|
|
|
|
|
mindist = torch.min(norm_dist) |
|
maxdist = torch.max(norm_dist) |
|
norm_dist = (norm_dist - mindist) / (maxdist - mindist) |
|
|
|
|
|
if cosine_blending: |
|
weights = alpha_edge + (alpha_center - alpha_edge) * ( |
|
0.5 * (1 + torch.cos(norm_dist * math.pi)) |
|
) |
|
else: |
|
weights = alpha_edge + (alpha_center - alpha_edge) * (1 - norm_dist) |
|
|
|
return weights |
|
|
|
|
|
|
|
|
|
|
|
def get_tv_resample_method(method_str: str) -> InterpolationMode: |
|
resample_method_dict = { |
|
"bilinear": InterpolationMode.BILINEAR, |
|
"bicubic": InterpolationMode.BICUBIC, |
|
"nearest": InterpolationMode.NEAREST_EXACT, |
|
"nearest-exact": InterpolationMode.NEAREST_EXACT, |
|
} |
|
resample_method = resample_method_dict.get(method_str, None) |
|
if resample_method is None: |
|
raise ValueError(f"Unknown resampling method: {resample_method}") |
|
else: |
|
return resample_method |
|
|
|
|
|
def resize_max_res( |
|
img: torch.Tensor, |
|
max_edge_resolution: int, |
|
resample_method: InterpolationMode = InterpolationMode.BILINEAR, |
|
) -> torch.Tensor: |
|
""" |
|
Resize image to limit maximum edge length while keeping aspect ratio. |
|
|
|
Args: |
|
img (`torch.Tensor`): |
|
Image tensor to be resized. Expected shape: [B, C, H, W] |
|
max_edge_resolution (`int`): |
|
Maximum edge length (pixel). |
|
resample_method (`PIL.Image.Resampling`): |
|
Resampling method used to resize images. |
|
|
|
Returns: |
|
`torch.Tensor`: Resized image. |
|
""" |
|
assert 4 == img.dim(), f"Invalid input shape {img.shape}" |
|
|
|
original_height, original_width = img.shape[-2:] |
|
downscale_factor = min( |
|
max_edge_resolution / original_width, max_edge_resolution / original_height |
|
) |
|
|
|
new_width = int(original_width * downscale_factor) |
|
new_height = int(original_height * downscale_factor) |
|
|
|
resized_img = resize(img, (new_height, new_width), resample_method, antialias=True) |
|
return resized_img |
|
|
|
|
|
def ensemble_depth( |
|
depth: torch.Tensor, |
|
scale_invariant: bool = True, |
|
shift_invariant: bool = True, |
|
output_uncertainty: bool = False, |
|
reduction: str = "median", |
|
regularizer_strength: float = 0.02, |
|
max_iter: int = 50, |
|
tol: float = 1e-6, |
|
max_res: int = 1024, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
""" |
|
Ensembles depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the |
|
number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for |
|
depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The |
|
alignment happens when the predictions have one or more degrees of freedom, that is when they are either |
|
affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only |
|
`scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`) |
|
alignment is skipped and only ensembling is performed. |
|
|
|
Args: |
|
depth (`torch.Tensor`): |
|
Input ensemble depth maps. |
|
scale_invariant (`bool`, *optional*, defaults to `True`): |
|
Whether to treat predictions as scale-invariant. |
|
shift_invariant (`bool`, *optional*, defaults to `True`): |
|
Whether to treat predictions as shift-invariant. |
|
output_uncertainty (`bool`, *optional*, defaults to `False`): |
|
Whether to output uncertainty map. |
|
reduction (`str`, *optional*, defaults to `"median"`): |
|
Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and |
|
`"median"`. |
|
regularizer_strength (`float`, *optional*, defaults to `0.02`): |
|
Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1. |
|
max_iter (`int`, *optional*, defaults to `2`): |
|
Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options` |
|
argument. |
|
tol (`float`, *optional*, defaults to `1e-3`): |
|
Alignment solver tolerance. The solver stops when the tolerance is reached. |
|
max_res (`int`, *optional*, defaults to `1024`): |
|
Resolution at which the alignment is performed; `None` matches the `processing_resolution`. |
|
Returns: |
|
A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape: |
|
`(1, 1, H, W)`. |
|
""" |
|
if depth.dim() != 4 or depth.shape[1] != 1: |
|
raise ValueError(f"Expecting 4D tensor of shape [B,1,H,W]; got {depth.shape}.") |
|
if reduction not in ("mean", "median"): |
|
raise ValueError(f"Unrecognized reduction method: {reduction}.") |
|
if not scale_invariant and shift_invariant: |
|
raise ValueError("Pure shift-invariant ensembling is not supported.") |
|
|
|
def init_param(depth: torch.Tensor): |
|
init_min = depth.reshape(ensemble_size, -1).min(dim=1).values |
|
init_max = depth.reshape(ensemble_size, -1).max(dim=1).values |
|
|
|
if scale_invariant and shift_invariant: |
|
init_s = 1.0 / (init_max - init_min).clamp(min=1e-6) |
|
init_t = -init_s * init_min |
|
param = torch.cat((init_s, init_t)).cpu().numpy() |
|
elif scale_invariant: |
|
init_s = 1.0 / init_max.clamp(min=1e-6) |
|
param = init_s.cpu().numpy() |
|
else: |
|
raise ValueError("Unrecognized alignment.") |
|
|
|
return param.astype(np.float64) |
|
|
|
def align(depth: torch.Tensor, param: np.ndarray) -> torch.Tensor: |
|
if scale_invariant and shift_invariant: |
|
s, t = np.split(param, 2) |
|
s = torch.from_numpy(s).to(depth).view(ensemble_size, 1, 1, 1) |
|
t = torch.from_numpy(t).to(depth).view(ensemble_size, 1, 1, 1) |
|
out = depth * s + t |
|
elif scale_invariant: |
|
s = torch.from_numpy(param).to(depth).view(ensemble_size, 1, 1, 1) |
|
out = depth * s |
|
else: |
|
raise ValueError("Unrecognized alignment.") |
|
return out |
|
|
|
def ensemble( |
|
depth_aligned: torch.Tensor, return_uncertainty: bool = False |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
uncertainty = None |
|
if reduction == "mean": |
|
prediction = torch.mean(depth_aligned, dim=0, keepdim=True) |
|
if return_uncertainty: |
|
uncertainty = torch.std(depth_aligned, dim=0, keepdim=True) |
|
elif reduction == "median": |
|
prediction = torch.median(depth_aligned, dim=0, keepdim=True).values |
|
if return_uncertainty: |
|
uncertainty = torch.median( |
|
torch.abs(depth_aligned - prediction), dim=0, keepdim=True |
|
).values |
|
else: |
|
raise ValueError(f"Unrecognized reduction method: {reduction}.") |
|
return prediction, uncertainty |
|
|
|
def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float: |
|
cost = 0.0 |
|
depth_aligned = align(depth, param) |
|
|
|
for i, j in torch.combinations(torch.arange(ensemble_size)): |
|
diff = depth_aligned[i] - depth_aligned[j] |
|
cost += (diff**2).mean().sqrt().item() |
|
|
|
if regularizer_strength > 0: |
|
prediction, _ = ensemble(depth_aligned, return_uncertainty=False) |
|
err_near = (0.0 - prediction.min()).abs().item() |
|
err_far = (1.0 - prediction.max()).abs().item() |
|
cost += (err_near + err_far) * regularizer_strength |
|
|
|
return cost |
|
|
|
def compute_param(depth: torch.Tensor): |
|
import scipy |
|
|
|
depth_to_align = depth.to(torch.float32) |
|
if max_res is not None and max(depth_to_align.shape[2:]) > max_res: |
|
depth_to_align = resize_max_res( |
|
depth_to_align, max_res, get_tv_resample_method("nearest-exact") |
|
) |
|
|
|
param = init_param(depth_to_align) |
|
|
|
res = scipy.optimize.minimize( |
|
partial(cost_fn, depth=depth_to_align), |
|
param, |
|
method="BFGS", |
|
tol=tol, |
|
options={"maxiter": max_iter, "disp": False}, |
|
) |
|
|
|
return res.x |
|
|
|
requires_aligning = scale_invariant or shift_invariant |
|
ensemble_size = depth.shape[0] |
|
|
|
if requires_aligning: |
|
param = compute_param(depth) |
|
depth = align(depth, param) |
|
|
|
depth, uncertainty = ensemble(depth, return_uncertainty=output_uncertainty) |
|
|
|
depth_max = depth.max() |
|
if scale_invariant and shift_invariant: |
|
depth_min = depth.min() |
|
elif scale_invariant: |
|
depth_min = 0 |
|
else: |
|
raise ValueError("Unrecognized alignment.") |
|
depth_range = (depth_max - depth_min).clamp(min=1e-6) |
|
depth = (depth - depth_min) / depth_range |
|
if output_uncertainty: |
|
uncertainty /= depth_range |
|
|
|
return depth, uncertainty |
|
|
|
|
|
|
|
bs_search_table = [ |
|
|
|
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32}, |
|
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32}, |
|
|
|
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32}, |
|
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32}, |
|
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16}, |
|
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16}, |
|
|
|
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32}, |
|
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32}, |
|
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32}, |
|
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16}, |
|
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16}, |
|
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16}, |
|
|
|
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32}, |
|
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32}, |
|
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16}, |
|
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16}, |
|
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16}, |
|
] |
|
|
|
|
|
def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int: |
|
""" |
|
Automatically search for suitable operating batch size. |
|
|
|
Args: |
|
ensemble_size (`int`): |
|
Number of predictions to be ensembled. |
|
input_res (`int`): |
|
Operating resolution of the input image. |
|
|
|
Returns: |
|
`int`: Operating batch size. |
|
""" |
|
if not torch.cuda.is_available(): |
|
return 1 |
|
|
|
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3 |
|
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype] |
|
for settings in sorted( |
|
filtered_bs_search_table, |
|
key=lambda k: (k["res"], -k["total_vram"]), |
|
): |
|
if input_res <= settings["res"] and total_vram >= settings["total_vram"]: |
|
bs = settings["bs"] |
|
if bs > ensemble_size: |
|
bs = ensemble_size |
|
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size: |
|
bs = math.ceil(ensemble_size / 2) |
|
return bs |
|
|
|
return 1 |
|
|