"""Legacy progressive scale-up demo. This script is retained for historical reference but has been superseded by ``integration_schedule.py`` which provides a more flexible scaling workflow. """ import argparse import warnings import torch import torch.nn.functional as F from bit_transformer import ( BitTransformerLM, configure_optimizer, expand_model, text_to_bits, ) from bit_transformer.training import train_loop as basic_train warnings.warn( "progressive_scaleup.py is deprecated; use integration_schedule.py instead.", DeprecationWarning, stacklevel=2, ) def progressive_scale_up( eps: float = 0.65, steps: int = 2, width_mult: float = 1.0, forward_kwargs: dict | None = None, ) -> None: """Demonstrate automatic scaling of the model on random data.""" params = dict(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=16) model = BitTransformerLM(**params) steps_per_epoch = 64 // 8 optimizer, scheduler = configure_optimizer( model, lr=1e-3, total_steps=steps * steps_per_epoch ) train = torch.randint(0, 2, (64, params["max_seq_len"]), dtype=torch.long) valid = torch.randint(0, 2, (16, params["max_seq_len"]), dtype=torch.long) for step in range(steps): # one epoch over train basic_train( model, train, epochs=1, compress_prob=0.5, log=False, forward_kwargs=forward_kwargs, ) with torch.no_grad(): logits, _ = model(valid, **(forward_kwargs or {})) pred = logits[:, :-1, :].reshape(-1, 2) target = valid[:, 1:].reshape(-1) val_loss = F.cross_entropy(pred, target).item() print(f"Step {step} validation loss: {val_loss:.4f}") if val_loss < eps: params["num_layers"] *= 2 params["d_model"] = int(params["d_model"] * width_mult) params["dim_feedforward"] = int(params["dim_feedforward"] * width_mult) model = expand_model(model, params) optimizer, scheduler = configure_optimizer( model, lr=1e-3, total_steps=steps * steps_per_epoch ) print( "Scaled model to", params["num_layers"], "layers and width", params["d_model"] ) def progressive_scale_up_text( improve_thresh: float = 0.01, steps: int = 2, width_mult: float = 2.0, max_len: int = 64, dataset_size: int = 512, forward_kwargs: dict | None = None, ) -> None: """Scale up using WikiText2 lines converted to bits. Parameters ---------- improve_thresh: float Relative validation loss improvement required to avoid scaling. If improvement is <= this threshold, model size is increased. steps: int Number of training steps. width_mult: float Multiplier applied when increasing model width. max_len: int Initial sequence length. dataset_size: int Number of training lines to load from WikiText2. forward_kwargs: dict | None Extra keyword arguments for the forward pass. """ from datasets import load_dataset ds = load_dataset("wikitext", "wikitext-2-raw-v1") train_iter = ds["train"]["text"] valid_iter = ds["validation"]["text"] train_lines = [] for line in train_iter: train_lines.append(line) if len(train_lines) >= dataset_size: break valid_lines = [] for line in valid_iter: valid_lines.append(line) if len(valid_lines) >= dataset_size // 4: break def lines_to_tensor(lines: list[str], length: int) -> torch.Tensor: seqs = [] for text in lines: bits = text_to_bits(text)[:length] if len(bits) < length: bits.extend([0] * (length - len(bits))) seqs.append(bits) return torch.tensor(seqs, dtype=torch.long) train = lines_to_tensor(train_lines, max_len) valid = lines_to_tensor(valid_lines, max_len) params = dict( d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=max_len, ) model = BitTransformerLM(**params) steps_per_epoch = len(train) // 8 optimizer, scheduler = configure_optimizer( model, lr=1e-3, total_steps=steps * max(1, steps_per_epoch) ) prev_loss: float | None = None scale_length = True for step in range(steps): basic_train( model, train, epochs=1, compress_prob=0.5, log=False, forward_kwargs=forward_kwargs, ) with torch.no_grad(): logits, _ = model(valid, **(forward_kwargs or {})) pred = logits[:, :-1, :].reshape(-1, 2) target = valid[:, 1:].reshape(-1) val_loss = F.cross_entropy(pred, target).item() print(f"Step {step} validation loss: {val_loss:.4f}") if prev_loss is not None: improvement = (prev_loss - val_loss) / max(prev_loss, 1e-8) if improvement <= improve_thresh: if scale_length: params["max_seq_len"] *= 2 train = lines_to_tensor(train_lines, params["max_seq_len"]) valid = lines_to_tensor(valid_lines, params["max_seq_len"]) model = model.double_length() steps_per_epoch = len(train) // 8 scale_type = "length" else: params["d_model"] = int(params["d_model"] * width_mult) params["dim_feedforward"] = int(params["dim_feedforward"] * width_mult) model = expand_model(model, params) scale_type = "width" optimizer, scheduler = configure_optimizer( model, lr=1e-3, total_steps=steps * max(1, steps_per_epoch) ) scale_length = not scale_length param_count = sum(p.numel() for p in model.parameters()) print( f"Scaled {scale_type}; seq_len={params['max_seq_len']} width={params['d_model']} params={param_count}" ) prev_loss = val_loss if __name__ == "__main__": parser = argparse.ArgumentParser(description="Progressively scale model length and width") parser.add_argument("--steps", type=int, default=2, help="number of training steps") parser.add_argument( "--improve-thresh", type=float, default=0.01, help="relative loss improvement required to avoid scaling", ) parser.add_argument( "--width-mult", type=float, default=2.0, help="width multiplier when scaling" ) parser.add_argument("--causal", action="store_true", help="use causal attention during training") parser.add_argument("--wikitext", action="store_true", help="use WikiText2 dataset") args = parser.parse_args() if args.wikitext: progressive_scale_up_text( improve_thresh=args.improve_thresh, steps=args.steps, width_mult=args.width_mult, forward_kwargs={"causal": args.causal} if args.causal else None, ) else: progressive_scale_up( eps=args.improve_thresh, steps=args.steps, width_mult=args.width_mult, forward_kwargs={"causal": args.causal} if args.causal else None, )