Frankie-walsh4's picture
fixes
b07886a
import gradio as gr
import os
import time
import json
import requests
import threading
"""
This app uses the Hugging Face Inference API to generate responses from the
Trinoid/Data_Management_Mistral model.
"""
# Get token from environment
HF_TOKEN = os.environ.get("HF_TOKEN")
print(f"HF_TOKEN is {'available' if HF_TOKEN else 'not available'}")
# Setup API for the Hugging Face Inference API
MODEL_ID = "Trinoid/Data_Management_Mistral"
API_URL = f"https://api-inference.huggingface.co/models/{MODEL_ID}"
headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
# Check if model exists
try:
print(f"Checking if model {MODEL_ID} exists...")
response = requests.get(API_URL, headers=headers)
print(f"Status: {response.status_code}")
if response.status_code == 200:
print("Model exists and is accessible")
print(f"Response: {response.text[:200]}...")
else:
print(f"Response: {response.text}")
except Exception as e:
print(f"Error checking model: {str(e)}")
# Global variable to track model status
model_loaded = False
estimated_time = None
use_simple_format = True # Toggle to use simpler format instead of chat format
def format_prompt(messages):
"""Format chat messages into a text prompt that Mistral models can understand"""
if use_simple_format:
# Simple format - just extract the message content
system = next((m["content"] for m in messages if m["role"] == "system"), "")
last_user_msg = next((m["content"] for m in reversed(messages) if m["role"] == "user"), "")
if system:
return f"{system}\n\nQuestion: {last_user_msg}\n\nAnswer:"
else:
return f"Question: {last_user_msg}\n\nAnswer:"
else:
# Chat format for Mistral models
formatted = ""
for msg in messages:
if msg["role"] == "system":
formatted += f"<s>[INST] {msg['content']} [/INST]</s>\n"
elif msg["role"] == "user":
formatted += f"<s>[INST] {msg['content']} [/INST]"
elif msg["role"] == "assistant":
formatted += f" {msg['content']} </s>\n"
return formatted
def query_model_text_generation(prompt, parameters=None):
"""Query the model using the text generation API endpoint"""
payload = {
"inputs": prompt,
}
if parameters:
payload["parameters"] = parameters
print(f"Sending text generation query to API...")
print(f"Prompt: {prompt[:100]}...")
try:
# Try with longer timeout
response = requests.post(
API_URL,
headers=headers,
json=payload,
timeout=180 # 3 minute timeout
)
print(f"API response status: {response.status_code}")
# If successful, return the response
if response.status_code == 200:
print(f"Success! Response: {str(response.text)[:200]}...")
return response.json()
# If model is loading, handle it
elif response.status_code == 503 and "estimated_time" in response.json():
est_time = response.json()["estimated_time"]
global estimated_time
estimated_time = est_time
print(f"Model is loading. Estimated time: {est_time:.2f} seconds")
return None
# For other errors
else:
print(f"API error: {response.text}")
return None
except Exception as e:
print(f"Request exception: {str(e)}")
return None
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
"""Respond to user messages"""
# Create the messages list
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
# Format the prompt
prompt = format_prompt(messages)
# Set up the generation parameters
parameters = {
"max_new_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"do_sample": True,
"return_full_text": False # Only return the generated text, not the prompt
}
# Initial message about model status
global estimated_time
if estimated_time:
initial_msg = f"βŒ› The model is loading... estimated time: {estimated_time:.0f} seconds. Please be patient."
else:
initial_msg = "βŒ› Working on your request..."
yield initial_msg
# Try multiple times with increasing waits
max_retries = 6
for attempt in range(max_retries):
# Check if this is a retry
if attempt > 0:
wait_time = min(60, 10 * attempt)
yield f"βŒ› Still working on your request... (attempt {attempt+1}/{max_retries})"
time.sleep(wait_time)
try:
# Query the model using text generation
result = query_model_text_generation(prompt, parameters)
if result:
# Handle different response formats
if isinstance(result, list) and len(result) > 0:
if isinstance(result[0], dict) and "generated_text" in result[0]:
yield result[0]["generated_text"]
return
if isinstance(result, dict) and "generated_text" in result:
yield result["generated_text"]
return
# String or other format
yield str(result)
return
# If model is still loading, get the latest estimate
if estimated_time and attempt < max_retries - 1:
try:
response = requests.get(API_URL, headers=headers)
if response.status_code == 503 and "estimated_time" in response.json():
estimated_time = response.json()["estimated_time"]
print(f"Updated loading time: {estimated_time:.0f} seconds")
except:
pass
except Exception as e:
print(f"Error in attempt {attempt+1}: {str(e)}")
if attempt == max_retries - 1:
yield f"""❌ Sorry, I couldn't generate a response after multiple attempts.
Error details: {str(e)}
Please try again later or contact support if this persists."""
# If all retries failed
yield """❌ The model couldn't be accessed after multiple attempts.
This could be due to:
1. Heavy server load
2. The model being too large for the current hardware
3. Temporary service issues
Please try again later. For best results with large models like Mistral-7B, consider:
- Using a smaller model
- Creating a 4-bit quantized version
- Using Hugging Face Inference Endpoints instead of Spaces"""
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a data management expert specializing in Microsoft 365 services.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
description="""This interface uses a fine-tuned Mistral model for Microsoft 365 data management.
First requests may take 2-3 minutes as the model loads."""
)
if __name__ == "__main__":
demo.launch()