BitTransformerLM / integration_schedule.py
WCNegentropy's picture
🤖 Updated BitTransformerLM from development space
36c78b1 verified
raw
history blame
13.1 kB
import os
import time
import math
from itertools import cycle
from typing import Optional
import torch
import torch.nn.functional as F
from bit_transformer import (
BitTransformerLM,
text_to_bits,
quantize_dynamic,
prepare_qat_fx,
convert_qat_fx,
hil_safe_inference,
collapse_submodel,
diffusion_inference,
TelemetrySynthesizer,
save_distilled_model,
)
from bit_transformer.training import train_loop as train
from bit_transformer.optimization import configure_optimizer, adjust_learning_rate
from bit_transformer.utils import save_model, load_model, set_dropout
from bit_transformer.torch_utils import cpu_autocast
def lines_to_tensor(lines, max_len):
seqs = []
for text in lines:
bits = text_to_bits(text)[:max_len]
if len(bits) < max_len:
bits.extend([0] * (max_len - len(bits)))
seqs.append(bits)
return torch.tensor(seqs, dtype=torch.long)
def load_wikitext(dataset_size=128, max_len=64):
try:
from datasets import load_dataset
ds = load_dataset("wikitext", "wikitext-2-raw-v1")
train_lines = [t for t in ds["train"]["text"] if t.strip()][:dataset_size]
valid_split = max(1, dataset_size // 4)
valid_lines = [t for t in ds["validation"]["text"] if t.strip()][:valid_split]
train = lines_to_tensor(train_lines, max_len)
valid = lines_to_tensor(valid_lines, max_len)
return train, valid, train_lines
except Exception as e:
print("Dataset load failed, using random bits", e)
train = torch.randint(0, 2, (dataset_size, max_len), dtype=torch.long)
valid = torch.randint(0, 2, (max_len, max_len), dtype=torch.long)
return train, valid, ["" for _ in range(len(train))]
def _warmup(
model: BitTransformerLM,
data: torch.Tensor,
steps: int = 5,
freeze_old: bool = False,
old_layers: int = 0,
*,
diffusion: bool = False,
curriculum: bool = False,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Run a short warm-up loop after expansion."""
model.train()
set_dropout(model, 0.1)
if freeze_old:
for idx, layer in enumerate(model.layers):
if idx < old_layers:
for p in layer.parameters():
p.requires_grad_(False)
if optimizer is None or scheduler is None:
optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=steps)
it = iter(data.split(8))
for idx in range(steps):
try:
batch = next(it)
except StopIteration:
it = iter(data.split(8))
batch = next(it)
if diffusion:
p = 0.5 * (1 - idx / max(1, steps - 1)) if curriculum else 0.5
noise = (torch.rand_like(batch.float()) < p).long()
noisy = batch ^ noise
logits, _ = model(noisy, causal=False)
pred = logits.reshape(-1, 2)
target = batch.reshape(-1)
else:
logits, _ = model(batch)
pred = logits[:, :-1, :].reshape(-1, 2)
target = batch[:, 1:].reshape(-1)
loss = F.cross_entropy(pred, target)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
for p in model.parameters():
p.requires_grad_(True)
model.eval()
set_dropout(model, 0.0)
def integration_schedule(
steps: int = 10,
max_len: int = 64,
dataset_size: int = 128,
*,
weights_path: str = "weights/model.pt.gz",
plateau_steps: int = 0,
collapsed_path: str | None = None,
epochs_per_step: int = 2,
extra_steps: int = 3,
collapse: bool = True,
diffusion: bool = False,
noise_schedule: str = "linear",
diffusion_steps: int = 8,
diffusion_curriculum: bool = False,
use_checkpoint: bool = True,
reversible: bool = True,
improve_thresh: float = 0.01,
qat: bool = False,
):
start = time.time()
train_bits, valid_bits, train_lines = load_wikitext(dataset_size, max_len)
if os.path.exists(weights_path):
try:
model = load_model(weights_path)
print(f"Loaded model from {weights_path}")
except Exception as e:
print("Failed to load weights, initializing new model", e)
model = BitTransformerLM(
d_model=32,
nhead=4,
num_layers=1,
dim_feedforward=64,
max_seq_len=max_len,
use_act=True,
act_threshold=0.7,
reversible=reversible,
chunk_size=max_len,
use_autocast=True,
use_checkpoint=use_checkpoint,
)
else:
model = BitTransformerLM(
d_model=32,
nhead=4,
num_layers=1,
dim_feedforward=64,
max_seq_len=max_len,
use_act=True,
act_threshold=0.7,
reversible=reversible,
chunk_size=max_len,
use_autocast=True,
use_checkpoint=use_checkpoint,
)
if qat:
model = prepare_qat_fx(model)
results = []
scale_cycle = cycle(["layers", "width", "context"])
base_lr = 1e-3
prev_val_loss: Optional[float] = None
for step in range(steps):
model.train()
set_dropout(model, 0.1)
opt, sched = configure_optimizer(
model, lr=base_lr, total_steps=epochs_per_step
)
train(
model,
train_bits,
epochs=epochs_per_step,
extra_steps=extra_steps,
compress_prob=0.0 if diffusion else 1.0,
log=True,
diffusion=diffusion,
diffusion_curriculum=diffusion_curriculum,
optimizer=opt,
scheduler=sched,
)
model.eval()
set_dropout(model, 0.0)
with torch.no_grad():
logits, telemetry = model(valid_bits, causal=not diffusion)
if diffusion:
pred = logits.reshape(-1, 2)
target = valid_bits.reshape(-1)
else:
pred = logits[:, :-1, :].reshape(-1, 2)
target = valid_bits[:, 1:].reshape(-1)
val_loss = F.cross_entropy(pred, target).item()
k = telemetry["negentropy_logits"].mean().item()
c = telemetry["lz_complexity_logits"].mean().item()
s = telemetry["symbiosis_score"].mean().item()
print(f"Step {step} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}")
results.append((step, val_loss, k, c, s))
if prev_val_loss is not None and prev_val_loss - val_loss < improve_thresh:
strategy = next(scale_cycle)
base_lr = adjust_learning_rate(opt, 1 / math.sqrt(2))
if strategy == "layers":
old_layers = model.num_layers
model = model.double_layers()
warm_opt, warm_sched = configure_optimizer(
model, lr=base_lr, total_steps=100
)
_warmup(
model,
train_bits,
steps=100,
freeze_old=True,
old_layers=old_layers,
diffusion=diffusion,
curriculum=diffusion_curriculum,
optimizer=warm_opt,
scheduler=warm_sched,
)
elif strategy == "width":
model = model.double_width()
warm_opt, warm_sched = configure_optimizer(
model, lr=base_lr, total_steps=100
)
_warmup(
model,
train_bits,
steps=100,
diffusion=diffusion,
curriculum=diffusion_curriculum,
optimizer=warm_opt,
scheduler=warm_sched,
)
else:
max_len *= 2
train_bits, valid_bits, train_lines = load_wikitext(
dataset_size, max_len
)
model = model.double_length()
warm_opt, warm_sched = configure_optimizer(
model, lr=base_lr, total_steps=100
)
_warmup(
model,
train_bits,
steps=100,
diffusion=diffusion,
curriculum=diffusion_curriculum,
optimizer=warm_opt,
scheduler=warm_sched,
)
prev_val_loss = val_loss
if time.time() - start > 8 * 60:
print("Time limit reached")
break
# optional plateau phase at final size
for p in range(plateau_steps):
model.train()
set_dropout(model, 0.1)
train(
model,
train_bits,
epochs=epochs_per_step,
extra_steps=extra_steps,
compress_prob=0.0 if diffusion else 1.0,
log=True,
diffusion=diffusion,
diffusion_curriculum=diffusion_curriculum,
)
model.eval()
set_dropout(model, 0.0)
with torch.no_grad():
logits, telemetry = model(valid_bits, causal=not diffusion)
if diffusion:
pred = logits.reshape(-1, 2)
target = valid_bits.reshape(-1)
else:
pred = logits[:, :-1, :].reshape(-1, 2)
target = valid_bits[:, 1:].reshape(-1)
val_loss = F.cross_entropy(pred, target).item()
k = telemetry["negentropy_logits"].mean().item()
c = telemetry["lz_complexity_logits"].mean().item()
s = telemetry["symbiosis_score"].mean().item()
idx = steps + p
print(
f"Plateau {p} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}"
)
results.append((idx, val_loss, k, c, s))
if time.time() - start > 8 * 60:
print("Time limit reached")
break
# final validation after last step
model.eval()
set_dropout(model, 0.0)
with torch.no_grad():
logits, telemetry = model(valid_bits, causal=not diffusion)
if diffusion:
pred = logits.reshape(-1, 2)
target = valid_bits.reshape(-1)
else:
pred = logits[:, :-1, :].reshape(-1, 2)
target = valid_bits[:, 1:].reshape(-1)
val_loss = F.cross_entropy(pred, target).item()
k = telemetry["negentropy_logits"].mean().item()
c = telemetry["lz_complexity_logits"].mean().item()
s = telemetry["symbiosis_score"].mean().item()
print(f"Final validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}")
results.append((steps + plateau_steps, val_loss, k, c, s))
# persist final model weights for future runs
save_model(model, weights_path)
input_bits = valid_bits[:1]
if qat:
qmodel = convert_qat_fx(model)
else:
with cpu_autocast():
model(input_bits)
qmodel = quantize_dynamic(model)
qmodel.eval()
try:
hil_safe_inference(
qmodel,
input_bits,
c_floor=0.3,
s_floor=0.5,
causal=not diffusion,
strict=not diffusion,
)
except RuntimeError as e:
print("Safety gate triggered", e)
collapsed = None
if collapse:
synth = TelemetrySynthesizer(n_clusters=8)
reps = synth.cluster_sequences(model, train_bits[:64])
floors = {"negentropy": 0.3, "lz_complexity": 0.35, "symbiosis_score": 0.5}
collapsed, metrics = collapse_submodel(
reps,
target_params=dict(
d_model=16,
nhead=4,
num_layers=1,
dim_feedforward=32,
max_seq_len=max_len,
),
floors=floors,
)
collapsed.eval()
with torch.no_grad():
logits, _ = collapsed(valid_bits)
pred = logits[:, :-1, :].reshape(-1, 2)
target = valid_bits[:, 1:].reshape(-1)
c_loss = F.cross_entropy(pred, target).item()
print("Collapsed model validation loss:", c_loss)
if collapsed_path is not None:
save_distilled_model(
collapsed,
collapsed_path,
{**metrics, "val_loss": c_loss},
floors=floors,
)
if diffusion:
sample = diffusion_inference(
model, length=max_len, steps=diffusion_steps, schedule=noise_schedule
)
print("Diffusion sample:", sample[0].tolist())
return results, collapsed
if __name__ == "__main__":
integration_schedule()