File size: 3,117 Bytes
36c78b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import torch.nn as nn

from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization.observer import MinMaxObserver
from torch.ao.quantization.qconfig import QConfig
from torch.ao.quantization import convert

from .model import BitTransformerLM


def quantize_dynamic(model: BitTransformerLM, dtype: torch.dtype = torch.qint8) -> BitTransformerLM:
    """Return a dynamically quantized copy of the model for inference."""
    quantized = torch.quantization.quantize_dynamic(
        model, {nn.Linear}, dtype=dtype
    )
    return quantized


class FourBitObserver(MinMaxObserver):
    """Min-max observer configured for 4-bit quantization."""

    def __init__(self, **kwargs):
        super().__init__(
            quant_min=0,
            quant_max=15,
            dtype=torch.quint8,
            qscheme=torch.per_tensor_affine,
            **kwargs,
        )


FourBitFakeQuantize = FakeQuantize.with_args(observer=FourBitObserver)

four_bit_qconfig = QConfig(activation=FourBitFakeQuantize, weight=FourBitFakeQuantize)


class QATLinear(nn.Linear):
    """Linear layer with fake quantization for QAT."""

    def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
        super().__init__(in_features, out_features, bias)
        self.weight_fake_quant = FourBitFakeQuantize()
        self.activation_post_process = FourBitFakeQuantize()

    @classmethod
    def from_float(cls, mod: nn.Linear) -> "QATLinear":
        qat = cls(mod.in_features, mod.out_features, mod.bias is not None)
        qat.weight = mod.weight
        qat.bias = mod.bias
        return qat

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.activation_post_process(x)
        w = self.weight_fake_quant(self.weight)
        return nn.functional.linear(x, w, self.bias)


def prepare_qat_fx(model: BitTransformerLM) -> BitTransformerLM:
    """Prepare BitTransformerLM for quantization-aware training."""

    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            setattr(model, name, QATLinear.from_float(module))
        else:
            prepare_qat_fx(module)
    return model


def convert_qat_fx(model: BitTransformerLM) -> BitTransformerLM:
    """Convert a QAT-prepared model to a quantized version."""

    for name, module in model.named_children():
        if isinstance(module, QATLinear):
            w = module.weight.data
            qmin, qmax = 0, 15
            min_w = w.min()
            max_w = w.max()
            scale = (max_w - min_w) / (qmax - qmin + 1e-8)
            zero_point = qmin - torch.round(min_w / scale)
            q_w = torch.clamp(torch.round(w / scale + zero_point), qmin, qmax)
            new_mod = nn.Linear(module.in_features, module.out_features, module.bias is not None)
            new_mod.weight = nn.Parameter((q_w - zero_point) * scale)
            if module.bias is not None:
                new_mod.bias = nn.Parameter(module.bias.data)
            setattr(model, name, new_mod)
        else:
            convert_qat_fx(module)
    return model