File size: 2,291 Bytes
36c78b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import torch
from bit_transformer import BitTransformerLM, compress_bits, decompress_bits, model_output_decompress
def test_compress_roundtrip():
bits = torch.randint(0, 2, (16,), dtype=torch.uint8)
comp = compress_bits(bits)
decomp = decompress_bits(comp)
assert torch.equal(bits, decomp)
def test_forward_compressed_equivalence():
B, L = 2, 8
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=L)
model.eval()
bits = torch.randint(0, 2, (B, L), dtype=torch.long)
logits_a, tele_a = model(bits)
compressed = [compress_bits(row.to(torch.uint8)) for row in bits]
logits_b, tele_b = model.forward_compressed(compressed)
assert torch.allclose(logits_a, logits_b)
for key in tele_a:
if isinstance(tele_a[key], list):
continue
assert torch.allclose(tele_a[key], tele_b[key])
def test_model_output_decompress():
bits = torch.randint(0, 2, (2, 8), dtype=torch.uint8)
comp = [compress_bits(row) for row in bits]
decomp = model_output_decompress(comp)
assert torch.equal(decomp, bits)
def test_metrics_on_compressed():
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
bits = torch.randint(0, 2, (2, 8), dtype=torch.uint8)
comps = [compress_bits(row) for row in bits]
comp_batch = torch.nn.utils.rnn.pad_sequence(comps, batch_first=True)
neg = model.negentropy_kpi(comp_batch)
assert neg.shape[0] == bits.size(0)
def test_compress_long_run_split():
bits = torch.zeros(300, dtype=torch.uint8)
comp = compress_bits(bits)
expected = torch.tensor([0, 255, 0, 45], dtype=torch.uint8)
assert torch.equal(comp, expected)
decomp = decompress_bits(comp)
assert torch.equal(decomp, bits)
def test_compress_long_run_with_change():
run1 = torch.ones(260, dtype=torch.uint8)
run2 = torch.zeros(10, dtype=torch.uint8)
bits = torch.cat([run1, run2])
comp = compress_bits(bits)
expected = torch.tensor([1, 255, 1, 5, 0, 10], dtype=torch.uint8)
assert torch.equal(comp, expected)
decomp = decompress_bits(comp)
assert torch.equal(decomp, bits)
|