vaibhavm29's picture
included medgemma tool
7d4bd7e
raw
history blame
5.53 kB
# medgemma_tool.py
from typing import Any, Dict, Optional, Tuple, Type
from pathlib import Path
from pydantic import BaseModel, Field
import torch
from PIL import Image
from transformers import (
AutoModelForImageTextToText,
AutoProcessor,
)
from langchain_core.tools import BaseTool
from langchain_core.callbacks import (
CallbackManagerForToolRun,
AsyncCallbackManagerForToolRun,
)
class MedGemmaInput(BaseModel):
"""Input schema for MedGEMMA X-ray tool."""
image_path: str = Field(..., description="Path to a chest X-ray image")
prompt: str = Field(..., description="Question or instruction for the image")
max_new_tokens: int = Field(
300,
description="Maximum number of tokens to generate in the answer",
)
class MedGemmaXRayTool(BaseTool):
"""A tool that uses medgemma to answer questions about chest X-ray images."""
name: str = "medgemma_xray_expert"
description: str = (
"The 1st tool to be used by the agent to answer any questions related to xray images."
"The tool is specialized in performing multiple tasks including Visual Question Answering,"
"Report generation, Abnormality detection, Anatomical localization, Clinical interpretations,"
"Comparitive analysis, Identfication and explanation of imaging signs. Input should be paths to"
"X-ray images and a natural language prompt describing the task to be carried out."
)
args_schema: Type[BaseModel] = MedGemmaInput
return_direct: bool = True
# model handles
model: Optional[AutoModelForImageTextToText] = None
processor: Optional[AutoProcessor] = None
# config
model_name: str = "google/medgemma-4b-it"
device: str = "cuda" if torch.cuda.is_available() else "cpu"
dtype: torch.dtype = torch.bfloat16
def __init__(
self,
model_name: str = "google/medgemma-4b-it",
device: Optional[str] = None,
dtype: torch.dtype = torch.bfloat16,
cache_dir: Optional[str] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.model_name = model_name
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = dtype
# Load model & processor
self.model = AutoModelForImageTextToText.from_pretrained(
model_name,
device_map="auto",
torch_dtype=dtype,
trust_remote_code=True,
cache_dir=cache_dir,
)
self.processor = AutoProcessor.from_pretrained(
model_name, trust_remote_code=True, cache_dir=cache_dir
)
self.model.eval()
def _generate(
self,
image_path: str,
prompt: str,
max_new_tokens: int,
) -> str:
"""Run MedGEMMA and return decoded answer."""
img = Image.open(image_path).convert("RGB")
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are an expert radiologist. Provide a detailed response to user's query."}],
},
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image", "image": img},
],
},
]
# 3. Tokenise with chat template
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(self.model.device, dtype=self.dtype)
start_len = inputs["input_ids"].shape[-1]
# 4. Generate
with torch.inference_mode():
gens = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
)
decoded = self.processor.decode(
gens[0][start_len:], skip_special_tokens=True
)
return decoded.strip()
def _run(
self,
image_path: str,
prompt: str,
max_new_tokens: int = 300,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Tuple[Dict[str, Any], Dict]:
"""Validate, invoke model, return output + metadata."""
try:
if not Path(image_path).is_file():
raise FileNotFoundError(f"Image not found: {image_path}")
answer = self._generate(image_path, prompt, max_new_tokens)
return (
{"response": answer},
{
"image_path": image_path,
"prompt": prompt,
"max_new_tokens": max_new_tokens,
"status": "completed",
},
)
except Exception as e:
return (
{"error": str(e)},
{
"image_path": image_path,
"prompt": prompt,
"max_new_tokens": max_new_tokens,
"status": "failed",
"error": str(e),
},
)
async def _arun(
self,
image_path: str,
prompt: str,
max_new_tokens: int = 300,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Tuple[Dict[str, Any], Dict]:
"""Asynchronous wrapper (delegates to sync)."""
return self._run(image_path, prompt, max_new_tokens)