File size: 1,311 Bytes
36c78b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import pathlib
import torch
from bit_transformer import BitTransformerLM

DATA_PATH = pathlib.Path('full_bits.pt')

class BitSeq(torch.utils.data.IterableDataset):
    def __init__(self, path: str | pathlib.Path = DATA_PATH, seq: int = 2048) -> None:
        self.bits = torch.load(path, mmap=True)
        self.seq = seq

    def __len__(self) -> int:
        return (self.bits.numel() // self.seq) - 1

    def __iter__(self):
        N = (self.bits.numel() // self.seq) - 1
        for i in range(N):
            s = i * self.seq
            yield (
                self.bits[s:s+self.seq].long(),
                self.bits[s+1:s+self.seq+1].long(),
            )

def main() -> None:
    dl = torch.utils.data.DataLoader(
        BitSeq(DATA_PATH, seq=2048),
        batch_size=8,
        num_workers=0,
        pin_memory=False,
    )

    model = BitTransformerLM(
        d_model=64,
        nhead=4,
        num_layers=2,
        dim_feedforward=256,
        max_seq_len=2048,
        reversible=True,
        use_autocast=True,
    )

    loss_fn = torch.nn.CrossEntropyLoss()
    xb, yb = next(iter(dl))
    logits, _ = model(xb)
    pred = logits.reshape(-1, 2)
    target = yb.reshape(-1)
    loss = loss_fn(pred, target)
    print('Batch loss:', float(loss))

if __name__ == '__main__':
    main()