BitTransformerLM / tests /test_act.py
WCNegentropy's picture
🤖 Updated BitTransformerLM from development space
36c78b1 verified
raw
history blame
595 Bytes
import os, sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import torch
from bit_transformer.model import BitTransformerLM
def test_act_halting():
model = BitTransformerLM(
d_model=16,
nhead=2,
num_layers=3,
dim_feedforward=32,
max_seq_len=8,
use_act=True,
act_threshold=0.1,
)
bits = torch.randint(0, 2, (1, 8), dtype=torch.long)
_, telemetry = model(bits)
halt_probs = torch.stack(telemetry["halt_probs"])[:, 0, 0]
assert (halt_probs < 1).sum().item() < model.num_layers