SmolLM2-135m / utils.py
gitesh-grover's picture
Upload 6 files
960a17b verified
raw
history blame contribute delete
959 Bytes
import torch
def get_device(seed = 1):
# Seed is to generate the same random data for each run
# For reproducibility
torch.manual_seed(seed)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
if torch.cuda.is_available():
print(f"[INFO] GPU: {torch.cuda.get_device_name(0)}")
print(f"[INFO] CUDA Version: {torch.version.cuda}\n")
torch.cuda.manual_seed(seed)
if not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
print("MPS not available because the current PyTorch install was not "
"built with MPS enabled.")
else:
print("MPS not available because the current MacOS version is not 12.3+ "
"and/or you do not have an MPS-enabled device on this machine.")
else:
torch.mps.manual_seed(seed)
return device