from __future__ import annotations from contextlib import contextmanager import torch @contextmanager def cpu_autocast(enabled: bool = True): """Context manager for bfloat16 autocast on CPU. Parameters ---------- enabled: bool, default True Whether to enable autocast. When ``False`` this context manager behaves like a no-op. """ if enabled: with torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16): yield else: yield