from __future__ import annotations | |
from contextlib import contextmanager | |
import torch | |
def cpu_autocast(enabled: bool = True): | |
"""Context manager for bfloat16 autocast on CPU. | |
Parameters | |
---------- | |
enabled: bool, default True | |
Whether to enable autocast. When ``False`` this context manager | |
behaves like a no-op. | |
""" | |
if enabled: | |
with torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16): | |
yield | |
else: | |
yield | |