File size: 1,311 Bytes
36c78b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import pathlib
import torch
from bit_transformer import BitTransformerLM
DATA_PATH = pathlib.Path('full_bits.pt')
class BitSeq(torch.utils.data.IterableDataset):
def __init__(self, path: str | pathlib.Path = DATA_PATH, seq: int = 2048) -> None:
self.bits = torch.load(path, mmap=True)
self.seq = seq
def __len__(self) -> int:
return (self.bits.numel() // self.seq) - 1
def __iter__(self):
N = (self.bits.numel() // self.seq) - 1
for i in range(N):
s = i * self.seq
yield (
self.bits[s:s+self.seq].long(),
self.bits[s+1:s+self.seq+1].long(),
)
def main() -> None:
dl = torch.utils.data.DataLoader(
BitSeq(DATA_PATH, seq=2048),
batch_size=8,
num_workers=0,
pin_memory=False,
)
model = BitTransformerLM(
d_model=64,
nhead=4,
num_layers=2,
dim_feedforward=256,
max_seq_len=2048,
reversible=True,
use_autocast=True,
)
loss_fn = torch.nn.CrossEntropyLoss()
xb, yb = next(iter(dl))
logits, _ = model(xb)
pred = logits.reshape(-1, 2)
target = yb.reshape(-1)
loss = loss_fn(pred, target)
print('Batch loss:', float(loss))
if __name__ == '__main__':
main()
|