|
import pathlib |
|
import torch |
|
from bit_transformer import BitTransformerLM |
|
|
|
DATA_PATH = pathlib.Path('full_bits.pt') |
|
|
|
class BitSeq(torch.utils.data.IterableDataset): |
|
def __init__(self, path: str | pathlib.Path = DATA_PATH, seq: int = 2048) -> None: |
|
self.bits = torch.load(path, mmap=True) |
|
self.seq = seq |
|
|
|
def __len__(self) -> int: |
|
return (self.bits.numel() // self.seq) - 1 |
|
|
|
def __iter__(self): |
|
N = (self.bits.numel() // self.seq) - 1 |
|
for i in range(N): |
|
s = i * self.seq |
|
yield ( |
|
self.bits[s:s+self.seq].long(), |
|
self.bits[s+1:s+self.seq+1].long(), |
|
) |
|
|
|
def main() -> None: |
|
dl = torch.utils.data.DataLoader( |
|
BitSeq(DATA_PATH, seq=2048), |
|
batch_size=8, |
|
num_workers=0, |
|
pin_memory=False, |
|
) |
|
|
|
model = BitTransformerLM( |
|
d_model=64, |
|
nhead=4, |
|
num_layers=2, |
|
dim_feedforward=256, |
|
max_seq_len=2048, |
|
reversible=True, |
|
use_autocast=True, |
|
) |
|
|
|
loss_fn = torch.nn.CrossEntropyLoss() |
|
xb, yb = next(iter(dl)) |
|
logits, _ = model(xb) |
|
pred = logits.reshape(-1, 2) |
|
target = yb.reshape(-1) |
|
loss = loss_fn(pred, target) |
|
print('Batch loss:', float(loss)) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|