File size: 3,808 Bytes
36c78b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import torch.nn.functional as F
from torch.profiler import profile
from bit_transformer import (
    BitTransformerLM,
    quantize_dynamic,
    hil_safe_inference,
    collapse_submodel,
)
from bit_transformer.training import train_loop
from bit_transformer.torch_utils import cpu_autocast


def train(
    model: BitTransformerLM,
    data: torch.Tensor,
    epochs: int = 1,
    compress_prob: float = 0.5,
    log: bool = False,
    forward_kwargs: dict | None = None,
) -> list[dict]:
    """Train with random compression; returns per-epoch metrics."""
    return train_loop(
        model,
        data,
        epochs=epochs,
        compress_prob=compress_prob,
        direct_prob=0.0,
        log=log,
        forward_kwargs=forward_kwargs,
    )


def recursive_integration_flow(steps: int = 4, max_len: int = 64) -> None:
    """Run a dynamic scale-up loop with telemetry-based gating."""
    train_bits = torch.randint(0, 2, (64, max_len), dtype=torch.long)
    valid_bits = torch.randint(0, 2, (16, max_len), dtype=torch.long)
    input_bits = torch.randint(0, 2, (1, max_len), dtype=torch.long)
    bit_sequence_data = train_bits.tolist()

    best_K = best_C = best_S = 0.0

    model = BitTransformerLM(
        d_model=32,
        nhead=4,
        num_layers=1,
        dim_feedforward=64,
        max_seq_len=max_len,
        use_act=True,
        act_threshold=0.7,
        reversible=True,
        chunk_size=max_len,
        use_autocast=True,
    )

    results = []
    for step in range(steps + 1):
        epochs = min(10, 2 + step // 2)
        train(model, train_bits, epochs=epochs, compress_prob=0.5, log=True)

        with torch.no_grad():
            with cpu_autocast():
                logits, telemetry = model(valid_bits)
            pred = logits[:, :-1, :].reshape(-1, 2)
            target = valid_bits[:, 1:].reshape(-1)
            val_loss = F.cross_entropy(pred, target).item()
            k = telemetry["negentropy_logits"].mean().item()
            c = telemetry["lz_complexity_logits"].mean().item()
            s = telemetry["symbiosis_score"].mean().item()

        print(f"Step {step} validation loss: {val_loss:.4f} K={k:.3f} C={c:.3f} S={s:.3f}")
        results.append((step, val_loss, k, c, s))

        if step > 0:
            if k < best_K - 0.3 or c < best_C - 0.3 or s < best_S - 0.3:
                print(f"\u26a0\ufe0f Step {step} regressed below metric floor. Halting.")
                break
        best_K = max(best_K, k)
        best_C = max(best_C, c)
        best_S = max(best_S, s)

        if step < steps:
            if step % 2 == 0:
                model = model.double_width()
            else:
                model = model.double_layers()

    # Post-scaling optimizations
    with cpu_autocast():
        model(input_bits)

    qmodel = quantize_dynamic(model)
    qmodel.eval()

    safe_output = hil_safe_inference(
        qmodel, input_bits, c_floor=0.5, s_floor=0.2
    )

    student_model, _ = collapse_submodel(
        bit_sequence_data,
        target_params=dict(
            d_model=16,
            nhead=4,
            num_layers=1,
            dim_feedforward=32,
            max_seq_len=max_len,
        ),
        floors={"negentropy": 0.2, "lz_complexity": 0.5, "symbiosis_score": 0.2},
    )

    if hasattr(torch, "compile"):
        try:
            compiled = torch.compile(student_model)
        except RuntimeError as exc:
            print(f"Compilation skipped: {exc}")
            compiled = student_model
    else:
        compiled = student_model
    compiled.eval()

    with profile() as prof:
        compiled(input_bits)
    prof.export_chrome_trace("trace12.json")
    print("Safe output bits:", safe_output[0].tolist())


if __name__ == "__main__":
    recursive_integration_flow()