Spaces:
Sleeping
Sleeping
import torch | |
def get_device(seed = 1): | |
# Seed is to generate the same random data for each run | |
# For reproducibility | |
torch.manual_seed(seed) | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") | |
if torch.cuda.is_available(): | |
print(f"[INFO] GPU: {torch.cuda.get_device_name(0)}") | |
print(f"[INFO] CUDA Version: {torch.version.cuda}\n") | |
torch.cuda.manual_seed(seed) | |
if not torch.backends.mps.is_available(): | |
if not torch.backends.mps.is_built(): | |
print("MPS not available because the current PyTorch install was not " | |
"built with MPS enabled.") | |
else: | |
print("MPS not available because the current MacOS version is not 12.3+ " | |
"and/or you do not have an MPS-enabled device on this machine.") | |
else: | |
torch.mps.manual_seed(seed) | |
return device |