Commit
·
3e275ef
1
Parent(s):
baa7e0f
Finalized model with better hyperparameters and also including Determined AI checkpoint.pt
Browse files- README.md +56 -22
- image.png → images/image.png +0 -0
- images/test_accuracy.png +0 -0
- images/test_auc.png +0 -0
- images/test_loss.png +0 -0
- model/Determined_AI_checkpoint.pt +3 -0
- model/Determined_AI_metadata.json +4 -0
- model/config.json +10 -0
- pytorch_model.bin → model/pytorch_model.bin +2 -2
README.md
CHANGED
@@ -8,7 +8,9 @@ tags:
|
|
8 |
datasets:
|
9 |
- medmnist
|
10 |
metrics:
|
|
|
11 |
- accuracy
|
|
|
12 |
---
|
13 |
|
14 |
# MedMNIST Active Learning Model
|
@@ -27,19 +29,65 @@ This model is designed for image classification tasks within the medical imaging
|
|
27 |
|
28 |
## Training Procedure
|
29 |
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
- Random resized cropping
|
33 |
- Horizontal flipping
|
34 |
- Random rotations
|
35 |
- Color jittering
|
36 |
- Gaussian blur
|
37 |
- RandAugment
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
## Usage
|
|
|
|
|
43 |
|
44 |
To utilize this model:
|
45 |
|
@@ -84,27 +132,13 @@ To utilize this model:
|
|
84 |
print(f"Predicted class: {prediction}")
|
85 |
```
|
86 |
|
87 |
-
## Evaluation
|
88 |
-
|
89 |
-
The model was evaluated on the validation set of PathMNIST. Key performance metrics include:
|
90 |
-
|
91 |
-
- **Accuracy:** 94%
|
92 |
-
- **Loss:** 0.1775
|
93 |
-
|
94 |
-
## Evaluation Metrics
|
95 |
-
|
96 |
-
The following plot illustrates the validation loss over training batches during the active learning process. The consistent decrease in validation loss demonstrates the effectiveness of the active learning strategy in improving model performance.
|
97 |
-
|
98 |
-

|
99 |
-
|
100 |
-
- **Validation Loss**: The graph shows a steady decline, indicating successful learning and convergence.
|
101 |
-
- **Batches**: Represents the number of iterations over the dataset.
|
102 |
-
|
103 |
## License
|
104 |
|
105 |
-
This project is licensed under the
|
106 |
|
107 |
## Acknowledgements
|
108 |
|
109 |
- [MedMNIST Dataset](https://medmnist.com/)
|
110 |
- [Determined AI](https://determined.ai/)
|
|
|
|
|
|
8 |
datasets:
|
9 |
- medmnist
|
10 |
metrics:
|
11 |
+
- loss
|
12 |
- accuracy
|
13 |
+
- area under the curve
|
14 |
---
|
15 |
|
16 |
# MedMNIST Active Learning Model
|
|
|
29 |
|
30 |
## Training Procedure
|
31 |
|
32 |
+
### Training Hyperparameters
|
33 |
+
|
34 |
+
| Hyperparameter | Value |
|
35 |
+
|------------------------|------------------------|
|
36 |
+
| Batch Size | 53 |
|
37 |
+
| Initial Labeled Size | 3559 |
|
38 |
+
| Learning Rate | 0.01332344940133225 |
|
39 |
+
| MC Dropout Passes | 6 |
|
40 |
+
| Samples to Label | 4430 |
|
41 |
+
| Weight Decay | 0.00021921795989143406 |
|
42 |
+
|
43 |
+
### Optimizer Settings
|
44 |
+
|
45 |
+
The optimizer used during training was Stochastic Gradient Descent(SDG), with the following settings and a Learning Rate Scheduler of ReduceLROnPlateau:
|
46 |
+
- `learning_rate = 0.01332344940133225`
|
47 |
+
- `momentum = 0.9`
|
48 |
+
- `weight_decay = 0.00021921795989143406`
|
49 |
+
|
50 |
+
The model was trained with float32 precision.
|
51 |
+
|
52 |
+
### Dataset
|
53 |
+
[PathMNIST](https://medmnist.com/)
|
54 |
+
|
55 |
+
### Data Augmentation
|
56 |
- Random resized cropping
|
57 |
- Horizontal flipping
|
58 |
- Random rotations
|
59 |
- Color jittering
|
60 |
- Gaussian blur
|
61 |
- RandAugment
|
62 |
+
|
63 |
+
### Active Learning Strategy
|
64 |
+
|
65 |
+
The active learning process was based on a mixed sampling strategy:
|
66 |
+
- **Uncertainty Sampling**: Monte Carlo (MC) dropout was used to estimate uncertainty.
|
67 |
+
- **Diversity Sampling**: K-means clustering was employed to ensure diverse samples.
|
68 |
+
|
69 |
+
## Evaluation
|
70 |
+
|
71 |
+
The model was evaluated on the validation set of PathMNIST. Key performance metrics include:
|
72 |
+
|
73 |
+
- **Accuracy:** 94.72%
|
74 |
+
- **Loss:** 0.2397
|
75 |
+
- **AUC:** 99.73%
|
76 |
+
|
77 |
+
## Graphs
|
78 |
+
|
79 |
+
The following plots illustrates the validation loss, validation accuracy, and validation auc over batches(number of iterations over the dataset) during the active learning process.
|
80 |
+
|
81 |
+
- **Validation Loss**
|
82 |
+

|
83 |
+
- **Validation Accuracy**
|
84 |
+

|
85 |
+
- **Validation AUC**
|
86 |
+

|
87 |
|
88 |
## Usage
|
89 |
+
All code for this model can be accessed in the following GitHub Repository:
|
90 |
+
[Allen Cheung Determined_AI_Hackathon](https://github.com/AllenCheung0213/Determined_AI_Hackathon)
|
91 |
|
92 |
To utilize this model:
|
93 |
|
|
|
132 |
print(f"Predicted class: {prediction}")
|
133 |
```
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
## License
|
136 |
|
137 |
+
This project is licensed under the MIT License.
|
138 |
|
139 |
## Acknowledgements
|
140 |
|
141 |
- [MedMNIST Dataset](https://medmnist.com/)
|
142 |
- [Determined AI](https://determined.ai/)
|
143 |
+
- **Survey on Deep Active Learning**: Wang, H., Jin, Q., Li, S., Liu, S., Wang, M., & Song, Z. (2024). A comprehensive survey on deep active learning in medical image analysis. *Medical Image Analysis*, 95, 103201. [https://doi.org/10.1016/j.media.2024.103201](https://doi.org/10.1016/j.media.2024.103201)
|
144 |
+
|
image.png → images/image.png
RENAMED
File without changes
|
images/test_accuracy.png
ADDED
![]() |
images/test_auc.png
ADDED
![]() |
images/test_loss.png
ADDED
![]() |
model/Determined_AI_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:90601111b45e812a9aa039a02d9bb133468216b348f3ff067ac211a0787e2e97
|
3 |
+
size 188520735
|
model/Determined_AI_metadata.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"epochs_completed": 18,
|
3 |
+
"steps_completed": 54905
|
4 |
+
}
|
model/config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "resnet50",
|
3 |
+
"num_classes": 9,
|
4 |
+
"input_size": [
|
5 |
+
3,
|
6 |
+
28,
|
7 |
+
28
|
8 |
+
],
|
9 |
+
"architecture": "ResNet50"
|
10 |
+
}
|
pytorch_model.bin → model/pytorch_model.bin
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:db3f9ab941286c336727e5a0c4d4b35ff1b8db5b7f8519573600bd2ee0108ef7
|
3 |
+
size 94397514
|