|
"""Legacy progressive scale-up demo. |
|
|
|
This script is retained for historical reference but has been superseded by |
|
``integration_schedule.py`` which provides a more flexible scaling workflow. |
|
""" |
|
|
|
import argparse |
|
import warnings |
|
import torch |
|
import torch.nn.functional as F |
|
from bit_transformer import ( |
|
BitTransformerLM, |
|
configure_optimizer, |
|
expand_model, |
|
text_to_bits, |
|
) |
|
from bit_transformer.training import train_loop as basic_train |
|
|
|
warnings.warn( |
|
"progressive_scaleup.py is deprecated; use integration_schedule.py instead.", |
|
DeprecationWarning, |
|
stacklevel=2, |
|
) |
|
|
|
|
|
def progressive_scale_up( |
|
eps: float = 0.65, |
|
steps: int = 2, |
|
width_mult: float = 1.0, |
|
forward_kwargs: dict | None = None, |
|
) -> None: |
|
"""Demonstrate automatic scaling of the model on random data.""" |
|
params = dict(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=16) |
|
model = BitTransformerLM(**params) |
|
steps_per_epoch = 64 // 8 |
|
optimizer, scheduler = configure_optimizer( |
|
model, lr=1e-3, total_steps=steps * steps_per_epoch |
|
) |
|
|
|
train = torch.randint(0, 2, (64, params["max_seq_len"]), dtype=torch.long) |
|
valid = torch.randint(0, 2, (16, params["max_seq_len"]), dtype=torch.long) |
|
|
|
for step in range(steps): |
|
|
|
basic_train( |
|
model, |
|
train, |
|
epochs=1, |
|
compress_prob=0.5, |
|
log=False, |
|
forward_kwargs=forward_kwargs, |
|
) |
|
|
|
with torch.no_grad(): |
|
logits, _ = model(valid, **(forward_kwargs or {})) |
|
pred = logits[:, :-1, :].reshape(-1, 2) |
|
target = valid[:, 1:].reshape(-1) |
|
val_loss = F.cross_entropy(pred, target).item() |
|
print(f"Step {step} validation loss: {val_loss:.4f}") |
|
if val_loss < eps: |
|
params["num_layers"] *= 2 |
|
params["d_model"] = int(params["d_model"] * width_mult) |
|
params["dim_feedforward"] = int(params["dim_feedforward"] * width_mult) |
|
model = expand_model(model, params) |
|
optimizer, scheduler = configure_optimizer( |
|
model, lr=1e-3, total_steps=steps * steps_per_epoch |
|
) |
|
print( |
|
"Scaled model to", params["num_layers"], "layers and width", params["d_model"] |
|
) |
|
|
|
|
|
def progressive_scale_up_text( |
|
improve_thresh: float = 0.01, |
|
steps: int = 2, |
|
width_mult: float = 2.0, |
|
max_len: int = 64, |
|
dataset_size: int = 512, |
|
forward_kwargs: dict | None = None, |
|
) -> None: |
|
"""Scale up using WikiText2 lines converted to bits. |
|
|
|
Parameters |
|
---------- |
|
improve_thresh: float |
|
Relative validation loss improvement required to avoid scaling. |
|
If improvement is <= this threshold, model size is increased. |
|
steps: int |
|
Number of training steps. |
|
width_mult: float |
|
Multiplier applied when increasing model width. |
|
max_len: int |
|
Initial sequence length. |
|
dataset_size: int |
|
Number of training lines to load from WikiText2. |
|
forward_kwargs: dict | None |
|
Extra keyword arguments for the forward pass. |
|
""" |
|
from datasets import load_dataset |
|
|
|
ds = load_dataset("wikitext", "wikitext-2-raw-v1") |
|
train_iter = ds["train"]["text"] |
|
valid_iter = ds["validation"]["text"] |
|
|
|
train_lines = [] |
|
for line in train_iter: |
|
train_lines.append(line) |
|
if len(train_lines) >= dataset_size: |
|
break |
|
|
|
valid_lines = [] |
|
for line in valid_iter: |
|
valid_lines.append(line) |
|
if len(valid_lines) >= dataset_size // 4: |
|
break |
|
|
|
def lines_to_tensor(lines: list[str], length: int) -> torch.Tensor: |
|
seqs = [] |
|
for text in lines: |
|
bits = text_to_bits(text)[:length] |
|
if len(bits) < length: |
|
bits.extend([0] * (length - len(bits))) |
|
seqs.append(bits) |
|
return torch.tensor(seqs, dtype=torch.long) |
|
|
|
train = lines_to_tensor(train_lines, max_len) |
|
valid = lines_to_tensor(valid_lines, max_len) |
|
|
|
params = dict( |
|
d_model=32, |
|
nhead=4, |
|
num_layers=1, |
|
dim_feedforward=64, |
|
max_seq_len=max_len, |
|
) |
|
model = BitTransformerLM(**params) |
|
steps_per_epoch = len(train) // 8 |
|
optimizer, scheduler = configure_optimizer( |
|
model, lr=1e-3, total_steps=steps * max(1, steps_per_epoch) |
|
) |
|
|
|
prev_loss: float | None = None |
|
scale_length = True |
|
|
|
for step in range(steps): |
|
basic_train( |
|
model, |
|
train, |
|
epochs=1, |
|
compress_prob=0.5, |
|
log=False, |
|
forward_kwargs=forward_kwargs, |
|
) |
|
|
|
with torch.no_grad(): |
|
logits, _ = model(valid, **(forward_kwargs or {})) |
|
pred = logits[:, :-1, :].reshape(-1, 2) |
|
target = valid[:, 1:].reshape(-1) |
|
val_loss = F.cross_entropy(pred, target).item() |
|
print(f"Step {step} validation loss: {val_loss:.4f}") |
|
if prev_loss is not None: |
|
improvement = (prev_loss - val_loss) / max(prev_loss, 1e-8) |
|
if improvement <= improve_thresh: |
|
if scale_length: |
|
params["max_seq_len"] *= 2 |
|
train = lines_to_tensor(train_lines, params["max_seq_len"]) |
|
valid = lines_to_tensor(valid_lines, params["max_seq_len"]) |
|
model = model.double_length() |
|
steps_per_epoch = len(train) // 8 |
|
scale_type = "length" |
|
else: |
|
params["d_model"] = int(params["d_model"] * width_mult) |
|
params["dim_feedforward"] = int(params["dim_feedforward"] * width_mult) |
|
model = expand_model(model, params) |
|
scale_type = "width" |
|
optimizer, scheduler = configure_optimizer( |
|
model, lr=1e-3, total_steps=steps * max(1, steps_per_epoch) |
|
) |
|
scale_length = not scale_length |
|
param_count = sum(p.numel() for p in model.parameters()) |
|
print( |
|
f"Scaled {scale_type}; seq_len={params['max_seq_len']} width={params['d_model']} params={param_count}" |
|
) |
|
prev_loss = val_loss |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Progressively scale model length and width") |
|
parser.add_argument("--steps", type=int, default=2, help="number of training steps") |
|
parser.add_argument( |
|
"--improve-thresh", |
|
type=float, |
|
default=0.01, |
|
help="relative loss improvement required to avoid scaling", |
|
) |
|
parser.add_argument( |
|
"--width-mult", type=float, default=2.0, help="width multiplier when scaling" |
|
) |
|
parser.add_argument("--causal", action="store_true", help="use causal attention during training") |
|
parser.add_argument("--wikitext", action="store_true", help="use WikiText2 dataset") |
|
args = parser.parse_args() |
|
if args.wikitext: |
|
progressive_scale_up_text( |
|
improve_thresh=args.improve_thresh, |
|
steps=args.steps, |
|
width_mult=args.width_mult, |
|
forward_kwargs={"causal": args.causal} if args.causal else None, |
|
) |
|
else: |
|
progressive_scale_up( |
|
eps=args.improve_thresh, |
|
steps=args.steps, |
|
width_mult=args.width_mult, |
|
forward_kwargs={"causal": args.causal} if args.causal else None, |
|
) |
|
|
|
|