import torch def enforce_parity(bits: torch.Tensor) -> tuple[torch.Tensor, int]: """Fix parity bits so each 9-bit chunk has even parity. Parameters ---------- bits: ``torch.Tensor`` Tensor of shape ``(..., length)`` where ``length`` is a multiple of 9. Returns ------- tuple[torch.Tensor, int] Corrected tensor and number of bytes that were adjusted. """ if bits.shape[-1] % 9 != 0: raise ValueError("Bit stream length must be multiple of 9") flat = bits.clone().view(-1, 9) payload = flat[:, :8] parity = flat[:, 8] new_parity = payload.sum(dim=1) % 2 corrections = (parity != new_parity).sum().item() flat[:, 8] = new_parity return flat.view_as(bits), corrections