Update model card for progressive resizing experiment
Browse files
README.md
CHANGED
@@ -1,104 +1,82 @@
|
|
1 |
-
---
|
2 |
license: mit
|
3 |
-
language:
|
4 |
-
- en
|
5 |
tags:
|
6 |
-
- image-classification
|
7 |
-
- medical-imaging
|
8 |
-
- diabetic-retinopathy
|
9 |
-
- resnet
|
10 |
-
- sih-2025
|
11 |
-
---
|
12 |
|
13 |
-
|
14 |
-
This is a ResNet50 model fine-tuned for the task of Diabetic Retinopathy (DR) grading based on fundus images. The model classifies a given retina scan into one of five severity grades, following the International Clinical Diabetic Retinopathy scale.
|
15 |
|
16 |
-
|
17 |
|
18 |
-
|
19 |
-
Model Architecture: ResNet50
|
20 |
|
21 |
-
|
22 |
|
23 |
-
|
24 |
|
25 |
-
|
26 |
|
27 |
-
|
|
|
28 |
|
29 |
-
|
|
|
30 |
|
31 |
-
|
32 |
|
33 |
-
|
|
|
34 |
|
35 |
-
|
36 |
|
37 |
-
|
38 |
|
39 |
-
|
40 |
-
To use this model, you need to have torch and torchvision installed.
|
41 |
|
42 |
-
|
43 |
-
import torchvision
|
44 |
-
from torchvision import models, transforms
|
45 |
-
from PIL import Image
|
46 |
-
|
47 |
-
# 1. Define the model architecture
|
48 |
-
model = models.resnet50(weights=None)
|
49 |
-
num_ftrs = model.fc.in_features
|
50 |
-
model.fc = torch.nn.Linear(num_ftrs, 5) # 5 classes
|
51 |
-
|
52 |
-
# 2. Load the fine-tuned weights from the Hub
|
53 |
-
weights_path = hf_hub_download(repo_id="Arko007/Diabetic-Retinopathy", filename="resnet50_finetuned_retinopathy.pth")
|
54 |
-
model.load_state_dict(torch.load(weights_path, map_location='cpu'))
|
55 |
-
model.eval()
|
56 |
-
|
57 |
-
# 3. Create the same data transform used for validation/testing
|
58 |
-
transform = transforms.Compose([
|
59 |
-
transforms.Resize((224, 224)),
|
60 |
-
transforms.ToTensor(),
|
61 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
62 |
-
])
|
63 |
-
|
64 |
-
# 4. Load an image and make a prediction
|
65 |
-
# Make sure to replace 'path/to/your/image.jpg'
|
66 |
-
img = Image.open('path/to/your/image.jpg').convert('RGB')
|
67 |
-
img_t = transform(img)
|
68 |
-
batch_t = torch.unsqueeze(img_t, 0)
|
69 |
-
|
70 |
-
with torch.no_grad():
|
71 |
-
output = model(batch_t)
|
72 |
-
_, predicted_idx = torch.max(output, 1)
|
73 |
|
74 |
-
|
75 |
-
|
76 |
|
|
|
|
|
77 |
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
80 |
|
81 |
-
|
|
|
|
|
82 |
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
-
|
|
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
The model was trained on an NVIDIA A100 GPU using a two-phase transfer learning strategy:
|
91 |
-
|
92 |
-
Head Training: The pre-trained ResNet50 backbone was frozen, and only the new classification head was trained for 15 epochs.
|
93 |
-
|
94 |
-
Fine-Tuning: The entire model was unfrozen and trained for an additional 30 epochs with a much smaller learning rate to fine-tune the deep features.
|
95 |
-
|
96 |
-
Key hyperparameters:
|
97 |
|
98 |
-
|
|
|
|
|
99 |
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
-
|
|
|
103 |
|
104 |
-
|
|
|
|
|
1 |
license: mit
|
2 |
+
language: en
|
|
|
3 |
tags:
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
+
image-classification
|
|
|
6 |
|
7 |
+
medical-imaging
|
8 |
|
9 |
+
diabetic-retinopathy
|
|
|
10 |
|
11 |
+
resnet
|
12 |
|
13 |
+
fine-tuning
|
14 |
|
15 |
+
progressive-resizing
|
16 |
|
17 |
+
sih-2025
|
18 |
+
base_model: microsoft/resnet-50
|
19 |
|
20 |
+
Progressively Resized ResNet50 for Diabetic Retinopathy Grading
|
21 |
+
This repository contains a collection of ResNet50 models fine-tuned for classifying diabetic retinopathy severity. These models are the result of an advanced, multi-stage progressive resizing experiment.
|
22 |
|
23 |
+
The strategy involves starting with a fine-tuned model and continuing to train it on progressively higher image resolutions. This allows the model to first learn general features on smaller images and then refine its understanding by learning fine-grained details from larger, higher-quality images.
|
24 |
|
25 |
+
Model Versions
|
26 |
+
This repository contains several model checkpoints, each representing the best-performing model at a specific resolution stage. The final model from the highest resolution stage represents the culmination of this experiment.
|
27 |
|
28 |
+
best_model_384px.pth: Fine-tuned on 384x384 images.
|
29 |
|
30 |
+
best_model_512px.pth: Fine-tuned on 512x512 images.
|
31 |
|
32 |
+
best_model_768px.pth: Fine-tuned on 768x768 images.
|
|
|
33 |
|
34 |
+
best_model_1024px.pth: The final model, fine-tuned on 1024x1024 images.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
+
Performance (Final Model)
|
37 |
+
The final model's performance was evaluated on the official test set from the IDRiD dataset.
|
38 |
|
39 |
+
Classification Report
|
40 |
+
precision recall f1-score support
|
41 |
|
42 |
+
Grade 0 0.76 0.65 0.70 34
|
43 |
+
Grade 1 0.11 0.40 0.17 5
|
44 |
+
Grade 2 0.59 0.59 0.59 32
|
45 |
+
Grade 3 0.64 0.47 0.55 19
|
46 |
+
Grade 4 0.40 0.31 0.35 13
|
47 |
|
48 |
+
accuracy 0.54 103
|
49 |
+
macro avg 0.50 0.48 0.47 103
|
50 |
+
weighted avg 0.61 0.54 0.57 103
|
51 |
|
52 |
+
Confusion Matrix
|
53 |
+
Grade 0 Grade 1 Grade 2 Grade 3 Grade 4
|
54 |
+
Grade 0 22 10 2 0 0
|
55 |
+
Grade 1 2 2 1 0 0
|
56 |
+
Grade 2 4 4 19 3 2
|
57 |
+
Grade 3 0 2 4 9 4
|
58 |
+
Grade 4 1 0 6 2 4
|
59 |
|
60 |
+
How to Use a Specific Model
|
61 |
+
You can load any of the model versions using PyTorch. Make sure to use the correct filename.
|
62 |
|
63 |
+
import torch
|
64 |
+
from torchvision import models
|
65 |
+
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
+
# 1. Define the model architecture
|
68 |
+
model = models.resnet50(weights=None)
|
69 |
+
model.fc = torch.nn.Linear(model.fc.in_features, 5) # 5 classes
|
70 |
|
71 |
+
# 2. Load the fine-tuned weights for the desired resolution
|
72 |
+
weights_path = hf_hub_download(
|
73 |
+
repo_id="Arko007/Diabetic-Retinopathy",
|
74 |
+
filename="best_model_1024px.pth" # Change this to load other versions
|
75 |
+
)
|
76 |
+
model.load_state_dict(torch.load(weights_path, map_location='cpu'))
|
77 |
+
model.eval()
|
78 |
|
79 |
+
# 3. Preprocess your image using the correct size for the model you loaded
|
80 |
+
# ...
|
81 |
|
82 |
+
Developed by: Arko007 for SIH 2025.
|