import os, sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from bit_transformer import ( BitTransformerLM, hil_safe_inference, text_to_bits, bits_to_text, plot_telemetry, infer_long_sequence, diffusion_inference, compress_bits, ) from bit_transformer.safety import SafetyGate import torch import torch.nn.functional as F import torch.nn as nn import pytest def test_forward_pass(): B, L = 2, 8 model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=L) bits = torch.randint(0, 2, (B, L), dtype=torch.long) logits, telemetry = model(bits) assert logits.shape == (B, L, 2) required_keys = { "negentropy_input", "lz_complexity_input", "negentropy_logits", "lz_complexity_logits", "symbiosis_kl", "symbiosis_score", "attention_entropy", "attention_entropy_mean", } assert required_keys.issubset(telemetry.keys()) pred = logits[:, :-1, :].reshape(-1, 2) target = bits[:, 1:].reshape(-1) loss = F.cross_entropy(pred, target) assert torch.isfinite(loss) def test_autocast_forward(): model = BitTransformerLM( d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8, use_autocast=True, ) bits = torch.randint(0, 2, (1, 8), dtype=torch.long) logits, _ = model(bits) assert logits.shape == (1, 8, 2) def test_act_forward(): model = BitTransformerLM( d_model=32, nhead=4, num_layers=2, dim_feedforward=64, max_seq_len=8, use_act=True, ) bits = torch.randint(0, 2, (1, 8), dtype=torch.long) logits, tele = model(bits) assert logits.shape == (1, 8, 2) assert "halt_probs" in tele def test_act_skips_layers(): model = BitTransformerLM( d_model=16, nhead=4, num_layers=3, dim_feedforward=32, max_seq_len=8, use_act=True, act_threshold=0.5, ) for proj in model.halt_projs: nn.init.constant_(proj.weight, 0.0) nn.init.constant_(proj.bias, 10.0) bits = torch.randint(0, 2, (1, 8), dtype=torch.long) _, tele = model(bits) assert len(tele["halt_probs"]) < model.num_layers def test_hil_safety_gate(): model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) bits = torch.randint(0, 2, (1, 8), dtype=torch.long) # Expect gate triggered with high floors raised = False try: hil_safe_inference(model, bits, c_floor=1.0, s_floor=1.0) except RuntimeError: raised = True assert raised def test_hil_safety_non_strict(): model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) bits = torch.randint(0, 2, (1, 8), dtype=torch.long) out, _ = hil_safe_inference(model, bits, c_floor=1.0, s_floor=1.0, strict=False) assert out.shape == bits.shape def test_safety_gate_burn_in(): model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) bits = torch.randint(0, 2, (1, 8), dtype=torch.long) gate = SafetyGate(c_floor=1.0, s_floor=1.0, burn_in=1) hil_safe_inference(model, bits, gate=gate) with pytest.raises(RuntimeError): hil_safe_inference(model, bits, gate=gate) def test_bit_io_roundtrip(): text = "hello" bits = text_to_bits(text) assert bits_to_text(bits) == text def test_plot_telemetry(): log = { "negentropy": [0.6, 0.7, 0.4], "lz_complexity": [0.5, 0.45, 0.6], "symbiosis_score": [0.55, 0.6, 0.3], "clusters": [0, 0, 1], } fig, axes = plot_telemetry(log) assert len(axes) == 3 fig.clf() def test_metric_no_gradient_flow(): model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) bits = torch.randint(0, 2, (2, 8), dtype=torch.long) logits, _ = model(bits) loss = model.negentropy_logits(logits).mean() + model.lz_complexity_logits(logits).mean() assert not loss.requires_grad with pytest.raises(RuntimeError): loss.backward() def test_negentropy_decompression_edge_case(): bits = torch.tensor([0, 1] * 8, dtype=torch.uint8) comp = compress_bits(bits) model = BitTransformerLM(d_model=16, nhead=2, num_layers=1, dim_feedforward=32, max_seq_len=bits.numel()) neg_comp = model.negentropy_kpi(comp.unsqueeze(0)) neg_raw = model.negentropy_kpi(bits.unsqueeze(0)) assert torch.allclose(neg_comp, neg_raw, atol=1e-6) def test_dynamic_quantization(): model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) from bit_transformer import quantize_dynamic qmodel = quantize_dynamic(model) bits = torch.randint(0, 2, (1, 8), dtype=torch.long) logits, _ = qmodel(bits) assert logits.shape == (1, 8, 2) def test_qat_fx_roundtrip(): model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) from bit_transformer import prepare_qat_fx, convert_qat_fx example_bits = torch.randint(0, 2, (1, 8), dtype=torch.long) qat_model = prepare_qat_fx(model) qat_model.eval() qmodel = convert_qat_fx(qat_model) logits, _ = qmodel(example_bits) assert logits.shape == (1, 8, 2) def test_fsdp_wrap(): import os import torch import torch.distributed as dist from bit_transformer import BitTransformerLM, wrap_fsdp if not dist.is_initialized(): os.environ.setdefault("MASTER_ADDR", "localhost") os.environ.setdefault("MASTER_PORT", "29500") dist.init_process_group("gloo", rank=0, world_size=1) if not torch.cuda.is_available(): pytest.skip("CUDA not available") model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) fsdp_model = wrap_fsdp(model) bits = torch.randint(0, 2, (1, 8), dtype=torch.long) logits, _ = fsdp_model(bits) assert logits.shape == (1, 8, 2) dist.destroy_process_group() def test_make_pipeline(): import pytest import torch.distributed.rpc as rpc from bit_transformer import BitTransformerLM, make_pipeline if not rpc._is_current_rpc_agent_set(): pytest.skip("RPC not initialized") model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) pipe_model = make_pipeline(model, chunks=1) bits = torch.randint(0, 2, (1, 8), dtype=torch.long) logits, _ = pipe_model(bits) assert logits.shape == (1, 8, 2) def test_causal_attention(): model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) bits = torch.randint(0, 2, (1, 8), dtype=torch.long) logits, tele = model(bits, causal=True) assert logits.shape == (1, 8, 2) attn = tele["attention_maps"][0] upper = attn.triu(1) assert torch.allclose(upper, torch.zeros_like(upper)) def test_scaling_helpers(): model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) model = model.double_width() assert model.d_model == 64 model = model.double_layers() assert model.num_layers == 2 def test_expand_positional_encoding(): model = BitTransformerLM(d_model=16, nhead=4, num_layers=1, dim_feedforward=32, max_seq_len=8) model.expand_positional_encoding(16) assert model.pos_enc.pe.size(0) == 16 def test_infer_long_sequence(): model = BitTransformerLM(d_model=16, nhead=4, num_layers=1, dim_feedforward=32, max_seq_len=8) bits = torch.randint(0, 2, (12,), dtype=torch.long) preds, logs = infer_long_sequence(model, bits, ctx_bits=8, overlap=4) assert len(preds) == 12 assert len(logs) >= 2 def test_chunking_disabled_when_non_causal(): model = BitTransformerLM( d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8, chunk_size=2, full_attn_logging=True, ) # Zero query/key/value projections so attention is uniformly distributed. # This makes the test deterministic: any non-masked position receives equal # weight, allowing us to rely solely on the chunking mask for the check. nn.init.constant_(model.layers[0].self_attn.in_proj_weight, 0.0) nn.init.constant_(model.layers[0].self_attn.in_proj_bias, 0.0) # Disable dropout for deterministic attention weights. model.eval() for module in model.modules(): if isinstance(module, nn.Dropout): module.p = 0.0 model.layers[0].self_attn.dropout = 0.0 bits = torch.randint(0, 2, (1, 8), dtype=torch.long) _, tele_causal = model(bits, causal=True) _, tele_noncausal = model(bits, causal=False) attn_causal = tele_causal["attention_maps"][0] attn_noncausal = tele_noncausal["attention_maps"][0] # Causal mode keeps attention within chunk boundaries, while non-causal mode # should permit cross-chunk attention. assert attn_causal[0, 0, 0, 4] == 0 assert attn_noncausal[0, 0, 0, 4] > 0 def test_diffusion_inference_generates_bits(): model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) out = diffusion_inference(model, length=8, steps=2, batch_size=2) assert out.shape == (2, 8) assert set(out.unique().tolist()).issubset({0, 1}) def test_diffusion_inference_cosine_schedule(): model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) out = diffusion_inference(model, length=8, steps=2, schedule="cosine") assert out.shape == (1, 8) def test_chunking_restored_after_diffusion(): model = BitTransformerLM( d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8, chunk_size=2, full_attn_logging=True, ) bits = torch.randint(0, 2, (1, 8), dtype=torch.long) _ = model(bits, causal=False) assert model.layers[0].chunk_size == 2 _, tele = model(bits, causal=True) attn = tele["attention_maps"][0] assert attn[0, 0, 0, 4] == 0