File size: 6,206 Bytes
cc63976 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import torch
from typing import List, Union, Optional
from .types import BitTensor, BitSequence, TensorLike
def compress_bits(bits: torch.Tensor) -> torch.Tensor:
"""Run-length encode a 1D tensor of bits.
Args:
bits: 1D tensor with values 0 or 1 (bool or uint8).
Returns:
1D uint8 tensor containing interleaved values and run lengths.
"""
if bits.dim() != 1:
raise ValueError("compress_bits expects a 1D tensor")
b = bits.to(torch.uint8).flatten()
if b.numel() == 0:
return b
changes = torch.nonzero(b[1:] != b[:-1]).flatten().to(torch.long) + 1
starts = torch.cat([b.new_tensor([0], dtype=torch.long), changes])
ends = torch.cat([changes, b.new_tensor([b.numel()], dtype=torch.long)])
values = b[starts.to(torch.long)]
counts = ends - starts
out_vals: List[int] = []
out_counts: List[int] = []
for v, c in zip(values.tolist(), counts.tolist()):
while c > 255:
out_vals.append(v)
out_counts.append(255)
c -= 255
out_vals.append(v)
out_counts.append(c)
values_tensor = torch.tensor(out_vals, dtype=torch.uint8)
counts_tensor = torch.tensor(out_counts, dtype=torch.uint8)
out = torch.stack([values_tensor, counts_tensor], dim=1).flatten()
return out
def decompress_bits(compressed: torch.Tensor) -> torch.Tensor:
"""Decode a run-length encoded bit tensor."""
if compressed.dim() != 1 or compressed.numel() % 2 != 0:
raise ValueError("compressed tensor must be 1D even-length")
data = compressed.to(torch.uint8)
values = data[0::2]
counts = data[1::2].to(torch.long)
return torch.repeat_interleave(values, counts)
def compress_bits_batch(bits_batch: torch.Tensor) -> List[torch.Tensor]:
"""Run-length encode a batch of 1D bit tensors efficiently.
Args:
bits_batch: 2D tensor [batch_size, seq_len] or list of 1D tensors
Returns:
List of compressed tensors for each sequence in batch
"""
if isinstance(bits_batch, torch.Tensor):
if bits_batch.dim() == 2:
# Process each sequence in parallel using vectorized operations where possible
batch_size, seq_len = bits_batch.shape
compressed_sequences = []
# Vectorized processing for better performance
bits_batch = bits_batch.to(torch.uint8)
for i in range(batch_size):
compressed_sequences.append(compress_bits(bits_batch[i]))
return compressed_sequences
else:
return [compress_bits(bits_batch)]
else:
# Handle list input
return [compress_bits(seq) for seq in bits_batch]
def model_output_decompress(compressed_batch: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor:
"""Decompress a batch of compressed bit sequences with improved error handling."""
if isinstance(compressed_batch, torch.Tensor) and compressed_batch.dim() == 1:
sequences = [decompress_bits(compressed_batch)]
else:
sequences = []
for row in compressed_batch:
try:
sequences.append(decompress_bits(row))
except Exception as e:
# Graceful error recovery - return zeros if decompression fails
sequences.append(torch.zeros(1, dtype=torch.uint8))
lengths = [seq.numel() for seq in sequences]
if len(set(lengths)) != 1:
# Handle variable lengths by padding to max length
max_length = max(lengths)
padded_sequences = []
for seq in sequences:
if seq.numel() < max_length:
padding = torch.zeros(max_length - seq.numel(), dtype=seq.dtype, device=seq.device)
seq = torch.cat([seq, padding])
padded_sequences.append(seq)
return torch.stack(padded_sequences)
return torch.stack(sequences)
def compress_bits_parallel(bits_batch: torch.Tensor, num_workers: int = 4) -> List[torch.Tensor]:
"""Parallel compression for very large batches using multiprocessing.
Args:
bits_batch: 2D tensor [batch_size, seq_len]
num_workers: Number of parallel workers
Returns:
List of compressed tensors
"""
import concurrent.futures
import threading
if bits_batch.dim() != 2:
raise ValueError("bits_batch must be 2D [batch_size, seq_len]")
batch_size = bits_batch.shape[0]
if batch_size < num_workers * 2: # Not worth parallelizing small batches
return compress_bits_batch(bits_batch)
# Split batch into chunks for parallel processing
chunk_size = max(1, batch_size // num_workers)
chunks = [bits_batch[i:i + chunk_size] for i in range(0, batch_size, chunk_size)]
compressed_results = []
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(compress_bits_batch, chunk) for chunk in chunks]
for future in concurrent.futures.as_completed(futures):
try:
result = future.result()
compressed_results.extend(result)
except Exception as e:
# Fallback to single-threaded processing on error
print(f"Parallel compression failed: {e}, falling back to sequential processing")
return compress_bits_batch(bits_batch)
return compressed_results
import numpy as np
def pack_bits(bits: torch.Tensor) -> torch.Tensor:
"""Pack groups of 8 bits into uint8 values using numpy.packbits."""
if bits.dim() != 1:
raise ValueError("pack_bits expects a 1D tensor")
arr = bits.to(torch.uint8).cpu().numpy()
packed = np.packbits(arr)
return torch.from_numpy(packed)
def unpack_bits(packed: torch.Tensor, *, n_bits: Optional[int] = None) -> torch.Tensor:
"""Unpack uint8 values back into a bit tensor."""
if packed.dim() != 1:
raise ValueError("unpack_bits expects a 1D tensor")
arr = np.unpackbits(packed.to(torch.uint8).cpu().numpy())
if n_bits is not None:
arr = arr[:n_bits]
return torch.from_numpy(arr)
|