|
import io |
|
import json |
|
import os |
|
import traceback |
|
import inspect |
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
from flask import Flask, jsonify, request, render_template, send_file |
|
import subprocess |
|
import sys |
|
import warnings |
|
import matplotlib.pyplot as plt |
|
import torch |
|
import torch.nn.functional as F |
|
import requests |
|
import gzip |
|
|
|
from .model import BitTransformerLM, infer_long_sequence |
|
from .optimization import configure_optimizer |
|
from .collapse import collapse_submodel |
|
from .dashboard import plot_telemetry |
|
from .scale import expand_model |
|
from .bit_io import text_to_bits, bits_to_text |
|
from .safety import hil_safe_inference |
|
from .compression import model_output_decompress, compress_bits |
|
from .distributed import wrap_fsdp |
|
from .training import train_loop |
|
from .telemetry import detect_metric_drift |
|
from .quantization import prepare_qat_fx, convert_qat_fx |
|
from torch.distributed.fsdp import FullyShardedDataParallel |
|
from .hf_checkpoint import hf_login, save_checkpoint, download_checkpoint |
|
|
|
|
|
app = Flask(__name__) |
|
app.config["MAX_CONTENT_LENGTH"] = 1 * 1024 * 1024 |
|
|
|
MCP_SERVER_ADDR = os.getenv("MCP_SERVER_ADDR") |
|
|
|
|
|
@app.errorhandler(Exception) |
|
def handle_exception(err): |
|
"""Return JSON error responses with stack traces.""" |
|
return ( |
|
jsonify({"error": str(err), "trace": traceback.format_exc()}), |
|
getattr(err, "code", 500), |
|
) |
|
|
|
class MetricDriftWarning(UserWarning): |
|
"""Raised when telemetry metrics drift beyond the configured threshold.""" |
|
|
|
def _switch_torch(use_gpu: bool) -> None: |
|
"""Install the appropriate PyTorch wheel and restart the process.""" |
|
have_cuda = torch.version.cuda is not None |
|
if use_gpu == have_cuda: |
|
return |
|
wheel = "torch==2.7.1+cu118" if use_gpu else "torch==2.7.1+cpu" |
|
url = "https://download.pytorch.org/whl/cu118" if use_gpu else "https://download.pytorch.org/whl/cpu" |
|
subprocess.run([ |
|
sys.executable, |
|
"-m", |
|
"pip", |
|
"install", |
|
"--extra-index-url", |
|
url, |
|
wheel, |
|
], check=True) |
|
os.execv(sys.executable, [sys.executable] + sys.argv) |
|
|
|
def mcp_post(path: str, data=None): |
|
if not MCP_SERVER_ADDR: |
|
return None |
|
url = MCP_SERVER_ADDR.rstrip("/") + path |
|
resp = requests.post(url, json=data) |
|
resp.raise_for_status() |
|
if resp.headers.get("Content-Type", "").startswith("image/"): |
|
return resp.content |
|
return resp.json() |
|
|
|
def mcp_get(path: str): |
|
if not MCP_SERVER_ADDR: |
|
return None |
|
url = MCP_SERVER_ADDR.rstrip("/") + path |
|
resp = requests.get(url) |
|
resp.raise_for_status() |
|
if resp.headers.get("Content-Type", "").startswith("image/"): |
|
return resp.content |
|
return resp.json() |
|
|
|
class ModelManager: |
|
"""Manage model state and training utilities for the dashboard.""" |
|
|
|
def __init__( |
|
self, |
|
snapshot_dir: Optional[str] = None, |
|
telemetry_log: Optional[str] = None, |
|
*, |
|
drift_window: int = 10, |
|
drift_threshold: float = 0.2, |
|
) -> None: |
|
self.snapshot_dir = snapshot_dir or os.getenv("SNAPSHOT_DIR", "snapshots") |
|
self.telemetry_log = telemetry_log or os.getenv("TELEMETRY_LOG") |
|
if self.telemetry_log is None: |
|
self.telemetry_log = os.path.join(self.snapshot_dir, "metrics.json") |
|
os.makedirs(self.snapshot_dir, exist_ok=True) |
|
self.weights_path = os.path.join(self.snapshot_dir, "model.pt") |
|
|
|
self.model: Optional[BitTransformerLM] = None |
|
self.optimizer: Optional[torch.optim.Optimizer] = None |
|
self.scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None |
|
self.total_steps = 100 |
|
self.metrics: Dict[str, List[float]] = { |
|
"negentropy_logits": [], |
|
"lz_complexity_logits": [], |
|
"symbiosis_score": [], |
|
} |
|
self.drift_window = drift_window |
|
self.drift_threshold = drift_threshold |
|
self.lambda_K = 1.0 |
|
self.lambda_C = 1.0 |
|
self.lambda_S = 1.0 |
|
self.c_floor = 0.3 |
|
self.s_floor = 0.5 |
|
self.causal = True |
|
self.diffusion = False |
|
self.decompress_output = False |
|
self.use_compression = False |
|
self.use_gpu = False |
|
self.qat = False |
|
|
|
|
|
if os.path.exists(self.telemetry_log): |
|
try: |
|
with open(self.telemetry_log) as f: |
|
saved = json.load(f) |
|
for key in self.metrics: |
|
self.metrics[key] = saved.get(key, []) |
|
except Exception: |
|
pass |
|
if os.path.exists(self.weights_path): |
|
try: |
|
self.model = torch.load(self.weights_path, map_location="cpu") |
|
self.optimizer, self.scheduler = configure_optimizer( |
|
self.model, lr=1e-3, total_steps=self.total_steps |
|
) |
|
self._apply_device() |
|
except Exception: |
|
self.model = None |
|
|
|
config_path = os.getenv("MODEL_CONFIG", "/config/model_params.json") |
|
if self.model is None and os.path.exists(config_path): |
|
try: |
|
with open(config_path) as f: |
|
params = json.load(f) |
|
self.init_model(params) |
|
except Exception: |
|
pass |
|
|
|
def init_model(self, params: Dict) -> None: |
|
int_fields = { |
|
"d_model", |
|
"nhead", |
|
"num_layers", |
|
"dim_feedforward", |
|
"max_seq_len", |
|
"chunk_size", |
|
"overlap", |
|
} |
|
float_fields = {"act_threshold"} |
|
bool_fields = {"reversible", "use_checkpoint"} |
|
clean: Dict[str, Any] = {} |
|
for k, v in params.items(): |
|
if v is None: |
|
clean[k] = None |
|
elif k in int_fields: |
|
clean[k] = int(v) |
|
elif k in float_fields: |
|
clean[k] = float(v) |
|
elif k in bool_fields: |
|
clean[k] = bool(v) |
|
else: |
|
clean[k] = v |
|
self.model = BitTransformerLM( |
|
**clean, |
|
lambda_K=self.lambda_K, |
|
lambda_C=self.lambda_C, |
|
lambda_S=self.lambda_S, |
|
) |
|
self.optimizer, self.scheduler = configure_optimizer( |
|
self.model, lr=1e-3, total_steps=self.total_steps |
|
) |
|
self._apply_device() |
|
for key in self.metrics: |
|
self.metrics[key].clear() |
|
|
|
def set_lambdas(self, k: float, c: float, s: float) -> None: |
|
"""Update λ weights and propagate to the model.""" |
|
self.lambda_K = k |
|
self.lambda_C = c |
|
self.lambda_S = s |
|
if self.model is not None: |
|
self.model.set_lambdas(k, c, s) |
|
|
|
def set_floors(self, c_floor: float, s_floor: float) -> None: |
|
"""Update safety floors for complexity (C) and symbiosis (S).""" |
|
self.c_floor = c_floor |
|
self.s_floor = s_floor |
|
|
|
def set_diffusion(self, flag: bool) -> None: |
|
"""Toggle Diffusion LM mode which disables causal masking and chunking.""" |
|
self.diffusion = flag |
|
self.causal = not flag |
|
if self.model is not None and flag: |
|
self.model.chunk_size = None |
|
|
|
def set_decompress_output(self, flag: bool) -> None: |
|
"""Enable or disable decompression of model outputs.""" |
|
self.decompress_output = flag |
|
|
|
def set_compression(self, flag: bool) -> None: |
|
"""Toggle automatic compression of inputs.""" |
|
self.use_compression = flag |
|
|
|
def set_qat(self, flag: bool) -> None: |
|
"""Enable or disable 4-bit quantization-aware training.""" |
|
self.qat = flag |
|
if self.model is None: |
|
return |
|
if flag: |
|
self.model = prepare_qat_fx(self.model) |
|
else: |
|
self.model = convert_qat_fx(self.model) |
|
|
|
def set_gpu(self, flag: bool) -> None: |
|
"""Toggle GPU acceleration and FSDP, reinstalling PyTorch if needed.""" |
|
_switch_torch(flag) |
|
self.use_gpu = flag and torch.cuda.is_available() |
|
self._apply_device() |
|
|
|
def _apply_device(self) -> None: |
|
"""Move the model to the selected device and wrap with FSDP if needed.""" |
|
if self.model is None: |
|
return |
|
if self.use_gpu: |
|
device = torch.device("cuda") |
|
if isinstance(self.model, FullyShardedDataParallel): |
|
base = self.model.module |
|
else: |
|
base = self.model |
|
base = base.to(device) |
|
self.model = wrap_fsdp(base, device_id=device) |
|
else: |
|
device = torch.device("cpu") |
|
if isinstance(self.model, FullyShardedDataParallel): |
|
self.model = self.model.module |
|
self.model = self.model.to(device) |
|
|
|
def train_step(self, bits: torch.Tensor) -> float: |
|
assert ( |
|
self.model is not None |
|
and self.optimizer is not None |
|
and self.scheduler is not None |
|
) |
|
self.model.train() |
|
device = next(self.model.parameters()).device |
|
bits = bits.to(device) |
|
ratio = 1.0 |
|
if self.use_compression: |
|
comps = [compress_bits(row.to(torch.uint8)) for row in bits] |
|
comp_len = sum(c.numel() for c in comps) |
|
ratio = min(comp_len / bits.numel(), 1.0) |
|
logits, telemetry = self.model.forward_compressed(comps, causal=self.causal) |
|
else: |
|
logits, telemetry = self.model(bits, causal=self.causal) |
|
pred = logits[:, :-1, :].reshape(-1, 2) |
|
target = bits[:, 1:].reshape(-1) |
|
loss = F.cross_entropy(pred, target) |
|
loss.backward() |
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
|
self.optimizer.step() |
|
self.scheduler.step() |
|
self.optimizer.zero_grad() |
|
self._log_metrics(telemetry) |
|
self._save_state() |
|
return loss.item(), ratio |
|
|
|
def train_epochs( |
|
self, |
|
bits: torch.Tensor, |
|
*, |
|
epochs: int = 1, |
|
compress_prob: float = 0.5, |
|
direct_prob: float = 0.0, |
|
batch_size: int = 8, |
|
num_workers: int = 0, |
|
accum_steps: int = 1, |
|
amp: bool = False, |
|
compile_model: bool = False, |
|
) -> List[Dict[str, float]]: |
|
"""Run ``train_loop`` on a batch tensor and persist the state.""" |
|
assert self.model is not None |
|
device = next(self.model.parameters()).device |
|
bits = bits.to(device) |
|
import math |
|
steps_per_epoch = max(1, math.ceil(len(bits) / batch_size)) |
|
self.total_steps = math.ceil(epochs * steps_per_epoch / accum_steps) |
|
self.optimizer, self.scheduler = configure_optimizer( |
|
self.model, lr=1e-3, total_steps=self.total_steps |
|
) |
|
metrics = train_loop( |
|
self.model, |
|
bits, |
|
epochs=epochs, |
|
compress_prob=compress_prob if self.use_compression else 0.0, |
|
direct_prob=direct_prob, |
|
batch_size=batch_size, |
|
num_workers=num_workers, |
|
accum_steps=accum_steps, |
|
amp=amp, |
|
compile_model=compile_model, |
|
forward_kwargs={"causal": self.causal}, |
|
optimizer=self.optimizer, |
|
scheduler=self.scheduler, |
|
) |
|
self._save_state() |
|
return metrics |
|
|
|
def scale_up(self, width_mult: float = 1.0) -> None: |
|
assert self.model is not None |
|
params = dict( |
|
d_model=int(self.model.d_model * width_mult), |
|
nhead=self.model.layers[0].self_attn.num_heads, |
|
num_layers=self.model.num_layers * 2, |
|
dim_feedforward=int(self.model.layers[0].linear1.out_features * width_mult), |
|
max_seq_len=self.model.pos_enc.pe.size(0), |
|
) |
|
self.model = expand_model(self.model, { |
|
**params, |
|
"lambda_K": self.lambda_K, |
|
"lambda_C": self.lambda_C, |
|
"lambda_S": self.lambda_S, |
|
}) |
|
self.optimizer, self.scheduler = configure_optimizer( |
|
self.model, lr=1e-3, total_steps=self.total_steps |
|
) |
|
self._save_state() |
|
|
|
def collapse(self, cluster_bits: List[List[int]], target_params: Dict, width_scale: float = 1.0) -> None: |
|
self.model, _ = collapse_submodel( |
|
cluster_bits, |
|
target_params, |
|
width_scale=width_scale, |
|
forward_kwargs={"causal": self.causal}, |
|
) |
|
self.model.set_lambdas(self.lambda_K, self.lambda_C, self.lambda_S) |
|
self.optimizer, self.scheduler = configure_optimizer( |
|
self.model, lr=1e-3, total_steps=self.total_steps |
|
) |
|
self._apply_device() |
|
for key in self.metrics: |
|
self.metrics[key].clear() |
|
|
|
def infer(self, bits: torch.Tensor) -> Dict: |
|
assert self.model is not None |
|
self.model.eval() |
|
device = next(self.model.parameters()).device |
|
bits = bits.to(device) |
|
ratio = 1.0 |
|
with torch.no_grad(): |
|
if self.use_compression: |
|
comps = [compress_bits(row.to(torch.uint8)) for row in bits] |
|
comp_len = sum(c.numel() for c in comps) |
|
ratio = min(comp_len / bits.numel(), 1.0) |
|
logits, telemetry = self.model.forward_compressed(comps, causal=self.causal) |
|
else: |
|
logits, telemetry = self.model(bits, causal=self.causal) |
|
self._log_metrics(telemetry) |
|
pred_bits = logits.argmax(-1) |
|
if self.decompress_output: |
|
try: |
|
pred_bits = model_output_decompress(pred_bits) |
|
except Exception as e: |
|
return {"error": f"Decompression failed: {e}", "suggestion": "Disable compression toggle."} |
|
def _to_python(obj): |
|
if isinstance(obj, torch.Tensor): |
|
return obj.tolist() |
|
if isinstance(obj, list): |
|
return [_to_python(o) for o in obj] |
|
if isinstance(obj, dict): |
|
return {kk: _to_python(vv) for kk, vv in obj.items()} |
|
return obj |
|
tele = {k: _to_python(v) for k, v in telemetry.items()} |
|
return {"predicted": pred_bits.squeeze(0).tolist(), "telemetry": tele, "ratio": ratio} |
|
|
|
def infer_long(self, bits: torch.Tensor, ctx_bits: int = 4096, overlap: int = 256) -> Dict: |
|
"""Run sliding-window inference on a long sequence.""" |
|
assert self.model is not None |
|
device = next(self.model.parameters()).device |
|
bits = bits.to(device) |
|
preds, logs = infer_long_sequence(self.model, bits.squeeze(0), ctx_bits=ctx_bits, overlap=overlap) |
|
for tele in logs: |
|
self._log_metrics(tele) |
|
return {"predicted": preds.tolist(), "windows": len(logs)} |
|
|
|
def _log_metrics(self, telemetry: Dict) -> None: |
|
for key in self.metrics: |
|
val = telemetry[key].mean().item() |
|
self.metrics[key].append(val) |
|
drift = detect_metric_drift( |
|
self.metrics, window=self.drift_window, threshold=self.drift_threshold |
|
) |
|
bad = [k for k, v in drift.items() if v] |
|
if bad: |
|
warnings.warn( |
|
f"Metric drift detected: {', '.join(bad)}", |
|
MetricDriftWarning, |
|
) |
|
|
|
def infer_text(self, text: str) -> Dict[str, Any]: |
|
"""Run text through the model using the safety gate.""" |
|
assert self.model is not None |
|
device = next(self.model.parameters()).device |
|
bits = torch.tensor(text_to_bits(text), dtype=torch.long).unsqueeze(0).to(device) |
|
out_bits, telemetry = hil_safe_inference( |
|
self.model, bits, c_floor=self.c_floor, s_floor=self.s_floor |
|
) |
|
self._log_metrics(telemetry) |
|
return { |
|
"output": bits_to_text(out_bits.squeeze(0).tolist()), |
|
"telemetry": telemetry, |
|
} |
|
|
|
def get_status(self) -> Dict[str, Any]: |
|
info: Dict[str, Any] = { |
|
"use_gpu": self.use_gpu, |
|
"diffusion": self.diffusion, |
|
"compression": self.use_compression, |
|
"lambda_K": self.lambda_K, |
|
"lambda_C": self.lambda_C, |
|
"lambda_S": self.lambda_S, |
|
"c_floor": self.c_floor, |
|
"s_floor": self.s_floor, |
|
"qat": self.qat, |
|
} |
|
if self.model is not None: |
|
info.update( |
|
{ |
|
"d_model": self.model.d_model, |
|
"num_layers": self.model.num_layers, |
|
"d_ff": self.model.layers[0].linear1.out_features, |
|
"nhead": self.model.layers[0].self_attn.num_heads, |
|
"max_seq_len": self.model.pos_enc.pe.size(0), |
|
} |
|
) |
|
else: |
|
info.update( |
|
{ |
|
"d_model": None, |
|
"num_layers": 0, |
|
"d_ff": None, |
|
"nhead": None, |
|
"max_seq_len": None, |
|
} |
|
) |
|
return info |
|
|
|
def get_model_config(self) -> Dict[str, Any]: |
|
"""Return current model hyperparameters and safety settings.""" |
|
cfg: Dict[str, Any] = { |
|
"lambda_K": self.lambda_K, |
|
"lambda_C": self.lambda_C, |
|
"lambda_S": self.lambda_S, |
|
"c_floor": self.c_floor, |
|
"s_floor": self.s_floor, |
|
} |
|
if self.model is not None: |
|
cfg.update( |
|
{ |
|
"d_model": self.model.d_model, |
|
"nhead": self.model.layers[0].self_attn.num_heads, |
|
"num_layers": self.model.num_layers, |
|
"dim_feedforward": self.model.layers[0].linear1.out_features, |
|
"max_seq_len": self.model.pos_enc.pe.size(0), |
|
"chunk_size": self.model.chunk_size, |
|
"reversible": self.model.reversible, |
|
"use_checkpoint": self.model.use_checkpoint, |
|
} |
|
) |
|
else: |
|
cfg.update( |
|
{ |
|
"d_model": None, |
|
"nhead": None, |
|
"num_layers": 0, |
|
"dim_feedforward": None, |
|
"max_seq_len": None, |
|
"chunk_size": None, |
|
"reversible": None, |
|
"use_checkpoint": None, |
|
} |
|
) |
|
return cfg |
|
|
|
def get_metrics(self) -> Dict[str, Any]: |
|
"""Return logged telemetry metrics with summary statistics.""" |
|
from statistics import mean, stdev |
|
|
|
data = { |
|
"negentropy": self.metrics["negentropy_logits"], |
|
"lz_complexity": self.metrics["lz_complexity_logits"], |
|
"symbiosis": self.metrics["symbiosis_score"], |
|
} |
|
summary: Dict[str, Dict[str, Optional[float]]] = {} |
|
for key, values in data.items(): |
|
if values: |
|
m = mean(values) |
|
s = stdev(values) if len(values) > 1 else 0.0 |
|
summary[key] = {"mean": m, "std": s} |
|
else: |
|
summary[key] = {"mean": None, "std": None} |
|
data["summary"] = summary |
|
return data |
|
|
|
|
|
def _save_state(self) -> None: |
|
if self.model is None: |
|
return |
|
torch.save(self.model, self.weights_path) |
|
with open(self.telemetry_log, "w") as f: |
|
json.dump(self.metrics, f) |
|
|
|
|
|
manager: Optional[ModelManager] = None |
|
|
|
|
|
@app.route("/") |
|
def index(): |
|
return render_template( |
|
"dashboard.html", |
|
metrics=manager.metrics, |
|
lambdas={ |
|
"lambda_K": manager.lambda_K, |
|
"lambda_C": manager.lambda_C, |
|
"lambda_S": manager.lambda_S, |
|
}, |
|
diffusion=manager.diffusion, |
|
compression=manager.use_compression, |
|
defaults={k: v.default for k, v in inspect.signature(BitTransformerLM.__init__).parameters.items() if v.default is not inspect._empty}, |
|
c_floor=manager.c_floor, |
|
s_floor=manager.s_floor, |
|
qat=manager.qat, |
|
) |
|
|
|
|
|
@app.route("/status", methods=["GET"]) |
|
def status(): |
|
if MCP_SERVER_ADDR: |
|
return jsonify(mcp_get("/status")) |
|
return jsonify(manager.get_status()) |
|
|
|
|
|
@app.route("/model_config", methods=["GET"]) |
|
def model_config(): |
|
if MCP_SERVER_ADDR: |
|
return jsonify(mcp_get("/model_config")) |
|
return jsonify(manager.get_model_config()) |
|
|
|
|
|
@app.route("/metrics", methods=["GET"]) |
|
def metrics(): |
|
if MCP_SERVER_ADDR: |
|
return jsonify(mcp_get("/metrics")) |
|
return jsonify(manager.get_metrics()) |
|
|
|
|
|
@app.route("/save_checkpoint", methods=["POST"]) |
|
def save_checkpoint_route(): |
|
repo_id = request.json.get("repo_id") |
|
token = request.json.get("token") or os.getenv("HF_TOKEN") |
|
if MCP_SERVER_ADDR: |
|
return jsonify(mcp_post("/save_checkpoint", {"repo_id": repo_id, "token": token})) |
|
if manager.model is None: |
|
return jsonify({"error": "model not initialized"}), 400 |
|
if token: |
|
hf_login(token=token) |
|
save_checkpoint(manager.model, repo_id=repo_id) |
|
return jsonify({"status": "saved"}) |
|
|
|
|
|
@app.route("/download_checkpoint", methods=["POST"]) |
|
def download_checkpoint_route(): |
|
repo_id = request.json.get("repo_id") |
|
token = request.json.get("token") or os.getenv("HF_TOKEN") |
|
if MCP_SERVER_ADDR: |
|
return jsonify(mcp_post("/download_checkpoint", {"repo_id": repo_id, "token": token})) |
|
if token: |
|
hf_login(token=token) |
|
dest = manager.weights_path + ".gz" |
|
ok = download_checkpoint(dest, repo_id=repo_id) |
|
if not ok: |
|
return jsonify({"status": "failed"}), 500 |
|
if manager.model is None: |
|
return jsonify({"status": "downloaded", "loaded": False}) |
|
with gzip.open(dest, "rb") as f: |
|
state = torch.load(f, map_location="cpu") |
|
manager.model.load_state_dict(state) |
|
manager.optimizer, manager.scheduler = configure_optimizer( |
|
manager.model, lr=1e-3, total_steps=manager.total_steps |
|
) |
|
manager._apply_device() |
|
manager._save_state() |
|
return jsonify({"status": "downloaded", "loaded": True}) |
|
|
|
|
|
@app.route("/text_to_bits", methods=["POST"]) |
|
def text_to_bits_route(): |
|
text = request.json.get("text", "") |
|
if len(text) > 100_000: |
|
return jsonify({"error": "text too large"}), 413 |
|
return jsonify({"bits": text_to_bits(text)}) |
|
|
|
|
|
@app.route("/dataset", methods=["GET"]) |
|
def dataset_route(): |
|
name = request.args.get("name", "") |
|
split = request.args.get("split", "train") |
|
size = int(request.args.get("size", 1)) |
|
seq_len = int(request.args.get("seq_len", 64)) |
|
if size * seq_len > 1_000_000: |
|
return jsonify({"error": "dataset too large"}), 413 |
|
if name == "wikitext2": |
|
try: |
|
from datasets import load_dataset |
|
|
|
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split=split) |
|
lines = [t for t in ds["text"] if t.strip()][:size] |
|
except Exception: |
|
bits = torch.randint(0, 2, (size, seq_len), dtype=torch.long) |
|
return jsonify({"bits": bits.tolist()}) |
|
bits_list = [] |
|
for text in lines: |
|
b = text_to_bits(text)[:seq_len] |
|
if len(b) < seq_len: |
|
b.extend([0] * (seq_len - len(b))) |
|
bits_list.append(b) |
|
if len(bits_list) < size: |
|
pad = size - len(bits_list) |
|
bits_list.extend(torch.randint(0, 2, (pad, seq_len), dtype=torch.long).tolist()) |
|
return jsonify({"bits": bits_list}) |
|
return jsonify({"error": "unknown dataset"}), 400 |
|
|
|
|
|
@app.route("/init", methods=["POST"]) |
|
def init_model(): |
|
data = request.json or {} |
|
int_fields = { |
|
"d_model", |
|
"nhead", |
|
"num_layers", |
|
"dim_feedforward", |
|
"max_seq_len", |
|
"chunk_size", |
|
"overlap", |
|
} |
|
float_fields = {"act_threshold"} |
|
bool_fields = {"reversible", "use_checkpoint"} |
|
params = {} |
|
for k, v in data.items(): |
|
if v is None: |
|
params[k] = None |
|
elif k in int_fields: |
|
params[k] = int(v) |
|
elif k in float_fields: |
|
params[k] = float(v) |
|
elif k in bool_fields: |
|
params[k] = bool(v) |
|
else: |
|
params[k] = v |
|
if MCP_SERVER_ADDR: |
|
data = mcp_post("/init", params) |
|
return jsonify(data) |
|
manager.init_model(params) |
|
return jsonify({"status": "initialized", "params": params}) |
|
|
|
|
|
@app.route("/train", methods=["POST"]) |
|
def train_model(): |
|
bits = torch.tensor(request.json["bits"], dtype=torch.long) |
|
if MCP_SERVER_ADDR: |
|
data = mcp_post("/train", {"bits": request.json["bits"]}) |
|
return jsonify(data) |
|
loss, ratio = manager.train_step(bits) |
|
return jsonify({"loss": loss, "ratio": ratio}) |
|
|
|
|
|
@app.route("/train_epochs", methods=["POST"]) |
|
def train_epochs_route(): |
|
bits = torch.tensor(request.json["bits"], dtype=torch.long) |
|
epochs = int(request.json.get("epochs", 1)) |
|
compress_prob = float(request.json.get("compress_prob", 0.5)) |
|
direct_prob = float(request.json.get("direct_prob", 0.0)) |
|
if MCP_SERVER_ADDR: |
|
data = mcp_post( |
|
"/train_epochs", |
|
{ |
|
"bits": request.json["bits"], |
|
"epochs": epochs, |
|
"compress_prob": compress_prob, |
|
"direct_prob": direct_prob, |
|
}, |
|
) |
|
return jsonify(data) |
|
metrics = manager.train_epochs( |
|
bits, |
|
epochs=epochs, |
|
compress_prob=compress_prob, |
|
direct_prob=direct_prob, |
|
) |
|
return jsonify({"metrics": metrics}) |
|
|
|
|
|
@app.route("/scale_up", methods=["POST"]) |
|
def scale_up(): |
|
width_mult = float(request.json.get("width_mult", 1.0)) |
|
if MCP_SERVER_ADDR: |
|
data = mcp_post("/scale_up", {"width_mult": width_mult}) |
|
return jsonify(data) |
|
manager.scale_up(width_mult) |
|
return jsonify({ |
|
"status": "scaled", |
|
"layers": manager.model.num_layers, |
|
"d_model": manager.model.d_model, |
|
}) |
|
|
|
|
|
@app.route("/collapse", methods=["POST"]) |
|
def collapse_model(): |
|
cluster_bits = request.json["clusters"] |
|
params = {k: int(v) for k, v in request.json["params"].items()} |
|
width_scale = float(request.json.get("width_scale", 1.0)) |
|
if MCP_SERVER_ADDR: |
|
data = mcp_post( |
|
"/collapse", |
|
{"clusters": cluster_bits, "params": params, "width_scale": width_scale}, |
|
) |
|
return jsonify(data) |
|
manager.collapse(cluster_bits, params, width_scale) |
|
return jsonify({"status": "collapsed"}) |
|
|
|
|
|
@app.route("/lambdas", methods=["GET", "POST"]) |
|
def update_lambdas(): |
|
if request.method == "POST": |
|
data = request.json |
|
if MCP_SERVER_ADDR: |
|
res = mcp_post("/lambdas", data) |
|
return jsonify(res) |
|
manager.set_lambdas( |
|
float(data["lambda_K"]), float(data["lambda_C"]), float(data["lambda_S"]) |
|
) |
|
return jsonify({"status": "updated"}) |
|
else: |
|
if MCP_SERVER_ADDR: |
|
return jsonify(mcp_get("/lambdas")) |
|
return jsonify( |
|
{ |
|
"lambda_K": manager.lambda_K, |
|
"lambda_C": manager.lambda_C, |
|
"lambda_S": manager.lambda_S, |
|
} |
|
) |
|
|
|
|
|
@app.route("/config/telemetry", methods=["GET", "POST"]) |
|
def telemetry_config(): |
|
"""Get or update telemetry λ weights and safety floors.""" |
|
if request.method == "POST": |
|
data = request.json |
|
if MCP_SERVER_ADDR: |
|
res = mcp_post("/config/telemetry", data) |
|
return jsonify(res) |
|
manager.set_lambdas( |
|
float(data.get("lambda_K", manager.lambda_K)), |
|
float(data.get("lambda_C", manager.lambda_C)), |
|
float(data.get("lambda_S", manager.lambda_S)), |
|
) |
|
manager.set_floors( |
|
float(data.get("c_floor", manager.c_floor)), |
|
float(data.get("s_floor", manager.s_floor)), |
|
) |
|
return jsonify({"status": "updated"}) |
|
else: |
|
if MCP_SERVER_ADDR: |
|
return jsonify(mcp_get("/config/telemetry")) |
|
return jsonify( |
|
{ |
|
"lambda_K": manager.lambda_K, |
|
"lambda_C": manager.lambda_C, |
|
"lambda_S": manager.lambda_S, |
|
"c_floor": manager.c_floor, |
|
"s_floor": manager.s_floor, |
|
} |
|
) |
|
|
|
|
|
@app.route("/diffusion", methods=["GET", "POST"]) |
|
def update_diffusion(): |
|
if request.method == "POST": |
|
if MCP_SERVER_ADDR: |
|
return jsonify(mcp_post("/diffusion", request.json)) |
|
manager.set_diffusion(bool(request.json.get("diffusion", False))) |
|
return jsonify({"status": "updated"}) |
|
else: |
|
if MCP_SERVER_ADDR: |
|
return jsonify(mcp_get("/diffusion")) |
|
return jsonify({"diffusion": manager.diffusion}) |
|
|
|
|
|
@app.route("/gpu", methods=["GET", "POST"]) |
|
def update_gpu(): |
|
if request.method == "POST": |
|
if MCP_SERVER_ADDR: |
|
return jsonify(mcp_post("/gpu", request.json)) |
|
manager.set_gpu(bool(request.json.get("use_gpu", False))) |
|
return jsonify({"status": "updated"}) |
|
else: |
|
if MCP_SERVER_ADDR: |
|
return jsonify(mcp_get("/gpu")) |
|
return jsonify({"use_gpu": manager.use_gpu}) |
|
|
|
|
|
@app.route("/compression", methods=["GET", "POST"]) |
|
def update_compression(): |
|
if request.method == "POST": |
|
if MCP_SERVER_ADDR: |
|
return jsonify(mcp_post("/compression", request.json)) |
|
manager.set_compression(bool(request.json.get("compression", False))) |
|
return jsonify({"status": "updated"}) |
|
else: |
|
if MCP_SERVER_ADDR: |
|
return jsonify(mcp_get("/compression")) |
|
return jsonify({"compression": manager.use_compression}) |
|
|
|
|
|
@app.route("/qat", methods=["GET", "POST"]) |
|
def update_qat(): |
|
if request.method == "POST": |
|
if MCP_SERVER_ADDR: |
|
return jsonify(mcp_post("/qat", request.json)) |
|
manager.set_qat(bool(request.json.get("qat", False))) |
|
return jsonify({"status": "updated"}) |
|
else: |
|
if MCP_SERVER_ADDR: |
|
return jsonify(mcp_get("/qat")) |
|
return jsonify({"qat": manager.qat}) |
|
|
|
|
|
@app.route("/infer", methods=["POST"]) |
|
def inference(): |
|
bits = torch.tensor(request.json["bits"], dtype=torch.long) |
|
if MCP_SERVER_ADDR: |
|
data = mcp_post("/infer", {"bits": request.json["bits"]}) |
|
return jsonify(data) |
|
result = manager.infer(bits) |
|
return jsonify(result) |
|
|
|
|
|
@app.route("/infer_long", methods=["POST"]) |
|
def inference_long(): |
|
bits = torch.tensor(request.json["bits"], dtype=torch.long) |
|
ctx = int(request.json.get("ctx_bits", 4096)) |
|
overlap = int(request.json.get("overlap", 256)) |
|
if MCP_SERVER_ADDR: |
|
data = mcp_post( |
|
"/infer_long", |
|
{"bits": request.json["bits"], "ctx_bits": ctx, "overlap": overlap}, |
|
) |
|
return jsonify(data) |
|
result = manager.infer_long(bits, ctx_bits=ctx, overlap=overlap) |
|
return jsonify(result) |
|
|
|
|
|
@app.route("/infer_text", methods=["POST"]) |
|
def inference_text(): |
|
text = request.json.get("text", "") |
|
if MCP_SERVER_ADDR: |
|
data = mcp_post("/infer_text", {"text": text}) |
|
return jsonify(data) |
|
result = manager.infer_text(text) |
|
return jsonify(result) |
|
|
|
@app.route("/plot.png") |
|
def plot_png(): |
|
if MCP_SERVER_ADDR: |
|
resp = requests.get(MCP_SERVER_ADDR.rstrip("/") + "/plot.png") |
|
resp.raise_for_status() |
|
return send_file(io.BytesIO(resp.content), mimetype="image/png") |
|
fig, _ = plot_telemetry(manager.metrics) |
|
buf = io.BytesIO() |
|
fig.savefig(buf, format="png") |
|
plt.close(fig) |
|
buf.seek(0) |
|
return send_file(buf, mimetype="image/png") |
|
|
|
|
|
def run_dashboard(host: Optional[str] = None, port: Optional[int] = None, |
|
snapshot_dir: Optional[str] = None, telemetry_log: Optional[str] = None) -> None: |
|
"""Launch the Flask dashboard server.""" |
|
env_host = os.getenv("HOST", "0.0.0.0") |
|
env_port = int(os.getenv("PORT", "5000")) |
|
host = host or env_host |
|
port = port or env_port |
|
global manager |
|
if manager is None: |
|
manager = ModelManager(snapshot_dir, telemetry_log) |
|
app.run(host=host, port=port, debug=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser(description="Run dashboard server") |
|
parser.add_argument("--host", default=os.getenv("HOST", "0.0.0.0")) |
|
parser.add_argument("--port", type=int, default=int(os.getenv("PORT", "5000"))) |
|
parser.add_argument("--snapshot-dir", default=os.getenv("SNAPSHOT_DIR", "snapshots")) |
|
parser.add_argument("--telemetry-log", default=os.getenv("TELEMETRY_LOG")) |
|
args = parser.parse_args() |
|
run_dashboard(args.host, args.port, args.snapshot_dir, args.telemetry_log) |
|
|