File size: 6,057 Bytes
3334875 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import argparse
import os
import subprocess
import sys
import time
import torch
from bit_transformer.utils import load_model
from bit_transformer.hf_checkpoint import (
hf_login,
save_checkpoint,
download_checkpoint,
)
from bit_transformer import diffusion_inference
from bit_transformer.cli_standards import create_workflow_parser, BitTransformerCLI
from integration_schedule import integration_schedule
def _launch_dashboard() -> list[subprocess.Popen]:
"""Start MCP server and dashboard processes."""
server = subprocess.Popen([sys.executable, "mcp_server.py"])
time.sleep(2)
dash_env = dict(os.environ)
dash_env.setdefault("MCP_SERVER_ADDR", "http://127.0.0.1:7000")
dashboard = subprocess.Popen(
[sys.executable, "-m", "bit_transformer.dashboard_app"],
env=dash_env,
)
return [server, dashboard]
def _terminate(procs: list[subprocess.Popen]) -> None:
for p in procs:
p.terminate()
try:
p.wait(timeout=5)
except Exception:
p.kill()
def run_workflow(
steps: int = 10,
max_len: int = 64,
dataset_size: int = 128,
*,
launch_ui: bool = False,
weights_path: str = "weights/model.pt.gz",
collapsed_path: str = "weights/collapsed.pt.gz",
plateau_steps: int = 0,
epochs_per_step: int = 2,
extra_steps: int = 3,
collapse: bool = True,
hf_repo: str | None = None,
hf_token: str | None = None,
diffusion: bool = False,
noise_schedule: str = "linear",
diffusion_steps: int = 8,
diffusion_curriculum: bool = False,
use_checkpoint: bool = True,
reversible: bool = True,
qat: bool = False,
) -> tuple:
"""Run the full integration schedule with optional dashboard.
If ``qat`` is ``True`` the model undergoes 4-bit quantization-aware training
before being converted to quantized weights for safety checks.
"""
procs: list[subprocess.Popen] = []
if launch_ui:
procs = _launch_dashboard()
if hf_repo:
hf_login(token=hf_token)
if not os.path.exists(weights_path):
download_checkpoint(weights_path, repo_id=hf_repo)
try:
results, collapsed = integration_schedule(
steps=steps,
max_len=max_len,
dataset_size=dataset_size,
weights_path=weights_path,
plateau_steps=plateau_steps,
collapsed_path=collapsed_path,
epochs_per_step=epochs_per_step,
extra_steps=extra_steps,
collapse=collapse,
diffusion=diffusion,
noise_schedule=noise_schedule,
diffusion_steps=diffusion_steps,
diffusion_curriculum=diffusion_curriculum,
use_checkpoint=use_checkpoint,
reversible=reversible,
qat=qat,
)
model = load_model(weights_path)
print("Workflow results:", results)
if diffusion:
sample = diffusion_inference(
model, length=max_len, steps=diffusion_steps, schedule=noise_schedule
)
print("Diffusion inference output bits:", sample[0].tolist())
if hf_repo:
save_checkpoint(model, repo_id=hf_repo)
finally:
if launch_ui:
_terminate(procs)
return model, collapsed
if __name__ == "__main__":
# Use standardized CLI parser
parser = create_workflow_parser()
# Add workflow-specific arguments
workflow_group = parser.add_argument_group('Workflow Configuration')
workflow_group.add_argument("--steps", type=int, default=10,
help="Number of progressive scale-up steps")
workflow_group.add_argument("--plateau-steps", type=int, default=0,
help="Extra training steps at final size")
workflow_group.add_argument("--epochs-per-step", type=int, default=2,
help="Epochs per training step")
workflow_group.add_argument("--extra-steps", type=int, default=3,
help="Optimizer updates after each epoch")
workflow_group.add_argument("--no-collapse", action="store_true",
help="Skip collapsed model generation")
workflow_group.add_argument("--dashboard", action="store_true",
help="Launch MCP server and dashboard UI")
# Add advanced optimization arguments
opt_group = parser.add_argument_group('Advanced Optimization')
opt_group.add_argument("--no-checkpoint", action="store_true",
help="Disable gradient checkpointing (faster but more memory)")
opt_group.add_argument("--no-reversible", action="store_true",
help="Use standard transformer blocks instead of reversible layers")
opt_group.add_argument("--qat", action="store_true",
help="Enable 4-bit quantization-aware training")
# Override some defaults for workflow context
parser.set_defaults(
seq_length=64, # Use seq-length instead of max-len
dataset_size=128,
weights_path="weights/model.pt.gz"
)
args = parser.parse_args()
run_workflow(
args.steps,
args.seq_length, # Standardized name
args.dataset_size,
launch_ui=args.dashboard,
weights_path=args.weights_path,
collapsed_path=getattr(args, 'collapsed_path', 'weights/collapsed.pt.gz'),
plateau_steps=args.plateau_steps,
epochs_per_step=args.epochs_per_step,
extra_steps=args.extra_steps,
collapse=not args.no_collapse,
hf_repo=args.hf_repo,
hf_token=args.hf_token,
diffusion=args.diffusion_mode, # Standardized name
noise_schedule=args.noise_schedule,
diffusion_steps=args.diffusion_steps,
diffusion_curriculum=args.diffusion_curriculum,
use_checkpoint=not args.no_checkpoint,
reversible=not args.no_reversible,
qat=args.qat,
)
|