|
import torch |
|
from typing import List, Union, Optional |
|
from .types import BitTensor, BitSequence, TensorLike |
|
|
|
|
|
def compress_bits(bits: torch.Tensor) -> torch.Tensor: |
|
"""Run-length encode a 1D tensor of bits. |
|
|
|
Args: |
|
bits: 1D tensor with values 0 or 1 (bool or uint8). |
|
|
|
Returns: |
|
1D uint8 tensor containing interleaved values and run lengths. |
|
""" |
|
if bits.dim() != 1: |
|
raise ValueError("compress_bits expects a 1D tensor") |
|
b = bits.to(torch.uint8).flatten() |
|
if b.numel() == 0: |
|
return b |
|
changes = torch.nonzero(b[1:] != b[:-1]).flatten().to(torch.long) + 1 |
|
starts = torch.cat([b.new_tensor([0], dtype=torch.long), changes]) |
|
ends = torch.cat([changes, b.new_tensor([b.numel()], dtype=torch.long)]) |
|
values = b[starts.to(torch.long)] |
|
counts = ends - starts |
|
|
|
out_vals: List[int] = [] |
|
out_counts: List[int] = [] |
|
for v, c in zip(values.tolist(), counts.tolist()): |
|
while c > 255: |
|
out_vals.append(v) |
|
out_counts.append(255) |
|
c -= 255 |
|
out_vals.append(v) |
|
out_counts.append(c) |
|
values_tensor = torch.tensor(out_vals, dtype=torch.uint8) |
|
counts_tensor = torch.tensor(out_counts, dtype=torch.uint8) |
|
out = torch.stack([values_tensor, counts_tensor], dim=1).flatten() |
|
return out |
|
|
|
|
|
def decompress_bits(compressed: torch.Tensor) -> torch.Tensor: |
|
"""Decode a run-length encoded bit tensor.""" |
|
if compressed.dim() != 1 or compressed.numel() % 2 != 0: |
|
raise ValueError("compressed tensor must be 1D even-length") |
|
data = compressed.to(torch.uint8) |
|
values = data[0::2] |
|
counts = data[1::2].to(torch.long) |
|
return torch.repeat_interleave(values, counts) |
|
|
|
|
|
def compress_bits_batch(bits_batch: torch.Tensor) -> List[torch.Tensor]: |
|
"""Run-length encode a batch of 1D bit tensors efficiently. |
|
|
|
Args: |
|
bits_batch: 2D tensor [batch_size, seq_len] or list of 1D tensors |
|
|
|
Returns: |
|
List of compressed tensors for each sequence in batch |
|
""" |
|
if isinstance(bits_batch, torch.Tensor): |
|
if bits_batch.dim() == 2: |
|
|
|
batch_size, seq_len = bits_batch.shape |
|
compressed_sequences = [] |
|
|
|
|
|
bits_batch = bits_batch.to(torch.uint8) |
|
for i in range(batch_size): |
|
compressed_sequences.append(compress_bits(bits_batch[i])) |
|
return compressed_sequences |
|
else: |
|
return [compress_bits(bits_batch)] |
|
else: |
|
|
|
return [compress_bits(seq) for seq in bits_batch] |
|
|
|
|
|
def model_output_decompress(compressed_batch: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor: |
|
"""Decompress a batch of compressed bit sequences with improved error handling.""" |
|
if isinstance(compressed_batch, torch.Tensor) and compressed_batch.dim() == 1: |
|
sequences = [decompress_bits(compressed_batch)] |
|
else: |
|
sequences = [] |
|
for row in compressed_batch: |
|
try: |
|
sequences.append(decompress_bits(row)) |
|
except Exception as e: |
|
|
|
sequences.append(torch.zeros(1, dtype=torch.uint8)) |
|
|
|
lengths = [seq.numel() for seq in sequences] |
|
if len(set(lengths)) != 1: |
|
|
|
max_length = max(lengths) |
|
padded_sequences = [] |
|
for seq in sequences: |
|
if seq.numel() < max_length: |
|
padding = torch.zeros(max_length - seq.numel(), dtype=seq.dtype, device=seq.device) |
|
seq = torch.cat([seq, padding]) |
|
padded_sequences.append(seq) |
|
return torch.stack(padded_sequences) |
|
return torch.stack(sequences) |
|
|
|
|
|
def compress_bits_parallel(bits_batch: torch.Tensor, num_workers: int = 4) -> List[torch.Tensor]: |
|
"""Parallel compression for very large batches using multiprocessing. |
|
|
|
Args: |
|
bits_batch: 2D tensor [batch_size, seq_len] |
|
num_workers: Number of parallel workers |
|
|
|
Returns: |
|
List of compressed tensors |
|
""" |
|
import concurrent.futures |
|
import threading |
|
|
|
if bits_batch.dim() != 2: |
|
raise ValueError("bits_batch must be 2D [batch_size, seq_len]") |
|
|
|
batch_size = bits_batch.shape[0] |
|
if batch_size < num_workers * 2: |
|
return compress_bits_batch(bits_batch) |
|
|
|
|
|
chunk_size = max(1, batch_size // num_workers) |
|
chunks = [bits_batch[i:i + chunk_size] for i in range(0, batch_size, chunk_size)] |
|
|
|
compressed_results = [] |
|
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: |
|
futures = [executor.submit(compress_bits_batch, chunk) for chunk in chunks] |
|
for future in concurrent.futures.as_completed(futures): |
|
try: |
|
result = future.result() |
|
compressed_results.extend(result) |
|
except Exception as e: |
|
|
|
print(f"Parallel compression failed: {e}, falling back to sequential processing") |
|
return compress_bits_batch(bits_batch) |
|
|
|
return compressed_results |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
def pack_bits(bits: torch.Tensor) -> torch.Tensor: |
|
"""Pack groups of 8 bits into uint8 values using numpy.packbits.""" |
|
if bits.dim() != 1: |
|
raise ValueError("pack_bits expects a 1D tensor") |
|
arr = bits.to(torch.uint8).cpu().numpy() |
|
packed = np.packbits(arr) |
|
return torch.from_numpy(packed) |
|
|
|
|
|
def unpack_bits(packed: torch.Tensor, *, n_bits: Optional[int] = None) -> torch.Tensor: |
|
"""Unpack uint8 values back into a bit tensor.""" |
|
if packed.dim() != 1: |
|
raise ValueError("unpack_bits expects a 1D tensor") |
|
arr = np.unpackbits(packed.to(torch.uint8).cpu().numpy()) |
|
if n_bits is not None: |
|
arr = arr[:n_bits] |
|
return torch.from_numpy(arr) |
|
|
|
|