Frankie-walsh4 commited on
Commit
387c509
Β·
1 Parent(s): 8c02af0
Files changed (1) hide show
  1. app.py +76 -9
app.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  import time
5
  import json
6
  import requests
 
7
  from huggingface_hub.errors import HfHubHTTPError
8
 
9
  """
@@ -25,6 +26,55 @@ else:
25
  API_URL = "https://api-inference.huggingface.co/models/Trinoid/Data_Management_Mistral"
26
  headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def respond(
30
  message,
@@ -34,6 +84,12 @@ def respond(
34
  temperature,
35
  top_p,
36
  ):
 
 
 
 
 
 
37
  messages = [{"role": "system", "content": system_message}]
38
 
39
  for val in history:
@@ -50,7 +106,7 @@ def respond(
50
  print(f"Sending messages: {json.dumps(messages, indent=2)}")
51
 
52
  # Try to initialize the model with retries
53
- max_retries = 3
54
  retry_count = 0
55
 
56
  # Try both methods: InferenceClient and direct API call
@@ -68,12 +124,15 @@ def respond(
68
  stream=True,
69
  temperature=temperature,
70
  top_p=top_p,
 
71
  ):
72
  token = message.choices[0].delta.content
73
  if token:
74
  response += token
75
  yield response
 
76
  # If we got here, we were successful
 
77
  break
78
  else:
79
  # Method 2: Direct API call
@@ -88,7 +147,7 @@ def respond(
88
  }
89
 
90
  print(f"Making direct API call to {API_URL}")
91
- api_response = requests.post(API_URL, headers=headers, json=payload)
92
  print(f"API response status: {api_response.status_code}")
93
 
94
  if api_response.status_code == 200:
@@ -97,6 +156,7 @@ def respond(
97
  if isinstance(result, list) and len(result) > 0 and "generated_text" in result[0]:
98
  response = result[0]["generated_text"]
99
  yield response
 
100
  break
101
  else:
102
  print(f"Unexpected API response format: {result}")
@@ -105,8 +165,9 @@ def respond(
105
  print(f"API error: {api_response.text}")
106
  if api_response.status_code == 504 and retry_count < max_retries - 1:
107
  retry_count += 1
 
108
  yield f"βŒ› Model is warming up, please wait... (Attempt {retry_count}/{max_retries})"
109
- time.sleep(10)
110
  else:
111
  yield f"❌ API error: {api_response.status_code} - {api_response.text}"
112
  break
@@ -118,15 +179,16 @@ def respond(
118
 
119
  if "504 Server Error: Gateway Timeout" in error_message:
120
  if retry_count < max_retries - 1:
121
- wait_time = 10 # seconds
122
  print(f"Model timed out. Waiting {wait_time} seconds before retry {retry_count}/{max_retries}...")
123
  yield f"βŒ› Model is warming up, please wait... (Attempt {retry_count}/{max_retries})"
124
  time.sleep(wait_time)
125
- # Try direct API on next attempt
126
- use_direct_api = True
 
127
  else:
128
  print("All retries failed.")
129
- yield "❌ The model timed out after multiple attempts. Try again in a few minutes."
130
  break
131
  else:
132
  print(f"Non-timeout error: {error_message}")
@@ -146,7 +208,7 @@ For information on how to customize the ChatInterface, peruse the gradio docs: h
146
  demo = gr.ChatInterface(
147
  respond,
148
  additional_inputs=[
149
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
150
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
151
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
152
  gr.Slider(
@@ -157,9 +219,14 @@ demo = gr.ChatInterface(
157
  label="Top-p (nucleus sampling)",
158
  ),
159
  ],
160
- description="This interface uses your fine-tuned Mistral model for Microsoft 365 data management. The first request may take some time as the model loads."
 
161
  )
162
 
163
 
164
  if __name__ == "__main__":
 
 
 
 
165
  demo.launch()
 
4
  import time
5
  import json
6
  import requests
7
+ import threading
8
  from huggingface_hub.errors import HfHubHTTPError
9
 
10
  """
 
26
  API_URL = "https://api-inference.huggingface.co/models/Trinoid/Data_Management_Mistral"
27
  headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
28
 
