--- language: en license: mit tags: - image-classification - active-learning - medical-imaging datasets: - medmnist metrics: - loss - accuracy - area under the curve --- # MedMNIST Active Learning Model ## Overview This model is designed for image classification tasks within the medical imaging domain, specifically targeting the MedMNIST dataset. It employs a ResNet-50 architecture tailored for 28x28 pixel images and incorporates active learning strategies to enhance performance with limited labeled data. ## Model Architecture - **Base Model:** ResNet-50 - **Modifications:** - Adjusted initial convolution layer to accommodate 28x28 input images. - Removed max pooling layer to preserve spatial dimensions. - Customized fully connected layer to output predictions for 9 classes. ## Training Procedure ### Training Hyperparameters | Hyperparameter | Value | |------------------------|------------------------| | Batch Size | 53 | | Initial Labeled Size | 3559 | | Learning Rate | 0.01332344940133225 | | MC Dropout Passes | 6 | | Samples to Label | 4430 | | Weight Decay | 0.00021921795989143406 | ### Optimizer Settings The optimizer used during training was Stochastic Gradient Descent(SDG), with the following settings and a Learning Rate Scheduler of ReduceLROnPlateau: - `learning_rate = 0.01332344940133225` - `momentum = 0.9` - `weight_decay = 0.00021921795989143406` The model was trained with float32 precision. ### Dataset [PathMNIST](https://medmnist.com/) ### Data Augmentation - Random resized cropping - Horizontal flipping - Random rotations - Color jittering - Gaussian blur - RandAugment ### Active Learning Strategy The active learning process was based on a mixed sampling strategy: - **Uncertainty Sampling**: Monte Carlo (MC) dropout was used to estimate uncertainty. - **Diversity Sampling**: K-means clustering was employed to ensure diverse samples. ## Evaluation The model was evaluated on the validation set of PathMNIST. Key performance metrics include: - **Accuracy:** 94.72% - **Loss:** 0.2397 - **AUC:** 99.73% ## Graphs The following plots illustrates the validation loss, validation accuracy, and validation auc over batches(number of iterations over the dataset) during the active learning process. - **Validation Loss** ![Validation Loss](images/test_loss.png) - **Validation Accuracy** ![Validation Accuracy](images/test_accuracy.png) - **Validation AUC** ![Validation AUC](images/test_auc.png) ## Usage All code for this model can be accessed in the following GitHub Repository: [Allen Cheung Determined_AI_Hackathon](https://github.com/AllenCheung0213/Determined_AI_Hackathon) To utilize this model: 1. **Install Dependencies:** Ensure the following Python packages are installed: - `torch` - `torchvision` - `medmnist` - `scikit-learn` - `determined` Install them using pip: ```bash pip install torch torchvision medmnist scikit-learn determined ``` 2. **Load the Model:** ```python import torch from model import ResNet50_28 model = ResNet50_28(num_classes=9) model.load_state_dict(torch.load('pytorch_model.bin')) model.eval() ``` 3. **Inference:** ```python from torchvision import transforms from PIL import Image transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) image = Image.open('path_to_image.jpg') input_tensor = transform(image).unsqueeze(0) output = model(input_tensor) prediction = output.argmax(dim=1).item() print(f"Predicted class: {prediction}") ``` ## License This project is licensed under the MIT License. ## Acknowledgements - [MedMNIST Dataset](https://medmnist.com/) - [Determined AI](https://determined.ai/) - **Survey on Deep Active Learning**: Wang, H., Jin, Q., Li, S., Liu, S., Wang, M., & Song, Z. (2024). A comprehensive survey on deep active learning in medical image analysis. *Medical Image Analysis*, 95, 103201. [https://doi.org/10.1016/j.media.2024.103201](https://doi.org/10.1016/j.media.2024.103201)