# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # More information about Marigold: # https://marigoldmonodepth.github.io # https://marigoldcomputervision.github.io # Efficient inference pipelines are now part of diffusers: # https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage # https://huggingface.co/docs/diffusers/api/pipelines/marigold # Examples of trained models and live demos: # https://huggingface.co/prs-eth # Related projects: # https://rollingdepth.github.io/ # https://marigolddepthcompletion.github.io/ # Citation (BibTeX): # https://github.com/prs-eth/Marigold#-citation # If you find Marigold useful, we kindly ask you to cite our papers. # -------------------------------------------------------------------------- import logging import numpy as np import torch from typing import Dict, Union import math from diffusers import ( AutoencoderKL, DDIMScheduler, DiffusionPipeline, LCMScheduler, UNet2DConditionModel, AutoencoderTiny, ) from diffusers.utils import BaseOutput from PIL import Image from torch.utils.data import DataLoader, TensorDataset from torchvision.transforms.functional import resize, pil_to_tensor from torchvision.transforms import InterpolationMode from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from functools import partial from typing import Optional, Tuple class MarigoldDepthOutput(BaseOutput): """ Output class for Marigold Monocular Depth Estimation pipeline. Args: depth_np (`np.ndarray`): Predicted depth map, with depth values in the range of [0, 1]. base_depth_np (`np.ndarray`): Upsampled base depth map, with depth values in the range of [0, 1]. This is the depth map used as a global guidance for the boosted inference. It is upsampled to the same resolution as the final depth map. This is useful for visualization and debugging purposes. """ depth_np: np.ndarray base_depth_np: np.ndarray # NEW: upsampled base depth class MarigoldDepthHRPipeline(DiffusionPipeline): """ Pipeline for high resolution monocular depth estimation using Marigold: https://marigoldcomputervision.github.io. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: unet (`UNet2DConditionModel`): Conditional U-Net to denoise the prediction latent, conditioned on image latent. vae (`AutoencoderKL`): Variational Auto-Encoder (VAE) Model to encode and decode images and predictions to and from latent representations. scheduler (`DDIMScheduler`): A scheduler to be used in combination with `unet` to denoise the encoded image latents. text_encoder (`CLIPTextModel`): Text-encoder, for empty text embedding. tokenizer (`CLIPTokenizer`): CLIP tokenizer. boosting_unet (`UNet2DConditionModel`): Conditional U-Net to denoise the depth latent, conditioned on image latent and a global depth map. scale_invariant (`bool`, *optional*): A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in the model config. When used together with the `shift_invariant=True` flag, the model is also called "affine-invariant". NB: overriding this value is not supported. shift_invariant (`bool`, *optional*): A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in the model config. When used together with the `scale_invariant=True` flag, the model is also called "affine-invariant". NB: overriding this value is not supported. default_denoising_steps (`int`, *optional*): The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable quality with the given model. This value must be set in the model config. When the pipeline is called without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure reasonable results with various model flavors compatible with the pipeline, such as those relying on very short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). default_boosting_denoising_steps (`int`, *optional*): Same as `default_denoising_steps` but for `boosting_unet`. default_processing_resolution (`int`, *optional*): The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in the model config. When the pipeline is called without explicitly setting `processing_resolution`, the default value is used. This is required to ensure reasonable results with various model flavors trained with varying optimal processing resolution values. """ latent_scale_factor = 0.18215 def __init__( self, unet: UNet2DConditionModel, vae: AutoencoderKL, scheduler: Union[DDIMScheduler, LCMScheduler], text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, boosting_unet: Optional[UNet2DConditionModel], scale_invariant: Optional[bool] = True, shift_invariant: Optional[bool] = True, default_denoising_steps: Optional[int] = None, default_boosting_denoising_steps: Optional[int] = None, default_processing_resolution: Optional[int] = None, base_depth_model_uri: Optional[str] = None, variant: Optional[str] = None, ): super().__init__() if boosting_unet is None: logging.warning( "Boosting U-Net is not provided. If this message appears during training, it is expected." ) self.register_modules( unet=unet, vae=vae, scheduler=scheduler, text_encoder=text_encoder, tokenizer=tokenizer, boosting_unet=boosting_unet, ) self.register_to_config( scale_invariant=scale_invariant, shift_invariant=shift_invariant, default_denoising_steps=default_denoising_steps, default_boosting_denoising_steps=default_denoising_steps, default_processing_resolution=default_processing_resolution, ) self.register_to_config(base_depth_model_uri=base_depth_model_uri) self.scale_invariant = scale_invariant self.shift_invariant = shift_invariant self.default_denoising_steps = default_denoising_steps self.default_boosting_denoising_steps = default_boosting_denoising_steps self.default_processing_resolution = default_processing_resolution if base_depth_model_uri is not None: # load the original LR depth model self.base_pipe = DiffusionPipeline.from_pretrained( base_depth_model_uri, variant=variant, torch_dtype=self.dtype, trust_remote_code=True, ) self.base_pipe.to(self.device) else: self.base_pipe = None self.empty_text_embed = None @torch.no_grad() def __call__( self, input_image: Union[Image.Image, torch.Tensor], *, base_depth: Optional[Union[Image.Image, np.ndarray, torch.Tensor]] = None, denoising_steps: Optional[int] = None, boosted_denoising_steps: Optional[int] = None, ensemble_size: int = 10, boosted_ensemble_size: int = 5, processing_res: Optional[int] = None, match_input_res: bool = True, resample_method: str = "bilinear", batch_size: int = 0, show_progress_bar: bool = True, ensemble_kwargs: Dict = None, upscale_factor: int = 2, ) -> MarigoldDepthOutput: """ Function invoked when calling the pipeline. Args: input_image (`Image` or `torch.Tensor`): Input RGB (or gray-scale) image. base_depth (`Image`, `np.ndarray`, `torch.Tensor` or `MarigoldDepthOutput`, *optional*): Base depth map to be used as a global guidance for the boosted inference. denoising_steps (`int`, *optional*, defaults to `10`): Number of diffusion denoising steps (DDIM) during inference. ensemble_size (`int`, *optional*, defaults to `10`): Number of predictions to be ensembled. boosted_ensemble_size (`int`, *optional*, defaults to `5`): Number of predictions to be ensembled in the boosted inference. processing_res (`int`, *optional*, defaults to `768`): Maximum resolution of processing. If set to 0: will not resize at all. match_input_res (`bool`, *optional*, defaults to `True`): Resize depth prediction to match input resolution. Only valid if `processing_res` > 0. resample_method: (`str`, *optional*, defaults to `bilinear`): Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`. batch_size (`int`, *optional*, defaults to `0`): Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. show_progress_bar (`bool`, *optional*, defaults to `True`): Display a progress bar of diffusion denoising. scale_invariant (`str`, *optional*, defaults to `True`): Flag of scale-invariant prediction, if True, scale will be adjusted from the raw prediction. shift_invariant (`str`, *optional*, defaults to `True`): Flag of shift-invariant prediction, if True, shift will be adjusted from the raw prediction, if False, near plane will be fixed at 0m. ensemble_kwargs (`dict`, *optional*, defaults to `None`): Arguments for detailed ensembling settings. Returns: `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including: - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1] - **base_depth_np** (`np.ndarray`) Upsampled base depth map, with depth values in the range of [0, 1]. This is the depth map used as a global guidance for the boosted inference. """ # Model-specific optimal default values leading to fast and reasonable results. if denoising_steps is None: denoising_steps = self.default_denoising_steps if boosted_denoising_steps is None: boosted_denoising_steps = self.default_boosting_denoising_steps if processing_res is None: processing_res = self.default_processing_resolution # Asserts assert processing_res >= 0, "Processing resolution must be non-negative." assert ensemble_size >= 1, "Ensemble size must be at least 1." assert boosted_ensemble_size >= 1, "Boosted ensemble size must be at least 1." assert math.log2( upscale_factor ).is_integer(), "Upscale factor must be a power of 2." assert upscale_factor >= 1, "Upscale factor must be at least 2." assert batch_size >= 0, "Batch size must be non-negative." # Warnings if upscale_factor >= 8 and self.dtype == torch.float16: logging.warning( "Warning: Upscaling factors of 8 (and more) with half precision or more may lead to artifacts in the final prediction." ) if upscale_factor >= 4 and isinstance(self.vae, AutoencoderTiny): logging.warning( "Warning: Upscaling factors of 4 (and more) with the Tiny VAE may lead to instabilities." ) # Get the resolution of the input RGB image input_width, input_height = ( input_image.size if isinstance(input_image, Image.Image) else input_image.shape[-2:] ) # 1) get base prediction if base_depth is not None: # load into float32 np.ndarray if isinstance(base_depth, Image.Image): lowres = np.asarray(base_depth.convert("L"), dtype=np.float32) elif isinstance(base_depth, torch.Tensor): lowres = base_depth.squeeze().cpu().numpy().astype(np.float32) elif isinstance(base_depth, np.ndarray): lowres = base_depth.squeeze().astype(np.float32) elif isinstance(base_depth, MarigoldDepthOutput): lowres = base_depth.depth_np.astype(np.float32) else: raise TypeError(f"Unsupported base_depth type: {type(base_depth)}") # *** min–max normalize to [0,1] *** min_v, max_v = lowres.min(), lowres.max() eps = 1e-8 if max_v - min_v > eps: lowres = (lowres - min_v) / (max_v - min_v + eps) else: # flat image → all zeros lowres = np.zeros_like(lowres, dtype=np.float32) else: assert self.base_pipe is not None if self.base_pipe.device != self.device: # Move the base pipe to the correct device self.base_pipe.to(self.device) base_out = self.base_pipe( input_image, num_inference_steps=denoising_steps, ensemble_size=ensemble_size, processing_resolution=processing_res, match_input_resolution=False, batch_size=1, # base inference is always done in batch size 1 # show_progress_bar=show_progress_bar, resample_method_input=resample_method, resample_method_output=resample_method, ensembling_kwargs=ensemble_kwargs, ) lowres = base_out.prediction[0,:,:,0] # [H, W] base_out.depth_np = lowres # 2) Upsample base for output t = torch.from_numpy(lowres[None,None]) up = resize(t, (input_height, input_width), interpolation=InterpolationMode.NEAREST_EXACT, antialias=True) base_depth_np_upsampled = up.squeeze().cpu().numpy() # 3) If no boosting requested, return early if upscale_factor == 1: # If no upscaling is needed, return the base prediction return MarigoldDepthOutput( depth_np=lowres, base_depth_np=base_depth_np_upsampled ) # 4) Normalize and run boosted inference (unchanged) global_pred = torch.from_numpy(lowres).to(self.device) global_pred = (global_pred - global_pred.min()) / (global_pred.max() - global_pred.min()) # Iterative refinement logic current_pred = global_pred current_factor = 2 # Start with an upscale factor of 2 # Create a list of all upscale factors up to the target upscale_factors = [2**i for i in range(1, int(math.log2(upscale_factor)) + 1)] # precalculate patch dimensions if processing_res == 0: processing_res = 768 df = min( processing_res / input_image.width, processing_res / input_image.height ) patch_height = int(input_image.height * df) patch_width = int(input_image.width * df) # Pre‐warn about any that exceed the 1.1× threshold for factor in upscale_factors: tw = patch_width * factor th = patch_height * factor if tw > input_width * 1.1 or th > input_height * 1.1: logging.warning( f"Warning: Attempting to upsample to {tw}×{th}, " f"which exceeds the original input of {input_width}×{input_height}. " "This technically works, but may lead to suboptimal results." ) # 5) Perform iterative boosted inference with tqdm( total=len(upscale_factors), desc=" Upscaling Progress", unit="step", leave=False, ) as pbar: for current_factor in upscale_factors: # Update the description with the current upscaling factor pbar.set_description(f" Upscaling x{current_factor}") # Determine if this is the final step is_final_step = current_factor == upscale_factor # 2. Perform a single boosted inference step boosted_output = self.boosted_inference( input_image=input_image, denoising_steps=boosted_denoising_steps, ensemble_size=( boosted_ensemble_size if current_factor < upscale_factors[-1] else boosted_ensemble_size ), processing_res=processing_res, match_input_res=match_input_res and is_final_step, batch_size=batch_size, resample_method=resample_method, show_progress_bar=show_progress_bar, ensemble_kwargs=ensemble_kwargs, global_pred=current_pred, upscale_factor=current_factor, ) # Update predictions current_pred = torch.from_numpy(boosted_output.depth_np) # Clean up GPU memory torch.cuda.empty_cache() # Progress to the next upscale factor current_factor *= 2 # Update the progress bar pbar.update(1) # Return the final output, and attach base depth map out = boosted_output out.base_depth_np = base_depth_np_upsampled return out def boosted_inference( self, input_image: Union[torch.Tensor], denoising_steps: int = 10, ensemble_size: int = 10, processing_res: int = 768, match_input_res: bool = True, batch_size: int = 0, resample_method: str = "bilinear", seed: Union[int, None] = None, show_progress_bar: bool = True, ensemble_kwargs: Dict = None, global_pred: torch.Tensor = None, upscale_factor: int = 2, ) -> MarigoldDepthOutput: """ Function invoked when calling the pipeline with boosted inference. Args: input_image (`torch.Tensor`): Input RGB image. denoising_steps (`int`, *optional*, defaults to `10`): Number of diffusion denoising steps (DDIM) during inference. ensemble_size (`int`, *optional*, defaults to `10`): Number of predictions to be ensembled. processing_res (`int`, *optional*, defaults to `768`): Maximum resolution of processing. If set to 0: will not resize at all. match_input_res (`bool`, *optional*, defaults to `True`): Resize depth prediction to match input resolution. Only valid if `processing_res` > 0. resample_method: (`str`, *optional*, defaults to `bilinear`): Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`. batch_size (`int`, *optional*, defaults to `0`): Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. seed (`int`, *optional*, defaults to `None`): Random seed for the diffusion process. show_progress_bar (`bool`, *optional*, defaults to `True`): Display a progress bar of diffusion denoising. ensemble_kwargs (`dict`, *optional*, defaults to `None`): Arguments for detailed ensembling settings. global_pred (`torch.Tensor`): Global depth map to be used as guidance. upscale_factor (`int`, *optional*, defaults to `2`): Upscale factor of the global depth map. Returns: `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including: - **depth_np** (`np.ndarray`) Predicted depth map with depth values in the range of [0, 1] - **base_depth_np** (`np.ndarray`) Upsampled base depth map with depth values in the range of [0, 1]. """ device = self.device self._check_inference_step(denoising_steps) resample_method: InterpolationMode = get_tv_resample_method(resample_method) # Convert to torch tensor if isinstance(input_image, Image.Image): input_image = input_image.convert("RGB") # convert to torch tensor [H, W, rgb] -> [rgb, H, W] input_image = pil_to_tensor(input_image) elif isinstance(input_image, torch.Tensor): input_image = input_image.squeeze() # pass else: raise TypeError(f"Unknown input type: {type(input_image) = }") input_size = input_image.shape assert ( 3 == input_image.dim() and 3 == input_size[0] ), f"Wrong input shape {input_size}, expected [rgb, H, W]" if isinstance(global_pred, torch.Tensor): global_pred = global_pred.squeeze().unsqueeze(0).to(device) else: raise TypeError(f"Unknown global_pred type: {type(global_pred) = }") if processing_res == 0: # fallback to original resolution processing_res = 768 df = min( processing_res / input_image.shape[2], processing_res / input_image.shape[1] ) patch_height = int(input_image.shape[1] * df) patch_width = int(input_image.shape[2] * df) # need to be divisible by 8 patch_height = round(patch_height / 16) * 16 patch_width = round(patch_width / 16) * 16 patch_size = patch_height, patch_width global_size = (patch_size[0] * upscale_factor, patch_size[1] * upscale_factor) if global_pred.shape[1:] != global_size: global_pred = resize( global_pred, global_size, interpolation=resample_method, antialias=True, ) if input_image.shape[1:] != patch_size: input_image = resize( input_image, global_size, interpolation=resample_method, antialias=True, ).squeeze() input_image = ( input_image.unsqueeze(0) / 255.0 * 2.0 - 1.0 ) # [0, 255] -> [-1, 1] input_image = input_image.to(self.dtype).to(device) assert input_image.min() >= -1.0 and input_image.max() <= 1.0 global_pred = global_pred.to(self.dtype).to(device) if batch_size > 0: _bs = batch_size else: _bs = find_batch_size( ensemble_size=ensemble_size * (2 * upscale_factor - 1) * (2 * upscale_factor - 1), input_res=max(patch_size), dtype=self.dtype, ) # create a small buffer in z-dimension global_pred = 0.9 * (global_pred * 2 - 1) global_pred = (global_pred + 1) / 2 depth_pred, pred_uncert = self.multidiffusion_inference( rgb_norm=input_image, global_pred=global_pred, num_inference_steps=denoising_steps, patch_size=patch_size, seed=seed, show_pbar=show_progress_bar, ensemble_size=ensemble_size, batch_size=_bs, ensemble_kwargs=ensemble_kwargs, ) depth_pred = depth_pred.squeeze(0) # rescale to to [0, 1] min_d = torch.min(depth_pred) max_d = torch.max(depth_pred) depth_pred = (depth_pred - min_d) / (max_d - min_d) if depth_pred.shape[1:] != input_size[1:] and match_input_res: depth_pred = resize( depth_pred.unsqueeze(0), input_size[1:], interpolation=resample_method, antialias=True, ).squeeze() depth_pred = depth_pred.squeeze().cpu().numpy() depth_pred = depth_pred.clip(0, 1) return MarigoldDepthOutput( depth_np=depth_pred, ) def _check_inference_step(self, n_step: int) -> None: """ Check if denoising step is reasonable Args: n_step (`int`): denoising steps """ assert n_step >= 1 def encode_empty_text(self): """ Encode text embedding for empty prompt """ prompt = "" text_inputs = self.tokenizer( prompt, padding="do_not_pad", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) @torch.no_grad() def multidiffusion_inference( self, rgb_norm: torch.Tensor, global_pred: torch.Tensor, num_inference_steps: int, seed: Union[int, None], show_pbar: bool, patch_size=(512, 768), encoder_patch_size=None, ensemble_size=1, batch_size=1, ensemble_kwargs: Dict = None, ) -> torch.Tensor: """ Perform an individual depth prediction without ensembling. Args: rgb_norm (`torch.Tensor`): Input RGB image. num_inference_steps (`int`): Number of diffusion denoisign steps (DDIM) during inference. num_patches_vert (`int`): Number of vertical patches. num_patches_horz (`int`): Number of horizontal patches. step_height (`int`): Height of the patch. step_width (`int`): Width of the patch. Returns: `torch.Tensor`: Predicted depth map. `torch.Tensor`: Uncertainty map. """ device = self.device rgb_norm = rgb_norm.to(device) global_pred = global_pred.to(device) # Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps latent_patch_size = (patch_size[0] // 8, patch_size[1] // 8) latent_full_size = (rgb_norm.shape[-2] // 8, rgb_norm.shape[-1] // 8) # Normalize the global prediction to [-1, 1] global_pred = global_pred * 2 - 1 # Encode rgb and depth map if encoder_patch_size is None: encoder_patch_size = patch_size rgb_latent = self.encode_rgb_patched( rgb_norm, patch_size=encoder_patch_size, show_pbar=show_pbar ) global_pred_latent = self.encode_depth_patched( global_pred, patch_size=encoder_patch_size, show_pbar=show_pbar ) # offload to cpu for memory efficiency rgb_norm = rgb_norm.cpu() global_pred = global_pred.cpu() # Initial depth map (noise) if seed is None: rand_num_generator = None else: rand_num_generator = torch.Generator(device=device) rand_num_generator.manual_seed(seed) # patch the input images ( rgb_latent_patched, num_patches_vert, num_patches_horz, step_height, step_width, ) = self.extract_patches(rgb_latent[0], patch_size=latent_patch_size) # also patch the global depth map global_pred_latent_patched, _, _, _, _ = self.extract_patches( global_pred_latent[0], patch_size=latent_patch_size ) # Batched empty text embedding if self.empty_text_embed is None: self.encode_empty_text() batch_empty_text_embed = self.empty_text_embed.repeat((batch_size, 1, 1)).to( device ) if hasattr(self, "pooled_empty_text_embeds"): batch_pooled_empty_text_embed = self.pooled_empty_text_embeds.repeat( (batch_size, 1, 1) ).to(device) # enlarge the variable according to the ensemble size if ensemble_size > 1: rgb_latent_patched = rgb_latent_patched.repeat(ensemble_size, 1, 1, 1) global_pred_latent_patched = global_pred_latent_patched.repeat( ensemble_size, 1, 1, 1 ) batch_empty_text_embed = batch_empty_text_embed.repeat(ensemble_size, 1, 1) if hasattr(self, "pooled_empty_text_embeds"): batch_pooled_empty_text_embed = batch_pooled_empty_text_embed.repeat( ensemble_size, 1, 1 ) # Initialize the canvas and split it to get identical noise on overlaps depth_latent = torch.randn( (ensemble_size, 4, latent_full_size[0], latent_full_size[1]), device=device, dtype=self.dtype, generator=rand_num_generator, ) ( depth_latent_patched, num_patches_vert, num_patches_horz, step_height, step_width, ) = self.extract_patches(depth_latent, patch_size=latent_patch_size) # Denoising loop if show_pbar: iterable = tqdm( enumerate(timesteps), total=len(timesteps), leave=False, desc=" " * 4 + "Diffusion denoising", ) else: iterable = enumerate(timesteps) for _, t in iterable: # 1. inference all the patches with unet assert ( self.boosting_unet.conv_in.in_channels == 12 ), "The input channels of the boosting unet must be 12." unet_input = torch.cat( [rgb_latent_patched, global_pred_latent_patched, depth_latent_patched], dim=1, ) # 2. Create a dataloader and predict the noise dataset = TensorDataset(unet_input) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False) noise_preds = [] for batch in tqdm( loader, leave=False, disable=not show_pbar, desc=" " * 6 + "UNet patch inference", ): (unet_input,) = batch noise_pred = self.boosting_unet( unet_input, t, encoder_hidden_states=batch_empty_text_embed[: unet_input.shape[0]], ).sample noise_preds.append(noise_pred) noise_preds = torch.concat(noise_preds, dim=0) # 3. Default ddim scheduler step for each patch scheduler_out = self.scheduler.step( noise_preds, t, depth_latent_patched, generator=rand_num_generator ) # 4. Reshape patches to patch-spatial dimension depth_latent = scheduler_out.prev_sample.reshape( ensemble_size, num_patches_vert, num_patches_horz, 4, latent_patch_size[0], latent_patch_size[1], ) # 5. Blend the patches with multidiffusion formula depth_latent_full = self.blend_patches( depth_latent, canvas_size=global_pred_latent.shape[2:], num_patches_vert=num_patches_vert, num_patches_horz=num_patches_horz, step_height=step_height, step_width=step_width, ) # 6. Update the depth_latent_patched if t != timesteps[-1]: depth_latent_patched, _, _, _, _ = self.extract_patches( depth_latent_full, patch_size=latent_patch_size ) # if t<=1: at the end of the loop we decode the full latent depth = self.decode_depth_patched( depth_latent_full=depth_latent_full, canvas_size=global_pred.shape[1:], latent_decoder_patch_size=( encoder_patch_size[0] // 8, encoder_patch_size[1] // 8, ), show_pbar=show_pbar, ensemble_size=ensemble_size, ) if ensemble_size > 1: depth, pred_uncert = ensemble_depth( depth, scale_invariant=self.scale_invariant, shift_invariant=self.shift_invariant, **(ensemble_kwargs or {}), ) else: depth = (depth + 1.0) / 2.0 pred_uncert = None return depth, pred_uncert def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: """ Encode RGB image into latent. Args: rgb_in (`torch.Tensor`): Input RGB image to be encoded. Returns: `torch.Tensor`: Image latent. """ if isinstance(self.vae, AutoencoderTiny): rgb_latent = self.vae.encoder(rgb_in) else: h = self.vae.encoder(rgb_in) moments = self.vae.quant_conv(h) mean, _ = torch.chunk(moments, 2, dim=1) rgb_latent = mean * self.latent_scale_factor return rgb_latent def encode_rgb_patched( self, rgb_in: torch.Tensor, patch_size: tuple, show_pbar: bool = False ) -> torch.Tensor: """ Encode depth map into latent. Args: depth_in (`torch.Tensor`): Input depth map to be encoded. patch_size (`Tuple[int, int]`): Size of the patch. Returns: `torch.Tensor`: Depth latent. """ device = self.device ( rgb_patched, num_patches_vert_pix, num_patches_horz_pix, step_height_pix, step_width_pix, ) = self.extract_patches(rgb_in.squeeze(0), patch_size=patch_size) rgb_in = rgb_in.cpu() rgb_patched = rgb_patched.cpu() # forward the patches rgb_latent_patched = [] for rgb in tqdm( rgb_patched, leave=False, disable=not show_pbar, desc=" " * 4 + "Encoding RGB", ): patch = rgb.unsqueeze(0).to(device) rgb_latent_patched.append(self.encode_rgb(patch).cpu()) # reshape the spatial dimensions rgb_latent_patched = torch.concat(rgb_latent_patched, dim=0) rgb_latent_patched = rgb_latent_patched.reshape( num_patches_vert_pix, num_patches_horz_pix, 4, rgb_latent_patched.shape[-2], rgb_latent_patched.shape[-1], ).to(device) # blend the patches rgb_latent = self.blend_patches( rgb_latent_patched, canvas_size=(rgb_in.shape[-2] // 8, rgb_in.shape[-1] // 8), num_patches_vert=num_patches_vert_pix, num_patches_horz=num_patches_horz_pix, step_height=step_height_pix // 8, step_width=step_width_pix // 8, ) if len(rgb_latent.shape) == 3: rgb_latent = rgb_latent.unsqueeze(0) return rgb_latent def encode_depth_patched( self, depth_in: torch.Tensor, patch_size, show_pbar: bool = False, ) -> torch.Tensor: """ Encode depth map into latent, but in a patched and scalable way Args: depth_in (`torch.Tensor`): Input depth map to be encoded. patch_size (`Tuple[int, int]`): Size of the patch. show_pbar (`bool`): Display a progress bar. Returns: `torch.Tensor`: Depth latent. """ device = self.device ensemble_size = depth_in.shape[0] ( depth_in_patched, num_patches_vert_pix, num_patches_horz_pix, step_height_pix, step_width_pix, ) = self.extract_patches(depth_in, patch_size=patch_size) depth_in = depth_in.cpu() depth_in_patched = depth_in_patched.cpu() # forward the patches depth_latent_patched = [] for gpred in tqdm( depth_in_patched, leave=False, disable=not show_pbar, desc=" " * 4 + "Encoding context depth", ): patch = gpred.unsqueeze(0).to(device) depth_latent_patched.append(self.encode_depth(patch).cpu()) # reshape the spatial dimensions depth_latent_patched = torch.concat(depth_latent_patched, dim=0) depth_latent_patched = depth_latent_patched.reshape( ensemble_size, num_patches_vert_pix, num_patches_horz_pix, 4, depth_latent_patched.shape[-2], depth_latent_patched.shape[-1], ).to(device) # blend the patches depth_latent = self.blend_patches( depth_latent_patched, canvas_size=(depth_in.shape[-2] // 8, depth_in.shape[-1] // 8), num_patches_vert=num_patches_vert_pix, num_patches_horz=num_patches_horz_pix, step_height=step_height_pix // 8, step_width=step_width_pix // 8, ) if len(depth_latent.shape) == 3: depth_latent = depth_latent.unsqueeze(0) return depth_latent def decode_depth_patched( self, depth_latent_full: torch.Tensor, canvas_size: tuple, latent_decoder_patch_size: tuple, show_pbar: bool = True, ensemble_size: int = 1, ) -> torch.Tensor: """ Decode depth map from latent in a patched and scalable way. Args: depth_latent_full (`torch.Tensor`): Depth latent to be decoded. canvas_size (`tuple`): Size of the canvas. latent_decoder_patch_size (`tuple`): Size of the patch. show_pbar (`bool`): Display a progress bar. ensemble_size (`int`): Ensemble size. Returns: `torch.Tensor`: Decoded depth map. """ encoder_patch_size = ( latent_decoder_patch_size[0] * 8, latent_decoder_patch_size[1] * 8, ) # extract patches ( depth_latent_patched, num_patches_vert, num_patches_horz, step_height, step_width, ) = self.extract_patches( depth_latent_full, patch_size=latent_decoder_patch_size, overlap=0.5 ) # decode patches depthp = [] for patch in tqdm( depth_latent_patched, leave=False, desc=" " * 6 + "Decoding Depth", disable=not show_pbar, ): depthp.append(self.decode_depth(patch.unsqueeze(0))) depthp = torch.concat(depthp, dim=0) depthp = depthp.reshape( ensemble_size, num_patches_vert, num_patches_horz, 1, encoder_patch_size[0], encoder_patch_size[1], ) # blend together depth = self.blend_patches( depthp, canvas_size=canvas_size, num_patches_vert=num_patches_vert, num_patches_horz=num_patches_horz, step_height=step_height * 8, step_width=step_width * 8, ) return depth def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: """ Decode depth latent into depth map. Args: depth_latent (`torch.Tensor`): Depth latent to be decoded. Returns: `torch.Tensor`: Decoded depth map of the input depth latent. """ # if self.using_tiny_vae: if isinstance(self.vae, AutoencoderTiny): stacked = self.vae.decoder(depth_latent) else: depth_latent = depth_latent / self.latent_scale_factor z = self.vae.post_quant_conv(depth_latent) stacked = self.vae.decoder(z) # mean of output channels depth_mean = stacked.mean(dim=1, keepdim=True) return depth_mean def encode_depth( self, depth_in: torch.Tensor, ) -> torch.Tensor: """ Encode depth map into latent. Args: depth_in (`torch.Tensor`): Input depth map to be encoded. Returns: `torch.Tensor`: Depth latent of the input depth map. """ # stack depth into 3-channel stacked = self.stack_depth_images(depth_in) # encode using VAE encoder depth_latent = self.encode_rgb(stacked) return depth_latent @staticmethod def stack_depth_images(depth_in): if 4 == len(depth_in.shape): stacked = depth_in.repeat(1, 3, 1, 1) elif 3 == len(depth_in.shape): stacked = depth_in.unsqueeze(1) stacked = depth_in.repeat(1, 3, 1, 1) return stacked def extract_patches(self, canvas_input_image, patch_size, overlap=0.5): """ Extract patches from an image Args: image (`np.ndarray`): Input image of shape [channels, height, width] patch_size (`tuple`): Size of the patch step_size (`int`): Step size Returns: input_image_patched (`torch.Tensor`): Extracted patches num_patches_vert (`int`): Number of patches in the vertical direction num_patches_horz (`int`): Number of patches in the horizontal direction step_height (`int`): Step size in the vertical direction step_width (`int`): Step size in the horizontal direction """ if len(canvas_input_image.shape) == 4: ensemble_size = canvas_input_image.shape[0] else: canvas_input_image = canvas_input_image.unsqueeze(0) ensemble_size = 1 # Step sizes (50% (overlap) of the original dimensions) h, w = patch_size step_height, step_width = int(h * (1 - overlap)), int(w * (1 - overlap)) # Calculate the number of patches to extract in both dimensions num_patches_vert = (canvas_input_image.shape[-2] - h) // step_height + 1 num_patches_horz = (canvas_input_image.shape[-1] - w) // step_width + 1 # Initialize a list to hold the patches patches = [] for e in range(ensemble_size): for i in range(num_patches_vert): for j in range(num_patches_horz): # Calculate the top left corner of the current patch start_y = i * step_height start_x = j * step_width # Extract the patch patch = canvas_input_image[ e, :, start_y : start_y + h, start_x : start_x + w ] patches.append(patch) # Stack the patches and return input_image_patched = torch.stack(patches) return ( input_image_patched, num_patches_vert, num_patches_horz, step_height, step_width, ) def blend_patches( self, depth_preds, num_patches_vert, num_patches_horz, step_height, step_width, global_depth_pred=None, canvas_size=None, alpha_center=1.0, alpha_edge=1e-4, noise_blend=False, eps=1e-8, ): """ Blend patches of depth maps and apply the transformation to the global depth map Args: global_depth_pred (`torch.Tensor`): Global depth map depth_preds (`torch.Tensor`): Local depth maps num_patches_vert (`int`): Number of patches in the vertical direction num_patches_horz (`int`): Number of patches in the horizontal direction step_height (`int`): Step size in the vertical direction step_width (`int`): Step size in the horizontal direction overlap (`float`): Overlap between patches alpha_center (`float`): Weight at the center of the patch alpha_edge (`float`): Weight at the edge of the patch, should not be exactly 0 to avoid division by zero adjust_LSQ (`bool`): Adjust the transformation using least squares optimization Returns: `torch.Tensor`: Blended depth map """ eps = 1e-8 channels, h, w = ( depth_preds.shape[-3], depth_preds.shape[-2], depth_preds.shape[-1], ) if len(depth_preds.shape) == 6: ensemble_size = depth_preds.shape[0] if len(depth_preds.shape) == 6 else 1 elif len(depth_preds.shape) == 5: ensemble_size = 1 depth_preds = depth_preds.unsqueeze(0) else: raise ValueError("depth_preds should have 5 or 6 dimensions") # Initialize the canvas for blending depth maps blended_depth_map = torch.zeros( (ensemble_size, channels, canvas_size[0], canvas_size[1]), device=depth_preds.device, dtype=depth_preds.dtype, ) denominator_map = torch.zeros( (1, canvas_size[0], canvas_size[1]), device=depth_preds.device, dtype=depth_preds.dtype, ) for i in range(num_patches_vert): for j in range(num_patches_horz): # Calculate the top left corner of the current patch start_y = i * step_height start_x = j * step_width # Extract the patch local_patch = depth_preds[:, i, j] # Generate blending weights weights = self.get_linear_weight_map( h, w, device=local_patch.device, alpha_center=alpha_center, alpha_edge=alpha_edge, cosine_blending=True, margin=0.0, ) # accumulate the local patch to the canvas as a linear combination blended_depth_map[ :, :, start_y : start_y + h, start_x : start_x + w ] += (local_patch * weights) if noise_blend: denominator_weights = weights**2 else: denominator_weights = weights denominator_map[ :, start_y : start_y + h, start_x : start_x + w ] += denominator_weights if noise_blend: denominator_map = torch.sqrt(denominator_map) blended_depth_map /= denominator_map + eps # Ensure that blended_depth_map does not have NaN values # by filling them with the global_depth_pred if global_depth_pred is not None: blended_depth_map[torch.isnan(blended_depth_map)] = ( global_depth_pred.repeat(ensemble_size, 1, 1, 1)[ torch.isnan(blended_depth_map) ] ) return blended_depth_map def get_linear_weight_map( self, h, w, device, alpha_center=1.0, alpha_edge=1e-4, margin=0.0, cosine_blending=False, ): """ Generate a linear weight map for blending patches Args: h (`int`): Height of the weight map w (`int`): Width of the weight map device (`torch.device`): Device to use alpha_center (`float`): Weight at the center of the patch alpha_edge (`float`): Weight at the edge of the patch margin (`int`): Perceptage of image dimensions. This margin at the edges to be filled with near 0 values. Returns: `torch.Tensor`: Linear distance weight map (looks like a pyramid) """ x = torch.linspace(-1, 1, h, device=device) y = torch.linspace(-1, 1, w, device=device) xx, yy = torch.meshgrid(x, y, indexing="ij") dist = torch.stack([xx.abs(), yy.abs()]).max(dim=0).values norm_dist = dist / torch.max(dist) # Clamp the distance to the margin norm_dist = torch.clamp(norm_dist + margin, 0, 1) # scale to 0 to 1 mindist = torch.min(norm_dist) maxdist = torch.max(norm_dist) norm_dist = (norm_dist - mindist) / (maxdist - mindist) # Apply a cosine-based blending function for smooth transition if cosine_blending: weights = alpha_edge + (alpha_center - alpha_edge) * ( 0.5 * (1 + torch.cos(norm_dist * math.pi)) ) else: weights = alpha_edge + (alpha_center - alpha_edge) * (1 - norm_dist) return weights def get_tv_resample_method(method_str: str) -> InterpolationMode: resample_method_dict = { "bilinear": InterpolationMode.BILINEAR, "bicubic": InterpolationMode.BICUBIC, "nearest": InterpolationMode.NEAREST_EXACT, "nearest-exact": InterpolationMode.NEAREST_EXACT, } resample_method = resample_method_dict.get(method_str, None) if resample_method is None: raise ValueError(f"Unknown resampling method: {resample_method}") else: return resample_method def resize_max_res( img: torch.Tensor, max_edge_resolution: int, resample_method: InterpolationMode = InterpolationMode.BILINEAR, ) -> torch.Tensor: """ Resize image to limit maximum edge length while keeping aspect ratio. Args: img (`torch.Tensor`): Image tensor to be resized. Expected shape: [B, C, H, W] max_edge_resolution (`int`): Maximum edge length (pixel). resample_method (`PIL.Image.Resampling`): Resampling method used to resize images. Returns: `torch.Tensor`: Resized image. """ assert 4 == img.dim(), f"Invalid input shape {img.shape}" original_height, original_width = img.shape[-2:] downscale_factor = min( max_edge_resolution / original_width, max_edge_resolution / original_height ) new_width = int(original_width * downscale_factor) new_height = int(original_height * downscale_factor) resized_img = resize(img, (new_height, new_width), resample_method, antialias=True) return resized_img def ensemble_depth( depth: torch.Tensor, scale_invariant: bool = True, shift_invariant: bool = True, output_uncertainty: bool = False, reduction: str = "median", regularizer_strength: float = 0.02, max_iter: int = 50, tol: float = 1e-6, max_res: int = 1024, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Ensembles depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The alignment happens when the predictions have one or more degrees of freedom, that is when they are either affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only `scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`) alignment is skipped and only ensembling is performed. Args: depth (`torch.Tensor`): Input ensemble depth maps. scale_invariant (`bool`, *optional*, defaults to `True`): Whether to treat predictions as scale-invariant. shift_invariant (`bool`, *optional*, defaults to `True`): Whether to treat predictions as shift-invariant. output_uncertainty (`bool`, *optional*, defaults to `False`): Whether to output uncertainty map. reduction (`str`, *optional*, defaults to `"median"`): Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and `"median"`. regularizer_strength (`float`, *optional*, defaults to `0.02`): Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1. max_iter (`int`, *optional*, defaults to `2`): Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options` argument. tol (`float`, *optional*, defaults to `1e-3`): Alignment solver tolerance. The solver stops when the tolerance is reached. max_res (`int`, *optional*, defaults to `1024`): Resolution at which the alignment is performed; `None` matches the `processing_resolution`. Returns: A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape: `(1, 1, H, W)`. """ if depth.dim() != 4 or depth.shape[1] != 1: raise ValueError(f"Expecting 4D tensor of shape [B,1,H,W]; got {depth.shape}.") if reduction not in ("mean", "median"): raise ValueError(f"Unrecognized reduction method: {reduction}.") if not scale_invariant and shift_invariant: raise ValueError("Pure shift-invariant ensembling is not supported.") def init_param(depth: torch.Tensor): init_min = depth.reshape(ensemble_size, -1).min(dim=1).values init_max = depth.reshape(ensemble_size, -1).max(dim=1).values if scale_invariant and shift_invariant: init_s = 1.0 / (init_max - init_min).clamp(min=1e-6) init_t = -init_s * init_min param = torch.cat((init_s, init_t)).cpu().numpy() elif scale_invariant: init_s = 1.0 / init_max.clamp(min=1e-6) param = init_s.cpu().numpy() else: raise ValueError("Unrecognized alignment.") return param.astype(np.float64) def align(depth: torch.Tensor, param: np.ndarray) -> torch.Tensor: if scale_invariant and shift_invariant: s, t = np.split(param, 2) s = torch.from_numpy(s).to(depth).view(ensemble_size, 1, 1, 1) t = torch.from_numpy(t).to(depth).view(ensemble_size, 1, 1, 1) out = depth * s + t elif scale_invariant: s = torch.from_numpy(param).to(depth).view(ensemble_size, 1, 1, 1) out = depth * s else: raise ValueError("Unrecognized alignment.") return out def ensemble( depth_aligned: torch.Tensor, return_uncertainty: bool = False ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: uncertainty = None if reduction == "mean": prediction = torch.mean(depth_aligned, dim=0, keepdim=True) if return_uncertainty: uncertainty = torch.std(depth_aligned, dim=0, keepdim=True) elif reduction == "median": prediction = torch.median(depth_aligned, dim=0, keepdim=True).values if return_uncertainty: uncertainty = torch.median( torch.abs(depth_aligned - prediction), dim=0, keepdim=True ).values else: raise ValueError(f"Unrecognized reduction method: {reduction}.") return prediction, uncertainty def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float: cost = 0.0 depth_aligned = align(depth, param) for i, j in torch.combinations(torch.arange(ensemble_size)): diff = depth_aligned[i] - depth_aligned[j] cost += (diff**2).mean().sqrt().item() if regularizer_strength > 0: prediction, _ = ensemble(depth_aligned, return_uncertainty=False) err_near = (0.0 - prediction.min()).abs().item() err_far = (1.0 - prediction.max()).abs().item() cost += (err_near + err_far) * regularizer_strength return cost def compute_param(depth: torch.Tensor): import scipy depth_to_align = depth.to(torch.float32) if max_res is not None and max(depth_to_align.shape[2:]) > max_res: depth_to_align = resize_max_res( depth_to_align, max_res, get_tv_resample_method("nearest-exact") ) param = init_param(depth_to_align) res = scipy.optimize.minimize( partial(cost_fn, depth=depth_to_align), param, method="BFGS", tol=tol, options={"maxiter": max_iter, "disp": False}, ) return res.x requires_aligning = scale_invariant or shift_invariant ensemble_size = depth.shape[0] if requires_aligning: param = compute_param(depth) depth = align(depth, param) depth, uncertainty = ensemble(depth, return_uncertainty=output_uncertainty) depth_max = depth.max() if scale_invariant and shift_invariant: depth_min = depth.min() elif scale_invariant: depth_min = 0 else: raise ValueError("Unrecognized alignment.") depth_range = (depth_max - depth_min).clamp(min=1e-6) depth = (depth - depth_min) / depth_range if output_uncertainty: uncertainty /= depth_range return depth, uncertainty # [1,1,H,W], [1,1,H,W] # Search table for suggested max. inference batch size bs_search_table = [ # tested on A100-PCIE-80GB {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32}, {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32}, # tested on A100-PCIE-40GB {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32}, {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32}, {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16}, {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16}, # tested on RTX3090, RTX4090 {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32}, {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32}, {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32}, {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16}, {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16}, {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16}, # tested on GTX1080Ti {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32}, {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32}, {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16}, {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16}, {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16}, ] def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int: """ Automatically search for suitable operating batch size. Args: ensemble_size (`int`): Number of predictions to be ensembled. input_res (`int`): Operating resolution of the input image. Returns: `int`: Operating batch size. """ if not torch.cuda.is_available(): return 1 total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3 filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype] for settings in sorted( filtered_bs_search_table, key=lambda k: (k["res"], -k["total_vram"]), ): if input_res <= settings["res"] and total_vram >= settings["total_vram"]: bs = settings["bs"] if bs > ensemble_size: bs = ensemble_size elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size: bs = math.ceil(ensemble_size / 2) return bs return 1