29
+ # Global variable to track if model is warmed up
30
+ model_warmed_up = False
31
+ warming_up = False
32
+
33
+ def warm_up_model():
34
+ """Send a warmup request to get the model loaded before user interaction"""
35
+ global warming_up, model_warmed_up
36
+
37
+ if warming_up:
38
+ return # Already warming up
39
+
40
+ warming_up = True
41
+ print("Starting model warm-up...")
42
+
43
+ # Simple warmup message
44
+ warmup_messages = [
45
+ {"role": "system", "content": "You are a helpful assistant."},
46
+ {"role": "user", "content": "Hello"}
47
+ ]
48
+
49
+ # Try direct API approach first
50
+ try:
51
+ payload = {
52
+ "inputs": warmup_messages,
53
+ "parameters": {
54
+ "max_new_tokens": 5, # Just need a short response
55
+ "temperature": 0.1,
56
+ "top_p": 0.95,
57
+ },
58
+ "stream": False,
59
+ }
60
+
61
+ print("Sending warmup request...")
62
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=60)
63
+
64
+ if response.status_code == 200:
65
+ print("Warmup successful!")
66
+ model_warmed_up = True
67
+ else:
68
+ print(f"Warmup API call failed with status {response.status_code}")
69
+ print(f"Response: {response.text}")
70
+ except Exception as e:
71
+ print(f"Warmup exception: {str(e)}")
72
+
73
+ # Even if it failed, mark as no longer warming up
74
+ warming_up = False
75
+
76
+ # Start warmup in background thread
77
+ threading.Thread(target=warm_up_model, daemon=True).start()
78
 
79
  def respond(
80
  message,
 
84
  temperature,
85
  top_p,
86
  ):
87
+ global model_warmed_up
88
+
89
+ # If model isn't warmed up yet, give a message
90
+ if not model_warmed_up:
91
+ yield "βŒ› Model is being loaded for the first time, this may take up to a minute. Please be patient..."
92
+
93
  messages = [{"role": "system", "content": system_message}]
94
 
95
  for val in history:
 
106
  print(f"Sending messages: {json.dumps(messages, indent=2)}")
107
 
108
  # Try to initialize the model with retries
109
+ max_retries = 5 # Increased from 3 to 5
110
  retry_count = 0
111
 
112
  # Try both methods: InferenceClient and direct API call
 
124
  stream=True,
125
  temperature=temperature,
126
  top_p=top_p,
127
+ timeout=30, # Increased timeout
128
  ):
129
  token = message.choices[0].delta.content
130
  if token:
131
  response += token
132
  yield response
133
+
134
  # If we got here, we were successful
135
+ model_warmed_up = True
136
  break
137
  else:
138
  # Method 2: Direct API call
 
147
  }
148
 
149
  print(f"Making direct API call to {API_URL}")
150
+ api_response = requests.post(API_URL, headers=headers, json=payload, timeout=60) # Increased timeout
151
  print(f"API response status: {api_response.status_code}")
152
 
153
  if api_response.status_code == 200:
 
156
  if isinstance(result, list) and len(result) > 0 and "generated_text" in result[0]:
157
  response = result[0]["generated_text"]
158
  yield response
159
+ model_warmed_up = True
160
  break
161
  else:
162
  print(f"Unexpected API response format: {result}")
 
165
  print(f"API error: {api_response.text}")
166
  if api_response.status_code == 504 and retry_count < max_retries - 1:
167
  retry_count += 1
168
+ wait_time = 15 # Increased wait time
169
  yield f"βŒ› Model is warming up, please wait... (Attempt {retry_count}/{max_retries})"
170
+ time.sleep(wait_time)
171
  else:
172
  yield f"❌ API error: {api_response.status_code} - {api_response.text}"
173
  break
 
179
 
180
  if "504 Server Error: Gateway Timeout" in error_message:
181
  if retry_count < max_retries - 1:
182
+ wait_time = 15 # Increased wait time
183
  print(f"Model timed out. Waiting {wait_time} seconds before retry {retry_count}/{max_retries}...")
184
  yield f"βŒ› Model is warming up, please wait... (Attempt {retry_count}/{max_retries})"
185
  time.sleep(wait_time)
186
+ # Try direct API on next attempt if we've tried InferenceClient twice
187
+ if retry_count >= 2:
188
+ use_direct_api = True
189
  else:
190
  print("All retries failed.")
191
+ yield "❌ The model timed out after multiple attempts. Your model is probably too large for the free tier. Try again in a few minutes or consider using a smaller model."
192
  break
193
  else:
194
  print(f"Non-timeout error: {error_message}")
 
208
  demo = gr.ChatInterface(
209
  respond,
210
  additional_inputs=[
211
+ gr.Textbox(value="You are a data management expert specializing in Microsoft 365 services.", label="System message"),
212
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
213
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
214
  gr.Slider(
 
219
  label="Top-p (nucleus sampling)",
220
  ),
221
  ],
222
+ description="""This interface uses a fine-tuned Mistral model for Microsoft 365 data management.
223
+ ⚠️ Note: This model needs time to load when first used. You may experience a delay of up to 60 seconds on your first message."""
224
  )
225
 
226
 
227
  if __name__ == "__main__":
228
+ # Start model warmup
229
+ warm_up_model()
230
+
231
+ # Launch the app
232
  demo.launch()