import os, sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) import torch from bit_transformer import BitTransformerLM, distill_step, TelemetryLog def test_distill_prunes_weights(): model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) attn = torch.rand(2, 4, 8, 8) telemetry = TelemetryLog(attention_maps=attn) pruned = distill_step(model, scale=0.5, telemetry=telemetry) total = 0 zeros = 0 for m in pruned.modules(): if isinstance(m, torch.nn.Linear): w = m.weight.detach() total += w.numel() zeros += (w == 0).sum().item() assert zeros >= int(total * 0.5)