import threading import torch import os from flask import Flask, request, Response, jsonify from flask_cors import CORS from huggingface_hub import HfApi, login app = Flask(__name__) CORS(app) # Global state tokenizer = None model = None model_loading = False model_loaded = False model_id = "microsoft/bitnet-b1.58-2B-4T" # Load model in background def load_model_thread(): global tokenizer, model, model_loaded, model_loading try: model_loading = True from transformers import AutoTokenizer, AutoModelForCausalLM print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_id) print("Loading model...") model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float32, device_map=None ).to("cpu") model_loaded = True print("✅ Model loaded successfully.") except Exception as e: print(f"❌ Error loading model: {e}") finally: model_loading = False # Start background model load threading.Thread(target=load_model_thread, daemon=True).start() @app.route("/") def home(): return "🚀 Flask backend for BitNet is running!" @app.route("/api/health", methods=["GET"]) def health(): """Health check endpoint""" return { "status": "ok", "model_loaded": model_loaded, "model_loading": model_loading } @app.route("/api/chat", methods=["POST"]) def chat(): """Chat endpoint with BitNet streaming response""" global model_loaded, model, tokenizer if not model_loaded: return { "status": "initializing", "message": "Model is still loading. Please try again shortly." }, 503 try: from transformers import TextIteratorStreamer data = request.get_json() message = data.get("message", "") history = data.get("history", []) system_message = data.get("system_message", ( "You are a helpful assistant. When generating code, always wrap it in markdown code blocks (```) " "with the appropriate language identifier (e.g., ```python, ```javascript). " "Ensure proper indentation and line breaks for readability." )) max_tokens = data.get("max_tokens", 512) temperature = data.get("temperature", 0.7) top_p = data.get("top_p", 0.95) messages = [{"role": "system", "content": system_message}] for user_msg, bot_msg in history: messages.append({"role": "user", "content": user_msg}) messages.append({"role": "assistant", "content": bot_msg}) messages.append({"role": "user", "content": message}) prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(prompt, return_tensors="pt").to("cpu") streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, ) thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) thread.start() def generate(): for new_text in streamer: yield f"data: {json.dumps({'response': new_text})}\n\n" yield "data: [DONE]\n\n" return Response(generate(), mimetype="text/event-stream") except Exception as e: print("Error during chat:", e) return {"error": str(e)}, 500 @app.route("/api/save_model", methods=["POST"]) def save_model(): """Save model and tokenizer to Hugging Face Hub""" global model, tokenizer, model_loaded if not model_loaded: return {"error": "Model is still loading. Try again later."}, 503 try: # Authenticate with Hugging Face token = request.json.get("token") if not token: return {"error": "Hugging Face token required"}, 400 login(token=token) # Define repository repo_id = "mike23415/playwebit" save_directory = "/tmp/playwebit" # Create temporary directory os.makedirs(save_directory, exist_ok=True) # Save custom model class (replace with actual implementation) custom_model_code = """ from transformers import PreTrainedModel from transformers.models.bitnet.configuration_bitnet import BitNetConfig class BitNetForCausalLM(PreTrainedModel): config_class = BitNetConfig def __init__(self, config): super().__init__(config) # Placeholder: Copy implementation from fork's modeling_bitnet.py raise NotImplementedError("Replace with actual BitNetForCausalLM implementation") def forward(self, *args, **kwargs): # Placeholder: Copy forward pass from fork raise NotImplementedError("Replace with actual forward pass implementation") """ with open(os.path.join(save_directory, "custom_bitnet.py"), "w") as f: f.write(custom_model_code) # Save configuration model.config.save_pretrained(save_directory) # Save model and tokenizer print("Saving model and tokenizer...") model.save_pretrained(save_directory, safe_serialization=True, max_shard_size="5GB") tokenizer.save_pretrained(save_directory) # Update config.json to reference custom class import json config_path = os.path.join(save_directory, "config.json") with open(config_path, "r") as f: config_json = json.load(f) config_json["architectures"] = ["BitNetForCausalLM"] with open(config_path, "w") as f: json.dump(config_json, f, indent=2) # Try TensorFlow conversion try: from transformers import TFAutoModelForCausalLM print("Converting to TensorFlow weights...") tf_model = TFAutoModelForCausalLM.from_pretrained(save_directory, from_pt=True) tf_model.save_pretrained(save_directory) print("TensorFlow weights saved.") except Exception as e: print(f"Error converting to TensorFlow: {e}")