|
import os, sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) |
|
from bit_transformer import ( |
|
BitTransformerLM, |
|
hil_safe_inference, |
|
text_to_bits, |
|
bits_to_text, |
|
plot_telemetry, |
|
infer_long_sequence, |
|
diffusion_inference, |
|
compress_bits, |
|
) |
|
from bit_transformer.safety import SafetyGate |
|
import torch |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
import pytest |
|
|
|
def test_forward_pass(): |
|
B, L = 2, 8 |
|
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=L) |
|
bits = torch.randint(0, 2, (B, L), dtype=torch.long) |
|
logits, telemetry = model(bits) |
|
assert logits.shape == (B, L, 2) |
|
required_keys = { |
|
"negentropy_input", |
|
"lz_complexity_input", |
|
"negentropy_logits", |
|
"lz_complexity_logits", |
|
"symbiosis_kl", |
|
"symbiosis_score", |
|
"attention_entropy", |
|
"attention_entropy_mean", |
|
} |
|
assert required_keys.issubset(telemetry.keys()) |
|
pred = logits[:, :-1, :].reshape(-1, 2) |
|
target = bits[:, 1:].reshape(-1) |
|
loss = F.cross_entropy(pred, target) |
|
assert torch.isfinite(loss) |
|
|
|
|
|
def test_autocast_forward(): |
|
model = BitTransformerLM( |
|
d_model=32, |
|
nhead=4, |
|
num_layers=1, |
|
dim_feedforward=64, |
|
max_seq_len=8, |
|
use_autocast=True, |
|
) |
|
bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
|
logits, _ = model(bits) |
|
assert logits.shape == (1, 8, 2) |
|
|
|
|
|
def test_act_forward(): |
|
model = BitTransformerLM( |
|
d_model=32, |
|
nhead=4, |
|
num_layers=2, |
|
dim_feedforward=64, |
|
max_seq_len=8, |
|
use_act=True, |
|
) |
|
bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
|
logits, tele = model(bits) |
|
assert logits.shape == (1, 8, 2) |
|
assert "halt_probs" in tele |
|
|
|
|
|
def test_act_skips_layers(): |
|
model = BitTransformerLM( |
|
d_model=16, |
|
nhead=4, |
|
num_layers=3, |
|
dim_feedforward=32, |
|
max_seq_len=8, |
|
use_act=True, |
|
act_threshold=0.5, |
|
) |
|
for proj in model.halt_projs: |
|
nn.init.constant_(proj.weight, 0.0) |
|
nn.init.constant_(proj.bias, 10.0) |
|
bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
|
_, tele = model(bits) |
|
assert len(tele["halt_probs"]) < model.num_layers |
|
|
|
|
|
def test_hil_safety_gate(): |
|
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
|
bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
|
|
|
raised = False |
|
try: |
|
hil_safe_inference(model, bits, c_floor=1.0, s_floor=1.0) |
|
except RuntimeError: |
|
raised = True |
|
assert raised |
|
|
|
|
|
def test_hil_safety_non_strict(): |
|
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
|
bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
|
out, _ = hil_safe_inference(model, bits, c_floor=1.0, s_floor=1.0, strict=False) |
|
assert out.shape == bits.shape |
|
|
|
|
|
def test_safety_gate_burn_in(): |
|
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
|
bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
|
gate = SafetyGate(c_floor=1.0, s_floor=1.0, burn_in=1) |
|
hil_safe_inference(model, bits, gate=gate) |
|
with pytest.raises(RuntimeError): |
|
hil_safe_inference(model, bits, gate=gate) |
|
|
|
|
|
def test_bit_io_roundtrip(): |
|
text = "hello" |
|
bits = text_to_bits(text) |
|
assert bits_to_text(bits) == text |
|
|
|
|
|
def test_plot_telemetry(): |
|
log = { |
|
"negentropy": [0.6, 0.7, 0.4], |
|
"lz_complexity": [0.5, 0.45, 0.6], |
|
"symbiosis_score": [0.55, 0.6, 0.3], |
|
"clusters": [0, 0, 1], |
|
} |
|
fig, axes = plot_telemetry(log) |
|
assert len(axes) == 3 |
|
fig.clf() |
|
|
|
|
|
def test_metric_no_gradient_flow(): |
|
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
|
bits = torch.randint(0, 2, (2, 8), dtype=torch.long) |
|
logits, _ = model(bits) |
|
loss = model.negentropy_logits(logits).mean() + model.lz_complexity_logits(logits).mean() |
|
assert not loss.requires_grad |
|
with pytest.raises(RuntimeError): |
|
loss.backward() |
|
|
|
|
|
def test_negentropy_decompression_edge_case(): |
|
bits = torch.tensor([0, 1] * 8, dtype=torch.uint8) |
|
comp = compress_bits(bits) |
|
model = BitTransformerLM(d_model=16, nhead=2, num_layers=1, dim_feedforward=32, max_seq_len=bits.numel()) |
|
neg_comp = model.negentropy_kpi(comp.unsqueeze(0)) |
|
neg_raw = model.negentropy_kpi(bits.unsqueeze(0)) |
|
assert torch.allclose(neg_comp, neg_raw, atol=1e-6) |
|
|
|
|
|
def test_dynamic_quantization(): |
|
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
|
from bit_transformer import quantize_dynamic |
|
|
|
qmodel = quantize_dynamic(model) |
|
bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
|
logits, _ = qmodel(bits) |
|
assert logits.shape == (1, 8, 2) |
|
|
|
|
|
def test_qat_fx_roundtrip(): |
|
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
|
from bit_transformer import prepare_qat_fx, convert_qat_fx |
|
|
|
example_bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
|
qat_model = prepare_qat_fx(model) |
|
qat_model.eval() |
|
qmodel = convert_qat_fx(qat_model) |
|
|
|
logits, _ = qmodel(example_bits) |
|
assert logits.shape == (1, 8, 2) |
|
|
|
|
|
def test_fsdp_wrap(): |
|
import os |
|
import torch |
|
import torch.distributed as dist |
|
from bit_transformer import BitTransformerLM, wrap_fsdp |
|
|
|
if not dist.is_initialized(): |
|
os.environ.setdefault("MASTER_ADDR", "localhost") |
|
os.environ.setdefault("MASTER_PORT", "29500") |
|
dist.init_process_group("gloo", rank=0, world_size=1) |
|
if not torch.cuda.is_available(): |
|
pytest.skip("CUDA not available") |
|
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
|
fsdp_model = wrap_fsdp(model) |
|
bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
|
logits, _ = fsdp_model(bits) |
|
assert logits.shape == (1, 8, 2) |
|
dist.destroy_process_group() |
|
|
|
|
|
def test_make_pipeline(): |
|
import pytest |
|
import torch.distributed.rpc as rpc |
|
from bit_transformer import BitTransformerLM, make_pipeline |
|
|
|
if not rpc._is_current_rpc_agent_set(): |
|
pytest.skip("RPC not initialized") |
|
|
|
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
|
pipe_model = make_pipeline(model, chunks=1) |
|
bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
|
logits, _ = pipe_model(bits) |
|
assert logits.shape == (1, 8, 2) |
|
|
|
|
|
def test_causal_attention(): |
|
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
|
bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
|
logits, tele = model(bits, causal=True) |
|
assert logits.shape == (1, 8, 2) |
|
attn = tele["attention_maps"][0] |
|
upper = attn.triu(1) |
|
assert torch.allclose(upper, torch.zeros_like(upper)) |
|
|
|
|
|
def test_scaling_helpers(): |
|
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
|
model = model.double_width() |
|
assert model.d_model == 64 |
|
model = model.double_layers() |
|
assert model.num_layers == 2 |
|
|
|
|
|
def test_expand_positional_encoding(): |
|
model = BitTransformerLM(d_model=16, nhead=4, num_layers=1, dim_feedforward=32, max_seq_len=8) |
|
model.expand_positional_encoding(16) |
|
assert model.pos_enc.pe.size(0) == 16 |
|
|
|
|
|
def test_infer_long_sequence(): |
|
model = BitTransformerLM(d_model=16, nhead=4, num_layers=1, dim_feedforward=32, max_seq_len=8) |
|
bits = torch.randint(0, 2, (12,), dtype=torch.long) |
|
preds, logs = infer_long_sequence(model, bits, ctx_bits=8, overlap=4) |
|
assert len(preds) == 12 |
|
assert len(logs) >= 2 |
|
|
|
|
|
def test_chunking_disabled_when_non_causal(): |
|
model = BitTransformerLM( |
|
d_model=32, |
|
nhead=4, |
|
num_layers=1, |
|
dim_feedforward=64, |
|
max_seq_len=8, |
|
chunk_size=2, |
|
full_attn_logging=True, |
|
) |
|
|
|
|
|
|
|
nn.init.constant_(model.layers[0].self_attn.in_proj_weight, 0.0) |
|
nn.init.constant_(model.layers[0].self_attn.in_proj_bias, 0.0) |
|
|
|
model.eval() |
|
for module in model.modules(): |
|
if isinstance(module, nn.Dropout): |
|
module.p = 0.0 |
|
model.layers[0].self_attn.dropout = 0.0 |
|
|
|
bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
|
_, tele_causal = model(bits, causal=True) |
|
_, tele_noncausal = model(bits, causal=False) |
|
attn_causal = tele_causal["attention_maps"][0] |
|
attn_noncausal = tele_noncausal["attention_maps"][0] |
|
|
|
|
|
assert attn_causal[0, 0, 0, 4] == 0 |
|
assert attn_noncausal[0, 0, 0, 4] > 0 |
|
|
|
|
|
def test_diffusion_inference_generates_bits(): |
|
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
|
out = diffusion_inference(model, length=8, steps=2, batch_size=2) |
|
assert out.shape == (2, 8) |
|
assert set(out.unique().tolist()).issubset({0, 1}) |
|
|
|
|
|
def test_diffusion_inference_cosine_schedule(): |
|
model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) |
|
out = diffusion_inference(model, length=8, steps=2, schedule="cosine") |
|
assert out.shape == (1, 8) |
|
|
|
|
|
def test_chunking_restored_after_diffusion(): |
|
model = BitTransformerLM( |
|
d_model=32, |
|
nhead=4, |
|
num_layers=1, |
|
dim_feedforward=64, |
|
max_seq_len=8, |
|
chunk_size=2, |
|
full_attn_logging=True, |
|
) |
|
bits = torch.randint(0, 2, (1, 8), dtype=torch.long) |
|
_ = model(bits, causal=False) |
|
assert model.layers[0].chunk_size == 2 |
|
_, tele = model(bits, causal=True) |
|
attn = tele["attention_maps"][0] |
|
assert attn[0, 0, 0, 4] == 0 |
|
|