|
import os |
|
import time |
|
import math |
|
from itertools import cycle |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from bit_transformer import ( |
|
BitTransformerLM, |
|
text_to_bits, |
|
quantize_dynamic, |
|
prepare_qat_fx, |
|
convert_qat_fx, |
|
hil_safe_inference, |
|
collapse_submodel, |
|
diffusion_inference, |
|
TelemetrySynthesizer, |
|
save_distilled_model, |
|
) |
|
from bit_transformer.training import train_loop as train |
|
from bit_transformer.optimization import configure_optimizer, adjust_learning_rate |
|
from bit_transformer.utils import save_model, load_model, set_dropout |
|
from bit_transformer.torch_utils import cpu_autocast |
|
|
|
|
|
def lines_to_tensor(lines, max_len): |
|
seqs = [] |
|
for text in lines: |
|
bits = text_to_bits(text)[:max_len] |
|
if len(bits) < max_len: |
|
bits.extend([0] * (max_len - len(bits))) |
|
seqs.append(bits) |
|
return torch.tensor(seqs, dtype=torch.long) |
|
|
|
|
|
def load_wikitext(dataset_size=128, max_len=64): |
|
try: |
|
from datasets import load_dataset |
|
|
|
ds = load_dataset("wikitext", "wikitext-2-raw-v1") |
|
train_lines = [t for t in ds["train"]["text"] if t.strip()][:dataset_size] |
|
valid_split = max(1, dataset_size // 4) |
|
valid_lines = [t for t in ds["validation"]["text"] if t.strip()][:valid_split] |
|
train = lines_to_tensor(train_lines, max_len) |
|
valid = lines_to_tensor(valid_lines, max_len) |
|
return train, valid, train_lines |
|
except Exception as e: |
|
print("Dataset load failed, using random bits", e) |
|
train = torch.randint(0, 2, (dataset_size, max_len), dtype=torch.long) |
|
valid = torch.randint(0, 2, (max_len, max_len), dtype=torch.long) |
|
return train, valid, ["" for _ in range(len(train))] |
|
|
|
|
|
def _warmup( |
|
model: BitTransformerLM, |
|
data: torch.Tensor, |
|
steps: int = 5, |
|
freeze_old: bool = False, |
|
old_layers: int = 0, |
|
*, |
|
diffusion: bool = False, |
|
curriculum: bool = False, |
|
optimizer: Optional[torch.optim.Optimizer] = None, |
|
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, |
|
) -> None: |
|
"""Run a short warm-up loop after expansion.""" |
|
model.train() |
|
set_dropout(model, 0.1) |
|
if freeze_old: |
|
for idx, layer in enumerate(model.layers): |
|
if idx < old_layers: |
|
for p in layer.parameters(): |
|
p.requires_grad_(False) |
|
if optimizer is None or scheduler is None: |
|
optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=steps) |
|
it = iter(data.split(8)) |
|
for idx in range(steps): |
|
try: |
|
batch = next(it) |
|
except StopIteration: |
|
it = iter(data.split(8)) |
|
batch = next(it) |
|
if diffusion: |
|
p = 0.5 * (1 - idx / max(1, steps - 1)) if curriculum else 0.5 |
|
noise = (torch.rand_like(batch.float()) < p).long() |
|
noisy = batch ^ noise |
|
logits, _ = model(noisy, causal=False) |
|
pred = logits.reshape(-1, 2) |
|
target = batch.reshape(-1) |
|
else: |
|
logits, _ = model(batch) |
|
pred = logits[:, :-1, :].reshape(-1, 2) |
|
target = batch[:, 1:].reshape(-1) |
|
loss = F.cross_entropy(pred, target) |
|
loss.backward() |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
optimizer.step() |
|
scheduler.step() |
|
optimizer.zero_grad() |
|
for p in model.parameters(): |
|
p.requires_grad_(True) |
|
model.eval() |
|
set_dropout(model, 0.0) |
|
|
|
|
|
def integration_schedule( |
|
steps: int = 10, |
|
max_len: int = 64, |
|
dataset_size: int = 128, |
|
*, |
|
weights_path: str = "weights/model.pt.gz", |
|
plateau_steps: int = 0, |
|
collapsed_path: str | None = None, |
|
epochs_per_step: int = 2, |
|
extra_steps: int = 3, |
|
collapse: bool = True, |
|
diffusion: bool = False, |
|
noise_schedule: str = "linear", |
|
diffusion_steps: int = 8, |
|
diffusion_curriculum: bool = False, |
|
use_checkpoint: bool = True, |
|
reversible: bool = True, |
|
improve_thresh: float = 0.01, |
|
qat: bool = False, |
|
): |
|
start = time.time() |
|
train_bits, valid_bits, train_lines = load_wikitext(dataset_size, max_len) |
|
if os.path.exists(weights_path): |
|
try: |
|
model = load_model(weights_path) |
|
print(f"Loaded model from {weights_path}") |
|
except Exception as e: |
|
print("Failed to load weights, initializing new model", e) |
|
model = BitTransformerLM( |
|
d_model=32, |
|
nhead=4, |
|
num_layers=1, |
|
dim_feedforward=64, |
|
max_seq_len=max_len, |
|
use_act=True, |
|
act_threshold=0.7, |
|
reversible=reversible, |
|
chunk_size=max_len, |
|
use_autocast=True, |
|
use_checkpoint=use_checkpoint, |
|
) |
|
else: |
|
model = BitTransformerLM( |
|
d_model=32, |
|
nhead=4, |
|
num_layers=1, |
|
dim_feedforward=64, |
|
max_seq_len=max_len, |
|
use_act=True, |
|
act_threshold=0.7, |
|
reversible=reversible, |
|
chunk_size=max_len, |
|
use_autocast=True, |
|
use_checkpoint=use_checkpoint, |
|
) |
|
if qat: |
|
model = prepare_qat_fx(model) |
|
results = [] |
|
scale_cycle = cycle(["layers", "width", "context"]) |
|
base_lr = 1e-3 |
|
prev_val_loss: Optional[float] = None |
|
for step in range(steps): |
|
model.train() |
|
set_dropout(model, 0.1) |
|
opt, sched = configure_optimizer( |
|
model, lr=base_lr, total_steps=epochs_per_step |
|
) |
|
train( |
|
model, |
|
train_bits, |
|
epochs=epochs_per_step, |
|
extra_steps=extra_steps, |
|
compress_prob=0.0 if diffusion else 1.0, |
|
log=True, |
|
diffusion=diffusion, |
|
diffusion_curriculum=diffusion_curriculum, |
|
optimizer=opt, |
|
scheduler=sched, |
|
) |
|
|
|
model.eval() |
|
set_dropout(model, 0.0) |
|
with torch.no_grad(): |
|
logits, telemetry = model(valid_bits, causal=not diffusion) |
|
if diffusion: |
|
pred = logits.reshape(-1, 2) |
|
target = valid_bits.reshape(-1) |
|
else: |
|
pred = logits[:, :-1, :].reshape(-1, 2) |
|
target = valid_bits[:, 1:].reshape(-1) |
|
val_loss = F.cross_entropy(pred, target).item() |
|
k = telemetry["negentropy_logits"].mean().item() |
|
c = telemetry["lz_complexity_logits"].mean().item() |
|
s = telemetry["symbiosis_score"].mean().item() |
|
print(f"Step {step} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}") |
|
results.append((step, val_loss, k, c, s)) |
|
|
|
if prev_val_loss is not None and prev_val_loss - val_loss < improve_thresh: |
|
strategy = next(scale_cycle) |
|
base_lr = adjust_learning_rate(opt, 1 / math.sqrt(2)) |
|
if strategy == "layers": |
|
old_layers = model.num_layers |
|
model = model.double_layers() |
|
warm_opt, warm_sched = configure_optimizer( |
|
model, lr=base_lr, total_steps=100 |
|
) |
|
_warmup( |
|
model, |
|
train_bits, |
|
steps=100, |
|
freeze_old=True, |
|
old_layers=old_layers, |
|
diffusion=diffusion, |
|
curriculum=diffusion_curriculum, |
|
optimizer=warm_opt, |
|
scheduler=warm_sched, |
|
) |
|
elif strategy == "width": |
|
model = model.double_width() |
|
warm_opt, warm_sched = configure_optimizer( |
|
model, lr=base_lr, total_steps=100 |
|
) |
|
_warmup( |
|
model, |
|
train_bits, |
|
steps=100, |
|
diffusion=diffusion, |
|
curriculum=diffusion_curriculum, |
|
optimizer=warm_opt, |
|
scheduler=warm_sched, |
|
) |
|
else: |
|
max_len *= 2 |
|
train_bits, valid_bits, train_lines = load_wikitext( |
|
dataset_size, max_len |
|
) |
|
model = model.double_length() |
|
warm_opt, warm_sched = configure_optimizer( |
|
model, lr=base_lr, total_steps=100 |
|
) |
|
_warmup( |
|
model, |
|
train_bits, |
|
steps=100, |
|
diffusion=diffusion, |
|
curriculum=diffusion_curriculum, |
|
optimizer=warm_opt, |
|
scheduler=warm_sched, |
|
) |
|
|
|
prev_val_loss = val_loss |
|
if time.time() - start > 8 * 60: |
|
print("Time limit reached") |
|
break |
|
|
|
|
|
for p in range(plateau_steps): |
|
model.train() |
|
set_dropout(model, 0.1) |
|
train( |
|
model, |
|
train_bits, |
|
epochs=epochs_per_step, |
|
extra_steps=extra_steps, |
|
compress_prob=0.0 if diffusion else 1.0, |
|
log=True, |
|
diffusion=diffusion, |
|
diffusion_curriculum=diffusion_curriculum, |
|
) |
|
model.eval() |
|
set_dropout(model, 0.0) |
|
with torch.no_grad(): |
|
logits, telemetry = model(valid_bits, causal=not diffusion) |
|
if diffusion: |
|
pred = logits.reshape(-1, 2) |
|
target = valid_bits.reshape(-1) |
|
else: |
|
pred = logits[:, :-1, :].reshape(-1, 2) |
|
target = valid_bits[:, 1:].reshape(-1) |
|
val_loss = F.cross_entropy(pred, target).item() |
|
k = telemetry["negentropy_logits"].mean().item() |
|
c = telemetry["lz_complexity_logits"].mean().item() |
|
s = telemetry["symbiosis_score"].mean().item() |
|
idx = steps + p |
|
print( |
|
f"Plateau {p} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}" |
|
) |
|
results.append((idx, val_loss, k, c, s)) |
|
if time.time() - start > 8 * 60: |
|
print("Time limit reached") |
|
break |
|
|
|
|
|
model.eval() |
|
set_dropout(model, 0.0) |
|
with torch.no_grad(): |
|
logits, telemetry = model(valid_bits, causal=not diffusion) |
|
if diffusion: |
|
pred = logits.reshape(-1, 2) |
|
target = valid_bits.reshape(-1) |
|
else: |
|
pred = logits[:, :-1, :].reshape(-1, 2) |
|
target = valid_bits[:, 1:].reshape(-1) |
|
val_loss = F.cross_entropy(pred, target).item() |
|
k = telemetry["negentropy_logits"].mean().item() |
|
c = telemetry["lz_complexity_logits"].mean().item() |
|
s = telemetry["symbiosis_score"].mean().item() |
|
|
|
print(f"Final validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}") |
|
results.append((steps + plateau_steps, val_loss, k, c, s)) |
|
|
|
|
|
save_model(model, weights_path) |
|
|
|
input_bits = valid_bits[:1] |
|
if qat: |
|
qmodel = convert_qat_fx(model) |
|
else: |
|
with cpu_autocast(): |
|
model(input_bits) |
|
qmodel = quantize_dynamic(model) |
|
qmodel.eval() |
|
try: |
|
hil_safe_inference( |
|
qmodel, |
|
input_bits, |
|
c_floor=0.3, |
|
s_floor=0.5, |
|
causal=not diffusion, |
|
strict=not diffusion, |
|
) |
|
except RuntimeError as e: |
|
print("Safety gate triggered", e) |
|
collapsed = None |
|
if collapse: |
|
synth = TelemetrySynthesizer(n_clusters=8) |
|
reps = synth.cluster_sequences(model, train_bits[:64]) |
|
floors = {"negentropy": 0.3, "lz_complexity": 0.35, "symbiosis_score": 0.5} |
|
collapsed, metrics = collapse_submodel( |
|
reps, |
|
target_params=dict( |
|
d_model=16, |
|
nhead=4, |
|
num_layers=1, |
|
dim_feedforward=32, |
|
max_seq_len=max_len, |
|
), |
|
floors=floors, |
|
) |
|
collapsed.eval() |
|
with torch.no_grad(): |
|
logits, _ = collapsed(valid_bits) |
|
pred = logits[:, :-1, :].reshape(-1, 2) |
|
target = valid_bits[:, 1:].reshape(-1) |
|
c_loss = F.cross_entropy(pred, target).item() |
|
print("Collapsed model validation loss:", c_loss) |
|
if collapsed_path is not None: |
|
save_distilled_model( |
|
collapsed, |
|
collapsed_path, |
|
{**metrics, "val_loss": c_loss}, |
|
floors=floors, |
|
) |
|
if diffusion: |
|
sample = diffusion_inference( |
|
model, length=max_len, steps=diffusion_steps, schedule=noise_schedule |
|
) |
|
print("Diffusion sample:", sample[0].tolist()) |
|
return results, collapsed |
|
|
|
|
|
if __name__ == "__main__": |
|
integration_schedule() |
|
|