File size: 2,291 Bytes
36c78b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import torch
from bit_transformer import BitTransformerLM, compress_bits, decompress_bits, model_output_decompress


def test_compress_roundtrip():
    bits = torch.randint(0, 2, (16,), dtype=torch.uint8)
    comp = compress_bits(bits)
    decomp = decompress_bits(comp)
    assert torch.equal(bits, decomp)


def test_forward_compressed_equivalence():
    B, L = 2, 8
    model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=L)
    model.eval()
    bits = torch.randint(0, 2, (B, L), dtype=torch.long)
    logits_a, tele_a = model(bits)
    compressed = [compress_bits(row.to(torch.uint8)) for row in bits]
    logits_b, tele_b = model.forward_compressed(compressed)
    assert torch.allclose(logits_a, logits_b)
    for key in tele_a:
        if isinstance(tele_a[key], list):
            continue
        assert torch.allclose(tele_a[key], tele_b[key])


def test_model_output_decompress():
    bits = torch.randint(0, 2, (2, 8), dtype=torch.uint8)
    comp = [compress_bits(row) for row in bits]
    decomp = model_output_decompress(comp)
    assert torch.equal(decomp, bits)


def test_metrics_on_compressed():
    model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
    bits = torch.randint(0, 2, (2, 8), dtype=torch.uint8)
    comps = [compress_bits(row) for row in bits]
    comp_batch = torch.nn.utils.rnn.pad_sequence(comps, batch_first=True)
    neg = model.negentropy_kpi(comp_batch)
    assert neg.shape[0] == bits.size(0)


def test_compress_long_run_split():
    bits = torch.zeros(300, dtype=torch.uint8)
    comp = compress_bits(bits)
    expected = torch.tensor([0, 255, 0, 45], dtype=torch.uint8)
    assert torch.equal(comp, expected)
    decomp = decompress_bits(comp)
    assert torch.equal(decomp, bits)


def test_compress_long_run_with_change():
    run1 = torch.ones(260, dtype=torch.uint8)
    run2 = torch.zeros(10, dtype=torch.uint8)
    bits = torch.cat([run1, run2])
    comp = compress_bits(bits)
    expected = torch.tensor([1, 255, 1, 5, 0, 10], dtype=torch.uint8)
    assert torch.equal(comp, expected)
    decomp = decompress_bits(comp)
    assert torch.equal(decomp, bits)