File size: 503 Bytes
36c78b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from __future__ import annotations

from contextlib import contextmanager
import torch


@contextmanager
def cpu_autocast(enabled: bool = True):
    """Context manager for bfloat16 autocast on CPU.

    Parameters
    ----------
    enabled: bool, default True
        Whether to enable autocast. When ``False`` this context manager
        behaves like a no-op.
    """
    if enabled:
        with torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16):
            yield
    else:
        yield