| import torch | |
| from torch.profiler import profile | |
| from bit_transformer import ( | |
| BitTransformerLM, | |
| quantize_dynamic, | |
| hil_safe_inference, | |
| collapse_submodel, | |
| ) | |
| from bit_transformer.training import train_loop | |
| from bit_transformer.torch_utils import cpu_autocast | |
| def train( | |
| model: BitTransformerLM, | |
| data: torch.Tensor, | |
| epochs: int = 3, | |
| compress_prob: float = 0.5, | |
| direct_prob: float = 0.0, | |
| log: bool = False, | |
| forward_kwargs: dict | None = None, | |
| ) -> list[dict]: | |
| """Train on bit sequences with optional random compression. | |
| If ``direct_prob`` is positive, some batches are fed using their | |
| run-length encoded representation packed into bits. Loss on these | |
| direct-compressed batches is tracked separately. | |
| Returns a list of per-epoch metric dictionaries containing raw and | |
| compressed loss/accuracy statistics and the mean compression ratio. | |
| """ | |
| return train_loop( | |
| model, | |
| data, | |
| epochs=epochs, | |
| compress_prob=compress_prob, | |
| direct_prob=direct_prob, | |
| log=log, | |
| forward_kwargs=forward_kwargs, | |
| ) | |
| def main() -> None: | |
| data = torch.randint(0, 2, (64, 128), dtype=torch.long) | |
| validation_bits = torch.randint(0, 2, (16, 128), dtype=torch.long) | |
| input_bits = torch.randint(0, 2, (1, 128), dtype=torch.long) | |
| bit_sequence_data = data.tolist() | |
| model = BitTransformerLM( | |
| d_model=32, | |
| nhead=4, | |
| num_layers=1, | |
| dim_feedforward=64, | |
| max_seq_len=128, | |
| use_act=True, | |
| act_threshold=0.7, | |
| reversible=True, | |
| chunk_size=128, | |
| ) | |
| for step in range(1, 13): | |
| if step % 2 == 0: | |
| model = model.double_width() | |
| else: | |
| model = model.double_layers() | |
| train(model, data, epochs=3, compress_prob=0.5, log=True) | |
| _, telemetry = model(validation_bits) | |
| K = telemetry["negentropy_logits"].mean().item() | |
| C = telemetry["lz_complexity_logits"].mean().item() | |
| S = telemetry["symbiosis_score"].mean().item() | |
| assert ( | |
| K > 0.3 and C > 0.35 and S > 0.5 | |
| ), f"Step {step} telemetry floor failure" | |
| with cpu_autocast(): | |
| model(input_bits) | |
| quantized_model = quantize_dynamic(model) | |
| quantized_model.eval() | |
| safe_output, _ = hil_safe_inference( | |
| quantized_model, input_bits, c_floor=0.35, s_floor=0.5 | |
| ) | |
| student_model, _ = collapse_submodel( | |
| bit_sequence_data, | |
| target_params=dict( | |
| d_model=16, | |
| nhead=4, | |
| num_layers=1, | |
| dim_feedforward=32, | |
| max_seq_len=128, | |
| ), | |
| floors={"negentropy": 0.3, "lz_complexity": 0.35, "symbiosis_score": 0.5}, | |
| ) | |
| compiled_model = ( | |
| torch.compile(student_model) | |
| if hasattr(torch, "compile") | |
| else student_model | |
| ) | |
| compiled_model.eval() | |
| with profile() as prof: | |
| compiled_model(input_bits) | |
| prof.export_chrome_trace("trace12.json") | |
| print("Safe output bits:", safe_output.squeeze(0).tolist()) | |
| if __name__ == "__main__": | |
| main() | |