|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset |
|
from pathlib import Path |
|
from datasets import load_dataset |
|
|
|
from bit_transformer import ( |
|
BitTransformerLM, |
|
configure_optimizer, |
|
expand_model, |
|
text_to_bits, |
|
) |
|
from bit_transformer.training import train_loop as basic_train |
|
|
|
|
|
def _build_memmap(lines, path: Path, max_len: int) -> None: |
|
"""Precompute bit tensors into a memory-mapped file.""" |
|
arr = np.memmap(path, mode="w+", shape=(len(lines), max_len), dtype="uint8") |
|
for idx, text in enumerate(lines): |
|
bits = text_to_bits(text)[:max_len] |
|
if len(bits) < max_len: |
|
bits.extend([0] * (max_len - len(bits))) |
|
arr[idx] = np.array(bits, dtype="uint8") |
|
arr.flush() |
|
|
|
|
|
class MemmapDataset(Dataset): |
|
"""Dataset backed by a memory-mapped array.""" |
|
|
|
def __init__(self, path: Path, length: int, max_len: int) -> None: |
|
self.path = path |
|
self.length = length |
|
self.max_len = max_len |
|
self._arr = np.memmap(path, mode="r", shape=(length, max_len), dtype="uint8") |
|
|
|
def __len__(self) -> int: |
|
return self.length |
|
|
|
def __getitem__(self, idx: int) -> torch.Tensor: |
|
return torch.from_numpy(self._arr[idx].astype("int64")) |
|
|
|
|
|
def progressive_scale_schedule(steps=12, max_len=64, dataset_size=128): |
|
"""Run deterministic scale-up on WikiText data.""" |
|
ds = load_dataset("wikitext", "wikitext-2-raw-v1") |
|
train_lines = [t for t in ds["train"]["text"] if t.strip()][:dataset_size] |
|
valid_lines = [t for t in ds["validation"]["text"] if t.strip()][: dataset_size // 4] |
|
|
|
train_path = Path("wikitext_train.memmap") |
|
valid_path = Path("wikitext_valid.memmap") |
|
_build_memmap(train_lines, train_path, max_len) |
|
_build_memmap(valid_lines, valid_path, max_len) |
|
|
|
train = MemmapDataset(train_path, len(train_lines), max_len) |
|
valid = torch.from_numpy( |
|
np.memmap(valid_path, mode="r", shape=(len(valid_lines), max_len), dtype="uint8") |
|
).long() |
|
|
|
layers = 1 |
|
width = 32 |
|
params = dict( |
|
d_model=width, |
|
nhead=4, |
|
num_layers=layers, |
|
dim_feedforward=width * 2, |
|
max_seq_len=max_len, |
|
reversible=True, |
|
chunk_size=max_len, |
|
use_autocast=True, |
|
use_act=True, |
|
act_threshold=0.9, |
|
) |
|
model = BitTransformerLM(**params) |
|
steps_per_epoch = max(1, (len(train) + 7) // 8) |
|
optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=(steps + 1) * steps_per_epoch) |
|
|
|
results = [] |
|
for step in range(steps + 1): |
|
basic_train( |
|
model, |
|
train, |
|
epochs=1, |
|
compress_prob=0.5, |
|
log=False, |
|
forward_kwargs=None, |
|
num_workers=2, |
|
) |
|
|
|
with torch.no_grad(): |
|
logits, _ = model(valid) |
|
pred = logits[:, :-1, :].reshape(-1, 2) |
|
target = valid[:, 1:].reshape(-1) |
|
val_loss = F.cross_entropy(pred, target).item() |
|
print(f"Step {step} validation loss: {val_loss:.4f}") |
|
results.append((step, val_loss)) |
|
|
|
if step < steps: |
|
if step % 2 == 0: |
|
layers *= 2 |
|
else: |
|
width *= 2 |
|
params = dict( |
|
d_model=width, |
|
nhead=4, |
|
num_layers=layers, |
|
dim_feedforward=width * 2, |
|
max_seq_len=max_len, |
|
reversible=True, |
|
chunk_size=max_len, |
|
use_autocast=True, |
|
use_act=True, |
|
act_threshold=0.9, |
|
) |
|
model = expand_model(model, params) |
|
optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=(steps - step) * steps_per_epoch) |
|
print(f"Scaled model to {layers} layers and width {width}") |
|
return results |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser(description="Deterministic scale-up benchmark") |
|
parser.add_argument("--steps", type=int, default=12, help="number of scale-up steps") |
|
parser.add_argument("--max-len", type=int, default=64, help="sequence length") |
|
parser.add_argument("--dataset-size", type=int, default=128, help="number of training lines") |
|
args = parser.parse_args() |
|
|
|
progressive_scale_schedule(steps=args.steps, max_len=args.max_len, dataset_size=args.dataset_size) |
|
|