|
import torch |
|
import torch.nn as nn |
|
|
|
from torch.ao.quantization.fake_quantize import FakeQuantize |
|
from torch.ao.quantization.observer import MinMaxObserver |
|
from torch.ao.quantization.qconfig import QConfig |
|
from torch.ao.quantization import convert |
|
|
|
from .model import BitTransformerLM |
|
|
|
|
|
def quantize_dynamic(model: BitTransformerLM, dtype: torch.dtype = torch.qint8) -> BitTransformerLM: |
|
"""Return a dynamically quantized copy of the model for inference.""" |
|
quantized = torch.quantization.quantize_dynamic( |
|
model, {nn.Linear}, dtype=dtype |
|
) |
|
return quantized |
|
|
|
|
|
class FourBitObserver(MinMaxObserver): |
|
"""Min-max observer configured for 4-bit quantization.""" |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__( |
|
quant_min=0, |
|
quant_max=15, |
|
dtype=torch.quint8, |
|
qscheme=torch.per_tensor_affine, |
|
**kwargs, |
|
) |
|
|
|
|
|
FourBitFakeQuantize = FakeQuantize.with_args(observer=FourBitObserver) |
|
|
|
four_bit_qconfig = QConfig(activation=FourBitFakeQuantize, weight=FourBitFakeQuantize) |
|
|
|
|
|
class QATLinear(nn.Linear): |
|
"""Linear layer with fake quantization for QAT.""" |
|
|
|
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: |
|
super().__init__(in_features, out_features, bias) |
|
self.weight_fake_quant = FourBitFakeQuantize() |
|
self.activation_post_process = FourBitFakeQuantize() |
|
|
|
@classmethod |
|
def from_float(cls, mod: nn.Linear) -> "QATLinear": |
|
qat = cls(mod.in_features, mod.out_features, mod.bias is not None) |
|
qat.weight = mod.weight |
|
qat.bias = mod.bias |
|
return qat |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.activation_post_process(x) |
|
w = self.weight_fake_quant(self.weight) |
|
return nn.functional.linear(x, w, self.bias) |
|
|
|
|
|
def prepare_qat_fx(model: BitTransformerLM) -> BitTransformerLM: |
|
"""Prepare BitTransformerLM for quantization-aware training.""" |
|
|
|
for name, module in model.named_children(): |
|
if isinstance(module, nn.Linear): |
|
setattr(model, name, QATLinear.from_float(module)) |
|
else: |
|
prepare_qat_fx(module) |
|
return model |
|
|
|
|
|
def convert_qat_fx(model: BitTransformerLM) -> BitTransformerLM: |
|
"""Convert a QAT-prepared model to a quantized version.""" |
|
|
|
for name, module in model.named_children(): |
|
if isinstance(module, QATLinear): |
|
w = module.weight.data |
|
qmin, qmax = 0, 15 |
|
min_w = w.min() |
|
max_w = w.max() |
|
scale = (max_w - min_w) / (qmax - qmin + 1e-8) |
|
zero_point = qmin - torch.round(min_w / scale) |
|
q_w = torch.clamp(torch.round(w / scale + zero_point), qmin, qmax) |
|
new_mod = nn.Linear(module.in_features, module.out_features, module.bias is not None) |
|
new_mod.weight = nn.Parameter((q_w - zero_point) * scale) |
|
if module.bias is not None: |
|
new_mod.bias = nn.Parameter(module.bias.data) |
|
setattr(model, name, new_mod) |
|
else: |
|
convert_qat_fx(module) |
|
return model |
|
|