import torch import torch.nn as nn from torch.ao.quantization.fake_quantize import FakeQuantize from torch.ao.quantization.observer import MinMaxObserver from torch.ao.quantization.qconfig import QConfig from torch.ao.quantization import convert from .model import BitTransformerLM def quantize_dynamic(model: BitTransformerLM, dtype: torch.dtype = torch.qint8) -> BitTransformerLM: """Return a dynamically quantized copy of the model for inference.""" quantized = torch.quantization.quantize_dynamic( model, {nn.Linear}, dtype=dtype ) return quantized class FourBitObserver(MinMaxObserver): """Min-max observer configured for 4-bit quantization.""" def __init__(self, **kwargs): super().__init__( quant_min=0, quant_max=15, dtype=torch.quint8, qscheme=torch.per_tensor_affine, **kwargs, ) FourBitFakeQuantize = FakeQuantize.with_args(observer=FourBitObserver) four_bit_qconfig = QConfig(activation=FourBitFakeQuantize, weight=FourBitFakeQuantize) class QATLinear(nn.Linear): """Linear layer with fake quantization for QAT.""" def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: super().__init__(in_features, out_features, bias) self.weight_fake_quant = FourBitFakeQuantize() self.activation_post_process = FourBitFakeQuantize() @classmethod def from_float(cls, mod: nn.Linear) -> "QATLinear": qat = cls(mod.in_features, mod.out_features, mod.bias is not None) qat.weight = mod.weight qat.bias = mod.bias return qat def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.activation_post_process(x) w = self.weight_fake_quant(self.weight) return nn.functional.linear(x, w, self.bias) def prepare_qat_fx(model: BitTransformerLM) -> BitTransformerLM: """Prepare BitTransformerLM for quantization-aware training.""" for name, module in model.named_children(): if isinstance(module, nn.Linear): setattr(model, name, QATLinear.from_float(module)) else: prepare_qat_fx(module) return model def convert_qat_fx(model: BitTransformerLM) -> BitTransformerLM: """Convert a QAT-prepared model to a quantized version.""" for name, module in model.named_children(): if isinstance(module, QATLinear): w = module.weight.data qmin, qmax = 0, 15 min_w = w.min() max_w = w.max() scale = (max_w - min_w) / (qmax - qmin + 1e-8) zero_point = qmin - torch.round(min_w / scale) q_w = torch.clamp(torch.round(w / scale + zero_point), qmin, qmax) new_mod = nn.Linear(module.in_features, module.out_features, module.bias is not None) new_mod.weight = nn.Parameter((q_w - zero_point) * scale) if module.bias is not None: new_mod.bias = nn.Parameter(module.bias.data) setattr(model, name, new_mod) else: convert_qat_fx(module) return model