alcheung0213 commited on
Commit
cc7a0e7
·
1 Parent(s): fc78132

Add validation loss plot

Browse files
Files changed (2) hide show
  1. README.md +107 -4
  2. image.png +0 -0
README.md CHANGED
@@ -1,7 +1,110 @@
1
- # MedMNIST Active Learning Model
2
-
3
- Custom ResNet50 model trained using active learning strategies.
4
-
5
  ---
 
6
  license: mit
 
 
 
 
 
 
 
 
7
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language: en
3
  license: mit
4
+ tags:
5
+ - image-classification
6
+ - active-learning
7
+ - medical-imaging
8
+ datasets:
9
+ - medmnist
10
+ metrics:
11
+ - accuracy
12
  ---
13
+
14
+ # MedMNIST Active Learning Model
15
+
16
+ ## Overview
17
+
18
+ This model is designed for image classification tasks within the medical imaging domain, specifically targeting the MedMNIST dataset. It employs a ResNet-50 architecture tailored for 28x28 pixel images and incorporates active learning strategies to enhance performance with limited labeled data.
19
+
20
+ ## Model Architecture
21
+
22
+ - **Base Model:** ResNet-50
23
+ - **Modifications:**
24
+ - Adjusted initial convolution layer to accommodate 28x28 input images.
25
+ - Removed max pooling layer to preserve spatial dimensions.
26
+ - Customized fully connected layer to output predictions for 9 classes.
27
+
28
+ ## Training Procedure
29
+
30
+ - **Dataset:** [PathMNIST](https://medmnist.com/)
31
+ - **Data Augmentation:**
32
+ - Random resized cropping
33
+ - Horizontal flipping
34
+ - Random rotations
35
+ - Color jittering
36
+ - Gaussian blur
37
+ - RandAugment
38
+ - **Optimizer:** Stochastic Gradient Descent (SGD) with momentum
39
+ - **Learning Rate Scheduler:** ReduceLROnPlateau
40
+ - **Active Learning Strategy:** Mixed sampling combining uncertainty sampling and diversity sampling using Monte Carlo dropout and K-means clustering.
41
+
42
+ ## Usage
43
+
44
+ To utilize this model:
45
+
46
+ 1. **Install Dependencies:**
47
+ Ensure the following Python packages are installed:
48
+ - `torch`
49
+ - `torchvision`
50
+ - `medmnist`
51
+ - `scikit-learn`
52
+ - `determined`
53
+
54
+ Install them using pip:
55
+ ```bash
56
+ pip install torch torchvision medmnist scikit-learn determined
57
+ ```
58
+
59
+ 2. **Load the Model:**
60
+ ```python
61
+ import torch
62
+ from model import ResNet50_28
63
+
64
+ model = ResNet50_28(num_classes=9)
65
+ model.load_state_dict(torch.load('path_to_checkpoint.pt')['model_state_dict'])
66
+ model.eval()
67
+ ```
68
+
69
+ 3. **Inference:**
70
+ ```python
71
+ from torchvision import transforms
72
+ from PIL import Image
73
+
74
+ transform = transforms.Compose([
75
+ transforms.Resize((28, 28)),
76
+ transforms.ToTensor(),
77
+ transforms.Normalize(mean=[0.5], std=[0.5])
78
+ ])
79
+
80
+ image = Image.open('path_to_image.jpg')
81
+ input_tensor = transform(image).unsqueeze(0)
82
+ output = model(input_tensor)
83
+ prediction = output.argmax(dim=1).item()
84
+ print(f"Predicted class: {prediction}")
85
+ ```
86
+
87
+ ## Evaluation
88
+
89
+ The model was evaluated on the validation set of PathMNIST. Key performance metrics include:
90
+
91
+ - **Accuracy:** 94%
92
+ - **Loss:** 0.1775
93
+
94
+ ## Evaluation Metrics
95
+
96
+ The following plot illustrates the validation loss over training batches during the active learning process. The consistent decrease in validation loss demonstrates the effectiveness of the active learning strategy in improving model performance.
97
+
98
+ ![Validation Loss](image.png)
99
+
100
+ - **Validation Loss**: The graph shows a steady decline, indicating successful learning and convergence.
101
+ - **Batches**: Represents the number of iterations over the dataset.
102
+
103
+ ## License
104
+
105
+ This project is licensed under the mit License.
106
+
107
+ ## Acknowledgements
108
+
109
+ - [MedMNIST Dataset](https://medmnist.com/)
110
+ - [Determined AI](https://determined.ai/)
image.png ADDED