File size: 757 Bytes
36c78b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch

def enforce_parity(bits: torch.Tensor) -> tuple[torch.Tensor, int]:
    """Fix parity bits so each 9-bit chunk has even parity.

    Parameters
    ----------
    bits: ``torch.Tensor``
        Tensor of shape ``(..., length)`` where ``length`` is a multiple of 9.

    Returns
    -------
    tuple[torch.Tensor, int]
        Corrected tensor and number of bytes that were adjusted.
    """
    if bits.shape[-1] % 9 != 0:
        raise ValueError("Bit stream length must be multiple of 9")
    flat = bits.clone().view(-1, 9)
    payload = flat[:, :8]
    parity = flat[:, 8]
    new_parity = payload.sum(dim=1) % 2
    corrections = (parity != new_parity).sum().item()
    flat[:, 8] = new_parity
    return flat.view_as(bits), corrections