File size: 6,277 Bytes
01046ae 9f3e630 01046ae 9f3e630 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import threading
import torch
import os
from flask import Flask, request, Response, jsonify
from flask_cors import CORS
from huggingface_hub import HfApi, login
app = Flask(__name__)
CORS(app)
# Global state
tokenizer = None
model = None
model_loading = False
model_loaded = False
model_id = "microsoft/bitnet-b1.58-2B-4T"
# Load model in background
def load_model_thread():
global tokenizer, model, model_loaded, model_loading
try:
model_loading = True
from transformers import AutoTokenizer, AutoModelForCausalLM
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
device_map=None
).to("cpu")
model_loaded = True
print("✅ Model loaded successfully.")
except Exception as e:
print(f"❌ Error loading model: {e}")
finally:
model_loading = False
# Start background model load
threading.Thread(target=load_model_thread, daemon=True).start()
@app.route("/")
def home():
return "🚀 Flask backend for BitNet is running!"
@app.route("/api/health", methods=["GET"])
def health():
"""Health check endpoint"""
return {
"status": "ok",
"model_loaded": model_loaded,
"model_loading": model_loading
}
@app.route("/api/chat", methods=["POST"])
def chat():
"""Chat endpoint with BitNet streaming response"""
global model_loaded, model, tokenizer
if not model_loaded:
return {
"status": "initializing",
"message": "Model is still loading. Please try again shortly."
}, 503
try:
from transformers import TextIteratorStreamer
data = request.get_json()
message = data.get("message", "")
history = data.get("history", [])
system_message = data.get("system_message", (
"You are a helpful assistant. When generating code, always wrap it in markdown code blocks (```) "
"with the appropriate language identifier (e.g., ```python, ```javascript). "
"Ensure proper indentation and line breaks for readability."
))
max_tokens = data.get("max_tokens", 512)
temperature = data.get("temperature", 0.7)
top_p = data.get("top_p", 0.95)
messages = [{"role": "system", "content": system_message}]
for user_msg, bot_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": message})
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
)
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
def generate():
for new_text in streamer:
yield f"data: {json.dumps({'response': new_text})}\n\n"
yield "data: [DONE]\n\n"
return Response(generate(), mimetype="text/event-stream")
except Exception as e:
print("Error during chat:", e)
return {"error": str(e)}, 500
@app.route("/api/save_model", methods=["POST"])
def save_model():
"""Save model and tokenizer to Hugging Face Hub"""
global model, tokenizer, model_loaded
if not model_loaded:
return {"error": "Model is still loading. Try again later."}, 503
try:
# Authenticate with Hugging Face
token = request.json.get("token")
if not token:
return {"error": "Hugging Face token required"}, 400
login(token=token)
# Define repository
repo_id = "mike23415/playwebit"
save_directory = "/tmp/playwebit"
# Create temporary directory
os.makedirs(save_directory, exist_ok=True)
# Save custom model class (replace with actual implementation)
custom_model_code = """
from transformers import PreTrainedModel
from transformers.models.bitnet.configuration_bitnet import BitNetConfig
class BitNetForCausalLM(PreTrainedModel):
config_class = BitNetConfig
def __init__(self, config):
super().__init__(config)
# Placeholder: Copy implementation from fork's modeling_bitnet.py
raise NotImplementedError("Replace with actual BitNetForCausalLM implementation")
def forward(self, *args, **kwargs):
# Placeholder: Copy forward pass from fork
raise NotImplementedError("Replace with actual forward pass implementation")
"""
with open(os.path.join(save_directory, "custom_bitnet.py"), "w") as f:
f.write(custom_model_code)
# Save configuration
model.config.save_pretrained(save_directory)
# Save model and tokenizer
print("Saving model and tokenizer...")
model.save_pretrained(save_directory, safe_serialization=True, max_shard_size="5GB")
tokenizer.save_pretrained(save_directory)
# Update config.json to reference custom class
import json
config_path = os.path.join(save_directory, "config.json")
with open(config_path, "r") as f:
config_json = json.load(f)
config_json["architectures"] = ["BitNetForCausalLM"]
with open(config_path, "w") as f:
json.dump(config_json, f, indent=2)
# Try TensorFlow conversion
try:
from transformers import TFAutoModelForCausalLM
print("Converting to TensorFlow weights...")
tf_model = TFAutoModelForCausalLM.from_pretrained(save_directory, from_pt=True)
tf_model.save_pretrained(save_directory)
print("TensorFlow weights saved.")
except Exception as e:
print(f"Error converting to TensorFlow: {e}") |