Spaces:
Sleeping
Sleeping
Commit
Β·
a0ee3bd
1
Parent(s):
fdf2b7f
fixes
Browse files
app.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
import time
|
4 |
-
import
|
5 |
-
import
|
6 |
import threading
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer
|
8 |
from peft import PeftModel
|
@@ -13,118 +13,74 @@ if torch.cuda.is_available():
|
|
13 |
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
14 |
print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
15 |
|
16 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
model_loaded = False
|
18 |
model_loading = False
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
global model_loaded, model_loading, loading_error, model, tokenizer, pipe
|
27 |
|
28 |
-
if
|
29 |
-
|
30 |
|
31 |
-
|
32 |
-
print("Starting model loading process...")
|
33 |
|
34 |
try:
|
35 |
-
#
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
42 |
-
print("Tokenizer loaded successfully")
|
43 |
-
|
44 |
-
# Load the base model in 4-bit
|
45 |
-
model = AutoModelForCausalLM.from_pretrained(
|
46 |
-
model_id,
|
47 |
-
torch_dtype=torch.float16,
|
48 |
-
device_map="auto",
|
49 |
-
load_in_4bit=True,
|
50 |
)
|
51 |
-
print("Base model loaded successfully")
|
52 |
|
53 |
-
|
54 |
-
print(f"Loading adapter {adapter_id}...")
|
55 |
-
model = PeftModel.from_pretrained(model, adapter_id)
|
56 |
-
print("Adapter loaded and applied successfully")
|
57 |
-
|
58 |
-
# Set up pipeline
|
59 |
-
print("Creating text generation pipeline...")
|
60 |
-
pipe = pipeline(
|
61 |
-
"text-generation",
|
62 |
-
model=model,
|
63 |
-
tokenizer=tokenizer,
|
64 |
-
device_map="auto",
|
65 |
-
)
|
66 |
-
print("Pipeline created successfully")
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
except Exception as e:
|
72 |
-
|
73 |
-
|
74 |
-
traceback.print_exc()
|
75 |
-
|
76 |
-
finally:
|
77 |
-
model_loading = False
|
78 |
-
|
79 |
-
# Start model loading in background thread
|
80 |
-
threading.Thread(target=load_model_in_thread, daemon=True).start()
|
81 |
-
|
82 |
-
def format_chat_prompt(messages):
|
83 |
-
"""Format messages into a prompt that Mistral-7B-Instruct can understand"""
|
84 |
-
prompt = ""
|
85 |
-
for message in messages:
|
86 |
-
if message["role"] == "system":
|
87 |
-
prompt += f"<s>[INST] {message['content']} [/INST]</s>\n"
|
88 |
-
elif message["role"] == "user":
|
89 |
-
prompt += f"<s>[INST] {message['content']} [/INST]"
|
90 |
-
elif message["role"] == "assistant":
|
91 |
-
prompt += f" {message['content']} </s>\n"
|
92 |
-
return prompt
|
93 |
-
|
94 |
-
def generate_response(messages, max_new_tokens=512, temperature=0.7, top_p=0.95):
|
95 |
-
"""Generate a response from the model"""
|
96 |
-
global model_loaded, loading_error, model, tokenizer, pipe
|
97 |
-
|
98 |
-
if not model_loaded:
|
99 |
-
if loading_error:
|
100 |
-
return f"Error loading model: {loading_error}"
|
101 |
-
return "Model is still loading. Please wait a moment and try again."
|
102 |
-
|
103 |
-
# Format the prompt for Mistral
|
104 |
-
prompt = format_chat_prompt(messages)
|
105 |
-
|
106 |
-
# Set up the streamer for incremental generation
|
107 |
-
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
|
108 |
-
|
109 |
-
# Generate in a separate thread to enable streaming
|
110 |
-
generation_kwargs = {
|
111 |
-
"input_ids": tokenizer.encode(prompt, return_tensors="pt").to("cuda"),
|
112 |
-
"max_new_tokens": max_new_tokens,
|
113 |
-
"temperature": temperature,
|
114 |
-
"top_p": top_p,
|
115 |
-
"do_sample": True,
|
116 |
-
"streamer": streamer,
|
117 |
-
}
|
118 |
-
|
119 |
-
# Start generation in a thread
|
120 |
-
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
|
121 |
-
thread.start()
|
122 |
-
|
123 |
-
# Stream the output
|
124 |
-
generated_text = ""
|
125 |
-
for new_text in streamer:
|
126 |
-
generated_text += new_text
|
127 |
-
yield generated_text
|
128 |
|
129 |
def respond(
|
130 |
message,
|
@@ -135,21 +91,8 @@ def respond(
|
|
135 |
top_p,
|
136 |
):
|
137 |
"""Respond to user messages"""
|
138 |
-
global model_loaded, model_loading
|
139 |
-
|
140 |
-
# Check if model is loaded
|
141 |
-
if not model_loaded:
|
142 |
-
if model_loading:
|
143 |
-
yield "β The model is still loading. This can take a few minutes on first startup. Please wait or try again later."
|
144 |
-
return
|
145 |
-
else:
|
146 |
-
# Try loading the model if it hasn't started yet
|
147 |
-
if not threading.active_count() > 1: # No background thread running
|
148 |
-
threading.Thread(target=load_model_in_thread, daemon=True).start()
|
149 |
-
yield "β Starting model load now. Please wait a moment and try again."
|
150 |
-
return
|
151 |
|
152 |
-
# Create the messages list
|
153 |
messages = [{"role": "system", "content": system_message}]
|
154 |
|
155 |
for val in history:
|
@@ -160,19 +103,85 @@ def respond(
|
|
160 |
|
161 |
messages.append({"role": "user", "content": message})
|
162 |
|
163 |
-
#
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
"""
|
178 |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
@@ -192,8 +201,8 @@ demo = gr.ChatInterface(
|
|
192 |
),
|
193 |
],
|
194 |
description="""This interface uses a fine-tuned Mistral model for Microsoft 365 data management.
|
195 |
-
The model is
|
196 |
-
First
|
197 |
)
|
198 |
|
199 |
|
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
import time
|
4 |
+
import json
|
5 |
+
import requests
|
6 |
import threading
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer
|
8 |
from peft import PeftModel
|
|
|
13 |
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
14 |
print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
15 |
|
16 |
+
# Get token from environment
|
17 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
18 |
+
print(f"HF_TOKEN is {'available' if HF_TOKEN else 'not available'}")
|
19 |
+
|
20 |
+
# Setup API for the Hugging Face Inference API
|
21 |
+
MODEL_ID = "Trinoid/Data_Management_Mistral"
|
22 |
+
API_URL = f"https://api-inference.huggingface.co/models/{MODEL_ID}"
|
23 |
+
headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
|
24 |
+
|
25 |
+
# Check if model exists
|
26 |
+
try:
|
27 |
+
print(f"Checking if model {MODEL_ID} exists...")
|
28 |
+
response = requests.get(API_URL, headers=headers)
|
29 |
+
print(f"Status: {response.status_code}")
|
30 |
+
if response.status_code == 200:
|
31 |
+
print("Model exists and is accessible")
|
32 |
+
else:
|
33 |
+
print(f"Response: {response.text}")
|
34 |
+
except Exception as e:
|
35 |
+
print(f"Error checking model: {str(e)}")
|
36 |
+
|
37 |
+
# Global variable to track model status
|
38 |
model_loaded = False
|
39 |
model_loading = False
|
40 |
+
estimated_time = None
|
41 |
+
|
42 |
+
def query_model(messages, parameters=None):
|
43 |
+
"""Query the model using the Inference API"""
|
44 |
+
payload = {
|
45 |
+
"inputs": messages,
|
46 |
+
}
|
|
|
47 |
|
48 |
+
if parameters:
|
49 |
+
payload["parameters"] = parameters
|
50 |
|
51 |
+
print(f"Sending query to API...")
|
|
|
52 |
|
53 |
try:
|
54 |
+
# Single attempt with longer timeout
|
55 |
+
response = requests.post(
|
56 |
+
API_URL,
|
57 |
+
headers=headers,
|
58 |
+
json=payload,
|
59 |
+
timeout=180 # 3 minute timeout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
)
|
|
|
61 |
|
62 |
+
print(f"API response status: {response.status_code}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
+
# If successful, return the response
|
65 |
+
if response.status_code == 200:
|
66 |
+
return response.json()
|
67 |
+
|
68 |
+
# If model is loading, handle it
|
69 |
+
elif response.status_code == 503 and "estimated_time" in response.json():
|
70 |
+
est_time = response.json()["estimated_time"]
|
71 |
+
global estimated_time
|
72 |
+
estimated_time = est_time
|
73 |
+
print(f"Model is loading. Estimated time: {est_time:.2f} seconds")
|
74 |
+
return None
|
75 |
+
|
76 |
+
# For other errors
|
77 |
+
else:
|
78 |
+
print(f"API error: {response.text}")
|
79 |
+
return None
|
80 |
+
|
81 |
except Exception as e:
|
82 |
+
print(f"Request exception: {str(e)}")
|
83 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
def respond(
|
86 |
message,
|
|
|
91 |
top_p,
|
92 |
):
|
93 |
"""Respond to user messages"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
+
# Create the messages list in chat format
|
96 |
messages = [{"role": "system", "content": system_message}]
|
97 |
|
98 |
for val in history:
|
|
|
103 |
|
104 |
messages.append({"role": "user", "content": message})
|
105 |
|
106 |
+
# Set up the inference parameters
|
107 |
+
parameters = {
|
108 |
+
"max_new_tokens": max_tokens,
|
109 |
+
"temperature": temperature,
|
110 |
+
"top_p": top_p,
|
111 |
+
"do_sample": True
|
112 |
+
}
|
113 |
+
|
114 |
+
# Initial message about model status
|
115 |
+
global estimated_time
|
116 |
+
if estimated_time:
|
117 |
+
initial_msg = f"β The model is loading... estimated time: {estimated_time:.0f} seconds. Please be patient."
|
118 |
+
else:
|
119 |
+
initial_msg = "β Working on your request..."
|
120 |
+
|
121 |
+
yield initial_msg
|
122 |
+
|
123 |
+
# Try multiple times with increasing waits
|
124 |
+
max_retries = 6
|
125 |
+
for attempt in range(max_retries):
|
126 |
+
# Check if this is a retry
|
127 |
+
if attempt > 0:
|
128 |
+
wait_time = min(60, 10 * attempt)
|
129 |
+
yield f"β Still working on your request... (attempt {attempt+1}/{max_retries})"
|
130 |
+
time.sleep(wait_time)
|
131 |
+
|
132 |
+
try:
|
133 |
+
# Query the model
|
134 |
+
result = query_model(messages, parameters)
|
135 |
+
|
136 |
+
if result:
|
137 |
+
# Handle different response formats
|
138 |
+
|
139 |
+
# List format with generated_text
|
140 |
+
if isinstance(result, list) and len(result) > 0:
|
141 |
+
if "generated_text" in result[0]:
|
142 |
+
yield result[0]["generated_text"]
|
143 |
+
return
|
144 |
+
|
145 |
+
# Direct message format
|
146 |
+
if isinstance(result, dict) and "generated_text" in result:
|
147 |
+
yield result["generated_text"]
|
148 |
+
return
|
149 |
+
|
150 |
+
# String format
|
151 |
+
if isinstance(result, str):
|
152 |
+
yield result
|
153 |
+
return
|
154 |
+
|
155 |
+
# Raw format as fallback
|
156 |
+
yield str(result)
|
157 |
+
return
|
158 |
+
|
159 |
+
# If model is still loading, get the latest estimate
|
160 |
+
if estimated_time and attempt < max_retries - 1:
|
161 |
+
response = requests.get(API_URL, headers=headers)
|
162 |
+
if response.status_code == 503 and "estimated_time" in response.json():
|
163 |
+
estimated_time = response.json()["estimated_time"]
|
164 |
+
print(f"Updated loading time: {estimated_time:.0f} seconds")
|
165 |
+
|
166 |
+
except Exception as e:
|
167 |
+
print(f"Error in attempt {attempt+1}: {str(e)}")
|
168 |
+
if attempt == max_retries - 1:
|
169 |
+
yield f"""β Sorry, I couldn't generate a response after several attempts.
|
170 |
+
|
171 |
+
Error details: {str(e)}
|
172 |
+
|
173 |
+
Please try again later or contact support if this persists."""
|
174 |
+
|
175 |
+
# If all retries failed
|
176 |
+
yield """β The model couldn't be accessed after multiple attempts.
|
177 |
+
|
178 |
+
This could be due to:
|
179 |
+
1. Heavy server load
|
180 |
+
2. The model being too large for the current hardware
|
181 |
+
3. Temporary service issues
|
182 |
+
|
183 |
+
Please try again later."""
|
184 |
+
|
185 |
|
186 |
"""
|
187 |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
|
|
201 |
),
|
202 |
],
|
203 |
description="""This interface uses a fine-tuned Mistral model for Microsoft 365 data management.
|
204 |
+
The model is accessed via the Hugging Face Inference API.
|
205 |
+
First requests may take 2-3 minutes as the model loads."""
|
206 |
)
|
207 |
|
208 |
|