File size: 615 Bytes
36c78b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import pathlib
import torch
from datasets import load_dataset

TXT_MB = 100
OUT = pathlib.Path('full_bits.pt')


def build_bits(out: pathlib.Path = OUT, txt_mb: int = TXT_MB) -> None:
    ds = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
    buf = bytearray()
    for line in ds['text']:
        buf.extend(line.encode() + b"\n")
        if len(buf) >= txt_mb * 2 ** 20:
            break
    bits = []
    for byte in buf:
        bits.extend(int(b) for b in f'{byte:08b}')
    tensor = torch.tensor(bits, dtype=torch.uint8)
    torch.save(tensor, out)

if __name__ == '__main__':
    build_bits()