|
|
import gradio as gr |
|
|
import os |
|
|
import json |
|
|
import re |
|
|
from mistralai import Mistral |
|
|
|
|
|
|
|
|
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") |
|
|
client = Mistral(api_key=MISTRAL_API_KEY) if MISTRAL_API_KEY else None |
|
|
|
|
|
PROMPT_TEMPLATE = """ |
|
|
You are a clinical triage assistant. A patient describes their symptoms as follows: |
|
|
"{symptoms}" |
|
|
|
|
|
Based on standard triage protocols, output ONLY the following fields in JSON: |
|
|
1. urgency_level: [Low, Moderate, High, Emergency] |
|
|
2. possible_condition: A concise possible diagnosis |
|
|
3. recommended_action: Clear next steps for the patient |
|
|
|
|
|
Respond only in valid JSON. Do not include triple backticks or markdown formatting. |
|
|
""" |
|
|
|
|
|
def build_prompt(symptom_description: str) -> str: |
|
|
""" |
|
|
Construct a natural language prompt based on symptom description. |
|
|
Args: |
|
|
symptom_description (str): Free-text input of patient symptoms |
|
|
Returns: |
|
|
str: LLM-ready prompt |
|
|
""" |
|
|
return PROMPT_TEMPLATE.format(symptoms=symptom_description) |
|
|
|
|
|
def extract_json(text: str) -> str: |
|
|
""" |
|
|
Extract JSON object from response string, including inside code blocks. |
|
|
Args: |
|
|
text (str): Raw response from LLM |
|
|
Returns: |
|
|
str: JSON-formatted string |
|
|
""" |
|
|
|
|
|
match = re.search(r"\{.*\}", text, re.DOTALL) |
|
|
return match.group(0) if match else text.strip() |
|
|
|
|
|
def triage_response(symptoms: str) -> dict: |
|
|
""" |
|
|
Sends a triage prompt to Mistral API and parses the JSON response. |
|
|
Args: |
|
|
symptoms (str): Description of patient symptoms |
|
|
Returns: |
|
|
dict: Dictionary with urgency_level, possible_condition, and recommended_action |
|
|
""" |
|
|
if not client: |
|
|
return {"error": "MISTRAL_API_KEY environment variable not set"} |
|
|
|
|
|
if not symptoms.strip(): |
|
|
return {"error": "Please provide a description of symptoms"} |
|
|
|
|
|
prompt = build_prompt(symptoms) |
|
|
|
|
|
try: |
|
|
|
|
|
response = client.chat.complete( |
|
|
model="mistral-large-latest", |
|
|
messages=[ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": prompt |
|
|
} |
|
|
], |
|
|
temperature=0.3, |
|
|
max_tokens=500 |
|
|
) |
|
|
|
|
|
|
|
|
raw_output = response.choices[0].message.content |
|
|
cleaned_output = extract_json(raw_output) |
|
|
|
|
|
try: |
|
|
result = json.loads(cleaned_output) |
|
|
|
|
|
required_fields = ["urgency_level", "possible_condition", "recommended_action"] |
|
|
if all(field in result for field in required_fields): |
|
|
return result |
|
|
else: |
|
|
return { |
|
|
"error": "Missing required fields in response", |
|
|
"raw_output": raw_output, |
|
|
"parsed_result": result |
|
|
} |
|
|
except json.JSONDecodeError as e: |
|
|
return { |
|
|
"error": "Invalid JSON format", |
|
|
"raw_output": raw_output, |
|
|
"exception": str(e) |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": f"Mistral API error: {str(e)}"} |
|
|
|
|
|
def format_triage_output(result: dict) -> str: |
|
|
""" |
|
|
Format the triage result for better display in Gradio. |
|
|
""" |
|
|
if "error" in result: |
|
|
return f"β Error: {result['error']}" |
|
|
|
|
|
urgency_icons = { |
|
|
"Low": "π’", |
|
|
"Moderate": "π‘", |
|
|
"High": "π ", |
|
|
"Emergency": "π΄" |
|
|
} |
|
|
|
|
|
urgency = result.get("urgency_level", "Unknown") |
|
|
icon = urgency_icons.get(urgency, "βͺ") |
|
|
|
|
|
return f""" |
|
|
{icon} **Urgency Level:** {urgency} |
|
|
|
|
|
π©Ί **Possible Condition:** {result.get("possible_condition", "Not specified")} |
|
|
|
|
|
π **Recommended Action:** {result.get("recommended_action", "Not specified")} |
|
|
|
|
|
--- |
|
|
*This is for informational purposes only. Always consult with healthcare professionals for medical advice.* |
|
|
""".strip() |
|
|
|
|
|
def gradio_triage_wrapper(symptoms: str) -> str: |
|
|
""" |
|
|
Wrapper function for Gradio that returns formatted text output. |
|
|
""" |
|
|
result = triage_response(symptoms) |
|
|
return format_triage_output(result) |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=gradio_triage_wrapper, |
|
|
inputs=gr.Textbox( |
|
|
lines=4, |
|
|
label="Enter patient's symptoms", |
|
|
placeholder="Describe the symptoms in detail (e.g., 'Severe chest pain for 30 minutes, difficulty breathing, sweating')" |
|
|
), |
|
|
outputs=gr.Markdown(label="Triage Assessment"), |
|
|
title="π₯ Clinical Triage Assistant", |
|
|
description=""" |
|
|
An AI-powered tool that provides preliminary triage assessment based on symptom descriptions. |
|
|
|
|
|
**Disclaimer:** This tool is for educational and informational purposes only. |
|
|
It should not replace professional medical advice, diagnosis, or treatment. |
|
|
""", |
|
|
examples=[ |
|
|
["Severe chest pain radiating to left arm, difficulty breathing, sweating"], |
|
|
["Mild headache and runny nose for 2 days"], |
|
|
["High fever (102Β°F), severe abdominal pain, vomiting"], |
|
|
["Sprained ankle, mild swelling, can walk with discomfort"] |
|
|
], |
|
|
theme=gr.themes.Soft() |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
if not MISTRAL_API_KEY: |
|
|
print("β οΈ Warning: MISTRAL_API_KEY environment variable not set!") |
|
|
print("Please set your Mistral API key:") |
|
|
print("export MISTRAL_API_KEY='your-api-key-here'") |
|
|
print("\nAlso install the Mistral AI client:") |
|
|
print("pip install mistralai") |
|
|
|
|
|
demo.launch(mcp_server=True) |