Dhruv-Ty commited on
Commit
e1ede20
·
1 Parent(s): 0ffa584

resolved the PermissionError

Browse files
Files changed (1) hide show
  1. medrax/tools/report_generation.py +9 -61
medrax/tools/report_generation.py CHANGED
@@ -2,6 +2,7 @@ from typing import Any, Dict, Optional, Tuple, Type
2
  from pydantic import BaseModel, Field
3
 
4
  import torch
 
5
 
6
  from langchain_core.callbacks import (
7
  AsyncCallbackManagerForToolRun,
@@ -28,18 +29,6 @@ class ChestXRayInput(BaseModel):
28
 
29
 
30
  class ChestXRayReportGeneratorTool(BaseTool):
31
- """Tool that generates comprehensive chest X-ray reports with both findings and impressions.
32
-
33
- This tool uses two Vision-Encoder-Decoder models (ViT-BERT) trained on CheXpert
34
- and MIMIC-CXR datasets to generate structured radiology reports. It automatically
35
- generates both detailed findings and impression summaries for each chest X-ray,
36
- following standard radiological reporting format.
37
-
38
- The tool uses:
39
- - Findings model: Generates detailed observations of all visible structures
40
- - Impression model: Provides concise clinical interpretation and key diagnoses
41
- """
42
-
43
  name: str = "chest_xray_report_generator"
44
  description: str = (
45
  "A tool that analyzes chest X-ray images and generates comprehensive radiology reports "
@@ -47,7 +36,7 @@ class ChestXRayReportGeneratorTool(BaseTool):
47
  "to a chest X-ray image file. Output is a structured report with both detailed "
48
  "observations and key clinical conclusions."
49
  )
50
- device: Optional[str] = "cpu" # Change the device to "cpu"
51
  args_schema: Type[BaseModel] = ChestXRayInput
52
  findings_model: VisionEncoderDecoderModel = None
53
  impression_model: VisionEncoderDecoderModel = None
@@ -57,12 +46,12 @@ class ChestXRayReportGeneratorTool(BaseTool):
57
  impression_processor: ViTImageProcessor = None
58
  generation_args: Dict[str, Any] = None
59
 
60
- def __init__(self, cache_dir: str = "/model-weights", device: Optional[str] = "cpu"):
61
- """Initialize the ChestXRayReportGeneratorTool with both findings and impression models."""
62
  super().__init__()
63
- self.device = torch.device(device) if device else torch.device("cpu") # Ensure CPU is used
 
64
 
65
- # Initialize findings model
66
  self.findings_model = VisionEncoderDecoderModel.from_pretrained(
67
  "IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
68
  ).eval()
@@ -73,7 +62,7 @@ class ChestXRayReportGeneratorTool(BaseTool):
73
  "IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
74
  )
75
 
76
- # Initialize impression model
77
  self.impression_model = VisionEncoderDecoderModel.from_pretrained(
78
  "IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
79
  ).eval()
@@ -84,11 +73,10 @@ class ChestXRayReportGeneratorTool(BaseTool):
84
  "IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
85
  )
86
 
87
- # Move models to device (CPU)
88
  self.findings_model = self.findings_model.to(self.device)
89
  self.impression_model = self.impression_model.to(self.device)
90
 
91
- # Default generation arguments
92
  self.generation_args = {
93
  "num_return_sequences": 1,
94
  "max_length": 128,
@@ -99,19 +87,8 @@ class ChestXRayReportGeneratorTool(BaseTool):
99
  def _process_image(
100
  self, image_path: str, processor: ViTImageProcessor, model: VisionEncoderDecoderModel
101
  ) -> torch.Tensor:
102
- """Process the input image for a specific model.
103
-
104
- Args:
105
- image_path (str): Path to the input image.
106
- processor: Image processor for the specific model.
107
- model: The model to process the image for.
108
-
109
- Returns:
110
- torch.Tensor: Processed image tensor ready for model input.
111
- """
112
  image = Image.open(image_path).convert("RGB")
113
  pixel_values = processor(image, return_tensors="pt").pixel_values
114
-
115
  expected_size = model.config.encoder.image_size
116
  actual_size = pixel_values.shape[-1]
117
 
@@ -123,23 +100,11 @@ class ChestXRayReportGeneratorTool(BaseTool):
123
  align_corners=False,
124
  )
