| import sys | |
| import pathlib | |
| import torch | |
| sys.path.append(str(pathlib.Path(__file__).resolve().parents[1])) | |
| from bit_transformer.parity import enforce_parity | |
| def test_enforce_parity_corrects_bits(): | |
| bits = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0]]) | |
| corrected, count = enforce_parity(bits) | |
| assert count == 1 | |
| payload = corrected.view(-1, 9)[0, :8] | |
| parity = corrected.view(-1, 9)[0, 8] | |
| assert parity.item() == int(payload.sum() % 2) | |