nandometzger
commited on
Commit
·
9973b4a
1
Parent(s):
a4b6cb2
add find bs func2
Browse files- pipeline.py +26 -0
pipeline.py
CHANGED
@@ -1820,6 +1820,32 @@ def ensemble_depth(
|
|
1820 |
return depth, uncertainty # [1,1,H,W], [1,1,H,W]
|
1821 |
|
1822 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1823 |
def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
|
1824 |
"""
|
1825 |
Automatically search for suitable operating batch size.
|
|
|
1820 |
return depth, uncertainty # [1,1,H,W], [1,1,H,W]
|
1821 |
|
1822 |
|
1823 |
+
# Search table for suggested max. inference batch size
|
1824 |
+
bs_search_table = [
|
1825 |
+
# tested on A100-PCIE-80GB
|
1826 |
+
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
|
1827 |
+
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
|
1828 |
+
# tested on A100-PCIE-40GB
|
1829 |
+
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
|
1830 |
+
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
|
1831 |
+
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
|
1832 |
+
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
|
1833 |
+
# tested on RTX3090, RTX4090
|
1834 |
+
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
|
1835 |
+
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
|
1836 |
+
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
|
1837 |
+
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
|
1838 |
+
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
|
1839 |
+
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
|
1840 |
+
# tested on GTX1080Ti
|
1841 |
+
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
|
1842 |
+
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
|
1843 |
+
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
|
1844 |
+
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
|
1845 |
+
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
|
1846 |
+
]
|
1847 |
+
|
1848 |
+
|
1849 |
def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
|
1850 |
"""
|
1851 |
Automatically search for suitable operating batch size.
|