File size: 6,206 Bytes
cc63976
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import torch
from typing import List, Union, Optional
from .types import BitTensor, BitSequence, TensorLike


def compress_bits(bits: torch.Tensor) -> torch.Tensor:
    """Run-length encode a 1D tensor of bits.

    Args:
        bits: 1D tensor with values 0 or 1 (bool or uint8).

    Returns:
        1D uint8 tensor containing interleaved values and run lengths.
    """
    if bits.dim() != 1:
        raise ValueError("compress_bits expects a 1D tensor")
    b = bits.to(torch.uint8).flatten()
    if b.numel() == 0:
        return b
    changes = torch.nonzero(b[1:] != b[:-1]).flatten().to(torch.long) + 1
    starts = torch.cat([b.new_tensor([0], dtype=torch.long), changes])
    ends = torch.cat([changes, b.new_tensor([b.numel()], dtype=torch.long)])
    values = b[starts.to(torch.long)]
    counts = ends - starts

    out_vals: List[int] = []
    out_counts: List[int] = []
    for v, c in zip(values.tolist(), counts.tolist()):
        while c > 255:
            out_vals.append(v)
            out_counts.append(255)
            c -= 255
        out_vals.append(v)
        out_counts.append(c)
    values_tensor = torch.tensor(out_vals, dtype=torch.uint8)
    counts_tensor = torch.tensor(out_counts, dtype=torch.uint8)
    out = torch.stack([values_tensor, counts_tensor], dim=1).flatten()
    return out


def decompress_bits(compressed: torch.Tensor) -> torch.Tensor:
    """Decode a run-length encoded bit tensor."""
    if compressed.dim() != 1 or compressed.numel() % 2 != 0:
        raise ValueError("compressed tensor must be 1D even-length")
    data = compressed.to(torch.uint8)
    values = data[0::2]
    counts = data[1::2].to(torch.long)
    return torch.repeat_interleave(values, counts)


def compress_bits_batch(bits_batch: torch.Tensor) -> List[torch.Tensor]:
    """Run-length encode a batch of 1D bit tensors efficiently.
    
    Args:
        bits_batch: 2D tensor [batch_size, seq_len] or list of 1D tensors
        
    Returns:
        List of compressed tensors for each sequence in batch
    """
    if isinstance(bits_batch, torch.Tensor):
        if bits_batch.dim() == 2:
            # Process each sequence in parallel using vectorized operations where possible
            batch_size, seq_len = bits_batch.shape
            compressed_sequences = []
            
            # Vectorized processing for better performance
            bits_batch = bits_batch.to(torch.uint8)
            for i in range(batch_size):
                compressed_sequences.append(compress_bits(bits_batch[i]))
            return compressed_sequences
        else:
            return [compress_bits(bits_batch)]
    else:
        # Handle list input
        return [compress_bits(seq) for seq in bits_batch]


def model_output_decompress(compressed_batch: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor:
    """Decompress a batch of compressed bit sequences with improved error handling."""
    if isinstance(compressed_batch, torch.Tensor) and compressed_batch.dim() == 1:
        sequences = [decompress_bits(compressed_batch)]
    else:
        sequences = []
        for row in compressed_batch:
            try:
                sequences.append(decompress_bits(row))
            except Exception as e:
                # Graceful error recovery - return zeros if decompression fails
                sequences.append(torch.zeros(1, dtype=torch.uint8))
                
    lengths = [seq.numel() for seq in sequences]
    if len(set(lengths)) != 1:
        # Handle variable lengths by padding to max length
        max_length = max(lengths)
        padded_sequences = []
        for seq in sequences:
            if seq.numel() < max_length:
                padding = torch.zeros(max_length - seq.numel(), dtype=seq.dtype, device=seq.device)
                seq = torch.cat([seq, padding])
            padded_sequences.append(seq)
        return torch.stack(padded_sequences)
    return torch.stack(sequences)


def compress_bits_parallel(bits_batch: torch.Tensor, num_workers: int = 4) -> List[torch.Tensor]:
    """Parallel compression for very large batches using multiprocessing.
    
    Args:
        bits_batch: 2D tensor [batch_size, seq_len] 
        num_workers: Number of parallel workers
        
    Returns:
        List of compressed tensors
    """
    import concurrent.futures
    import threading
    
    if bits_batch.dim() != 2:
        raise ValueError("bits_batch must be 2D [batch_size, seq_len]")
    
    batch_size = bits_batch.shape[0]
    if batch_size < num_workers * 2:  # Not worth parallelizing small batches
        return compress_bits_batch(bits_batch)
    
    # Split batch into chunks for parallel processing
    chunk_size = max(1, batch_size // num_workers)
    chunks = [bits_batch[i:i + chunk_size] for i in range(0, batch_size, chunk_size)]
    
    compressed_results = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = [executor.submit(compress_bits_batch, chunk) for chunk in chunks]
        for future in concurrent.futures.as_completed(futures):
            try:
                result = future.result()
                compressed_results.extend(result)
            except Exception as e:
                # Fallback to single-threaded processing on error
                print(f"Parallel compression failed: {e}, falling back to sequential processing")
                return compress_bits_batch(bits_batch)
    
    return compressed_results


import numpy as np


def pack_bits(bits: torch.Tensor) -> torch.Tensor:
    """Pack groups of 8 bits into uint8 values using numpy.packbits."""
    if bits.dim() != 1:
        raise ValueError("pack_bits expects a 1D tensor")
    arr = bits.to(torch.uint8).cpu().numpy()
    packed = np.packbits(arr)
    return torch.from_numpy(packed)


def unpack_bits(packed: torch.Tensor, *, n_bits: Optional[int] = None) -> torch.Tensor:
    """Unpack uint8 values back into a bit tensor."""
    if packed.dim() != 1:
        raise ValueError("unpack_bits expects a 1D tensor")
    arr = np.unpackbits(packed.to(torch.uint8).cpu().numpy())
    if n_bits is not None:
        arr = arr[:n_bits]
    return torch.from_numpy(arr)