theodore-ioann's picture
Upload 15 files
5b303e8 verified
import gradio as gr
import torch
from PIL import Image
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from utils import *
from supervised import *
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load Models
models = {
"unet": UNet(num_classes=2).to(device),
"segformer": Segformer(num_classes=2).to(device),
"inception": Inception(num_classes=2).to(device),
"kmeans": KMeans(n_clusters=2),
"gmm": GaussianMixture(n_components=2),
}
models["unet"].load_state_dict(torch.load("unet.pt", map_location=device))
models["segformer"].load_state_dict(torch.load("segformer.pt", map_location=device))
models["inception"].load_state_dict(torch.load("inception.pt", map_location=device))
for model in models.values():
if isinstance(model, (UNet, Segformer, Inception)):
model.eval()
# Inference function
def inference(image, model_name, postprocess_mode):
model = models[model_name]
status_text = f"✅ Inference with {model_name.upper()} and postprocessing mode: {postprocess_mode}"
bw_mask, overlay = predict_and_visualize_single(model, image, postprocess_mode=postprocess_mode)
return overlay, bw_mask, status_text
# Gradio Interface
with gr.Blocks(theme=gr.themes.Base(primary_hue="rose", secondary_hue="slate")) as demo:
gr.Markdown("## 🩺 Skin Lesion Segmentation")
gr.Markdown("Upload a skin image, choose a model, and view segmentation results.")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type='numpy', label="📷 Upload Image")
model_choice = gr.Radio(
choices=["unet", "segformer", "inception", "kmeans", "gmm"],
label="Model",
value="unet"
)
post_choice = gr.Radio(
choices=["none", "open", "close", "erosion", "dilation"],
label="Postprocessing",
value="none"
)
run_btn = gr.Button("▶ Run Segmentation")
with gr.Column(scale=2):
with gr.Row():
overlay_output = gr.Image(type='numpy', label="🎯 Overlay")
mask_output = gr.Image(type='numpy', label="🖤 Predicted Mask")
status = gr.Textbox(label="Status", interactive=False)
with gr.Row():
gr.Examples(
examples=["./examples/ISIC_0012880.jpg", "./examples/ISIC_0015972.jpg"],
inputs=[image_input],
label="Use Example Images"
)
with gr.Accordion("ℹ️ Legend", open=False):
gr.Markdown("""
- **🔴 Red**: Predicted lesion overlay
- **⚫ White**: Binary mask
- **Postprocessing**: Cleans up noisy segmentation
""")
run_btn.click(
fn=inference,
inputs=[image_input, model_choice, post_choice],
outputs=[overlay_output, mask_output, status]
)
demo.launch(share=True)