BitTransformerLM / progressive_scaleup.py
WCNegentropy's picture
🤖 Updated BitTransformerLM from development space
36c78b1 verified
raw
history blame
7.51 kB
"""Legacy progressive scale-up demo.
This script is retained for historical reference but has been superseded by
``integration_schedule.py`` which provides a more flexible scaling workflow.
"""
import argparse
import warnings
import torch
import torch.nn.functional as F
from bit_transformer import (
BitTransformerLM,
configure_optimizer,
expand_model,
text_to_bits,
)
from bit_transformer.training import train_loop as basic_train
warnings.warn(
"progressive_scaleup.py is deprecated; use integration_schedule.py instead.",
DeprecationWarning,
stacklevel=2,
)
def progressive_scale_up(
eps: float = 0.65,
steps: int = 2,
width_mult: float = 1.0,
forward_kwargs: dict | None = None,
) -> None:
"""Demonstrate automatic scaling of the model on random data."""
params = dict(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=16)
model = BitTransformerLM(**params)
steps_per_epoch = 64 // 8
optimizer, scheduler = configure_optimizer(
model, lr=1e-3, total_steps=steps * steps_per_epoch
)
train = torch.randint(0, 2, (64, params["max_seq_len"]), dtype=torch.long)
valid = torch.randint(0, 2, (16, params["max_seq_len"]), dtype=torch.long)
for step in range(steps):
# one epoch over train
basic_train(
model,
train,
epochs=1,
compress_prob=0.5,
log=False,
forward_kwargs=forward_kwargs,
)
with torch.no_grad():
logits, _ = model(valid, **(forward_kwargs or {}))
pred = logits[:, :-1, :].reshape(-1, 2)
target = valid[:, 1:].reshape(-1)
val_loss = F.cross_entropy(pred, target).item()
print(f"Step {step} validation loss: {val_loss:.4f}")
if val_loss < eps:
params["num_layers"] *= 2
params["d_model"] = int(params["d_model"] * width_mult)
params["dim_feedforward"] = int(params["dim_feedforward"] * width_mult)
model = expand_model(model, params)
optimizer, scheduler = configure_optimizer(
model, lr=1e-3, total_steps=steps * steps_per_epoch
)
print(
"Scaled model to", params["num_layers"], "layers and width", params["d_model"]
)
def progressive_scale_up_text(
improve_thresh: float = 0.01,
steps: int = 2,
width_mult: float = 2.0,
max_len: int = 64,
dataset_size: int = 512,
forward_kwargs: dict | None = None,
) -> None:
"""Scale up using WikiText2 lines converted to bits.
Parameters
----------
improve_thresh: float
Relative validation loss improvement required to avoid scaling.
If improvement is <= this threshold, model size is increased.
steps: int
Number of training steps.
width_mult: float
Multiplier applied when increasing model width.
max_len: int
Initial sequence length.
dataset_size: int
Number of training lines to load from WikiText2.
forward_kwargs: dict | None
Extra keyword arguments for the forward pass.
"""
from datasets import load_dataset
ds = load_dataset("wikitext", "wikitext-2-raw-v1")
train_iter = ds["train"]["text"]
valid_iter = ds["validation"]["text"]
train_lines = []
for line in train_iter:
train_lines.append(line)
if len(train_lines) >= dataset_size:
break
valid_lines = []
for line in valid_iter:
valid_lines.append(line)
if len(valid_lines) >= dataset_size // 4:
break
def lines_to_tensor(lines: list[str], length: int) -> torch.Tensor:
seqs = []
for text in lines:
bits = text_to_bits(text)[:length]
if len(bits) < length:
bits.extend([0] * (length - len(bits)))
seqs.append(bits)
return torch.tensor(seqs, dtype=torch.long)
train = lines_to_tensor(train_lines, max_len)
valid = lines_to_tensor(valid_lines, max_len)
params = dict(
d_model=32,
nhead=4,
num_layers=1,
dim_feedforward=64,
max_seq_len=max_len,
)
model = BitTransformerLM(**params)
steps_per_epoch = len(train) // 8
optimizer, scheduler = configure_optimizer(
model, lr=1e-3, total_steps=steps * max(1, steps_per_epoch)
)
prev_loss: float | None = None
scale_length = True
for step in range(steps):
basic_train(
model,
train,
epochs=1,
compress_prob=0.5,
log=False,
forward_kwargs=forward_kwargs,
)
with torch.no_grad():
logits, _ = model(valid, **(forward_kwargs or {}))
pred = logits[:, :-1, :].reshape(-1, 2)
target = valid[:, 1:].reshape(-1)
val_loss = F.cross_entropy(pred, target).item()
print(f"Step {step} validation loss: {val_loss:.4f}")
if prev_loss is not None:
improvement = (prev_loss - val_loss) / max(prev_loss, 1e-8)
if improvement <= improve_thresh:
if scale_length:
params["max_seq_len"] *= 2
train = lines_to_tensor(train_lines, params["max_seq_len"])
valid = lines_to_tensor(valid_lines, params["max_seq_len"])
model = model.double_length()
steps_per_epoch = len(train) // 8
scale_type = "length"
else:
params["d_model"] = int(params["d_model"] * width_mult)
params["dim_feedforward"] = int(params["dim_feedforward"] * width_mult)
model = expand_model(model, params)
scale_type = "width"
optimizer, scheduler = configure_optimizer(
model, lr=1e-3, total_steps=steps * max(1, steps_per_epoch)
)
scale_length = not scale_length
param_count = sum(p.numel() for p in model.parameters())
print(
f"Scaled {scale_type}; seq_len={params['max_seq_len']} width={params['d_model']} params={param_count}"
)
prev_loss = val_loss
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Progressively scale model length and width")
parser.add_argument("--steps", type=int, default=2, help="number of training steps")
parser.add_argument(
"--improve-thresh",
type=float,
default=0.01,
help="relative loss improvement required to avoid scaling",
)
parser.add_argument(
"--width-mult", type=float, default=2.0, help="width multiplier when scaling"
)
parser.add_argument("--causal", action="store_true", help="use causal attention during training")
parser.add_argument("--wikitext", action="store_true", help="use WikiText2 dataset")
args = parser.parse_args()
if args.wikitext:
progressive_scale_up_text(
improve_thresh=args.improve_thresh,
steps=args.steps,
width_mult=args.width_mult,
forward_kwargs={"causal": args.causal} if args.causal else None,
)
else:
progressive_scale_up(
eps=args.improve_thresh,
steps=args.steps,
width_mult=args.width_mult,
forward_kwargs={"causal": args.causal} if args.causal else None,
)