|
import torch |
|
|
|
def get_device(seed = 1): |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") |
|
|
|
if torch.cuda.is_available(): |
|
print(f"[INFO] GPU: {torch.cuda.get_device_name(0)}") |
|
print(f"[INFO] CUDA Version: {torch.version.cuda}\n") |
|
torch.cuda.manual_seed(seed) |
|
|
|
if not torch.backends.mps.is_available(): |
|
if not torch.backends.mps.is_built(): |
|
print("MPS not available because the current PyTorch install was not " |
|
"built with MPS enabled.") |
|
else: |
|
print("MPS not available because the current MacOS version is not 12.3+ " |
|
"and/or you do not have an MPS-enabled device on this machine.") |
|
else: |
|
torch.mps.manual_seed(seed) |
|
|
|
return device |