|
import io |
|
import os |
|
import gzip |
|
import uuid |
|
import traceback |
|
from concurrent.futures import ThreadPoolExecutor |
|
from flask import Flask, request, jsonify, send_file |
|
import matplotlib.pyplot as plt |
|
import torch |
|
|
|
from bit_transformer.dashboard_app import ModelManager |
|
from bit_transformer.dashboard import plot_telemetry |
|
from bit_transformer.hf_checkpoint import hf_login, save_checkpoint, download_checkpoint |
|
from bit_transformer.optimization import configure_optimizer |
|
from bit_transformer.bit_io import text_to_bits |
|
from bit_transformer.dataset_builder import BitTransformerDatasetBuilder, create_bittransformerlm_dataset |
|
|
|
app = Flask(__name__) |
|
manager = ModelManager() |
|
|
|
|
|
executor = ThreadPoolExecutor(max_workers=4) |
|
jobs: dict[str, dict] = {} |
|
|
|
|
|
def _submit_job(fn, *args, **kwargs) -> str: |
|
"""Schedule a function for background execution and return a job id.""" |
|
job_id = str(uuid.uuid4()) |
|
jobs[job_id] = {"status": "queued", "result": None, "error": None, "logs": []} |
|
|
|
def wrapper(): |
|
jobs[job_id]["status"] = "running" |
|
try: |
|
jobs[job_id]["result"] = fn(*args, **kwargs) |
|
jobs[job_id]["status"] = "completed" |
|
except Exception as err: |
|
jobs[job_id]["status"] = "error" |
|
jobs[job_id]["error"] = str(err) |
|
jobs[job_id]["trace"] = traceback.format_exc() |
|
|
|
executor.submit(wrapper) |
|
return job_id |
|
|
|
|
|
@app.errorhandler(Exception) |
|
def handle_exception(err): |
|
"""Return JSON error responses with stack traces.""" |
|
return ( |
|
jsonify({"error": str(err), "trace": traceback.format_exc()}), |
|
getattr(err, "code", 500), |
|
) |
|
|
|
|
|
@app.route("/init", methods=["POST"]) |
|
def init_model(): |
|
data = request.json or {} |
|
int_fields = { |
|
"d_model", |
|
"nhead", |
|
"num_layers", |
|
"dim_feedforward", |
|
"max_seq_len", |
|
"chunk_size", |
|
"overlap", |
|
} |
|
float_fields = {"act_threshold"} |
|
bool_fields = {"reversible", "use_checkpoint"} |
|
params = {} |
|
for k, v in data.items(): |
|
if v is None: |
|
params[k] = None |
|
elif k in int_fields: |
|
params[k] = int(v) |
|
elif k in float_fields: |
|
params[k] = float(v) |
|
elif k in bool_fields: |
|
params[k] = bool(v) |
|
else: |
|
params[k] = v |
|
manager.init_model(params) |
|
return jsonify({"status": "initialized", "params": params}) |
|
|
|
@app.route("/train", methods=["POST"]) |
|
def train_model(): |
|
bits = request.json["bits"] |
|
|
|
def task(): |
|
tensor = torch.tensor(bits, dtype=torch.long) |
|
loss, ratio = manager.train_step(tensor) |
|
return {"loss": loss, "ratio": ratio} |
|
|
|
job_id = _submit_job(task) |
|
return jsonify({"job_id": job_id}) |
|
|
|
|
|
@app.route("/train_epochs", methods=["POST"]) |
|
def train_epochs_route(): |
|
data = request.json |
|
bits = data["bits"] |
|
epochs = int(data.get("epochs", 1)) |
|
compress_prob = float(data.get("compress_prob", 0.5)) |
|
direct_prob = float(data.get("direct_prob", 0.0)) |
|
|
|
def task(): |
|
tensor = torch.tensor(bits, dtype=torch.long) |
|
metrics = manager.train_epochs( |
|
tensor, |
|
epochs=epochs, |
|
compress_prob=compress_prob, |
|
direct_prob=direct_prob, |
|
) |
|
return {"metrics": metrics} |
|
|
|
job_id = _submit_job(task) |
|
return jsonify({"job_id": job_id}) |
|
|
|
@app.route("/scale_up", methods=["POST"]) |
|
def scale_up(): |
|
width_mult = float(request.json.get("width_mult", 1.0)) |
|
|
|
def task(): |
|
manager.scale_up(width_mult) |
|
return { |
|
"status": "scaled", |
|
"layers": manager.model.num_layers, |
|
"d_model": manager.model.d_model, |
|
} |
|
|
|
job_id = _submit_job(task) |
|
return jsonify({"job_id": job_id}) |
|
|
|
@app.route("/collapse", methods=["POST"]) |
|
def collapse_model(): |
|
cluster_bits = request.json["clusters"] |
|
params = {k: int(v) for k, v in request.json["params"].items()} |
|
width_scale = float(request.json.get("width_scale", 1.0)) |
|
|
|
def task(): |
|
manager.collapse(cluster_bits, params, width_scale) |
|
return {"status": "collapsed"} |
|
|
|
job_id = _submit_job(task) |
|
return jsonify({"job_id": job_id}) |
|
|
|
|
|
@app.route("/job/<job_id>", methods=["GET"]) |
|
def get_job(job_id: str): |
|
job = jobs.get(job_id) |
|
if job is None: |
|
return jsonify({"error": "not found"}), 404 |
|
return jsonify(job) |
|
|
|
|
|
@app.route("/jobs", methods=["GET"]) |
|
def list_jobs(): |
|
return jsonify(jobs) |
|
|
|
@app.route("/lambdas", methods=["GET", "POST"]) |
|
def update_lambdas(): |
|
if request.method == "POST": |
|
data = request.json |
|
manager.set_lambdas(float(data["lambda_K"]), float(data["lambda_C"]), float(data["lambda_S"])) |
|
return jsonify({"status": "updated"}) |
|
else: |
|
return jsonify({ |
|
"lambda_K": manager.lambda_K, |
|
"lambda_C": manager.lambda_C, |
|
"lambda_S": manager.lambda_S, |
|
}) |
|
|
|
@app.route("/diffusion", methods=["GET", "POST"]) |
|
def update_diffusion(): |
|
if request.method == "POST": |
|
manager.set_diffusion(bool(request.json.get("diffusion", False))) |
|
return jsonify({"status": "updated"}) |
|
return jsonify({"diffusion": manager.diffusion}) |
|
|
|
|
|
@app.route("/qat", methods=["GET", "POST"]) |
|
def update_qat(): |
|
if request.method == "POST": |
|
manager.set_qat(bool(request.json.get("qat", False))) |
|
return jsonify({"status": "updated"}) |
|
return jsonify({"qat": manager.qat}) |
|
|
|
|
|
@app.route("/gpu", methods=["GET", "POST"]) |
|
def update_gpu(): |
|
if request.method == "POST": |
|
manager.set_gpu(bool(request.json.get("use_gpu", False))) |
|
return jsonify({"status": "updated"}) |
|
return jsonify({"use_gpu": manager.use_gpu}) |
|
|
|
@app.route("/infer", methods=["POST"]) |
|
def inference(): |
|
bits = torch.tensor(request.json["bits"], dtype=torch.long) |
|
result = manager.infer(bits) |
|
return jsonify(result) |
|
|
|
|
|
@app.route("/infer_long", methods=["POST"]) |
|
def inference_long(): |
|
bits = torch.tensor(request.json["bits"], dtype=torch.long) |
|
ctx = int(request.json.get("ctx_bits", 4096)) |
|
overlap = int(request.json.get("overlap", 256)) |
|
result = manager.infer_long(bits, ctx_bits=ctx, overlap=overlap) |
|
return jsonify(result) |
|
|
|
@app.route("/infer_text", methods=["POST"]) |
|
def inference_text(): |
|
text = request.json.get("text", "") |
|
result = manager.infer_text(text) |
|
return jsonify(result) |
|
|
|
@app.route("/status", methods=["GET"]) |
|
def status(): |
|
return jsonify(manager.get_status()) |
|
|
|
|
|
@app.route("/model_config", methods=["GET"]) |
|
def model_config(): |
|
return jsonify(manager.get_model_config()) |
|
|
|
|
|
@app.route("/metrics", methods=["GET"]) |
|
def metrics(): |
|
return jsonify(manager.get_metrics()) |
|
|
|
|
|
@app.route("/save_checkpoint", methods=["POST"]) |
|
def save_checkpoint_route(): |
|
repo_id = request.json.get("repo_id") |
|
token = request.json.get("token") or os.getenv("HF_TOKEN") |
|
if manager.model is None: |
|
return jsonify({"error": "model not initialized"}), 400 |
|
if token: |
|
hf_login(token=token) |
|
save_checkpoint(manager.model, repo_id=repo_id) |
|
return jsonify({"status": "saved"}) |
|
|
|
|
|
@app.route("/download_checkpoint", methods=["POST"]) |
|
def download_checkpoint_route(): |
|
repo_id = request.json.get("repo_id") |
|
token = request.json.get("token") or os.getenv("HF_TOKEN") |
|
if token: |
|
hf_login(token=token) |
|
dest = manager.weights_path + ".gz" |
|
ok = download_checkpoint(dest, repo_id=repo_id) |
|
if not ok: |
|
return jsonify({"status": "failed"}), 500 |
|
if manager.model is None: |
|
return jsonify({"status": "downloaded", "loaded": False}) |
|
with gzip.open(dest, "rb") as f: |
|
state = torch.load(f, map_location="cpu") |
|
manager.model.load_state_dict(state) |
|
manager.optimizer, manager.scheduler = configure_optimizer( |
|
manager.model, lr=1e-3, total_steps=manager.total_steps |
|
) |
|
manager._apply_device() |
|
manager._save_state() |
|
return jsonify({"status": "downloaded", "loaded": True}) |
|
|
|
@app.route("/plot.png") |
|
def plot_png(): |
|
fig, _ = plot_telemetry(manager.metrics) |
|
buf = io.BytesIO() |
|
fig.savefig(buf, format="png") |
|
plt.close(fig) |
|
buf.seek(0) |
|
return send_file(buf, mimetype="image/png") |
|
|
|
|
|
@app.route("/text_to_bits", methods=["POST"]) |
|
def text_to_bits_route(): |
|
text = request.json.get("text", "") |
|
if len(text) > 100_000: |
|
return jsonify({"error": "text too large"}), 413 |
|
return jsonify({"bits": text_to_bits(text)}) |
|
|
|
|
|
@app.route("/dataset", methods=["GET"]) |
|
def dataset_route(): |
|
name = request.args.get("name", "") |
|
split = request.args.get("split", "train") |
|
size = int(request.args.get("size", 1)) |
|
seq_len = int(request.args.get("seq_len", 64)) |
|
if size * seq_len > 1_000_000: |
|
return jsonify({"error": "dataset too large"}), 413 |
|
if name == "wikitext2": |
|
try: |
|
from datasets import load_dataset |
|
|
|
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split=split) |
|
lines = [t for t in ds["text"] if t.strip()][:size] |
|
except Exception: |
|
bits = torch.randint(0, 2, (size, seq_len), dtype=torch.long) |
|
return jsonify({"bits": bits.tolist()}) |
|
bits_list = [] |
|
for text in lines: |
|
b = text_to_bits(text)[:seq_len] |
|
if len(b) < seq_len: |
|
b.extend([0] * (seq_len - len(b))) |
|
bits_list.append(b) |
|
if len(bits_list) < size: |
|
pad = size - len(bits_list) |
|
bits_list.extend(torch.randint(0, 2, (pad, seq_len), dtype=torch.long).tolist()) |
|
return jsonify({"bits": bits_list}) |
|
return jsonify({"error": "unknown dataset"}), 400 |
|
|
|
|
|
|
|
|
|
@app.route("/dataset/create", methods=["POST"]) |
|
def create_dataset(): |
|
"""Create and upload a new BitTransformerLM dataset.""" |
|
data = request.json or {} |
|
|
|
hf_token = data.get("hf_token") or os.getenv("HF_TOKEN") |
|
repo_id = data.get("repo_id", "BitTransformerLM") |
|
source_texts = data.get("source_texts", None) |
|
|
|
if not hf_token: |
|
return jsonify({"error": "HF token required"}), 400 |
|
|
|
def task(): |
|
try: |
|
dataset_url = create_bittransformerlm_dataset( |
|
hf_token=hf_token, |
|
repo_id=repo_id, |
|
source_texts=source_texts |
|
) |
|
return { |
|
"status": "success", |
|
"dataset_url": dataset_url, |
|
"repo_id": repo_id |
|
} |
|
except Exception as e: |
|
return { |
|
"status": "error", |
|
"error": str(e) |
|
} |
|
|
|
job_id = _submit_job(task) |
|
return jsonify({"job_id": job_id, "message": "Dataset creation started"}) |
|
|
|
|
|
@app.route("/dataset/builder", methods=["POST"]) |
|
def create_dataset_builder(): |
|
"""Initialize a dataset builder for custom dataset creation.""" |
|
data = request.json or {} |
|
|
|
hf_token = data.get("hf_token") or os.getenv("HF_TOKEN") |
|
repo_id = data.get("repo_id", "BitTransformerLM") |
|
|
|
if not hf_token: |
|
return jsonify({"error": "HF token required"}), 400 |
|
|
|
try: |
|
builder = BitTransformerDatasetBuilder(hf_token, repo_id) |
|
|
|
|
|
builder_info = { |
|
"repo_id": repo_id, |
|
"config": builder.config, |
|
"status": "ready" |
|
} |
|
|
|
return jsonify({ |
|
"status": "builder_created", |
|
"builder_info": builder_info |
|
}) |
|
|
|
except Exception as e: |
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
@app.route("/dataset/generate", methods=["POST"]) |
|
def generate_dataset_samples(): |
|
"""Generate specific types of dataset samples.""" |
|
data = request.json or {} |
|
|
|
sample_type = data.get("type", "text_to_bits") |
|
count = int(data.get("count", 100)) |
|
max_len = int(data.get("max_len", 256)) |
|
texts = data.get("texts", None) |
|
|
|
if count > 5000: |
|
return jsonify({"error": "count too large, max 5000"}), 400 |
|
|
|
def task(): |
|
try: |
|
|
|
builder = BitTransformerDatasetBuilder("dummy_token", "temp") |
|
|
|
if sample_type == "text_to_bits": |
|
if not texts: |
|
texts = builder._get_default_texts()[:count] |
|
samples = builder.generate_text_to_bits_data(texts[:count], max_len) |
|
|
|
elif sample_type == "synthetic": |
|
samples = builder.generate_synthetic_patterns(count, max_len) |
|
|
|
elif sample_type == "safety": |
|
samples = builder.generate_safety_benchmarks(count) |
|
|
|
elif sample_type == "compression": |
|
|
|
base_texts = builder._get_default_texts()[:50] |
|
base_samples = builder.generate_text_to_bits_data(base_texts, max_len) |
|
samples = builder.generate_compression_variants(base_samples)[:count] |
|
|
|
else: |
|
return {"error": f"Unknown sample type: {sample_type}"} |
|
|
|
return { |
|
"status": "success", |
|
"samples": samples[:10], |
|
"total_generated": len(samples), |
|
"sample_type": sample_type |
|
} |
|
|
|
except Exception as e: |
|
return {"error": str(e)} |
|
|
|
job_id = _submit_job(task) |
|
return jsonify({"job_id": job_id, "message": f"Generating {sample_type} samples"}) |
|
|
|
|
|
@app.route("/dataset/info", methods=["GET"]) |
|
def dataset_info(): |
|
"""Get information about available dataset generation options.""" |
|
return jsonify({ |
|
"sample_types": [ |
|
{ |
|
"type": "text_to_bits", |
|
"description": "Convert text to parity-protected bit sequences", |
|
"parameters": ["texts", "max_len"] |
|
}, |
|
{ |
|
"type": "synthetic", |
|
"description": "Generate synthetic bit patterns", |
|
"parameters": ["count", "max_len"], |
|
"patterns": ["alternating", "blocks", "fibonacci", "prime_based", "random_walk"] |
|
}, |
|
{ |
|
"type": "safety", |
|
"description": "Generate safety benchmark sequences", |
|
"parameters": ["count"], |
|
"categories": ["low_entropy", "medium_entropy", "high_entropy", "edge_cases"] |
|
}, |
|
{ |
|
"type": "compression", |
|
"description": "Generate compressed variants of base sequences", |
|
"parameters": ["count", "compression_ratios"] |
|
} |
|
], |
|
"default_config": { |
|
"max_sequence_length": 512, |
|
"total_samples": 25000, |
|
"safety_thresholds": { |
|
"min_negentropy": 0.1, |
|
"max_lz_complexity": 0.9, |
|
"min_symbiosis": 0.3 |
|
} |
|
} |
|
}) |
|
|
|
|
|
@app.route("/health") |
|
def health_check(): |
|
return jsonify({"status": "ok"}) |
|
|
|
|
|
def run_mcp_server(host: str = "0.0.0.0", port: int = 7000) -> None: |
|
app.run(host=host, port=port, debug=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
import torch |
|
run_mcp_server() |
|
|