import numpy as np import torch import torch.nn.functional as F from torch.utils.data import Dataset from pathlib import Path from datasets import load_dataset from bit_transformer import ( BitTransformerLM, configure_optimizer, expand_model, text_to_bits, ) from bit_transformer.training import train_loop as basic_train def _build_memmap(lines, path: Path, max_len: int) -> None: """Precompute bit tensors into a memory-mapped file.""" arr = np.memmap(path, mode="w+", shape=(len(lines), max_len), dtype="uint8") for idx, text in enumerate(lines): bits = text_to_bits(text)[:max_len] if len(bits) < max_len: bits.extend([0] * (max_len - len(bits))) arr[idx] = np.array(bits, dtype="uint8") arr.flush() class MemmapDataset(Dataset): """Dataset backed by a memory-mapped array.""" def __init__(self, path: Path, length: int, max_len: int) -> None: self.path = path self.length = length self.max_len = max_len self._arr = np.memmap(path, mode="r", shape=(length, max_len), dtype="uint8") def __len__(self) -> int: # pragma: no cover - trivial return self.length def __getitem__(self, idx: int) -> torch.Tensor: return torch.from_numpy(self._arr[idx].astype("int64")) def progressive_scale_schedule(steps=12, max_len=64, dataset_size=128): """Run deterministic scale-up on WikiText data.""" ds = load_dataset("wikitext", "wikitext-2-raw-v1") train_lines = [t for t in ds["train"]["text"] if t.strip()][:dataset_size] valid_lines = [t for t in ds["validation"]["text"] if t.strip()][: dataset_size // 4] train_path = Path("wikitext_train.memmap") valid_path = Path("wikitext_valid.memmap") _build_memmap(train_lines, train_path, max_len) _build_memmap(valid_lines, valid_path, max_len) train = MemmapDataset(train_path, len(train_lines), max_len) valid = torch.from_numpy( np.memmap(valid_path, mode="r", shape=(len(valid_lines), max_len), dtype="uint8") ).long() layers = 1 width = 32 params = dict( d_model=width, nhead=4, num_layers=layers, dim_feedforward=width * 2, max_seq_len=max_len, reversible=True, chunk_size=max_len, use_autocast=True, use_act=True, act_threshold=0.9, ) model = BitTransformerLM(**params) steps_per_epoch = max(1, (len(train) + 7) // 8) optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=(steps + 1) * steps_per_epoch) results = [] for step in range(steps + 1): basic_train( model, train, epochs=1, compress_prob=0.5, log=False, forward_kwargs=None, num_workers=2, ) with torch.no_grad(): logits, _ = model(valid) pred = logits[:, :-1, :].reshape(-1, 2) target = valid[:, 1:].reshape(-1) val_loss = F.cross_entropy(pred, target).item() print(f"Step {step} validation loss: {val_loss:.4f}") results.append((step, val_loss)) if step < steps: if step % 2 == 0: layers *= 2 else: width *= 2 params = dict( d_model=width, nhead=4, num_layers=layers, dim_feedforward=width * 2, max_seq_len=max_len, reversible=True, chunk_size=max_len, use_autocast=True, use_act=True, act_threshold=0.9, ) model = expand_model(model, params) optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=(steps - step) * steps_per_epoch) print(f"Scaled model to {layers} layers and width {width}") return results if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Deterministic scale-up benchmark") parser.add_argument("--steps", type=int, default=12, help="number of scale-up steps") parser.add_argument("--max-len", type=int, default=64, help="sequence length") parser.add_argument("--dataset-size", type=int, default=128, help="number of training lines") args = parser.parse_args() progressive_scale_schedule(steps=args.steps, max_len=args.max_len, dataset_size=args.dataset_size)