Frankie-walsh4 commited on
Commit
b07886a
·
1 Parent(s): a2cebb0
Files changed (1) hide show
  1. app.py +57 -31
app.py CHANGED
@@ -25,6 +25,7 @@ try:
25
  print(f"Status: {response.status_code}")
26
  if response.status_code == 200:
27
  print("Model exists and is accessible")
 
28
  else:
29
  print(f"Response: {response.text}")
30
  except Exception as e:
@@ -32,22 +33,46 @@ except Exception as e:
32
 
33
  # Global variable to track model status
34
  model_loaded = False
35
- model_loading = False
36
  estimated_time = None
 
37
 
38
- def query_model(messages, parameters=None):
39
- """Query the model using the Inference API"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  payload = {
41
- "inputs": messages,
42
  }
43
 
44
  if parameters:
45
  payload["parameters"] = parameters
46
 
47
- print(f"Sending query to API...")
 
48
 
49
  try:
50
- # Single attempt with longer timeout
51
  response = requests.post(
52
  API_URL,
53
  headers=headers,
@@ -59,6 +84,7 @@ def query_model(messages, parameters=None):
59
 
60
  # If successful, return the response
61
  if response.status_code == 200:
 
62
  return response.json()
63
 
64
  # If model is loading, handle it
@@ -88,7 +114,7 @@ def respond(
88
  ):
89
  """Respond to user messages"""
90
 
91
- # Create the messages list in chat format
92
  messages = [{"role": "system", "content": system_message}]
93
 
94
  for val in history:
@@ -99,12 +125,16 @@ def respond(
99
 
100
  messages.append({"role": "user", "content": message})
101
 
102
- # Set up the inference parameters
 
 
 
103
  parameters = {
104
  "max_new_tokens": max_tokens,
105
  "temperature": temperature,
106
  "top_p": top_p,
107
- "do_sample": True
 
108
  }
109
 
110
  # Initial message about model status
@@ -126,43 +156,38 @@ def respond(
126
  time.sleep(wait_time)
127
 
128
  try:
129
- # Query the model
130
- result = query_model(messages, parameters)
131
 
132
  if result:
133
  # Handle different response formats
134
-
135
- # List format with generated_text
136
  if isinstance(result, list) and len(result) > 0:
137
- if "generated_text" in result[0]:
138
  yield result[0]["generated_text"]
139
  return
140
-
141
- # Direct message format
142
  if isinstance(result, dict) and "generated_text" in result:
143
  yield result["generated_text"]
144
  return
145
-
146
- # String format
147
- if isinstance(result, str):
148
- yield result
149
- return
150
-
151
- # Raw format as fallback
152
  yield str(result)
153
  return
154
 
155
  # If model is still loading, get the latest estimate
156
  if estimated_time and attempt < max_retries - 1:
157
- response = requests.get(API_URL, headers=headers)
158
- if response.status_code == 503 and "estimated_time" in response.json():
159
- estimated_time = response.json()["estimated_time"]
160
- print(f"Updated loading time: {estimated_time:.0f} seconds")
 
 
 
161
 
162
  except Exception as e:
163
  print(f"Error in attempt {attempt+1}: {str(e)}")
164
  if attempt == max_retries - 1:
165
- yield f"""❌ Sorry, I couldn't generate a response after several attempts.
166
 
167
  Error details: {str(e)}
168
 
@@ -176,7 +201,10 @@ This could be due to:
176
  2. The model being too large for the current hardware
177
  3. Temporary service issues
178
 
179
- Please try again later."""
 
 
 
180
 
181
 
182
  """
@@ -197,11 +225,9 @@ demo = gr.ChatInterface(
197
  ),
198
  ],
199
  description="""This interface uses a fine-tuned Mistral model for Microsoft 365 data management.
200
- The model is accessed via the Hugging Face Inference API.
201
  First requests may take 2-3 minutes as the model loads."""
202
  )
203
 
204
 
205
  if __name__ == "__main__":
206
- # Launch the app
207
  demo.launch()
 
25
  print(f"Status: {response.status_code}")
26
  if response.status_code == 200:
27
  print("Model exists and is accessible")
28
+ print(f"Response: {response.text[:200]}...")
29
  else:
30
  print(f"Response: {response.text}")
31
  except Exception as e:
 
33
 
34
  # Global variable to track model status
35
  model_loaded = False
 
36
  estimated_time = None
37
+ use_simple_format = True # Toggle to use simpler format instead of chat format
38
 
