import sys import pathlib import torch sys.path.append(str(pathlib.Path(__file__).resolve().parents[1])) from bit_transformer.parity import enforce_parity def test_enforce_parity_corrects_bits(): bits = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0]]) corrected, count = enforce_parity(bits) assert count == 1 payload = corrected.view(-1, 9)[0, :8] parity = corrected.view(-1, 9)[0, 8] assert parity.item() == int(payload.sum() % 2)