|
|
|
""" |
|
BitTransformerLM Gradio Dashboard |
|
================================= |
|
|
|
Comprehensive Gradio interface for BitTransformerLM with full feature parity to the Flask dashboard. |
|
Supports both local deployment and HuggingFace Spaces integration while maintaining MCP server compatibility. |
|
""" |
|
|
|
import io |
|
import json |
|
import os |
|
import sys |
|
import traceback |
|
import warnings |
|
from typing import Any, Dict, List, Optional, Union, Tuple |
|
import matplotlib.pyplot as plt |
|
import matplotlib |
|
matplotlib.use('Agg') |
|
import torch |
|
import torch.nn.functional as F |
|
import gradio as gr |
|
import numpy as np |
|
from pathlib import Path |
|
import threading |
|
import time |
|
import requests |
|
from concurrent.futures import ThreadPoolExecutor |
|
import uuid |
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent)) |
|
|
|
|
|
from bit_transformer.model import BitTransformerLM, infer_long_sequence |
|
from bit_transformer.optimization import configure_optimizer |
|
from bit_transformer.collapse import collapse_submodel |
|
from bit_transformer.dashboard import plot_telemetry |
|
from bit_transformer.scale import expand_model |
|
from bit_transformer.bit_io import text_to_bits, bits_to_text |
|
from bit_transformer.safety import hil_safe_inference |
|
from bit_transformer.compression import model_output_decompress, compress_bits |
|
from bit_transformer.distributed import wrap_fsdp |
|
from bit_transformer.training import train_loop |
|
from bit_transformer.telemetry import detect_metric_drift |
|
from bit_transformer.quantization import prepare_qat_fx, convert_qat_fx |
|
from bit_transformer.hf_checkpoint import hf_login, save_checkpoint, download_checkpoint |
|
from bit_transformer.dataset_builder import BitTransformerDatasetBuilder, create_bittransformerlm_dataset |
|
|
|
|
|
class GradioModelManager: |
|
"""Enhanced ModelManager for Gradio interface with thread safety.""" |
|
|
|
def __init__(self): |
|
self.model = None |
|
self.config = {} |
|
self.telemetry_log = { |
|
"negentropy": [], |
|
"lz_complexity": [], |
|
"symbiosis_score": [], |
|
"steps": [] |
|
} |
|
self.c_floor = 0.3 |
|
self.s_floor = 0.5 |
|
self.lambda_weights = {"K": 1.0, "C": 1.0, "S": 1.0} |
|
self.compression_enabled = False |
|
self.qat_enabled = False |
|
self.diffusion_enabled = False |
|
self.gpu_enabled = False |
|
|
|
|
|
self.executor = ThreadPoolExecutor(max_workers=4) |
|
self.jobs = {} |
|
self.mcp_server_addr = os.getenv("MCP_SERVER_ADDR") |
|
|
|
|
|
self.lock = threading.Lock() |
|
|
|
def init_model(self, model_config: dict): |
|
"""Initialize BitTransformerLM model with given configuration.""" |
|
with self.lock: |
|
try: |
|
|
|
clean_config = {k: v for k, v in model_config.items() if v is not None and v != ""} |
|
|
|
self.model = BitTransformerLM(**clean_config) |
|
self.config = clean_config |
|
|
|
|
|
if self.qat_enabled: |
|
self.model = prepare_qat_fx(self.model) |
|
if self.gpu_enabled and torch.cuda.is_available(): |
|
self.model = self.model.cuda() |
|
|
|
return f"β
Model initialized with config: {clean_config}" |
|
except Exception as e: |
|
return f"β Model initialization failed: {str(e)}" |
|
|
|
def train_step(self, bits_input, epochs=1): |
|
"""Execute training step(s) with given bit input.""" |
|
if self.model is None: |
|
return "β Model not initialized", None, None |
|
|
|
try: |
|
|
|
if isinstance(bits_input, str): |
|
if bits_input.strip().startswith('['): |
|
|
|
bits = json.loads(bits_input) |
|
else: |
|
|
|
bits = [int(x) for x in bits_input.strip().split()] |
|
else: |
|
bits = bits_input |
|
|
|
tensor = torch.tensor(bits, dtype=torch.long) |
|
if self.gpu_enabled and torch.cuda.is_available(): |
|
tensor = tensor.cuda() |
|
|
|
|
|
total_loss = 0 |
|
compression_ratio = 1.0 |
|
|
|
for epoch in range(epochs): |
|
self.model.train() |
|
|
|
|
|
if self.compression_enabled: |
|
compressed_bits, ratio = compress_bits(bits) |
|
tensor = torch.tensor(compressed_bits, dtype=torch.long) |
|
compression_ratio = ratio |
|
|
|
output, telemetry = self.model(tensor.unsqueeze(0)) |
|
|
|
|
|
if output.dim() == 3: |
|
loss = F.cross_entropy( |
|
output.view(-1, output.size(-1)), |
|
tensor[:-1].unsqueeze(0).contiguous().view(-1), |
|
ignore_index=-1 |
|
) |
|
else: |
|
loss = F.cross_entropy(output, tensor.unsqueeze(0)) |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
self._update_telemetry(telemetry) |
|
total_loss += loss.item() |
|
|
|
avg_loss = total_loss / epochs |
|
return f"β
Training completed. Average Loss: {avg_loss:.4f}", avg_loss, compression_ratio |
|
|
|
except Exception as e: |
|
return f"β Training failed: {str(e)}", None, None |
|
|
|
def inference(self, bits_input, long_inference=False, ctx_bits=4096, overlap=256): |
|
"""Run inference on bit input.""" |
|
if self.model is None: |
|
return "β Model not initialized", None |
|
|
|
try: |
|
|
|
if isinstance(bits_input, str): |
|
if bits_input.strip().startswith('['): |
|
bits = json.loads(bits_input) |
|
else: |
|
bits = [int(x) for x in bits_input.strip().split()] |
|
else: |
|
bits = bits_input |
|
|
|
tensor = torch.tensor(bits, dtype=torch.long) |
|
if self.gpu_enabled and torch.cuda.is_available(): |
|
tensor = tensor.cuda() |
|
|
|
self.model.eval() |
|
|
|
with torch.inference_mode(): |
|
if long_inference or len(bits) > ctx_bits: |
|
|
|
output, telemetry = infer_long_sequence( |
|
self.model, tensor.unsqueeze(0), |
|
ctx_bits=ctx_bits, overlap=overlap |
|
) |
|
else: |
|
|
|
output, telemetry = hil_safe_inference( |
|
self.model, tensor.unsqueeze(0), |
|
c_floor=self.c_floor, s_floor=self.s_floor |
|
) |
|
|
|
|
|
self._update_telemetry(telemetry) |
|
|
|
output_bits = output.squeeze(0).cpu().tolist() |
|
return f"β
Inference completed. Output length: {len(output_bits)}", output_bits |
|
|
|
except Exception as e: |
|
return f"β Inference failed: {str(e)}", None |
|
|
|
def text_inference(self, text_input): |
|
"""Convert text to bits, run inference, convert back to text.""" |
|
try: |
|
|
|
bits = text_to_bits(text_input) |
|
|
|
|
|
result, output_bits = self.inference(bits) |
|
|
|
if output_bits is None: |
|
return result, None |
|
|
|
|
|
try: |
|
output_text = bits_to_text(output_bits) |
|
return f"β
Text inference completed.", output_text |
|
except Exception as e: |
|
return f"β
Inference completed, but text conversion failed: {str(e)}", str(output_bits) |
|
|
|
except Exception as e: |
|
return f"β Text inference failed: {str(e)}", None |
|
|
|
def scale_model(self, width_multiplier): |
|
"""Scale up model width.""" |
|
if self.model is None: |
|
return "β Model not initialized" |
|
|
|
try: |
|
with self.lock: |
|
self.model = expand_model(self.model, width_multiplier) |
|
return f"β
Model scaled by factor {width_multiplier}" |
|
except Exception as e: |
|
return f"β Model scaling failed: {str(e)}" |
|
|
|
def collapse_model(self, cluster_bits, target_params, width_scale=1.0): |
|
"""Collapse model using cluster analysis.""" |
|
if self.model is None: |
|
return "β Model not initialized" |
|
|
|
try: |
|
|
|
if isinstance(cluster_bits, str): |
|
clusters = json.loads(cluster_bits) |
|
else: |
|
clusters = cluster_bits |
|
|
|
if isinstance(target_params, str): |
|
params = json.loads(target_params) |
|
else: |
|
params = target_params |
|
|
|
with self.lock: |
|
collapsed_model = collapse_submodel( |
|
self.model, clusters, params, width_scale |
|
) |
|
self.model = collapsed_model |
|
return f"β
Model collapsed successfully" |
|
except Exception as e: |
|
return f"β Model collapse failed: {str(e)}" |
|
|
|
def get_model_status(self): |
|
"""Get current model status and configuration.""" |
|
if self.model is None: |
|
return "β No model initialized" |
|
|
|
try: |
|
param_count = sum(p.numel() for p in self.model.parameters()) |
|
status = { |
|
"initialized": True, |
|
"parameters": param_count, |
|
"config": self.config, |
|
"gpu_enabled": self.gpu_enabled, |
|
"qat_enabled": self.qat_enabled, |
|
"compression_enabled": self.compression_enabled, |
|
"diffusion_enabled": self.diffusion_enabled, |
|
} |
|
return json.dumps(status, indent=2) |
|
except Exception as e: |
|
return f"β Status check failed: {str(e)}" |
|
|
|
def get_telemetry_plot(self): |
|
"""Generate telemetry plot.""" |
|
try: |
|
if not any(self.telemetry_log.values()): |
|
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
ax.text(0.5, 0.5, 'No telemetry data yet', ha='center', va='center', transform=ax.transAxes) |
|
ax.set_title('Telemetry Metrics') |
|
return fig |
|
|
|
fig, axes = plot_telemetry( |
|
self.telemetry_log, |
|
k_floor=0.5, |
|
c_floor=self.c_floor, |
|
s_floor=self.s_floor |
|
) |
|
return fig |
|
except Exception as e: |
|
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
ax.text(0.5, 0.5, f'Plot error: {str(e)}', ha='center', va='center', transform=ax.transAxes) |
|
ax.set_title('Telemetry Metrics - Error') |
|
return fig |
|
|
|
def _update_telemetry(self, telemetry_dict): |
|
"""Update telemetry log with new values.""" |
|
if not telemetry_dict: |
|
return |
|
|
|
step = len(self.telemetry_log["steps"]) |
|
self.telemetry_log["steps"].append(step) |
|
|
|
|
|
self.telemetry_log["negentropy"].append( |
|
float(telemetry_dict.get("negentropy", torch.tensor(0.0)).mean().item()) |
|
) |
|
self.telemetry_log["lz_complexity"].append( |
|
float(telemetry_dict.get("lz_complexity_logits", torch.tensor(0.0)).mean().item()) |
|
) |
|
self.telemetry_log["symbiosis_score"].append( |
|
float(telemetry_dict.get("symbiosis_score", torch.tensor(0.0)).mean().item()) |
|
) |
|
|
|
def huggingface_upload(self, repo_id, hf_token=None): |
|
"""Upload model to HuggingFace.""" |
|
if self.model is None: |
|
return "β Model not initialized" |
|
|
|
try: |
|
if hf_token: |
|
hf_login(hf_token) |
|
|
|
save_checkpoint(self.model, repo_id, self.config) |
|
return f"β
Model uploaded to {repo_id}" |
|
except Exception as e: |
|
return f"β HF upload failed: {str(e)}" |
|
|
|
def huggingface_download(self, repo_id, hf_token=None): |
|
"""Download model from HuggingFace.""" |
|
try: |
|
if hf_token: |
|
hf_login(hf_token) |
|
|
|
with self.lock: |
|
model, config = download_checkpoint(repo_id) |
|
self.model = model |
|
self.config = config |
|
|
|
return f"β
Model downloaded from {repo_id}" |
|
except Exception as e: |
|
return f"β HF download failed: {str(e)}" |
|
|
|
def mcp_request(self, endpoint, data=None, method="POST"): |
|
"""Make request to MCP server if available.""" |
|
if not self.mcp_server_addr: |
|
return "β MCP server not configured" |
|
|
|
try: |
|
url = self.mcp_server_addr.rstrip("/") + endpoint |
|
if method == "POST": |
|
resp = requests.post(url, json=data, timeout=30) |
|
else: |
|
resp = requests.get(url, timeout=30) |
|
|
|
resp.raise_for_status() |
|
|
|
if resp.headers.get("Content-Type", "").startswith("image/"): |
|
return "β
MCP request completed (binary data)" |
|
return f"β
MCP request completed: {resp.json()}" |
|
except Exception as e: |
|
return f"β MCP request failed: {str(e)}" |
|
|
|
|
|
manager = GradioModelManager() |
|
|
|
def create_gradio_interface(): |
|
"""Create the main Gradio interface with all BitTransformerLM features.""" |
|
|
|
|
|
def init_model_callback(d_model, nhead, num_layers, dim_feedforward, max_seq_len, |
|
chunk_size, overlap, reversible, use_checkpoint, act_threshold, |
|
c_floor, s_floor): |
|
"""Initialize model with form parameters.""" |
|
config = { |
|
"d_model": d_model, |
|
"nhead": nhead, |
|
"num_layers": num_layers, |
|
"dim_feedforward": dim_feedforward, |
|
"max_seq_len": max_seq_len, |
|
"chunk_size": chunk_size if chunk_size > 0 else None, |
|
"overlap": overlap, |
|
"reversible": reversible, |
|
"use_checkpoint": use_checkpoint, |
|
"act_threshold": act_threshold |
|
} |
|
|
|
|
|
manager.c_floor = c_floor |
|
manager.s_floor = s_floor |
|
|
|
result = manager.init_model(config) |
|
status = manager.get_model_status() |
|
plot = manager.get_telemetry_plot() |
|
|
|
return result, status, plot |
|
|
|
def train_callback(bits_input, epochs, file_input): |
|
"""Training callback with file upload support.""" |
|
if file_input is not None: |
|
|
|
try: |
|
if file_input.name.endswith(('.txt', '.md')): |
|
with open(file_input.name, 'r') as f: |
|
text = f.read() |
|
bits = text_to_bits(text) |
|
else: |
|
with open(file_input.name, 'rb') as f: |
|
data = f.read() |
|
|
|
bits = [] |
|
for byte in data: |
|
for i in range(8): |
|
bits.append((byte >> (7-i)) & 1) |
|
|
|
result, loss, ratio = manager.train_step(bits, epochs) |
|
except Exception as e: |
|
result = f"β File processing failed: {str(e)}" |
|
loss, ratio = None, None |
|
else: |
|
result, loss, ratio = manager.train_step(bits_input, epochs) |
|
|
|
status = manager.get_model_status() |
|
plot = manager.get_telemetry_plot() |
|
|
|
return result, status, plot, f"Compression Ratio: {ratio:.2f}" if ratio else "" |
|
|
|
def inference_callback(bits_input, file_input): |
|
"""Standard inference callback.""" |
|
if file_input is not None: |
|
|
|
try: |
|
if file_input.name.endswith(('.txt', '.md')): |
|
with open(file_input.name, 'r') as f: |
|
text = f.read() |
|
bits = text_to_bits(text) |
|
else: |
|
with open(file_input.name, 'rb') as f: |
|
data = f.read() |
|
bits = [] |
|
for byte in data: |
|
for i in range(8): |
|
bits.append((byte >> (7-i)) & 1) |
|
|
|
result, output_bits = manager.inference(bits) |
|
except Exception as e: |
|
result = f"β File processing failed: {str(e)}" |
|
output_bits = None |
|
else: |
|
result, output_bits = manager.inference(bits_input) |
|
|
|
return result, str(output_bits) if output_bits else "" |
|
|
|
def long_inference_callback(bits_input, ctx_bits, overlap): |
|
"""Long sequence inference callback.""" |
|
result, output_bits = manager.inference(bits_input, long_inference=True, |
|
ctx_bits=ctx_bits, overlap=overlap) |
|
return result, str(output_bits) if output_bits else "" |
|
|
|
def text_inference_callback(text_input): |
|
"""Text-to-text inference callback.""" |
|
result, output_text = manager.text_inference(text_input) |
|
return result, output_text if output_text else "" |
|
|
|
|
|
with gr.Blocks(title="BitTransformerLM Dashboard", |
|
theme=gr.themes.Soft()) as interface: |
|
|
|
gr.Markdown("# π€ BitTransformerLM Interactive Dashboard") |
|
gr.Markdown("*Experimental bit-native transformer with comprehensive training and inference capabilities*") |
|
|
|
with gr.Tab("ποΈ Model Configuration"): |
|
gr.Markdown("## Initialize BitTransformerLM") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
d_model = gr.Number(label="d_model", value=64, info="Model width") |
|
nhead = gr.Number(label="nhead", value=4, info="Attention heads") |
|
num_layers = gr.Number(label="num_layers", value=2, info="Transformer layers") |
|
dim_feedforward = gr.Number(label="dim_feedforward", value=256, info="FFN dimension") |
|
|
|
with gr.Column(): |
|
max_seq_len = gr.Number(label="max_seq_len", value=512, info="Max sequence length") |
|
chunk_size = gr.Number(label="chunk_size", value=0, info="Chunk size (0=auto)") |
|
overlap = gr.Number(label="overlap", value=64, info="Sliding window overlap") |
|
act_threshold = gr.Number(label="act_threshold", value=0.95, info="ACT halt threshold") |
|
|
|
with gr.Row(): |
|
reversible = gr.Checkbox(label="Reversible Layers", value=False) |
|
use_checkpoint = gr.Checkbox(label="Gradient Checkpointing", value=True) |
|
|
|
with gr.Row(): |
|
c_floor = gr.Number(label="c_floor", value=0.3, info="Complexity safety floor") |
|
s_floor = gr.Number(label="s_floor", value=0.5, info="Symbiosis safety floor") |
|
|
|
init_btn = gr.Button("π Initialize Model", variant="primary") |
|
init_output = gr.Textbox(label="Initialization Result", interactive=False) |
|
|
|
with gr.Tab("π― Training"): |
|
gr.Markdown("## Train BitTransformerLM") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
train_bits = gr.Textbox( |
|
label="Bit Input", |
|
placeholder="0 1 0 1 or [0,1,0,1] or upload file", |
|
lines=3 |
|
) |
|
train_file = gr.File(label="Upload Training File", file_types=[".txt", ".md", ".bin"]) |
|
train_epochs = gr.Number(label="Epochs", value=1, minimum=1) |
|
|
|
with gr.Column(): |
|
train_btn = gr.Button("π Start Training", variant="primary") |
|
train_output = gr.Textbox(label="Training Result", interactive=False) |
|
compression_output = gr.Textbox(label="Compression Info", interactive=False) |
|
|
|
with gr.Tab("π§ Inference"): |
|
with gr.Tab("Standard Inference"): |
|
gr.Markdown("## Standard Inference") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
infer_bits = gr.Textbox( |
|
label="Bit Input", |
|
placeholder="0 1 0 1 or [0,1,0,1]", |
|
lines=3 |
|
) |
|
infer_file = gr.File(label="Upload Inference File") |
|
|
|
with gr.Column(): |
|
infer_btn = gr.Button("π― Run Inference", variant="primary") |
|
infer_result = gr.Textbox(label="Result", interactive=False) |
|
infer_output = gr.Textbox(label="Output Bits", lines=5, interactive=False) |
|
|
|
with gr.Tab("Long Sequence Inference"): |
|
gr.Markdown("## Long Sequence Inference") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
long_bits = gr.Textbox( |
|
label="Long Bit Sequence", |
|
lines=5, |
|
placeholder="Long sequence of bits..." |
|
) |
|
long_ctx_bits = gr.Number(label="Context Bits", value=4096) |
|
long_overlap = gr.Number(label="Overlap", value=256) |
|
|
|
with gr.Column(): |
|
long_infer_btn = gr.Button("π Run Long Inference", variant="primary") |
|
long_result = gr.Textbox(label="Result", interactive=False) |
|
long_output = gr.Textbox(label="Output Bits", lines=5, interactive=False) |
|
|
|
with gr.Tab("Text Inference"): |
|
gr.Markdown("## Text-to-Text Inference") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
text_input = gr.Textbox( |
|
label="Input Text", |
|
placeholder="Enter text to process...", |
|
lines=3 |
|
) |
|
text_infer_btn = gr.Button("π Process Text", variant="primary") |
|
|
|
with gr.Column(): |
|
text_result = gr.Textbox(label="Result", interactive=False) |
|
text_output = gr.Textbox( |
|
label="Output Text", |
|
lines=5, |
|
interactive=False |
|
) |
|
|
|
with gr.Tab("βοΈ Model Operations"): |
|
with gr.Tab("Scale Model"): |
|
gr.Markdown("## Scale Model Width") |
|
|
|
with gr.Row(): |
|
width_mult = gr.Number(label="Width Multiplier", value=1.5, step=0.1) |
|
scale_btn = gr.Button("π Scale Model", variant="secondary") |
|
|
|
scale_output = gr.Textbox(label="Scaling Result", interactive=False) |
|
|
|
with gr.Tab("Collapse Model"): |
|
gr.Markdown("## Collapse Submodel") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
cluster_bits = gr.Textbox( |
|
label="Cluster Bits (JSON)", |
|
placeholder='[[0,1,0,1],[1,1,0,0]]', |
|
lines=3 |
|
) |
|
target_params = gr.Textbox( |
|
label="Target Parameters (JSON)", |
|
placeholder='{"d_model":32,"nhead":4,"num_layers":1}', |
|
lines=3 |
|
) |
|
width_scale = gr.Number(label="Width Scale", value=1.0, step=0.1) |
|
|
|
with gr.Column(): |
|
collapse_btn = gr.Button("ποΈ Collapse Model", variant="secondary") |
|
collapse_output = gr.Textbox(label="Collapse Result", interactive=False) |
|
|
|
with gr.Tab("π Monitoring"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("## Model Status") |
|
status_output = gr.Code(label="Current Status", language="json") |
|
refresh_btn = gr.Button("π Refresh Status") |
|
|
|
with gr.Column(): |
|
gr.Markdown("## System Settings") |
|
|
|
with gr.Row(): |
|
gpu_checkbox = gr.Checkbox(label="π₯ Enable GPU/FSDP", value=False) |
|
qat_checkbox = gr.Checkbox(label="β‘ Enable 4-bit QAT", value=False) |
|
|
|
with gr.Row(): |
|
compression_checkbox = gr.Checkbox(label="ποΈ Enable Compression", value=False) |
|
diffusion_checkbox = gr.Checkbox(label="π Enable Diffusion Mode", value=False) |
|
|
|
gr.Markdown("## π Telemetry Metrics") |
|
telemetry_plot = gr.Plot(label="K/C/S Metrics Over Time") |
|
|
|
with gr.Tab("βοΈ HuggingFace Integration"): |
|
gr.Markdown("## HuggingFace Model Hub") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
hf_repo_id = gr.Textbox(label="Repository ID", placeholder="username/model-name") |
|
hf_token = gr.Textbox(label="HF Token (optional)", type="password") |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
hf_upload_btn = gr.Button("β¬οΈ Upload to HF", variant="secondary") |
|
hf_download_btn = gr.Button("β¬οΈ Download from HF", variant="secondary") |
|
|
|
hf_result = gr.Textbox(label="HuggingFace Result", interactive=False) |
|
|
|
|
|
init_btn.click( |
|
init_model_callback, |
|
inputs=[d_model, nhead, num_layers, dim_feedforward, max_seq_len, |
|
chunk_size, overlap, reversible, use_checkpoint, act_threshold, |
|
c_floor, s_floor], |
|
outputs=[init_output, status_output, telemetry_plot] |
|
) |
|
|
|
train_btn.click( |
|
train_callback, |
|
inputs=[train_bits, train_epochs, train_file], |
|
outputs=[train_output, status_output, telemetry_plot, compression_output] |
|
) |
|
|
|
infer_btn.click( |
|
inference_callback, |
|
inputs=[infer_bits, infer_file], |
|
outputs=[infer_result, infer_output] |
|
) |
|
|
|
long_infer_btn.click( |
|
long_inference_callback, |
|
inputs=[long_bits, long_ctx_bits, long_overlap], |
|
outputs=[long_result, long_output] |
|
) |
|
|
|
text_infer_btn.click( |
|
text_inference_callback, |
|
inputs=[text_input], |
|
outputs=[text_result, text_output] |
|
) |
|
|
|
scale_btn.click( |
|
manager.scale_model, |
|
inputs=[width_mult], |
|
outputs=[scale_output] |
|
) |
|
|
|
collapse_btn.click( |
|
manager.collapse_model, |
|
inputs=[cluster_bits, target_params, width_scale], |
|
outputs=[collapse_output] |
|
) |
|
|
|
refresh_btn.click( |
|
manager.get_model_status, |
|
outputs=[status_output] |
|
) |
|
|
|
hf_upload_btn.click( |
|
manager.huggingface_upload, |
|
inputs=[hf_repo_id, hf_token], |
|
outputs=[hf_result] |
|
) |
|
|
|
hf_download_btn.click( |
|
manager.huggingface_download, |
|
inputs=[hf_repo_id, hf_token], |
|
outputs=[hf_result] |
|
) |
|
|
|
|
|
def update_gpu_setting(enabled): |
|
manager.gpu_enabled = enabled |
|
return f"GPU/FSDP: {'Enabled' if enabled else 'Disabled'}" |
|
|
|
def update_qat_setting(enabled): |
|
manager.qat_enabled = enabled |
|
return f"QAT: {'Enabled' if enabled else 'Disabled'}" |
|
|
|
def update_compression_setting(enabled): |
|
manager.compression_enabled = enabled |
|
return f"Compression: {'Enabled' if enabled else 'Disabled'}" |
|
|
|
def update_diffusion_setting(enabled): |
|
manager.diffusion_enabled = enabled |
|
return f"Diffusion: {'Enabled' if enabled else 'Disabled'}" |
|
|
|
|
|
interface.load( |
|
manager.get_telemetry_plot, |
|
outputs=[telemetry_plot], |
|
every=10 |
|
) |
|
|
|
|
|
interface.load( |
|
manager.get_model_status, |
|
outputs=[status_output] |
|
) |
|
|
|
return interface |
|
|
|
def run_gradio_server(host="127.0.0.1", port=7860, share=False): |
|
"""Run the Gradio server.""" |
|
interface = create_gradio_interface() |
|
|
|
print("π Starting BitTransformerLM Gradio Dashboard...") |
|
print(f"π Server will be available at: http://{host}:{port}") |
|
|
|
if os.getenv("MCP_SERVER_ADDR"): |
|
print(f"π MCP Server configured at: {os.getenv('MCP_SERVER_ADDR')}") |
|
|
|
interface.launch( |
|
server_name=host, |
|
server_port=port, |
|
share=share, |
|
show_error=True, |
|
debug=True |
|
) |
|
|
|
if __name__ == "__main__": |
|
|
|
if os.getenv("SPACE_ID"): |
|
|
|
print("π€ Running on HuggingFace Spaces") |
|
interface = create_gradio_interface() |
|
interface.launch() |
|
else: |
|
|
|
import argparse |
|
parser = argparse.ArgumentParser(description="BitTransformerLM Gradio Dashboard") |
|
parser.add_argument("--host", default="127.0.0.1", help="Host address") |
|
parser.add_argument("--port", type=int, default=7860, help="Port number") |
|
parser.add_argument("--share", action="store_true", help="Enable sharing") |
|
|
|
args = parser.parse_args() |
|
run_gradio_server(args.host, args.port, args.share) |