Commit
·
cc7a0e7
1
Parent(s):
fc78132
Add validation loss plot
Browse files
README.md
CHANGED
@@ -1,7 +1,110 @@
|
|
1 |
-
# MedMNIST Active Learning Model
|
2 |
-
|
3 |
-
Custom ResNet50 model trained using active learning strategies.
|
4 |
-
|
5 |
---
|
|
|
6 |
license: mit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
language: en
|
3 |
license: mit
|
4 |
+
tags:
|
5 |
+
- image-classification
|
6 |
+
- active-learning
|
7 |
+
- medical-imaging
|
8 |
+
datasets:
|
9 |
+
- medmnist
|
10 |
+
metrics:
|
11 |
+
- accuracy
|
12 |
---
|
13 |
+
|
14 |
+
# MedMNIST Active Learning Model
|
15 |
+
|
16 |
+
## Overview
|
17 |
+
|
18 |
+
This model is designed for image classification tasks within the medical imaging domain, specifically targeting the MedMNIST dataset. It employs a ResNet-50 architecture tailored for 28x28 pixel images and incorporates active learning strategies to enhance performance with limited labeled data.
|
19 |
+
|
20 |
+
## Model Architecture
|
21 |
+
|
22 |
+
- **Base Model:** ResNet-50
|
23 |
+
- **Modifications:**
|
24 |
+
- Adjusted initial convolution layer to accommodate 28x28 input images.
|
25 |
+
- Removed max pooling layer to preserve spatial dimensions.
|
26 |
+
- Customized fully connected layer to output predictions for 9 classes.
|
27 |
+
|
28 |
+
## Training Procedure
|
29 |
+
|
30 |
+
- **Dataset:** [PathMNIST](https://medmnist.com/)
|
31 |
+
- **Data Augmentation:**
|
32 |
+
- Random resized cropping
|
33 |
+
- Horizontal flipping
|
34 |
+
- Random rotations
|
35 |
+
- Color jittering
|
36 |
+
- Gaussian blur
|
37 |
+
- RandAugment
|
38 |
+
- **Optimizer:** Stochastic Gradient Descent (SGD) with momentum
|
39 |
+
- **Learning Rate Scheduler:** ReduceLROnPlateau
|
40 |
+
- **Active Learning Strategy:** Mixed sampling combining uncertainty sampling and diversity sampling using Monte Carlo dropout and K-means clustering.
|
41 |
+
|
42 |
+
## Usage
|
43 |
+
|
44 |
+
To utilize this model:
|
45 |
+
|
46 |
+
1. **Install Dependencies:**
|
47 |
+
Ensure the following Python packages are installed:
|
48 |
+
- `torch`
|
49 |
+
- `torchvision`
|
50 |
+
- `medmnist`
|
51 |
+
- `scikit-learn`
|
52 |
+
- `determined`
|
53 |
+
|
54 |
+
Install them using pip:
|
55 |
+
```bash
|
56 |
+
pip install torch torchvision medmnist scikit-learn determined
|
57 |
+
```
|
58 |
+
|
59 |
+
2. **Load the Model:**
|
60 |
+
```python
|
61 |
+
import torch
|
62 |
+
from model import ResNet50_28
|
63 |
+
|
64 |
+
model = ResNet50_28(num_classes=9)
|
65 |
+
model.load_state_dict(torch.load('path_to_checkpoint.pt')['model_state_dict'])
|
66 |
+
model.eval()
|
67 |
+
```
|
68 |
+
|
69 |
+
3. **Inference:**
|
70 |
+
```python
|
71 |
+
from torchvision import transforms
|
72 |
+
from PIL import Image
|
73 |
+
|
74 |
+
transform = transforms.Compose([
|
75 |
+
transforms.Resize((28, 28)),
|
76 |
+
transforms.ToTensor(),
|
77 |
+
transforms.Normalize(mean=[0.5], std=[0.5])
|
78 |
+
])
|
79 |
+
|
80 |
+
image = Image.open('path_to_image.jpg')
|
81 |
+
input_tensor = transform(image).unsqueeze(0)
|
82 |
+
output = model(input_tensor)
|
83 |
+
prediction = output.argmax(dim=1).item()
|
84 |
+
print(f"Predicted class: {prediction}")
|
85 |
+
```
|
86 |
+
|
87 |
+
## Evaluation
|
88 |
+
|
89 |
+
The model was evaluated on the validation set of PathMNIST. Key performance metrics include:
|
90 |
+
|
91 |
+
- **Accuracy:** 94%
|
92 |
+
- **Loss:** 0.1775
|
93 |
+
|
94 |
+
## Evaluation Metrics
|
95 |
+
|
96 |
+
The following plot illustrates the validation loss over training batches during the active learning process. The consistent decrease in validation loss demonstrates the effectiveness of the active learning strategy in improving model performance.
|
97 |
+
|
98 |
+

|
99 |
+
|
100 |
+
- **Validation Loss**: The graph shows a steady decline, indicating successful learning and convergence.
|
101 |
+
- **Batches**: Represents the number of iterations over the dataset.
|
102 |
+
|
103 |
+
## License
|
104 |
+
|
105 |
+
This project is licensed under the mit License.
|
106 |
+
|
107 |
+
## Acknowledgements
|
108 |
+
|
109 |
+
- [MedMNIST Dataset](https://medmnist.com/)
|
110 |
+
- [Determined AI](https://determined.ai/)
|
image.png
ADDED
![]() |