File size: 719 Bytes
36c78b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import os, sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import torch

from bit_transformer import BitTransformerLM, distill_step, TelemetryLog


def test_distill_prunes_weights():
    model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
    attn = torch.rand(2, 4, 8, 8)
    telemetry = TelemetryLog(attention_maps=attn)
    pruned = distill_step(model, scale=0.5, telemetry=telemetry)
    total = 0
    zeros = 0
    for m in pruned.modules():
        if isinstance(m, torch.nn.Linear):
            w = m.weight.detach()
            total += w.numel()
            zeros += (w == 0).sum().item()
    assert zeros >= int(total * 0.5)