File size: 61,683 Bytes
ea5665b 2637d94 ea5665b e6dd6dd ea5665b e6dd6dd ea5665b 2637d94 ea5665b 2637d94 ea5665b 2637d94 ea5665b c076e6d ea5665b c076e6d ea5665b a162382 ea5665b 1502fda ea5665b 2637d94 c076e6d 2637d94 c076e6d 2637d94 e6dd6dd c076e6d e6dd6dd c076e6d ea5665b c076e6d ea5665b 4404838 c076e6d ea5665b e6dd6dd 2637d94 e6dd6dd 2637d94 ea5665b a162382 ea5665b a4b6cb2 9973b4a a4b6cb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 |
# Copyright 2023-2025 Marigold Team, ETH Zürich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# More information about Marigold:
# https://marigoldmonodepth.github.io
# https://marigoldcomputervision.github.io
# Efficient inference pipelines are now part of diffusers:
# https://huggingface.co/docs/diffusers/using-diffusers/marigold_usage
# https://huggingface.co/docs/diffusers/api/pipelines/marigold
# Examples of trained models and live demos:
# https://huggingface.co/prs-eth
# Related projects:
# https://rollingdepth.github.io/
# https://marigolddepthcompletion.github.io/
# Citation (BibTeX):
# https://github.com/prs-eth/Marigold#-citation
# If you find Marigold useful, we kindly ask you to cite our papers.
# --------------------------------------------------------------------------
import logging
import numpy as np
import torch
from typing import Dict, Union
import math
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DiffusionPipeline,
LCMScheduler,
UNet2DConditionModel,
AutoencoderTiny,
)
from diffusers.utils import BaseOutput
from PIL import Image
from torch.utils.data import DataLoader, TensorDataset
from torchvision.transforms.functional import resize, pil_to_tensor
from torchvision.transforms import InterpolationMode
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from functools import partial
from typing import Optional, Tuple
class MarigoldDepthOutput(BaseOutput):
"""
Output class for Marigold Monocular Depth Estimation pipeline.
Args:
depth_np (`np.ndarray`):
Predicted depth map, with depth values in the range of [0, 1].
base_depth_np (`np.ndarray`):
Upsampled base depth map, with depth values in the range of [0, 1].
This is the depth map used as a global guidance for the boosted inference.
It is upsampled to the same resolution as the final depth map.
This is useful for visualization and debugging purposes.
"""
depth_np: np.ndarray
base_depth_np: np.ndarray # NEW: upsampled base depth
class MarigoldDepthHRPipeline(DiffusionPipeline):
"""
Pipeline for high resolution monocular depth estimation using Marigold: https://marigoldcomputervision.github.io.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
unet (`UNet2DConditionModel`):
Conditional U-Net to denoise the prediction latent, conditioned on image latent.
vae (`AutoencoderKL`):
Variational Auto-Encoder (VAE) Model to encode and decode images and predictions
to and from latent representations.
scheduler (`DDIMScheduler`):
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
text_encoder (`CLIPTextModel`):
Text-encoder, for empty text embedding.
tokenizer (`CLIPTokenizer`):
CLIP tokenizer.
boosting_unet (`UNet2DConditionModel`):
Conditional U-Net to denoise the depth latent, conditioned on image latent and a global depth map.
scale_invariant (`bool`, *optional*):
A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in
the model config. When used together with the `shift_invariant=True` flag, the model is also called
"affine-invariant". NB: overriding this value is not supported.
shift_invariant (`bool`, *optional*):
A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in
the model config. When used together with the `scale_invariant=True` flag, the model is also called
"affine-invariant". NB: overriding this value is not supported.
default_denoising_steps (`int`, *optional*):
The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable
quality with the given model. This value must be set in the model config. When the pipeline is called
without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure
reasonable results with various model flavors compatible with the pipeline, such as those relying on very
short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`).
default_boosting_denoising_steps (`int`, *optional*):
Same as `default_denoising_steps` but for `boosting_unet`.
default_processing_resolution (`int`, *optional*):
The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
default value is used. This is required to ensure reasonable results with various model flavors trained
with varying optimal processing resolution values.
"""
latent_scale_factor = 0.18215
def __init__(
self,
unet: UNet2DConditionModel,
vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, LCMScheduler],
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
boosting_unet: Optional[UNet2DConditionModel],
scale_invariant: Optional[bool] = True,
shift_invariant: Optional[bool] = True,
default_denoising_steps: Optional[int] = None,
default_boosting_denoising_steps: Optional[int] = None,
default_processing_resolution: Optional[int] = None,
base_depth_model_uri: Optional[str] = None,
variant: Optional[str] = None,
):
super().__init__()
if boosting_unet is None:
logging.warning(
"Boosting U-Net is not provided. If this message appears during training, it is expected."
)
self.register_modules(
unet=unet,
vae=vae,
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
boosting_unet=boosting_unet,
)
self.register_to_config(
scale_invariant=scale_invariant,
shift_invariant=shift_invariant,
default_denoising_steps=default_denoising_steps,
default_boosting_denoising_steps=default_denoising_steps,
default_processing_resolution=default_processing_resolution,
)
self.register_to_config(base_depth_model_uri=base_depth_model_uri)
self.scale_invariant = scale_invariant
self.shift_invariant = shift_invariant
self.default_denoising_steps = default_denoising_steps
self.default_boosting_denoising_steps = default_boosting_denoising_steps
self.default_processing_resolution = default_processing_resolution
if base_depth_model_uri is not None:
# load the original LR depth model
self.base_pipe = DiffusionPipeline.from_pretrained(
base_depth_model_uri,
variant=variant,
torch_dtype=self.dtype,
trust_remote_code=True,
)
self.base_pipe.to(self.device)
else:
self.base_pipe = None
self.empty_text_embed = None
@torch.no_grad()
def __call__(
self,
input_image: Union[Image.Image, torch.Tensor],
*,
base_depth: Optional[Union[Image.Image, np.ndarray, torch.Tensor]] = None,
denoising_steps: Optional[int] = None,
boosted_denoising_steps: Optional[int] = None,
ensemble_size: int = 10,
boosted_ensemble_size: int = 5,
processing_res: Optional[int] = None,
match_input_res: bool = True,
resample_method: str = "bilinear",
batch_size: int = 0,
show_progress_bar: bool = True,
ensemble_kwargs: Dict = None,
upscale_factor: int = 2,
) -> MarigoldDepthOutput:
"""
Function invoked when calling the pipeline.
Args:
input_image (`Image` or `torch.Tensor`):
Input RGB (or gray-scale) image.
base_depth (`Image`, `np.ndarray`, `torch.Tensor` or `MarigoldDepthOutput`, *optional*):
Base depth map to be used as a global guidance for the boosted inference.
denoising_steps (`int`, *optional*, defaults to `10`):
Number of diffusion denoising steps (DDIM) during inference.
ensemble_size (`int`, *optional*, defaults to `10`):
Number of predictions to be ensembled.
boosted_ensemble_size (`int`, *optional*, defaults to `5`):
Number of predictions to be ensembled in the boosted inference.
processing_res (`int`, *optional*, defaults to `768`):
Maximum resolution of processing.
If set to 0: will not resize at all.
match_input_res (`bool`, *optional*, defaults to `True`):
Resize depth prediction to match input resolution.
Only valid if `processing_res` > 0.
resample_method: (`str`, *optional*, defaults to `bilinear`):
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
batch_size (`int`, *optional*, defaults to `0`):
Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size.
show_progress_bar (`bool`, *optional*, defaults to `True`):
Display a progress bar of diffusion denoising.
scale_invariant (`str`, *optional*, defaults to `True`):
Flag of scale-invariant prediction, if True, scale will be adjusted from the raw prediction.
shift_invariant (`str`, *optional*, defaults to `True`):
Flag of shift-invariant prediction, if True, shift will be adjusted from the raw prediction, if False,
near plane will be fixed at 0m.
ensemble_kwargs (`dict`, *optional*, defaults to `None`):
Arguments for detailed ensembling settings.
Returns:
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
- **base_depth_np** (`np.ndarray`) Upsampled base depth map, with depth values in the range of [0, 1].
This is the depth map used as a global guidance for the boosted inference.
"""
# Model-specific optimal default values leading to fast and reasonable results.
if denoising_steps is None:
denoising_steps = self.default_denoising_steps
if boosted_denoising_steps is None:
boosted_denoising_steps = self.default_boosting_denoising_steps
if processing_res is None:
processing_res = self.default_processing_resolution
# Asserts
assert processing_res >= 0, "Processing resolution must be non-negative."
assert ensemble_size >= 1, "Ensemble size must be at least 1."
assert boosted_ensemble_size >= 1, "Boosted ensemble size must be at least 1."
assert math.log2(
upscale_factor
).is_integer(), "Upscale factor must be a power of 2."
assert upscale_factor >= 1, "Upscale factor must be at least 2."
assert batch_size >= 0, "Batch size must be non-negative."
# Warnings
if upscale_factor >= 8 and self.dtype == torch.float16:
logging.warning(
"Warning: Upscaling factors of 8 (and more) with half precision or more may lead to artifacts in the final prediction."
)
if upscale_factor >= 4 and isinstance(self.vae, AutoencoderTiny):
logging.warning(
"Warning: Upscaling factors of 4 (and more) with the Tiny VAE may lead to instabilities."
)
# Get the resolution of the input RGB image
input_width, input_height = (
input_image.size
if isinstance(input_image, Image.Image)
else input_image.shape[-2:]
)
# 1) get base prediction
if base_depth is not None:
# load into float32 np.ndarray
if isinstance(base_depth, Image.Image):
lowres = np.asarray(base_depth.convert("L"), dtype=np.float32)
elif isinstance(base_depth, torch.Tensor):
lowres = base_depth.squeeze().cpu().numpy().astype(np.float32)
elif isinstance(base_depth, np.ndarray):
lowres = base_depth.squeeze().astype(np.float32)
elif isinstance(base_depth, MarigoldDepthOutput):
lowres = base_depth.depth_np.astype(np.float32)
else:
raise TypeError(f"Unsupported base_depth type: {type(base_depth)}")
# *** min–max normalize to [0,1] ***
min_v, max_v = lowres.min(), lowres.max()
eps = 1e-8
if max_v - min_v > eps:
lowres = (lowres - min_v) / (max_v - min_v + eps)
else:
# flat image → all zeros
lowres = np.zeros_like(lowres, dtype=np.float32)
else:
assert self.base_pipe is not None
if self.base_pipe.device != self.device:
# Move the base pipe to the correct device
self.base_pipe.to(self.device)
base_out = self.base_pipe(
input_image,
num_inference_steps=denoising_steps,
ensemble_size=ensemble_size,
processing_resolution=processing_res,
match_input_resolution=False,
batch_size=1, # base inference is always done in batch size 1
# show_progress_bar=show_progress_bar,
resample_method_input=resample_method,
resample_method_output=resample_method,
ensembling_kwargs=ensemble_kwargs,
)
lowres = base_out.prediction[0,:,:,0] # [H, W]
base_out.depth_np = lowres
# 2) Upsample base for output
t = torch.from_numpy(lowres[None,None])
up = resize(t, (input_height, input_width),
interpolation=InterpolationMode.NEAREST_EXACT,
antialias=True)
base_depth_np_upsampled = up.squeeze().cpu().numpy()
# 3) If no boosting requested, return early
if upscale_factor == 1:
# If no upscaling is needed, return the base prediction
return MarigoldDepthOutput(
depth_np=lowres,
base_depth_np=base_depth_np_upsampled
)
# 4) Normalize and run boosted inference (unchanged)
global_pred = torch.from_numpy(lowres).to(self.device)
global_pred = (global_pred - global_pred.min()) / (global_pred.max() - global_pred.min())
# Iterative refinement logic
current_pred = global_pred
current_factor = 2 # Start with an upscale factor of 2
# Create a list of all upscale factors up to the target
upscale_factors = [2**i for i in range(1, int(math.log2(upscale_factor)) + 1)]
# precalculate patch dimensions
if processing_res == 0:
processing_res = 768
df = min(
processing_res / input_image.width, processing_res / input_image.height
)
patch_height = int(input_image.height * df)
patch_width = int(input_image.width * df)
# Pre‐warn about any that exceed the 1.1× threshold
for factor in upscale_factors:
tw = patch_width * factor
th = patch_height * factor
if tw > input_width * 1.1 or th > input_height * 1.1:
logging.warning(
f"Warning: Attempting to upsample to {tw}×{th}, "
f"which exceeds the original input of {input_width}×{input_height}. "
"This technically works, but may lead to suboptimal results."
)
# 5) Perform iterative boosted inference
with tqdm(
total=len(upscale_factors),
desc=" Upscaling Progress",
unit="step",
leave=False,
) as pbar:
for current_factor in upscale_factors:
# Update the description with the current upscaling factor
pbar.set_description(f" Upscaling x{current_factor}")
# Determine if this is the final step
is_final_step = current_factor == upscale_factor
# 2. Perform a single boosted inference step
boosted_output = self.boosted_inference(
input_image=input_image,
denoising_steps=boosted_denoising_steps,
ensemble_size=(
boosted_ensemble_size
if current_factor < upscale_factors[-1]
else boosted_ensemble_size
),
processing_res=processing_res,
match_input_res=match_input_res and is_final_step,
batch_size=batch_size,
resample_method=resample_method,
show_progress_bar=show_progress_bar,
ensemble_kwargs=ensemble_kwargs,
global_pred=current_pred,
upscale_factor=current_factor,
)
# Update predictions
current_pred = torch.from_numpy(boosted_output.depth_np)
# Clean up GPU memory
torch.cuda.empty_cache()
# Progress to the next upscale factor
current_factor *= 2
# Update the progress bar
pbar.update(1)
# Return the final output, and attach base depth map
out = boosted_output
out.base_depth_np = base_depth_np_upsampled
return out
def boosted_inference(
self,
input_image: Union[torch.Tensor],
denoising_steps: int = 10,
ensemble_size: int = 10,
processing_res: int = 768,
match_input_res: bool = True,
batch_size: int = 0,
resample_method: str = "bilinear",
seed: Union[int, None] = None,
show_progress_bar: bool = True,
ensemble_kwargs: Dict = None,
global_pred: torch.Tensor = None,
upscale_factor: int = 2,
) -> MarigoldDepthOutput:
"""
Function invoked when calling the pipeline with boosted inference.
Args:
input_image (`torch.Tensor`):
Input RGB image.
denoising_steps (`int`, *optional*, defaults to `10`):
Number of diffusion denoising steps (DDIM) during inference.
ensemble_size (`int`, *optional*, defaults to `10`):
Number of predictions to be ensembled.
processing_res (`int`, *optional*, defaults to `768`):
Maximum resolution of processing.
If set to 0: will not resize at all.
match_input_res (`bool`, *optional*, defaults to `True`):
Resize depth prediction to match input resolution.
Only valid if `processing_res` > 0.
resample_method: (`str`, *optional*, defaults to `bilinear`):
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
batch_size (`int`, *optional*, defaults to `0`):
Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size.
seed (`int`, *optional*, defaults to `None`):
Random seed for the diffusion process.
show_progress_bar (`bool`, *optional*, defaults to `True`):
Display a progress bar of diffusion denoising.
ensemble_kwargs (`dict`, *optional*, defaults to `None`):
Arguments for detailed ensembling settings.
global_pred (`torch.Tensor`):
Global depth map to be used as guidance.
upscale_factor (`int`, *optional*, defaults to `2`):
Upscale factor of the global depth map.
Returns:
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
- **depth_np** (`np.ndarray`) Predicted depth map with depth values in the range of [0, 1]
- **base_depth_np** (`np.ndarray`) Upsampled base depth map with depth values in the range of [0, 1].
"""
device = self.device
self._check_inference_step(denoising_steps)
resample_method: InterpolationMode = get_tv_resample_method(resample_method)
# Convert to torch tensor
if isinstance(input_image, Image.Image):
input_image = input_image.convert("RGB")
# convert to torch tensor [H, W, rgb] -> [rgb, H, W]
input_image = pil_to_tensor(input_image)
elif isinstance(input_image, torch.Tensor):
input_image = input_image.squeeze()
# pass
else:
raise TypeError(f"Unknown input type: {type(input_image) = }")
input_size = input_image.shape
assert (
3 == input_image.dim() and 3 == input_size[0]
), f"Wrong input shape {input_size}, expected [rgb, H, W]"
if isinstance(global_pred, torch.Tensor):
global_pred = global_pred.squeeze().unsqueeze(0).to(device)
else:
raise TypeError(f"Unknown global_pred type: {type(global_pred) = }")
if processing_res == 0:
# fallback to original resolution
processing_res = 768
df = min(
processing_res / input_image.shape[2], processing_res / input_image.shape[1]
)
patch_height = int(input_image.shape[1] * df)
patch_width = int(input_image.shape[2] * df)
# need to be divisible by 8
patch_height = round(patch_height / 16) * 16
patch_width = round(patch_width / 16) * 16
patch_size = patch_height, patch_width
global_size = (patch_size[0] * upscale_factor, patch_size[1] * upscale_factor)
if global_pred.shape[1:] != global_size:
global_pred = resize(
global_pred,
global_size,
interpolation=resample_method,
antialias=True,
)
if input_image.shape[1:] != patch_size:
input_image = resize(
input_image,
global_size,
interpolation=resample_method,
antialias=True,
).squeeze()
input_image = (
input_image.unsqueeze(0) / 255.0 * 2.0 - 1.0
) # [0, 255] -> [-1, 1]
input_image = input_image.to(self.dtype).to(device)
assert input_image.min() >= -1.0 and input_image.max() <= 1.0
global_pred = global_pred.to(self.dtype).to(device)
if batch_size > 0:
_bs = batch_size
else:
_bs = find_batch_size(
ensemble_size=ensemble_size
* (2 * upscale_factor - 1)
* (2 * upscale_factor - 1),
input_res=max(patch_size),
dtype=self.dtype,
)
# create a small buffer in z-dimension
global_pred = 0.9 * (global_pred * 2 - 1)
global_pred = (global_pred + 1) / 2
depth_pred, pred_uncert = self.multidiffusion_inference(
rgb_norm=input_image,
global_pred=global_pred,
num_inference_steps=denoising_steps,
patch_size=patch_size,
seed=seed,
show_pbar=show_progress_bar,
ensemble_size=ensemble_size,
batch_size=_bs,
ensemble_kwargs=ensemble_kwargs,
)
depth_pred = depth_pred.squeeze(0)
# rescale to to [0, 1]
min_d = torch.min(depth_pred)
max_d = torch.max(depth_pred)
depth_pred = (depth_pred - min_d) / (max_d - min_d)
if depth_pred.shape[1:] != input_size[1:] and match_input_res:
depth_pred = resize(
depth_pred.unsqueeze(0),
input_size[1:],
interpolation=resample_method,
antialias=True,
).squeeze()
depth_pred = depth_pred.squeeze().cpu().numpy()
depth_pred = depth_pred.clip(0, 1)
return MarigoldDepthOutput(
depth_np=depth_pred,
)
def _check_inference_step(self, n_step: int) -> None:
"""
Check if denoising step is reasonable
Args:
n_step (`int`): denoising steps
"""
assert n_step >= 1
def encode_empty_text(self):
"""
Encode text embedding for empty prompt
"""
prompt = ""
text_inputs = self.tokenizer(
prompt,
padding="do_not_pad",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
@torch.no_grad()
def multidiffusion_inference(
self,
rgb_norm: torch.Tensor,
global_pred: torch.Tensor,
num_inference_steps: int,
seed: Union[int, None],
show_pbar: bool,
patch_size=(512, 768),
encoder_patch_size=None,
ensemble_size=1,
batch_size=1,
ensemble_kwargs: Dict = None,
) -> torch.Tensor:
"""
Perform an individual depth prediction without ensembling.
Args:
rgb_norm (`torch.Tensor`):
Input RGB image.
num_inference_steps (`int`):
Number of diffusion denoisign steps (DDIM) during inference.
num_patches_vert (`int`):
Number of vertical patches.
num_patches_horz (`int`):
Number of horizontal patches.
step_height (`int`):
Height of the patch.
step_width (`int`):
Width of the patch.
Returns:
`torch.Tensor`: Predicted depth map.
`torch.Tensor`: Uncertainty map.
"""
device = self.device
rgb_norm = rgb_norm.to(device)
global_pred = global_pred.to(device)
# Set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
latent_patch_size = (patch_size[0] // 8, patch_size[1] // 8)
latent_full_size = (rgb_norm.shape[-2] // 8, rgb_norm.shape[-1] // 8)
# Normalize the global prediction to [-1, 1]
global_pred = global_pred * 2 - 1
# Encode rgb and depth map
if encoder_patch_size is None:
encoder_patch_size = patch_size
rgb_latent = self.encode_rgb_patched(
rgb_norm, patch_size=encoder_patch_size, show_pbar=show_pbar
)
global_pred_latent = self.encode_depth_patched(
global_pred, patch_size=encoder_patch_size, show_pbar=show_pbar
)
# offload to cpu for memory efficiency
rgb_norm = rgb_norm.cpu()
global_pred = global_pred.cpu()
# Initial depth map (noise)
if seed is None:
rand_num_generator = None
else:
rand_num_generator = torch.Generator(device=device)
rand_num_generator.manual_seed(seed)
# patch the input images
(
rgb_latent_patched,
num_patches_vert,
num_patches_horz,
step_height,
step_width,
) = self.extract_patches(rgb_latent[0], patch_size=latent_patch_size)
# also patch the global depth map
global_pred_latent_patched, _, _, _, _ = self.extract_patches(
global_pred_latent[0], patch_size=latent_patch_size
)
# Batched empty text embedding
if self.empty_text_embed is None:
self.encode_empty_text()
batch_empty_text_embed = self.empty_text_embed.repeat((batch_size, 1, 1)).to(
device
)
if hasattr(self, "pooled_empty_text_embeds"):
batch_pooled_empty_text_embed = self.pooled_empty_text_embeds.repeat(
(batch_size, 1, 1)
).to(device)
# enlarge the variable according to the ensemble size
if ensemble_size > 1:
rgb_latent_patched = rgb_latent_patched.repeat(ensemble_size, 1, 1, 1)
global_pred_latent_patched = global_pred_latent_patched.repeat(
ensemble_size, 1, 1, 1
)
batch_empty_text_embed = batch_empty_text_embed.repeat(ensemble_size, 1, 1)
if hasattr(self, "pooled_empty_text_embeds"):
batch_pooled_empty_text_embed = batch_pooled_empty_text_embed.repeat(
ensemble_size, 1, 1
)
# Initialize the canvas and split it to get identical noise on overlaps
depth_latent = torch.randn(
(ensemble_size, 4, latent_full_size[0], latent_full_size[1]),
device=device,
dtype=self.dtype,
generator=rand_num_generator,
)
(
depth_latent_patched,
num_patches_vert,
num_patches_horz,
step_height,
step_width,
) = self.extract_patches(depth_latent, patch_size=latent_patch_size)
# Denoising loop
if show_pbar:
iterable = tqdm(
enumerate(timesteps),
total=len(timesteps),
leave=False,
desc=" " * 4 + "Diffusion denoising",
)
else:
iterable = enumerate(timesteps)
for _, t in iterable:
# 1. inference all the patches with unet
assert (
self.boosting_unet.conv_in.in_channels == 12
), "The input channels of the boosting unet must be 12."
unet_input = torch.cat(
[rgb_latent_patched, global_pred_latent_patched, depth_latent_patched],
dim=1,
)
# 2. Create a dataloader and predict the noise
dataset = TensorDataset(unet_input)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
noise_preds = []
for batch in tqdm(
loader,
leave=False,
disable=not show_pbar,
desc=" " * 6 + "UNet patch inference",
):
(unet_input,) = batch
noise_pred = self.boosting_unet(
unet_input,
t,
encoder_hidden_states=batch_empty_text_embed[: unet_input.shape[0]],
).sample
noise_preds.append(noise_pred)
noise_preds = torch.concat(noise_preds, dim=0)
# 3. Default ddim scheduler step for each patch
scheduler_out = self.scheduler.step(
noise_preds, t, depth_latent_patched, generator=rand_num_generator
)
# 4. Reshape patches to patch-spatial dimension
depth_latent = scheduler_out.prev_sample.reshape(
ensemble_size,
num_patches_vert,
num_patches_horz,
4,
latent_patch_size[0],
latent_patch_size[1],
)
# 5. Blend the patches with multidiffusion formula
depth_latent_full = self.blend_patches(
depth_latent,
canvas_size=global_pred_latent.shape[2:],
num_patches_vert=num_patches_vert,
num_patches_horz=num_patches_horz,
step_height=step_height,
step_width=step_width,
)
# 6. Update the depth_latent_patched
if t != timesteps[-1]:
depth_latent_patched, _, _, _, _ = self.extract_patches(
depth_latent_full, patch_size=latent_patch_size
)
# if t<=1: at the end of the loop we decode the full latent
depth = self.decode_depth_patched(
depth_latent_full=depth_latent_full,
canvas_size=global_pred.shape[1:],
latent_decoder_patch_size=(
encoder_patch_size[0] // 8,
encoder_patch_size[1] // 8,
),
show_pbar=show_pbar,
ensemble_size=ensemble_size,
)
if ensemble_size > 1:
depth, pred_uncert = ensemble_depth(
depth,
scale_invariant=self.scale_invariant,
shift_invariant=self.shift_invariant,
**(ensemble_kwargs or {}),
)
else:
depth = (depth + 1.0) / 2.0
pred_uncert = None
return depth, pred_uncert
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
"""
Encode RGB image into latent.
Args:
rgb_in (`torch.Tensor`):
Input RGB image to be encoded.
Returns:
`torch.Tensor`: Image latent.
"""
if isinstance(self.vae, AutoencoderTiny):
rgb_latent = self.vae.encoder(rgb_in)
else:
h = self.vae.encoder(rgb_in)
moments = self.vae.quant_conv(h)
mean, _ = torch.chunk(moments, 2, dim=1)
rgb_latent = mean * self.latent_scale_factor
return rgb_latent
def encode_rgb_patched(
self, rgb_in: torch.Tensor, patch_size: tuple, show_pbar: bool = False
) -> torch.Tensor:
"""
Encode depth map into latent.
Args:
depth_in (`torch.Tensor`):
Input depth map to be encoded.
patch_size (`Tuple[int, int]`):
Size of the patch.
Returns:
`torch.Tensor`: Depth latent.
"""
device = self.device
(
rgb_patched,
num_patches_vert_pix,
num_patches_horz_pix,
step_height_pix,
step_width_pix,
) = self.extract_patches(rgb_in.squeeze(0), patch_size=patch_size)
rgb_in = rgb_in.cpu()
rgb_patched = rgb_patched.cpu()
# forward the patches
rgb_latent_patched = []
for rgb in tqdm(
rgb_patched,
leave=False,
disable=not show_pbar,
desc=" " * 4 + "Encoding RGB",
):
patch = rgb.unsqueeze(0).to(device)
rgb_latent_patched.append(self.encode_rgb(patch).cpu())
# reshape the spatial dimensions
rgb_latent_patched = torch.concat(rgb_latent_patched, dim=0)
rgb_latent_patched = rgb_latent_patched.reshape(
num_patches_vert_pix,
num_patches_horz_pix,
4,
rgb_latent_patched.shape[-2],
rgb_latent_patched.shape[-1],
).to(device)
# blend the patches
rgb_latent = self.blend_patches(
rgb_latent_patched,
canvas_size=(rgb_in.shape[-2] // 8, rgb_in.shape[-1] // 8),
num_patches_vert=num_patches_vert_pix,
num_patches_horz=num_patches_horz_pix,
step_height=step_height_pix // 8,
step_width=step_width_pix // 8,
)
if len(rgb_latent.shape) == 3:
rgb_latent = rgb_latent.unsqueeze(0)
return rgb_latent
def encode_depth_patched(
self,
depth_in: torch.Tensor,
patch_size,
show_pbar: bool = False,
) -> torch.Tensor:
"""
Encode depth map into latent, but in a patched and scalable way
Args:
depth_in (`torch.Tensor`):
Input depth map to be encoded.
patch_size (`Tuple[int, int]`):
Size of the patch.
show_pbar (`bool`):
Display a progress bar.
Returns:
`torch.Tensor`: Depth latent.
"""
device = self.device
ensemble_size = depth_in.shape[0]
(
depth_in_patched,
num_patches_vert_pix,
num_patches_horz_pix,
step_height_pix,
step_width_pix,
) = self.extract_patches(depth_in, patch_size=patch_size)
depth_in = depth_in.cpu()
depth_in_patched = depth_in_patched.cpu()
# forward the patches
depth_latent_patched = []
for gpred in tqdm(
depth_in_patched,
leave=False,
disable=not show_pbar,
desc=" " * 4 + "Encoding context depth",
):
patch = gpred.unsqueeze(0).to(device)
depth_latent_patched.append(self.encode_depth(patch).cpu())
# reshape the spatial dimensions
depth_latent_patched = torch.concat(depth_latent_patched, dim=0)
depth_latent_patched = depth_latent_patched.reshape(
ensemble_size,
num_patches_vert_pix,
num_patches_horz_pix,
4,
depth_latent_patched.shape[-2],
depth_latent_patched.shape[-1],
).to(device)
# blend the patches
depth_latent = self.blend_patches(
depth_latent_patched,
canvas_size=(depth_in.shape[-2] // 8, depth_in.shape[-1] // 8),
num_patches_vert=num_patches_vert_pix,
num_patches_horz=num_patches_horz_pix,
step_height=step_height_pix // 8,
step_width=step_width_pix // 8,
)
if len(depth_latent.shape) == 3:
depth_latent = depth_latent.unsqueeze(0)
return depth_latent
def decode_depth_patched(
self,
depth_latent_full: torch.Tensor,
canvas_size: tuple,
latent_decoder_patch_size: tuple,
show_pbar: bool = True,
ensemble_size: int = 1,
) -> torch.Tensor:
"""
Decode depth map from latent in a patched and scalable way.
Args:
depth_latent_full (`torch.Tensor`):
Depth latent to be decoded.
canvas_size (`tuple`):
Size of the canvas.
latent_decoder_patch_size (`tuple`):
Size of the patch.
show_pbar (`bool`):
Display a progress bar.
ensemble_size (`int`):
Ensemble size.
Returns:
`torch.Tensor`: Decoded depth map.
"""
encoder_patch_size = (
latent_decoder_patch_size[0] * 8,
latent_decoder_patch_size[1] * 8,
)
# extract patches
(
depth_latent_patched,
num_patches_vert,
num_patches_horz,
step_height,
step_width,
) = self.extract_patches(
depth_latent_full, patch_size=latent_decoder_patch_size, overlap=0.5
)
# decode patches
depthp = []
for patch in tqdm(
depth_latent_patched,
leave=False,
desc=" " * 6 + "Decoding Depth",
disable=not show_pbar,
):
depthp.append(self.decode_depth(patch.unsqueeze(0)))
depthp = torch.concat(depthp, dim=0)
depthp = depthp.reshape(
ensemble_size,
num_patches_vert,
num_patches_horz,
1,
encoder_patch_size[0],
encoder_patch_size[1],
)
# blend together
depth = self.blend_patches(
depthp,
canvas_size=canvas_size,
num_patches_vert=num_patches_vert,
num_patches_horz=num_patches_horz,
step_height=step_height * 8,
step_width=step_width * 8,
)
return depth
def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
"""
Decode depth latent into depth map.
Args:
depth_latent (`torch.Tensor`):
Depth latent to be decoded.
Returns:
`torch.Tensor`: Decoded depth map of the input depth latent.
"""
# if self.using_tiny_vae:
if isinstance(self.vae, AutoencoderTiny):
stacked = self.vae.decoder(depth_latent)
else:
depth_latent = depth_latent / self.latent_scale_factor
z = self.vae.post_quant_conv(depth_latent)
stacked = self.vae.decoder(z)
# mean of output channels
depth_mean = stacked.mean(dim=1, keepdim=True)
return depth_mean
def encode_depth(
self,
depth_in: torch.Tensor,
) -> torch.Tensor:
"""
Encode depth map into latent.
Args:
depth_in (`torch.Tensor`):
Input depth map to be encoded.
Returns:
`torch.Tensor`: Depth latent of the input depth map.
"""
# stack depth into 3-channel
stacked = self.stack_depth_images(depth_in)
# encode using VAE encoder
depth_latent = self.encode_rgb(stacked)
return depth_latent
@staticmethod
def stack_depth_images(depth_in):
if 4 == len(depth_in.shape):
stacked = depth_in.repeat(1, 3, 1, 1)
elif 3 == len(depth_in.shape):
stacked = depth_in.unsqueeze(1)
stacked = depth_in.repeat(1, 3, 1, 1)
return stacked
def extract_patches(self, canvas_input_image, patch_size, overlap=0.5):
"""
Extract patches from an image
Args:
image (`np.ndarray`): Input image of shape [channels, height, width]
patch_size (`tuple`): Size of the patch
step_size (`int`): Step size
Returns:
input_image_patched (`torch.Tensor`): Extracted patches
num_patches_vert (`int`): Number of patches in the vertical direction
num_patches_horz (`int`): Number of patches in the horizontal direction
step_height (`int`): Step size in the vertical direction
step_width (`int`): Step size in the horizontal direction
"""
if len(canvas_input_image.shape) == 4:
ensemble_size = canvas_input_image.shape[0]
else:
canvas_input_image = canvas_input_image.unsqueeze(0)
ensemble_size = 1
# Step sizes (50% (overlap) of the original dimensions)
h, w = patch_size
step_height, step_width = int(h * (1 - overlap)), int(w * (1 - overlap))
# Calculate the number of patches to extract in both dimensions
num_patches_vert = (canvas_input_image.shape[-2] - h) // step_height + 1
num_patches_horz = (canvas_input_image.shape[-1] - w) // step_width + 1
# Initialize a list to hold the patches
patches = []
for e in range(ensemble_size):
for i in range(num_patches_vert):
for j in range(num_patches_horz):
# Calculate the top left corner of the current patch
start_y = i * step_height
start_x = j * step_width
# Extract the patch
patch = canvas_input_image[
e, :, start_y : start_y + h, start_x : start_x + w
]
patches.append(patch)
# Stack the patches and return
input_image_patched = torch.stack(patches)
return (
input_image_patched,
num_patches_vert,
num_patches_horz,
step_height,
step_width,
)
def blend_patches(
self,
depth_preds,
num_patches_vert,
num_patches_horz,
step_height,
step_width,
global_depth_pred=None,
canvas_size=None,
alpha_center=1.0,
alpha_edge=1e-4,
noise_blend=False,
eps=1e-8,
):
"""
Blend patches of depth maps and apply the transformation to the global depth map
Args:
global_depth_pred (`torch.Tensor`): Global depth map
depth_preds (`torch.Tensor`): Local depth maps
num_patches_vert (`int`): Number of patches in the vertical direction
num_patches_horz (`int`): Number of patches in the horizontal direction
step_height (`int`): Step size in the vertical direction
step_width (`int`): Step size in the horizontal direction
overlap (`float`): Overlap between patches
alpha_center (`float`): Weight at the center of the patch
alpha_edge (`float`): Weight at the edge of the patch, should not be exactly 0 to avoid division by zero
adjust_LSQ (`bool`): Adjust the transformation using least squares optimization
Returns:
`torch.Tensor`: Blended depth map
"""
eps = 1e-8
channels, h, w = (
depth_preds.shape[-3],
depth_preds.shape[-2],
depth_preds.shape[-1],
)
if len(depth_preds.shape) == 6:
ensemble_size = depth_preds.shape[0] if len(depth_preds.shape) == 6 else 1
elif len(depth_preds.shape) == 5:
ensemble_size = 1
depth_preds = depth_preds.unsqueeze(0)
else:
raise ValueError("depth_preds should have 5 or 6 dimensions")
# Initialize the canvas for blending depth maps
blended_depth_map = torch.zeros(
(ensemble_size, channels, canvas_size[0], canvas_size[1]),
device=depth_preds.device,
dtype=depth_preds.dtype,
)
denominator_map = torch.zeros(
(1, canvas_size[0], canvas_size[1]),
device=depth_preds.device,
dtype=depth_preds.dtype,
)
for i in range(num_patches_vert):
for j in range(num_patches_horz):
# Calculate the top left corner of the current patch
start_y = i * step_height
start_x = j * step_width
# Extract the patch
local_patch = depth_preds[:, i, j]
# Generate blending weights
weights = self.get_linear_weight_map(
h,
w,
device=local_patch.device,
alpha_center=alpha_center,
alpha_edge=alpha_edge,
cosine_blending=True,
margin=0.0,
)
# accumulate the local patch to the canvas as a linear combination
blended_depth_map[
:, :, start_y : start_y + h, start_x : start_x + w
] += (local_patch * weights)
if noise_blend:
denominator_weights = weights**2
else:
denominator_weights = weights
denominator_map[
:, start_y : start_y + h, start_x : start_x + w
] += denominator_weights
if noise_blend:
denominator_map = torch.sqrt(denominator_map)
blended_depth_map /= denominator_map + eps
# Ensure that blended_depth_map does not have NaN values
# by filling them with the global_depth_pred
if global_depth_pred is not None:
blended_depth_map[torch.isnan(blended_depth_map)] = (
global_depth_pred.repeat(ensemble_size, 1, 1, 1)[
torch.isnan(blended_depth_map)
]
)
return blended_depth_map
def get_linear_weight_map(
self,
h,
w,
device,
alpha_center=1.0,
alpha_edge=1e-4,
margin=0.0,
cosine_blending=False,
):
"""
Generate a linear weight map for blending patches
Args:
h (`int`): Height of the weight map
w (`int`): Width of the weight map
device (`torch.device`): Device to use
alpha_center (`float`): Weight at the center of the patch
alpha_edge (`float`): Weight at the edge of the patch
margin (`int`): Perceptage of image dimensions. This margin at the edges to be filled with near 0 values.
Returns:
`torch.Tensor`: Linear distance weight map (looks like a pyramid)
"""
x = torch.linspace(-1, 1, h, device=device)
y = torch.linspace(-1, 1, w, device=device)
xx, yy = torch.meshgrid(x, y, indexing="ij")
dist = torch.stack([xx.abs(), yy.abs()]).max(dim=0).values
norm_dist = dist / torch.max(dist)
# Clamp the distance to the margin
norm_dist = torch.clamp(norm_dist + margin, 0, 1)
# scale to 0 to 1
mindist = torch.min(norm_dist)
maxdist = torch.max(norm_dist)
norm_dist = (norm_dist - mindist) / (maxdist - mindist)
# Apply a cosine-based blending function for smooth transition
if cosine_blending:
weights = alpha_edge + (alpha_center - alpha_edge) * (
0.5 * (1 + torch.cos(norm_dist * math.pi))
)
else:
weights = alpha_edge + (alpha_center - alpha_edge) * (1 - norm_dist)
return weights
def get_tv_resample_method(method_str: str) -> InterpolationMode:
resample_method_dict = {
"bilinear": InterpolationMode.BILINEAR,
"bicubic": InterpolationMode.BICUBIC,
"nearest": InterpolationMode.NEAREST_EXACT,
"nearest-exact": InterpolationMode.NEAREST_EXACT,
}
resample_method = resample_method_dict.get(method_str, None)
if resample_method is None:
raise ValueError(f"Unknown resampling method: {resample_method}")
else:
return resample_method
def resize_max_res(
img: torch.Tensor,
max_edge_resolution: int,
resample_method: InterpolationMode = InterpolationMode.BILINEAR,
) -> torch.Tensor:
"""
Resize image to limit maximum edge length while keeping aspect ratio.
Args:
img (`torch.Tensor`):
Image tensor to be resized. Expected shape: [B, C, H, W]
max_edge_resolution (`int`):
Maximum edge length (pixel).
resample_method (`PIL.Image.Resampling`):
Resampling method used to resize images.
Returns:
`torch.Tensor`: Resized image.
"""
assert 4 == img.dim(), f"Invalid input shape {img.shape}"
original_height, original_width = img.shape[-2:]
downscale_factor = min(
max_edge_resolution / original_width, max_edge_resolution / original_height
)
new_width = int(original_width * downscale_factor)
new_height = int(original_height * downscale_factor)
resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
return resized_img
def ensemble_depth(
depth: torch.Tensor,
scale_invariant: bool = True,
shift_invariant: bool = True,
output_uncertainty: bool = False,
reduction: str = "median",
regularizer_strength: float = 0.02,
max_iter: int = 50,
tol: float = 1e-6,
max_res: int = 1024,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Ensembles depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the
number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for
depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The
alignment happens when the predictions have one or more degrees of freedom, that is when they are either
affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only
`scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`)
alignment is skipped and only ensembling is performed.
Args:
depth (`torch.Tensor`):
Input ensemble depth maps.
scale_invariant (`bool`, *optional*, defaults to `True`):
Whether to treat predictions as scale-invariant.
shift_invariant (`bool`, *optional*, defaults to `True`):
Whether to treat predictions as shift-invariant.
output_uncertainty (`bool`, *optional*, defaults to `False`):
Whether to output uncertainty map.
reduction (`str`, *optional*, defaults to `"median"`):
Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and
`"median"`.
regularizer_strength (`float`, *optional*, defaults to `0.02`):
Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1.
max_iter (`int`, *optional*, defaults to `2`):
Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options`
argument.
tol (`float`, *optional*, defaults to `1e-3`):
Alignment solver tolerance. The solver stops when the tolerance is reached.
max_res (`int`, *optional*, defaults to `1024`):
Resolution at which the alignment is performed; `None` matches the `processing_resolution`.
Returns:
A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape:
`(1, 1, H, W)`.
"""
if depth.dim() != 4 or depth.shape[1] != 1:
raise ValueError(f"Expecting 4D tensor of shape [B,1,H,W]; got {depth.shape}.")
if reduction not in ("mean", "median"):
raise ValueError(f"Unrecognized reduction method: {reduction}.")
if not scale_invariant and shift_invariant:
raise ValueError("Pure shift-invariant ensembling is not supported.")
def init_param(depth: torch.Tensor):
init_min = depth.reshape(ensemble_size, -1).min(dim=1).values
init_max = depth.reshape(ensemble_size, -1).max(dim=1).values
if scale_invariant and shift_invariant:
init_s = 1.0 / (init_max - init_min).clamp(min=1e-6)
init_t = -init_s * init_min
param = torch.cat((init_s, init_t)).cpu().numpy()
elif scale_invariant:
init_s = 1.0 / init_max.clamp(min=1e-6)
param = init_s.cpu().numpy()
else:
raise ValueError("Unrecognized alignment.")
return param.astype(np.float64)
def align(depth: torch.Tensor, param: np.ndarray) -> torch.Tensor:
if scale_invariant and shift_invariant:
s, t = np.split(param, 2)
s = torch.from_numpy(s).to(depth).view(ensemble_size, 1, 1, 1)
t = torch.from_numpy(t).to(depth).view(ensemble_size, 1, 1, 1)
out = depth * s + t
elif scale_invariant:
s = torch.from_numpy(param).to(depth).view(ensemble_size, 1, 1, 1)
out = depth * s
else:
raise ValueError("Unrecognized alignment.")
return out
def ensemble(
depth_aligned: torch.Tensor, return_uncertainty: bool = False
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
uncertainty = None
if reduction == "mean":
prediction = torch.mean(depth_aligned, dim=0, keepdim=True)
if return_uncertainty:
uncertainty = torch.std(depth_aligned, dim=0, keepdim=True)
elif reduction == "median":
prediction = torch.median(depth_aligned, dim=0, keepdim=True).values
if return_uncertainty:
uncertainty = torch.median(
torch.abs(depth_aligned - prediction), dim=0, keepdim=True
).values
else:
raise ValueError(f"Unrecognized reduction method: {reduction}.")
return prediction, uncertainty
def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float:
cost = 0.0
depth_aligned = align(depth, param)
for i, j in torch.combinations(torch.arange(ensemble_size)):
diff = depth_aligned[i] - depth_aligned[j]
cost += (diff**2).mean().sqrt().item()
if regularizer_strength > 0:
prediction, _ = ensemble(depth_aligned, return_uncertainty=False)
err_near = (0.0 - prediction.min()).abs().item()
err_far = (1.0 - prediction.max()).abs().item()
cost += (err_near + err_far) * regularizer_strength
return cost
def compute_param(depth: torch.Tensor):
import scipy
depth_to_align = depth.to(torch.float32)
if max_res is not None and max(depth_to_align.shape[2:]) > max_res:
depth_to_align = resize_max_res(
depth_to_align, max_res, get_tv_resample_method("nearest-exact")
)
param = init_param(depth_to_align)
res = scipy.optimize.minimize(
partial(cost_fn, depth=depth_to_align),
param,
method="BFGS",
tol=tol,
options={"maxiter": max_iter, "disp": False},
)
return res.x
requires_aligning = scale_invariant or shift_invariant
ensemble_size = depth.shape[0]
if requires_aligning:
param = compute_param(depth)
depth = align(depth, param)
depth, uncertainty = ensemble(depth, return_uncertainty=output_uncertainty)
depth_max = depth.max()
if scale_invariant and shift_invariant:
depth_min = depth.min()
elif scale_invariant:
depth_min = 0
else:
raise ValueError("Unrecognized alignment.")
depth_range = (depth_max - depth_min).clamp(min=1e-6)
depth = (depth - depth_min) / depth_range
if output_uncertainty:
uncertainty /= depth_range
return depth, uncertainty # [1,1,H,W], [1,1,H,W]
# Search table for suggested max. inference batch size
bs_search_table = [
# tested on A100-PCIE-80GB
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
# tested on A100-PCIE-40GB
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
# tested on RTX3090, RTX4090
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
# tested on GTX1080Ti
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
]
def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
"""
Automatically search for suitable operating batch size.
Args:
ensemble_size (`int`):
Number of predictions to be ensembled.
input_res (`int`):
Operating resolution of the input image.
Returns:
`int`: Operating batch size.
"""
if not torch.cuda.is_available():
return 1
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
for settings in sorted(
filtered_bs_search_table,
key=lambda k: (k["res"], -k["total_vram"]),
):
if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
bs = settings["bs"]
if bs > ensemble_size:
bs = ensemble_size
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
bs = math.ceil(ensemble_size / 2)
return bs
return 1
|