{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import operator\n",
    "import warnings\n",
    "from typing import *\n",
    "import traceback\n",
    "\n",
    "import os\n",
    "import torch\n",
    "from dotenv import load_dotenv\n",
    "from IPython.display import Image\n",
    "from langgraph.checkpoint.memory import MemorySaver\n",
    "from langgraph.graph import END, StateGraph\n",
    "from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage\n",
    "from langchain_openai import ChatOpenAI\n",
    "from transformers import logging\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import re\n",
    "\n",
    "from medrax.agent import *\n",
    "from medrax.tools import *\n",
    "from medrax.utils import *\n",
    "\n",
    "import json\n",
    "import openai\n",
    "import os\n",
    "import glob\n",
    "import time\n",
    "import logging\n",
    "from datetime import datetime\n",
    "from tenacity import retry, wait_exponential, stop_after_attempt\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "_ = load_dotenv()\n",
    "\n",
    "\n",
    "# Setup directory paths\n",
    "ROOT = \"set this directory to where MedRAX is, .e.g /home/MedRAX\"\n",
    "PROMPT_FILE = f\"{ROOT}/medrax/docs/system_prompts.txt\"\n",
    "BENCHMARK_FILE = f\"{ROOT}/benchmark/questions\"\n",
    "MODEL_DIR = f\"set this to where the tool models are, e.g /home/models\"\n",
    "FIGURES_DIR = f\"{ROOT}/benchmark/figures\"\n",
    "\n",
    "model_name = \"medrax\"\n",
    "temperature = 0.2\n",
    "medrax_logs = f\"{ROOT}/experiments/medrax_logs\"\n",
    "log_filename = f\"{medrax_logs}/{model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json\"\n",
    "logging.basicConfig(filename=log_filename, level=logging.INFO, format=\"%(message)s\", force=True)\n",
    "device = \"cuda\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_tools():\n",
    "    report_tool = ChestXRayReportGeneratorTool(cache_dir=MODEL_DIR, device=device)\n",
    "    xray_classification_tool = ChestXRayClassifierTool(device=device)\n",
    "    segmentation_tool = ChestXRaySegmentationTool(device=device)\n",
    "    grounding_tool = XRayPhraseGroundingTool(\n",
    "        cache_dir=MODEL_DIR, temp_dir=\"temp\", device=device, load_in_8bit=True\n",
    "    )\n",
    "    xray_vqa_tool = XRayVQATool(cache_dir=MODEL_DIR, device=device)\n",
    "    llava_med_tool = LlavaMedTool(cache_dir=MODEL_DIR, device=device, load_in_8bit=True)\n",
    "\n",
    "    return [\n",
    "        report_tool,\n",
    "        xray_classification_tool,\n",
    "        segmentation_tool,\n",
    "        grounding_tool,\n",
    "        xray_vqa_tool,\n",
    "        llava_med_tool,\n",
    "    ]\n",
    "\n",
    "\n",
    "def get_agent(tools):\n",
    "    prompts = load_prompts_from_file(PROMPT_FILE)\n",
    "    prompt = prompts[\"MEDICAL_ASSISTANT\"]\n",
    "\n",
    "    checkpointer = MemorySaver()\n",
    "    model = ChatOpenAI(model=\"gpt-4o\", temperature=temperature, top_p=0.95)\n",
    "    agent = Agent(\n",
    "        model,\n",
    "        tools=tools,\n",
    "        log_tools=True,\n",
    "        log_dir=\"logs\",\n",
    "        system_prompt=prompt,\n",
    "        checkpointer=checkpointer,\n",
    "    )\n",
    "    thread = {\"configurable\": {\"thread_id\": \"1\"}}\n",
    "    return agent, thread\n",
    "\n",
    "\n",
    "def run_medrax(agent, thread, prompt, image_urls=[]):\n",
    "    messages = [\n",
    "        HumanMessage(\n",
    "            content=[\n",
    "                {\"type\": \"text\", \"text\": prompt},\n",
    "            ]\n",
    "            + [{\"type\": \"image_url\", \"image_url\": {\"url\": image_url}} for image_url in image_urls]\n",
    "        )\n",
    "    ]\n",
    "\n",
    "    final_response = None\n",
    "    for event in agent.workflow.stream({\"messages\": messages}, thread):\n",
    "        for v in event.values():\n",
    "            final_response = v\n",
    "\n",
    "    final_response = final_response[\"messages\"][-1].content.strip()\n",
    "    agent_state = agent.workflow.get_state(thread)\n",
    "\n",
    "    return final_response, str(agent_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_multimodal_request(question_data, case_details, case_id, question_id, agent, thread):\n",
    "    # Parse required figures\n",
    "    try:\n",
    "        # Try multiple ways of parsing figures\n",
    "        if isinstance(question_data[\"figures\"], str):\n",
    "            try:\n",
    "                required_figures = json.loads(question_data[\"figures\"])\n",
    "            except json.JSONDecodeError:\n",
    "                required_figures = [question_data[\"figures\"]]\n",
    "        elif isinstance(question_data[\"figures\"], list):\n",
    "            required_figures = question_data[\"figures\"]\n",
    "        else:\n",
    "            required_figures = [str(question_data[\"figures\"])]\n",
    "    except Exception as e:\n",
    "        print(f\"Error parsing figures: {e}\")\n",
    "        required_figures = []\n",
    "\n",
    "    # Ensure each figure starts with \"Figure \"\n",
    "    required_figures = [\n",
    "        fig if fig.startswith(\"Figure \") else f\"Figure {fig}\" for fig in required_figures\n",
    "    ]\n",
    "\n",
    "    subfigures = []\n",
    "    for figure in required_figures:\n",
    "        # Handle both regular figures and those with letter suffixes\n",
    "        base_figure_num = \"\".join(filter(str.isdigit, figure))\n",
    "        figure_letter = \"\".join(filter(str.isalpha, figure.split()[-1])) or None\n",
    "\n",
    "        # Find matching figures in case details\n",
    "        matching_figures = [\n",
    "            case_figure\n",
    "            for case_figure in case_details.get(\"figures\", [])\n",
    "            if case_figure[\"number\"] == f\"Figure {base_figure_num}\"\n",
    "        ]\n",
    "\n",
    "        if not matching_figures:\n",
    "            print(f\"No matching figure found for {figure} in case {case_id}\")\n",
    "            continue\n",
    "\n",
    "        for case_figure in matching_figures:\n",
    "            # If a specific letter is specified, filter subfigures\n",
    "            if figure_letter:\n",
    "                matching_subfigures = [\n",
    "                    subfig\n",
    "                    for subfig in case_figure.get(\"subfigures\", [])\n",
    "                    if subfig.get(\"number\", \"\").lower().endswith(figure_letter.lower())\n",
    "                    or subfig.get(\"label\", \"\").lower() == figure_letter.lower()\n",
    "                ]\n",
    "                subfigures.extend(matching_subfigures)\n",
    "            else:\n",
    "                # If no letter specified, add all subfigures\n",
    "                subfigures.extend(case_figure.get(\"subfigures\", []))\n",
    "\n",
    "    # Add images to content\n",
    "    figure_prompt = \"\"\n",
    "    image_urls = []\n",
    "\n",
    "    for subfig in subfigures:\n",
    "        if \"number\" in subfig:\n",
    "            subfig_number = subfig[\"number\"].lower().strip().replace(\" \", \"_\") + \".jpg\"\n",
    "            subfig_path = os.path.join(FIGURES_DIR, case_id, subfig_number)\n",
    "            figure_prompt += f\"{subfig_number} located at {subfig_path}\\n\"\n",
    "        if \"url\" in subfig:\n",
    "            image_urls.append(subfig[\"url\"])\n",
    "        else:\n",
    "            print(f\"Subfigure missing URL: {subfig}\")\n",
    "\n",
    "    prompt = (\n",
    "        f\"Answer this question correctly using chain of thought reasoning and \"\n",
    "        \"carefully evaluating choices. Solve using our own vision and reasoning and then\"\n",
    "        \"use tools to complement your reasoning. Trust your own judgement over any tools.\\n\"\n",
    "        f\"{question_data['question']}\\n{figure_prompt}\"\n",
    "    )\n",
    "\n",
    "    try:\n",
    "        start_time = time.time()\n",
    "\n",
    "        final_response, agent_state = run_medrax(\n",
    "            agent=agent, thread=thread, prompt=prompt, image_urls=image_urls\n",
    "        )\n",
    "        model_answer, agent_state = run_medrax(\n",
    "            agent=agent,\n",
    "            thread=thread,\n",
    "            prompt=\"If you had to choose the best option, only respond with the letter of choice (only one of A, B, C, D, E, F)\",\n",
    "        )\n",
    "        duration = time.time() - start_time\n",
    "\n",
    "        log_entry = {\n",
    "            \"case_id\": case_id,\n",
    "            \"question_id\": question_id,\n",
    "            \"timestamp\": datetime.now().isoformat(),\n",
    "            \"model\": model_name,\n",
    "            \"temperature\": temperature,\n",
    "            \"duration\": round(duration, 2),\n",
    "            \"usage\": \"\",\n",
    "            \"cost\": 0,\n",
    "            \"raw_response\": final_response,\n",
    "            \"model_answer\": model_answer.strip(),\n",
    "            \"correct_answer\": question_data[\"answer\"][0],\n",
    "            \"input\": {\n",
    "                \"messages\": prompt,\n",
    "                \"question_data\": {\n",
    "                    \"question\": question_data[\"question\"],\n",
    "                    \"explanation\": question_data[\"explanation\"],\n",
    "                    \"metadata\": question_data.get(\"metadata\", {}),\n",
    "                    \"figures\": question_data[\"figures\"],\n",
    "                },\n",
    "                \"image_urls\": [subfig[\"url\"] for subfig in subfigures if \"url\" in subfig],\n",
    "                \"image_captions\": [subfig.get(\"caption\", \"\") for subfig in subfigures],\n",
    "            },\n",
    "            \"agent_state\": agent_state,\n",
    "        }\n",
    "        logging.info(json.dumps(log_entry))\n",
    "        return final_response, model_answer.strip()\n",
    "\n",
    "    except Exception as e:\n",
    "        log_entry = {\n",
    "            \"case_id\": case_id,\n",
    "            \"question_id\": question_id,\n",
    "            \"timestamp\": datetime.now().isoformat(),\n",
    "            \"model\": model_name,\n",
    "            \"temperature\": temperature,\n",
    "            \"status\": \"error\",\n",
    "            \"error\": str(e),\n",
    "            \"cost\": 0,\n",
    "            \"input\": {\n",
    "                \"messages\": prompt,\n",
    "                \"question_data\": {\n",
    "                    \"question\": question_data[\"question\"],\n",
    "                    \"explanation\": question_data[\"explanation\"],\n",
    "                    \"metadata\": question_data.get(\"metadata\", {}),\n",
    "                    \"figures\": question_data[\"figures\"],\n",
    "                },\n",
    "                \"image_urls\": [subfig[\"url\"] for subfig in subfigures if \"url\" in subfig],\n",
    "                \"image_captions\": [subfig.get(\"caption\", \"\") for subfig in subfigures],\n",
    "            },\n",
    "        }\n",
    "        logging.info(json.dumps(log_entry))\n",
    "        print(f\"Error processing case {case_id}, question {question_id}: {str(e)}\")\n",
    "        return \"\", \"\"\n",
    "\n",
    "\n",
    "def load_benchmark_questions(case_id):\n",
    "    benchmark_dir = \"../benchmark/questions\"\n",
    "    return glob.glob(f\"{benchmark_dir}/{case_id}/{case_id}_*.json\")\n",
    "\n",
    "\n",
    "def count_total_questions():\n",
    "    total_cases = len(glob.glob(\"../benchmark/questions/*\"))\n",
    "    total_questions = sum(\n",
    "        len(glob.glob(f\"../benchmark/questions/{case_id}/*.json\"))\n",
    "        for case_id in os.listdir(\"../benchmark/questions\")\n",
    "    )\n",
    "    return total_cases, total_questions\n",
    "\n",
    "\n",
    "def main(tools):\n",
    "    with open(\"../data/eurorad_metadata.json\", \"r\") as file:\n",
    "        data = json.load(file)\n",
    "\n",
    "    total_cases, total_questions = count_total_questions()\n",
    "    cases_processed = 0\n",
    "    questions_processed = 0\n",
    "    skipped_questions = 0\n",
    "\n",
    "    print(f\"Beginning benchmark evaluation for model {model_name} with temperature {temperature}\\n\")\n",
    "\n",
    "    for case_id, case_details in data.items():\n",
    "        if int(case_details[\"case_id\"]) <= 17158:\n",
    "            continue\n",
    "\n",
    "        print(f\"----------------------------------------------------------------\")\n",
    "        agent, thread = get_agent(tools)\n",
    "\n",
    "        question_files = load_benchmark_questions(case_id)\n",
    "        if not question_files:\n",
    "            continue\n",
    "\n",
    "        cases_processed += 1\n",
    "        for question_file in question_files:\n",
    "            with open(question_file, \"r\") as file:\n",
    "                question_data = json.load(file)\n",
    "                question_id = os.path.basename(question_file).split(\".\")[0]\n",
    "\n",
    "            # agent, thread = get_agent(tools)\n",
    "            questions_processed += 1\n",
    "            final_response, model_answer = create_multimodal_request(\n",
    "                question_data, case_details, case_id, question_id, agent, thread\n",
    "            )\n",
    "\n",
    "            # Handle cases where response is None\n",
    "            if final_response is None:\n",
    "                skipped_questions += 1\n",
    "                print(f\"Skipped question: Case ID {case_id}, Question ID {question_id}\")\n",
    "                continue\n",
    "\n",
    "            print(\n",
    "                f\"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}\"\n",
    "            )\n",
    "            print(f\"Case ID: {case_id}\")\n",
    "            print(f\"Question ID: {question_id}\")\n",
    "            print(f\"Final Response: {final_response}\")\n",
    "            print(f\"Model Answer: {model_answer}\")\n",
    "            print(f\"Correct Answer: {question_data['answer']}\")\n",
    "            print(f\"----------------------------------------------------------------\\n\")\n",
    "\n",
    "    print(f\"\\nBenchmark Summary:\")\n",
    "    print(f\"Total Cases Processed: {cases_processed}\")\n",
    "    print(f\"Total Questions Processed: {questions_processed}\")\n",
    "    print(f\"Total Questions Skipped: {skipped_questions}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tools = get_tools()\n",
    "main(tools)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "medmax",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}