|
|
|
from typing import Any, Dict, Optional, Tuple, Type |
|
|
|
from pathlib import Path |
|
from pydantic import BaseModel, Field |
|
|
|
import torch |
|
from PIL import Image |
|
from transformers import ( |
|
AutoModelForImageTextToText, |
|
AutoProcessor, |
|
) |
|
|
|
from langchain_core.tools import BaseTool |
|
from langchain_core.callbacks import ( |
|
CallbackManagerForToolRun, |
|
AsyncCallbackManagerForToolRun, |
|
) |
|
|
|
class MedGemmaInput(BaseModel): |
|
"""Input schema for MedGEMMA X-ray tool.""" |
|
image_path: str = Field(..., description="Path to a chest X-ray image") |
|
prompt: str = Field(..., description="Question or instruction for the image") |
|
max_new_tokens: int = Field( |
|
300, |
|
description="Maximum number of tokens to generate in the answer", |
|
) |
|
|
|
|
|
class MedGemmaXRayTool(BaseTool): |
|
"""A tool that uses medgemma to answer questions about chest X-ray images.""" |
|
|
|
name: str = "medgemma_xray_expert" |
|
description: str = ( |
|
"The 1st tool to be used by the agent to answer any questions related to xray images." |
|
"The tool is specialized in performing multiple tasks including Visual Question Answering," |
|
"Report generation, Abnormality detection, Anatomical localization, Clinical interpretations," |
|
"Comparitive analysis, Identfication and explanation of imaging signs. Input should be paths to" |
|
"X-ray images and a natural language prompt describing the task to be carried out." |
|
) |
|
args_schema: Type[BaseModel] = MedGemmaInput |
|
return_direct: bool = True |
|
|
|
|
|
model: Optional[AutoModelForImageTextToText] = None |
|
processor: Optional[AutoProcessor] = None |
|
|
|
|
|
model_name: str = "google/medgemma-4b-it" |
|
device: str = "cuda" if torch.cuda.is_available() else "cpu" |
|
dtype: torch.dtype = torch.bfloat16 |
|
|
|
def __init__( |
|
self, |
|
model_name: str = "google/medgemma-4b-it", |
|
device: Optional[str] = None, |
|
dtype: torch.dtype = torch.bfloat16, |
|
cache_dir: Optional[str] = None, |
|
**kwargs: Any, |
|
) -> None: |
|
super().__init__(**kwargs) |
|
|
|
self.model_name = model_name |
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
self.dtype = dtype |
|
|
|
|
|
self.model = AutoModelForImageTextToText.from_pretrained( |
|
model_name, |
|
device_map="auto", |
|
torch_dtype=dtype, |
|
trust_remote_code=True, |
|
cache_dir=cache_dir, |
|
) |
|
self.processor = AutoProcessor.from_pretrained( |
|
model_name, trust_remote_code=True, cache_dir=cache_dir |
|
) |
|
self.model.eval() |
|
|
|
def _generate( |
|
self, |
|
image_path: str, |
|
prompt: str, |
|
max_new_tokens: int, |
|
) -> str: |
|
"""Run MedGEMMA and return decoded answer.""" |
|
img = Image.open(image_path).convert("RGB") |
|
|
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": [{"type": "text", "text": "You are an expert radiologist. Provide a detailed response to user's query."}], |
|
}, |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": prompt}, |
|
{"type": "image", "image": img}, |
|
], |
|
}, |
|
] |
|
|
|
|
|
inputs = self.processor.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
tokenize=True, |
|
return_dict=True, |
|
return_tensors="pt", |
|
).to(self.model.device, dtype=self.dtype) |
|
|
|
start_len = inputs["input_ids"].shape[-1] |
|
|
|
|
|
with torch.inference_mode(): |
|
gens = self.model.generate( |
|
**inputs, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=False, |
|
) |
|
decoded = self.processor.decode( |
|
gens[0][start_len:], skip_special_tokens=True |
|
) |
|
return decoded.strip() |
|
|
|
def _run( |
|
self, |
|
image_path: str, |
|
prompt: str, |
|
max_new_tokens: int = 300, |
|
run_manager: Optional[CallbackManagerForToolRun] = None, |
|
) -> Tuple[Dict[str, Any], Dict]: |
|
"""Validate, invoke model, return output + metadata.""" |
|
try: |
|
if not Path(image_path).is_file(): |
|
raise FileNotFoundError(f"Image not found: {image_path}") |
|
|
|
answer = self._generate(image_path, prompt, max_new_tokens) |
|
|
|
return ( |
|
{"response": answer}, |
|
{ |
|
"image_path": image_path, |
|
"prompt": prompt, |
|
"max_new_tokens": max_new_tokens, |
|
"status": "completed", |
|
}, |
|
) |
|
|
|
except Exception as e: |
|
return ( |
|
{"error": str(e)}, |
|
{ |
|
"image_path": image_path, |
|
"prompt": prompt, |
|
"max_new_tokens": max_new_tokens, |
|
"status": "failed", |
|
"error": str(e), |
|
}, |
|
) |
|
|
|
async def _arun( |
|
self, |
|
image_path: str, |
|
prompt: str, |
|
max_new_tokens: int = 300, |
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, |
|
) -> Tuple[Dict[str, Any], Dict]: |
|
"""Asynchronous wrapper (delegates to sync).""" |
|
return self._run(image_path, prompt, max_new_tokens) |
|
|