import gradio as gr import os import json import re from mistralai import Mistral # Initialize Mistral client MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") client = Mistral(api_key=MISTRAL_API_KEY) if MISTRAL_API_KEY else None PROMPT_TEMPLATE = """ You are a clinical triage assistant. A patient describes their symptoms as follows: "{symptoms}" Based on standard triage protocols, output ONLY the following fields in JSON: 1. urgency_level: [Low, Moderate, High, Emergency] 2. possible_condition: A concise possible diagnosis 3. recommended_action: Clear next steps for the patient Respond only in valid JSON. Do not include triple backticks or markdown formatting. """ def build_prompt(symptom_description: str) -> str: """ Construct a natural language prompt based on symptom description. Args: symptom_description (str): Free-text input of patient symptoms Returns: str: LLM-ready prompt """ return PROMPT_TEMPLATE.format(symptoms=symptom_description) def extract_json(text: str) -> str: """ Extract JSON object from response string, including inside code blocks. Args: text (str): Raw response from LLM Returns: str: JSON-formatted string """ # Match {...} pattern even inside markdown blocks match = re.search(r"\{.*\}", text, re.DOTALL) return match.group(0) if match else text.strip() def triage_response(symptoms: str) -> dict: """ Sends a triage prompt to Mistral API and parses the JSON response. Args: symptoms (str): Description of patient symptoms Returns: dict: Dictionary with urgency_level, possible_condition, and recommended_action """ if not client: return {"error": "MISTRAL_API_KEY environment variable not set"} if not symptoms.strip(): return {"error": "Please provide a description of symptoms"} prompt = build_prompt(symptoms) try: # Create chat completion using Mistral client response = client.chat.complete( model="mistral-large-latest", messages=[ { "role": "user", "content": prompt } ], temperature=0.3, max_tokens=500 ) # Extract content from response raw_output = response.choices[0].message.content cleaned_output = extract_json(raw_output) try: result = json.loads(cleaned_output) # Validate required fields required_fields = ["urgency_level", "possible_condition", "recommended_action"] if all(field in result for field in required_fields): return result else: return { "error": "Missing required fields in response", "raw_output": raw_output, "parsed_result": result } except json.JSONDecodeError as e: return { "error": "Invalid JSON format", "raw_output": raw_output, "exception": str(e) } except Exception as e: return {"error": f"Mistral API error: {str(e)}"} def format_triage_output(result: dict) -> str: """ Format the triage result for better display in Gradio. """ if "error" in result: return f"❌ Error: {result['error']}" urgency_icons = { "Low": "🟢", "Moderate": "🟡", "High": "🟠", "Emergency": "🔴" } urgency = result.get("urgency_level", "Unknown") icon = urgency_icons.get(urgency, "⚪") return f""" {icon} **Urgency Level:** {urgency} 🩺 **Possible Condition:** {result.get("possible_condition", "Not specified")} 📋 **Recommended Action:** {result.get("recommended_action", "Not specified")} --- *This is for informational purposes only. Always consult with healthcare professionals for medical advice.* """.strip() def gradio_triage_wrapper(symptoms: str) -> str: """ Wrapper function for Gradio that returns formatted text output. """ result = triage_response(symptoms) return format_triage_output(result) # Create Gradio interface demo = gr.Interface( fn=gradio_triage_wrapper, inputs=gr.Textbox( lines=4, label="Enter patient's symptoms", placeholder="Describe the symptoms in detail (e.g., 'Severe chest pain for 30 minutes, difficulty breathing, sweating')" ), outputs=gr.Markdown(label="Triage Assessment"), title="🏥 Clinical Triage Assistant", description=""" An AI-powered tool that provides preliminary triage assessment based on symptom descriptions. **Disclaimer:** This tool is for educational and informational purposes only. It should not replace professional medical advice, diagnosis, or treatment. """, examples=[ ["Severe chest pain radiating to left arm, difficulty breathing, sweating"], ["Mild headache and runny nose for 2 days"], ["High fever (102°F), severe abdominal pain, vomiting"], ["Sprained ankle, mild swelling, can walk with discomfort"] ], theme=gr.themes.Soft() ) if __name__ == "__main__": # Check if API key is available if not MISTRAL_API_KEY: print("⚠️ Warning: MISTRAL_API_KEY environment variable not set!") print("Please set your Mistral API key:") print("export MISTRAL_API_KEY='your-api-key-here'") print("\nAlso install the Mistral AI client:") print("pip install mistralai") demo.launch(mcp_server=True)