alcheung0213's picture
Finalized model with better hyperparameters and also including Determined AI checkpoint.pt
3e275ef
---
language: en
license: mit
tags:
- image-classification
- active-learning
- medical-imaging
datasets:
- medmnist
metrics:
- loss
- accuracy
- area under the curve
---
# MedMNIST Active Learning Model
## Overview
This model is designed for image classification tasks within the medical imaging domain, specifically targeting the MedMNIST dataset. It employs a ResNet-50 architecture tailored for 28x28 pixel images and incorporates active learning strategies to enhance performance with limited labeled data.
## Model Architecture
- **Base Model:** ResNet-50
- **Modifications:**
- Adjusted initial convolution layer to accommodate 28x28 input images.
- Removed max pooling layer to preserve spatial dimensions.
- Customized fully connected layer to output predictions for 9 classes.
## Training Procedure
### Training Hyperparameters
| Hyperparameter | Value |
|------------------------|------------------------|
| Batch Size | 53 |
| Initial Labeled Size | 3559 |
| Learning Rate | 0.01332344940133225 |
| MC Dropout Passes | 6 |
| Samples to Label | 4430 |
| Weight Decay | 0.00021921795989143406 |
### Optimizer Settings
The optimizer used during training was Stochastic Gradient Descent(SDG), with the following settings and a Learning Rate Scheduler of ReduceLROnPlateau:
- `learning_rate = 0.01332344940133225`
- `momentum = 0.9`
- `weight_decay = 0.00021921795989143406`
The model was trained with float32 precision.
### Dataset
[PathMNIST](https://medmnist.com/)
### Data Augmentation
- Random resized cropping
- Horizontal flipping
- Random rotations
- Color jittering
- Gaussian blur
- RandAugment
### Active Learning Strategy
The active learning process was based on a mixed sampling strategy:
- **Uncertainty Sampling**: Monte Carlo (MC) dropout was used to estimate uncertainty.
- **Diversity Sampling**: K-means clustering was employed to ensure diverse samples.
## Evaluation
The model was evaluated on the validation set of PathMNIST. Key performance metrics include:
- **Accuracy:** 94.72%
- **Loss:** 0.2397
- **AUC:** 99.73%
## Graphs
The following plots illustrates the validation loss, validation accuracy, and validation auc over batches(number of iterations over the dataset) during the active learning process.
- **Validation Loss**
![Validation Loss](images/test_loss.png)
- **Validation Accuracy**
![Validation Accuracy](images/test_accuracy.png)
- **Validation AUC**
![Validation AUC](images/test_auc.png)
## Usage
All code for this model can be accessed in the following GitHub Repository:
[Allen Cheung Determined_AI_Hackathon](https://github.com/AllenCheung0213/Determined_AI_Hackathon)
To utilize this model:
1. **Install Dependencies:**
Ensure the following Python packages are installed:
- `torch`
- `torchvision`
- `medmnist`
- `scikit-learn`
- `determined`
Install them using pip:
```bash
pip install torch torchvision medmnist scikit-learn determined
```
2. **Load the Model:**
```python
import torch
from model import ResNet50_28
model = ResNet50_28(num_classes=9)
model.load_state_dict(torch.load('pytorch_model.bin'))
model.eval()
```
3. **Inference:**
```python
from torchvision import transforms
from PIL import Image
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
image = Image.open('path_to_image.jpg')
input_tensor = transform(image).unsqueeze(0)
output = model(input_tensor)
prediction = output.argmax(dim=1).item()
print(f"Predicted class: {prediction}")
```
## License
This project is licensed under the MIT License.
## Acknowledgements
- [MedMNIST Dataset](https://medmnist.com/)
- [Determined AI](https://determined.ai/)
- **Survey on Deep Active Learning**: Wang, H., Jin, Q., Li, S., Liu, S., Wang, M., & Song, Z. (2024). A comprehensive survey on deep active learning in medical image analysis. *Medical Image Analysis*, 95, 103201. [https://doi.org/10.1016/j.media.2024.103201](https://doi.org/10.1016/j.media.2024.103201)