Frankie-walsh4 commited on
Commit
fdf2b7f
·
1 Parent(s): ee12cf3
Files changed (1) hide show
  1. app.py +133 -207
app.py CHANGED
@@ -1,153 +1,130 @@
1
  import gradio as gr
2
  import os
3
  import time
4
- import json
5
- import requests
6
  import threading
7
- from huggingface_hub.errors import HfHubHTTPError
 
8
 
9
- """
10
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
11
- """
12
- # Get token from environment (even though we might not need it)
13
- HF_TOKEN = os.environ.get("HF_TOKEN")
14
- print(f"HF_TOKEN is {'available' if HF_TOKEN else 'not available'}")
15
-
16
- # Setup API for the Hugging Face Inference API
17
- API_URL = "https://api-inference.huggingface.co/models/Trinoid/Data_Management_Mistral"
18
- headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
19
 
20
- print("Trying to access model directly via API")
21
- response = requests.get(API_URL, headers=headers)
22
- print(f"Status: {response.status_code}")
23
- print(f"Response: {response.text[:200]}...") # Print first 200 chars of response
24
-
25
- # Global variable to track if model is warmed up
26
- model_warmed_up = False
27
  model_loading = False
28
- estimated_time = None
 
 
 
29
 
30
- def query_model(inputs, parameters=None):
31
- """Send a query to the model via the Inference API"""
32
- payload = {
33
- "inputs": inputs,
34
- }
35
 
36
- if parameters:
37
- payload["parameters"] = parameters
38
-
39
- print(f"Sending query to API: {json.dumps(payload, indent=2)[:200]}...")
40
 
41
- # Try multiple times with backoff
42
- max_attempts = 5
43
- for attempt in range(max_attempts):
44
- try:
45
- response = requests.post(
46
- API_URL,
47
- headers=headers,
48
- json=payload,
49
- timeout=180 # 3 minute timeout
50
- )
51
-
52
- print(f"API response status: {response.status_code}")
53
-
54
- # If successful, return the result
55
- if response.status_code == 200:
56
- return response.json()
57
-
58
- # If model is loading, handle the error
59
- elif response.status_code == 503 and "estimated_time" in response.json():
60
- est_time = response.json()["estimated_time"]
61
- print(f"Model is loading. Estimated time: {est_time:.2f} seconds")
62
-
63
- # Wait a portion of the estimated time
64
- wait_time = min(30, max(10, est_time / 4))
65
- print(f"Waiting {wait_time:.2f} seconds before retry...")
66
- time.sleep(wait_time)
67
-
68
- # For other errors, wait and retry
69
- else:
70
- print(f"API error: {response.text}")
71
- wait_time = 10 * (attempt + 1)
72
- print(f"Waiting {wait_time} seconds before retry...")
73
- time.sleep(wait_time)
74
-
75
- except Exception as e:
76
- print(f"Request exception: {str(e)}")
77
- wait_time = 15 * (attempt + 1)
78
- print(f"Waiting {wait_time} seconds before retry...")
79
- time.sleep(wait_time)
80
 
81
- # If we've tried all attempts and still failed, return None
82
- return None
83
-
84
- def is_model_loaded():
85
- """Check if the model is loaded and ready for inference"""
86
  try:
87
- # Send a simple query to check model status
88
- response = requests.get(API_URL, headers=headers)
 
 
89
 
90
- # If we get a 200, the model is ready
91
- if response.status_code == 200:
92
- return True
93
 
94
- # If we get a 503 with estimated_time, it's loading
95
- if response.status_code == 503 and "estimated_time" in response.json():
96
- global estimated_time
97
- estimated_time = response.json()["estimated_time"]
98
- return False
99
-
100
- # Other response indicates an issue
101
- return False
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  except Exception as e:
104
- print(f"Error checking model status: {str(e)}")
105
- return False
 
 
 
 
106
 
107
- def warm_up_model():
108
- """Send a warmup request to get the model loaded"""
109
- global model_warmed_up, model_loading
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- if model_loading:
112
- return # Already warming up
113
-
114
- model_loading = True
115
 
116
- # Check if model is already loaded
117
- if is_model_loaded():
118
- print("Model is already loaded!")
119
- model_warmed_up = True
120
- model_loading = False
121
- return
122
-
123
- print("Starting model warm-up with basic query...")
124
 
125
- # Try to trigger model loading with a simple query
126
- inputs = [
127
- {"role": "system", "content": "You are a helpful assistant."},
128
- {"role": "user", "content": "Hi"}
129
- ]
130
 
131
- parameters = {
132
- "max_new_tokens": 5,
133
- "temperature": 0.1,
134
- "top_p": 0.95,
135
- "do_sample": True
 
 
 
136
  }
137
 
138
- # Send the query and check result
139
- result = query_model(inputs, parameters)
 
140
 
