|
--- |
|
language: en |
|
license: mit |
|
tags: |
|
- image-classification |
|
- active-learning |
|
- medical-imaging |
|
datasets: |
|
- medmnist |
|
metrics: |
|
- loss |
|
- accuracy |
|
- area under the curve |
|
--- |
|
|
|
# MedMNIST Active Learning Model |
|
|
|
## Overview |
|
|
|
This model is designed for image classification tasks within the medical imaging domain, specifically targeting the MedMNIST dataset. It employs a ResNet-50 architecture tailored for 28x28 pixel images and incorporates active learning strategies to enhance performance with limited labeled data. |
|
|
|
## Model Architecture |
|
|
|
- **Base Model:** ResNet-50 |
|
- **Modifications:** |
|
- Adjusted initial convolution layer to accommodate 28x28 input images. |
|
- Removed max pooling layer to preserve spatial dimensions. |
|
- Customized fully connected layer to output predictions for 9 classes. |
|
|
|
## Training Procedure |
|
|
|
### Training Hyperparameters |
|
|
|
| Hyperparameter | Value | |
|
|------------------------|------------------------| |
|
| Batch Size | 53 | |
|
| Initial Labeled Size | 3559 | |
|
| Learning Rate | 0.01332344940133225 | |
|
| MC Dropout Passes | 6 | |
|
| Samples to Label | 4430 | |
|
| Weight Decay | 0.00021921795989143406 | |
|
|
|
### Optimizer Settings |
|
|
|
The optimizer used during training was Stochastic Gradient Descent(SDG), with the following settings and a Learning Rate Scheduler of ReduceLROnPlateau: |
|
- `learning_rate = 0.01332344940133225` |
|
- `momentum = 0.9` |
|
- `weight_decay = 0.00021921795989143406` |
|
|
|
The model was trained with float32 precision. |
|
|
|
### Dataset |
|
[PathMNIST](https://medmnist.com/) |
|
|
|
### Data Augmentation |
|
- Random resized cropping |
|
- Horizontal flipping |
|
- Random rotations |
|
- Color jittering |
|
- Gaussian blur |
|
- RandAugment |
|
|
|
### Active Learning Strategy |
|
|
|
The active learning process was based on a mixed sampling strategy: |
|
- **Uncertainty Sampling**: Monte Carlo (MC) dropout was used to estimate uncertainty. |
|
- **Diversity Sampling**: K-means clustering was employed to ensure diverse samples. |
|
|
|
## Evaluation |
|
|
|
The model was evaluated on the validation set of PathMNIST. Key performance metrics include: |
|
|
|
- **Accuracy:** 94.72% |
|
- **Loss:** 0.2397 |
|
- **AUC:** 99.73% |
|
|
|
## Graphs |
|
|
|
The following plots illustrates the validation loss, validation accuracy, and validation auc over batches(number of iterations over the dataset) during the active learning process. |
|
|
|
- **Validation Loss** |
|
 |
|
- **Validation Accuracy** |
|
 |
|
- **Validation AUC** |
|
 |
|
|
|
## Usage |
|
All code for this model can be accessed in the following GitHub Repository: |
|
[Allen Cheung Determined_AI_Hackathon](https://github.com/AllenCheung0213/Determined_AI_Hackathon) |
|
|
|
To utilize this model: |
|
|
|
1. **Install Dependencies:** |
|
Ensure the following Python packages are installed: |
|
- `torch` |
|
- `torchvision` |
|
- `medmnist` |
|
- `scikit-learn` |
|
- `determined` |
|
|
|
Install them using pip: |
|
```bash |
|
pip install torch torchvision medmnist scikit-learn determined |
|
``` |
|
|
|
2. **Load the Model:** |
|
```python |
|
import torch |
|
from model import ResNet50_28 |
|
|
|
model = ResNet50_28(num_classes=9) |
|
model.load_state_dict(torch.load('pytorch_model.bin')) |
|
model.eval() |
|
``` |
|
|
|
3. **Inference:** |
|
```python |
|
from torchvision import transforms |
|
from PIL import Image |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((28, 28)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.5], std=[0.5]) |
|
]) |
|
|
|
image = Image.open('path_to_image.jpg') |
|
input_tensor = transform(image).unsqueeze(0) |
|
output = model(input_tensor) |
|
prediction = output.argmax(dim=1).item() |
|
print(f"Predicted class: {prediction}") |
|
``` |
|
|
|
## License |
|
|
|
This project is licensed under the MIT License. |
|
|
|
## Acknowledgements |
|
|
|
- [MedMNIST Dataset](https://medmnist.com/) |
|
- [Determined AI](https://determined.ai/) |
|
- **Survey on Deep Active Learning**: Wang, H., Jin, Q., Li, S., Liu, S., Wang, M., & Song, Z. (2024). A comprehensive survey on deep active learning in medical image analysis. *Medical Image Analysis*, 95, 103201. [https://doi.org/10.1016/j.media.2024.103201](https://doi.org/10.1016/j.media.2024.103201) |
|
|
|
|