File size: 31,685 Bytes
2b42db0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
#!/usr/bin/env python3
"""
BitTransformerLM Gradio Dashboard
=================================

Comprehensive Gradio interface for BitTransformerLM with full feature parity to the Flask dashboard.
Supports both local deployment and HuggingFace Spaces integration while maintaining MCP server compatibility.
"""

import io
import json
import os
import sys
import traceback
import warnings
from typing import Any, Dict, List, Optional, Union, Tuple
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend
import torch
import torch.nn.functional as F
import gradio as gr
import numpy as np
from pathlib import Path
import threading
import time
import requests
from concurrent.futures import ThreadPoolExecutor
import uuid

# Add BitTransformerLM to path
sys.path.insert(0, str(Path(__file__).parent))

# BitTransformerLM imports
from bit_transformer.model import BitTransformerLM, infer_long_sequence
from bit_transformer.optimization import configure_optimizer
from bit_transformer.collapse import collapse_submodel
from bit_transformer.dashboard import plot_telemetry
from bit_transformer.scale import expand_model
from bit_transformer.bit_io import text_to_bits, bits_to_text
from bit_transformer.safety import hil_safe_inference
from bit_transformer.compression import model_output_decompress, compress_bits
from bit_transformer.distributed import wrap_fsdp
from bit_transformer.training import train_loop
from bit_transformer.telemetry import detect_metric_drift
from bit_transformer.quantization import prepare_qat_fx, convert_qat_fx
from bit_transformer.hf_checkpoint import hf_login, save_checkpoint, download_checkpoint
from bit_transformer.dataset_builder import BitTransformerDatasetBuilder, create_bittransformerlm_dataset

