WCNegentropy commited on
Commit
3334875
·
verified ·
1 Parent(s): 7b4c2a6

🔧 Configuration update: unified_workflow.py with optimizations

Browse files
Files changed (1) hide show
  1. unified_workflow.py +165 -0
unified_workflow.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import subprocess
4
+ import sys
5
+ import time
6
+ import torch
7
+ from bit_transformer.utils import load_model
8
+ from bit_transformer.hf_checkpoint import (
9
+ hf_login,
10
+ save_checkpoint,
11
+ download_checkpoint,
12
+ )
13
+ from bit_transformer import diffusion_inference
14
+ from bit_transformer.cli_standards import create_workflow_parser, BitTransformerCLI
15
+
16
+ from integration_schedule import integration_schedule
17
+
18
+
19
+ def _launch_dashboard() -> list[subprocess.Popen]:
20
+ """Start MCP server and dashboard processes."""
21
+ server = subprocess.Popen([sys.executable, "mcp_server.py"])
22
+ time.sleep(2)
23
+ dash_env = dict(os.environ)
24
+ dash_env.setdefault("MCP_SERVER_ADDR", "http://127.0.0.1:7000")
25
+ dashboard = subprocess.Popen(
26
+ [sys.executable, "-m", "bit_transformer.dashboard_app"],
27
+ env=dash_env,
28
+ )
29
+ return [server, dashboard]
30
+
31
+
32
+ def _terminate(procs: list[subprocess.Popen]) -> None:
33
+ for p in procs:
34
+ p.terminate()
35
+ try:
36
+ p.wait(timeout=5)
37
+ except Exception:
38
+ p.kill()
39
+
40
+
41
+ def run_workflow(
42
+ steps: int = 10,
43
+ max_len: int = 64,
44
+ dataset_size: int = 128,
45
+ *,
46
+ launch_ui: bool = False,
47
+ weights_path: str = "weights/model.pt.gz",
48
+ collapsed_path: str = "weights/collapsed.pt.gz",
49
+ plateau_steps: int = 0,
50
+ epochs_per_step: int = 2,
51
+ extra_steps: int = 3,
52
+ collapse: bool = True,
53
+ hf_repo: str | None = None,
54
+ hf_token: str | None = None,
55
+ diffusion: bool = False,
56
+ noise_schedule: str = "linear",
57
+ diffusion_steps: int = 8,
58
+ diffusion_curriculum: bool = False,
59
+ use_checkpoint: bool = True,
60
+ reversible: bool = True,
61
+ qat: bool = False,
62
+ ) -> tuple:
63
+ """Run the full integration schedule with optional dashboard.
64
+
65
+ If ``qat`` is ``True`` the model undergoes 4-bit quantization-aware training
66
+ before being converted to quantized weights for safety checks.
67
+ """
68
+ procs: list[subprocess.Popen] = []
69
+ if launch_ui:
70
+ procs = _launch_dashboard()
71
+ if hf_repo:
72
+ hf_login(token=hf_token)
73
+ if not os.path.exists(weights_path):
74
+ download_checkpoint(weights_path, repo_id=hf_repo)
75
+ try:
76
+ results, collapsed = integration_schedule(
77
+ steps=steps,
78
+ max_len=max_len,
79
+ dataset_size=dataset_size,
80
+ weights_path=weights_path,
81
+ plateau_steps=plateau_steps,
82
+ collapsed_path=collapsed_path,
83
+ epochs_per_step=epochs_per_step,
84
+ extra_steps=extra_steps,
85
+ collapse=collapse,
86
+ diffusion=diffusion,
87
+ noise_schedule=noise_schedule,
88
+ diffusion_steps=diffusion_steps,
89
+ diffusion_curriculum=diffusion_curriculum,
90
+ use_checkpoint=use_checkpoint,
91
+ reversible=reversible,
92
+ qat=qat,
93
+ )
94
+ model = load_model(weights_path)
95
+ print("Workflow results:", results)
96
+ if diffusion:
97
+ sample = diffusion_inference(
98
+ model, length=max_len, steps=diffusion_steps, schedule=noise_schedule
99
+ )
100
+ print("Diffusion inference output bits:", sample[0].tolist())
101
+ if hf_repo:
102
+ save_checkpoint(model, repo_id=hf_repo)
103
+ finally:
104
+ if launch_ui:
105
+ _terminate(procs)
106
+ return model, collapsed
107
+
108
+
109
+ if __name__ == "__main__":
110
+ # Use standardized CLI parser
111
+ parser = create_workflow_parser()
112
+
113
+ # Add workflow-specific arguments
114
+ workflow_group = parser.add_argument_group('Workflow Configuration')
115
+ workflow_group.add_argument("--steps", type=int, default=10,
116
+ help="Number of progressive scale-up steps")
117
+ workflow_group.add_argument("--plateau-steps", type=int, default=0,
118
+ help="Extra training steps at final size")
119
+ workflow_group.add_argument("--epochs-per-step", type=int, default=2,
120
+ help="Epochs per training step")
121
+ workflow_group.add_argument("--extra-steps", type=int, default=3,
122
+ help="Optimizer updates after each epoch")
123
+ workflow_group.add_argument("--no-collapse", action="store_true",
124
+ help="Skip collapsed model generation")
125
+ workflow_group.add_argument("--dashboard", action="store_true",
126
+ help="Launch MCP server and dashboard UI")
127
+
128
+ # Add advanced optimization arguments
129
+ opt_group = parser.add_argument_group('Advanced Optimization')
130
+ opt_group.add_argument("--no-checkpoint", action="store_true",
131
+ help="Disable gradient checkpointing (faster but more memory)")
132
+ opt_group.add_argument("--no-reversible", action="store_true",
133
+ help="Use standard transformer blocks instead of reversible layers")
134
+ opt_group.add_argument("--qat", action="store_true",
135
+ help="Enable 4-bit quantization-aware training")
136
+
137
+ # Override some defaults for workflow context
138
+ parser.set_defaults(
139
+ seq_length=64, # Use seq-length instead of max-len
140
+ dataset_size=128,
141
+ weights_path="weights/model.pt.gz"
142
+ )
143
+ args = parser.parse_args()
144
+
145
+ run_workflow(
146
+ args.steps,
147
+ args.seq_length, # Standardized name
148
+ args.dataset_size,
149
+ launch_ui=args.dashboard,
150
+ weights_path=args.weights_path,
151
+ collapsed_path=getattr(args, 'collapsed_path', 'weights/collapsed.pt.gz'),
152
+ plateau_steps=args.plateau_steps,
153
+ epochs_per_step=args.epochs_per_step,
154
+ extra_steps=args.extra_steps,
155
+ collapse=not args.no_collapse,
156
+ hf_repo=args.hf_repo,
157
+ hf_token=args.hf_token,
158
+ diffusion=args.diffusion_mode, # Standardized name
159
+ noise_schedule=args.noise_schedule,
160
+ diffusion_steps=args.diffusion_steps,
161
+ diffusion_curriculum=args.diffusion_curriculum,
162
+ use_checkpoint=not args.no_checkpoint,
163
+ reversible=not args.no_reversible,
164
+ qat=args.qat,
165
+ )