Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`.""" | |
| import io | |
| import json | |
| import struct | |
| import typing as tp | |
| # format is `ECDC` magic code, followed by the header size as uint32. | |
| # Then an uint8 indicates the protocol version (0.) | |
| # The header is then provided as json and should contain all required | |
| # informations for decoding. A raw stream of bytes is then provided | |
| # and should be interpretable using the json header. | |
| _encodec_header_struct = struct.Struct('!4sBI') | |
| _ENCODEC_MAGIC = b'ECDC' | |
| def write_ecdc_header(fo: tp.IO[bytes], metadata: tp.Any): | |
| meta_dumped = json.dumps(metadata).encode('utf-8') | |
| version = 0 | |
| header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version, | |
| len(meta_dumped)) | |
| fo.write(header) | |
| fo.write(meta_dumped) | |
| fo.flush() | |
| def _read_exactly(fo: tp.IO[bytes], size: int) -> bytes: | |
| buf = b"" | |
| while len(buf) < size: | |
| new_buf = fo.read(size) | |
| if not new_buf: | |
| raise EOFError("Impossible to read enough data from the stream, " | |
| f"{size} bytes remaining.") | |
| buf += new_buf | |
| size -= len(new_buf) | |
| return buf | |
| def read_ecdc_header(fo: tp.IO[bytes]): | |
| header_bytes = _read_exactly(fo, _encodec_header_struct.size) | |
| magic, version, meta_size = _encodec_header_struct.unpack(header_bytes) | |
| if magic != _ENCODEC_MAGIC: | |
| raise ValueError("File is not in ECDC format.") | |
| if version != 0: | |
| raise ValueError("Version not supported.") | |
| meta_bytes = _read_exactly(fo, meta_size) | |
| return json.loads(meta_bytes.decode('utf-8')) | |
| class BitPacker: | |
| """Simple bit packer to handle ints with a non standard width, e.g. 10 bits. | |
| Note that for some bandwidth (1.5, 3), the codebook representation | |
| will not cover an integer number of bytes. | |
| Args: | |
| bits (int): number of bits per value that will be pushed. | |
| fo (IO[bytes]): file-object to push the bytes to. | |
| """ | |
| def __init__(self, bits: int, fo: tp.IO[bytes]): | |
| self._current_value = 0 | |
| self._current_bits = 0 | |
| self.bits = bits | |
| self.fo = fo | |
| def push(self, value: int): | |
| """Push a new value to the stream. This will immediately | |
| write as many uint8 as possible to the underlying file-object.""" | |
| self._current_value += (value << self._current_bits) | |
| self._current_bits += self.bits | |
| while self._current_bits >= 8: | |
| lower_8bits = self._current_value & 0xff | |
| self._current_bits -= 8 | |
| self._current_value >>= 8 | |
| self.fo.write(bytes([lower_8bits])) | |
| def flush(self): | |
| """Flushes the remaining partial uint8, call this at the end | |
| of the stream to encode.""" | |
| if self._current_bits: | |
| self.fo.write(bytes([self._current_value])) | |
| self._current_value = 0 | |
| self._current_bits = 0 | |
| self.fo.flush() | |
| class BitUnpacker: | |
| """BitUnpacker does the opposite of `BitPacker`. | |
| Args: | |
| bits (int): number of bits of the values to decode. | |
| fo (IO[bytes]): file-object to push the bytes to. | |
| """ | |
| def __init__(self, bits: int, fo: tp.IO[bytes]): | |
| self.bits = bits | |
| self.fo = fo | |
| self._mask = (1 << bits) - 1 | |
| self._current_value = 0 | |
| self._current_bits = 0 | |
| def pull(self) -> tp.Optional[int]: | |
| """ | |
| Pull a single value from the stream, potentially reading some | |
| extra bytes from the underlying file-object. | |
| Returns `None` when reaching the end of the stream. | |
| """ | |
| while self._current_bits < self.bits: | |
| buf = self.fo.read(1) | |
| if not buf: | |
| return None | |
| character = buf[0] | |
| self._current_value += character << self._current_bits | |
| self._current_bits += 8 | |
| out = self._current_value & self._mask | |
| self._current_value >>= self.bits | |
| self._current_bits -= self.bits | |
| return out | |
| def test(): | |
| import torch | |
| torch.manual_seed(1234) | |
| for rep in range(4): | |
| length: int = torch.randint(10, 2_000, (1, )).item() | |
| bits: int = torch.randint(1, 16, (1, )).item() | |
| tokens: tp.List[int] = torch.randint(2**bits, (length, )).tolist() | |
| rebuilt: tp.List[int] = [] | |
| buf = io.BytesIO() | |
| packer = BitPacker(bits, buf) | |
| for token in tokens: | |
| packer.push(token) | |
| packer.flush() | |
| buf.seek(0) | |
| unpacker = BitUnpacker(bits, buf) | |
| while True: | |
| value = unpacker.pull() | |
| if value is None: | |
| break | |
| rebuilt.append(value) | |
| assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens)) | |
| # The flushing mechanism might lead to "ghost" values at the end of the stream. | |
| assert len(rebuilt) <= len(tokens) + 8 // bits, (len(rebuilt), | |
| len(tokens), bits) | |
| for idx, (a, b) in enumerate(zip(tokens, rebuilt)): | |
| assert a == b, (idx, a, b) | |
| if __name__ == '__main__': | |
| test() | |