lukmanaj's picture
Create app.py
1f95818 verified
import gradio as gr
import os
import json
import re
from mistralai import Mistral
# Initialize Mistral client
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
client = Mistral(api_key=MISTRAL_API_KEY) if MISTRAL_API_KEY else None
PROMPT_TEMPLATE = """
You are a clinical triage assistant. A patient describes their symptoms as follows:
"{symptoms}"
Based on standard triage protocols, output ONLY the following fields in JSON:
1. urgency_level: [Low, Moderate, High, Emergency]
2. possible_condition: A concise possible diagnosis
3. recommended_action: Clear next steps for the patient
Respond only in valid JSON. Do not include triple backticks or markdown formatting.
"""
def build_prompt(symptom_description: str) -> str:
"""
Construct a natural language prompt based on symptom description.
Args:
symptom_description (str): Free-text input of patient symptoms
Returns:
str: LLM-ready prompt
"""
return PROMPT_TEMPLATE.format(symptoms=symptom_description)
def extract_json(text: str) -> str:
"""
Extract JSON object from response string, including inside code blocks.
Args:
text (str): Raw response from LLM
Returns:
str: JSON-formatted string
"""
# Match {...} pattern even inside markdown blocks
match = re.search(r"\{.*\}", text, re.DOTALL)
return match.group(0) if match else text.strip()
def triage_response(symptoms: str) -> dict:
"""
Sends a triage prompt to Mistral API and parses the JSON response.
Args:
symptoms (str): Description of patient symptoms
Returns:
dict: Dictionary with urgency_level, possible_condition, and recommended_action
"""
if not client:
return {"error": "MISTRAL_API_KEY environment variable not set"}
if not symptoms.strip():
return {"error": "Please provide a description of symptoms"}
prompt = build_prompt(symptoms)
try:
# Create chat completion using Mistral client
response = client.chat.complete(
model="mistral-large-latest",
messages=[
{
"role": "user",
"content": prompt
}
],
temperature=0.3,
max_tokens=500
)
# Extract content from response
raw_output = response.choices[0].message.content
cleaned_output = extract_json(raw_output)
try:
result = json.loads(cleaned_output)
# Validate required fields
required_fields = ["urgency_level", "possible_condition", "recommended_action"]
if all(field in result for field in required_fields):
return result
else:
return {
"error": "Missing required fields in response",
"raw_output": raw_output,
"parsed_result": result
}
except json.JSONDecodeError as e:
return {
"error": "Invalid JSON format",
"raw_output": raw_output,
"exception": str(e)
}
except Exception as e:
return {"error": f"Mistral API error: {str(e)}"}
def format_triage_output(result: dict) -> str:
"""
Format the triage result for better display in Gradio.
"""
if "error" in result:
return f"❌ Error: {result['error']}"
urgency_icons = {
"Low": "🟒",
"Moderate": "🟑",
"High": "🟠",
"Emergency": "πŸ”΄"
}
urgency = result.get("urgency_level", "Unknown")
icon = urgency_icons.get(urgency, "βšͺ")
return f"""
{icon} **Urgency Level:** {urgency}
🩺 **Possible Condition:** {result.get("possible_condition", "Not specified")}
πŸ“‹ **Recommended Action:** {result.get("recommended_action", "Not specified")}
---
*This is for informational purposes only. Always consult with healthcare professionals for medical advice.*
""".strip()
def gradio_triage_wrapper(symptoms: str) -> str:
"""
Wrapper function for Gradio that returns formatted text output.
"""
result = triage_response(symptoms)
return format_triage_output(result)
# Create Gradio interface
demo = gr.Interface(
fn=gradio_triage_wrapper,
inputs=gr.Textbox(
lines=4,
label="Enter patient's symptoms",
placeholder="Describe the symptoms in detail (e.g., 'Severe chest pain for 30 minutes, difficulty breathing, sweating')"
),
outputs=gr.Markdown(label="Triage Assessment"),
title="πŸ₯ Clinical Triage Assistant",
description="""
An AI-powered tool that provides preliminary triage assessment based on symptom descriptions.
**Disclaimer:** This tool is for educational and informational purposes only.
It should not replace professional medical advice, diagnosis, or treatment.
""",
examples=[
["Severe chest pain radiating to left arm, difficulty breathing, sweating"],
["Mild headache and runny nose for 2 days"],
["High fever (102Β°F), severe abdominal pain, vomiting"],
["Sprained ankle, mild swelling, can walk with discomfort"]
],
theme=gr.themes.Soft()
)
if __name__ == "__main__":
# Check if API key is available
if not MISTRAL_API_KEY:
print("⚠️ Warning: MISTRAL_API_KEY environment variable not set!")
print("Please set your Mistral API key:")
print("export MISTRAL_API_KEY='your-api-key-here'")
print("\nAlso install the Mistral AI client:")
print("pip install mistralai")
demo.launch(mcp_server=True)