alcheung0213's picture
Add validation loss plot
cc7a0e7
|
raw
history blame
3.1 kB
metadata
language: en
license: mit
tags:
  - image-classification
  - active-learning
  - medical-imaging
datasets:
  - medmnist
metrics:
  - accuracy

MedMNIST Active Learning Model

Overview

This model is designed for image classification tasks within the medical imaging domain, specifically targeting the MedMNIST dataset. It employs a ResNet-50 architecture tailored for 28x28 pixel images and incorporates active learning strategies to enhance performance with limited labeled data.

Model Architecture

  • Base Model: ResNet-50
  • Modifications:
    • Adjusted initial convolution layer to accommodate 28x28 input images.
    • Removed max pooling layer to preserve spatial dimensions.
    • Customized fully connected layer to output predictions for 9 classes.

Training Procedure

  • Dataset: PathMNIST
  • Data Augmentation:
    • Random resized cropping
    • Horizontal flipping
    • Random rotations
    • Color jittering
    • Gaussian blur
    • RandAugment
  • Optimizer: Stochastic Gradient Descent (SGD) with momentum
  • Learning Rate Scheduler: ReduceLROnPlateau
  • Active Learning Strategy: Mixed sampling combining uncertainty sampling and diversity sampling using Monte Carlo dropout and K-means clustering.

Usage

To utilize this model:

  1. Install Dependencies: Ensure the following Python packages are installed:

    • torch
    • torchvision
    • medmnist
    • scikit-learn
    • determined

    Install them using pip:

    pip install torch torchvision medmnist scikit-learn determined
    
  2. Load the Model:

    import torch
    from model import ResNet50_28
    
    model = ResNet50_28(num_classes=9)
    model.load_state_dict(torch.load('path_to_checkpoint.pt')['model_state_dict'])
    model.eval()
    
  3. Inference:

    from torchvision import transforms
    from PIL import Image
    
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    
    image = Image.open('path_to_image.jpg')
    input_tensor = transform(image).unsqueeze(0)
    output = model(input_tensor)
    prediction = output.argmax(dim=1).item()
    print(f"Predicted class: {prediction}")
    

Evaluation

The model was evaluated on the validation set of PathMNIST. Key performance metrics include:

  • Accuracy: 94%
  • Loss: 0.1775

Evaluation Metrics

The following plot illustrates the validation loss over training batches during the active learning process. The consistent decrease in validation loss demonstrates the effectiveness of the active learning strategy in improving model performance.

Validation Loss

  • Validation Loss: The graph shows a steady decline, indicating successful learning and convergence.
  • Batches: Represents the number of iterations over the dataset.

License

This project is licensed under the mit License.

Acknowledgements