🔧 Configuration update: unified_workflow.py with optimizations
Browse files- unified_workflow.py +165 -0
unified_workflow.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import subprocess
|
| 4 |
+
import sys
|
| 5 |
+
import time
|
| 6 |
+
import torch
|
| 7 |
+
from bit_transformer.utils import load_model
|
| 8 |
+
from bit_transformer.hf_checkpoint import (
|
| 9 |
+
hf_login,
|
| 10 |
+
save_checkpoint,
|
| 11 |
+
download_checkpoint,
|
| 12 |
+
)
|
| 13 |
+
from bit_transformer import diffusion_inference
|
| 14 |
+
from bit_transformer.cli_standards import create_workflow_parser, BitTransformerCLI
|
| 15 |
+
|
| 16 |
+
from integration_schedule import integration_schedule
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _launch_dashboard() -> list[subprocess.Popen]:
|
| 20 |
+
"""Start MCP server and dashboard processes."""
|
| 21 |
+
server = subprocess.Popen([sys.executable, "mcp_server.py"])
|
| 22 |
+
time.sleep(2)
|
| 23 |
+
dash_env = dict(os.environ)
|
| 24 |
+
dash_env.setdefault("MCP_SERVER_ADDR", "http://127.0.0.1:7000")
|
| 25 |
+
dashboard = subprocess.Popen(
|
| 26 |
+
[sys.executable, "-m", "bit_transformer.dashboard_app"],
|
| 27 |
+
env=dash_env,
|
| 28 |
+
)
|
| 29 |
+
return [server, dashboard]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _terminate(procs: list[subprocess.Popen]) -> None:
|
| 33 |
+
for p in procs:
|
| 34 |
+
p.terminate()
|
| 35 |
+
try:
|
| 36 |
+
p.wait(timeout=5)
|
| 37 |
+
except Exception:
|
| 38 |
+
p.kill()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def run_workflow(
|
| 42 |
+
steps: int = 10,
|
| 43 |
+
max_len: int = 64,
|
| 44 |
+
dataset_size: int = 128,
|
| 45 |
+
*,
|
| 46 |
+
launch_ui: bool = False,
|
| 47 |
+
weights_path: str = "weights/model.pt.gz",
|
| 48 |
+
collapsed_path: str = "weights/collapsed.pt.gz",
|
| 49 |
+
plateau_steps: int = 0,
|
| 50 |
+
epochs_per_step: int = 2,
|
| 51 |
+
extra_steps: int = 3,
|
| 52 |
+
collapse: bool = True,
|
| 53 |
+
hf_repo: str | None = None,
|
| 54 |
+
hf_token: str | None = None,
|
| 55 |
+
diffusion: bool = False,
|
| 56 |
+
noise_schedule: str = "linear",
|
| 57 |
+
diffusion_steps: int = 8,
|
| 58 |
+
diffusion_curriculum: bool = False,
|
| 59 |
+
use_checkpoint: bool = True,
|
| 60 |
+
reversible: bool = True,
|
| 61 |
+
qat: bool = False,
|
| 62 |
+
) -> tuple:
|
| 63 |
+
"""Run the full integration schedule with optional dashboard.
|
| 64 |
+
|
| 65 |
+
If ``qat`` is ``True`` the model undergoes 4-bit quantization-aware training
|
| 66 |
+
before being converted to quantized weights for safety checks.
|
| 67 |
+
"""
|
| 68 |
+
procs: list[subprocess.Popen] = []
|
| 69 |
+
if launch_ui:
|
| 70 |
+
procs = _launch_dashboard()
|
| 71 |
+
if hf_repo:
|
| 72 |
+
hf_login(token=hf_token)
|
| 73 |
+
if not os.path.exists(weights_path):
|
| 74 |
+
download_checkpoint(weights_path, repo_id=hf_repo)
|
| 75 |
+
try:
|
| 76 |
+
results, collapsed = integration_schedule(
|
| 77 |
+
steps=steps,
|
| 78 |
+
max_len=max_len,
|
| 79 |
+
dataset_size=dataset_size,
|
| 80 |
+
weights_path=weights_path,
|
| 81 |
+
plateau_steps=plateau_steps,
|
| 82 |
+
collapsed_path=collapsed_path,
|
| 83 |
+
epochs_per_step=epochs_per_step,
|
| 84 |
+
extra_steps=extra_steps,
|
| 85 |
+
collapse=collapse,
|
| 86 |
+
diffusion=diffusion,
|
| 87 |
+
noise_schedule=noise_schedule,
|
| 88 |
+
diffusion_steps=diffusion_steps,
|
| 89 |
+
diffusion_curriculum=diffusion_curriculum,
|
| 90 |
+
use_checkpoint=use_checkpoint,
|
| 91 |
+
reversible=reversible,
|
| 92 |
+
qat=qat,
|
| 93 |
+
)
|
| 94 |
+
model = load_model(weights_path)
|
| 95 |
+
print("Workflow results:", results)
|
| 96 |
+
if diffusion:
|
| 97 |
+
sample = diffusion_inference(
|
| 98 |
+
model, length=max_len, steps=diffusion_steps, schedule=noise_schedule
|
| 99 |
+
)
|
| 100 |
+
print("Diffusion inference output bits:", sample[0].tolist())
|
| 101 |
+
if hf_repo:
|
| 102 |
+
save_checkpoint(model, repo_id=hf_repo)
|
| 103 |
+
finally:
|
| 104 |
+
if launch_ui:
|
| 105 |
+
_terminate(procs)
|
| 106 |
+
return model, collapsed
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
# Use standardized CLI parser
|
| 111 |
+
parser = create_workflow_parser()
|
| 112 |
+
|
| 113 |
+
# Add workflow-specific arguments
|
| 114 |
+
workflow_group = parser.add_argument_group('Workflow Configuration')
|
| 115 |
+
workflow_group.add_argument("--steps", type=int, default=10,
|
| 116 |
+
help="Number of progressive scale-up steps")
|
| 117 |
+
workflow_group.add_argument("--plateau-steps", type=int, default=0,
|
| 118 |
+
help="Extra training steps at final size")
|
| 119 |
+
workflow_group.add_argument("--epochs-per-step", type=int, default=2,
|
| 120 |
+
help="Epochs per training step")
|
| 121 |
+
workflow_group.add_argument("--extra-steps", type=int, default=3,
|
| 122 |
+
help="Optimizer updates after each epoch")
|
| 123 |
+
workflow_group.add_argument("--no-collapse", action="store_true",
|
| 124 |
+
help="Skip collapsed model generation")
|
| 125 |
+
workflow_group.add_argument("--dashboard", action="store_true",
|
| 126 |
+
help="Launch MCP server and dashboard UI")
|
| 127 |
+
|
| 128 |
+
# Add advanced optimization arguments
|
| 129 |
+
opt_group = parser.add_argument_group('Advanced Optimization')
|
| 130 |
+
opt_group.add_argument("--no-checkpoint", action="store_true",
|
| 131 |
+
help="Disable gradient checkpointing (faster but more memory)")
|
| 132 |
+
opt_group.add_argument("--no-reversible", action="store_true",
|
| 133 |
+
help="Use standard transformer blocks instead of reversible layers")
|
| 134 |
+
opt_group.add_argument("--qat", action="store_true",
|
| 135 |
+
help="Enable 4-bit quantization-aware training")
|
| 136 |
+
|
| 137 |
+
# Override some defaults for workflow context
|
| 138 |
+
parser.set_defaults(
|
| 139 |
+
seq_length=64, # Use seq-length instead of max-len
|
| 140 |
+
dataset_size=128,
|
| 141 |
+
weights_path="weights/model.pt.gz"
|
| 142 |
+
)
|
| 143 |
+
args = parser.parse_args()
|
| 144 |
+
|
| 145 |
+
run_workflow(
|
| 146 |
+
args.steps,
|
| 147 |
+
args.seq_length, # Standardized name
|
| 148 |
+
args.dataset_size,
|
| 149 |
+
launch_ui=args.dashboard,
|
| 150 |
+
weights_path=args.weights_path,
|
| 151 |
+
collapsed_path=getattr(args, 'collapsed_path', 'weights/collapsed.pt.gz'),
|
| 152 |
+
plateau_steps=args.plateau_steps,
|
| 153 |
+
epochs_per_step=args.epochs_per_step,
|
| 154 |
+
extra_steps=args.extra_steps,
|
| 155 |
+
collapse=not args.no_collapse,
|
| 156 |
+
hf_repo=args.hf_repo,
|
| 157 |
+
hf_token=args.hf_token,
|
| 158 |
+
diffusion=args.diffusion_mode, # Standardized name
|
| 159 |
+
noise_schedule=args.noise_schedule,
|
| 160 |
+
diffusion_steps=args.diffusion_steps,
|
| 161 |
+
diffusion_curriculum=args.diffusion_curriculum,
|
| 162 |
+
use_checkpoint=not args.no_checkpoint,
|
| 163 |
+
reversible=not args.no_reversible,
|
| 164 |
+
qat=args.qat,
|
| 165 |
+
)
|