import os, sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) import torch from bit_transformer.model import BitTransformerLM def test_act_halting(): model = BitTransformerLM( d_model=16, nhead=2, num_layers=3, dim_feedforward=32, max_seq_len=8, use_act=True, act_threshold=0.1, ) bits = torch.randint(0, 2, (1, 8), dtype=torch.long) _, telemetry = model(bits) halt_probs = torch.stack(telemetry["halt_probs"])[:, 0, 0] assert (halt_probs < 1).sum().item() < model.num_layers