import argparse import os import subprocess import sys import time import torch from bit_transformer.utils import load_model from bit_transformer.hf_checkpoint import ( hf_login, save_checkpoint, download_checkpoint, ) from bit_transformer import diffusion_inference from bit_transformer.cli_standards import create_workflow_parser, BitTransformerCLI from integration_schedule import integration_schedule def _launch_dashboard() -> list[subprocess.Popen]: """Start MCP server and dashboard processes.""" server = subprocess.Popen([sys.executable, "mcp_server.py"]) time.sleep(2) dash_env = dict(os.environ) dash_env.setdefault("MCP_SERVER_ADDR", "http://127.0.0.1:7000") dashboard = subprocess.Popen( [sys.executable, "-m", "bit_transformer.dashboard_app"], env=dash_env, ) return [server, dashboard] def _terminate(procs: list[subprocess.Popen]) -> None: for p in procs: p.terminate() try: p.wait(timeout=5) except Exception: p.kill() def run_workflow( steps: int = 10, max_len: int = 64, dataset_size: int = 128, *, launch_ui: bool = False, weights_path: str = "weights/model.pt.gz", collapsed_path: str = "weights/collapsed.pt.gz", plateau_steps: int = 0, epochs_per_step: int = 2, extra_steps: int = 3, collapse: bool = True, hf_repo: str | None = None, hf_token: str | None = None, diffusion: bool = False, noise_schedule: str = "linear", diffusion_steps: int = 8, diffusion_curriculum: bool = False, use_checkpoint: bool = True, reversible: bool = True, qat: bool = False, ) -> tuple: """Run the full integration schedule with optional dashboard. If ``qat`` is ``True`` the model undergoes 4-bit quantization-aware training before being converted to quantized weights for safety checks. """ procs: list[subprocess.Popen] = [] if launch_ui: procs = _launch_dashboard() if hf_repo: hf_login(token=hf_token) if not os.path.exists(weights_path): download_checkpoint(weights_path, repo_id=hf_repo) try: results, collapsed = integration_schedule( steps=steps, max_len=max_len, dataset_size=dataset_size, weights_path=weights_path, plateau_steps=plateau_steps, collapsed_path=collapsed_path, epochs_per_step=epochs_per_step, extra_steps=extra_steps, collapse=collapse, diffusion=diffusion, noise_schedule=noise_schedule, diffusion_steps=diffusion_steps, diffusion_curriculum=diffusion_curriculum, use_checkpoint=use_checkpoint, reversible=reversible, qat=qat, ) model = load_model(weights_path) print("Workflow results:", results) if diffusion: sample = diffusion_inference( model, length=max_len, steps=diffusion_steps, schedule=noise_schedule ) print("Diffusion inference output bits:", sample[0].tolist()) if hf_repo: save_checkpoint(model, repo_id=hf_repo) finally: if launch_ui: _terminate(procs) return model, collapsed if __name__ == "__main__": # Use standardized CLI parser parser = create_workflow_parser() # Add workflow-specific arguments workflow_group = parser.add_argument_group('Workflow Configuration') workflow_group.add_argument("--steps", type=int, default=10, help="Number of progressive scale-up steps") workflow_group.add_argument("--plateau-steps", type=int, default=0, help="Extra training steps at final size") workflow_group.add_argument("--epochs-per-step", type=int, default=2, help="Epochs per training step") workflow_group.add_argument("--extra-steps", type=int, default=3, help="Optimizer updates after each epoch") workflow_group.add_argument("--no-collapse", action="store_true", help="Skip collapsed model generation") workflow_group.add_argument("--dashboard", action="store_true", help="Launch MCP server and dashboard UI") # Add advanced optimization arguments opt_group = parser.add_argument_group('Advanced Optimization') opt_group.add_argument("--no-checkpoint", action="store_true", help="Disable gradient checkpointing (faster but more memory)") opt_group.add_argument("--no-reversible", action="store_true", help="Use standard transformer blocks instead of reversible layers") opt_group.add_argument("--qat", action="store_true", help="Enable 4-bit quantization-aware training") # Override some defaults for workflow context parser.set_defaults( seq_length=64, # Use seq-length instead of max-len dataset_size=128, weights_path="weights/model.pt.gz" ) args = parser.parse_args() run_workflow( args.steps, args.seq_length, # Standardized name args.dataset_size, launch_ui=args.dashboard, weights_path=args.weights_path, collapsed_path=getattr(args, 'collapsed_path', 'weights/collapsed.pt.gz'), plateau_steps=args.plateau_steps, epochs_per_step=args.epochs_per_step, extra_steps=args.extra_steps, collapse=not args.no_collapse, hf_repo=args.hf_repo, hf_token=args.hf_token, diffusion=args.diffusion_mode, # Standardized name noise_schedule=args.noise_schedule, diffusion_steps=args.diffusion_steps, diffusion_curriculum=args.diffusion_curriculum, use_checkpoint=not args.no_checkpoint, reversible=not args.no_reversible, qat=args.qat, )