mike23415 commited on
Commit
01046ae
·
verified ·
1 Parent(s): ddb7bf1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -0
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import torch
3
+ import os
4
+ from flask import Flask, request, Response, jsonify
5
+ from flask_cors import CORS
6
+ from huggingface_hub import HfApi, login
7
+
8
+ app = Flask(__name__)
9
+ CORS(app)
10
+
11
+ # Global state
12
+ tokenizer = None
13
+ model = None
14
+ model_loading = False
15
+ model_loaded = False
16
+ model_id = "microsoft/bitnet-b1.58-2B-4T"
17
+
18
+ # Load model in background
19
+ def load_model_thread():
20
+ global tokenizer, model, model_loaded, model_loading
21
+ try:
22
+ model_loading = True
23
+ from transformers import AutoTokenizer, AutoModelForCausalLM
24
+ print("Loading tokenizer...")
25
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
26
+ print("Loading model...")
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ model_id,
29
+ torch_dtype=torch.float32,
30
+ device_map=None
31
+ ).to("cpu")
32
+ model_loaded = True
33
+ print("✅ Model loaded successfully.")
34
+ except Exception as e:
35
+ print(f"❌ Error loading model: {e}")
36
+ finally:
37
+ model_loading = False
38
+
39
+ # Start background model load
40
+ threading.Thread(target=load_model_thread, daemon=True).start()
41
+
42
+ @app.route("/")
43
+ def home():
44
+ return "🚀 Flask backend for BitNet is running!"
45
+
46
+ @app.route("/api/health", methods=["GET"])
47
+ def health():
48
+ """Health check endpoint"""
49
+ return {
50
+ "status": "ok",
51
+ "model_loaded": model_loaded,
52
+ "model_loading": model_loading
53
+ }
54
+
55
+ @app.route("/api/chat", methods=["POST"])
56
+ def chat():
57
+ """Chat endpoint with BitNet streaming response"""
58
+ global model_loaded, model, tokenizer
59
+
60
+ if not model_loaded:
61
+ return {
62
+ "status": "initializing",
63
+ "message": "Model is still loading. Please try again shortly."
64
+ }, 503
65
+
66
+ try:
67
+ from transformers import TextIteratorStreamer
68
+ data = request.get_json()
69
+ message = data.get("message", "")
70
+ history = data.get("history", [])
71
+ system_message = data.get("system_message", (
72
+ "You are a helpful assistant. When generating code, always wrap it in markdown code blocks (```) "
73
+ "with the appropriate language identifier (e.g., ```python, ```javascript). "
74
+ "Ensure proper indentation and line breaks for readability."
75
+ ))
76
+ max_tokens = data.get("max_tokens", 512)
77
+ temperature = data.get("temperature", 0.7)
78
+ top_p = data.get("top_p", 0.95)
79
+
80
+ messages = [{"role": "system", "content": system_message}]
81
+ for user_msg, bot_msg in history:
82
+ messages.append({"role": "user", "content": user_msg})
83
+ messages.append({"role": "assistant", "content": bot_msg})
84
+ messages.append({"role": "user", "content": message})
85
+
86
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
87
+ inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
88
+
89
+ streamer = TextIteratorStreamer(
90
+ tokenizer, skip_prompt=True, skip_special_tokens=True
91
+ )
92
+
93
+ generate_kwargs = dict(
94
+ **inputs,
95
+ streamer=streamer,
96
+ max_new_tokens=max_tokens,
97
+ temperature=temperature,
98
+ top_p=top_p,
99
+ do_sample=True,
100
+ )
101
+
102
+ thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
103
+ thread.start()
104
+
105
+ def generate():
106
+ for new_text in streamer:
107
+ yield f"data: {json.dumps({'response': new_text})}\n\n"
108
+ yield "data: [DONE]\n\n"
109
+
110
+ return Response(generate(), mimetype="text/event-stream")
111
+
112
+ except Exception as e:
113
+ print("Error during chat:", e)
114
+ return {"error": str(e)}, 500
115
+
116
+ @app.route("/api/save_model", methods=["POST"])
117
+ def save_model():
118
+ """Save model and tokenizer to Hugging Face Hub"""
119
+ global model, tokenizer, model_loaded
120
+
121
+ if not model_loaded:
122
+ return {"error": "Model is still loading. Try again later."}, 503
123
+
124
+ try:
125
+ # Authenticate with Hugging Face
126
+ token = request.json.get("token")
127
+ if not token:
128
+ return {"error": "Hugging Face token required"}, 400
129
+ login(token=token)
130
+
131
+ # Define repository
132
+ repo_id = "priyanshu/playwebit"
133
+ save_directory = "/tmp/playwebit"
134
+
135
+ # Create temporary directory
136
+ os.makedirs(save_directory, exist_ok=True)
137
+
138
+ # Save custom model class (replace with actual implementation)
139
+ custom_model_code = """
140
+ from transformers import PreTrainedModel
141
+ from transformers.models.bitnet.configuration_bitnet import BitNetConfig
142
+
143
+ class BitNetForCausalLM(PreTrainedModel):
144
+ config_class = BitNetConfig
145
+
146
+ def __init__(self, config):
147
+ super().__init__(config)
148
+ # Placeholder: Copy implementation from fork's modeling_bitnet.py
149
+ raise NotImplementedError("Replace with actual BitNetForCausalLM implementation")
150
+
151
+ def forward(self, *args, **kwargs):
152
+ # Placeholder: Copy forward pass from fork
153
+ raise NotImplementedError("Replace with actual forward pass implementation")
154
+ """
155
+ with open(os.path.join(save_directory, "custom_bitnet.py"), "w") as f:
156
+ f.write(custom_model_code)
157
+
158
+ # Save configuration
159
+ model.config.save_pretrained(save_directory)
160
+
161
+ # Save model and tokenizer
162
+ print("Saving model and tokenizer...")
163
+ model.save_pretrained(save_directory, safe_serialization=True, max_shard_size="5GB")
164
+ tokenizer.save_pretrained(save_directory)
165
+
166
+ # Update config.json to reference custom class
167
+ import json
168
+ config_path = os.path.join(save_directory, "config.json")
169
+ with open(config_path, "r") as f:
170
+ config_json = json.load(f)
171
+ config_json["architectures"] = ["BitNetForCausalLM"]
172
+ with open(config_path, "w") as f:
173
+ json.dump(config_json, f, indent=2)
174
+
175
+ # Try TensorFlow conversion
176
+ try:
177
+ from transformers import TFAutoModelForCausalLM
178
+ print("Converting to TensorFlow weights...")
179
+ tf_model = TFAutoModelForCausalLM.from_pretrained(save_directory, from_pt=True)
180
+ tf_model.save_pretrained(save_directory)
181
+ print("TensorFlow weights saved.")
182
+ except Exception as e:
183
+ print(f"Error converting to TensorFlow: {e}")
184
+
185
+ # Upload to Hugging Face Hub
186
+ api = HfApi()
187
+ print(f"Uploading to {repo_id}...")
188
+ api.upload_folder(
189
+ folder_path=save_directory,
190
+ repo_id=repo_id,
191
+ repo_type="model",
192
+ commit_message="Upload PlayWeBit model, tokenizer, and custom class"
193
+ )
194
+
195
+ return {"message": f"Model uploaded to https://huggingface.co/{repo_id}"}
196
+
197
+ except Exception as e:
198
+ print("Error saving model:", e)
199
+ return {"error": str(e)}, 500
200
+
201
+ if __name__ == "__main__":
202
+ app.run(host="0.0.0.0", port=7860)