File size: 5,534 Bytes
7d4bd7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# medgemma_tool.py
from typing import Any, Dict, Optional, Tuple, Type

from pathlib import Path
from pydantic import BaseModel, Field

import torch
from PIL import Image
from transformers import (
    AutoModelForImageTextToText,
    AutoProcessor,
)

from langchain_core.tools import BaseTool
from langchain_core.callbacks import (
    CallbackManagerForToolRun,
    AsyncCallbackManagerForToolRun,
)

class MedGemmaInput(BaseModel):
    """Input schema for MedGEMMA X-ray tool."""
    image_path: str = Field(..., description="Path to a chest X-ray image")
    prompt: str = Field(..., description="Question or instruction for the image")
    max_new_tokens: int = Field(
        300,
        description="Maximum number of tokens to generate in the answer",
    )


class MedGemmaXRayTool(BaseTool):
    """A tool that uses medgemma to answer questions about chest X-ray images."""

    name: str = "medgemma_xray_expert"
    description: str = (
        "The 1st tool to be used by the agent to answer any questions related to xray images."
        "The tool is specialized in performing multiple tasks including Visual Question Answering,"
        "Report generation, Abnormality detection, Anatomical localization, Clinical interpretations,"
        "Comparitive analysis, Identfication and explanation of imaging signs. Input should be paths to"
        "X-ray images and a natural language prompt describing the task to be carried out."
    )
    args_schema: Type[BaseModel] = MedGemmaInput
    return_direct: bool = True

    # model handles
    model: Optional[AutoModelForImageTextToText] = None
    processor: Optional[AutoProcessor] = None

    # config
    model_name: str = "google/medgemma-4b-it"
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    dtype: torch.dtype = torch.bfloat16

    def __init__(
        self,
        model_name: str = "google/medgemma-4b-it",
        device: Optional[str] = None,
        dtype: torch.dtype = torch.bfloat16,
        cache_dir: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)

        self.model_name = model_name
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.dtype = dtype

        # Load model & processor
        self.model = AutoModelForImageTextToText.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=dtype,
            trust_remote_code=True,
            cache_dir=cache_dir,
        )
        self.processor = AutoProcessor.from_pretrained(
            model_name, trust_remote_code=True, cache_dir=cache_dir
        )
        self.model.eval()

    def _generate(
        self,
        image_path: str,
        prompt: str,
        max_new_tokens: int,
    ) -> str:
        """Run MedGEMMA and return decoded answer."""
        img = Image.open(image_path).convert("RGB")

        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": "You are an expert radiologist. Provide a detailed response to user's query."}],
            },
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt},
                    {"type": "image", "image": img},
                ],
            },
        ]

        # 3. Tokenise with chat template
        inputs = self.processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
        ).to(self.model.device, dtype=self.dtype)

        start_len = inputs["input_ids"].shape[-1]

        # 4. Generate
        with torch.inference_mode():
            gens = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
            )
        decoded = self.processor.decode(
            gens[0][start_len:], skip_special_tokens=True
        )
        return decoded.strip()

    def _run(
        self,
        image_path: str,
        prompt: str,
        max_new_tokens: int = 300,
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> Tuple[Dict[str, Any], Dict]:
        """Validate, invoke model, return output + metadata."""
        try:
            if not Path(image_path).is_file():
                raise FileNotFoundError(f"Image not found: {image_path}")

            answer = self._generate(image_path, prompt, max_new_tokens)

            return (
                {"response": answer},
                {
                    "image_path": image_path,
                    "prompt": prompt,
                    "max_new_tokens": max_new_tokens,
                    "status": "completed",
                },
            )
        
        except Exception as e:
            return (
                {"error": str(e)},
                {
                    "image_path": image_path,
                    "prompt": prompt,
                    "max_new_tokens": max_new_tokens,
                    "status": "failed",
                    "error": str(e),
                },
            )

    async def _arun(
        self,
        image_path: str,
        prompt: str,
        max_new_tokens: int = 300,
        run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
    ) -> Tuple[Dict[str, Any], Dict]:
        """Asynchronous wrapper (delegates to sync)."""
        return self._run(image_path, prompt, max_new_tokens)