WCNegentropy's picture
πŸš€ Final optimization: Update compression.py with production-ready enhancements
cc63976 verified
raw
history blame
6.21 kB
import torch
from typing import List, Union, Optional
from .types import BitTensor, BitSequence, TensorLike
def compress_bits(bits: torch.Tensor) -> torch.Tensor:
"""Run-length encode a 1D tensor of bits.
Args:
bits: 1D tensor with values 0 or 1 (bool or uint8).
Returns:
1D uint8 tensor containing interleaved values and run lengths.
"""
if bits.dim() != 1:
raise ValueError("compress_bits expects a 1D tensor")
b = bits.to(torch.uint8).flatten()
if b.numel() == 0:
return b
changes = torch.nonzero(b[1:] != b[:-1]).flatten().to(torch.long) + 1
starts = torch.cat([b.new_tensor([0], dtype=torch.long), changes])
ends = torch.cat([changes, b.new_tensor([b.numel()], dtype=torch.long)])
values = b[starts.to(torch.long)]
counts = ends - starts
out_vals: List[int] = []
out_counts: List[int] = []
for v, c in zip(values.tolist(), counts.tolist()):
while c > 255:
out_vals.append(v)
out_counts.append(255)
c -= 255
out_vals.append(v)
out_counts.append(c)
values_tensor = torch.tensor(out_vals, dtype=torch.uint8)
counts_tensor = torch.tensor(out_counts, dtype=torch.uint8)
out = torch.stack([values_tensor, counts_tensor], dim=1).flatten()
return out
def decompress_bits(compressed: torch.Tensor) -> torch.Tensor:
"""Decode a run-length encoded bit tensor."""
if compressed.dim() != 1 or compressed.numel() % 2 != 0:
raise ValueError("compressed tensor must be 1D even-length")
data = compressed.to(torch.uint8)
values = data[0::2]
counts = data[1::2].to(torch.long)
return torch.repeat_interleave(values, counts)
def compress_bits_batch(bits_batch: torch.Tensor) -> List[torch.Tensor]:
"""Run-length encode a batch of 1D bit tensors efficiently.
Args:
bits_batch: 2D tensor [batch_size, seq_len] or list of 1D tensors
Returns:
List of compressed tensors for each sequence in batch
"""
if isinstance(bits_batch, torch.Tensor):
if bits_batch.dim() == 2:
# Process each sequence in parallel using vectorized operations where possible
batch_size, seq_len = bits_batch.shape
compressed_sequences = []
# Vectorized processing for better performance
bits_batch = bits_batch.to(torch.uint8)
for i in range(batch_size):
compressed_sequences.append(compress_bits(bits_batch[i]))
return compressed_sequences
else:
return [compress_bits(bits_batch)]
else:
# Handle list input
return [compress_bits(seq) for seq in bits_batch]
def model_output_decompress(compressed_batch: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor:
"""Decompress a batch of compressed bit sequences with improved error handling."""
if isinstance(compressed_batch, torch.Tensor) and compressed_batch.dim() == 1:
sequences = [decompress_bits(compressed_batch)]
else:
sequences = []
for row in compressed_batch:
try:
sequences.append(decompress_bits(row))
except Exception as e:
# Graceful error recovery - return zeros if decompression fails
sequences.append(torch.zeros(1, dtype=torch.uint8))
lengths = [seq.numel() for seq in sequences]
if len(set(lengths)) != 1:
# Handle variable lengths by padding to max length
max_length = max(lengths)
padded_sequences = []
for seq in sequences:
if seq.numel() < max_length:
padding = torch.zeros(max_length - seq.numel(), dtype=seq.dtype, device=seq.device)
seq = torch.cat([seq, padding])
padded_sequences.append(seq)
return torch.stack(padded_sequences)
return torch.stack(sequences)
def compress_bits_parallel(bits_batch: torch.Tensor, num_workers: int = 4) -> List[torch.Tensor]:
"""Parallel compression for very large batches using multiprocessing.
Args:
bits_batch: 2D tensor [batch_size, seq_len]
num_workers: Number of parallel workers
Returns:
List of compressed tensors
"""
import concurrent.futures
import threading
if bits_batch.dim() != 2:
raise ValueError("bits_batch must be 2D [batch_size, seq_len]")
batch_size = bits_batch.shape[0]
if batch_size < num_workers * 2: # Not worth parallelizing small batches
return compress_bits_batch(bits_batch)
# Split batch into chunks for parallel processing
chunk_size = max(1, batch_size // num_workers)
chunks = [bits_batch[i:i + chunk_size] for i in range(0, batch_size, chunk_size)]
compressed_results = []
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(compress_bits_batch, chunk) for chunk in chunks]
for future in concurrent.futures.as_completed(futures):
try:
result = future.result()
compressed_results.extend(result)
except Exception as e:
# Fallback to single-threaded processing on error
print(f"Parallel compression failed: {e}, falling back to sequential processing")
return compress_bits_batch(bits_batch)
return compressed_results
import numpy as np
def pack_bits(bits: torch.Tensor) -> torch.Tensor:
"""Pack groups of 8 bits into uint8 values using numpy.packbits."""
if bits.dim() != 1:
raise ValueError("pack_bits expects a 1D tensor")
arr = bits.to(torch.uint8).cpu().numpy()
packed = np.packbits(arr)
return torch.from_numpy(packed)
def unpack_bits(packed: torch.Tensor, *, n_bits: Optional[int] = None) -> torch.Tensor:
"""Unpack uint8 values back into a bit tensor."""
if packed.dim() != 1:
raise ValueError("unpack_bits expects a 1D tensor")
arr = np.unpackbits(packed.to(torch.uint8).cpu().numpy())
if n_bits is not None:
arr = arr[:n_bits]
return torch.from_numpy(arr)