Yuan Gao commited on
Commit
0f1419c
·
1 Parent(s): 3b89e43

preprocessing code, more in github

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -1
  2. mixinhelpers.py +221 -0
  3. preprocessor.py +69 -0
.gitattributes CHANGED
@@ -33,4 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- *.png filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
mixinhelpers.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # For CXR
2
+ import random
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from torchvision import transforms
9
+ from transformers import BatchEncoding, PreTrainedTokenizer
10
+
11
+ """
12
+ Mixin for all modalities, each mixin has:
13
+ - preprocess function that takes in path or data and returns tensor
14
+ - construct_input function that takes in tensor and returns dict with batch
15
+ dimension for model input
16
+ - key string for model input dict
17
+ """
18
+
19
+
20
+ class ECHO_Mixin:
21
+ LOWER_YELLOW: list[int] = [20, 50, 50]
22
+ UPPER_YELLOW: list[int] = [100, 255, 255]
23
+ IMAGE_SIZE: tuple[int, int] = (224, 224)
24
+ NORM_MEAN: tuple[float, float, float] = (0.48145466, 0.4578275, 0.40821073)
25
+ NORM_STD: tuple[float, float, float] = (0.26862954, 0.26130258, 0.27577711)
26
+
27
+ ECHO_TRANSFORMS = transforms.Compose(
28
+ [
29
+ transforms.ToTensor(), # Scaling into [0, 1]
30
+ transforms.Resize(IMAGE_SIZE),
31
+ transforms.Normalize(
32
+ mean=NORM_MEAN,
33
+ std=NORM_STD,
34
+ ),
35
+ ]
36
+ )
37
+ ECHO_KEY: str = "echo"
38
+
39
+ def grabimage(self, split: str, data: dict[str, np.ndarray]) -> np.ndarray:
40
+ """"""
41
+ if split == "train":
42
+ caseofinterest = random.choice(list(data.keys()))
43
+ imageindice = random.choice(list(range(data[caseofinterest].shape[0])))
44
+
45
+ else:
46
+ caseofinterest = random.choice(list(data.keys())) # listofcases[0]
47
+ imageindice = 0
48
+ video = data[caseofinterest]
49
+ return self.extract_echoframe(imageindice, video)
50
+
51
+ def extract_echoframe(self, imageindice: int, video: np.ndarray) -> np.ndarray:
52
+ image = video[imageindice]
53
+ hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
54
+ lower_yellow = np.array(self.LOWER_YELLOW) # Lower bound of yellow hue
55
+ upper_yellow = np.array(self.UPPER_YELLOW) # Upper bound of yellow hue
56
+ mask = cv2.inRange(hsv_image, lower_yellow, upper_yellow)
57
+ image[mask > 0] = [0, 0, 0]
58
+ image = np.array(image, dtype=np.float32)
59
+ image -= image.min()
60
+ image /= image.max()
61
+ image *= 255
62
+
63
+ image = image
64
+ image = image[:, :, :]
65
+ image = image.astype(np.uint8)
66
+ return image
67
+
68
+ def preprocess_echoseries(
69
+ self, video_dict: dict[str, np.ndarray], split: str = "valid"
70
+ ) -> torch.Tensor:
71
+ """assumes inference mode"""
72
+ image = self.grabimage(split, video_dict)
73
+ if not isinstance(image, np.ndarray):
74
+ raise TypeError("Expected image to be a numpy ndarray")
75
+ pil_image = Image.fromarray(image)
76
+ transformed = self.ECHO_TRANSFORMS(pil_image)
77
+ if not isinstance(transformed, torch.Tensor):
78
+ transformed = transforms.ToTensor()(pil_image)
79
+ return transformed
80
+
81
+ def preprocess_single_echo(self, avi_path: str) -> torch.Tensor:
82
+ """assumes inference mode, opens AVI file and processes first frame
83
+ Output: image: torch.Tensor of shape (C, H, W)
84
+ """
85
+ cap = cv2.VideoCapture(avi_path)
86
+ success, frame = cap.read()
87
+ cap.release()
88
+ if not success or frame is None:
89
+ raise ValueError(f"Could not read frame from AVI file: {avi_path}")
90
+ image = self.extract_echoframe(0, np.array([frame])) # process first frame
91
+ image = self.ECHO_TRANSFORMS(Image.fromarray(image))
92
+ if not isinstance(image, torch.Tensor):
93
+ image = torch.from_numpy(image)
94
+ return image
95
+
96
+
97
+ # CXR
98
+ class CXR_Mixin:
99
+ RESIZE: tuple[int, int] = (256, 256)
100
+ IMAGE_SIZE: tuple[int, int] = (224, 224)
101
+ NORM_MEAN: list[float] = [0.5862785803043838]
102
+ NORM_STD: list[float] = [0.27950088968644304]
103
+ VISION_KEY: str = "vision"
104
+ CXR_TRANSFORMS = transforms.Compose(
105
+ [
106
+ transforms.ToTensor(), # Scaling into [0, 1]
107
+ transforms.Resize(RESIZE),
108
+ transforms.CenterCrop(IMAGE_SIZE),
109
+ transforms.Normalize(
110
+ mean=NORM_MEAN,
111
+ std=NORM_STD,
112
+ ),
113
+ ]
114
+ )
115
+
116
+ @staticmethod
117
+ def remove_border(pixel_array: np.ndarray) -> np.ndarray:
118
+ # Find where the image is not just background (0s)
119
+ coords = np.column_stack(np.where(pixel_array > 0))
120
+ x_min, y_min = coords.min(axis=0)
121
+ x_max, y_max = coords.max(axis=0)
122
+ # Crop the image
123
+ cropped_image = pixel_array[x_min:x_max, y_min:y_max]
124
+ return cropped_image
125
+
126
+ def preprocess_loaded_cxr(self, img: np.array) -> torch.Tensor:
127
+ cxr = self.remove_border(img)
128
+ # Convert grayscale image to 3-channel RGB
129
+ cxr = np.repeat(cxr[..., np.newaxis], 3, axis=-1)
130
+
131
+ cxr = Image.fromarray(cxr)
132
+ transformed = self.CXR_TRANSFORMS(cxr)
133
+ if not isinstance(transformed, torch.Tensor):
134
+ transformed = transforms.ToTensor()(cxr)
135
+ return transformed
136
+
137
+ def preprocess_single_cxr(self, image_path: str) -> torch.Tensor:
138
+ """assumes inference mode"""
139
+ with open(image_path, "rb") as fopen:
140
+ image = Image.open(fopen).convert("RGB")
141
+ image = np.array(image)[:, :, 0] # convert to grayscale
142
+
143
+ cxr = self.preprocess_loaded_cxr(image)
144
+ return cxr
145
+
146
+
147
+ class ECG_Mixin:
148
+ LENGTH: int = 1000
149
+ FREQUENCY: int = 100 # we assume 100Hz sampling rate
150
+ CHANNELS: int = 12
151
+ NORM_MEAN: float = 0.02547506
152
+ NORM_SCALE: float = 0.16486814
153
+ NORM_VAR: float = 0.0271815
154
+ ECG_KEY: str = "ecg"
155
+
156
+ def manual_standardize(self, x: np.ndarray) -> torch.Tensor:
157
+ """
158
+ Apply manual standardization to ECG or other data.
159
+ Equivalent to sklearn's StandardScaler with given constants.
160
+
161
+ Args:
162
+ x (np.ndarray): Input array of shape (12, 1000)
163
+ Returns:
164
+ torch.Tensor: Scaled array of the same shape
165
+ """
166
+ return torch.from_numpy((x - self.NORM_MEAN) / self.NORM_SCALE).float()
167
+
168
+ def check_ecg(self, ecg: np.ndarray) -> np.ndarray:
169
+ # Find where the image is not just background (0s)
170
+ if np.isnan(ecg).any() or np.isinf(ecg).any():
171
+ raise ValueError("ECG contains NaN or Inf values")
172
+ return ecg[:, : self.LENGTH] # Truncate to first 1000 length (10 seconds at 100Hz)
173
+
174
+ def preprocess_single_ecg(self, ecg_path: str) -> torch.Tensor:
175
+ """assumes inference mode"""
176
+ # ecg is a np array path, assumes 12 channels
177
+ ecg = np.load(ecg_path)
178
+ if ecg.ndim == 2 and ecg.shape[0] != self.CHANNELS:
179
+ raise ValueError(f"Expected ECG with {self.CHANNELS} channels, got {ecg.shape[0]}")
180
+
181
+ ecg = self.check_ecg(ecg)
182
+ transformed = self.manual_standardize(ecg)
183
+
184
+ return transformed
185
+
186
+
187
+ class Text_Mixin:
188
+ MODALITY_LIST: dict[str, str] = {"echo": "echocardiogram", "ecg": "ecg", "vision": "cxr"}
189
+ MAX_LENGTH: int = 120 # longer length to accomodate longer reports
190
+ TEXT_LENGTH: int = 100 # 100 words
191
+
192
+ def get_first_n_words(self, text: str, n: int = 100) -> str:
193
+ """97.5 percentile of text is less than 35 words"""
194
+ words = text.split() # Split the text into words
195
+ return " ".join(words[:n]) # Join the first n words back into a string
196
+
197
+ def createCaption(self, caption: str, modality: str = "") -> str:
198
+ assert modality in set(self.MODALITY_LIST.keys()) or modality == "", (
199
+ f"modality should be in {self.MODALITY_LIST} or empty"
200
+ )
201
+ return f"text : {caption}, {modality} looks like : "
202
+
203
+ def createTokenizedCaption(self, caption: str, tokenizer: PreTrainedTokenizer) -> BatchEncoding:
204
+ encoding = tokenizer(
205
+ caption,
206
+ padding="max_length",
207
+ truncation=True,
208
+ max_length=self.MAX_LENGTH,
209
+ return_tensors="pt",
210
+ )
211
+ return encoding
212
+
213
+ def construct_caption(
214
+ self, caption: str, tokenizer: PreTrainedTokenizer, modality: str = ""
215
+ ) -> BatchEncoding:
216
+ """given caption string, return tokenized caption dict for model input
217
+ Output: dict with keys 'input_ids' and 'attention_mask', each of shape (1, L)
218
+ """
219
+ caption_str = self.createCaption(caption, modality)
220
+ tokenized = self.createTokenizedCaption(caption_str, tokenizer)
221
+ return tokenized
preprocessor.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, BatchEncoding
3
+
4
+ from mixinhelpers import CXR_Mixin, ECG_Mixin, ECHO_Mixin, Text_Mixin
5
+
6
+ """
7
+ Preprocessor classes for different modalities and their combinations.
8
+ You can combine different mixins to create preprocessors for multi-modal inputs.
9
+ Examples below are provided for ECHO+Text, ECG+Text, and CXR+Text.
10
+ """
11
+
12
+
13
+ class BasePreprocessor:
14
+ def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None:
15
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
16
+
17
+
18
+ # duo modality preprocessors
19
+ class ECHOText_Preprocessor(BasePreprocessor, ECHO_Mixin, Text_Mixin):
20
+ def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None:
21
+ super().__init__(model_name=model_name)
22
+
23
+ def preprocess_echo_text(self, echo_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]:
24
+ """this can be used in dataloader to correctly collate batches, use the string keys to
25
+ identify the modalities
26
+ echo_path: path to echo npy file
27
+ text: string of text report
28
+ returns: (echo tensor, tokenized text dict)"""
29
+ echo = self.preprocess_single_echo(echo_path) # (C, H, W)
30
+ text_inputs = self.construct_caption(
31
+ caption=text, tokenizer=self.tokenizer, modality=self.ECHO_KEY
32
+ )
33
+ return echo, text_inputs
34
+
35
+
36
+ class ECGText_Preprocessor(BasePreprocessor, ECG_Mixin, Text_Mixin):
37
+ def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None:
38
+ super().__init__(model_name=model_name)
39
+
40
+ def preprocess_ecg_text(self, ecg_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]:
41
+ """this can be used in dataloader to correctly collate batches, use the string keys
42
+ to identify the modalities
43
+ ecg_path: path to ecg npy file
44
+ text: string of text report
45
+ returns: (ecg tensor, tokenized text dict)"""
46
+ ecg = self.preprocess_single_ecg(ecg_path) # (C, L)
47
+ text_inputs = self.construct_caption(
48
+ caption=text, tokenizer=self.tokenizer, modality=self.ECG_KEY
49
+ )
50
+
51
+ return ecg, text_inputs
52
+
53
+
54
+ class CXRText_Preprocessor(BasePreprocessor, CXR_Mixin, Text_Mixin):
55
+ def __init__(self, model_name: str = "dmis-lab/biobert-v1.1") -> None:
56
+ super().__init__(model_name=model_name)
57
+
58
+ def preprocess_cxr_text(self, cxr_path: str, text: str) -> tuple[torch.Tensor, BatchEncoding]:
59
+ """this can be used in dataloader to correctly collate batches, use the string keys to
60
+ identify the modalities
61
+ cxr_path: path to cxr image file
62
+ text: string of text report
63
+ returns: (cxr tensor, tokenized text dict)"""
64
+ cxr = self.preprocess_single_cxr(cxr_path) # (C, H, W)
65
+ text_inputs = self.construct_caption(
66
+ caption=text, tokenizer=self.tokenizer, modality=self.VISION_KEY
67
+ )
68
+
69
+ return cxr, text_inputs