lukmanaj commited on
Commit
1f95818
·
verified ·
1 Parent(s): bf5f47a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import json
4
+ import re
5
+ from mistralai import Mistral
6
+
7
+ # Initialize Mistral client
8
+ MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
9
+ client = Mistral(api_key=MISTRAL_API_KEY) if MISTRAL_API_KEY else None
10
+
11
+ PROMPT_TEMPLATE = """
12
+ You are a clinical triage assistant. A patient describes their symptoms as follows:
13
+ "{symptoms}"
14
+
15
+ Based on standard triage protocols, output ONLY the following fields in JSON:
16
+ 1. urgency_level: [Low, Moderate, High, Emergency]
17
+ 2. possible_condition: A concise possible diagnosis
18
+ 3. recommended_action: Clear next steps for the patient
19
+
20
+ Respond only in valid JSON. Do not include triple backticks or markdown formatting.
21
+ """
22
+
23
+ def build_prompt(symptom_description: str) -> str:
24
+ """
25
+ Construct a natural language prompt based on symptom description.
26
+ Args:
27
+ symptom_description (str): Free-text input of patient symptoms
28
+ Returns:
29
+ str: LLM-ready prompt
30
+ """
31
+ return PROMPT_TEMPLATE.format(symptoms=symptom_description)
32
+
33
+ def extract_json(text: str) -> str:
34
+ """
35
+ Extract JSON object from response string, including inside code blocks.
36
+ Args:
37
+ text (str): Raw response from LLM
38
+ Returns:
39
+ str: JSON-formatted string
40
+ """
41
+ # Match {...} pattern even inside markdown blocks
42
+ match = re.search(r"\{.*\}", text, re.DOTALL)
43
+ return match.group(0) if match else text.strip()
44
+
45
+ def triage_response(symptoms: str) -> dict:
46
+ """
47
+ Sends a triage prompt to Mistral API and parses the JSON response.
48
+ Args:
49
+ symptoms (str): Description of patient symptoms
50
+ Returns:
51
+ dict: Dictionary with urgency_level, possible_condition, and recommended_action
52
+ """
53
+ if not client:
54
+ return {"error": "MISTRAL_API_KEY environment variable not set"}
55
+
56
+ if not symptoms.strip():
57
+ return {"error": "Please provide a description of symptoms"}
58
+
59
+ prompt = build_prompt(symptoms)
60
+
61
+ try:
62
+ # Create chat completion using Mistral client
63
+ response = client.chat.complete(
64
+ model="mistral-large-latest",
65
+ messages=[
66
+ {
67
+ "role": "user",
68
+ "content": prompt
69
+ }
70
+ ],
71
+ temperature=0.3,
72
+ max_tokens=500
73
+ )
74
+
75
+ # Extract content from response
76
+ raw_output = response.choices[0].message.content
77
+ cleaned_output = extract_json(raw_output)
78
+
79
+ try:
80
+ result = json.loads(cleaned_output)
81
+ # Validate required fields
82
+ required_fields = ["urgency_level", "possible_condition", "recommended_action"]
83
+ if all(field in result for field in required_fields):
84
+ return result
85
+ else:
86
+ return {
87
+ "error": "Missing required fields in response",
88
+ "raw_output": raw_output,
89
+ "parsed_result": result
90
+ }
91
+ except json.JSONDecodeError as e:
92
+ return {
93
+ "error": "Invalid JSON format",
94
+ "raw_output": raw_output,
95
+ "exception": str(e)
96
+ }
97
+
98
+ except Exception as e:
99
+ return {"error": f"Mistral API error: {str(e)}"}
100
+
101
+ def format_triage_output(result: dict) -> str:
102
+ """
103
+ Format the triage result for better display in Gradio.
104
+ """
105
+ if "error" in result:
106
+ return f"❌ Error: {result['error']}"
107
+
108
+ urgency_icons = {
109
+ "Low": "🟢",
110
+ "Moderate": "🟡",
111
+ "High": "🟠",
112
+ "Emergency": "🔴"
113
+ }
114
+
115
+ urgency = result.get("urgency_level", "Unknown")
116
+ icon = urgency_icons.get(urgency, "⚪")
117
+
118
+ return f"""
119
+ {icon} **Urgency Level:** {urgency}
120
+
121
+ 🩺 **Possible Condition:** {result.get("possible_condition", "Not specified")}
122
+
123
+ 📋 **Recommended Action:** {result.get("recommended_action", "Not specified")}
124
+
125
+ ---
126
+ *This is for informational purposes only. Always consult with healthcare professionals for medical advice.*
127
+ """.strip()
128
+
129
+ def gradio_triage_wrapper(symptoms: str) -> str:
130
+ """
131
+ Wrapper function for Gradio that returns formatted text output.
132
+ """
133
+ result = triage_response(symptoms)
134
+ return format_triage_output(result)
135
+
136
+ # Create Gradio interface
137
+ demo = gr.Interface(
138
+ fn=gradio_triage_wrapper,
139
+ inputs=gr.Textbox(
140
+ lines=4,
141
+ label="Enter patient's symptoms",
142
+ placeholder="Describe the symptoms in detail (e.g., 'Severe chest pain for 30 minutes, difficulty breathing, sweating')"
143
+ ),
144
+ outputs=gr.Markdown(label="Triage Assessment"),
145
+ title="🏥 Clinical Triage Assistant",
146
+ description="""
147
+ An AI-powered tool that provides preliminary triage assessment based on symptom descriptions.
148
+
149
+ **Disclaimer:** This tool is for educational and informational purposes only.
150
+ It should not replace professional medical advice, diagnosis, or treatment.
151
+ """,
152
+ examples=[
153
+ ["Severe chest pain radiating to left arm, difficulty breathing, sweating"],
154
+ ["Mild headache and runny nose for 2 days"],
155
+ ["High fever (102°F), severe abdominal pain, vomiting"],
156
+ ["Sprained ankle, mild swelling, can walk with discomfort"]
157
+ ],
158
+ theme=gr.themes.Soft()
159
+ )
160
+
161
+ if __name__ == "__main__":
162
+ # Check if API key is available
163
+ if not MISTRAL_API_KEY:
164
+ print("⚠️ Warning: MISTRAL_API_KEY environment variable not set!")
165
+ print("Please set your Mistral API key:")
166
+ print("export MISTRAL_API_KEY='your-api-key-here'")
167
+ print("\nAlso install the Mistral AI client:")
168
+ print("pip install mistralai")
169
+
170
+ demo.launch(mcp_server=True)