marigold-depth-hr-v1-1 / pipeline.py
nandometzger
allow base image as input
c076e6d
# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# More information about Marigold:
# https://marigoldmonodepth.github.io
# https://marigoldcomputervision.github.io
# Efficient inference pipelines are now part of diffusers:
# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
# https://huggingface.co/docs/diffusers/api/pipelines/marigold
# Examples of trained models and live demos:
# https://huggingface.co/prs-eth
# Related projects:
# https://rollingdepth.github.io/
# https://marigolddepthcompletion.github.io/
# Citation (BibTeX):
# https://github.com/prs-eth/Marigold#-citation
# If you find Marigold useful, we kindly ask you to cite our papers.
# --------------------------------------------------------------------------
import logging
import numpy as np
import torch
from typing import Dict, Union
import math
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DiffusionPipeline,
LCMScheduler,
UNet2DConditionModel,
AutoencoderTiny,
)
from diffusers.utils import BaseOutput
from PIL import Image
from torch.utils.data import DataLoader, TensorDataset
from torchvision.transforms.functional import resize, pil_to_tensor
from torchvision.transforms import InterpolationMode
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from functools import partial
from typing import Optional, Tuple
class MarigoldDepthOutput(BaseOutput):
"""
Output class for Marigold Monocular Depth Estimation pipeline.
Args:
depth_np (`np.ndarray`):
Predicted depth map, with depth values in the range of [0, 1].
base_depth_np (`np.ndarray`):
Upsampled base depth map, with depth values in the range of [0, 1].
This is the depth map used as a global guidance for the boosted inference.
It is upsampled to the same resolution as the final depth map.
This is useful for visualization and debugging purposes.
"""
depth_np: np.ndarray
base_depth_np: np.ndarray # NEW: upsampled base depth
class MarigoldDepthHRPipeline(DiffusionPipeline):
"""
Pipeline for high resolution monocular depth estimation using Marigold: https://marigoldcomputervision.github.io.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
unet (`UNet2DConditionModel`):
Conditional U-Net to denoise the prediction latent, conditioned on image latent.
vae (`AutoencoderKL`):
Variational Auto-Encoder (VAE) Model to encode and decode images and predictions
to and from latent representations.
scheduler (`DDIMScheduler`):
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
text_encoder (`CLIPTextModel`):
Text-encoder, for empty text embedding.
tokenizer (`CLIPTokenizer`):
CLIP tokenizer.
boosting_unet (`UNet2DConditionModel`):
Conditional U-Net to denoise the depth latent, conditioned on image latent and a global depth map.
scale_invariant (`bool`, *optional*):
A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in
the model config. When used together with the `shift_invariant=True` flag, the model is also called
"affine-invariant". NB: overriding this value is not supported.
shift_invariant (`bool`, *optional*):
A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in
the model config. When used together with the `scale_invariant=True` flag, the model is also called
"affine-invariant". NB: overriding this value is not supported.
default_denoising_steps (`int`, *optional*):
The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
quality with the given model. This value must be set in the model config. When the pipeline is called
without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
reasonable results with various model flavors compatible with the pipeline, such as those relying on very
short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
default_boosting_denoising_steps (`int`, *optional*):
Same as `default_denoising_steps` but for `boosting_unet`.
default_processing_resolution (`int`, *optional*):
The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
default value is used. This is required to ensure reasonable results with various model flavors trained
with varying optimal processing resolution values.
"""
latent_scale_factor = 0.18215
def __init__(
self,
unet: UNet2DConditionModel,
vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, LCMScheduler],
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
boosting_unet: Optional[UNet2DConditionModel],
scale_invariant: Optional[bool] = True,
shift_invariant: Optional[bool] = True,
default_denoising_steps: Optional[int] = None,
default_boosting_denoising_steps: Optional[int] = None,
default_processing_resolution: Optional[int] = None,
base_depth_model_uri: Optional[str] = None,
variant: Optional[str] = None,
):
super().__init__()
if boosting_unet is None:
logging.warning(
"Boosting U-Net is not provided. If this message appears during training, it is expected."
)
self.register_modules(
unet=unet,
vae=vae,
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
boosting_unet=boosting_unet,
)
self.register_to_config(
scale_invariant=scale_invariant,
shift_invariant=shift_invariant,
default_denoising_steps=default_denoising_steps,
default_boosting_denoising_steps=default_denoising_steps,
default_processing_resolution=default_processing_resolution,
)
self.register_to_config(base_depth_model_uri=base_depth_model_uri)
self.scale_invariant = scale_invariant
self.shift_invariant = shift_invariant
self.default_denoising_steps = default_denoising_steps
self.default_boosting_denoising_steps = default_boosting_denoising_steps
self.default_processing_resolution = default_processing_resolution
if base_depth_model_uri is not None:
# load the original LR depth model
self.base_pipe = DiffusionPipeline.from_pretrained(
base_depth_model_uri,
variant=variant,
torch_dtype=self.dtype,
trust_remote_code=True,
)
self.base_pipe.to(self.device)
else:
self.base_pipe = None
self.empty_text_embed = None
@torch.no_grad()
def __call__(
self,
input_image: Union[Image.Image, torch.Tensor],
*,
base_depth: Optional[Union[Image.Image, np.ndarray, torch.Tensor]] = None,
denoising_steps: Optional[int] = None,
boosted_denoising_steps: Optional[int] = None,
ensemble_size: int = 10,
boosted_ensemble_size: int = 5,
processing_res: Optional[int] = None,
match_input_res: bool = True,
resample_method: str = "bilinear",
batch_size: int = 0,
show_progress_bar: bool = True,
ensemble_kwargs: Dict = None,
upscale_factor: int = 2,
) -> MarigoldDepthOutput:
"""
Function invoked when calling the pipeline.
Args:
input_image (`Image` or `torch.Tensor`):
Input RGB (or gray-scale) image.
base_depth (`Image`, `np.ndarray`, `torch.Tensor` or `MarigoldDepthOutput`, *optional*):
Base depth map to be used as a global guidance for the boosted inference.
denoising_steps (`int`, *optional*, defaults to `10`):
Number of diffusion denoising steps (DDIM) during inference.
ensemble_size (`int`, *optional*, defaults to `10`):
Number of predictions to be ensembled.
boosted_ensemble_size (`int`, *optional*, defaults to `5`):
Number of predictions to be ensembled in the boosted inference.
processing_res (`int`, *optional*, defaults to `768`):
Maximum resolution of processing.
If set to 0: will not resize at all.
match_input_res (`bool`, *optional*, defaults to `True`):
Resize depth prediction to match input resolution.
Only valid if `processing_res` > 0.
resample_method: (`str`, *optional*, defaults to `bilinear`):
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
batch_size (`int`, *optional*, defaults to `0`):
Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size.
show_progress_bar (`bool`, *optional*, defaults to `True`):
Display a progress bar of diffusion denoising.
scale_invariant (`str`, *optional*, defaults to `True`):
Flag of scale-invariant prediction, if True, scale will be adjusted from the raw prediction.
shift_invariant (`str`, *optional*, defaults to `True`):
Flag of shift-invariant prediction, if True, shift will be adjusted from the raw prediction, if False,
near plane will be fixed at 0m.
ensemble_kwargs (`dict`, *optional*, defaults to `None`):
Arguments for detailed ensembling settings.
Returns:
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
- **base_depth_np** (`np.ndarray`) Upsampled base depth map, with depth values in the range of [0, 1].
This is the depth map used as a global guidance for the boosted inference.
"""
# Model-specific optimal default values leading to fast and reasonable results.
if denoising_steps is None:
denoising_steps = self.default_denoising_steps
if boosted_denoising_steps is None:
boosted_denoising_steps = self.default_boosting_denoising_steps
if processing_res is None:
processing_res = self.default_processing_resolution
# Asserts
assert processing_res >= 0, "Processing resolution must be non-negative."
assert ensemble_size >= 1, "Ensemble size must be at least 1."
assert boosted_ensemble_size >= 1, "Boosted ensemble size must be at least 1."
assert math.log2(
upscale_factor
).is_integer(), "Upscale factor must be a power of 2."
assert upscale_factor >= 1, "Upscale factor must be at least 2."
assert batch_size >= 0, "Batch size must be non-negative."
# Warnings
if upscale_factor >= 8 and self.dtype == torch.float16:
logging.warning(
"Warning: Upscaling factors of 8 (and more) with half precision or more may lead to artifacts in the final prediction."
)
if upscale_factor >= 4 and isinstance(self.vae, AutoencoderTiny):
logging.warning(
"Warning: Upscaling factors of 4 (and more) with the Tiny VAE may lead to instabilities."
)
# Get the resolution of the input RGB image
input_width, input_height = (
input_image.size
if isinstance(input_image, Image.Image)
else input_image.shape[-2:]
)
# 1) get base prediction
if base_depth is not None:
# load into float32 np.ndarray
if isinstance(base_depth, Image.Image):
lowres = np.asarray(base_depth.convert("L"), dtype=np.float32)
elif isinstance(base_depth, torch.Tensor):
lowres = base_depth.squeeze().cpu().numpy().astype(np.float32)
elif isinstance(base_depth, np.ndarray):
lowres = base_depth.squeeze().astype(np.float32)
elif isinstance(base_depth, MarigoldDepthOutput):
lowres = base_depth.depth_np.astype(np.float32)
else:
raise TypeError(f"Unsupported base_depth type: {type(base_depth)}")
# *** min–max normalize to [0,1] ***
min_v, max_v = lowres.min(), lowres.max()
eps = 1e-8
if max_v - min_v > eps:
lowres = (lowres - min_v) / (max_v - min_v + eps)
else:
# flat image → all zeros
lowres = np.zeros_like(lowres, dtype=np.float32)
else:
assert self.base_pipe is not None
if self.base_pipe.device != self.device:
# Move the base pipe to the correct device
self.base_pipe.to(self.device)
base_out = self.base_pipe(
input_image,
num_inference_steps=denoising_steps,
ensemble_size=ensemble_size,
processing_resolution=processing_res,
match_input_resolution=False,
batch_size=1, # base inference is always done in batch size 1
# show_progress_bar=show_progress_bar,
resample_method_input=resample_method,
resample_method_output=resample_method,
ensembling_kwargs=ensemble_kwargs,
)
lowres = base_out.prediction[0,:,:,0] # [H, W]
base_out.depth_np = lowres
# 2) Upsample base for output
t = torch.from_numpy(lowres[None,None])
up = resize(t, (input_height, input_width),
interpolation=InterpolationMode.NEAREST_EXACT,
antialias=True)
base_depth_np_upsampled = up.squeeze().cpu().numpy()
# 3) If no boosting requested, return early
if upscale_factor == 1:
# If no upscaling is needed, return the base prediction
return MarigoldDepthOutput(
depth_np=lowres,
base_depth_np=base_depth_np_upsampled
)
# 4) Normalize and run boosted inference (unchanged)
global_pred = torch.from_numpy(lowres).to(self.device)
global_pred = (global_pred - global_pred.min()) / (global_pred.max() - global_pred.min())
# Iterative refinement logic
current_pred = global_pred
current_factor = 2 # Start with an upscale factor of 2
# Create a list of all upscale factors up to the target
upscale_factors = [2**i for i in range(1, int(math.log2(upscale_factor)) + 1)]
# precalculate patch dimensions
if processing_res == 0:
processing_res = 768
df = min(
processing_res / input_image.width, processing_res / input_image.height
)
patch_height = int(input_image.height * df)
patch_width = int(input_image.width * df)
# Pre‐warn about any that exceed the 1.1× threshold
for factor in upscale_factors:
tw = patch_width * factor
th = patch_height * factor
if tw > input_width * 1.1 or th > input_height * 1.1:
logging.warning(
f"Warning: Attempting to upsample to {tw}×{th}, "
f"which exceeds the original input of {input_width}×{input_height}. "
"This technically works, but may lead to suboptimal results."
)
# 5) Perform iterative boosted inference
with tqdm(
total=len(upscale_factors),
desc=" Upscaling Progress",
unit="step",
leave=False,
) as pbar:
for current_factor in upscale_factors:
# Update the description with the current upscaling factor
pbar.set_description(f" Upscaling x{current_factor}")
# Determine if this is the final step
is_final_step = current_factor == upscale_factor
# 2. Perform a single boosted inference step
boosted_output = self.boosted_inference(
input_image=input_image,
denoising_steps=boosted_denoising_steps,
ensemble_size=(
boosted_ensemble_size
if current_factor < upscale_factors[-1]
else boosted_ensemble_size
),
processing_res=processing_res,
match_input_res=match_input_res and is_final_step,
batch_size=batch_size,
resample_method=resample_method,
show_progress_bar=show_progress_bar,
ensemble_kwargs=ensemble_kwargs,
global_pred=current_pred,
upscale_factor=current_factor,
)
# Update predictions
current_pred = torch.from_numpy(boosted_output.depth_np)
# Clean up GPU memory
torch.cuda.empty_cache()
# Progress to the next upscale factor
current_factor *= 2
# Update the progress bar
pbar.update(1)
# Return the final output, and attach base depth map
out = boosted_output
out.base_depth_np = base_depth_np_upsampled
return out
def boosted_inference(
self,
input_image: Union[torch.Tensor],
denoising_steps: int = 10,
ensemble_size: int = 10,
processing_res: int = 768,
match_input_res: bool = True,
batch_size: int = 0,
resample_method: str = "bilinear",
seed: Union[int, None] = None,
show_progress_bar: bool = True,
ensemble_kwargs: Dict = None,
global_pred: torch.Tensor = None,
upscale_factor: int = 2,
) -> MarigoldDepthOutput:
"""
Function invoked when calling the pipeline with boosted inference.
Args:
input_image (`torch.Tensor`):
Input RGB image.
denoising_steps (`int`, *optional*, defaults to `10`):
Number of diffusion denoising steps (DDIM) during inference.
ensemble_size (`int`, *optional*, defaults to `10`):
Number of predictions to be ensembled.
processing_res (`int`, *optional*, defaults to `768`):
Maximum resolution of processing.
If set to 0: will not resize at all.
match_input_res (`bool`, *optional*, defaults to `True`):
Resize depth prediction to match input resolution.
Only valid if `processing_res` > 0.
resample_method: (`str`, *optional*, defaults to `bilinear`):
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
batch_size (`int`, *optional*, defaults to `0`):
Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size.
seed (`int`, *optional*, defaults to `None`):
Random seed for the diffusion process.
show_progress_bar (`bool`, *optional*, defaults to `True`):
Display a progress bar of diffusion denoising.
ensemble_kwargs (`dict`, *optional*, defaults to `None`):
Arguments for detailed ensembling settings.
global_pred (`torch.Tensor`):
Global depth map to be used as guidance.
upscale_factor (`int`, *optional*, defaults to `2`):
Upscale factor of the global depth map.
Returns:
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
- **depth_np** (`np.ndarray`) Predicted depth map with depth values in the range of [0, 1]
- **base_depth_np** (`np.ndarray`) Upsampled base depth map with depth values in the range of [0, 1].
"""
device = self.device
self._check_inference_step(denoising_steps)
resample_method: InterpolationMode = get_tv_resample_method(resample_method)
# Convert to torch tensor
if isinstance(input_image, Image.Image):
input_image = input_image.convert("RGB")
# convert to torch tensor [H, W, rgb] -> [rgb, H, W]
input_image = pil_to_tensor(input_image)
elif isinstance(input_image, torch.Tensor):
input_image = input_image.squeeze()
# pass
else:
raise TypeError(f"Unknown input type: {type(input_image) = }")
input_size = input_image.shape
assert (
3 == input_image.dim() and 3 == input_size[0]
), f"Wrong input shape {input_size}, expected [rgb, H, W]"
if isinstance(global_pred, torch.Tensor):
global_pred = global_pred.squeeze().unsqueeze(0).to(device)
else:
raise TypeError(f"Unknown global_pred type: {type(global_pred) = }")
if processing_res == 0:
# fallback to original resolution
processing_res = 768
df = min(
processing_res / input_image.shape[2], processing_res / input_image.shape[1]
)
patch_height = int(input_image.shape[1] * df)
patch_width = int(input_image.shape[2] * df)
# need to be divisible by 8
patch_height = round(patch_height / 16) * 16
patch_width = round(patch_width / 16) * 16
patch_size = patch_height, patch_width
global_size = (patch_size[0] * upscale_factor, patch_size[1] * upscale_factor)
if global_pred.shape[1:] != global_size:
global_pred = resize(
global_pred,
global_size,
interpolation=resample_method,
antialias=True,
)
if input_image.shape[1:] != patch_size:
input_image = resize(
input_image,
global_size,
interpolation=resample_method,
antialias=True,
).squeeze()
input_image = (
input_image.unsqueeze(0) / 255.0 * 2.0 - 1.0
) # [0, 255] -> [-1, 1]
input_image = input_image.to(self.dtype).to(device)
assert input_image.min() >= -1.0 and input_image.max() <= 1.0
global_pred = global_pred.to(self.dtype).to(device)
if batch_size > 0:
_bs = batch_size
else:
_bs = find_batch_size(
ensemble_size=ensemble_size
* (2 * upscale_factor - 1)
* (2 * upscale_factor - 1),
input_res=max(patch_size),
dtype=self.dtype,
)
# create a small buffer in z-dimension
global_pred = 0.9 * (global_pred * 2 - 1)
global_pred = (global_pred + 1) / 2
depth_pred, pred_uncert = self.multidiffusion_inference(
rgb_norm=input_image,
global_pred=global_pred,
num_inference_steps=denoising_steps,
patch_size=patch_size,
seed=seed,
show_pbar=show_progress_bar,
ensemble_size=ensemble_size,
batch_size=_bs,
ensemble_kwargs=ensemble_kwargs,
)
depth_pred = depth_pred.squeeze(0)
# rescale to to [0, 1]
min_d = torch.min(depth_pred)
max_d = torch.max(depth_pred)
depth_pred = (depth_pred - min_d) / (max_d - min_d)
if depth_pred.shape[1:] != input_size[1:] and match_input_res:
depth_pred = resize(
depth_pred.unsqueeze(0),
input_size[1:],
interpolation=resample_method,
antialias=True,
).squeeze()
depth_pred = depth_pred.squeeze().cpu().numpy()
depth_pred = depth_pred.clip(0, 1)
return MarigoldDepthOutput(
depth_np=depth_pred,
)
def _check_inference_step(self, n_step: int) -> None:
"""
Check if denoising step is reasonable
Args:
n_step (`int`): denoising steps
"""
assert n_step >= 1
def encode_empty_text(self):
"""
Encode text embedding for empty prompt
"""
prompt = ""
text_inputs = self.tokenizer(
prompt,
padding="do_not_pad",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
@torch.no_grad()
def multidiffusion_inference(
self,
rgb_norm: torch.Tensor,
global_pred: torch.Tensor,
num_inference_steps: int,
seed: Union[int, None],
show_pbar: bool,
patch_size=(512, 768),
encoder_patch_size=None,
ensemble_size=1,
batch_size=1,
ensemble_kwargs: Dict = None,
) -> torch.Tensor:
"""
Perform an individual depth prediction without ensembling.
Args:
rgb_norm (`torch.Tensor`):
Input RGB image.
num_inference_steps (`int`):
Number of diffusion denoisign steps (DDIM) during inference.
num_patches_vert (`int`):
Number of vertical patches.
num_patches_horz (`int`):
Number of horizontal patches.
step_height (`int`):
Height of the patch.
step_width (`int`):
Width of the patch.
Returns:
`torch.Tensor`: Predicted depth map.
`torch.Tensor`: Uncertainty map.
"""
device = self.device
rgb_norm = rgb_norm.to(device)
global_pred = global_pred.to(device)
# Set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
latent_patch_size = (patch_size[0] // 8, patch_size[1] // 8)
latent_full_size = (rgb_norm.shape[-2] // 8, rgb_norm.shape[-1] // 8)
# Normalize the global prediction to [-1, 1]
global_pred = global_pred * 2 - 1
# Encode rgb and depth map
if encoder_patch_size is None:
encoder_patch_size = patch_size
rgb_latent = self.encode_rgb_patched(
rgb_norm, patch_size=encoder_patch_size, show_pbar=show_pbar
)
global_pred_latent = self.encode_depth_patched(
global_pred, patch_size=encoder_patch_size, show_pbar=show_pbar
)
# offload to cpu for memory efficiency
rgb_norm = rgb_norm.cpu()
global_pred = global_pred.cpu()
# Initial depth map (noise)
if seed is None:
rand_num_generator = None
else:
rand_num_generator = torch.Generator(device=device)
rand_num_generator.manual_seed(seed)
# patch the input images
(
rgb_latent_patched,
num_patches_vert,
num_patches_horz,
step_height,
step_width,
) = self.extract_patches(rgb_latent[0], patch_size=latent_patch_size)
# also patch the global depth map
global_pred_latent_patched, _, _, _, _ = self.extract_patches(
global_pred_latent[0], patch_size=latent_patch_size
)
# Batched empty text embedding
if self.empty_text_embed is None:
self.encode_empty_text()
batch_empty_text_embed = self.empty_text_embed.repeat((batch_size, 1, 1)).to(
device
)
if hasattr(self, "pooled_empty_text_embeds"):
batch_pooled_empty_text_embed = self.pooled_empty_text_embeds.repeat(
(batch_size, 1, 1)
).to(device)
# enlarge the variable according to the ensemble size
if ensemble_size > 1:
rgb_latent_patched = rgb_latent_patched.repeat(ensemble_size, 1, 1, 1)
global_pred_latent_patched = global_pred_latent_patched.repeat(
ensemble_size, 1, 1, 1
)
batch_empty_text_embed = batch_empty_text_embed.repeat(ensemble_size, 1, 1)
if hasattr(self, "pooled_empty_text_embeds"):
batch_pooled_empty_text_embed = batch_pooled_empty_text_embed.repeat(
ensemble_size, 1, 1
)
# Initialize the canvas and split it to get identical noise on overlaps
depth_latent = torch.randn(
(ensemble_size, 4, latent_full_size[0], latent_full_size[1]),
device=device,
dtype=self.dtype,
generator=rand_num_generator,
)
(
depth_latent_patched,
num_patches_vert,
num_patches_horz,
step_height,
step_width,
) = self.extract_patches(depth_latent, patch_size=latent_patch_size)
# Denoising loop
if show_pbar:
iterable = tqdm(
enumerate(timesteps),
total=len(timesteps),
leave=False,
desc=" " * 4 + "Diffusion denoising",
)
else:
iterable = enumerate(timesteps)
for _, t in iterable:
# 1. inference all the patches with unet
assert (
self.boosting_unet.conv_in.in_channels == 12
), "The input channels of the boosting unet must be 12."
unet_input = torch.cat(
[rgb_latent_patched, global_pred_latent_patched, depth_latent_patched],
dim=1,
)
# 2. Create a dataloader and predict the noise
dataset = TensorDataset(unet_input)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
noise_preds = []
for batch in tqdm(
loader,
leave=False,
disable=not show_pbar,
desc=" " * 6 + "UNet patch inference",
):
(unet_input,) = batch
noise_pred = self.boosting_unet(
unet_input,
t,
encoder_hidden_states=batch_empty_text_embed[: unet_input.shape[0]],
).sample
noise_preds.append(noise_pred)
noise_preds = torch.concat(noise_preds, dim=0)
# 3. Default ddim scheduler step for each patch
scheduler_out = self.scheduler.step(
noise_preds, t, depth_latent_patched, generator=rand_num_generator
)
# 4. Reshape patches to patch-spatial dimension
depth_latent = scheduler_out.prev_sample.reshape(
ensemble_size,
num_patches_vert,
num_patches_horz,
4,
latent_patch_size[0],
latent_patch_size[1],
)
# 5. Blend the patches with multidiffusion formula
depth_latent_full = self.blend_patches(
depth_latent,
canvas_size=global_pred_latent.shape[2:],
num_patches_vert=num_patches_vert,
num_patches_horz=num_patches_horz,
step_height=step_height,
step_width=step_width,
)
# 6. Update the depth_latent_patched
if t != timesteps[-1]:
depth_latent_patched, _, _, _, _ = self.extract_patches(
depth_latent_full, patch_size=latent_patch_size
)
# if t<=1: at the end of the loop we decode the full latent
depth = self.decode_depth_patched(
depth_latent_full=depth_latent_full,
canvas_size=global_pred.shape[1:],
latent_decoder_patch_size=(
encoder_patch_size[0] // 8,
encoder_patch_size[1] // 8,
),
show_pbar=show_pbar,
ensemble_size=ensemble_size,
)
if ensemble_size > 1:
depth, pred_uncert = ensemble_depth(
depth,
scale_invariant=self.scale_invariant,
shift_invariant=self.shift_invariant,
**(ensemble_kwargs or {}),
)
else:
depth = (depth + 1.0) / 2.0
pred_uncert = None
return depth, pred_uncert
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
"""
Encode RGB image into latent.
Args:
rgb_in (`torch.Tensor`):
Input RGB image to be encoded.
Returns:
`torch.Tensor`: Image latent.
"""
if isinstance(self.vae, AutoencoderTiny):
rgb_latent = self.vae.encoder(rgb_in)
else:
h = self.vae.encoder(rgb_in)
moments = self.vae.quant_conv(h)
mean, _ = torch.chunk(moments, 2, dim=1)
rgb_latent = mean * self.latent_scale_factor
return rgb_latent
def encode_rgb_patched(
self, rgb_in: torch.Tensor, patch_size: tuple, show_pbar: bool = False
) -> torch.Tensor:
"""
Encode depth map into latent.
Args:
depth_in (`torch.Tensor`):
Input depth map to be encoded.
patch_size (`Tuple[int, int]`):
Size of the patch.
Returns:
`torch.Tensor`: Depth latent.
"""
device = self.device
(
rgb_patched,
num_patches_vert_pix,
num_patches_horz_pix,
step_height_pix,
step_width_pix,
) = self.extract_patches(rgb_in.squeeze(0), patch_size=patch_size)
rgb_in = rgb_in.cpu()
rgb_patched = rgb_patched.cpu()
# forward the patches
rgb_latent_patched = []
for rgb in tqdm(
rgb_patched,
leave=False,
disable=not show_pbar,
desc=" " * 4 + "Encoding RGB",
):
patch = rgb.unsqueeze(0).to(device)
rgb_latent_patched.append(self.encode_rgb(patch).cpu())
# reshape the spatial dimensions
rgb_latent_patched = torch.concat(rgb_latent_patched, dim=0)
rgb_latent_patched = rgb_latent_patched.reshape(
num_patches_vert_pix,
num_patches_horz_pix,
4,
rgb_latent_patched.shape[-2],
rgb_latent_patched.shape[-1],
).to(device)
# blend the patches
rgb_latent = self.blend_patches(
rgb_latent_patched,
canvas_size=(rgb_in.shape[-2] // 8, rgb_in.shape[-1] // 8),
num_patches_vert=num_patches_vert_pix,
num_patches_horz=num_patches_horz_pix,
step_height=step_height_pix // 8,
step_width=step_width_pix // 8,
)
if len(rgb_latent.shape) == 3:
rgb_latent = rgb_latent.unsqueeze(0)
return rgb_latent
def encode_depth_patched(
self,
depth_in: torch.Tensor,
patch_size,
show_pbar: bool = False,
) -> torch.Tensor:
"""
Encode depth map into latent, but in a patched and scalable way
Args:
depth_in (`torch.Tensor`):
Input depth map to be encoded.
patch_size (`Tuple[int, int]`):
Size of the patch.
show_pbar (`bool`):
Display a progress bar.
Returns:
`torch.Tensor`: Depth latent.
"""
device = self.device
ensemble_size = depth_in.shape[0]
(
depth_in_patched,
num_patches_vert_pix,
num_patches_horz_pix,
step_height_pix,
step_width_pix,
) = self.extract_patches(depth_in, patch_size=patch_size)
depth_in = depth_in.cpu()
depth_in_patched = depth_in_patched.cpu()
# forward the patches
depth_latent_patched = []
for gpred in tqdm(
depth_in_patched,
leave=False,
disable=not show_pbar,
desc=" " * 4 + "Encoding context depth",
):
patch = gpred.unsqueeze(0).to(device)
depth_latent_patched.append(self.encode_depth(patch).cpu())
# reshape the spatial dimensions
depth_latent_patched = torch.concat(depth_latent_patched, dim=0)
depth_latent_patched = depth_latent_patched.reshape(
ensemble_size,
num_patches_vert_pix,
num_patches_horz_pix,
4,
depth_latent_patched.shape[-2],
depth_latent_patched.shape[-1],
).to(device)
# blend the patches
depth_latent = self.blend_patches(
depth_latent_patched,
canvas_size=(depth_in.shape[-2] // 8, depth_in.shape[-1] // 8),
num_patches_vert=num_patches_vert_pix,
num_patches_horz=num_patches_horz_pix,
step_height=step_height_pix // 8,
step_width=step_width_pix // 8,
)
if len(depth_latent.shape) == 3:
depth_latent = depth_latent.unsqueeze(0)
return depth_latent
def decode_depth_patched(
self,
depth_latent_full: torch.Tensor,
canvas_size: tuple,
latent_decoder_patch_size: tuple,
show_pbar: bool = True,
ensemble_size: int = 1,
) -> torch.Tensor:
"""
Decode depth map from latent in a patched and scalable way.
Args:
depth_latent_full (`torch.Tensor`):
Depth latent to be decoded.
canvas_size (`tuple`):
Size of the canvas.
latent_decoder_patch_size (`tuple`):
Size of the patch.
show_pbar (`bool`):
Display a progress bar.
ensemble_size (`int`):
Ensemble size.
Returns:
`torch.Tensor`: Decoded depth map.
"""
encoder_patch_size = (
latent_decoder_patch_size[0] * 8,
latent_decoder_patch_size[1] * 8,
)
# extract patches
(
depth_latent_patched,
num_patches_vert,
num_patches_horz,
step_height,
step_width,
) = self.extract_patches(
depth_latent_full, patch_size=latent_decoder_patch_size, overlap=0.5
)
# decode patches
depthp = []
for patch in tqdm(
depth_latent_patched,
leave=False,
desc=" " * 6 + "Decoding Depth",
disable=not show_pbar,
):
depthp.append(self.decode_depth(patch.unsqueeze(0)))
depthp = torch.concat(depthp, dim=0)
depthp = depthp.reshape(
ensemble_size,
num_patches_vert,
num_patches_horz,
1,
encoder_patch_size[0],
encoder_patch_size[1],
)
# blend together
depth = self.blend_patches(
depthp,
canvas_size=canvas_size,
num_patches_vert=num_patches_vert,
num_patches_horz=num_patches_horz,
step_height=step_height * 8,
step_width=step_width * 8,
)
return depth
def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
"""
Decode depth latent into depth map.
Args:
depth_latent (`torch.Tensor`):
Depth latent to be decoded.
Returns:
`torch.Tensor`: Decoded depth map of the input depth latent.
"""
# if self.using_tiny_vae:
if isinstance(self.vae, AutoencoderTiny):
stacked = self.vae.decoder(depth_latent)
else:
depth_latent = depth_latent / self.latent_scale_factor
z = self.vae.post_quant_conv(depth_latent)
stacked = self.vae.decoder(z)
# mean of output channels
depth_mean = stacked.mean(dim=1, keepdim=True)
return depth_mean
def encode_depth(
self,
depth_in: torch.Tensor,
) -> torch.Tensor:
"""
Encode depth map into latent.
Args:
depth_in (`torch.Tensor`):
Input depth map to be encoded.
Returns:
`torch.Tensor`: Depth latent of the input depth map.
"""
# stack depth into 3-channel
stacked = self.stack_depth_images(depth_in)
# encode using VAE encoder
depth_latent = self.encode_rgb(stacked)
return depth_latent
@staticmethod
def stack_depth_images(depth_in):
if 4 == len(depth_in.shape):
stacked = depth_in.repeat(1, 3, 1, 1)
elif 3 == len(depth_in.shape):
stacked = depth_in.unsqueeze(1)
stacked = depth_in.repeat(1, 3, 1, 1)
return stacked
def extract_patches(self, canvas_input_image, patch_size, overlap=0.5):
"""
Extract patches from an image
Args:
image (`np.ndarray`): Input image of shape [channels, height, width]
patch_size (`tuple`): Size of the patch
step_size (`int`): Step size
Returns:
input_image_patched (`torch.Tensor`): Extracted patches
num_patches_vert (`int`): Number of patches in the vertical direction
num_patches_horz (`int`): Number of patches in the horizontal direction
step_height (`int`): Step size in the vertical direction
step_width (`int`): Step size in the horizontal direction
"""
if len(canvas_input_image.shape) == 4:
ensemble_size = canvas_input_image.shape[0]
else:
canvas_input_image = canvas_input_image.unsqueeze(0)
ensemble_size = 1
# Step sizes (50% (overlap) of the original dimensions)
h, w = patch_size
step_height, step_width = int(h * (1 - overlap)), int(w * (1 - overlap))
# Calculate the number of patches to extract in both dimensions
num_patches_vert = (canvas_input_image.shape[-2] - h) // step_height + 1
num_patches_horz = (canvas_input_image.shape[-1] - w) // step_width + 1
# Initialize a list to hold the patches
patches = []
for e in range(ensemble_size):
for i in range(num_patches_vert):
for j in range(num_patches_horz):
# Calculate the top left corner of the current patch
start_y = i * step_height
start_x = j * step_width
# Extract the patch
patch = canvas_input_image[
e, :, start_y : start_y + h, start_x : start_x + w
]
patches.append(patch)
# Stack the patches and return
input_image_patched = torch.stack(patches)
return (
input_image_patched,
num_patches_vert,
num_patches_horz,
step_height,
step_width,
)
def blend_patches(
self,
depth_preds,
num_patches_vert,
num_patches_horz,
step_height,
step_width,
global_depth_pred=None,
canvas_size=None,
alpha_center=1.0,
alpha_edge=1e-4,
noise_blend=False,
eps=1e-8,
):
"""
Blend patches of depth maps and apply the transformation to the global depth map
Args:
global_depth_pred (`torch.Tensor`): Global depth map
depth_preds (`torch.Tensor`): Local depth maps
num_patches_vert (`int`): Number of patches in the vertical direction
num_patches_horz (`int`): Number of patches in the horizontal direction
step_height (`int`): Step size in the vertical direction
step_width (`int`): Step size in the horizontal direction
overlap (`float`): Overlap between patches
alpha_center (`float`): Weight at the center of the patch
alpha_edge (`float`): Weight at the edge of the patch, should not be exactly 0 to avoid division by zero
adjust_LSQ (`bool`): Adjust the transformation using least squares optimization
Returns:
`torch.Tensor`: Blended depth map
"""
eps = 1e-8
channels, h, w = (
depth_preds.shape[-3],
depth_preds.shape[-2],
depth_preds.shape[-1],
)
if len(depth_preds.shape) == 6:
ensemble_size = depth_preds.shape[0] if len(depth_preds.shape) == 6 else 1
elif len(depth_preds.shape) == 5:
ensemble_size = 1
depth_preds = depth_preds.unsqueeze(0)
else:
raise ValueError("depth_preds should have 5 or 6 dimensions")
# Initialize the canvas for blending depth maps
blended_depth_map = torch.zeros(
(ensemble_size, channels, canvas_size[0], canvas_size[1]),
device=depth_preds.device,
dtype=depth_preds.dtype,
)
denominator_map = torch.zeros(
(1, canvas_size[0], canvas_size[1]),
device=depth_preds.device,
dtype=depth_preds.dtype,
)
for i in range(num_patches_vert):
for j in range(num_patches_horz):
# Calculate the top left corner of the current patch
start_y = i * step_height
start_x = j * step_width
# Extract the patch
local_patch = depth_preds[:, i, j]
# Generate blending weights
weights = self.get_linear_weight_map(
h,
w,
device=local_patch.device,
alpha_center=alpha_center,
alpha_edge=alpha_edge,
cosine_blending=True,
margin=0.0,
)
# accumulate the local patch to the canvas as a linear combination
blended_depth_map[
:, :, start_y : start_y + h, start_x : start_x + w
] += (local_patch * weights)
if noise_blend:
denominator_weights = weights**2
else:
denominator_weights = weights
denominator_map[
:, start_y : start_y + h, start_x : start_x + w
] += denominator_weights
if noise_blend:
denominator_map = torch.sqrt(denominator_map)
blended_depth_map /= denominator_map + eps
# Ensure that blended_depth_map does not have NaN values
# by filling them with the global_depth_pred
if global_depth_pred is not None:
blended_depth_map[torch.isnan(blended_depth_map)] = (
global_depth_pred.repeat(ensemble_size, 1, 1, 1)[
torch.isnan(blended_depth_map)
]
)
return blended_depth_map
def get_linear_weight_map(
self,
h,
w,
device,
alpha_center=1.0,
alpha_edge=1e-4,
margin=0.0,
cosine_blending=False,
):
"""
Generate a linear weight map for blending patches
Args:
h (`int`): Height of the weight map
w (`int`): Width of the weight map
device (`torch.device`): Device to use
alpha_center (`float`): Weight at the center of the patch
alpha_edge (`float`): Weight at the edge of the patch
margin (`int`): Perceptage of image dimensions. This margin at the edges to be filled with near 0 values.
Returns:
`torch.Tensor`: Linear distance weight map (looks like a pyramid)
"""
x = torch.linspace(-1, 1, h, device=device)
y = torch.linspace(-1, 1, w, device=device)
xx, yy = torch.meshgrid(x, y, indexing="ij")
dist = torch.stack([xx.abs(), yy.abs()]).max(dim=0).values
norm_dist = dist / torch.max(dist)
# Clamp the distance to the margin
norm_dist = torch.clamp(norm_dist + margin, 0, 1)
# scale to 0 to 1
mindist = torch.min(norm_dist)
maxdist = torch.max(norm_dist)
norm_dist = (norm_dist - mindist) / (maxdist - mindist)
# Apply a cosine-based blending function for smooth transition
if cosine_blending:
weights = alpha_edge + (alpha_center - alpha_edge) * (
0.5 * (1 + torch.cos(norm_dist * math.pi))
)
else:
weights = alpha_edge + (alpha_center - alpha_edge) * (1 - norm_dist)
return weights
def get_tv_resample_method(method_str: str) -> InterpolationMode:
resample_method_dict = {
"bilinear": InterpolationMode.BILINEAR,
"bicubic": InterpolationMode.BICUBIC,
"nearest": InterpolationMode.NEAREST_EXACT,
"nearest-exact": InterpolationMode.NEAREST_EXACT,
}
resample_method = resample_method_dict.get(method_str, None)
if resample_method is None:
raise ValueError(f"Unknown resampling method: {resample_method}")
else:
return resample_method
def resize_max_res(
img: torch.Tensor,
max_edge_resolution: int,
resample_method: InterpolationMode = InterpolationMode.BILINEAR,
) -> torch.Tensor:
"""
Resize image to limit maximum edge length while keeping aspect ratio.
Args:
img (`torch.Tensor`):
Image tensor to be resized. Expected shape: [B, C, H, W]
max_edge_resolution (`int`):
Maximum edge length (pixel).
resample_method (`PIL.Image.Resampling`):
Resampling method used to resize images.
Returns:
`torch.Tensor`: Resized image.
"""
assert 4 == img.dim(), f"Invalid input shape {img.shape}"
original_height, original_width = img.shape[-2:]
downscale_factor = min(
max_edge_resolution / original_width, max_edge_resolution / original_height
)
new_width = int(original_width * downscale_factor)
new_height = int(original_height * downscale_factor)
resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
return resized_img
def ensemble_depth(
depth: torch.Tensor,
scale_invariant: bool = True,
shift_invariant: bool = True,
output_uncertainty: bool = False,
reduction: str = "median",
regularizer_strength: float = 0.02,
max_iter: int = 50,
tol: float = 1e-6,
max_res: int = 1024,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Ensembles depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the
number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for
depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The
alignment happens when the predictions have one or more degrees of freedom, that is when they are either
affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only
`scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`)
alignment is skipped and only ensembling is performed.
Args:
depth (`torch.Tensor`):
Input ensemble depth maps.
scale_invariant (`bool`, *optional*, defaults to `True`):
Whether to treat predictions as scale-invariant.
shift_invariant (`bool`, *optional*, defaults to `True`):
Whether to treat predictions as shift-invariant.
output_uncertainty (`bool`, *optional*, defaults to `False`):
Whether to output uncertainty map.
reduction (`str`, *optional*, defaults to `"median"`):
Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and
`"median"`.
regularizer_strength (`float`, *optional*, defaults to `0.02`):
Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1.
max_iter (`int`, *optional*, defaults to `2`):
Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options`
argument.
tol (`float`, *optional*, defaults to `1e-3`):
Alignment solver tolerance. The solver stops when the tolerance is reached.
max_res (`int`, *optional*, defaults to `1024`):
Resolution at which the alignment is performed; `None` matches the `processing_resolution`.
Returns:
A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape:
`(1, 1, H, W)`.
"""
if depth.dim() != 4 or depth.shape[1] != 1:
raise ValueError(f"Expecting 4D tensor of shape [B,1,H,W]; got {depth.shape}.")
if reduction not in ("mean", "median"):
raise ValueError(f"Unrecognized reduction method: {reduction}.")
if not scale_invariant and shift_invariant:
raise ValueError("Pure shift-invariant ensembling is not supported.")
def init_param(depth: torch.Tensor):
init_min = depth.reshape(ensemble_size, -1).min(dim=1).values
init_max = depth.reshape(ensemble_size, -1).max(dim=1).values
if scale_invariant and shift_invariant:
init_s = 1.0 / (init_max - init_min).clamp(min=1e-6)
init_t = -init_s * init_min
param = torch.cat((init_s, init_t)).cpu().numpy()
elif scale_invariant:
init_s = 1.0 / init_max.clamp(min=1e-6)
param = init_s.cpu().numpy()
else:
raise ValueError("Unrecognized alignment.")
return param.astype(np.float64)
def align(depth: torch.Tensor, param: np.ndarray) -> torch.Tensor:
if scale_invariant and shift_invariant:
s, t = np.split(param, 2)
s = torch.from_numpy(s).to(depth).view(ensemble_size, 1, 1, 1)
t = torch.from_numpy(t).to(depth).view(ensemble_size, 1, 1, 1)
out = depth * s + t
elif scale_invariant:
s = torch.from_numpy(param).to(depth).view(ensemble_size, 1, 1, 1)
out = depth * s
else:
raise ValueError("Unrecognized alignment.")
return out
def ensemble(
depth_aligned: torch.Tensor, return_uncertainty: bool = False
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
uncertainty = None
if reduction == "mean":
prediction = torch.mean(depth_aligned, dim=0, keepdim=True)
if return_uncertainty:
uncertainty = torch.std(depth_aligned, dim=0, keepdim=True)
elif reduction == "median":
prediction = torch.median(depth_aligned, dim=0, keepdim=True).values
if return_uncertainty:
uncertainty = torch.median(
torch.abs(depth_aligned - prediction), dim=0, keepdim=True
).values
else:
raise ValueError(f"Unrecognized reduction method: {reduction}.")
return prediction, uncertainty
def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float:
cost = 0.0
depth_aligned = align(depth, param)
for i, j in torch.combinations(torch.arange(ensemble_size)):
diff = depth_aligned[i] - depth_aligned[j]
cost += (diff**2).mean().sqrt().item()
if regularizer_strength > 0:
prediction, _ = ensemble(depth_aligned, return_uncertainty=False)
err_near = (0.0 - prediction.min()).abs().item()
err_far = (1.0 - prediction.max()).abs().item()
cost += (err_near + err_far) * regularizer_strength
return cost
def compute_param(depth: torch.Tensor):
import scipy
depth_to_align = depth.to(torch.float32)
if max_res is not None and max(depth_to_align.shape[2:]) > max_res:
depth_to_align = resize_max_res(
depth_to_align, max_res, get_tv_resample_method("nearest-exact")
)
param = init_param(depth_to_align)
res = scipy.optimize.minimize(
partial(cost_fn, depth=depth_to_align),
param,
method="BFGS",
tol=tol,
options={"maxiter": max_iter, "disp": False},
)
return res.x
requires_aligning = scale_invariant or shift_invariant
ensemble_size = depth.shape[0]
if requires_aligning:
param = compute_param(depth)
depth = align(depth, param)
depth, uncertainty = ensemble(depth, return_uncertainty=output_uncertainty)
depth_max = depth.max()
if scale_invariant and shift_invariant:
depth_min = depth.min()
elif scale_invariant:
depth_min = 0
else:
raise ValueError("Unrecognized alignment.")
depth_range = (depth_max - depth_min).clamp(min=1e-6)
depth = (depth - depth_min) / depth_range
if output_uncertainty:
uncertainty /= depth_range
return depth, uncertainty # [1,1,H,W], [1,1,H,W]
# Search table for suggested max. inference batch size
bs_search_table = [
# tested on A100-PCIE-80GB
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
# tested on A100-PCIE-40GB
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
# tested on RTX3090, RTX4090
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
# tested on GTX1080Ti
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
]
def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
"""
Automatically search for suitable operating batch size.
Args:
ensemble_size (`int`):
Number of predictions to be ensembled.
input_res (`int`):
Operating resolution of the input image.
Returns:
`int`: Operating batch size.
"""
if not torch.cuda.is_available():
return 1
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
for settings in sorted(
filtered_bs_search_table,
key=lambda k: (k["res"], -k["total_vram"]),
):
if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
bs = settings["bs"]
if bs > ensemble_size:
bs = ensemble_size
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
bs = math.ceil(ensemble_size / 2)
return bs
return 1