39
+ def format_prompt(messages):
40
+ """Format chat messages into a text prompt that Mistral models can understand"""
41
+ if use_simple_format:
42
+ # Simple format - just extract the message content
43
+ system = next((m["content"] for m in messages if m["role"] == "system"), "")
44
+ last_user_msg = next((m["content"] for m in reversed(messages) if m["role"] == "user"), "")
45
+
46
+ if system:
47
+ return f"{system}\n\nQuestion: {last_user_msg}\n\nAnswer:"
48
+ else:
49
+ return f"Question: {last_user_msg}\n\nAnswer:"
50
+ else:
51
+ # Chat format for Mistral models
52
+ formatted = ""
53
+ for msg in messages:
54
+ if msg["role"] == "system":
55
+ formatted += f"<s>[INST] {msg['content']} [/INST]</s>\n"
56
+ elif msg["role"] == "user":
57
+ formatted += f"<s>[INST] {msg['content']} [/INST]"
58
+ elif msg["role"] == "assistant":
59
+ formatted += f" {msg['content']} </s>\n"
60
+ return formatted
61
+
62
+ def query_model_text_generation(prompt, parameters=None):
63
+ """Query the model using the text generation API endpoint"""
64
  payload = {
65
+ "inputs": prompt,
66
  }
67
 
68
  if parameters:
69
  payload["parameters"] = parameters
70
 
71
+ print(f"Sending text generation query to API...")
72
+ print(f"Prompt: {prompt[:100]}...")
73
 
74
  try:
75
+ # Try with longer timeout
76
  response = requests.post(
77
  API_URL,
78
  headers=headers,
 
84
 
85
  # If successful, return the response
86
  if response.status_code == 200:
87
+ print(f"Success! Response: {str(response.text)[:200]}...")
88
  return response.json()
89
 
90
  # If model is loading, handle it
 
114
  ):
115
  """Respond to user messages"""
116
 
117
+ # Create the messages list
118
  messages = [{"role": "system", "content": system_message}]
119
 
120
  for val in history:
 
125
 
126
  messages.append({"role": "user", "content": message})
127
 
128
+ # Format the prompt
129
+ prompt = format_prompt(messages)
130
+
131
+ # Set up the generation parameters
132
  parameters = {
133
  "max_new_tokens": max_tokens,
134
  "temperature": temperature,
135
  "top_p": top_p,
136
+ "do_sample": True,
137
+ "return_full_text": False # Only return the generated text, not the prompt
138
  }
139
 
140
  # Initial message about model status
 
156
  time.sleep(wait_time)
157
 
158
  try:
159
+ # Query the model using text generation
160
+ result = query_model_text_generation(prompt, parameters)
161
 
162
  if result:
163
  # Handle different response formats
 
 
164
  if isinstance(result, list) and len(result) > 0:
165
+ if isinstance(result[0], dict) and "generated_text" in result[0]:
166
  yield result[0]["generated_text"]
167
  return
168
+
 
169
  if isinstance(result, dict) and "generated_text" in result:
170
  yield result["generated_text"]
171
  return
172
+
173
+ # String or other format
 
 
 
 
 
174
  yield str(result)
175
  return
176
 
177
  # If model is still loading, get the latest estimate
178
  if estimated_time and attempt < max_retries - 1:
179
+ try:
180
+ response = requests.get(API_URL, headers=headers)
181
+ if response.status_code == 503 and "estimated_time" in response.json():
182
+ estimated_time = response.json()["estimated_time"]
183
+ print(f"Updated loading time: {estimated_time:.0f} seconds")
184
+ except:
185
+ pass
186
 
187
  except Exception as e:
188
  print(f"Error in attempt {attempt+1}: {str(e)}")
189
  if attempt == max_retries - 1:
190
+ yield f"""❌ Sorry, I couldn't generate a response after multiple attempts.
191
 
192
  Error details: {str(e)}
193
 
 
201
  2. The model being too large for the current hardware
202
  3. Temporary service issues
203
 
204
+ Please try again later. For best results with large models like Mistral-7B, consider:
205
+ - Using a smaller model
206
+ - Creating a 4-bit quantized version
207
+ - Using Hugging Face Inference Endpoints instead of Spaces"""
208
 
209
 
210
  """
 
225
  ),
226
  ],
227
  description="""This interface uses a fine-tuned Mistral model for Microsoft 365 data management.
 
228
  First requests may take 2-3 minutes as the model loads."""
229
  )
230
 
231
 
232
  if __name__ == "__main__":
 
233
  demo.launch()