| from __future__ import annotations | |
| from contextlib import contextmanager | |
| import torch | |
| def cpu_autocast(enabled: bool = True): | |
| """Context manager for bfloat16 autocast on CPU. | |
| Parameters | |
| ---------- | |
| enabled: bool, default True | |
| Whether to enable autocast. When ``False`` this context manager | |
| behaves like a no-op. | |
| """ | |
| if enabled: | |
| with torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16): | |
| yield | |
| else: | |
| yield | |