WCNegentropy's picture
🤖 Updated BitTransformerLM from development space
36c78b1 verified
raw
history blame
503 Bytes
from __future__ import annotations
from contextlib import contextmanager
import torch
@contextmanager
def cpu_autocast(enabled: bool = True):
"""Context manager for bfloat16 autocast on CPU.
Parameters
----------
enabled: bool, default True
Whether to enable autocast. When ``False`` this context manager
behaves like a no-op.
"""
if enabled:
with torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16):
yield
else:
yield