import gradio as gr
import os
import time
import json
import requests
import threading
"""
This app uses the Hugging Face Inference API to generate responses from the
Trinoid/Data_Management_Mistral model.
"""
# Get token from environment
HF_TOKEN = os.environ.get("HF_TOKEN")
print(f"HF_TOKEN is {'available' if HF_TOKEN else 'not available'}")
# Setup API for the Hugging Face Inference API
MODEL_ID = "Trinoid/Data_Management_Mistral"
API_URL = f"https://api-inference.huggingface.co/models/{MODEL_ID}"
headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
# Check if model exists
try:
print(f"Checking if model {MODEL_ID} exists...")
response = requests.get(API_URL, headers=headers)
print(f"Status: {response.status_code}")
if response.status_code == 200:
print("Model exists and is accessible")
print(f"Response: {response.text[:200]}...")
else:
print(f"Response: {response.text}")
except Exception as e:
print(f"Error checking model: {str(e)}")
# Global variable to track model status
model_loaded = False
estimated_time = None
use_simple_format = True # Toggle to use simpler format instead of chat format
def format_prompt(messages):
"""Format chat messages into a text prompt that Mistral models can understand"""
if use_simple_format:
# Simple format - just extract the message content
system = next((m["content"] for m in messages if m["role"] == "system"), "")
last_user_msg = next((m["content"] for m in reversed(messages) if m["role"] == "user"), "")
if system:
return f"{system}\n\nQuestion: {last_user_msg}\n\nAnswer:"
else:
return f"Question: {last_user_msg}\n\nAnswer:"
else:
# Chat format for Mistral models
formatted = ""
for msg in messages:
if msg["role"] == "system":
formatted += f"[INST] {msg['content']} [/INST]\n"
elif msg["role"] == "user":
formatted += f"[INST] {msg['content']} [/INST]"
elif msg["role"] == "assistant":
formatted += f" {msg['content']} \n"
return formatted
def query_model_text_generation(prompt, parameters=None):
"""Query the model using the text generation API endpoint"""
payload = {
"inputs": prompt,
}
if parameters:
payload["parameters"] = parameters
print(f"Sending text generation query to API...")
print(f"Prompt: {prompt[:100]}...")
try:
# Try with longer timeout
response = requests.post(
API_URL,
headers=headers,
json=payload,
timeout=180 # 3 minute timeout
)
print(f"API response status: {response.status_code}")
# If successful, return the response
if response.status_code == 200:
print(f"Success! Response: {str(response.text)[:200]}...")
return response.json()
# If model is loading, handle it
elif response.status_code == 503 and "estimated_time" in response.json():
est_time = response.json()["estimated_time"]
global estimated_time
estimated_time = est_time
print(f"Model is loading. Estimated time: {est_time:.2f} seconds")
return None
# For other errors
else:
print(f"API error: {response.text}")
return None
except Exception as e:
print(f"Request exception: {str(e)}")
return None
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
"""Respond to user messages"""
# Create the messages list
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
# Format the prompt
prompt = format_prompt(messages)
# Set up the generation parameters
parameters = {
"max_new_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"do_sample": True,
"return_full_text": False # Only return the generated text, not the prompt
}
# Initial message about model status
global estimated_time
if estimated_time:
initial_msg = f"⌛ The model is loading... estimated time: {estimated_time:.0f} seconds. Please be patient."
else:
initial_msg = "⌛ Working on your request..."
yield initial_msg
# Try multiple times with increasing waits
max_retries = 6
for attempt in range(max_retries):
# Check if this is a retry
if attempt > 0:
wait_time = min(60, 10 * attempt)
yield f"⌛ Still working on your request... (attempt {attempt+1}/{max_retries})"
time.sleep(wait_time)
try:
# Query the model using text generation
result = query_model_text_generation(prompt, parameters)
if result:
# Handle different response formats
if isinstance(result, list) and len(result) > 0:
if isinstance(result[0], dict) and "generated_text" in result[0]:
yield result[0]["generated_text"]
return
if isinstance(result, dict) and "generated_text" in result:
yield result["generated_text"]
return
# String or other format
yield str(result)
return
# If model is still loading, get the latest estimate
if estimated_time and attempt < max_retries - 1:
try:
response = requests.get(API_URL, headers=headers)
if response.status_code == 503 and "estimated_time" in response.json():
estimated_time = response.json()["estimated_time"]
print(f"Updated loading time: {estimated_time:.0f} seconds")
except:
pass
except Exception as e:
print(f"Error in attempt {attempt+1}: {str(e)}")
if attempt == max_retries - 1:
yield f"""❌ Sorry, I couldn't generate a response after multiple attempts.
Error details: {str(e)}
Please try again later or contact support if this persists."""
# If all retries failed
yield """❌ The model couldn't be accessed after multiple attempts.
This could be due to:
1. Heavy server load
2. The model being too large for the current hardware
3. Temporary service issues
Please try again later. For best results with large models like Mistral-7B, consider:
- Using a smaller model
- Creating a 4-bit quantized version
- Using Hugging Face Inference Endpoints instead of Spaces"""
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a data management expert specializing in Microsoft 365 services.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
description="""This interface uses a fine-tuned Mistral model for Microsoft 365 data management.
First requests may take 2-3 minutes as the model loads."""
)
if __name__ == "__main__":
demo.launch()