File size: 503 Bytes
36c78b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
from __future__ import annotations
from contextlib import contextmanager
import torch
@contextmanager
def cpu_autocast(enabled: bool = True):
"""Context manager for bfloat16 autocast on CPU.
Parameters
----------
enabled: bool, default True
Whether to enable autocast. When ``False`` this context manager
behaves like a no-op.
"""
if enabled:
with torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16):
yield
else:
yield
|