# Global state management
class GradioModelManager:
    """Enhanced ModelManager for Gradio interface with thread safety."""
    
    def __init__(self):
        self.model = None
        self.config = {}
        self.telemetry_log = {
            "negentropy": [],
            "lz_complexity": [],
            "symbiosis_score": [],
            "steps": []
        }
        self.c_floor = 0.3
        self.s_floor = 0.5
        self.lambda_weights = {"K": 1.0, "C": 1.0, "S": 1.0}
        self.compression_enabled = False
        self.qat_enabled = False
        self.diffusion_enabled = False
        self.gpu_enabled = False
        
        # Background job management
        self.executor = ThreadPoolExecutor(max_workers=4)
        self.jobs = {}
        self.mcp_server_addr = os.getenv("MCP_SERVER_ADDR")
        
        # Thread safety
        self.lock = threading.Lock()

    def init_model(self, model_config: dict):
        """Initialize BitTransformerLM model with given configuration."""
        with self.lock:
            try:
                # Clean config - remove None values
                clean_config = {k: v for k, v in model_config.items() if v is not None and v != ""}
                
                self.model = BitTransformerLM(**clean_config)
                self.config = clean_config
                
                # Apply transformations
                if self.qat_enabled:
                    self.model = prepare_qat_fx(self.model)
                if self.gpu_enabled and torch.cuda.is_available():
                    self.model = self.model.cuda()
                
                return f"βœ… Model initialized with config: {clean_config}"
            except Exception as e:
                return f"❌ Model initialization failed: {str(e)}"

    def train_step(self, bits_input, epochs=1):
        """Execute training step(s) with given bit input."""
        if self.model is None:
            return "❌ Model not initialized", None, None
            
        try:
            # Parse bits input
            if isinstance(bits_input, str):
                if bits_input.strip().startswith('['):
                    # JSON format
                    bits = json.loads(bits_input)
                else:
                    # Space-separated format
                    bits = [int(x) for x in bits_input.strip().split()]
            else:
                bits = bits_input
                
            tensor = torch.tensor(bits, dtype=torch.long)
            if self.gpu_enabled and torch.cuda.is_available():
                tensor = tensor.cuda()
                
            # Training loop
            total_loss = 0
            compression_ratio = 1.0
            
            for epoch in range(epochs):
                self.model.train()
                
                # Forward pass with telemetry
                if self.compression_enabled:
                    compressed_bits, ratio = compress_bits(bits)
                    tensor = torch.tensor(compressed_bits, dtype=torch.long)
                    compression_ratio = ratio
                    
                output, telemetry = self.model(tensor.unsqueeze(0))
                
                # Compute loss
                if output.dim() == 3:
                    loss = F.cross_entropy(
                        output.view(-1, output.size(-1)),
                        tensor[:-1].unsqueeze(0).contiguous().view(-1),
                        ignore_index=-1
                    )
                else:
                    loss = F.cross_entropy(output, tensor.unsqueeze(0))
                
                # Backward pass
                loss.backward()
                
                # Update telemetry
                self._update_telemetry(telemetry)
                total_loss += loss.item()
            
            avg_loss = total_loss / epochs
            return f"βœ… Training completed. Average Loss: {avg_loss:.4f}", avg_loss, compression_ratio
            
        except Exception as e:
            return f"❌ Training failed: {str(e)}", None, None

    def inference(self, bits_input, long_inference=False, ctx_bits=4096, overlap=256):
        """Run inference on bit input."""
        if self.model is None:
            return "❌ Model not initialized", None
            
        try:
            # Parse bits input
            if isinstance(bits_input, str):
                if bits_input.strip().startswith('['):
                    bits = json.loads(bits_input)
                else:
                    bits = [int(x) for x in bits_input.strip().split()]
            else:
                bits = bits_input
                
            tensor = torch.tensor(bits, dtype=torch.long)
            if self.gpu_enabled and torch.cuda.is_available():
                tensor = tensor.cuda()
            
            self.model.eval()
            
            with torch.inference_mode():
                if long_inference or len(bits) > ctx_bits:
                    # Long sequence inference
                    output, telemetry = infer_long_sequence(
                        self.model, tensor.unsqueeze(0),
                        ctx_bits=ctx_bits, overlap=overlap
                    )
                else:
                    # Standard inference with safety gates
                    output, telemetry = hil_safe_inference(
                        self.model, tensor.unsqueeze(0),
                        c_floor=self.c_floor, s_floor=self.s_floor
                    )
                
                # Update telemetry
                self._update_telemetry(telemetry)
                
                output_bits = output.squeeze(0).cpu().tolist()
                return f"βœ… Inference completed. Output length: {len(output_bits)}", output_bits
                
        except Exception as e:
            return f"❌ Inference failed: {str(e)}", None

    def text_inference(self, text_input):
        """Convert text to bits, run inference, convert back to text."""
        try:
            # Text to bits
            bits = text_to_bits(text_input)
            
            # Run inference
            result, output_bits = self.inference(bits)
            
            if output_bits is None:
                return result, None
                
            # Convert back to text
            try:
                output_text = bits_to_text(output_bits)
                return f"βœ… Text inference completed.", output_text
            except Exception as e:
                return f"βœ… Inference completed, but text conversion failed: {str(e)}", str(output_bits)
                
        except Exception as e:
            return f"❌ Text inference failed: {str(e)}", None

    def scale_model(self, width_multiplier):
        """Scale up model width."""
        if self.model is None:
            return "❌ Model not initialized"
            
        try:
            with self.lock:
                self.model = expand_model(self.model, width_multiplier)
                return f"βœ… Model scaled by factor {width_multiplier}"
        except Exception as e:
            return f"❌ Model scaling failed: {str(e)}"

    def collapse_model(self, cluster_bits, target_params, width_scale=1.0):
        """Collapse model using cluster analysis."""
        if self.model is None:
            return "❌ Model not initialized"
            
        try:
            # Parse inputs
            if isinstance(cluster_bits, str):
                clusters = json.loads(cluster_bits)
            else:
                clusters = cluster_bits
                
            if isinstance(target_params, str):
                params = json.loads(target_params)
            else:
                params = target_params
            
            with self.lock:
                collapsed_model = collapse_submodel(
                    self.model, clusters, params, width_scale
                )
                self.model = collapsed_model
                return f"βœ… Model collapsed successfully"
        except Exception as e:
            return f"❌ Model collapse failed: {str(e)}"

    def get_model_status(self):
        """Get current model status and configuration."""
        if self.model is None:
            return "❌ No model initialized"
            
        try:
            param_count = sum(p.numel() for p in self.model.parameters())
            status = {
                "initialized": True,
                "parameters": param_count,
                "config": self.config,
                "gpu_enabled": self.gpu_enabled,
                "qat_enabled": self.qat_enabled,
                "compression_enabled": self.compression_enabled,
                "diffusion_enabled": self.diffusion_enabled,
            }
            return json.dumps(status, indent=2)
        except Exception as e:
            return f"❌ Status check failed: {str(e)}"

    def get_telemetry_plot(self):
        """Generate telemetry plot."""
        try:
            if not any(self.telemetry_log.values()):
                # Return empty plot
                fig, ax = plt.subplots(figsize=(10, 6))
                ax.text(0.5, 0.5, 'No telemetry data yet', ha='center', va='center', transform=ax.transAxes)
                ax.set_title('Telemetry Metrics')
                return fig
            
            fig, axes = plot_telemetry(
                self.telemetry_log,
                k_floor=0.5,  # Negentropy floor
                c_floor=self.c_floor,
                s_floor=self.s_floor
            )
            return fig
        except Exception as e:
            # Return error plot
            fig, ax = plt.subplots(figsize=(10, 6))
            ax.text(0.5, 0.5, f'Plot error: {str(e)}', ha='center', va='center', transform=ax.transAxes)
            ax.set_title('Telemetry Metrics - Error')
            return fig

    def _update_telemetry(self, telemetry_dict):
        """Update telemetry log with new values."""
        if not telemetry_dict:
            return
            
        step = len(self.telemetry_log["steps"])
        self.telemetry_log["steps"].append(step)
        
        # Extract metrics with defaults
        self.telemetry_log["negentropy"].append(
            float(telemetry_dict.get("negentropy", torch.tensor(0.0)).mean().item())
        )
        self.telemetry_log["lz_complexity"].append(
            float(telemetry_dict.get("lz_complexity_logits", torch.tensor(0.0)).mean().item())
        )
        self.telemetry_log["symbiosis_score"].append(
            float(telemetry_dict.get("symbiosis_score", torch.tensor(0.0)).mean().item())
        )

    def huggingface_upload(self, repo_id, hf_token=None):
        """Upload model to HuggingFace."""
        if self.model is None:
            return "❌ Model not initialized"
            
        try:
            if hf_token:
                hf_login(hf_token)
            
            save_checkpoint(self.model, repo_id, self.config)
            return f"βœ… Model uploaded to {repo_id}"
        except Exception as e:
            return f"❌ HF upload failed: {str(e)}"

    def huggingface_download(self, repo_id, hf_token=None):
        """Download model from HuggingFace."""
        try:
            if hf_token:
                hf_login(hf_token)
            
            with self.lock:
                model, config = download_checkpoint(repo_id)
                self.model = model
                self.config = config
            
            return f"βœ… Model downloaded from {repo_id}"
        except Exception as e:
            return f"❌ HF download failed: {str(e)}"

    def mcp_request(self, endpoint, data=None, method="POST"):
        """Make request to MCP server if available."""
        if not self.mcp_server_addr:
            return "❌ MCP server not configured"
            
        try:
            url = self.mcp_server_addr.rstrip("/") + endpoint
            if method == "POST":
                resp = requests.post(url, json=data, timeout=30)
            else:
                resp = requests.get(url, timeout=30)
            
            resp.raise_for_status()
            
            if resp.headers.get("Content-Type", "").startswith("image/"):
                return "βœ… MCP request completed (binary data)"
            return f"βœ… MCP request completed: {resp.json()}"
        except Exception as e:
            return f"❌ MCP request failed: {str(e)}"

