kernel
rotary / tests /utils.py
danieldk's picture
danieldk HF Staff
Add support for XPU (sycl) (#3)
e94ff91 verified
raw
history blame
517 Bytes
import torch
def infer_device():
"""
Get current device name based on available devices
"""
if torch.cuda.is_available(): # Works for both Nvidia and AMD
return "cuda"
elif torch.xpu.is_available():
return "xpu"
else:
return None
def supports_bfloat16():
device = infer_device()
if device == "cuda":
return torch.cuda.get_device_capability() >= (8, 0) # Ampere and newer
elif device == "xpu":
return True
else:
return False