|
import gradio as gr
|
|
import torch
|
|
from PIL import Image
|
|
from sklearn.cluster import KMeans
|
|
from sklearn.mixture import GaussianMixture
|
|
from utils import *
|
|
from supervised import *
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
models = {
|
|
"unet": UNet(num_classes=2).to(device),
|
|
"segformer": Segformer(num_classes=2).to(device),
|
|
"inception": Inception(num_classes=2).to(device),
|
|
"kmeans": KMeans(n_clusters=2),
|
|
"gmm": GaussianMixture(n_components=2),
|
|
}
|
|
|
|
models["unet"].load_state_dict(torch.load("unet.pt", map_location=device))
|
|
models["segformer"].load_state_dict(torch.load("segformer.pt", map_location=device))
|
|
models["inception"].load_state_dict(torch.load("inception.pt", map_location=device))
|
|
|
|
for model in models.values():
|
|
if isinstance(model, (UNet, Segformer, Inception)):
|
|
model.eval()
|
|
|
|
|
|
def inference(image, model_name, postprocess_mode):
|
|
model = models[model_name]
|
|
status_text = f"✅ Inference with {model_name.upper()} and postprocessing mode: {postprocess_mode}"
|
|
bw_mask, overlay = predict_and_visualize_single(model, image, postprocess_mode=postprocess_mode)
|
|
return overlay, bw_mask, status_text
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Base(primary_hue="rose", secondary_hue="slate")) as demo:
|
|
gr.Markdown("## 🩺 Skin Lesion Segmentation")
|
|
gr.Markdown("Upload a skin image, choose a model, and view segmentation results.")
|
|
|
|
with gr.Row():
|
|
with gr.Column(scale=1):
|
|
image_input = gr.Image(type='numpy', label="📷 Upload Image")
|
|
model_choice = gr.Radio(
|
|
choices=["unet", "segformer", "inception", "kmeans", "gmm"],
|
|
label="Model",
|
|
value="unet"
|
|
)
|
|
post_choice = gr.Radio(
|
|
choices=["none", "open", "close", "erosion", "dilation"],
|
|
label="Postprocessing",
|
|
value="none"
|
|
)
|
|
run_btn = gr.Button("▶ Run Segmentation")
|
|
|
|
with gr.Column(scale=2):
|
|
with gr.Row():
|
|
overlay_output = gr.Image(type='numpy', label="🎯 Overlay")
|
|
mask_output = gr.Image(type='numpy', label="🖤 Predicted Mask")
|
|
status = gr.Textbox(label="Status", interactive=False)
|
|
|
|
with gr.Row():
|
|
gr.Examples(
|
|
examples=["./examples/ISIC_0012880.jpg", "./examples/ISIC_0015972.jpg"],
|
|
inputs=[image_input],
|
|
label="Use Example Images"
|
|
)
|
|
|
|
with gr.Accordion("ℹ️ Legend", open=False):
|
|
gr.Markdown("""
|
|
- **🔴 Red**: Predicted lesion overlay
|
|
- **⚫ White**: Binary mask
|
|
- **Postprocessing**: Cleans up noisy segmentation
|
|
""")
|
|
|
|
run_btn.click(
|
|
fn=inference,
|
|
inputs=[image_input, model_choice, post_choice],
|
|
outputs=[overlay_output, mask_output, status]
|
|
)
|
|
|
|
demo.launch(share=True)
|
|
|