BitTransformerLM / tests /test_bit_io.py
WCNegentropy's picture
🤖 Updated BitTransformerLM from development space
36c78b1 verified
raw
history blame
455 Bytes
import sys
import pathlib
import torch
sys.path.append(str(pathlib.Path(__file__).resolve().parents[1]))
from bit_transformer.parity import enforce_parity
def test_enforce_parity_corrects_bits():
bits = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0]])
corrected, count = enforce_parity(bits)
assert count == 1
payload = corrected.view(-1, 9)[0, :8]
parity = corrected.view(-1, 9)[0, 8]
assert parity.item() == int(payload.sum() % 2)