import io import time import contextlib from pathlib import Path import sys import torch ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from progressive_scaleup import progressive_scale_up_text from unified_workflow import run_workflow from bit_transformer.bit_io import text_to_bits from bit_transformer.safety import hil_safe_inference def capture_run(func, *args, **kwargs): buf = io.StringIO() start = time.time() with contextlib.redirect_stdout(buf): result = func(*args, **kwargs) duration = time.time() - start return result, buf.getvalue(), duration def main() -> None: summary: list[str] = [] _, log, dur = capture_run( progressive_scale_up_text, improve_thresh=0.01, steps=10, width_mult=2.0, max_len=64, dataset_size=512, forward_kwargs={"causal": True}, ) summary.append("### Progressive Scale-Up (causal=True)\n") summary.append(log.strip()) summary.append(f"Duration: {dur:.2f}s\n") _, log, dur = capture_run( progressive_scale_up_text, improve_thresh=0.01, steps=10, width_mult=2.0, max_len=64, dataset_size=512, forward_kwargs={"causal": False}, ) summary.append("### Progressive Scale-Up (causal=False)\n") summary.append(log.strip()) summary.append(f"Duration: {dur:.2f}s\n") (model, _), log, dur = capture_run( run_workflow, steps=2, max_len=32, dataset_size=32, plateau_steps=1, epochs_per_step=1, extra_steps=1, diffusion=False, ) bits = text_to_bits("hi") tensor = torch.tensor(bits, dtype=torch.long).unsqueeze(0) out_bits, _ = hil_safe_inference(model, tensor, c_floor=0.0, s_floor=0.0) summary.append("### Unified Workflow (causal=True)\n") summary.append(log.strip()) summary.append(f"Inference on 'hi': {out_bits.squeeze(0).tolist()}\n") summary.append(f"Duration: {dur:.2f}s\n") (_, _), log, dur = capture_run( run_workflow, steps=2, max_len=32, dataset_size=32, plateau_steps=1, epochs_per_step=1, extra_steps=1, diffusion=True, ) summary.append("### Unified Workflow (causal=False / Diffusion)\n") summary.append(log.strip()) summary.append(f"Duration: {dur:.2f}s\n") report = "\n".join(summary) print(report) if __name__ == "__main__": main()