lukmanaj's picture
Update app2.py
d551f05 verified
import gradio as gr
from mistralai import Mistral
import os
import json
import re
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
client = Mistral(api_key=MISTRAL_API_KEY) if MISTRAL_API_KEY else None
PROMPT_TEMPLATE = """
You are a clinical triage assistant. A patient describes their symptoms as follows:
"{symptoms}"
Based on standard triage protocols, output ONLY the following fields in JSON:
1. urgency_level: [Low, Moderate, High, Emergency]
2. possible_condition: A concise possible diagnosis
3. recommended_action: Clear next steps for the patient
Respond only in valid JSON. Do not include triple backticks or markdown formatting.
"""
def build_prompt(symptom_description: str) -> str:
"""
Construct a natural language prompt based on symptom description.
Args:
symptom_description (str): Free-text input of patient symptoms
Returns:
str: LLM-ready prompt
"""
return PROMPT_TEMPLATE.format(symptoms=symptom_description)
def extract_json(text: str) -> str:
"""
Extract JSON object from response string, including inside code blocks.
Args:
text (str): Raw response from LLM
Returns:
str: JSON-formatted string
"""
# Match {...} pattern even inside markdown blocks
match = re.search(r"\{.*\}", text, re.DOTALL)
return match.group(0) if match else text.strip()
def triage_response(symptoms: str) -> dict:
"""
Sends a triage prompt to Mistral API and parses the JSON response.
Args:
symptoms (str): Description of patient symptoms
Returns:
dict: Dictionary with urgency_level, possible_condition, and recommended_action
"""
if not client:
return {"error": "MISTRAL_API_KEY environment variable not set"}
if not symptoms.strip():
return {"error": "Please provide a description of symptoms"}
prompt = build_prompt(symptoms)
try:
# Create chat completion using Mistral client
response = client.chat.complete(
model="mistral-large-latest",
messages=[
{
"role": "user",
"content": prompt
}
],
temperature=0.3,
max_tokens=500
)
# Extract content from response
raw_output = response.choices[0].message.content
cleaned_output = extract_json(raw_output)
try:
result = json.loads(cleaned_output)
# Validate required fields
required_fields = ["urgency_level", "possible_condition", "recommended_action"]
if all(field in result for field in required_fields):
return result
else:
return {
"error": "Missing required fields in response",
"raw_output": raw_output,
"parsed_result": result
}
except json.JSONDecodeError as e:
return {
"error": "Invalid JSON format",
"raw_output": raw_output,
"exception": str(e)
}
except Exception as e:
return {"error": f"Mistral API error: {str(e)}"}
# Launch Gradio app with MCP support
demo = gr.Interface(
fn=triage_response,
inputs=gr.Textbox(lines=4, label="Enter patient's symptoms"),
outputs="json",
title="Clinical Triage Assistant",
description="An MCP-compatible AI-powered tool that triages patient symptoms using Mistral API."
)
if __name__ == "__main__":
demo.launch(mcp_server=True)