BitTransformerLM / tests /test_distil.py
WCNegentropy's picture
🤖 Updated BitTransformerLM from development space
36c78b1 verified
raw
history blame
719 Bytes
import os, sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import torch
from bit_transformer import BitTransformerLM, distill_step, TelemetryLog
def test_distill_prunes_weights():
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
attn = torch.rand(2, 4, 8, 8)
telemetry = TelemetryLog(attention_maps=attn)
pruned = distill_step(model, scale=0.5, telemetry=telemetry)
total = 0
zeros = 0
for m in pruned.modules():
if isinstance(m, torch.nn.Linear):
w = m.weight.detach()
total += w.numel()
zeros += (w == 0).sum().item()
assert zeros >= int(total * 0.5)