141
- if result:
142
- print("Warmup successful! Model is ready.")
143
- model_warmed_up = True
144
- else:
145
- print("Warmup failed. Will try again during first user query.")
146
-
147
- model_loading = False
148
-
149
- # Start warmup in background thread
150
- threading.Thread(target=warm_up_model, daemon=True).start()
151
 
152
  def respond(
153
  message,
@@ -157,7 +134,20 @@ def respond(
157
  temperature,
158
  top_p,
159
  ):
160
- global model_warmed_up, estimated_time
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
  # Create the messages list
163
  messages = [{"role": "system", "content": system_message}]
@@ -170,84 +160,19 @@ def respond(
170
 
171
  messages.append({"role": "user", "content": message})
172
 
173
- # Check if the model is ready
174
- if not model_warmed_up and not is_model_loaded():
175
- if estimated_time:
176
- yield f"⌛ Model is being loaded, estimated wait time: {estimated_time:.0f} seconds. Please be patient or try again later."
177
- else:
178
- yield "⌛ Model is being loaded. This may take some time on the first use."
179
-
180
- # Set up parameters for the query
181
- parameters = {
182
- "max_new_tokens": max_tokens,
183
- "temperature": temperature,
184
- "top_p": top_p,
185
- "do_sample": True
186
- }
187
-
188
- # Try multiple times if needed
189
- max_retries = 5
190
- for attempt in range(max_retries):
191
- try:
192
- print(f"Attempt {attempt + 1}/{max_retries} to query the model...")
193
-
194
- # Make API request
195
- result = query_model(messages, parameters)
196
-
197
- if result:
198
- # Handle different response formats
199
- if isinstance(result, list) and len(result) > 0:
200
- if "generated_text" in result[0]:
201
- response = result[0]["generated_text"]
202
- model_warmed_up = True
203
- yield response
204
- return
205
-
206
- # Direct message response format
207
- if isinstance(result, dict) and "generated_text" in result:
208
- response = result["generated_text"]
209
- model_warmed_up = True
210
- yield response
211
- return
212
-
213
- # For completion format
214
- if isinstance(result, str):
215
- model_warmed_up = True
216
- yield result
217
- return
218
-
219
- # Unknown format, show raw result
220
- print(f"Unexpected response format: {json.dumps(result, indent=2)[:500]}...")
221
- model_warmed_up = True
222
- yield str(result)
223
- return
224
-
225
- # If query_model returned None, it means all its retries failed
226
- print(f"Query attempt {attempt + 1} failed completely")
227
-
228
- if attempt < max_retries - 1:
229
- wait_time = 20 * (attempt + 1)
230
- yield f"⌛ Still trying to get a response (Attempt {attempt + 1}/{max_retries})..."
231
- time.sleep(wait_time)
232
- else:
233
- yield """❌ The model couldn't be accessed after multiple attempts.
234
-
235
- If you're seeing this on the Nvidia L40 hardware, please try:
236
- 1. Restarting the space
237
- 2. Checking your model's size and format
238
- 3. Contacting Hugging Face support if the issue persists"""
239
- return
240
-
241
- except Exception as e:
242
- print(f"Unexpected error: {str(e)}")
243
-
244
- if attempt < max_retries - 1:
245
- wait_time = 15
246
- yield f"⌛ An error occurred. Retrying (Attempt {attempt + 1}/{max_retries})..."
247
- time.sleep(wait_time)
248
- else:
249
- yield f"❌ An error occurred after multiple attempts: {str(e)}"
250
- return
251
 
252
  """
253
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
@@ -266,8 +191,9 @@ demo = gr.ChatInterface(
266
  label="Top-p (nucleus sampling)",
267
  ),
268
  ],
269
- description="""This interface uses a fine-tuned Mistral model for Microsoft 365 data management.
270
- This model runs on Nvidia L40 GPU hardware for optimal performance."""
 
271
  )
272
 
273
 
 
1
  import gradio as gr
2
  import os
3
  import time
4
+ import torch
5
+ import traceback
6
  import threading
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TextIteratorStreamer
8
+ from peft import PeftModel
9
 
10
+ print("CUDA available:", torch.cuda.is_available())
11
+ if torch.cuda.is_available():
12
+ print(f"CUDA device count: {torch.cuda.device_count()}")
13
+ print(f"CUDA device: {torch.cuda.get_device_name(0)}")
14
+ print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
 
 
 
 
 
15
 
16
+ # Global variable to track model loading
17
+ model_loaded = False
 
 
 
 
 
18
  model_loading = False
19
+ loading_error = None
20
+ model = None
21
+ tokenizer = None
22
+ pipe = None
23
 
24
+ def load_model_in_thread():
25
+ """Load the model in a separate thread to avoid blocking the UI"""
26
+ global model_loaded, model_loading, loading_error, model, tokenizer, pipe
 
 
27
 
28
+ if model_loading:
29
+ return # Already loading
 
 
30
 
31
+ model_loading = True
32
+ print("Starting model loading process...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
 
 
 
 
34
  try:
35
+ # Load base model
36
+ model_id = "mistralai/Mistral-7B-Instruct-v0.2"
37
+ adapter_id = "Trinoid/Data_Management_Mistral"
38
+ print(f"Loading base model {model_id}...")
39
 
40
+ # Initialize tokenizer
41
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
42
+ print("Tokenizer loaded successfully")
43
 
44
+ # Load the base model in 4-bit
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ model_id,
47
+ torch_dtype=torch.float16,
48
+ device_map="auto",
49
+ load_in_4bit=True,
50
+ )
51
+ print("Base model loaded successfully")
52
 
53
+ # Load and apply the LoRA adapter
54
+ print(f"Loading adapter {adapter_id}...")
55
+ model = PeftModel.from_pretrained(model, adapter_id)
56
+ print("Adapter loaded and applied successfully")
57
+
58
+ # Set up pipeline
59
+ print("Creating text generation pipeline...")
60
+ pipe = pipeline(
61
+ "text-generation",
62
+ model=model,
63
+ tokenizer=tokenizer,
64
+ device_map="auto",
65
+ )
66
+ print("Pipeline created successfully")
67
+
68
+ model_loaded = True
69
+ print("Model loading complete! Ready for inference.")
70
+
71
  except Exception as e:
72
+ loading_error = str(e)
73
+ print(f"Error loading model: {str(e)}")
74
+ traceback.print_exc()
75
+
76
+ finally:
77
+ model_loading = False
78
 
79
+ # Start model loading in background thread
80
+ threading.Thread(target=load_model_in_thread, daemon=True).start()
81
+
82
+ def format_chat_prompt(messages):
83
+ """Format messages into a prompt that Mistral-7B-Instruct can understand"""
84
+ prompt = ""
85
+ for message in messages:
86
+ if message["role"] == "system":
87
+ prompt += f"<s>[INST] {message['content']} [/INST]</s>\n"
88
+ elif message["role"] == "user":
89
+ prompt += f"<s>[INST] {message['content']} [/INST]"
90
+ elif message["role"] == "assistant":
91
+ prompt += f" {message['content']} </s>\n"
92
+ return prompt
93
+
94
+ def generate_response(messages, max_new_tokens=512, temperature=0.7, top_p=0.95):
95
+ """Generate a response from the model"""
96
+ global model_loaded, loading_error, model, tokenizer, pipe
97
 
98
+ if not model_loaded:
99
+ if loading_error:
100
+ return f"Error loading model: {loading_error}"
101
+ return "Model is still loading. Please wait a moment and try again."
102
 
103
+ # Format the prompt for Mistral
104
+ prompt = format_chat_prompt(messages)
 
 
 
 
 
 
105
 
106
+ # Set up the streamer for incremental generation
107
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
 
 
 
108
 
109
+ # Generate in a separate thread to enable streaming
110
+ generation_kwargs = {
111
+ "input_ids": tokenizer.encode(prompt, return_tensors="pt").to("cuda"),
112
+ "max_new_tokens": max_new_tokens,
113
+ "temperature": temperature,
114
+ "top_p": top_p,
115
+ "do_sample": True,
116
+ "streamer": streamer,
117
  }
118
 
119
+ # Start generation in a thread
120
+ thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
121
+ thread.start()
122
 
123
+ # Stream the output
124
+ generated_text = ""
125
+ for new_text in streamer:
126
+ generated_text += new_text
127
+ yield generated_text
 
 
 
 
 
128
 
129
  def respond(
130
  message,
 
134
  temperature,
135
  top_p,
136
  ):
137
+ """Respond to user messages"""
138
+ global model_loaded, model_loading
139
+
140
+ # Check if model is loaded
141
+ if not model_loaded:
142
+ if model_loading:
143
+ yield "⌛ The model is still loading. This can take a few minutes on first startup. Please wait or try again later."
144
+ return
145
+ else:
146
+ # Try loading the model if it hasn't started yet
147
+ if not threading.active_count() > 1: # No background thread running
148
+ threading.Thread(target=load_model_in_thread, daemon=True).start()
149
+ yield "⌛ Starting model load now. Please wait a moment and try again."
150
+ return
151
 
152
  # Create the messages list
153
  messages = [{"role": "system", "content": system_message}]
 
160
 
161
  messages.append({"role": "user", "content": message})
162
 
163
+ # Generate and stream the response
164
+ try:
165
+ for response in generate_response(
166
+ messages,
167
+ max_new_tokens=max_tokens,
168
+ temperature=temperature,
169
+ top_p=top_p
170
+ ):
171
+ yield response
172
+ except Exception as e:
173
+ print(f"Error generating response: {str(e)}")
174
+ traceback.print_exc()
175
+ yield f"An error occurred while generating the response: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  """
178
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
 
191
  label="Top-p (nucleus sampling)",
192
  ),
193
  ],
194
+ description="""This interface uses a fine-tuned Mistral model for Microsoft 365 data management.
195
+ The model is loaded directly on the L40 GPU for optimal performance.
196
+ First-time loading may take a few minutes."""
197
  )
198
 
199