125
 
126
- pixel_values = pixel_values.to(self.device)
127
-
128
- return pixel_values
129
 
130
  def _generate_report_section(
131
  self, pixel_values: torch.Tensor, model: VisionEncoderDecoderModel, tokenizer: BertTokenizer
132
  ) -> str:
133
- """Generate a report section using the specified model.
134
-
135
- Args:
136
- pixel_values: Processed image tensor.
137
- model: The model to use for generation.
138
- tokenizer: The tokenizer for the model.
139
-
140
- Returns:
141
- str: Generated text for the report section.
142
- """
143
  generation_config = GenerationConfig(
144
  **{
145
  **self.generation_args,
@@ -149,9 +114,7 @@ class ChestXRayReportGeneratorTool(BaseTool):
149
  "decoder_start_token_id": tokenizer.cls_token_id,
150
  }
151
  )
152
-
153
  generated_ids = model.generate(pixel_values, generation_config=generation_config)
154
-
155
  return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
156
 
157
  def _run(
@@ -159,17 +122,7 @@ class ChestXRayReportGeneratorTool(BaseTool):
159
  image_path: str,
160
  run_manager: Optional[CallbackManagerForToolRun] = None,
161
  ) -> Tuple[str, Dict]:
162
- """Generate a comprehensive chest X-ray report containing both findings and impression.
163
-
164
- Args:
165
- image_path (str): The path to the chest X-ray image file.
166
- run_manager (Optional[CallbackManagerForToolRun]): The callback manager.
167
-
168
- Returns:
169
- Tuple[str, Dict]: A tuple containing the complete report and metadata.
170
- """
171
  try:
172
- # Process image for both models
173
  findings_pixels = self._process_image(
174
  image_path, self.findings_processor, self.findings_model
175
  )
@@ -177,7 +130,6 @@ class ChestXRayReportGeneratorTool(BaseTool):
177
  image_path, self.impression_processor, self.impression_model
178
  )
179
 
180
- # Generate both sections
181
  with torch.inference_mode():
182
  findings_text = self._generate_report_section(
183
  findings_pixels, self.findings_model, self.findings_tokenizer
@@ -186,19 +138,16 @@ class ChestXRayReportGeneratorTool(BaseTool):
186
  impression_pixels, self.impression_model, self.impression_tokenizer
187
  )
188
 
189
- # Combine into formatted report
190
  report = (
191
  "CHEST X-RAY REPORT\n\n"
192
  f"FINDINGS:\n{findings_text}\n\n"
193
  f"IMPRESSION:\n{impression_text}"
194
  )
195
-
196
  metadata = {
197
  "image_path": image_path,
198
  "analysis_status": "completed",
199
  "sections_generated": ["findings", "impression"],
200
  }
201
-
202
  return report, metadata
203
 
204
  except Exception as e:
@@ -213,5 +162,4 @@ class ChestXRayReportGeneratorTool(BaseTool):
213
  image_path: str,
214
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
215
  ) -> Tuple[str, Dict]:
216
- """Asynchronously generate a comprehensive chest X-ray report."""
217
  return self._run(image_path)
 
2
  from pydantic import BaseModel, Field
3
 
4
  import torch
5
+ import os # Added to create local cache dir
6
 
