|
import torch |
|
from torch.profiler import profile |
|
from bit_transformer import ( |
|
BitTransformerLM, |
|
quantize_dynamic, |
|
hil_safe_inference, |
|
collapse_submodel, |
|
) |
|
from bit_transformer.training import train_loop |
|
from bit_transformer.torch_utils import cpu_autocast |
|
|
|
def train( |
|
model: BitTransformerLM, |
|
data: torch.Tensor, |
|
epochs: int = 3, |
|
compress_prob: float = 0.5, |
|
direct_prob: float = 0.0, |
|
log: bool = False, |
|
forward_kwargs: dict | None = None, |
|
) -> list[dict]: |
|
"""Train on bit sequences with optional random compression. |
|
|
|
If ``direct_prob`` is positive, some batches are fed using their |
|
run-length encoded representation packed into bits. Loss on these |
|
direct-compressed batches is tracked separately. |
|
|
|
Returns a list of per-epoch metric dictionaries containing raw and |
|
compressed loss/accuracy statistics and the mean compression ratio. |
|
""" |
|
return train_loop( |
|
model, |
|
data, |
|
epochs=epochs, |
|
compress_prob=compress_prob, |
|
direct_prob=direct_prob, |
|
log=log, |
|
forward_kwargs=forward_kwargs, |
|
) |
|
|
|
|
|
def main() -> None: |
|
data = torch.randint(0, 2, (64, 128), dtype=torch.long) |
|
validation_bits = torch.randint(0, 2, (16, 128), dtype=torch.long) |
|
input_bits = torch.randint(0, 2, (1, 128), dtype=torch.long) |
|
bit_sequence_data = data.tolist() |
|
|
|
model = BitTransformerLM( |
|
d_model=32, |
|
nhead=4, |
|
num_layers=1, |
|
dim_feedforward=64, |
|
max_seq_len=128, |
|
use_act=True, |
|
act_threshold=0.7, |
|
reversible=True, |
|
chunk_size=128, |
|
) |
|
|
|
for step in range(1, 13): |
|
if step % 2 == 0: |
|
model = model.double_width() |
|
else: |
|
model = model.double_layers() |
|
train(model, data, epochs=3, compress_prob=0.5, log=True) |
|
_, telemetry = model(validation_bits) |
|
K = telemetry["negentropy_logits"].mean().item() |
|
C = telemetry["lz_complexity_logits"].mean().item() |
|
S = telemetry["symbiosis_score"].mean().item() |
|
assert ( |
|
K > 0.3 and C > 0.35 and S > 0.5 |
|
), f"Step {step} telemetry floor failure" |
|
|
|
with cpu_autocast(): |
|
model(input_bits) |
|
|
|
quantized_model = quantize_dynamic(model) |
|
quantized_model.eval() |
|
|
|
safe_output, _ = hil_safe_inference( |
|
quantized_model, input_bits, c_floor=0.35, s_floor=0.5 |
|
) |
|
|
|
student_model, _ = collapse_submodel( |
|
bit_sequence_data, |
|
target_params=dict( |
|
d_model=16, |
|
nhead=4, |
|
num_layers=1, |
|
dim_feedforward=32, |
|
max_seq_len=128, |
|
), |
|
floors={"negentropy": 0.3, "lz_complexity": 0.35, "symbiosis_score": 0.5}, |
|
) |
|
|
|
compiled_model = ( |
|
torch.compile(student_model) |
|
if hasattr(torch, "compile") |
|
else student_model |
|
) |
|
compiled_model.eval() |
|
|
|
with profile() as prof: |
|
compiled_model(input_bits) |
|
|
|
prof.export_chrome_trace("trace12.json") |
|
print("Safe output bits:", safe_output.squeeze(0).tolist()) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|