import os import time import math from itertools import cycle from typing import Optional import torch import torch.nn.functional as F from bit_transformer import ( BitTransformerLM, text_to_bits, quantize_dynamic, prepare_qat_fx, convert_qat_fx, hil_safe_inference, collapse_submodel, diffusion_inference, TelemetrySynthesizer, save_distilled_model, ) from bit_transformer.training import train_loop as train from bit_transformer.optimization import configure_optimizer, adjust_learning_rate from bit_transformer.utils import save_model, load_model, set_dropout from bit_transformer.torch_utils import cpu_autocast def lines_to_tensor(lines, max_len): seqs = [] for text in lines: bits = text_to_bits(text)[:max_len] if len(bits) < max_len: bits.extend([0] * (max_len - len(bits))) seqs.append(bits) return torch.tensor(seqs, dtype=torch.long) def load_wikitext(dataset_size=128, max_len=64): try: from datasets import load_dataset ds = load_dataset("wikitext", "wikitext-2-raw-v1") train_lines = [t for t in ds["train"]["text"] if t.strip()][:dataset_size] valid_split = max(1, dataset_size // 4) valid_lines = [t for t in ds["validation"]["text"] if t.strip()][:valid_split] train = lines_to_tensor(train_lines, max_len) valid = lines_to_tensor(valid_lines, max_len) return train, valid, train_lines except Exception as e: print("Dataset load failed, using random bits", e) train = torch.randint(0, 2, (dataset_size, max_len), dtype=torch.long) valid = torch.randint(0, 2, (max_len, max_len), dtype=torch.long) return train, valid, ["" for _ in range(len(train))] def _warmup( model: BitTransformerLM, data: torch.Tensor, steps: int = 5, freeze_old: bool = False, old_layers: int = 0, *, diffusion: bool = False, curriculum: bool = False, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, ) -> None: """Run a short warm-up loop after expansion.""" model.train() set_dropout(model, 0.1) if freeze_old: for idx, layer in enumerate(model.layers): if idx < old_layers: for p in layer.parameters(): p.requires_grad_(False) if optimizer is None or scheduler is None: optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=steps) it = iter(data.split(8)) for idx in range(steps): try: batch = next(it) except StopIteration: it = iter(data.split(8)) batch = next(it) if diffusion: p = 0.5 * (1 - idx / max(1, steps - 1)) if curriculum else 0.5 noise = (torch.rand_like(batch.float()) < p).long() noisy = batch ^ noise logits, _ = model(noisy, causal=False) pred = logits.reshape(-1, 2) target = batch.reshape(-1) else: logits, _ = model(batch) pred = logits[:, :-1, :].reshape(-1, 2) target = batch[:, 1:].reshape(-1) loss = F.cross_entropy(pred, target) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() optimizer.zero_grad() for p in model.parameters(): p.requires_grad_(True) model.eval() set_dropout(model, 0.0) def integration_schedule( steps: int = 10, max_len: int = 64, dataset_size: int = 128, *, weights_path: str = "weights/model.pt.gz", plateau_steps: int = 0, collapsed_path: str | None = None, epochs_per_step: int = 2, extra_steps: int = 3, collapse: bool = True, diffusion: bool = False, noise_schedule: str = "linear", diffusion_steps: int = 8, diffusion_curriculum: bool = False, use_checkpoint: bool = True, reversible: bool = True, improve_thresh: float = 0.01, qat: bool = False, ): start = time.time() train_bits, valid_bits, train_lines = load_wikitext(dataset_size, max_len) if os.path.exists(weights_path): try: model = load_model(weights_path) print(f"Loaded model from {weights_path}") except Exception as e: print("Failed to load weights, initializing new model", e) model = BitTransformerLM( d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=max_len, use_act=True, act_threshold=0.7, reversible=reversible, chunk_size=max_len, use_autocast=True, use_checkpoint=use_checkpoint, ) else: model = BitTransformerLM( d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=max_len, use_act=True, act_threshold=0.7, reversible=reversible, chunk_size=max_len, use_autocast=True, use_checkpoint=use_checkpoint, ) if qat: model = prepare_qat_fx(model) results = [] scale_cycle = cycle(["layers", "width", "context"]) base_lr = 1e-3 prev_val_loss: Optional[float] = None for step in range(steps): model.train() set_dropout(model, 0.1) opt, sched = configure_optimizer( model, lr=base_lr, total_steps=epochs_per_step ) train( model, train_bits, epochs=epochs_per_step, extra_steps=extra_steps, compress_prob=0.0 if diffusion else 1.0, log=True, diffusion=diffusion, diffusion_curriculum=diffusion_curriculum, optimizer=opt, scheduler=sched, ) model.eval() set_dropout(model, 0.0) with torch.no_grad(): logits, telemetry = model(valid_bits, causal=not diffusion) if diffusion: pred = logits.reshape(-1, 2) target = valid_bits.reshape(-1) else: pred = logits[:, :-1, :].reshape(-1, 2) target = valid_bits[:, 1:].reshape(-1) val_loss = F.cross_entropy(pred, target).item() k = telemetry["negentropy_logits"].mean().item() c = telemetry["lz_complexity_logits"].mean().item() s = telemetry["symbiosis_score"].mean().item() print(f"Step {step} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}") results.append((step, val_loss, k, c, s)) if prev_val_loss is not None and prev_val_loss - val_loss < improve_thresh: strategy = next(scale_cycle) base_lr = adjust_learning_rate(opt, 1 / math.sqrt(2)) if strategy == "layers": old_layers = model.num_layers model = model.double_layers() warm_opt, warm_sched = configure_optimizer( model, lr=base_lr, total_steps=100 ) _warmup( model, train_bits, steps=100, freeze_old=True, old_layers=old_layers, diffusion=diffusion, curriculum=diffusion_curriculum, optimizer=warm_opt, scheduler=warm_sched, ) elif strategy == "width": model = model.double_width() warm_opt, warm_sched = configure_optimizer( model, lr=base_lr, total_steps=100 ) _warmup( model, train_bits, steps=100, diffusion=diffusion, curriculum=diffusion_curriculum, optimizer=warm_opt, scheduler=warm_sched, ) else: max_len *= 2 train_bits, valid_bits, train_lines = load_wikitext( dataset_size, max_len ) model = model.double_length() warm_opt, warm_sched = configure_optimizer( model, lr=base_lr, total_steps=100 ) _warmup( model, train_bits, steps=100, diffusion=diffusion, curriculum=diffusion_curriculum, optimizer=warm_opt, scheduler=warm_sched, ) prev_val_loss = val_loss if time.time() - start > 8 * 60: print("Time limit reached") break # optional plateau phase at final size for p in range(plateau_steps): model.train() set_dropout(model, 0.1) train( model, train_bits, epochs=epochs_per_step, extra_steps=extra_steps, compress_prob=0.0 if diffusion else 1.0, log=True, diffusion=diffusion, diffusion_curriculum=diffusion_curriculum, ) model.eval() set_dropout(model, 0.0) with torch.no_grad(): logits, telemetry = model(valid_bits, causal=not diffusion) if diffusion: pred = logits.reshape(-1, 2) target = valid_bits.reshape(-1) else: pred = logits[:, :-1, :].reshape(-1, 2) target = valid_bits[:, 1:].reshape(-1) val_loss = F.cross_entropy(pred, target).item() k = telemetry["negentropy_logits"].mean().item() c = telemetry["lz_complexity_logits"].mean().item() s = telemetry["symbiosis_score"].mean().item() idx = steps + p print( f"Plateau {p} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}" ) results.append((idx, val_loss, k, c, s)) if time.time() - start > 8 * 60: print("Time limit reached") break # final validation after last step model.eval() set_dropout(model, 0.0) with torch.no_grad(): logits, telemetry = model(valid_bits, causal=not diffusion) if diffusion: pred = logits.reshape(-1, 2) target = valid_bits.reshape(-1) else: pred = logits[:, :-1, :].reshape(-1, 2) target = valid_bits[:, 1:].reshape(-1) val_loss = F.cross_entropy(pred, target).item() k = telemetry["negentropy_logits"].mean().item() c = telemetry["lz_complexity_logits"].mean().item() s = telemetry["symbiosis_score"].mean().item() print(f"Final validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}") results.append((steps + plateau_steps, val_loss, k, c, s)) # persist final model weights for future runs save_model(model, weights_path) input_bits = valid_bits[:1] if qat: qmodel = convert_qat_fx(model) else: with cpu_autocast(): model(input_bits) qmodel = quantize_dynamic(model) qmodel.eval() try: hil_safe_inference( qmodel, input_bits, c_floor=0.3, s_floor=0.5, causal=not diffusion, strict=not diffusion, ) except RuntimeError as e: print("Safety gate triggered", e) collapsed = None if collapse: synth = TelemetrySynthesizer(n_clusters=8) reps = synth.cluster_sequences(model, train_bits[:64]) floors = {"negentropy": 0.3, "lz_complexity": 0.35, "symbiosis_score": 0.5} collapsed, metrics = collapse_submodel( reps, target_params=dict( d_model=16, nhead=4, num_layers=1, dim_feedforward=32, max_seq_len=max_len, ), floors=floors, ) collapsed.eval() with torch.no_grad(): logits, _ = collapsed(valid_bits) pred = logits[:, :-1, :].reshape(-1, 2) target = valid_bits[:, 1:].reshape(-1) c_loss = F.cross_entropy(pred, target).item() print("Collapsed model validation loss:", c_loss) if collapsed_path is not None: save_distilled_model( collapsed, collapsed_path, {**metrics, "val_loss": c_loss}, floors=floors, ) if diffusion: sample = diffusion_inference( model, length=max_len, steps=diffusion_steps, schedule=noise_schedule ) print("Diffusion sample:", sample[0].tolist()) return results, collapsed if __name__ == "__main__": integration_schedule()