File size: 6,277 Bytes
01046ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f3e630
01046ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f3e630
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import threading
import torch
import os
from flask import Flask, request, Response, jsonify
from flask_cors import CORS
from huggingface_hub import HfApi, login

app = Flask(__name__)
CORS(app)

# Global state
tokenizer = None
model = None
model_loading = False
model_loaded = False
model_id = "microsoft/bitnet-b1.58-2B-4T"

# Load model in background
def load_model_thread():
    global tokenizer, model, model_loaded, model_loading
    try:
        model_loading = True
        from transformers import AutoTokenizer, AutoModelForCausalLM
        print("Loading tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        print("Loading model...")
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.float32,
            device_map=None
        ).to("cpu")
        model_loaded = True
        print("✅ Model loaded successfully.")
    except Exception as e:
        print(f"❌ Error loading model: {e}")
    finally:
        model_loading = False

# Start background model load
threading.Thread(target=load_model_thread, daemon=True).start()

@app.route("/")
def home():
    return "🚀 Flask backend for BitNet is running!"

@app.route("/api/health", methods=["GET"])
def health():
    """Health check endpoint"""
    return {
        "status": "ok",
        "model_loaded": model_loaded,
        "model_loading": model_loading
    }

@app.route("/api/chat", methods=["POST"])
def chat():
    """Chat endpoint with BitNet streaming response"""
    global model_loaded, model, tokenizer
    if not model_loaded:
        return {
            "status": "initializing",
            "message": "Model is still loading. Please try again shortly."
        }, 503
    try:
        from transformers import TextIteratorStreamer
        data = request.get_json()
        message = data.get("message", "")
        history = data.get("history", [])
        system_message = data.get("system_message", (
            "You are a helpful assistant. When generating code, always wrap it in markdown code blocks (```) "
            "with the appropriate language identifier (e.g., ```python, ```javascript). "
            "Ensure proper indentation and line breaks for readability."
        ))
        max_tokens = data.get("max_tokens", 512)
        temperature = data.get("temperature", 0.7)
        top_p = data.get("top_p", 0.95)
        messages = [{"role": "system", "content": system_message}]
        for user_msg, bot_msg in history:
            messages.append({"role": "user", "content": user_msg})
            messages.append({"role": "assistant", "content": bot_msg})
        messages.append({"role": "user", "content": message})
        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
        streamer = TextIteratorStreamer(
            tokenizer, skip_prompt=True, skip_special_tokens=True
        )
        generate_kwargs = dict(
            **inputs,
            streamer=streamer,
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
        )
        thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
        thread.start()
        def generate():
            for new_text in streamer:
                yield f"data: {json.dumps({'response': new_text})}\n\n"
            yield "data: [DONE]\n\n"
        return Response(generate(), mimetype="text/event-stream")
    except Exception as e:
        print("Error during chat:", e)
        return {"error": str(e)}, 500

@app.route("/api/save_model", methods=["POST"])
def save_model():
    """Save model and tokenizer to Hugging Face Hub"""
    global model, tokenizer, model_loaded
    if not model_loaded:
        return {"error": "Model is still loading. Try again later."}, 503
    try:
        # Authenticate with Hugging Face
        token = request.json.get("token")
        if not token:
            return {"error": "Hugging Face token required"}, 400
        login(token=token)
        # Define repository
        repo_id = "mike23415/playwebit"
        save_directory = "/tmp/playwebit"
        # Create temporary directory
        os.makedirs(save_directory, exist_ok=True)
        # Save custom model class (replace with actual implementation)
        custom_model_code = """
from transformers import PreTrainedModel
from transformers.models.bitnet.configuration_bitnet import BitNetConfig

class BitNetForCausalLM(PreTrainedModel):
    config_class = BitNetConfig

    def __init__(self, config):
        super().__init__(config)
        # Placeholder: Copy implementation from fork's modeling_bitnet.py
        raise NotImplementedError("Replace with actual BitNetForCausalLM implementation")

    def forward(self, *args, **kwargs):
        # Placeholder: Copy forward pass from fork
        raise NotImplementedError("Replace with actual forward pass implementation")
"""
        with open(os.path.join(save_directory, "custom_bitnet.py"), "w") as f:
            f.write(custom_model_code)
        # Save configuration
        model.config.save_pretrained(save_directory)
        # Save model and tokenizer
        print("Saving model and tokenizer...")
        model.save_pretrained(save_directory, safe_serialization=True, max_shard_size="5GB")
        tokenizer.save_pretrained(save_directory)
        # Update config.json to reference custom class
        import json
        config_path = os.path.join(save_directory, "config.json")
        with open(config_path, "r") as f:
            config_json = json.load(f)
        config_json["architectures"] = ["BitNetForCausalLM"]
        with open(config_path, "w") as f:
            json.dump(config_json, f, indent=2)
        # Try TensorFlow conversion
        try:
            from transformers import TFAutoModelForCausalLM
            print("Converting to TensorFlow weights...")
            tf_model = TFAutoModelForCausalLM.from_pretrained(save_directory, from_pt=True)
            tf_model.save_pretrained(save_directory)
            print("TensorFlow weights saved.")
        except Exception as e:
            print(f"Error converting to TensorFlow: {e}")