mri_segment / app.py
elliemci's picture
Create app.py
0ebdb2c verified
raw
history blame
3.14 kB
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor
from PIL import Image
from pathlib import Path
# can upload from Huggingface Space "elliemci/maskformer_tumor_segmentation"
model = MaskFormerForInstanceSegmentation.from_pretrained("elliemci/maskformer_tumor_segmentation")
image_processor = MaskFormerImageProcessor.from_pretrained("elliemci/maskformer_tumor_segmentation")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Define a custom dataset class to handle images
class ImageDataset(Dataset):
def __init__(self, image_paths, transform=None):
self.image_paths = image_paths
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert('RGB')
if self.transform:
image = self.transform(image)
return image
def segment(image_files):
"""Takes a list of UploadedFile objects and returns a list of segmented images."""
dataset = ImageDataset(image_files, transform=transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=len(image_files), shuffle=False) # Batch size is the number of images
# process a batch
with torch.no_grad():
for batch in dataloader: # Only one iteration since batch_size = len(image_files)
pixel_values = batch.to(device, dtype=torch.float32)
outputs = model(pixel_values=pixel_values)
# Post-processing
original_images = outputs.get("org_images", batch)
target_sizes = [(image.shape[-2], image.shape[-1]) for image in original_images]
predicted_masks = image_processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)
return predicted_masks # Return the list of segmented images
# components for Gradion interface
def update_gallery(images):
print(f"Type in update_gallery: {type(images[0])}")
gallery_data = []
if images:
segmented_images = segment(images) # Process all images in one batch
for i, image in enumerate(images):
segmented_image_pil = transforms.ToPILImage()(segmented_images[i].to(device, dtype=torch.float32))
gallery_data.extend([(image, "Original Image"), (segmented_image_pil, "Segmented Image")])
return gallery_data
# Gradio UI for MEI segmentation
import gradio as gr
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>MRI Brain Tumor Segmentation App</h1>")
with gr.Column():
with gr.Column():
image_files = gr.Files(label="Upload MRI files",
file_count="multiple",
type="filepath")
with gr.Row():
gallery = gr.Gallery(label="Brain Images and Tumor Segmentation")
image_files.change(
fn=update_gallery,
inputs=[image_files],
outputs=[gallery])
demo.launch()