# Global manager instance
manager = GradioModelManager()

def create_gradio_interface():
    """Create the main Gradio interface with all BitTransformerLM features."""
    
    # Helper functions for Gradio callbacks
    def init_model_callback(d_model, nhead, num_layers, dim_feedforward, max_seq_len, 
                           chunk_size, overlap, reversible, use_checkpoint, act_threshold,
                           c_floor, s_floor):
        """Initialize model with form parameters."""
        config = {
            "d_model": d_model,
            "nhead": nhead, 
            "num_layers": num_layers,
            "dim_feedforward": dim_feedforward,
            "max_seq_len": max_seq_len,
            "chunk_size": chunk_size if chunk_size > 0 else None,
            "overlap": overlap,
            "reversible": reversible,
            "use_checkpoint": use_checkpoint,
            "act_threshold": act_threshold
        }
        
        # Update safety floors
        manager.c_floor = c_floor
        manager.s_floor = s_floor
        
        result = manager.init_model(config)
        status = manager.get_model_status()
        plot = manager.get_telemetry_plot()
        
        return result, status, plot

    def train_callback(bits_input, epochs, file_input):
        """Training callback with file upload support."""
        if file_input is not None:
            # Process uploaded file
            try:
                if file_input.name.endswith(('.txt', '.md')):
                    with open(file_input.name, 'r') as f:
                        text = f.read()
                    bits = text_to_bits(text)
                else:
                    with open(file_input.name, 'rb') as f:
                        data = f.read()
                    # Convert bytes to bits
                    bits = []
                    for byte in data:
                        for i in range(8):
                            bits.append((byte >> (7-i)) & 1)
                
                result, loss, ratio = manager.train_step(bits, epochs)
            except Exception as e:
                result = f"❌ File processing failed: {str(e)}"
                loss, ratio = None, None
        else:
            result, loss, ratio = manager.train_step(bits_input, epochs)
        
        status = manager.get_model_status()
        plot = manager.get_telemetry_plot()
        
        return result, status, plot, f"Compression Ratio: {ratio:.2f}" if ratio else ""

    def inference_callback(bits_input, file_input):
        """Standard inference callback."""
        if file_input is not None:
            # Process uploaded file similar to training
            try:
                if file_input.name.endswith(('.txt', '.md')):
                    with open(file_input.name, 'r') as f:
                        text = f.read()
                    bits = text_to_bits(text)
                else:
                    with open(file_input.name, 'rb') as f:
                        data = f.read()
                    bits = []
                    for byte in data:
                        for i in range(8):
                            bits.append((byte >> (7-i)) & 1)
                
                result, output_bits = manager.inference(bits)
            except Exception as e:
                result = f"❌ File processing failed: {str(e)}"
                output_bits = None
        else:
            result, output_bits = manager.inference(bits_input)
            
        return result, str(output_bits) if output_bits else ""

    def long_inference_callback(bits_input, ctx_bits, overlap):
        """Long sequence inference callback."""
        result, output_bits = manager.inference(bits_input, long_inference=True, 
                                               ctx_bits=ctx_bits, overlap=overlap)
        return result, str(output_bits) if output_bits else ""

    def text_inference_callback(text_input):
        """Text-to-text inference callback."""
        result, output_text = manager.text_inference(text_input)
        return result, output_text if output_text else ""

    # Create Gradio interface
    with gr.Blocks(title="BitTransformerLM Dashboard", 
                   theme=gr.themes.Soft()) as interface:
        
        gr.Markdown("# πŸ€– BitTransformerLM Interactive Dashboard")
        gr.Markdown("*Experimental bit-native transformer with comprehensive training and inference capabilities*")
        
        with gr.Tab("πŸ—οΈ Model Configuration"):
            gr.Markdown("## Initialize BitTransformerLM")
            
            with gr.Row():
                with gr.Column():
                    d_model = gr.Number(label="d_model", value=64, info="Model width")
                    nhead = gr.Number(label="nhead", value=4, info="Attention heads")
                    num_layers = gr.Number(label="num_layers", value=2, info="Transformer layers")
                    dim_feedforward = gr.Number(label="dim_feedforward", value=256, info="FFN dimension")
                
                with gr.Column():
                    max_seq_len = gr.Number(label="max_seq_len", value=512, info="Max sequence length")
                    chunk_size = gr.Number(label="chunk_size", value=0, info="Chunk size (0=auto)")
                    overlap = gr.Number(label="overlap", value=64, info="Sliding window overlap")
                    act_threshold = gr.Number(label="act_threshold", value=0.95, info="ACT halt threshold")
            
            with gr.Row():
                reversible = gr.Checkbox(label="Reversible Layers", value=False)
                use_checkpoint = gr.Checkbox(label="Gradient Checkpointing", value=True)
            
            with gr.Row():
                c_floor = gr.Number(label="c_floor", value=0.3, info="Complexity safety floor")
                s_floor = gr.Number(label="s_floor", value=0.5, info="Symbiosis safety floor")
            
            init_btn = gr.Button("πŸš€ Initialize Model", variant="primary")
            init_output = gr.Textbox(label="Initialization Result", interactive=False)
            
        with gr.Tab("🎯 Training"):
            gr.Markdown("## Train BitTransformerLM")
            
            with gr.Row():
                with gr.Column():
                    train_bits = gr.Textbox(
                        label="Bit Input",
                        placeholder="0 1 0 1 or [0,1,0,1] or upload file",
                        lines=3
                    )
                    train_file = gr.File(label="Upload Training File", file_types=[".txt", ".md", ".bin"])
                    train_epochs = gr.Number(label="Epochs", value=1, minimum=1)
                    
                with gr.Column():
                    train_btn = gr.Button("πŸƒ Start Training", variant="primary")
                    train_output = gr.Textbox(label="Training Result", interactive=False)
                    compression_output = gr.Textbox(label="Compression Info", interactive=False)
                    
        with gr.Tab("🧠 Inference"):
            with gr.Tab("Standard Inference"):
                gr.Markdown("## Standard Inference")
                
                with gr.Row():
                    with gr.Column():
                        infer_bits = gr.Textbox(
                            label="Bit Input",
                            placeholder="0 1 0 1 or [0,1,0,1]",
                            lines=3
                        )
                        infer_file = gr.File(label="Upload Inference File")
                        
                    with gr.Column():
                        infer_btn = gr.Button("🎯 Run Inference", variant="primary")
                        infer_result = gr.Textbox(label="Result", interactive=False)
                        infer_output = gr.Textbox(label="Output Bits", lines=5, interactive=False)
                        
            with gr.Tab("Long Sequence Inference"):
                gr.Markdown("## Long Sequence Inference")
                
                with gr.Row():
                    with gr.Column():
                        long_bits = gr.Textbox(
                            label="Long Bit Sequence",
                            lines=5,
                            placeholder="Long sequence of bits..."
                        )
                        long_ctx_bits = gr.Number(label="Context Bits", value=4096)
                        long_overlap = gr.Number(label="Overlap", value=256)
                        
                    with gr.Column():
                        long_infer_btn = gr.Button("πŸ”„ Run Long Inference", variant="primary")
                        long_result = gr.Textbox(label="Result", interactive=False)
                        long_output = gr.Textbox(label="Output Bits", lines=5, interactive=False)
                        
            with gr.Tab("Text Inference"):
                gr.Markdown("## Text-to-Text Inference")
                
                with gr.Row():
                    with gr.Column():
                        text_input = gr.Textbox(
                            label="Input Text",
                            placeholder="Enter text to process...",
                            lines=3
                        )
                        text_infer_btn = gr.Button("πŸ“ Process Text", variant="primary")
                        
                    with gr.Column():
                        text_result = gr.Textbox(label="Result", interactive=False)
                        text_output = gr.Textbox(
                            label="Output Text",
                            lines=5,
                            interactive=False
                        )
                        
        with gr.Tab("βš™οΈ Model Operations"):
            with gr.Tab("Scale Model"):
                gr.Markdown("## Scale Model Width")
                
                with gr.Row():
                    width_mult = gr.Number(label="Width Multiplier", value=1.5, step=0.1)
                    scale_btn = gr.Button("πŸ“ˆ Scale Model", variant="secondary")
                    
                scale_output = gr.Textbox(label="Scaling Result", interactive=False)
                
            with gr.Tab("Collapse Model"):
                gr.Markdown("## Collapse Submodel")
                
                with gr.Row():
                    with gr.Column():
                        cluster_bits = gr.Textbox(
                            label="Cluster Bits (JSON)",
                            placeholder='[[0,1,0,1],[1,1,0,0]]',
                            lines=3
                        )
                        target_params = gr.Textbox(
                            label="Target Parameters (JSON)",
                            placeholder='{"d_model":32,"nhead":4,"num_layers":1}',
                            lines=3
                        )
                        width_scale = gr.Number(label="Width Scale", value=1.0, step=0.1)
                        
                    with gr.Column():
                        collapse_btn = gr.Button("πŸ—œοΈ Collapse Model", variant="secondary")
                        collapse_output = gr.Textbox(label="Collapse Result", interactive=False)
                        
        with gr.Tab("πŸ“Š Monitoring"):
            with gr.Row():
                with gr.Column():
                    gr.Markdown("## Model Status")
                    status_output = gr.Code(label="Current Status", language="json")
                    refresh_btn = gr.Button("πŸ”„ Refresh Status")
                    
                with gr.Column():
                    gr.Markdown("## System Settings")
                    
                    with gr.Row():
                        gpu_checkbox = gr.Checkbox(label="πŸ”₯ Enable GPU/FSDP", value=False)
                        qat_checkbox = gr.Checkbox(label="⚑ Enable 4-bit QAT", value=False)
                    
                    with gr.Row():
                        compression_checkbox = gr.Checkbox(label="πŸ—œοΈ Enable Compression", value=False)
                        diffusion_checkbox = gr.Checkbox(label="🌊 Enable Diffusion Mode", value=False)
                        
            gr.Markdown("## πŸ“ˆ Telemetry Metrics")
            telemetry_plot = gr.Plot(label="K/C/S Metrics Over Time")
            
        with gr.Tab("☁️ HuggingFace Integration"):
            gr.Markdown("## HuggingFace Model Hub")
            
            with gr.Row():
                with gr.Column():
                    hf_repo_id = gr.Textbox(label="Repository ID", placeholder="username/model-name")
                    hf_token = gr.Textbox(label="HF Token (optional)", type="password")
                    
                with gr.Column():
                    with gr.Row():
                        hf_upload_btn = gr.Button("⬆️ Upload to HF", variant="secondary")
                        hf_download_btn = gr.Button("⬇️ Download from HF", variant="secondary")
                        
            hf_result = gr.Textbox(label="HuggingFace Result", interactive=False)
            
        # Event handlers
        init_btn.click(
            init_model_callback,
            inputs=[d_model, nhead, num_layers, dim_feedforward, max_seq_len,
                   chunk_size, overlap, reversible, use_checkpoint, act_threshold,
                   c_floor, s_floor],
            outputs=[init_output, status_output, telemetry_plot]
        )
        
        train_btn.click(
            train_callback,
            inputs=[train_bits, train_epochs, train_file],
            outputs=[train_output, status_output, telemetry_plot, compression_output]
        )
        
        infer_btn.click(
            inference_callback,
            inputs=[infer_bits, infer_file],
            outputs=[infer_result, infer_output]
        )
        
        long_infer_btn.click(
            long_inference_callback,
            inputs=[long_bits, long_ctx_bits, long_overlap],
            outputs=[long_result, long_output]
        )
        
        text_infer_btn.click(
            text_inference_callback,
            inputs=[text_input],
            outputs=[text_result, text_output]
        )
        
        scale_btn.click(
            manager.scale_model,
            inputs=[width_mult],
            outputs=[scale_output]
        )
        
        collapse_btn.click(
            manager.collapse_model,
            inputs=[cluster_bits, target_params, width_scale],
            outputs=[collapse_output]
        )
        
        refresh_btn.click(
            manager.get_model_status,
            outputs=[status_output]
        )
        
        hf_upload_btn.click(
            manager.huggingface_upload,
            inputs=[hf_repo_id, hf_token],
            outputs=[hf_result]
        )
        
        hf_download_btn.click(
            manager.huggingface_download,
            inputs=[hf_repo_id, hf_token],
            outputs=[hf_result]
        )
        
        # System settings callbacks
        def update_gpu_setting(enabled):
            manager.gpu_enabled = enabled
            return f"GPU/FSDP: {'Enabled' if enabled else 'Disabled'}"
            
        def update_qat_setting(enabled):
            manager.qat_enabled = enabled
            return f"QAT: {'Enabled' if enabled else 'Disabled'}"
            
        def update_compression_setting(enabled):
            manager.compression_enabled = enabled
            return f"Compression: {'Enabled' if enabled else 'Disabled'}"
            
        def update_diffusion_setting(enabled):
            manager.diffusion_enabled = enabled
            return f"Diffusion: {'Enabled' if enabled else 'Disabled'}"
        
        # Auto-refresh telemetry every 10 seconds
        interface.load(
            manager.get_telemetry_plot,
            outputs=[telemetry_plot],
            every=10
        )
        
        # Load initial status
        interface.load(
            manager.get_model_status,
            outputs=[status_output]
        )
        
    return interface

