import gradio as gr import os import time import json import requests import threading """ This app uses the Hugging Face Inference API to generate responses from the Trinoid/Data_Management_Mistral model. """ # Get token from environment HF_TOKEN = os.environ.get("HF_TOKEN") print(f"HF_TOKEN is {'available' if HF_TOKEN else 'not available'}") # Setup API for the Hugging Face Inference API MODEL_ID = "Trinoid/Data_Management_Mistral" API_URL = f"https://api-inference.huggingface.co/models/{MODEL_ID}" headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {} # Check if model exists try: print(f"Checking if model {MODEL_ID} exists...") response = requests.get(API_URL, headers=headers) print(f"Status: {response.status_code}") if response.status_code == 200: print("Model exists and is accessible") print(f"Response: {response.text[:200]}...") else: print(f"Response: {response.text}") except Exception as e: print(f"Error checking model: {str(e)}") # Global variable to track model status model_loaded = False estimated_time = None use_simple_format = True # Toggle to use simpler format instead of chat format def format_prompt(messages): """Format chat messages into a text prompt that Mistral models can understand""" if use_simple_format: # Simple format - just extract the message content system = next((m["content"] for m in messages if m["role"] == "system"), "") last_user_msg = next((m["content"] for m in reversed(messages) if m["role"] == "user"), "") if system: return f"{system}\n\nQuestion: {last_user_msg}\n\nAnswer:" else: return f"Question: {last_user_msg}\n\nAnswer:" else: # Chat format for Mistral models formatted = "" for msg in messages: if msg["role"] == "system": formatted += f"[INST] {msg['content']} [/INST]\n" elif msg["role"] == "user": formatted += f"[INST] {msg['content']} [/INST]" elif msg["role"] == "assistant": formatted += f" {msg['content']} \n" return formatted def query_model_text_generation(prompt, parameters=None): """Query the model using the text generation API endpoint""" payload = { "inputs": prompt, } if parameters: payload["parameters"] = parameters print(f"Sending text generation query to API...") print(f"Prompt: {prompt[:100]}...") try: # Try with longer timeout response = requests.post( API_URL, headers=headers, json=payload, timeout=180 # 3 minute timeout ) print(f"API response status: {response.status_code}") # If successful, return the response if response.status_code == 200: print(f"Success! Response: {str(response.text)[:200]}...") return response.json() # If model is loading, handle it elif response.status_code == 503 and "estimated_time" in response.json(): est_time = response.json()["estimated_time"] global estimated_time estimated_time = est_time print(f"Model is loading. Estimated time: {est_time:.2f} seconds") return None # For other errors else: print(f"API error: {response.text}") return None except Exception as e: print(f"Request exception: {str(e)}") return None def respond( message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, ): """Respond to user messages""" # Create the messages list messages = [{"role": "system", "content": system_message}] for val in history: if val[0]: messages.append({"role": "user", "content": val[0]}) if val[1]: messages.append({"role": "assistant", "content": val[1]}) messages.append({"role": "user", "content": message}) # Format the prompt prompt = format_prompt(messages) # Set up the generation parameters parameters = { "max_new_tokens": max_tokens, "temperature": temperature, "top_p": top_p, "do_sample": True, "return_full_text": False # Only return the generated text, not the prompt } # Initial message about model status global estimated_time if estimated_time: initial_msg = f"⌛ The model is loading... estimated time: {estimated_time:.0f} seconds. Please be patient." else: initial_msg = "⌛ Working on your request..." yield initial_msg # Try multiple times with increasing waits max_retries = 6 for attempt in range(max_retries): # Check if this is a retry if attempt > 0: wait_time = min(60, 10 * attempt) yield f"⌛ Still working on your request... (attempt {attempt+1}/{max_retries})" time.sleep(wait_time) try: # Query the model using text generation result = query_model_text_generation(prompt, parameters) if result: # Handle different response formats if isinstance(result, list) and len(result) > 0: if isinstance(result[0], dict) and "generated_text" in result[0]: yield result[0]["generated_text"] return if isinstance(result, dict) and "generated_text" in result: yield result["generated_text"] return # String or other format yield str(result) return # If model is still loading, get the latest estimate if estimated_time and attempt < max_retries - 1: try: response = requests.get(API_URL, headers=headers) if response.status_code == 503 and "estimated_time" in response.json(): estimated_time = response.json()["estimated_time"] print(f"Updated loading time: {estimated_time:.0f} seconds") except: pass except Exception as e: print(f"Error in attempt {attempt+1}: {str(e)}") if attempt == max_retries - 1: yield f"""❌ Sorry, I couldn't generate a response after multiple attempts. Error details: {str(e)} Please try again later or contact support if this persists.""" # If all retries failed yield """❌ The model couldn't be accessed after multiple attempts. This could be due to: 1. Heavy server load 2. The model being too large for the current hardware 3. Temporary service issues Please try again later. For best results with large models like Mistral-7B, consider: - Using a smaller model - Creating a 4-bit quantized version - Using Hugging Face Inference Endpoints instead of Spaces""" """ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface """ demo = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox(value="You are a data management expert specializing in Microsoft 365 services.", label="System message"), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)", ), ], description="""This interface uses a fine-tuned Mistral model for Microsoft 365 data management. First requests may take 2-3 minutes as the model loads.""" ) if __name__ == "__main__": demo.launch()