File size: 5,334 Bytes
391ecac ce54ee1 391ecac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import os
import time
########################
# MODEL DEFINITION
########################
class MelanomaModel(nn.Module):
def __init__(self, out_size, dropout_prob=0.5):
super(MelanomaModel, self).__init__()
from efficientnet_pytorch import EfficientNet
self.efficient_net = EfficientNet.from_pretrained('efficientnet-b0')
# Remove the original FC layer
self.efficient_net._fc = nn.Identity()
self.fc1 = nn.Linear(1280, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, out_size)
self.dropout = nn.Dropout(dropout_prob)
def forward(self, x):
x = self.efficient_net(x)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = F.relu(self.fc2(x))
x = self.dropout(x)
x = self.fc3(x)
return x
########################
# DIAGNOSIS MAP
########################
DIAGNOSIS_MAP = {
0: 'Melanoma',
1: 'Melanocytic nevus',
2: 'Basal cell carcinoma',
3: 'Actinic keratosis',
4: 'Benign keratosis',
5: 'Dermatofibroma',
6: 'Vascular lesion',
7: 'Squamous cell carcinoma',
8: 'Unknown'
}
########################
# LOAD MODEL FUNCTION
########################
@st.cache_resource
def load_model():
"""
Loads the model checkpoint.
Using weights_only=False (if you trust the .pth file).
If you prefer a more secure approach, re-save your checkpoint
to only contain raw state_dict and set weights_only=True.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MelanomaModel(out_size=9)
# Path to your model file
# model_path = os.path.join("model", "multi_weight.pth")
model_path = "multi_weight.pth"
# If you trust the checkpoint file, set weights_only=False
checkpoint = torch.load(
model_path,
map_location=device,
weights_only=False # if you have a purely raw state_dict, you can use True
)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()
return model, device
########################
# IMAGE TRANSFORM
########################
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
########################
# PREDICTION UTILS
########################
def predict_skin_lesion(img: Image.Image, model: nn.Module, device: torch.device):
# Transform and move image to device
img_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img_tensor)
probs = F.softmax(outputs, dim=1)
top_probs, top_idxs = torch.topk(probs, 3, dim=1) # top 3 predictions
predictions = []
for prob, idx in zip(top_probs[0], top_idxs[0]):
label = DIAGNOSIS_MAP.get(idx.item(), "Unknown")
confidence = prob.item() * 100
predictions.append((label, confidence))
return predictions
########################
# PAGE CONFIG & STYLE
########################
st.set_page_config(
page_title="Skin Lesion Classifier",
page_icon=":microscope:",
layout="centered",
initial_sidebar_state="expanded"
)
def set_background_color():
st.markdown(
"""
<style>
.stApp {
background-color: #FDEAE0; /* A pale peach/light skin tone */
}
</style>
""",
unsafe_allow_html=True
)
set_background_color()
########################
# STREAMLIT APP
########################
def main():
st.title("Skin Lesion Classifier")
st.write("Upload an image of a skin lesion to see the top-3 predicted diagnoses.")
# Create a stylish sidebar
st.sidebar.title("Possible Diagnoses")
st.sidebar.markdown("Here are the categories the model can distinguish:")
for idx, diag in DIAGNOSIS_MAP.items():
st.sidebar.markdown(f"- **{diag}**")
# Add the names to the sidebar in a new section
st.sidebar.title("Team Members")
st.sidebar.markdown(
"""
- **PRATHUSH MON**
- **PRATIK J**
- **RAYAN NASAR**
- **R HARIMURALI**
- **WASEEM AHAMMED**
"""
)
# Load the model once (cached)
model, device = load_model()
# File uploader
uploaded_file = st.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
# Display the image
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
# Predict on button click
if st.button("Classify"):
with st.spinner("Analyzing..."):
time.sleep(3) # 3-second spinner
results = predict_skin_lesion(image, model, device)
st.subheader("Top-3 Predictions")
for i, (diagnosis, confidence) in enumerate(results, start=1):
st.write(f"{i}. **{diagnosis}**: {confidence:.2f}%")
if __name__ == "__main__":
main()
|