WCNegentropy's picture
🤖 Updated BitTransformerLM from development space
36c78b1 verified
raw
history blame
3.12 kB
import torch
import torch.nn as nn
from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization.observer import MinMaxObserver
from torch.ao.quantization.qconfig import QConfig
from torch.ao.quantization import convert
from .model import BitTransformerLM
def quantize_dynamic(model: BitTransformerLM, dtype: torch.dtype = torch.qint8) -> BitTransformerLM:
"""Return a dynamically quantized copy of the model for inference."""
quantized = torch.quantization.quantize_dynamic(
model, {nn.Linear}, dtype=dtype
)
return quantized
class FourBitObserver(MinMaxObserver):
"""Min-max observer configured for 4-bit quantization."""
def __init__(self, **kwargs):
super().__init__(
quant_min=0,
quant_max=15,
dtype=torch.quint8,
qscheme=torch.per_tensor_affine,
**kwargs,
)
FourBitFakeQuantize = FakeQuantize.with_args(observer=FourBitObserver)
four_bit_qconfig = QConfig(activation=FourBitFakeQuantize, weight=FourBitFakeQuantize)
class QATLinear(nn.Linear):
"""Linear layer with fake quantization for QAT."""
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
super().__init__(in_features, out_features, bias)
self.weight_fake_quant = FourBitFakeQuantize()
self.activation_post_process = FourBitFakeQuantize()
@classmethod
def from_float(cls, mod: nn.Linear) -> "QATLinear":
qat = cls(mod.in_features, mod.out_features, mod.bias is not None)
qat.weight = mod.weight
qat.bias = mod.bias
return qat
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.activation_post_process(x)
w = self.weight_fake_quant(self.weight)
return nn.functional.linear(x, w, self.bias)
def prepare_qat_fx(model: BitTransformerLM) -> BitTransformerLM:
"""Prepare BitTransformerLM for quantization-aware training."""
for name, module in model.named_children():
if isinstance(module, nn.Linear):
setattr(model, name, QATLinear.from_float(module))
else:
prepare_qat_fx(module)
return model
def convert_qat_fx(model: BitTransformerLM) -> BitTransformerLM:
"""Convert a QAT-prepared model to a quantized version."""
for name, module in model.named_children():
if isinstance(module, QATLinear):
w = module.weight.data
qmin, qmax = 0, 15
min_w = w.min()
max_w = w.max()
scale = (max_w - min_w) / (qmax - qmin + 1e-8)
zero_point = qmin - torch.round(min_w / scale)
q_w = torch.clamp(torch.round(w / scale + zero_point), qmin, qmax)
new_mod = nn.Linear(module.in_features, module.out_features, module.bias is not None)
new_mod.weight = nn.Parameter((q_w - zero_point) * scale)
if module.bias is not None:
new_mod.bias = nn.Parameter(module.bias.data)
setattr(model, name, new_mod)
else:
convert_qat_fx(module)
return model