7
  from langchain_core.callbacks import (
8
  AsyncCallbackManagerForToolRun,
 
29
 
30
 
31
  class ChestXRayReportGeneratorTool(BaseTool):
 
 
 
 
 
 
 
 
 
 
 
 
32
  name: str = "chest_xray_report_generator"
33
  description: str = (
34
  "A tool that analyzes chest X-ray images and generates comprehensive radiology reports "
 
36
  "to a chest X-ray image file. Output is a structured report with both detailed "
37
  "observations and key clinical conclusions."
38
  )
39
+ device: Optional[str] = "cpu"
40
  args_schema: Type[BaseModel] = ChestXRayInput
41
  findings_model: VisionEncoderDecoderModel = None
42
  impression_model: VisionEncoderDecoderModel = None
 
46
  impression_processor: ViTImageProcessor = None
47
  generation_args: Dict[str, Any] = None
48
 
49
+ def __init__(self, cache_dir: str = "./model_weights", device: Optional[str] = "cpu"):
 
50
  super().__init__()
51
+ os.makedirs(cache_dir, exist_ok=True) # Ensure local folder exists
52
+ self.device = torch.device(device) if device else torch.device("cpu")
53
 
54
+ # Load findings model
55
  self.findings_model = VisionEncoderDecoderModel.from_pretrained(
56
  "IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
57
  ).eval()
 
62
  "IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
63
  )
64
 
65
+ # Load impression model
66
  self.impression_model = VisionEncoderDecoderModel.from_pretrained(
67
  "IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
68
  ).eval()
 
73
  "IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
74
  )
75
 
76
+ # Move models to CPU
77
  self.findings_model = self.findings_model.to(self.device)
78
  self.impression_model = self.impression_model.to(self.device)
79
 
 
80
  self.generation_args = {
81
  "num_return_sequences": 1,
82
  "max_length": 128,
 
87
  def _process_image(
88
  self, image_path: str, processor: ViTImageProcessor, model: VisionEncoderDecoderModel
89
  ) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
90
  image = Image.open(image_path).convert("RGB")
91
  pixel_values = processor(image, return_tensors="pt").pixel_values
 
92
  expected_size = model.config.encoder.image_size
93
  actual_size = pixel_values.shape[-1]
94
 
 
100
  align_corners=False,
101
  )
102
 
103
+ return pixel_values.to(self.device)
 
 
104
 
105
  def _generate_report_section(
106
  self, pixel_values: torch.Tensor, model: VisionEncoderDecoderModel, tokenizer: BertTokenizer
107
  ) -> str:
 
 
 
 
 
 
 
 
 
 
108
  generation_config = GenerationConfig(
109
  **{
110
  **self.generation_args,
 
114
  "decoder_start_token_id": tokenizer.cls_token_id,
115
  }
116
  )
 
117
  generated_ids = model.generate(pixel_values, generation_config=generation_config)
 
118
  return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
119
 
120
  def _run(
 
122
  image_path: str,
123
  run_manager: Optional[CallbackManagerForToolRun] = None,
124
  ) -> Tuple[str, Dict]:
 
 
 
 
 
 
 
 
 
125
  try:
 
126
  findings_pixels = self._process_image(
127
  image_path, self.findings_processor, self.findings_model
128
  )
 
130
  image_path, self.impression_processor, self.impression_model
131
  )
132
 
 
133
  with torch.inference_mode():
134
  findings_text = self._generate_report_section(
135
  findings_pixels, self.findings_model, self.findings_tokenizer
 
138
  impression_pixels, self.impression_model, self.impression_tokenizer
139
  )
140
 
 
141
  report = (
142
  "CHEST X-RAY REPORT\n\n"
143
  f"FINDINGS:\n{findings_text}\n\n"
144
  f"IMPRESSION:\n{impression_text}"
145
  )
 
146
  metadata = {
147
  "image_path": image_path,
148
  "analysis_status": "completed",
149
  "sections_generated": ["findings", "impression"],
150
  }
 
151
  return report, metadata
152
 
153
  except Exception as e:
 
162
  image_path: str,
163
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
164
  ) -> Tuple[str, Dict]:
 
165
  return self._run(image_path)