ianpan commited on
Commit
7bfd23a
·
verified ·
1 Parent(s): 2e324e1

Upload MammoEnsemble

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +37 -0
  3. configuration.py +32 -0
  4. model.safetensors +3 -0
  5. modeling.py +195 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MammoEnsemble"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration.MammoConfig",
7
+ "AutoModel": "modeling.MammoEnsemble"
8
+ },
9
+ "backbone": "tf_efficientnetv2_s",
10
+ "dropout": 0.1,
11
+ "feature_dim": 1280,
12
+ "image_sizes": [
13
+ [
14
+ 2048,
15
+ 1024
16
+ ],
17
+ [
18
+ 1920,
19
+ 1280
20
+ ],
21
+ [
22
+ 1536,
23
+ 1536
24
+ ]
25
+ ],
26
+ "in_chans": 1,
27
+ "model_type": "mammo",
28
+ "num_classes": 5,
29
+ "num_models": 3,
30
+ "pad_to_aspect_ratio": [
31
+ true,
32
+ true,
33
+ false
34
+ ],
35
+ "torch_dtype": "float32",
36
+ "transformers_version": "4.47.0"
37
+ }
configuration.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List, Tuple
3
+
4
+
5
+ class MammoConfig(PretrainedConfig):
6
+ model_type = "mammo"
7
+
8
+ def __init__(
9
+ self,
10
+ backbone: str = "tf_efficientnetv2_s",
11
+ feature_dim: int = 1280,
12
+ dropout: float = 0.1,
13
+ num_classes: int = 5,
14
+ in_chans: int = 1,
15
+ num_models: int = 3,
16
+ image_sizes: List[Tuple[int, int]] = [(2048, 1024), (1920, 1280), (1536, 1536)],
17
+ pad_to_aspect_ratio: List[bool] = [True, True, False],
18
+ **kwargs,
19
+ ):
20
+ self.backbone = backbone
21
+ self.feature_dim = feature_dim
22
+ self.dropout = dropout
23
+ self.num_classes = num_classes
24
+ self.in_chans = in_chans
25
+ self.num_models = num_models
26
+ assert len(image_sizes) == len(pad_to_aspect_ratio) == num_models, (
27
+ f"length of `image_sizes` [{len(image_sizes)}] and `pad_to_aspect_ratio` "
28
+ f"[{len(pad_to_aspect_ratio)}] must be equal to `num_models` [{num_models}]."
29
+ )
30
+ self.image_sizes = image_sizes
31
+ self.pad_to_aspect_ratio = pad_to_aspect_ratio
32
+ super().__init__(**kwargs)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ace0998e2d534b08eaa673fad37df16eaffef6dca3f3b5ac2196857a2949596
3
+ size 244305924
modeling.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from transformers import PreTrainedModel
7
+ from timm import create_model
8
+ from typing import Mapping, Sequence, Tuple
9
+ from .configuration import MammoConfig
10
+
11
+
12
+ def _pad_to_aspect_ratio(img: np.ndarray, aspect_ratio: float) -> np.ndarray:
13
+ """
14
+ Pads to specified aspect ratio, only if current aspect ratio is
15
+ greater.
16
+ """
17
+ h, w = img.shape[:2]
18
+ if h / w > aspect_ratio:
19
+ new_w = round(h / aspect_ratio)
20
+ w_diff = new_w - w
21
+ left_pad = w_diff // 2
22
+ right_pad = w_diff - left_pad
23
+ padding = ((0, 0), (left_pad, right_pad))
24
+ if img.ndim == 3:
25
+ padding = padding + ((0, 0),)
26
+ img = np.pad(img, padding, mode="constant", constant_values=0)
27
+ return img
28
+
29
+
30
+ def _to_torch_tensor(x: np.ndarray, device: str) -> torch.Tensor:
31
+ if x.ndim == 2:
32
+ x = torch.from_numpy(x).unsqueeze(0)
33
+ elif x.ndim == 3:
34
+ x = torch.from_numpy(x)
35
+ if torch.tensor(x.size()).argmin().item() == 2:
36
+ # channels last -> first
37
+ x = x.permute(2, 0, 1)
38
+ else:
39
+ raise ValueError(f"Expected 2 or 3 dimensions, got {x.ndim}")
40
+ return x.float().to(device)
41
+
42
+
43
+ class MammoModel(nn.Module):
44
+ def __init__(
45
+ self,
46
+ backbone: str,
47
+ image_size: Tuple[int, int],
48
+ pad_to_aspect_ratio: bool,
49
+ feature_dim: int = 1280,
50
+ dropout: float = 0.1,
51
+ num_classes: int = 5,
52
+ in_chans: int = 1,
53
+ ):
54
+ super().__init__()
55
+ self.backbone = create_model(
56
+ model_name=backbone,
57
+ pretrained=False,
58
+ num_classes=0,
59
+ global_pool="",
60
+ features_only=False,
61
+ in_chans=in_chans,
62
+ )
63
+ self.pooling = nn.AdaptiveAvgPool2d(1)
64
+ self.dropout = nn.Dropout(p=dropout)
65
+ self.linear = nn.Linear(feature_dim, num_classes)
66
+
67
+ self.pad_to_aspect_ratio = pad_to_aspect_ratio
68
+ self.aspect_ratio = image_size[0] / image_size[1]
69
+ if self.pad_to_aspect_ratio:
70
+ self.resize = A.Resize(image_size[0], image_size[1], p=1)
71
+ else:
72
+ self.resize = A.Compose(
73
+ [
74
+ A.LongestMaxSize(image_size[0], p=1),
75
+ A.PadIfNeeded(image_size[0], image_size[1], p=1),
76
+ ],
77
+ p=1,
78
+ )
79
+
80
+ def normalize(self, x: torch.Tensor) -> torch.Tensor:
81
+ # [0, 255] -> [-1, 1]
82
+ mini, maxi = 0.0, 255.0
83
+ x = (x - mini) / (maxi - mini)
84
+ x = (x - 0.5) * 2.0
85
+ return x
86
+
87
+ def preprocess(
88
+ self,
89
+ x: Mapping[str, np.ndarray] | Sequence[Mapping[str, np.ndarray]],
90
+ device: str,
91
+ ) -> Sequence[Mapping[str, torch.Tensor]]:
92
+ # x is a dict (or list of dicts) with keys "cc" and/or "mlo"
93
+ # though the actual keys do not matter
94
+ if not isinstance(x, Sequence):
95
+ assert isinstance(x, Mapping)
96
+ x = [x]
97
+ if self.pad_to_aspect_ratio:
98
+ x = [
99
+ {
100
+ k: _pad_to_aspect_ratio(v.copy(), self.aspect_ratio)
101
+ for k, v in sample.items()
102
+ }
103
+ for sample in x
104
+ ]
105
+ x = [
106
+ {
107
+ k: _to_torch_tensor(self.resize(image=v)["image"], device=device)
108
+ for k, v in sample.items()
109
+ }
110
+ for sample in x
111
+ ]
112
+ return x
113
+
114
+ def forward(
115
+ self, x: Sequence[Mapping[str, torch.Tensor]]
116
+ ) -> Mapping[str, torch.Tensor]:
117
+ batch_tensor = []
118
+ batch_indices = []
119
+ for idx, sample in enumerate(x):
120
+ for k, v in sample.items():
121
+ batch_tensor.append(v)
122
+ batch_indices.append(idx)
123
+
124
+ batch_tensor = torch.stack(batch_tensor, dim=0)
125
+ batch_tensor = self.normalize(batch_tensor)
126
+ features = self.pooling(self.backbone(batch_tensor))
127
+ b, d = features.shape[:2]
128
+ features = features.reshape(b, d)
129
+ logits = self.linear(features)
130
+ # cancer
131
+ logits0 = logits[:, 0].sigmoid()
132
+ # density
133
+ logits1 = logits[:, 1:].softmax(dim=1)
134
+ # mean over views
135
+ batch_indices = torch.tensor(batch_indices)
136
+ logits0 = torch.stack(
137
+ [logits0[batch_indices == i].mean(dim=0) for i in batch_indices.unique()]
138
+ )
139
+ logits1 = torch.stack(
140
+ [logits1[batch_indices == i].mean(dim=0) for i in batch_indices.unique()]
141
+ )
142
+ return {"cancer": logits0, "density": logits1}
143
+
144
+
145
+ class MammoEnsemble(PreTrainedModel):
146
+ config_class = MammoConfig
147
+
148
+ def __init__(self, config):
149
+ super().__init__(config)
150
+ self.num_models = config.num_models
151
+ for i in range(self.num_models):
152
+ setattr(
153
+ self,
154
+ f"net{i}",
155
+ MammoModel(
156
+ config.backbone,
157
+ config.image_sizes[i],
158
+ config.pad_to_aspect_ratio[i],
159
+ config.feature_dim,
160
+ config.dropout,
161
+ config.num_classes,
162
+ config.in_chans,
163
+ ),
164
+ )
165
+
166
+ @staticmethod
167
+ def load_image_from_dicom(path: str) -> np.ndarray | None:
168
+ try:
169
+ from pydicom import dcmread
170
+ from pydicom.pixels import apply_voi_lut
171
+ except ModuleNotFoundError:
172
+ print("`pydicom` is not installed, returning None ...")
173
+ return None
174
+ dicom = dcmread(path)
175
+ arr = apply_voi_lut(dicom.pixel_array, dicom)
176
+ if dicom.PhotometricInterpretation == "MONOCHROME1":
177
+ arr = arr.max() - arr
178
+
179
+ arr = arr - arr.min()
180
+ arr = arr / arr.max()
181
+ arr = (arr * 255).astype("uint8")
182
+ return arr
183
+
184
+ def forward(
185
+ self,
186
+ x: Mapping[str, np.ndarray] | Sequence[Mapping[str, np.ndarray]],
187
+ device: str = "cpu",
188
+ ) -> Mapping[str, torch.Tensor]:
189
+ out = []
190
+ for i in range(self.num_models):
191
+ model = getattr(self, f"net{i}")
192
+ x_pp = model.preprocess(x, device=device)
193
+ out.append(model(x_pp))
194
+ out = {k: torch.stack([o[k] for o in out]).mean(0) for k in out[0].keys()}
195
+ return out