import io import os import gzip import uuid import traceback from concurrent.futures import ThreadPoolExecutor from flask import Flask, request, jsonify, send_file import matplotlib.pyplot as plt import torch from bit_transformer.dashboard_app import ModelManager from bit_transformer.dashboard import plot_telemetry from bit_transformer.hf_checkpoint import hf_login, save_checkpoint, download_checkpoint from bit_transformer.optimization import configure_optimizer from bit_transformer.bit_io import text_to_bits from bit_transformer.dataset_builder import BitTransformerDatasetBuilder, create_bittransformerlm_dataset app = Flask(__name__) manager = ModelManager() # background job management executor = ThreadPoolExecutor(max_workers=4) jobs: dict[str, dict] = {} def _submit_job(fn, *args, **kwargs) -> str: """Schedule a function for background execution and return a job id.""" job_id = str(uuid.uuid4()) jobs[job_id] = {"status": "queued", "result": None, "error": None, "logs": []} def wrapper(): jobs[job_id]["status"] = "running" try: jobs[job_id]["result"] = fn(*args, **kwargs) jobs[job_id]["status"] = "completed" except Exception as err: # pragma: no cover - captured for client jobs[job_id]["status"] = "error" jobs[job_id]["error"] = str(err) jobs[job_id]["trace"] = traceback.format_exc() executor.submit(wrapper) return job_id @app.errorhandler(Exception) def handle_exception(err): """Return JSON error responses with stack traces.""" return ( jsonify({"error": str(err), "trace": traceback.format_exc()}), getattr(err, "code", 500), ) @app.route("/init", methods=["POST"]) def init_model(): data = request.json or {} int_fields = { "d_model", "nhead", "num_layers", "dim_feedforward", "max_seq_len", "chunk_size", "overlap", } float_fields = {"act_threshold"} bool_fields = {"reversible", "use_checkpoint"} params = {} for k, v in data.items(): if v is None: params[k] = None elif k in int_fields: params[k] = int(v) elif k in float_fields: params[k] = float(v) elif k in bool_fields: params[k] = bool(v) else: params[k] = v manager.init_model(params) return jsonify({"status": "initialized", "params": params}) @app.route("/train", methods=["POST"]) def train_model(): bits = request.json["bits"] def task(): tensor = torch.tensor(bits, dtype=torch.long) loss, ratio = manager.train_step(tensor) return {"loss": loss, "ratio": ratio} job_id = _submit_job(task) return jsonify({"job_id": job_id}) @app.route("/train_epochs", methods=["POST"]) def train_epochs_route(): data = request.json bits = data["bits"] epochs = int(data.get("epochs", 1)) compress_prob = float(data.get("compress_prob", 0.5)) direct_prob = float(data.get("direct_prob", 0.0)) def task(): tensor = torch.tensor(bits, dtype=torch.long) metrics = manager.train_epochs( tensor, epochs=epochs, compress_prob=compress_prob, direct_prob=direct_prob, ) return {"metrics": metrics} job_id = _submit_job(task) return jsonify({"job_id": job_id}) @app.route("/scale_up", methods=["POST"]) def scale_up(): width_mult = float(request.json.get("width_mult", 1.0)) def task(): manager.scale_up(width_mult) return { "status": "scaled", "layers": manager.model.num_layers, "d_model": manager.model.d_model, } job_id = _submit_job(task) return jsonify({"job_id": job_id}) @app.route("/collapse", methods=["POST"]) def collapse_model(): cluster_bits = request.json["clusters"] params = {k: int(v) for k, v in request.json["params"].items()} width_scale = float(request.json.get("width_scale", 1.0)) def task(): manager.collapse(cluster_bits, params, width_scale) return {"status": "collapsed"} job_id = _submit_job(task) return jsonify({"job_id": job_id}) @app.route("/job/", methods=["GET"]) def get_job(job_id: str): job = jobs.get(job_id) if job is None: return jsonify({"error": "not found"}), 404 return jsonify(job) @app.route("/jobs", methods=["GET"]) def list_jobs(): return jsonify(jobs) @app.route("/lambdas", methods=["GET", "POST"]) def update_lambdas(): if request.method == "POST": data = request.json manager.set_lambdas(float(data["lambda_K"]), float(data["lambda_C"]), float(data["lambda_S"])) return jsonify({"status": "updated"}) else: return jsonify({ "lambda_K": manager.lambda_K, "lambda_C": manager.lambda_C, "lambda_S": manager.lambda_S, }) @app.route("/diffusion", methods=["GET", "POST"]) def update_diffusion(): if request.method == "POST": manager.set_diffusion(bool(request.json.get("diffusion", False))) return jsonify({"status": "updated"}) return jsonify({"diffusion": manager.diffusion}) @app.route("/qat", methods=["GET", "POST"]) def update_qat(): if request.method == "POST": manager.set_qat(bool(request.json.get("qat", False))) return jsonify({"status": "updated"}) return jsonify({"qat": manager.qat}) @app.route("/gpu", methods=["GET", "POST"]) def update_gpu(): if request.method == "POST": manager.set_gpu(bool(request.json.get("use_gpu", False))) return jsonify({"status": "updated"}) return jsonify({"use_gpu": manager.use_gpu}) @app.route("/infer", methods=["POST"]) def inference(): bits = torch.tensor(request.json["bits"], dtype=torch.long) result = manager.infer(bits) return jsonify(result) @app.route("/infer_long", methods=["POST"]) def inference_long(): bits = torch.tensor(request.json["bits"], dtype=torch.long) ctx = int(request.json.get("ctx_bits", 4096)) overlap = int(request.json.get("overlap", 256)) result = manager.infer_long(bits, ctx_bits=ctx, overlap=overlap) return jsonify(result) @app.route("/infer_text", methods=["POST"]) def inference_text(): text = request.json.get("text", "") result = manager.infer_text(text) return jsonify(result) @app.route("/status", methods=["GET"]) def status(): return jsonify(manager.get_status()) @app.route("/model_config", methods=["GET"]) def model_config(): return jsonify(manager.get_model_config()) @app.route("/metrics", methods=["GET"]) def metrics(): return jsonify(manager.get_metrics()) @app.route("/save_checkpoint", methods=["POST"]) def save_checkpoint_route(): repo_id = request.json.get("repo_id") token = request.json.get("token") or os.getenv("HF_TOKEN") if manager.model is None: return jsonify({"error": "model not initialized"}), 400 if token: hf_login(token=token) save_checkpoint(manager.model, repo_id=repo_id) return jsonify({"status": "saved"}) @app.route("/download_checkpoint", methods=["POST"]) def download_checkpoint_route(): repo_id = request.json.get("repo_id") token = request.json.get("token") or os.getenv("HF_TOKEN") if token: hf_login(token=token) dest = manager.weights_path + ".gz" ok = download_checkpoint(dest, repo_id=repo_id) if not ok: return jsonify({"status": "failed"}), 500 if manager.model is None: return jsonify({"status": "downloaded", "loaded": False}) with gzip.open(dest, "rb") as f: state = torch.load(f, map_location="cpu") manager.model.load_state_dict(state) manager.optimizer, manager.scheduler = configure_optimizer( manager.model, lr=1e-3, total_steps=manager.total_steps ) manager._apply_device() manager._save_state() return jsonify({"status": "downloaded", "loaded": True}) @app.route("/plot.png") def plot_png(): fig, _ = plot_telemetry(manager.metrics) buf = io.BytesIO() fig.savefig(buf, format="png") plt.close(fig) buf.seek(0) return send_file(buf, mimetype="image/png") @app.route("/text_to_bits", methods=["POST"]) def text_to_bits_route(): text = request.json.get("text", "") if len(text) > 100_000: return jsonify({"error": "text too large"}), 413 return jsonify({"bits": text_to_bits(text)}) @app.route("/dataset", methods=["GET"]) def dataset_route(): name = request.args.get("name", "") split = request.args.get("split", "train") size = int(request.args.get("size", 1)) seq_len = int(request.args.get("seq_len", 64)) if size * seq_len > 1_000_000: return jsonify({"error": "dataset too large"}), 413 if name == "wikitext2": try: from datasets import load_dataset ds = load_dataset("wikitext", "wikitext-2-raw-v1", split=split) lines = [t for t in ds["text"] if t.strip()][:size] except Exception: bits = torch.randint(0, 2, (size, seq_len), dtype=torch.long) return jsonify({"bits": bits.tolist()}) bits_list = [] for text in lines: b = text_to_bits(text)[:seq_len] if len(b) < seq_len: b.extend([0] * (seq_len - len(b))) bits_list.append(b) if len(bits_list) < size: pad = size - len(bits_list) bits_list.extend(torch.randint(0, 2, (pad, seq_len), dtype=torch.long).tolist()) return jsonify({"bits": bits_list}) return jsonify({"error": "unknown dataset"}), 400 # Dataset Management Endpoints @app.route("/dataset/create", methods=["POST"]) def create_dataset(): """Create and upload a new BitTransformerLM dataset.""" data = request.json or {} hf_token = data.get("hf_token") or os.getenv("HF_TOKEN") repo_id = data.get("repo_id", "BitTransformerLM") source_texts = data.get("source_texts", None) if not hf_token: return jsonify({"error": "HF token required"}), 400 def task(): try: dataset_url = create_bittransformerlm_dataset( hf_token=hf_token, repo_id=repo_id, source_texts=source_texts ) return { "status": "success", "dataset_url": dataset_url, "repo_id": repo_id } except Exception as e: return { "status": "error", "error": str(e) } job_id = _submit_job(task) return jsonify({"job_id": job_id, "message": "Dataset creation started"}) @app.route("/dataset/builder", methods=["POST"]) def create_dataset_builder(): """Initialize a dataset builder for custom dataset creation.""" data = request.json or {} hf_token = data.get("hf_token") or os.getenv("HF_TOKEN") repo_id = data.get("repo_id", "BitTransformerLM") if not hf_token: return jsonify({"error": "HF token required"}), 400 try: builder = BitTransformerDatasetBuilder(hf_token, repo_id) # Store builder configuration builder_info = { "repo_id": repo_id, "config": builder.config, "status": "ready" } return jsonify({ "status": "builder_created", "builder_info": builder_info }) except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/dataset/generate", methods=["POST"]) def generate_dataset_samples(): """Generate specific types of dataset samples.""" data = request.json or {} sample_type = data.get("type", "text_to_bits") # text_to_bits, synthetic, safety, compression count = int(data.get("count", 100)) max_len = int(data.get("max_len", 256)) texts = data.get("texts", None) if count > 5000: return jsonify({"error": "count too large, max 5000"}), 400 def task(): try: # Create temporary builder (no upload) builder = BitTransformerDatasetBuilder("dummy_token", "temp") if sample_type == "text_to_bits": if not texts: texts = builder._get_default_texts()[:count] samples = builder.generate_text_to_bits_data(texts[:count], max_len) elif sample_type == "synthetic": samples = builder.generate_synthetic_patterns(count, max_len) elif sample_type == "safety": samples = builder.generate_safety_benchmarks(count) elif sample_type == "compression": # Need base samples first base_texts = builder._get_default_texts()[:50] base_samples = builder.generate_text_to_bits_data(base_texts, max_len) samples = builder.generate_compression_variants(base_samples)[:count] else: return {"error": f"Unknown sample type: {sample_type}"} return { "status": "success", "samples": samples[:10], # Return first 10 for preview "total_generated": len(samples), "sample_type": sample_type } except Exception as e: return {"error": str(e)} job_id = _submit_job(task) return jsonify({"job_id": job_id, "message": f"Generating {sample_type} samples"}) @app.route("/dataset/info", methods=["GET"]) def dataset_info(): """Get information about available dataset generation options.""" return jsonify({ "sample_types": [ { "type": "text_to_bits", "description": "Convert text to parity-protected bit sequences", "parameters": ["texts", "max_len"] }, { "type": "synthetic", "description": "Generate synthetic bit patterns", "parameters": ["count", "max_len"], "patterns": ["alternating", "blocks", "fibonacci", "prime_based", "random_walk"] }, { "type": "safety", "description": "Generate safety benchmark sequences", "parameters": ["count"], "categories": ["low_entropy", "medium_entropy", "high_entropy", "edge_cases"] }, { "type": "compression", "description": "Generate compressed variants of base sequences", "parameters": ["count", "compression_ratios"] } ], "default_config": { "max_sequence_length": 512, "total_samples": 25000, "safety_thresholds": { "min_negentropy": 0.1, "max_lz_complexity": 0.9, "min_symbiosis": 0.3 } } }) @app.route("/health") def health_check(): return jsonify({"status": "ok"}) def run_mcp_server(host: str = "0.0.0.0", port: int = 7000) -> None: app.run(host=host, port=port, debug=True) if __name__ == "__main__": import torch run_mcp_server()