File size: 5,686 Bytes
1f95818
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import gradio as gr
import os
import json
import re
from mistralai import Mistral

# Initialize Mistral client
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
client = Mistral(api_key=MISTRAL_API_KEY) if MISTRAL_API_KEY else None

PROMPT_TEMPLATE = """
You are a clinical triage assistant. A patient describes their symptoms as follows:
"{symptoms}"

Based on standard triage protocols, output ONLY the following fields in JSON:
1. urgency_level: [Low, Moderate, High, Emergency]
2. possible_condition: A concise possible diagnosis
3. recommended_action: Clear next steps for the patient

Respond only in valid JSON. Do not include triple backticks or markdown formatting.
"""

def build_prompt(symptom_description: str) -> str:
    """
    Construct a natural language prompt based on symptom description.
    Args:
        symptom_description (str): Free-text input of patient symptoms
    Returns:
        str: LLM-ready prompt
    """
    return PROMPT_TEMPLATE.format(symptoms=symptom_description)

def extract_json(text: str) -> str:
    """
    Extract JSON object from response string, including inside code blocks.
    Args:
        text (str): Raw response from LLM
    Returns:
        str: JSON-formatted string
    """
    # Match {...} pattern even inside markdown blocks
    match = re.search(r"\{.*\}", text, re.DOTALL)
    return match.group(0) if match else text.strip()

def triage_response(symptoms: str) -> dict:
    """
    Sends a triage prompt to Mistral API and parses the JSON response.
    Args:
        symptoms (str): Description of patient symptoms
    Returns:
        dict: Dictionary with urgency_level, possible_condition, and recommended_action
    """
    if not client:
        return {"error": "MISTRAL_API_KEY environment variable not set"}
    
    if not symptoms.strip():
        return {"error": "Please provide a description of symptoms"}
    
    prompt = build_prompt(symptoms)
    
    try:
        # Create chat completion using Mistral client
        response = client.chat.complete(
            model="mistral-large-latest",
            messages=[
                {
                    "role": "user",
                    "content": prompt
                }
            ],
            temperature=0.3,
            max_tokens=500
        )
        
        # Extract content from response
        raw_output = response.choices[0].message.content
        cleaned_output = extract_json(raw_output)
        
        try:
            result = json.loads(cleaned_output)
            # Validate required fields
            required_fields = ["urgency_level", "possible_condition", "recommended_action"]
            if all(field in result for field in required_fields):
                return result
            else:
                return {
                    "error": "Missing required fields in response", 
                    "raw_output": raw_output,
                    "parsed_result": result
                }
        except json.JSONDecodeError as e:
            return {
                "error": "Invalid JSON format", 
                "raw_output": raw_output, 
                "exception": str(e)
            }
            
    except Exception as e:
        return {"error": f"Mistral API error: {str(e)}"}

def format_triage_output(result: dict) -> str:
    """
    Format the triage result for better display in Gradio.
    """
    if "error" in result:
        return f"❌ Error: {result['error']}"
    
    urgency_icons = {
        "Low": "🟒",
        "Moderate": "🟑", 
        "High": "🟠",
        "Emergency": "πŸ”΄"
    }
    
    urgency = result.get("urgency_level", "Unknown")
    icon = urgency_icons.get(urgency, "βšͺ")
    
    return f"""
{icon} **Urgency Level:** {urgency}

🩺 **Possible Condition:** {result.get("possible_condition", "Not specified")}

πŸ“‹ **Recommended Action:** {result.get("recommended_action", "Not specified")}

---
*This is for informational purposes only. Always consult with healthcare professionals for medical advice.*
    """.strip()

def gradio_triage_wrapper(symptoms: str) -> str:
    """
    Wrapper function for Gradio that returns formatted text output.
    """
    result = triage_response(symptoms)
    return format_triage_output(result)

# Create Gradio interface
demo = gr.Interface(
    fn=gradio_triage_wrapper,
    inputs=gr.Textbox(
        lines=4, 
        label="Enter patient's symptoms",
        placeholder="Describe the symptoms in detail (e.g., 'Severe chest pain for 30 minutes, difficulty breathing, sweating')"
    ),
    outputs=gr.Markdown(label="Triage Assessment"),
    title="πŸ₯ Clinical Triage Assistant",
    description="""
    An AI-powered tool that provides preliminary triage assessment based on symptom descriptions.
    
    **Disclaimer:** This tool is for educational and informational purposes only. 
    It should not replace professional medical advice, diagnosis, or treatment.
    """,
    examples=[
        ["Severe chest pain radiating to left arm, difficulty breathing, sweating"],
        ["Mild headache and runny nose for 2 days"],
        ["High fever (102Β°F), severe abdominal pain, vomiting"],
        ["Sprained ankle, mild swelling, can walk with discomfort"]
    ],
    theme=gr.themes.Soft()
)

if __name__ == "__main__":
    # Check if API key is available
    if not MISTRAL_API_KEY:
        print("⚠️  Warning: MISTRAL_API_KEY environment variable not set!")
        print("Please set your Mistral API key:")
        print("export MISTRAL_API_KEY='your-api-key-here'")
        print("\nAlso install the Mistral AI client:")
        print("pip install mistralai")
    
    demo.launch(mcp_server=True)