|
import torch |
|
import torch.nn.functional as F |
|
from torch.profiler import profile |
|
from bit_transformer import ( |
|
BitTransformerLM, |
|
quantize_dynamic, |
|
hil_safe_inference, |
|
collapse_submodel, |
|
) |
|
from bit_transformer.training import train_loop |
|
from bit_transformer.torch_utils import cpu_autocast |
|
|
|
|
|
def train( |
|
model: BitTransformerLM, |
|
data: torch.Tensor, |
|
epochs: int = 1, |
|
compress_prob: float = 0.5, |
|
log: bool = False, |
|
forward_kwargs: dict | None = None, |
|
) -> list[dict]: |
|
"""Train with random compression; returns per-epoch metrics.""" |
|
return train_loop( |
|
model, |
|
data, |
|
epochs=epochs, |
|
compress_prob=compress_prob, |
|
direct_prob=0.0, |
|
log=log, |
|
forward_kwargs=forward_kwargs, |
|
) |
|
|
|
|
|
def recursive_integration_flow(steps: int = 4, max_len: int = 64) -> None: |
|
"""Run a dynamic scale-up loop with telemetry-based gating.""" |
|
train_bits = torch.randint(0, 2, (64, max_len), dtype=torch.long) |
|
valid_bits = torch.randint(0, 2, (16, max_len), dtype=torch.long) |
|
input_bits = torch.randint(0, 2, (1, max_len), dtype=torch.long) |
|
bit_sequence_data = train_bits.tolist() |
|
|
|
best_K = best_C = best_S = 0.0 |
|
|
|
model = BitTransformerLM( |
|
d_model=32, |
|
nhead=4, |
|
num_layers=1, |
|
dim_feedforward=64, |
|
max_seq_len=max_len, |
|
use_act=True, |
|
act_threshold=0.7, |
|
reversible=True, |
|
chunk_size=max_len, |
|
use_autocast=True, |
|
) |
|
|
|
results = [] |
|
for step in range(steps + 1): |
|
epochs = min(10, 2 + step // 2) |
|
train(model, train_bits, epochs=epochs, compress_prob=0.5, log=True) |
|
|
|
with torch.no_grad(): |
|
with cpu_autocast(): |
|
logits, telemetry = model(valid_bits) |
|
pred = logits[:, :-1, :].reshape(-1, 2) |
|
target = valid_bits[:, 1:].reshape(-1) |
|
val_loss = F.cross_entropy(pred, target).item() |
|
k = telemetry["negentropy_logits"].mean().item() |
|
c = telemetry["lz_complexity_logits"].mean().item() |
|
s = telemetry["symbiosis_score"].mean().item() |
|
|
|
print(f"Step {step} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}") |
|
results.append((step, val_loss, k, c, s)) |
|
|
|
if step > 0: |
|
if k < best_K - 0.3 or c < best_C - 0.3 or s < best_S - 0.3: |
|
print(f"\u26a0\ufe0f Step {step} regressed below metric floor. Halting.") |
|
break |
|
best_K = max(best_K, k) |
|
best_C = max(best_C, c) |
|
best_S = max(best_S, s) |
|
|
|
if step < steps: |
|
if step % 2 == 0: |
|
model = model.double_width() |
|
else: |
|
model = model.double_layers() |
|
|
|
|
|
with cpu_autocast(): |
|
model(input_bits) |
|
|
|
qmodel = quantize_dynamic(model) |
|
qmodel.eval() |
|
|
|
safe_output = hil_safe_inference( |
|
qmodel, input_bits, c_floor=0.5, s_floor=0.2 |
|
) |
|
|
|
student_model, _ = collapse_submodel( |
|
bit_sequence_data, |
|
target_params=dict( |
|
d_model=16, |
|
nhead=4, |
|
num_layers=1, |
|
dim_feedforward=32, |
|
max_seq_len=max_len, |
|
), |
|
floors={"negentropy": 0.2, "lz_complexity": 0.5, "symbiosis_score": 0.2}, |
|
) |
|
|
|
if hasattr(torch, "compile"): |
|
try: |
|
compiled = torch.compile(student_model) |
|
except RuntimeError as exc: |
|
print(f"Compilation skipped: {exc}") |
|
compiled = student_model |
|
else: |
|
compiled = student_model |
|
compiled.eval() |
|
|
|
with profile() as prof: |
|
compiled(input_bits) |
|
prof.export_chrome_trace("trace12.json") |
|
print("Safe output bits:", safe_output[0].tolist()) |
|
|
|
|
|
if __name__ == "__main__": |
|
recursive_integration_flow() |
|
|