import io import json import os import traceback import inspect from typing import Any, Dict, List, Optional, Union from flask import Flask, jsonify, request, render_template, send_file import subprocess import sys import warnings import matplotlib.pyplot as plt import torch import torch.nn.functional as F import requests import gzip from .model import BitTransformerLM, infer_long_sequence from .optimization import configure_optimizer from .collapse import collapse_submodel from .dashboard import plot_telemetry from .scale import expand_model from .bit_io import text_to_bits, bits_to_text from .safety import hil_safe_inference from .compression import model_output_decompress, compress_bits from .distributed import wrap_fsdp from .training import train_loop from .telemetry import detect_metric_drift from .quantization import prepare_qat_fx, convert_qat_fx from torch.distributed.fsdp import FullyShardedDataParallel from .hf_checkpoint import hf_login, save_checkpoint, download_checkpoint app = Flask(__name__) app.config["MAX_CONTENT_LENGTH"] = 1 * 1024 * 1024 # 1MB upload limit MCP_SERVER_ADDR = os.getenv("MCP_SERVER_ADDR") @app.errorhandler(Exception) def handle_exception(err): """Return JSON error responses with stack traces.""" return ( jsonify({"error": str(err), "trace": traceback.format_exc()}), getattr(err, "code", 500), ) class MetricDriftWarning(UserWarning): """Raised when telemetry metrics drift beyond the configured threshold.""" def _switch_torch(use_gpu: bool) -> None: """Install the appropriate PyTorch wheel and restart the process.""" have_cuda = torch.version.cuda is not None if use_gpu == have_cuda: return wheel = "torch==2.7.1+cu118" if use_gpu else "torch==2.7.1+cpu" url = "https://download.pytorch.org/whl/cu118" if use_gpu else "https://download.pytorch.org/whl/cpu" subprocess.run([ sys.executable, "-m", "pip", "install", "--extra-index-url", url, wheel, ], check=True) os.execv(sys.executable, [sys.executable] + sys.argv) def mcp_post(path: str, data=None): if not MCP_SERVER_ADDR: return None url = MCP_SERVER_ADDR.rstrip("/") + path resp = requests.post(url, json=data) resp.raise_for_status() if resp.headers.get("Content-Type", "").startswith("image/"): return resp.content return resp.json() def mcp_get(path: str): if not MCP_SERVER_ADDR: return None url = MCP_SERVER_ADDR.rstrip("/") + path resp = requests.get(url) resp.raise_for_status() if resp.headers.get("Content-Type", "").startswith("image/"): return resp.content return resp.json() class ModelManager: """Manage model state and training utilities for the dashboard.""" def __init__( self, snapshot_dir: Optional[str] = None, telemetry_log: Optional[str] = None, *, drift_window: int = 10, drift_threshold: float = 0.2, ) -> None: self.snapshot_dir = snapshot_dir or os.getenv("SNAPSHOT_DIR", "snapshots") self.telemetry_log = telemetry_log or os.getenv("TELEMETRY_LOG") if self.telemetry_log is None: self.telemetry_log = os.path.join(self.snapshot_dir, "metrics.json") os.makedirs(self.snapshot_dir, exist_ok=True) self.weights_path = os.path.join(self.snapshot_dir, "model.pt") self.model: Optional[BitTransformerLM] = None self.optimizer: Optional[torch.optim.Optimizer] = None self.scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None self.total_steps = 100 self.metrics: Dict[str, List[float]] = { "negentropy_logits": [], "lz_complexity_logits": [], "symbiosis_score": [], } self.drift_window = drift_window self.drift_threshold = drift_threshold self.lambda_K = 1.0 self.lambda_C = 1.0 self.lambda_S = 1.0 self.c_floor = 0.3 self.s_floor = 0.5 self.causal = True self.diffusion = False self.decompress_output = False self.use_compression = False self.use_gpu = False self.qat = False # Load any existing state if os.path.exists(self.telemetry_log): try: with open(self.telemetry_log) as f: saved = json.load(f) for key in self.metrics: self.metrics[key] = saved.get(key, []) except Exception: pass if os.path.exists(self.weights_path): try: self.model = torch.load(self.weights_path, map_location="cpu") self.optimizer, self.scheduler = configure_optimizer( self.model, lr=1e-3, total_steps=self.total_steps ) self._apply_device() except Exception: self.model = None config_path = os.getenv("MODEL_CONFIG", "/config/model_params.json") if self.model is None and os.path.exists(config_path): try: with open(config_path) as f: params = json.load(f) self.init_model(params) except Exception: pass def init_model(self, params: Dict) -> None: int_fields = { "d_model", "nhead", "num_layers", "dim_feedforward", "max_seq_len", "chunk_size", "overlap", } float_fields = {"act_threshold"} bool_fields = {"reversible", "use_checkpoint"} clean: Dict[str, Any] = {} for k, v in params.items(): if v is None: clean[k] = None elif k in int_fields: clean[k] = int(v) elif k in float_fields: clean[k] = float(v) elif k in bool_fields: clean[k] = bool(v) else: clean[k] = v self.model = BitTransformerLM( **clean, lambda_K=self.lambda_K, lambda_C=self.lambda_C, lambda_S=self.lambda_S, ) self.optimizer, self.scheduler = configure_optimizer( self.model, lr=1e-3, total_steps=self.total_steps ) self._apply_device() for key in self.metrics: self.metrics[key].clear() def set_lambdas(self, k: float, c: float, s: float) -> None: """Update λ weights and propagate to the model.""" self.lambda_K = k self.lambda_C = c self.lambda_S = s if self.model is not None: self.model.set_lambdas(k, c, s) def set_floors(self, c_floor: float, s_floor: float) -> None: """Update safety floors for complexity (C) and symbiosis (S).""" self.c_floor = c_floor self.s_floor = s_floor def set_diffusion(self, flag: bool) -> None: """Toggle Diffusion LM mode which disables causal masking and chunking.""" self.diffusion = flag self.causal = not flag if self.model is not None and flag: self.model.chunk_size = None def set_decompress_output(self, flag: bool) -> None: """Enable or disable decompression of model outputs.""" self.decompress_output = flag def set_compression(self, flag: bool) -> None: """Toggle automatic compression of inputs.""" self.use_compression = flag def set_qat(self, flag: bool) -> None: """Enable or disable 4-bit quantization-aware training.""" self.qat = flag if self.model is None: return if flag: self.model = prepare_qat_fx(self.model) else: self.model = convert_qat_fx(self.model) def set_gpu(self, flag: bool) -> None: """Toggle GPU acceleration and FSDP, reinstalling PyTorch if needed.""" _switch_torch(flag) self.use_gpu = flag and torch.cuda.is_available() self._apply_device() def _apply_device(self) -> None: """Move the model to the selected device and wrap with FSDP if needed.""" if self.model is None: return if self.use_gpu: device = torch.device("cuda") if isinstance(self.model, FullyShardedDataParallel): base = self.model.module else: base = self.model base = base.to(device) self.model = wrap_fsdp(base, device_id=device) else: device = torch.device("cpu") if isinstance(self.model, FullyShardedDataParallel): self.model = self.model.module self.model = self.model.to(device) def train_step(self, bits: torch.Tensor) -> float: assert ( self.model is not None and self.optimizer is not None and self.scheduler is not None ) self.model.train() device = next(self.model.parameters()).device bits = bits.to(device) ratio = 1.0 if self.use_compression: comps = [compress_bits(row.to(torch.uint8)) for row in bits] comp_len = sum(c.numel() for c in comps) ratio = min(comp_len / bits.numel(), 1.0) logits, telemetry = self.model.forward_compressed(comps, causal=self.causal) else: logits, telemetry = self.model(bits, causal=self.causal) pred = logits[:, :-1, :].reshape(-1, 2) target = bits[:, 1:].reshape(-1) loss = F.cross_entropy(pred, target) loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() self._log_metrics(telemetry) self._save_state() return loss.item(), ratio def train_epochs( self, bits: torch.Tensor, *, epochs: int = 1, compress_prob: float = 0.5, direct_prob: float = 0.0, batch_size: int = 8, num_workers: int = 0, accum_steps: int = 1, amp: bool = False, compile_model: bool = False, ) -> List[Dict[str, float]]: """Run ``train_loop`` on a batch tensor and persist the state.""" assert self.model is not None device = next(self.model.parameters()).device bits = bits.to(device) import math steps_per_epoch = max(1, math.ceil(len(bits) / batch_size)) self.total_steps = math.ceil(epochs * steps_per_epoch / accum_steps) self.optimizer, self.scheduler = configure_optimizer( self.model, lr=1e-3, total_steps=self.total_steps ) metrics = train_loop( self.model, bits, epochs=epochs, compress_prob=compress_prob if self.use_compression else 0.0, direct_prob=direct_prob, batch_size=batch_size, num_workers=num_workers, accum_steps=accum_steps, amp=amp, compile_model=compile_model, forward_kwargs={"causal": self.causal}, optimizer=self.optimizer, scheduler=self.scheduler, ) self._save_state() return metrics def scale_up(self, width_mult: float = 1.0) -> None: assert self.model is not None params = dict( d_model=int(self.model.d_model * width_mult), nhead=self.model.layers[0].self_attn.num_heads, num_layers=self.model.num_layers * 2, dim_feedforward=int(self.model.layers[0].linear1.out_features * width_mult), max_seq_len=self.model.pos_enc.pe.size(0), ) self.model = expand_model(self.model, { **params, "lambda_K": self.lambda_K, "lambda_C": self.lambda_C, "lambda_S": self.lambda_S, }) self.optimizer, self.scheduler = configure_optimizer( self.model, lr=1e-3, total_steps=self.total_steps ) self._save_state() def collapse(self, cluster_bits: List[List[int]], target_params: Dict, width_scale: float = 1.0) -> None: self.model, _ = collapse_submodel( cluster_bits, target_params, width_scale=width_scale, forward_kwargs={"causal": self.causal}, ) self.model.set_lambdas(self.lambda_K, self.lambda_C, self.lambda_S) self.optimizer, self.scheduler = configure_optimizer( self.model, lr=1e-3, total_steps=self.total_steps ) self._apply_device() for key in self.metrics: self.metrics[key].clear() def infer(self, bits: torch.Tensor) -> Dict: assert self.model is not None self.model.eval() device = next(self.model.parameters()).device bits = bits.to(device) ratio = 1.0 with torch.no_grad(): if self.use_compression: comps = [compress_bits(row.to(torch.uint8)) for row in bits] comp_len = sum(c.numel() for c in comps) ratio = min(comp_len / bits.numel(), 1.0) logits, telemetry = self.model.forward_compressed(comps, causal=self.causal) else: logits, telemetry = self.model(bits, causal=self.causal) self._log_metrics(telemetry) pred_bits = logits.argmax(-1) if self.decompress_output: try: pred_bits = model_output_decompress(pred_bits) except Exception as e: return {"error": f"Decompression failed: {e}", "suggestion": "Disable compression toggle."} def _to_python(obj): if isinstance(obj, torch.Tensor): return obj.tolist() if isinstance(obj, list): return [_to_python(o) for o in obj] if isinstance(obj, dict): return {kk: _to_python(vv) for kk, vv in obj.items()} return obj tele = {k: _to_python(v) for k, v in telemetry.items()} return {"predicted": pred_bits.squeeze(0).tolist(), "telemetry": tele, "ratio": ratio} def infer_long(self, bits: torch.Tensor, ctx_bits: int = 4096, overlap: int = 256) -> Dict: """Run sliding-window inference on a long sequence.""" assert self.model is not None device = next(self.model.parameters()).device bits = bits.to(device) preds, logs = infer_long_sequence(self.model, bits.squeeze(0), ctx_bits=ctx_bits, overlap=overlap) for tele in logs: self._log_metrics(tele) return {"predicted": preds.tolist(), "windows": len(logs)} def _log_metrics(self, telemetry: Dict) -> None: for key in self.metrics: val = telemetry[key].mean().item() self.metrics[key].append(val) drift = detect_metric_drift( self.metrics, window=self.drift_window, threshold=self.drift_threshold ) bad = [k for k, v in drift.items() if v] if bad: warnings.warn( f"Metric drift detected: {', '.join(bad)}", MetricDriftWarning, ) def infer_text(self, text: str) -> Dict[str, Any]: """Run text through the model using the safety gate.""" assert self.model is not None device = next(self.model.parameters()).device bits = torch.tensor(text_to_bits(text), dtype=torch.long).unsqueeze(0).to(device) out_bits, telemetry = hil_safe_inference( self.model, bits, c_floor=self.c_floor, s_floor=self.s_floor ) self._log_metrics(telemetry) return { "output": bits_to_text(out_bits.squeeze(0).tolist()), "telemetry": telemetry, } def get_status(self) -> Dict[str, Any]: info: Dict[str, Any] = { "use_gpu": self.use_gpu, "diffusion": self.diffusion, "compression": self.use_compression, "lambda_K": self.lambda_K, "lambda_C": self.lambda_C, "lambda_S": self.lambda_S, "c_floor": self.c_floor, "s_floor": self.s_floor, "qat": self.qat, } if self.model is not None: info.update( { "d_model": self.model.d_model, "num_layers": self.model.num_layers, "d_ff": self.model.layers[0].linear1.out_features, "nhead": self.model.layers[0].self_attn.num_heads, "max_seq_len": self.model.pos_enc.pe.size(0), } ) else: info.update( { "d_model": None, "num_layers": 0, "d_ff": None, "nhead": None, "max_seq_len": None, } ) return info def get_model_config(self) -> Dict[str, Any]: """Return current model hyperparameters and safety settings.""" cfg: Dict[str, Any] = { "lambda_K": self.lambda_K, "lambda_C": self.lambda_C, "lambda_S": self.lambda_S, "c_floor": self.c_floor, "s_floor": self.s_floor, } if self.model is not None: cfg.update( { "d_model": self.model.d_model, "nhead": self.model.layers[0].self_attn.num_heads, "num_layers": self.model.num_layers, "dim_feedforward": self.model.layers[0].linear1.out_features, "max_seq_len": self.model.pos_enc.pe.size(0), "chunk_size": self.model.chunk_size, "reversible": self.model.reversible, "use_checkpoint": self.model.use_checkpoint, } ) else: cfg.update( { "d_model": None, "nhead": None, "num_layers": 0, "dim_feedforward": None, "max_seq_len": None, "chunk_size": None, "reversible": None, "use_checkpoint": None, } ) return cfg def get_metrics(self) -> Dict[str, Any]: """Return logged telemetry metrics with summary statistics.""" from statistics import mean, stdev data = { "negentropy": self.metrics["negentropy_logits"], "lz_complexity": self.metrics["lz_complexity_logits"], "symbiosis": self.metrics["symbiosis_score"], } summary: Dict[str, Dict[str, Optional[float]]] = {} for key, values in data.items(): if values: m = mean(values) s = stdev(values) if len(values) > 1 else 0.0 summary[key] = {"mean": m, "std": s} else: summary[key] = {"mean": None, "std": None} data["summary"] = summary return data def _save_state(self) -> None: if self.model is None: return torch.save(self.model, self.weights_path) with open(self.telemetry_log, "w") as f: json.dump(self.metrics, f) manager: Optional[ModelManager] = None @app.route("/") def index(): return render_template( "dashboard.html", metrics=manager.metrics, lambdas={ "lambda_K": manager.lambda_K, "lambda_C": manager.lambda_C, "lambda_S": manager.lambda_S, }, diffusion=manager.diffusion, compression=manager.use_compression, defaults={k: v.default for k, v in inspect.signature(BitTransformerLM.__init__).parameters.items() if v.default is not inspect._empty}, c_floor=manager.c_floor, s_floor=manager.s_floor, qat=manager.qat, ) @app.route("/status", methods=["GET"]) def status(): if MCP_SERVER_ADDR: return jsonify(mcp_get("/status")) return jsonify(manager.get_status()) @app.route("/model_config", methods=["GET"]) def model_config(): if MCP_SERVER_ADDR: return jsonify(mcp_get("/model_config")) return jsonify(manager.get_model_config()) @app.route("/metrics", methods=["GET"]) def metrics(): if MCP_SERVER_ADDR: return jsonify(mcp_get("/metrics")) return jsonify(manager.get_metrics()) @app.route("/save_checkpoint", methods=["POST"]) def save_checkpoint_route(): repo_id = request.json.get("repo_id") token = request.json.get("token") or os.getenv("HF_TOKEN") if MCP_SERVER_ADDR: return jsonify(mcp_post("/save_checkpoint", {"repo_id": repo_id, "token": token})) if manager.model is None: return jsonify({"error": "model not initialized"}), 400 if token: hf_login(token=token) save_checkpoint(manager.model, repo_id=repo_id) return jsonify({"status": "saved"}) @app.route("/download_checkpoint", methods=["POST"]) def download_checkpoint_route(): repo_id = request.json.get("repo_id") token = request.json.get("token") or os.getenv("HF_TOKEN") if MCP_SERVER_ADDR: return jsonify(mcp_post("/download_checkpoint", {"repo_id": repo_id, "token": token})) if token: hf_login(token=token) dest = manager.weights_path + ".gz" ok = download_checkpoint(dest, repo_id=repo_id) if not ok: return jsonify({"status": "failed"}), 500 if manager.model is None: return jsonify({"status": "downloaded", "loaded": False}) with gzip.open(dest, "rb") as f: state = torch.load(f, map_location="cpu") manager.model.load_state_dict(state) manager.optimizer, manager.scheduler = configure_optimizer( manager.model, lr=1e-3, total_steps=manager.total_steps ) manager._apply_device() manager._save_state() return jsonify({"status": "downloaded", "loaded": True}) @app.route("/text_to_bits", methods=["POST"]) def text_to_bits_route(): text = request.json.get("text", "") if len(text) > 100_000: return jsonify({"error": "text too large"}), 413 return jsonify({"bits": text_to_bits(text)}) @app.route("/dataset", methods=["GET"]) def dataset_route(): name = request.args.get("name", "") split = request.args.get("split", "train") size = int(request.args.get("size", 1)) seq_len = int(request.args.get("seq_len", 64)) if size * seq_len > 1_000_000: return jsonify({"error": "dataset too large"}), 413 if name == "wikitext2": try: from datasets import load_dataset ds = load_dataset("wikitext", "wikitext-2-raw-v1", split=split) lines = [t for t in ds["text"] if t.strip()][:size] except Exception: bits = torch.randint(0, 2, (size, seq_len), dtype=torch.long) return jsonify({"bits": bits.tolist()}) bits_list = [] for text in lines: b = text_to_bits(text)[:seq_len] if len(b) < seq_len: b.extend([0] * (seq_len - len(b))) bits_list.append(b) if len(bits_list) < size: pad = size - len(bits_list) bits_list.extend(torch.randint(0, 2, (pad, seq_len), dtype=torch.long).tolist()) return jsonify({"bits": bits_list}) return jsonify({"error": "unknown dataset"}), 400 @app.route("/init", methods=["POST"]) def init_model(): data = request.json or {} int_fields = { "d_model", "nhead", "num_layers", "dim_feedforward", "max_seq_len", "chunk_size", "overlap", } float_fields = {"act_threshold"} bool_fields = {"reversible", "use_checkpoint"} params = {} for k, v in data.items(): if v is None: params[k] = None elif k in int_fields: params[k] = int(v) elif k in float_fields: params[k] = float(v) elif k in bool_fields: params[k] = bool(v) else: params[k] = v if MCP_SERVER_ADDR: data = mcp_post("/init", params) return jsonify(data) manager.init_model(params) return jsonify({"status": "initialized", "params": params}) @app.route("/train", methods=["POST"]) def train_model(): bits = torch.tensor(request.json["bits"], dtype=torch.long) if MCP_SERVER_ADDR: data = mcp_post("/train", {"bits": request.json["bits"]}) return jsonify(data) loss, ratio = manager.train_step(bits) return jsonify({"loss": loss, "ratio": ratio}) @app.route("/train_epochs", methods=["POST"]) def train_epochs_route(): bits = torch.tensor(request.json["bits"], dtype=torch.long) epochs = int(request.json.get("epochs", 1)) compress_prob = float(request.json.get("compress_prob", 0.5)) direct_prob = float(request.json.get("direct_prob", 0.0)) if MCP_SERVER_ADDR: data = mcp_post( "/train_epochs", { "bits": request.json["bits"], "epochs": epochs, "compress_prob": compress_prob, "direct_prob": direct_prob, }, ) return jsonify(data) metrics = manager.train_epochs( bits, epochs=epochs, compress_prob=compress_prob, direct_prob=direct_prob, ) return jsonify({"metrics": metrics}) @app.route("/scale_up", methods=["POST"]) def scale_up(): width_mult = float(request.json.get("width_mult", 1.0)) if MCP_SERVER_ADDR: data = mcp_post("/scale_up", {"width_mult": width_mult}) return jsonify(data) manager.scale_up(width_mult) return jsonify({ "status": "scaled", "layers": manager.model.num_layers, "d_model": manager.model.d_model, }) @app.route("/collapse", methods=["POST"]) def collapse_model(): cluster_bits = request.json["clusters"] params = {k: int(v) for k, v in request.json["params"].items()} width_scale = float(request.json.get("width_scale", 1.0)) if MCP_SERVER_ADDR: data = mcp_post( "/collapse", {"clusters": cluster_bits, "params": params, "width_scale": width_scale}, ) return jsonify(data) manager.collapse(cluster_bits, params, width_scale) return jsonify({"status": "collapsed"}) @app.route("/lambdas", methods=["GET", "POST"]) def update_lambdas(): if request.method == "POST": data = request.json if MCP_SERVER_ADDR: res = mcp_post("/lambdas", data) return jsonify(res) manager.set_lambdas( float(data["lambda_K"]), float(data["lambda_C"]), float(data["lambda_S"]) ) return jsonify({"status": "updated"}) else: if MCP_SERVER_ADDR: return jsonify(mcp_get("/lambdas")) return jsonify( { "lambda_K": manager.lambda_K, "lambda_C": manager.lambda_C, "lambda_S": manager.lambda_S, } ) @app.route("/config/telemetry", methods=["GET", "POST"]) def telemetry_config(): """Get or update telemetry λ weights and safety floors.""" if request.method == "POST": data = request.json if MCP_SERVER_ADDR: res = mcp_post("/config/telemetry", data) return jsonify(res) manager.set_lambdas( float(data.get("lambda_K", manager.lambda_K)), float(data.get("lambda_C", manager.lambda_C)), float(data.get("lambda_S", manager.lambda_S)), ) manager.set_floors( float(data.get("c_floor", manager.c_floor)), float(data.get("s_floor", manager.s_floor)), ) return jsonify({"status": "updated"}) else: if MCP_SERVER_ADDR: return jsonify(mcp_get("/config/telemetry")) return jsonify( { "lambda_K": manager.lambda_K, "lambda_C": manager.lambda_C, "lambda_S": manager.lambda_S, "c_floor": manager.c_floor, "s_floor": manager.s_floor, } ) @app.route("/diffusion", methods=["GET", "POST"]) def update_diffusion(): if request.method == "POST": if MCP_SERVER_ADDR: return jsonify(mcp_post("/diffusion", request.json)) manager.set_diffusion(bool(request.json.get("diffusion", False))) return jsonify({"status": "updated"}) else: if MCP_SERVER_ADDR: return jsonify(mcp_get("/diffusion")) return jsonify({"diffusion": manager.diffusion}) @app.route("/gpu", methods=["GET", "POST"]) def update_gpu(): if request.method == "POST": if MCP_SERVER_ADDR: return jsonify(mcp_post("/gpu", request.json)) manager.set_gpu(bool(request.json.get("use_gpu", False))) return jsonify({"status": "updated"}) else: if MCP_SERVER_ADDR: return jsonify(mcp_get("/gpu")) return jsonify({"use_gpu": manager.use_gpu}) @app.route("/compression", methods=["GET", "POST"]) def update_compression(): if request.method == "POST": if MCP_SERVER_ADDR: return jsonify(mcp_post("/compression", request.json)) manager.set_compression(bool(request.json.get("compression", False))) return jsonify({"status": "updated"}) else: if MCP_SERVER_ADDR: return jsonify(mcp_get("/compression")) return jsonify({"compression": manager.use_compression}) @app.route("/qat", methods=["GET", "POST"]) def update_qat(): if request.method == "POST": if MCP_SERVER_ADDR: return jsonify(mcp_post("/qat", request.json)) manager.set_qat(bool(request.json.get("qat", False))) return jsonify({"status": "updated"}) else: if MCP_SERVER_ADDR: return jsonify(mcp_get("/qat")) return jsonify({"qat": manager.qat}) @app.route("/infer", methods=["POST"]) def inference(): bits = torch.tensor(request.json["bits"], dtype=torch.long) if MCP_SERVER_ADDR: data = mcp_post("/infer", {"bits": request.json["bits"]}) return jsonify(data) result = manager.infer(bits) return jsonify(result) @app.route("/infer_long", methods=["POST"]) def inference_long(): bits = torch.tensor(request.json["bits"], dtype=torch.long) ctx = int(request.json.get("ctx_bits", 4096)) overlap = int(request.json.get("overlap", 256)) if MCP_SERVER_ADDR: data = mcp_post( "/infer_long", {"bits": request.json["bits"], "ctx_bits": ctx, "overlap": overlap}, ) return jsonify(data) result = manager.infer_long(bits, ctx_bits=ctx, overlap=overlap) return jsonify(result) @app.route("/infer_text", methods=["POST"]) def inference_text(): text = request.json.get("text", "") if MCP_SERVER_ADDR: data = mcp_post("/infer_text", {"text": text}) return jsonify(data) result = manager.infer_text(text) return jsonify(result) @app.route("/plot.png") def plot_png(): if MCP_SERVER_ADDR: resp = requests.get(MCP_SERVER_ADDR.rstrip("/") + "/plot.png") resp.raise_for_status() return send_file(io.BytesIO(resp.content), mimetype="image/png") fig, _ = plot_telemetry(manager.metrics) buf = io.BytesIO() fig.savefig(buf, format="png") plt.close(fig) buf.seek(0) return send_file(buf, mimetype="image/png") def run_dashboard(host: Optional[str] = None, port: Optional[int] = None, snapshot_dir: Optional[str] = None, telemetry_log: Optional[str] = None) -> None: """Launch the Flask dashboard server.""" env_host = os.getenv("HOST", "0.0.0.0") env_port = int(os.getenv("PORT", "5000")) host = host or env_host port = port or env_port global manager if manager is None: manager = ModelManager(snapshot_dir, telemetry_log) app.run(host=host, port=port, debug=True) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Run dashboard server") parser.add_argument("--host", default=os.getenv("HOST", "0.0.0.0")) parser.add_argument("--port", type=int, default=int(os.getenv("PORT", "5000"))) parser.add_argument("--snapshot-dir", default=os.getenv("SNAPSHOT_DIR", "snapshots")) parser.add_argument("--telemetry-log", default=os.getenv("TELEMETRY_LOG")) args = parser.parse_args() run_dashboard(args.host, args.port, args.snapshot_dir, args.telemetry_log)