File size: 5,686 Bytes
1f95818 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import gradio as gr
import os
import json
import re
from mistralai import Mistral
# Initialize Mistral client
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
client = Mistral(api_key=MISTRAL_API_KEY) if MISTRAL_API_KEY else None
PROMPT_TEMPLATE = """
You are a clinical triage assistant. A patient describes their symptoms as follows:
"{symptoms}"
Based on standard triage protocols, output ONLY the following fields in JSON:
1. urgency_level: [Low, Moderate, High, Emergency]
2. possible_condition: A concise possible diagnosis
3. recommended_action: Clear next steps for the patient
Respond only in valid JSON. Do not include triple backticks or markdown formatting.
"""
def build_prompt(symptom_description: str) -> str:
"""
Construct a natural language prompt based on symptom description.
Args:
symptom_description (str): Free-text input of patient symptoms
Returns:
str: LLM-ready prompt
"""
return PROMPT_TEMPLATE.format(symptoms=symptom_description)
def extract_json(text: str) -> str:
"""
Extract JSON object from response string, including inside code blocks.
Args:
text (str): Raw response from LLM
Returns:
str: JSON-formatted string
"""
# Match {...} pattern even inside markdown blocks
match = re.search(r"\{.*\}", text, re.DOTALL)
return match.group(0) if match else text.strip()
def triage_response(symptoms: str) -> dict:
"""
Sends a triage prompt to Mistral API and parses the JSON response.
Args:
symptoms (str): Description of patient symptoms
Returns:
dict: Dictionary with urgency_level, possible_condition, and recommended_action
"""
if not client:
return {"error": "MISTRAL_API_KEY environment variable not set"}
if not symptoms.strip():
return {"error": "Please provide a description of symptoms"}
prompt = build_prompt(symptoms)
try:
# Create chat completion using Mistral client
response = client.chat.complete(
model="mistral-large-latest",
messages=[
{
"role": "user",
"content": prompt
}
],
temperature=0.3,
max_tokens=500
)
# Extract content from response
raw_output = response.choices[0].message.content
cleaned_output = extract_json(raw_output)
try:
result = json.loads(cleaned_output)
# Validate required fields
required_fields = ["urgency_level", "possible_condition", "recommended_action"]
if all(field in result for field in required_fields):
return result
else:
return {
"error": "Missing required fields in response",
"raw_output": raw_output,
"parsed_result": result
}
except json.JSONDecodeError as e:
return {
"error": "Invalid JSON format",
"raw_output": raw_output,
"exception": str(e)
}
except Exception as e:
return {"error": f"Mistral API error: {str(e)}"}
def format_triage_output(result: dict) -> str:
"""
Format the triage result for better display in Gradio.
"""
if "error" in result:
return f"β Error: {result['error']}"
urgency_icons = {
"Low": "π’",
"Moderate": "π‘",
"High": "π ",
"Emergency": "π΄"
}
urgency = result.get("urgency_level", "Unknown")
icon = urgency_icons.get(urgency, "βͺ")
return f"""
{icon} **Urgency Level:** {urgency}
π©Ί **Possible Condition:** {result.get("possible_condition", "Not specified")}
π **Recommended Action:** {result.get("recommended_action", "Not specified")}
---
*This is for informational purposes only. Always consult with healthcare professionals for medical advice.*
""".strip()
def gradio_triage_wrapper(symptoms: str) -> str:
"""
Wrapper function for Gradio that returns formatted text output.
"""
result = triage_response(symptoms)
return format_triage_output(result)
# Create Gradio interface
demo = gr.Interface(
fn=gradio_triage_wrapper,
inputs=gr.Textbox(
lines=4,
label="Enter patient's symptoms",
placeholder="Describe the symptoms in detail (e.g., 'Severe chest pain for 30 minutes, difficulty breathing, sweating')"
),
outputs=gr.Markdown(label="Triage Assessment"),
title="π₯ Clinical Triage Assistant",
description="""
An AI-powered tool that provides preliminary triage assessment based on symptom descriptions.
**Disclaimer:** This tool is for educational and informational purposes only.
It should not replace professional medical advice, diagnosis, or treatment.
""",
examples=[
["Severe chest pain radiating to left arm, difficulty breathing, sweating"],
["Mild headache and runny nose for 2 days"],
["High fever (102Β°F), severe abdominal pain, vomiting"],
["Sprained ankle, mild swelling, can walk with discomfort"]
],
theme=gr.themes.Soft()
)
if __name__ == "__main__":
# Check if API key is available
if not MISTRAL_API_KEY:
print("β οΈ Warning: MISTRAL_API_KEY environment variable not set!")
print("Please set your Mistral API key:")
print("export MISTRAL_API_KEY='your-api-key-here'")
print("\nAlso install the Mistral AI client:")
print("pip install mistralai")
demo.launch(mcp_server=True) |