def run_gradio_server(host="127.0.0.1", port=7860, share=False):
    """Run the Gradio server."""
    interface = create_gradio_interface()
    
    print("πŸš€ Starting BitTransformerLM Gradio Dashboard...")
    print(f"πŸ“ Server will be available at: http://{host}:{port}")
    
    if os.getenv("MCP_SERVER_ADDR"):
        print(f"πŸ”— MCP Server configured at: {os.getenv('MCP_SERVER_ADDR')}")
    
    interface.launch(
        server_name=host,
        server_port=port,
        share=share,
        show_error=True,
        debug=True
    )

if __name__ == "__main__":
    # Support both local development and HF Spaces
    if os.getenv("SPACE_ID"):
        # Running on HuggingFace Spaces
        print("πŸ€— Running on HuggingFace Spaces")
        interface = create_gradio_interface()
        interface.launch()
    else:
        # Local development
        import argparse
        parser = argparse.ArgumentParser(description="BitTransformerLM Gradio Dashboard")
        parser.add_argument("--host", default="127.0.0.1", help="Host address")
        parser.add_argument("--port", type=int, default=7860, help="Port number")
        parser.add_argument("--share", action="store_true", help="Enable sharing")
        
        args = parser.parse_args()
        run_gradio_server(args.host, args.port, args.share)