File size: 719 Bytes
36c78b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import os, sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import torch
from bit_transformer import BitTransformerLM, distill_step, TelemetryLog
def test_distill_prunes_weights():
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
attn = torch.rand(2, 4, 8, 8)
telemetry = TelemetryLog(attention_maps=attn)
pruned = distill_step(model, scale=0.5, telemetry=telemetry)
total = 0
zeros = 0
for m in pruned.modules():
if isinstance(m, torch.nn.Linear):
w = m.weight.detach()
total += w.numel()
zeros += (w == 0).sum().item()
assert zeros >= int(total * 0.5)
|