File size: 5,534 Bytes
7d4bd7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
# medgemma_tool.py
from typing import Any, Dict, Optional, Tuple, Type
from pathlib import Path
from pydantic import BaseModel, Field
import torch
from PIL import Image
from transformers import (
AutoModelForImageTextToText,
AutoProcessor,
)
from langchain_core.tools import BaseTool
from langchain_core.callbacks import (
CallbackManagerForToolRun,
AsyncCallbackManagerForToolRun,
)
class MedGemmaInput(BaseModel):
"""Input schema for MedGEMMA X-ray tool."""
image_path: str = Field(..., description="Path to a chest X-ray image")
prompt: str = Field(..., description="Question or instruction for the image")
max_new_tokens: int = Field(
300,
description="Maximum number of tokens to generate in the answer",
)
class MedGemmaXRayTool(BaseTool):
"""A tool that uses medgemma to answer questions about chest X-ray images."""
name: str = "medgemma_xray_expert"
description: str = (
"The 1st tool to be used by the agent to answer any questions related to xray images."
"The tool is specialized in performing multiple tasks including Visual Question Answering,"
"Report generation, Abnormality detection, Anatomical localization, Clinical interpretations,"
"Comparitive analysis, Identfication and explanation of imaging signs. Input should be paths to"
"X-ray images and a natural language prompt describing the task to be carried out."
)
args_schema: Type[BaseModel] = MedGemmaInput
return_direct: bool = True
# model handles
model: Optional[AutoModelForImageTextToText] = None
processor: Optional[AutoProcessor] = None
# config
model_name: str = "google/medgemma-4b-it"
device: str = "cuda" if torch.cuda.is_available() else "cpu"
dtype: torch.dtype = torch.bfloat16
def __init__(
self,
model_name: str = "google/medgemma-4b-it",
device: Optional[str] = None,
dtype: torch.dtype = torch.bfloat16,
cache_dir: Optional[str] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.model_name = model_name
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = dtype
# Load model & processor
self.model = AutoModelForImageTextToText.from_pretrained(
model_name,
device_map="auto",
torch_dtype=dtype,
trust_remote_code=True,
cache_dir=cache_dir,
)
self.processor = AutoProcessor.from_pretrained(
model_name, trust_remote_code=True, cache_dir=cache_dir
)
self.model.eval()
def _generate(
self,
image_path: str,
prompt: str,
max_new_tokens: int,
) -> str:
"""Run MedGEMMA and return decoded answer."""
img = Image.open(image_path).convert("RGB")
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are an expert radiologist. Provide a detailed response to user's query."}],
},
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image", "image": img},
],
},
]
# 3. Tokenise with chat template
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(self.model.device, dtype=self.dtype)
start_len = inputs["input_ids"].shape[-1]
# 4. Generate
with torch.inference_mode():
gens = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
)
decoded = self.processor.decode(
gens[0][start_len:], skip_special_tokens=True
)
return decoded.strip()
def _run(
self,
image_path: str,
prompt: str,
max_new_tokens: int = 300,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Tuple[Dict[str, Any], Dict]:
"""Validate, invoke model, return output + metadata."""
try:
if not Path(image_path).is_file():
raise FileNotFoundError(f"Image not found: {image_path}")
answer = self._generate(image_path, prompt, max_new_tokens)
return (
{"response": answer},
{
"image_path": image_path,
"prompt": prompt,
"max_new_tokens": max_new_tokens,
"status": "completed",
},
)
except Exception as e:
return (
{"error": str(e)},
{
"image_path": image_path,
"prompt": prompt,
"max_new_tokens": max_new_tokens,
"status": "failed",
"error": str(e),
},
)
async def _arun(
self,
image_path: str,
prompt: str,
max_new_tokens: int = 300,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Tuple[Dict[str, Any], Dict]:
"""Asynchronous wrapper (delegates to sync)."""
return self._run(image_path, prompt, max_new_tokens)
|