import pathlib import torch from bit_transformer import BitTransformerLM DATA_PATH = pathlib.Path('full_bits.pt') class BitSeq(torch.utils.data.IterableDataset): def __init__(self, path: str | pathlib.Path = DATA_PATH, seq: int = 2048) -> None: self.bits = torch.load(path, mmap=True) self.seq = seq def __len__(self) -> int: return (self.bits.numel() // self.seq) - 1 def __iter__(self): N = (self.bits.numel() // self.seq) - 1 for i in range(N): s = i * self.seq yield ( self.bits[s:s+self.seq].long(), self.bits[s+1:s+self.seq+1].long(), ) def main() -> None: dl = torch.utils.data.DataLoader( BitSeq(DATA_PATH, seq=2048), batch_size=8, num_workers=0, pin_memory=False, ) model = BitTransformerLM( d_model=64, nhead=4, num_layers=2, dim_feedforward=256, max_seq_len=2048, reversible=True, use_autocast=True, ) loss_fn = torch.nn.CrossEntropyLoss() xb, yb = next(iter(dl)) logits, _ = model(xb) pred = logits.reshape(-1, 2) target = yb.reshape(-1) loss = loss_fn(pred, target) print('Batch loss:', float(loss)) if __name__ == '__main__': main()