Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import time | |
import json | |
import requests | |
import threading | |
""" | |
This app uses the Hugging Face Inference API to generate responses from the | |
Trinoid/Data_Management_Mistral model. | |
""" | |
# Get token from environment | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
print(f"HF_TOKEN is {'available' if HF_TOKEN else 'not available'}") | |
# Setup API for the Hugging Face Inference API | |
MODEL_ID = "Trinoid/Data_Management_Mistral" | |
API_URL = f"https://api-inference.huggingface.co/models/{MODEL_ID}" | |
headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {} | |
# Check if model exists | |
try: | |
print(f"Checking if model {MODEL_ID} exists...") | |
response = requests.get(API_URL, headers=headers) | |
print(f"Status: {response.status_code}") | |
if response.status_code == 200: | |
print("Model exists and is accessible") | |
print(f"Response: {response.text[:200]}...") | |
else: | |
print(f"Response: {response.text}") | |
except Exception as e: | |
print(f"Error checking model: {str(e)}") | |
# Global variable to track model status | |
model_loaded = False | |
estimated_time = None | |
use_simple_format = True # Toggle to use simpler format instead of chat format | |
def format_prompt(messages): | |
"""Format chat messages into a text prompt that Mistral models can understand""" | |
if use_simple_format: | |
# Simple format - just extract the message content | |
system = next((m["content"] for m in messages if m["role"] == "system"), "") | |
last_user_msg = next((m["content"] for m in reversed(messages) if m["role"] == "user"), "") | |
if system: | |
return f"{system}\n\nQuestion: {last_user_msg}\n\nAnswer:" | |
else: | |
return f"Question: {last_user_msg}\n\nAnswer:" | |
else: | |
# Chat format for Mistral models | |
formatted = "" | |
for msg in messages: | |
if msg["role"] == "system": | |
formatted += f"<s>[INST] {msg['content']} [/INST]</s>\n" | |
elif msg["role"] == "user": | |
formatted += f"<s>[INST] {msg['content']} [/INST]" | |
elif msg["role"] == "assistant": | |
formatted += f" {msg['content']} </s>\n" | |
return formatted | |
def query_model_text_generation(prompt, parameters=None): | |
"""Query the model using the text generation API endpoint""" | |
payload = { | |
"inputs": prompt, | |
} | |
if parameters: | |
payload["parameters"] = parameters | |
print(f"Sending text generation query to API...") | |
print(f"Prompt: {prompt[:100]}...") | |
try: | |
# Try with longer timeout | |
response = requests.post( | |
API_URL, | |
headers=headers, | |
json=payload, | |
timeout=180 # 3 minute timeout | |
) | |
print(f"API response status: {response.status_code}") | |
# If successful, return the response | |
if response.status_code == 200: | |
print(f"Success! Response: {str(response.text)[:200]}...") | |
return response.json() | |
# If model is loading, handle it | |
elif response.status_code == 503 and "estimated_time" in response.json(): | |
est_time = response.json()["estimated_time"] | |
global estimated_time | |
estimated_time = est_time | |
print(f"Model is loading. Estimated time: {est_time:.2f} seconds") | |
return None | |
# For other errors | |
else: | |
print(f"API error: {response.text}") | |
return None | |
except Exception as e: | |
print(f"Request exception: {str(e)}") | |
return None | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
"""Respond to user messages""" | |
# Create the messages list | |
messages = [{"role": "system", "content": system_message}] | |
for val in history: | |
if val[0]: | |
messages.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
messages.append({"role": "user", "content": message}) | |
# Format the prompt | |
prompt = format_prompt(messages) | |
# Set up the generation parameters | |
parameters = { | |
"max_new_tokens": max_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
"do_sample": True, | |
"return_full_text": False # Only return the generated text, not the prompt | |
} | |
# Initial message about model status | |
global estimated_time | |
if estimated_time: | |
initial_msg = f"β The model is loading... estimated time: {estimated_time:.0f} seconds. Please be patient." | |
else: | |
initial_msg = "β Working on your request..." | |
yield initial_msg | |
# Try multiple times with increasing waits | |
max_retries = 6 | |
for attempt in range(max_retries): | |
# Check if this is a retry | |
if attempt > 0: | |
wait_time = min(60, 10 * attempt) | |
yield f"β Still working on your request... (attempt {attempt+1}/{max_retries})" | |
time.sleep(wait_time) | |
try: | |
# Query the model using text generation | |
result = query_model_text_generation(prompt, parameters) | |
if result: | |
# Handle different response formats | |
if isinstance(result, list) and len(result) > 0: | |
if isinstance(result[0], dict) and "generated_text" in result[0]: | |
yield result[0]["generated_text"] | |
return | |
if isinstance(result, dict) and "generated_text" in result: | |
yield result["generated_text"] | |
return | |
# String or other format | |
yield str(result) | |
return | |
# If model is still loading, get the latest estimate | |
if estimated_time and attempt < max_retries - 1: | |
try: | |
response = requests.get(API_URL, headers=headers) | |
if response.status_code == 503 and "estimated_time" in response.json(): | |
estimated_time = response.json()["estimated_time"] | |
print(f"Updated loading time: {estimated_time:.0f} seconds") | |
except: | |
pass | |
except Exception as e: | |
print(f"Error in attempt {attempt+1}: {str(e)}") | |
if attempt == max_retries - 1: | |
yield f"""β Sorry, I couldn't generate a response after multiple attempts. | |
Error details: {str(e)} | |
Please try again later or contact support if this persists.""" | |
# If all retries failed | |
yield """β The model couldn't be accessed after multiple attempts. | |
This could be due to: | |
1. Heavy server load | |
2. The model being too large for the current hardware | |
3. Temporary service issues | |
Please try again later. For best results with large models like Mistral-7B, consider: | |
- Using a smaller model | |
- Creating a 4-bit quantized version | |
- Using Hugging Face Inference Endpoints instead of Spaces""" | |
""" | |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface | |
""" | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox(value="You are a data management expert specializing in Microsoft 365 services.", label="System message"), | |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)", | |
), | |
], | |
description="""This interface uses a fine-tuned Mistral model for Microsoft 365 data management. | |
First requests may take 2-3 minutes as the model loads.""" | |
) | |
if __name__ == "__main__": | |
demo.launch() |