import torch from torchvision import transforms from torch.utils.data import DataLoader, Dataset from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor from PIL import Image from pathlib import Path # can upload from Huggingface Space "elliemci/maskformer_tumor_segmentation" model = MaskFormerForInstanceSegmentation.from_pretrained("elliemci/maskformer_tumor_segmentation") image_processor = MaskFormerImageProcessor.from_pretrained("elliemci/maskformer_tumor_segmentation") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Define a custom dataset class to handle images class ImageDataset(Dataset): def __init__(self, image_paths, transform=None): self.image_paths = image_paths self.transform = transform def __len__(self): return len(self.image_paths) def __getitem__(self, idx): image = Image.open(self.image_paths[idx]).convert('RGB') if self.transform: image = self.transform(image) return image def segment(image_files): """Takes a list of UploadedFile objects and returns a list of segmented images.""" dataset = ImageDataset(image_files, transform=transforms.ToTensor()) dataloader = DataLoader(dataset, batch_size=len(image_files), shuffle=False) # Batch size is the number of images # process a batch with torch.no_grad(): for batch in dataloader: # Only one iteration since batch_size = len(image_files) pixel_values = batch.to(device, dtype=torch.float32) outputs = model(pixel_values=pixel_values) # Post-processing original_images = outputs.get("org_images", batch) target_sizes = [(image.shape[-2], image.shape[-1]) for image in original_images] predicted_masks = image_processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes) return predicted_masks # Return the list of segmented images # components for Gradion interface def update_gallery(images): print(f"Type in update_gallery: {type(images[0])}") gallery_data = [] if images: segmented_images = segment(images) # Process all images in one batch for i, image in enumerate(images): segmented_image_pil = transforms.ToPILImage()(segmented_images[i].to(device, dtype=torch.float32)) gallery_data.extend([(image, "Original Image"), (segmented_image_pil, "Segmented Image")]) return gallery_data # Gradio UI for MEI segmentation import gradio as gr with gr.Blocks() as demo: gr.Markdown("

MRI Brain Tumor Segmentation App

") with gr.Column(): with gr.Column(): image_files = gr.Files(label="Upload MRI files", file_count="multiple", type="filepath") with gr.Row(): gallery = gr.Gallery(label="Brain Images and Tumor Segmentation") image_files.change( fn=update_gallery, inputs=[image_files], outputs=[gallery]) demo.launch()