File size: 4,292 Bytes
fc78132
cc7a0e7
fc78132
cc7a0e7
 
 
 
 
 
 
3e275ef
cc7a0e7
3e275ef
fc78132
cc7a0e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e275ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc7a0e7
 
 
 
 
 
3e275ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc7a0e7
 
3e275ef
 
cc7a0e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
baa7e0f
cc7a0e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e275ef
cc7a0e7
 
 
 
 
3e275ef
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
---
language: en
license: mit
tags:
- image-classification
- active-learning
- medical-imaging
datasets:
- medmnist
metrics:
- loss
- accuracy
- area under the curve
---

# MedMNIST Active Learning Model

## Overview

This model is designed for image classification tasks within the medical imaging domain, specifically targeting the MedMNIST dataset. It employs a ResNet-50 architecture tailored for 28x28 pixel images and incorporates active learning strategies to enhance performance with limited labeled data.

## Model Architecture

- **Base Model:** ResNet-50
- **Modifications:**
  - Adjusted initial convolution layer to accommodate 28x28 input images.
  - Removed max pooling layer to preserve spatial dimensions.
  - Customized fully connected layer to output predictions for 9 classes.

## Training Procedure

### Training Hyperparameters

| Hyperparameter         | Value                  |
|------------------------|------------------------|
| Batch Size             | 53                    |
| Initial Labeled Size   | 3559                  |
| Learning Rate          | 0.01332344940133225    |
| MC Dropout Passes      | 6                     |
| Samples to Label       | 4430                  |
| Weight Decay           | 0.00021921795989143406 |

### Optimizer Settings

The optimizer used during training was Stochastic Gradient Descent(SDG), with the following settings and a Learning Rate Scheduler of ReduceLROnPlateau:
- `learning_rate = 0.01332344940133225`
- `momentum = 0.9`
- `weight_decay = 0.00021921795989143406`

The model was trained with float32 precision.

### Dataset
[PathMNIST](https://medmnist.com/)

### Data Augmentation
  - Random resized cropping
  - Horizontal flipping
  - Random rotations
  - Color jittering
  - Gaussian blur
  - RandAugment

### Active Learning Strategy

The active learning process was based on a mixed sampling strategy:
- **Uncertainty Sampling**: Monte Carlo (MC) dropout was used to estimate uncertainty.
- **Diversity Sampling**: K-means clustering was employed to ensure diverse samples.

## Evaluation

The model was evaluated on the validation set of PathMNIST. Key performance metrics include:

- **Accuracy:** 94.72%
- **Loss:** 0.2397
- **AUC:** 99.73%

## Graphs

The following plots illustrates the validation loss, validation accuracy, and validation auc over batches(number of iterations over the dataset) during the active learning process.

- **Validation Loss**
![Validation Loss](images/test_loss.png)
- **Validation Accuracy**
![Validation Accuracy](images/test_accuracy.png)
- **Validation AUC**
![Validation AUC](images/test_auc.png)

## Usage
All code for this model can be accessed in the following GitHub Repository:
[Allen Cheung Determined_AI_Hackathon](https://github.com/AllenCheung0213/Determined_AI_Hackathon)

To utilize this model:

1. **Install Dependencies:**
   Ensure the following Python packages are installed:
   - `torch`
   - `torchvision`
   - `medmnist`
   - `scikit-learn`
   - `determined`

   Install them using pip:
   ```bash
   pip install torch torchvision medmnist scikit-learn determined
   ```

2. **Load the Model:**
    ```python
    import torch
    from model import ResNet50_28

    model = ResNet50_28(num_classes=9)
    model.load_state_dict(torch.load('pytorch_model.bin'))
    model.eval()
    ```
    
3. **Inference:**
    ```python
    from torchvision import transforms
    from PIL import Image

    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    image = Image.open('path_to_image.jpg')
    input_tensor = transform(image).unsqueeze(0)
    output = model(input_tensor)
    prediction = output.argmax(dim=1).item()
    print(f"Predicted class: {prediction}")
    ```

## License

This project is licensed under the MIT License.

## Acknowledgements

- [MedMNIST Dataset](https://medmnist.com/)
- [Determined AI](https://determined.ai/)
- **Survey on Deep Active Learning**: Wang, H., Jin, Q., Li, S., Liu, S., Wang, M., & Song, Z. (2024). A comprehensive survey on deep active learning in medical image analysis. *Medical Image Analysis*, 95, 103201. [https://doi.org/10.1016/j.media.2024.103201](https://doi.org/10.1016/j.media.2024.103201)