File size: 455 Bytes
36c78b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import sys
import pathlib
import torch

sys.path.append(str(pathlib.Path(__file__).resolve().parents[1]))
from bit_transformer.parity import enforce_parity


def test_enforce_parity_corrects_bits():
    bits = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0]])
    corrected, count = enforce_parity(bits)
    assert count == 1
    payload = corrected.view(-1, 9)[0, :8]
    parity = corrected.view(-1, 9)[0, 8]
    assert parity.item() == int(payload.sum() % 2)