|
|
import gradio as gr |
|
|
import requests |
|
|
import os |
|
|
import json |
|
|
import re |
|
|
|
|
|
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") |
|
|
MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions" |
|
|
|
|
|
PROMPT_TEMPLATE = """ |
|
|
You are a clinical triage assistant. A patient describes their symptoms as follows: |
|
|
|
|
|
"{symptoms}" |
|
|
|
|
|
Based on standard triage protocols, output ONLY the following fields in JSON: |
|
|
1. urgency_level: [Low, Moderate, High, Emergency] |
|
|
2. possible_condition: A concise possible diagnosis |
|
|
3. recommended_action: Clear next steps for the patient |
|
|
|
|
|
Respond only in valid JSON. Do not include triple backticks or markdown formatting. |
|
|
""" |
|
|
|
|
|
def build_prompt(symptom_description: str) -> str: |
|
|
""" |
|
|
Construct a natural language prompt based on symptom description. |
|
|
|
|
|
Args: |
|
|
symptom_description (str): Free-text input of patient symptoms |
|
|
|
|
|
Returns: |
|
|
str: LLM-ready prompt |
|
|
""" |
|
|
return PROMPT_TEMPLATE.format(symptoms=symptom_description) |
|
|
|
|
|
def extract_json(text: str) -> str: |
|
|
""" |
|
|
Extract JSON object from response string, including inside code blocks. |
|
|
|
|
|
Args: |
|
|
text (str): Raw response from LLM |
|
|
|
|
|
Returns: |
|
|
str: JSON-formatted string |
|
|
""" |
|
|
|
|
|
match = re.search(r"\{.*\}", text, re.DOTALL) |
|
|
return match.group(0) if match else text.strip() |
|
|
|
|
|
def triage_response(symptoms: str) -> dict: |
|
|
""" |
|
|
Sends a triage prompt to Mistral API and parses the JSON response. |
|
|
|
|
|
Args: |
|
|
symptoms (str): Description of patient symptoms |
|
|
|
|
|
Returns: |
|
|
dict: Dictionary with urgency_level, possible_condition, and recommended_action |
|
|
""" |
|
|
prompt = build_prompt(symptoms) |
|
|
|
|
|
headers = { |
|
|
"Authorization": f"Bearer {MISTRAL_API_KEY}", |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
|
|
|
body = { |
|
|
"model": "mistral-medium", |
|
|
"messages": [{"role": "user", "content": prompt}], |
|
|
"temperature": 0.3 |
|
|
} |
|
|
|
|
|
try: |
|
|
response = requests.post(MISTRAL_API_URL, headers=headers, json=body) |
|
|
response.raise_for_status() |
|
|
raw_output = response.json()["choices"][0]["message"]["content"] |
|
|
cleaned_output = extract_json(raw_output) |
|
|
try: |
|
|
return json.loads(cleaned_output) |
|
|
except Exception as e: |
|
|
return {"error": "Invalid JSON format", "raw_output": raw_output, "exception": str(e)} |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=triage_response, |
|
|
inputs=gr.Textbox(lines=4, label="Enter patient's symptoms"), |
|
|
outputs="json", |
|
|
title="Clinical Triage Assistant", |
|
|
description="An MCP-compatible AI-powered tool that triages patient symptoms using Mistral API." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(mcp_server=True) |
|
|
|