BitTransformerLM / wikitext_schedule.py
WCNegentropy's picture
🤖 Updated BitTransformerLM from development space
36c78b1 verified
raw
history blame
4.46 kB
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from pathlib import Path
from datasets import load_dataset
from bit_transformer import (
BitTransformerLM,
configure_optimizer,
expand_model,
text_to_bits,
)
from bit_transformer.training import train_loop as basic_train
def _build_memmap(lines, path: Path, max_len: int) -> None:
"""Precompute bit tensors into a memory-mapped file."""
arr = np.memmap(path, mode="w+", shape=(len(lines), max_len), dtype="uint8")
for idx, text in enumerate(lines):
bits = text_to_bits(text)[:max_len]
if len(bits) < max_len:
bits.extend([0] * (max_len - len(bits)))
arr[idx] = np.array(bits, dtype="uint8")
arr.flush()
class MemmapDataset(Dataset):
"""Dataset backed by a memory-mapped array."""
def __init__(self, path: Path, length: int, max_len: int) -> None:
self.path = path
self.length = length
self.max_len = max_len
self._arr = np.memmap(path, mode="r", shape=(length, max_len), dtype="uint8")
def __len__(self) -> int: # pragma: no cover - trivial
return self.length
def __getitem__(self, idx: int) -> torch.Tensor:
return torch.from_numpy(self._arr[idx].astype("int64"))
def progressive_scale_schedule(steps=12, max_len=64, dataset_size=128):
"""Run deterministic scale-up on WikiText data."""
ds = load_dataset("wikitext", "wikitext-2-raw-v1")
train_lines = [t for t in ds["train"]["text"] if t.strip()][:dataset_size]
valid_lines = [t for t in ds["validation"]["text"] if t.strip()][: dataset_size // 4]
train_path = Path("wikitext_train.memmap")
valid_path = Path("wikitext_valid.memmap")
_build_memmap(train_lines, train_path, max_len)
_build_memmap(valid_lines, valid_path, max_len)
train = MemmapDataset(train_path, len(train_lines), max_len)
valid = torch.from_numpy(
np.memmap(valid_path, mode="r", shape=(len(valid_lines), max_len), dtype="uint8")
).long()
layers = 1
width = 32
params = dict(
d_model=width,
nhead=4,
num_layers=layers,
dim_feedforward=width * 2,
max_seq_len=max_len,
reversible=True,
chunk_size=max_len,
use_autocast=True,
use_act=True,
act_threshold=0.9,
)
model = BitTransformerLM(**params)
steps_per_epoch = max(1, (len(train) + 7) // 8)
optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=(steps + 1) * steps_per_epoch)
results = []
for step in range(steps + 1):
basic_train(
model,
train,
epochs=1,
compress_prob=0.5,
log=False,
forward_kwargs=None,
num_workers=2,
)
with torch.no_grad():
logits, _ = model(valid)
pred = logits[:, :-1, :].reshape(-1, 2)
target = valid[:, 1:].reshape(-1)
val_loss = F.cross_entropy(pred, target).item()
print(f"Step {step} validation loss: {val_loss:.4f}")
results.append((step, val_loss))
if step < steps:
if step % 2 == 0:
layers *= 2
else:
width *= 2
params = dict(
d_model=width,
nhead=4,
num_layers=layers,
dim_feedforward=width * 2,
max_seq_len=max_len,
reversible=True,
chunk_size=max_len,
use_autocast=True,
use_act=True,
act_threshold=0.9,
)
model = expand_model(model, params)
optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=(steps - step) * steps_per_epoch)
print(f"Scaled model to {layers} layers and width {width}")
return results
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Deterministic scale-up benchmark")
parser.add_argument("--steps", type=int, default=12, help="number of scale-up steps")
parser.add_argument("--max-len", type=int, default=64, help="sequence length")
parser.add_argument("--dataset-size", type=int, default=128, help="number of training lines")
args = parser.parse_args()
progressive_scale_schedule(steps=args.steps, max_len=args.max_len, dataset_size=args.dataset_size)