import os, sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
import torch | |
from bit_transformer import BitTransformerLM, distill_step, TelemetryLog | |
def test_distill_prunes_weights(): | |
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) | |
attn = torch.rand(2, 4, 8, 8) | |
telemetry = TelemetryLog(attention_maps=attn) | |
pruned = distill_step(model, scale=0.5, telemetry=telemetry) | |
total = 0 | |
zeros = 0 | |
for m in pruned.modules(): | |
if isinstance(m, torch.nn.Linear): | |
w = m.weight.detach() | |
total += w.numel() | |
zeros += (w == 0).sum().item() | |
assert zeros >= int(total * 0.5) | |