nandometzger commited on
Commit
9973b4a
·
1 Parent(s): a4b6cb2

add find bs func2

Browse files
Files changed (1) hide show
  1. pipeline.py +26 -0
pipeline.py CHANGED
@@ -1820,6 +1820,32 @@ def ensemble_depth(
1820
  return depth, uncertainty # [1,1,H,W], [1,1,H,W]
1821
 
1822
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1823
  def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
1824
  """
1825
  Automatically search for suitable operating batch size.
 
1820
  return depth, uncertainty # [1,1,H,W], [1,1,H,W]
1821
 
1822
 
1823
+ # Search table for suggested max. inference batch size
1824
+ bs_search_table = [
1825
+ # tested on A100-PCIE-80GB
1826
+ {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
1827
+ {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
1828
+ # tested on A100-PCIE-40GB
1829
+ {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
1830
+ {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
1831
+ {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
1832
+ {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
1833
+ # tested on RTX3090, RTX4090
1834
+ {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
1835
+ {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
1836
+ {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
1837
+ {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
1838
+ {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
1839
+ {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
1840
+ # tested on GTX1080Ti
1841
+ {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
1842
+ {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
1843
+ {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
1844
+ {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
1845
+ {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
1846
+ ]
1847
+
1848
+
1849
  def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
1850
  """
1851
  Automatically search for suitable operating batch size.