BitTransformerLM / tests /test_compression.py
WCNegentropy's picture
🤖 Updated BitTransformerLM from development space
36c78b1 verified
raw
history blame
2.29 kB
import os, sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import torch
from bit_transformer import BitTransformerLM, compress_bits, decompress_bits, model_output_decompress
def test_compress_roundtrip():
bits = torch.randint(0, 2, (16,), dtype=torch.uint8)
comp = compress_bits(bits)
decomp = decompress_bits(comp)
assert torch.equal(bits, decomp)
def test_forward_compressed_equivalence():
B, L = 2, 8
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=L)
model.eval()
bits = torch.randint(0, 2, (B, L), dtype=torch.long)
logits_a, tele_a = model(bits)
compressed = [compress_bits(row.to(torch.uint8)) for row in bits]
logits_b, tele_b = model.forward_compressed(compressed)
assert torch.allclose(logits_a, logits_b)
for key in tele_a:
if isinstance(tele_a[key], list):
continue
assert torch.allclose(tele_a[key], tele_b[key])
def test_model_output_decompress():
bits = torch.randint(0, 2, (2, 8), dtype=torch.uint8)
comp = [compress_bits(row) for row in bits]
decomp = model_output_decompress(comp)
assert torch.equal(decomp, bits)
def test_metrics_on_compressed():
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
bits = torch.randint(0, 2, (2, 8), dtype=torch.uint8)
comps = [compress_bits(row) for row in bits]
comp_batch = torch.nn.utils.rnn.pad_sequence(comps, batch_first=True)
neg = model.negentropy_kpi(comp_batch)
assert neg.shape[0] == bits.size(0)
def test_compress_long_run_split():
bits = torch.zeros(300, dtype=torch.uint8)
comp = compress_bits(bits)
expected = torch.tensor([0, 255, 0, 45], dtype=torch.uint8)
assert torch.equal(comp, expected)
decomp = decompress_bits(comp)
assert torch.equal(decomp, bits)
def test_compress_long_run_with_change():
run1 = torch.ones(260, dtype=torch.uint8)
run2 = torch.zeros(10, dtype=torch.uint8)
bits = torch.cat([run1, run2])
comp = compress_bits(bits)
expected = torch.tensor([1, 255, 1, 5, 0, 10], dtype=torch.uint8)
assert torch.equal(comp, expected)
decomp = decompress_bits(comp)
assert torch.equal(decomp, bits)