WCNegentropy commited on
Commit
2b42db0
Β·
verified Β·
1 Parent(s): 1cddbe9

πŸš€ OS Launch: Clean documentation and refined licensing

Browse files

This OS launch commit includes:

βœ… **Cleaned Documentation**
- Removed inflated claims and marketing language
- Added honest research status and limitations
- Created professional model card and validation reports
- Streamlined licensing to AGPLv3 + commercial contact

βœ… **Refined Codebase**
- Complete experimental bit-native transformer implementation
- 57 Python files with comprehensive research framework
- Safety telemetry and monitoring systems
- Distributed training and development tools

βœ… **Professional Standards**
- Empirical validation of all claims
- Clear experimental vs production distinctions
- Rigorous research methodology requirements
- Community contribution framework

Ready for serious research evaluation and academic investigation.

Files changed (1) hide show
  1. gradio_dashboard.py +778 -0
gradio_dashboard.py ADDED
@@ -0,0 +1,778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ BitTransformerLM Gradio Dashboard
4
+ =================================
5
+
6
+ Comprehensive Gradio interface for BitTransformerLM with full feature parity to the Flask dashboard.
7
+ Supports both local deployment and HuggingFace Spaces integration while maintaining MCP server compatibility.
8
+ """
9
+
10
+ import io
11
+ import json
12
+ import os
13
+ import sys
14
+ import traceback
15
+ import warnings
16
+ from typing import Any, Dict, List, Optional, Union, Tuple
17
+ import matplotlib.pyplot as plt
18
+ import matplotlib
19
+ matplotlib.use('Agg') # Use non-interactive backend
20
+ import torch
21
+ import torch.nn.functional as F
22
+ import gradio as gr
23
+ import numpy as np
24
+ from pathlib import Path
25
+ import threading
26
+ import time
27
+ import requests
28
+ from concurrent.futures import ThreadPoolExecutor
29
+ import uuid
30
+
31
+ # Add BitTransformerLM to path
32
+ sys.path.insert(0, str(Path(__file__).parent))
33
+
34
+ # BitTransformerLM imports
35
+ from bit_transformer.model import BitTransformerLM, infer_long_sequence
36
+ from bit_transformer.optimization import configure_optimizer
37
+ from bit_transformer.collapse import collapse_submodel
38
+ from bit_transformer.dashboard import plot_telemetry
39
+ from bit_transformer.scale import expand_model
40
+ from bit_transformer.bit_io import text_to_bits, bits_to_text
41
+ from bit_transformer.safety import hil_safe_inference
42
+ from bit_transformer.compression import model_output_decompress, compress_bits
43
+ from bit_transformer.distributed import wrap_fsdp
44
+ from bit_transformer.training import train_loop
45
+ from bit_transformer.telemetry import detect_metric_drift
46
+ from bit_transformer.quantization import prepare_qat_fx, convert_qat_fx
47
+ from bit_transformer.hf_checkpoint import hf_login, save_checkpoint, download_checkpoint
48
+ from bit_transformer.dataset_builder import BitTransformerDatasetBuilder, create_bittransformerlm_dataset
49
+
50
+ # Global state management
51
+ class GradioModelManager:
52
+ """Enhanced ModelManager for Gradio interface with thread safety."""
53
+
54
+ def __init__(self):
55
+ self.model = None
56
+ self.config = {}
57
+ self.telemetry_log = {
58
+ "negentropy": [],
59
+ "lz_complexity": [],
60
+ "symbiosis_score": [],
61
+ "steps": []
62
+ }
63
+ self.c_floor = 0.3
64
+ self.s_floor = 0.5
65
+ self.lambda_weights = {"K": 1.0, "C": 1.0, "S": 1.0}
66
+ self.compression_enabled = False
67
+ self.qat_enabled = False
68
+ self.diffusion_enabled = False
69
+ self.gpu_enabled = False
70
+
71
+ # Background job management
72
+ self.executor = ThreadPoolExecutor(max_workers=4)
73
+ self.jobs = {}
74
+ self.mcp_server_addr = os.getenv("MCP_SERVER_ADDR")
75
+
76
+ # Thread safety
77
+ self.lock = threading.Lock()
78
+
79
+ def init_model(self, model_config: dict):
80
+ """Initialize BitTransformerLM model with given configuration."""
81
+ with self.lock:
82
+ try:
83
+ # Clean config - remove None values
84
+ clean_config = {k: v for k, v in model_config.items() if v is not None and v != ""}
85
+
86
+ self.model = BitTransformerLM(**clean_config)
87
+ self.config = clean_config
88
+
89
+ # Apply transformations
90
+ if self.qat_enabled:
91
+ self.model = prepare_qat_fx(self.model)
92
+ if self.gpu_enabled and torch.cuda.is_available():
93
+ self.model = self.model.cuda()
94
+
95
+ return f"βœ… Model initialized with config: {clean_config}"
96
+ except Exception as e:
97
+ return f"❌ Model initialization failed: {str(e)}"
98
+
99
+ def train_step(self, bits_input, epochs=1):
100
+ """Execute training step(s) with given bit input."""
101
+ if self.model is None:
102
+ return "❌ Model not initialized", None, None
103
+
104
+ try:
105
+ # Parse bits input
106
+ if isinstance(bits_input, str):
107
+ if bits_input.strip().startswith('['):
108
+ # JSON format
109
+ bits = json.loads(bits_input)
110
+ else:
111
+ # Space-separated format
112
+ bits = [int(x) for x in bits_input.strip().split()]
113
+ else:
114
+ bits = bits_input
115
+
116
+ tensor = torch.tensor(bits, dtype=torch.long)
117
+ if self.gpu_enabled and torch.cuda.is_available():
118
+ tensor = tensor.cuda()
119
+
120
+ # Training loop
121
+ total_loss = 0
122
+ compression_ratio = 1.0
123
+
124
+ for epoch in range(epochs):
125
+ self.model.train()
126
+
127
+ # Forward pass with telemetry
128
+ if self.compression_enabled:
129
+ compressed_bits, ratio = compress_bits(bits)
130
+ tensor = torch.tensor(compressed_bits, dtype=torch.long)
131
+ compression_ratio = ratio
132
+
133
+ output, telemetry = self.model(tensor.unsqueeze(0))
134
+
135
+ # Compute loss
136
+ if output.dim() == 3:
137
+ loss = F.cross_entropy(
138
+ output.view(-1, output.size(-1)),
139
+ tensor[:-1].unsqueeze(0).contiguous().view(-1),
140
+ ignore_index=-1
141
+ )
142
+ else:
143
+ loss = F.cross_entropy(output, tensor.unsqueeze(0))
144
+
145
+ # Backward pass
146
+ loss.backward()
147
+
148
+ # Update telemetry
149
+ self._update_telemetry(telemetry)
150
+ total_loss += loss.item()
151
+
152
+ avg_loss = total_loss / epochs
153
+ return f"βœ… Training completed. Average Loss: {avg_loss:.4f}", avg_loss, compression_ratio
154
+
155
+ except Exception as e:
156
+ return f"❌ Training failed: {str(e)}", None, None
157
+
158
+ def inference(self, bits_input, long_inference=False, ctx_bits=4096, overlap=256):
159
+ """Run inference on bit input."""
160
+ if self.model is None:
161
+ return "❌ Model not initialized", None
162
+
163
+ try:
164
+ # Parse bits input
165
+ if isinstance(bits_input, str):
166
+ if bits_input.strip().startswith('['):
167
+ bits = json.loads(bits_input)
168
+ else:
169
+ bits = [int(x) for x in bits_input.strip().split()]
170
+ else:
171
+ bits = bits_input
172
+
173
+ tensor = torch.tensor(bits, dtype=torch.long)
174
+ if self.gpu_enabled and torch.cuda.is_available():
175
+ tensor = tensor.cuda()
176
+
177
+ self.model.eval()
178
+
179
+ with torch.inference_mode():
180
+ if long_inference or len(bits) > ctx_bits:
181
+ # Long sequence inference
182
+ output, telemetry = infer_long_sequence(
183
+ self.model, tensor.unsqueeze(0),
184
+ ctx_bits=ctx_bits, overlap=overlap
185
+ )
186
+ else:
187
+ # Standard inference with safety gates
188
+ output, telemetry = hil_safe_inference(
189
+ self.model, tensor.unsqueeze(0),
190
+ c_floor=self.c_floor, s_floor=self.s_floor
191
+ )
192
+
193
+ # Update telemetry
194
+ self._update_telemetry(telemetry)
195
+
196
+ output_bits = output.squeeze(0).cpu().tolist()
197
+ return f"βœ… Inference completed. Output length: {len(output_bits)}", output_bits
198
+
199
+ except Exception as e:
200
+ return f"❌ Inference failed: {str(e)}", None
201
+
202
+ def text_inference(self, text_input):
203
+ """Convert text to bits, run inference, convert back to text."""
204
+ try:
205
+ # Text to bits
206
+ bits = text_to_bits(text_input)
207
+
208
+ # Run inference
209
+ result, output_bits = self.inference(bits)
210
+
211
+ if output_bits is None:
212
+ return result, None
213
+
214
+ # Convert back to text
215
+ try:
216
+ output_text = bits_to_text(output_bits)
217
+ return f"βœ… Text inference completed.", output_text
218
+ except Exception as e:
219
+ return f"βœ… Inference completed, but text conversion failed: {str(e)}", str(output_bits)
220
+
221
+ except Exception as e:
222
+ return f"❌ Text inference failed: {str(e)}", None
223
+
224
+ def scale_model(self, width_multiplier):
225
+ """Scale up model width."""
226
+ if self.model is None:
227
+ return "❌ Model not initialized"
228
+
229
+ try:
230
+ with self.lock:
231
+ self.model = expand_model(self.model, width_multiplier)
232
+ return f"βœ… Model scaled by factor {width_multiplier}"
233
+ except Exception as e:
234
+ return f"❌ Model scaling failed: {str(e)}"
235
+
236
+ def collapse_model(self, cluster_bits, target_params, width_scale=1.0):
237
+ """Collapse model using cluster analysis."""
238
+ if self.model is None:
239
+ return "❌ Model not initialized"
240
+
241
+ try:
242
+ # Parse inputs
243
+ if isinstance(cluster_bits, str):
244
+ clusters = json.loads(cluster_bits)
245
+ else:
246
+ clusters = cluster_bits
247
+
248
+ if isinstance(target_params, str):
249
+ params = json.loads(target_params)
250
+ else:
251
+ params = target_params
252
+
253
+ with self.lock:
254
+ collapsed_model = collapse_submodel(
255
+ self.model, clusters, params, width_scale
256
+ )
257
+ self.model = collapsed_model
258
+ return f"βœ… Model collapsed successfully"
259
+ except Exception as e:
260
+ return f"❌ Model collapse failed: {str(e)}"
261
+
262
+ def get_model_status(self):
263
+ """Get current model status and configuration."""
264
+ if self.model is None:
265
+ return "❌ No model initialized"
266
+
267
+ try:
268
+ param_count = sum(p.numel() for p in self.model.parameters())
269
+ status = {
270
+ "initialized": True,
271
+ "parameters": param_count,
272
+ "config": self.config,
273
+ "gpu_enabled": self.gpu_enabled,
274
+ "qat_enabled": self.qat_enabled,
275
+ "compression_enabled": self.compression_enabled,
276
+ "diffusion_enabled": self.diffusion_enabled,
277
+ }
278
+ return json.dumps(status, indent=2)
279
+ except Exception as e:
280
+ return f"❌ Status check failed: {str(e)}"
281
+
282
+ def get_telemetry_plot(self):
283
+ """Generate telemetry plot."""
284
+ try:
285
+ if not any(self.telemetry_log.values()):
286
+ # Return empty plot
287
+ fig, ax = plt.subplots(figsize=(10, 6))
288
+ ax.text(0.5, 0.5, 'No telemetry data yet', ha='center', va='center', transform=ax.transAxes)
289
+ ax.set_title('Telemetry Metrics')
290
+ return fig
291
+
292
+ fig, axes = plot_telemetry(
293
+ self.telemetry_log,
294
+ k_floor=0.5, # Negentropy floor
295
+ c_floor=self.c_floor,
296
+ s_floor=self.s_floor
297
+ )
298
+ return fig
299
+ except Exception as e:
300
+ # Return error plot
301
+ fig, ax = plt.subplots(figsize=(10, 6))
302
+ ax.text(0.5, 0.5, f'Plot error: {str(e)}', ha='center', va='center', transform=ax.transAxes)
303
+ ax.set_title('Telemetry Metrics - Error')
304
+ return fig
305
+
306
+ def _update_telemetry(self, telemetry_dict):
307
+ """Update telemetry log with new values."""
308
+ if not telemetry_dict:
309
+ return
310
+
311
+ step = len(self.telemetry_log["steps"])
312
+ self.telemetry_log["steps"].append(step)
313
+
314
+ # Extract metrics with defaults
315
+ self.telemetry_log["negentropy"].append(
316
+ float(telemetry_dict.get("negentropy", torch.tensor(0.0)).mean().item())
317
+ )
318
+ self.telemetry_log["lz_complexity"].append(
319
+ float(telemetry_dict.get("lz_complexity_logits", torch.tensor(0.0)).mean().item())
320
+ )
321
+ self.telemetry_log["symbiosis_score"].append(
322
+ float(telemetry_dict.get("symbiosis_score", torch.tensor(0.0)).mean().item())
323
+ )
324
+
325
+ def huggingface_upload(self, repo_id, hf_token=None):
326
+ """Upload model to HuggingFace."""
327
+ if self.model is None:
328
+ return "❌ Model not initialized"
329
+
330
+ try:
331
+ if hf_token:
332
+ hf_login(hf_token)
333
+
334
+ save_checkpoint(self.model, repo_id, self.config)
335
+ return f"βœ… Model uploaded to {repo_id}"
336
+ except Exception as e:
337
+ return f"❌ HF upload failed: {str(e)}"
338
+
339
+ def huggingface_download(self, repo_id, hf_token=None):
340
+ """Download model from HuggingFace."""
341
+ try:
342
+ if hf_token:
343
+ hf_login(hf_token)
344
+
345
+ with self.lock:
346
+ model, config = download_checkpoint(repo_id)
347
+ self.model = model
348
+ self.config = config
349
+
350
+ return f"βœ… Model downloaded from {repo_id}"
351
+ except Exception as e:
352
+ return f"❌ HF download failed: {str(e)}"
353
+
354
+ def mcp_request(self, endpoint, data=None, method="POST"):
355
+ """Make request to MCP server if available."""
356
+ if not self.mcp_server_addr:
357
+ return "❌ MCP server not configured"
358
+
359
+ try:
360
+ url = self.mcp_server_addr.rstrip("/") + endpoint
361
+ if method == "POST":
362
+ resp = requests.post(url, json=data, timeout=30)
363
+ else:
364
+ resp = requests.get(url, timeout=30)
365
+
366
+ resp.raise_for_status()
367
+
368
+ if resp.headers.get("Content-Type", "").startswith("image/"):
369
+ return "βœ… MCP request completed (binary data)"
370
+ return f"βœ… MCP request completed: {resp.json()}"
371
+ except Exception as e:
372
+ return f"❌ MCP request failed: {str(e)}"
373
+
374
+ # Global manager instance
375
+ manager = GradioModelManager()
376
+
377
+ def create_gradio_interface():
378
+ """Create the main Gradio interface with all BitTransformerLM features."""
379
+
380
+ # Helper functions for Gradio callbacks
381
+ def init_model_callback(d_model, nhead, num_layers, dim_feedforward, max_seq_len,
382
+ chunk_size, overlap, reversible, use_checkpoint, act_threshold,
383
+ c_floor, s_floor):
384
+ """Initialize model with form parameters."""
385
+ config = {
386
+ "d_model": d_model,
387
+ "nhead": nhead,
388
+ "num_layers": num_layers,
389
+ "dim_feedforward": dim_feedforward,
390
+ "max_seq_len": max_seq_len,
391
+ "chunk_size": chunk_size if chunk_size > 0 else None,
392
+ "overlap": overlap,
393
+ "reversible": reversible,
394
+ "use_checkpoint": use_checkpoint,
395
+ "act_threshold": act_threshold
396
+ }
397
+
398
+ # Update safety floors
399
+ manager.c_floor = c_floor
400
+ manager.s_floor = s_floor
401
+
402
+ result = manager.init_model(config)
403
+ status = manager.get_model_status()
404
+ plot = manager.get_telemetry_plot()
405
+
406
+ return result, status, plot
407
+
408
+ def train_callback(bits_input, epochs, file_input):
409
+ """Training callback with file upload support."""
410
+ if file_input is not None:
411
+ # Process uploaded file
412
+ try:
413
+ if file_input.name.endswith(('.txt', '.md')):
414
+ with open(file_input.name, 'r') as f:
415
+ text = f.read()
416
+ bits = text_to_bits(text)
417
+ else:
418
+ with open(file_input.name, 'rb') as f:
419
+ data = f.read()
420
+ # Convert bytes to bits
421
+ bits = []
422
+ for byte in data:
423
+ for i in range(8):
424
+ bits.append((byte >> (7-i)) & 1)
425
+
426
+ result, loss, ratio = manager.train_step(bits, epochs)
427
+ except Exception as e:
428
+ result = f"❌ File processing failed: {str(e)}"
429
+ loss, ratio = None, None
430
+ else:
431
+ result, loss, ratio = manager.train_step(bits_input, epochs)
432
+
433
+ status = manager.get_model_status()
434
+ plot = manager.get_telemetry_plot()
435
+
436
+ return result, status, plot, f"Compression Ratio: {ratio:.2f}" if ratio else ""
437
+
438
+ def inference_callback(bits_input, file_input):
439
+ """Standard inference callback."""
440
+ if file_input is not None:
441
+ # Process uploaded file similar to training
442
+ try:
443
+ if file_input.name.endswith(('.txt', '.md')):
444
+ with open(file_input.name, 'r') as f:
445
+ text = f.read()
446
+ bits = text_to_bits(text)
447
+ else:
448
+ with open(file_input.name, 'rb') as f:
449
+ data = f.read()
450
+ bits = []
451
+ for byte in data:
452
+ for i in range(8):
453
+ bits.append((byte >> (7-i)) & 1)
454
+
455
+ result, output_bits = manager.inference(bits)
456
+ except Exception as e:
457
+ result = f"❌ File processing failed: {str(e)}"
458
+ output_bits = None
459
+ else:
460
+ result, output_bits = manager.inference(bits_input)
461
+
462
+ return result, str(output_bits) if output_bits else ""
463
+
464
+ def long_inference_callback(bits_input, ctx_bits, overlap):
465
+ """Long sequence inference callback."""
466
+ result, output_bits = manager.inference(bits_input, long_inference=True,
467
+ ctx_bits=ctx_bits, overlap=overlap)
468
+ return result, str(output_bits) if output_bits else ""
469
+
470
+ def text_inference_callback(text_input):
471
+ """Text-to-text inference callback."""
472
+ result, output_text = manager.text_inference(text_input)
473
+ return result, output_text if output_text else ""
474
+
475
+ # Create Gradio interface
476
+ with gr.Blocks(title="BitTransformerLM Dashboard",
477
+ theme=gr.themes.Soft()) as interface:
478
+
479
+ gr.Markdown("# πŸ€– BitTransformerLM Interactive Dashboard")
480
+ gr.Markdown("*Experimental bit-native transformer with comprehensive training and inference capabilities*")
481
+
482
+ with gr.Tab("πŸ—οΈ Model Configuration"):
483
+ gr.Markdown("## Initialize BitTransformerLM")
484
+
485
+ with gr.Row():
486
+ with gr.Column():
487
+ d_model = gr.Number(label="d_model", value=64, info="Model width")
488
+ nhead = gr.Number(label="nhead", value=4, info="Attention heads")
489
+ num_layers = gr.Number(label="num_layers", value=2, info="Transformer layers")
490
+ dim_feedforward = gr.Number(label="dim_feedforward", value=256, info="FFN dimension")
491
+
492
+ with gr.Column():
493
+ max_seq_len = gr.Number(label="max_seq_len", value=512, info="Max sequence length")
494
+ chunk_size = gr.Number(label="chunk_size", value=0, info="Chunk size (0=auto)")
495
+ overlap = gr.Number(label="overlap", value=64, info="Sliding window overlap")
496
+ act_threshold = gr.Number(label="act_threshold", value=0.95, info="ACT halt threshold")
497
+
498
+ with gr.Row():
499
+ reversible = gr.Checkbox(label="Reversible Layers", value=False)
500
+ use_checkpoint = gr.Checkbox(label="Gradient Checkpointing", value=True)
501
+
502
+ with gr.Row():
503
+ c_floor = gr.Number(label="c_floor", value=0.3, info="Complexity safety floor")
504
+ s_floor = gr.Number(label="s_floor", value=0.5, info="Symbiosis safety floor")
505
+
506
+ init_btn = gr.Button("πŸš€ Initialize Model", variant="primary")
507
+ init_output = gr.Textbox(label="Initialization Result", interactive=False)
508
+
509
+ with gr.Tab("🎯 Training"):
510
+ gr.Markdown("## Train BitTransformerLM")
511
+
512
+ with gr.Row():
513
+ with gr.Column():
514
+ train_bits = gr.Textbox(
515
+ label="Bit Input",
516
+ placeholder="0 1 0 1 or [0,1,0,1] or upload file",
517
+ lines=3
518
+ )
519
+ train_file = gr.File(label="Upload Training File", file_types=[".txt", ".md", ".bin"])
520
+ train_epochs = gr.Number(label="Epochs", value=1, minimum=1)
521
+
522
+ with gr.Column():
523
+ train_btn = gr.Button("πŸƒ Start Training", variant="primary")
524
+ train_output = gr.Textbox(label="Training Result", interactive=False)
525
+ compression_output = gr.Textbox(label="Compression Info", interactive=False)
526
+
527
+ with gr.Tab("🧠 Inference"):
528
+ with gr.Tab("Standard Inference"):
529
+ gr.Markdown("## Standard Inference")
530
+
531
+ with gr.Row():
532
+ with gr.Column():
533
+ infer_bits = gr.Textbox(
534
+ label="Bit Input",
535
+ placeholder="0 1 0 1 or [0,1,0,1]",
536
+ lines=3
537
+ )
538
+ infer_file = gr.File(label="Upload Inference File")
539
+
540
+ with gr.Column():
541
+ infer_btn = gr.Button("🎯 Run Inference", variant="primary")
542
+ infer_result = gr.Textbox(label="Result", interactive=False)
543
+ infer_output = gr.Textbox(label="Output Bits", lines=5, interactive=False)
544
+
545
+ with gr.Tab("Long Sequence Inference"):
546
+ gr.Markdown("## Long Sequence Inference")
547
+
548
+ with gr.Row():
549
+ with gr.Column():
550
+ long_bits = gr.Textbox(
551
+ label="Long Bit Sequence",
552
+ lines=5,
553
+ placeholder="Long sequence of bits..."
554
+ )
555
+ long_ctx_bits = gr.Number(label="Context Bits", value=4096)
556
+ long_overlap = gr.Number(label="Overlap", value=256)
557
+
558
+ with gr.Column():
559
+ long_infer_btn = gr.Button("πŸ”„ Run Long Inference", variant="primary")
560
+ long_result = gr.Textbox(label="Result", interactive=False)
561
+ long_output = gr.Textbox(label="Output Bits", lines=5, interactive=False)
562
+
563
+ with gr.Tab("Text Inference"):
564
+ gr.Markdown("## Text-to-Text Inference")
565
+
566
+ with gr.Row():
567
+ with gr.Column():
568
+ text_input = gr.Textbox(
569
+ label="Input Text",
570
+ placeholder="Enter text to process...",
571
+ lines=3
572
+ )
573
+ text_infer_btn = gr.Button("πŸ“ Process Text", variant="primary")
574
+
575
+ with gr.Column():
576
+ text_result = gr.Textbox(label="Result", interactive=False)
577
+ text_output = gr.Textbox(
578
+ label="Output Text",
579
+ lines=5,
580
+ interactive=False
581
+ )
582
+
583
+ with gr.Tab("βš™οΈ Model Operations"):
584
+ with gr.Tab("Scale Model"):
585
+ gr.Markdown("## Scale Model Width")
586
+
587
+ with gr.Row():
588
+ width_mult = gr.Number(label="Width Multiplier", value=1.5, step=0.1)
589
+ scale_btn = gr.Button("πŸ“ˆ Scale Model", variant="secondary")
590
+
591
+ scale_output = gr.Textbox(label="Scaling Result", interactive=False)
592
+
593
+ with gr.Tab("Collapse Model"):
594
+ gr.Markdown("## Collapse Submodel")
595
+
596
+ with gr.Row():
597
+ with gr.Column():
598
+ cluster_bits = gr.Textbox(
599
+ label="Cluster Bits (JSON)",
600
+ placeholder='[[0,1,0,1],[1,1,0,0]]',
601
+ lines=3
602
+ )
603
+ target_params = gr.Textbox(
604
+ label="Target Parameters (JSON)",
605
+ placeholder='{"d_model":32,"nhead":4,"num_layers":1}',
606
+ lines=3
607
+ )
608
+ width_scale = gr.Number(label="Width Scale", value=1.0, step=0.1)
609
+
610
+ with gr.Column():
611
+ collapse_btn = gr.Button("πŸ—œοΈ Collapse Model", variant="secondary")
612
+ collapse_output = gr.Textbox(label="Collapse Result", interactive=False)
613
+
614
+ with gr.Tab("πŸ“Š Monitoring"):
615
+ with gr.Row():
616
+ with gr.Column():
617
+ gr.Markdown("## Model Status")
618
+ status_output = gr.Code(label="Current Status", language="json")
619
+ refresh_btn = gr.Button("πŸ”„ Refresh Status")
620
+
621
+ with gr.Column():
622
+ gr.Markdown("## System Settings")
623
+
624
+ with gr.Row():
625
+ gpu_checkbox = gr.Checkbox(label="πŸ”₯ Enable GPU/FSDP", value=False)
626
+ qat_checkbox = gr.Checkbox(label="⚑ Enable 4-bit QAT", value=False)
627
+
628
+ with gr.Row():
629
+ compression_checkbox = gr.Checkbox(label="πŸ—œοΈ Enable Compression", value=False)
630
+ diffusion_checkbox = gr.Checkbox(label="🌊 Enable Diffusion Mode", value=False)
631
+
632
+ gr.Markdown("## πŸ“ˆ Telemetry Metrics")
633
+ telemetry_plot = gr.Plot(label="K/C/S Metrics Over Time")
634
+
635
+ with gr.Tab("☁️ HuggingFace Integration"):
636
+ gr.Markdown("## HuggingFace Model Hub")
637
+
638
+ with gr.Row():
639
+ with gr.Column():
640
+ hf_repo_id = gr.Textbox(label="Repository ID", placeholder="username/model-name")
641
+ hf_token = gr.Textbox(label="HF Token (optional)", type="password")
642
+
643
+ with gr.Column():
644
+ with gr.Row():
645
+ hf_upload_btn = gr.Button("⬆️ Upload to HF", variant="secondary")
646
+ hf_download_btn = gr.Button("⬇️ Download from HF", variant="secondary")
647
+
648
+ hf_result = gr.Textbox(label="HuggingFace Result", interactive=False)
649
+
650
+ # Event handlers
651
+ init_btn.click(
652
+ init_model_callback,
653
+ inputs=[d_model, nhead, num_layers, dim_feedforward, max_seq_len,
654
+ chunk_size, overlap, reversible, use_checkpoint, act_threshold,
655
+ c_floor, s_floor],
656
+ outputs=[init_output, status_output, telemetry_plot]
657
+ )
658
+
659
+ train_btn.click(
660
+ train_callback,
661
+ inputs=[train_bits, train_epochs, train_file],
662
+ outputs=[train_output, status_output, telemetry_plot, compression_output]
663
+ )
664
+
665
+ infer_btn.click(
666
+ inference_callback,
667
+ inputs=[infer_bits, infer_file],
668
+ outputs=[infer_result, infer_output]
669
+ )
670
+
671
+ long_infer_btn.click(
672
+ long_inference_callback,
673
+ inputs=[long_bits, long_ctx_bits, long_overlap],
674
+ outputs=[long_result, long_output]
675
+ )
676
+
677
+ text_infer_btn.click(
678
+ text_inference_callback,
679
+ inputs=[text_input],
680
+ outputs=[text_result, text_output]
681
+ )
682
+
683
+ scale_btn.click(
684
+ manager.scale_model,
685
+ inputs=[width_mult],
686
+ outputs=[scale_output]
687
+ )
688
+
689
+ collapse_btn.click(
690
+ manager.collapse_model,
691
+ inputs=[cluster_bits, target_params, width_scale],
692
+ outputs=[collapse_output]
693
+ )
694
+
695
+ refresh_btn.click(
696
+ manager.get_model_status,
697
+ outputs=[status_output]
698
+ )
699
+
700
+ hf_upload_btn.click(
701
+ manager.huggingface_upload,
702
+ inputs=[hf_repo_id, hf_token],
703
+ outputs=[hf_result]
704
+ )
705
+
706
+ hf_download_btn.click(
707
+ manager.huggingface_download,
708
+ inputs=[hf_repo_id, hf_token],
709
+ outputs=[hf_result]
710
+ )
711
+
712
+ # System settings callbacks
713
+ def update_gpu_setting(enabled):
714
+ manager.gpu_enabled = enabled
715
+ return f"GPU/FSDP: {'Enabled' if enabled else 'Disabled'}"
716
+
717
+ def update_qat_setting(enabled):
718
+ manager.qat_enabled = enabled
719
+ return f"QAT: {'Enabled' if enabled else 'Disabled'}"
720
+
721
+ def update_compression_setting(enabled):
722
+ manager.compression_enabled = enabled
723
+ return f"Compression: {'Enabled' if enabled else 'Disabled'}"
724
+
725
+ def update_diffusion_setting(enabled):
726
+ manager.diffusion_enabled = enabled
727
+ return f"Diffusion: {'Enabled' if enabled else 'Disabled'}"
728
+
729
+ # Auto-refresh telemetry every 10 seconds
730
+ interface.load(
731
+ manager.get_telemetry_plot,
732
+ outputs=[telemetry_plot],
733
+ every=10
734
+ )
735
+
736
+ # Load initial status
737
+ interface.load(
738
+ manager.get_model_status,
739
+ outputs=[status_output]
740
+ )
741
+
742
+ return interface
743
+
744
+ def run_gradio_server(host="127.0.0.1", port=7860, share=False):
745
+ """Run the Gradio server."""
746
+ interface = create_gradio_interface()
747
+
748
+ print("πŸš€ Starting BitTransformerLM Gradio Dashboard...")
749
+ print(f"πŸ“ Server will be available at: http://{host}:{port}")
750
+
751
+ if os.getenv("MCP_SERVER_ADDR"):
752
+ print(f"πŸ”— MCP Server configured at: {os.getenv('MCP_SERVER_ADDR')}")
753
+
754
+ interface.launch(
755
+ server_name=host,
756
+ server_port=port,
757
+ share=share,
758
+ show_error=True,
759
+ debug=True
760
+ )
761
+
762
+ if __name__ == "__main__":
763
+ # Support both local development and HF Spaces
764
+ if os.getenv("SPACE_ID"):
765
+ # Running on HuggingFace Spaces
766
+ print("πŸ€— Running on HuggingFace Spaces")
767
+ interface = create_gradio_interface()
768
+ interface.launch()
769
+ else:
770
+ # Local development
771
+ import argparse
772
+ parser = argparse.ArgumentParser(description="BitTransformerLM Gradio Dashboard")
773
+ parser.add_argument("--host", default="127.0.0.1", help="Host address")
774
+ parser.add_argument("--port", type=int, default=7860, help="Port number")
775
+ parser.add_argument("--share", action="store_true", help="Enable sharing")
776
+
777
+ args = parser.parse_args()
778
+ run_gradio_server(args.host, args.port, args.share)