BitTransformerLM / full_bits_train.py
WCNegentropy's picture
🤖 Updated BitTransformerLM from development space
36c78b1 verified
raw
history blame
1.31 kB
import pathlib
import torch
from bit_transformer import BitTransformerLM
DATA_PATH = pathlib.Path('full_bits.pt')
class BitSeq(torch.utils.data.IterableDataset):
def __init__(self, path: str | pathlib.Path = DATA_PATH, seq: int = 2048) -> None:
self.bits = torch.load(path, mmap=True)
self.seq = seq
def __len__(self) -> int:
return (self.bits.numel() // self.seq) - 1
def __iter__(self):
N = (self.bits.numel() // self.seq) - 1
for i in range(N):
s = i * self.seq
yield (
self.bits[s:s+self.seq].long(),
self.bits[s+1:s+self.seq+1].long(),
)
def main() -> None:
dl = torch.utils.data.DataLoader(
BitSeq(DATA_PATH, seq=2048),
batch_size=8,
num_workers=0,
pin_memory=False,
)
model = BitTransformerLM(
d_model=64,
nhead=4,
num_layers=2,
dim_feedforward=256,
max_seq_len=2048,
reversible=True,
use_autocast=True,
)
loss_fn = torch.nn.CrossEntropyLoss()
xb, yb = next(iter(dl))
logits, _ = model(xb)
pred = logits.reshape(-1, 2)
target = yb.reshape(-1)
loss = loss_fn(pred, target)
print('Batch loss:', float(loss))
if __name__ == '__main__':
main()