language: en
license: mit
tags:
- image-classification
- active-learning
- medical-imaging
datasets:
- medmnist
metrics:
- accuracy
MedMNIST Active Learning Model
Overview
This model is designed for image classification tasks within the medical imaging domain, specifically targeting the MedMNIST dataset. It employs a ResNet-50 architecture tailored for 28x28 pixel images and incorporates active learning strategies to enhance performance with limited labeled data.
Model Architecture
- Base Model: ResNet-50
- Modifications:
- Adjusted initial convolution layer to accommodate 28x28 input images.
- Removed max pooling layer to preserve spatial dimensions.
- Customized fully connected layer to output predictions for 9 classes.
Training Procedure
- Dataset: PathMNIST
- Data Augmentation:
- Random resized cropping
- Horizontal flipping
- Random rotations
- Color jittering
- Gaussian blur
- RandAugment
- Optimizer: Stochastic Gradient Descent (SGD) with momentum
- Learning Rate Scheduler: ReduceLROnPlateau
- Active Learning Strategy: Mixed sampling combining uncertainty sampling and diversity sampling using Monte Carlo dropout and K-means clustering.
Usage
To utilize this model:
Install Dependencies: Ensure the following Python packages are installed:
torch
torchvision
medmnist
scikit-learn
determined
Install them using pip:
pip install torch torchvision medmnist scikit-learn determined
Load the Model:
import torch from model import ResNet50_28 model = ResNet50_28(num_classes=9) model.load_state_dict(torch.load('path_to_checkpoint.pt')['model_state_dict']) model.eval()
Inference:
from torchvision import transforms from PIL import Image transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) image = Image.open('path_to_image.jpg') input_tensor = transform(image).unsqueeze(0) output = model(input_tensor) prediction = output.argmax(dim=1).item() print(f"Predicted class: {prediction}")
Evaluation
The model was evaluated on the validation set of PathMNIST. Key performance metrics include:
- Accuracy: 94%
- Loss: 0.1775
Evaluation Metrics
The following plot illustrates the validation loss over training batches during the active learning process. The consistent decrease in validation loss demonstrates the effectiveness of the active learning strategy in improving model performance.
- Validation Loss: The graph shows a steady decline, indicating successful learning and convergence.
- Batches: Represents the number of iterations over the dataset.
License
This project is licensed under the mit License.