Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision import transforms | |
| from torch.utils.data import DataLoader, Dataset | |
| from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor | |
| from PIL import Image | |
| from pathlib import Path | |
| # can upload from Huggingface Space "elliemci/maskformer_tumor_segmentation" | |
| model = MaskFormerForInstanceSegmentation.from_pretrained("elliemci/maskformer_tumor_segmentation") | |
| image_processor = MaskFormerImageProcessor.from_pretrained("elliemci/maskformer_tumor_segmentation") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| # Define a custom dataset class to handle images | |
| class ImageDataset(Dataset): | |
| def __init__(self, image_paths, transform=None): | |
| self.image_paths = image_paths | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, idx): | |
| image = Image.open(self.image_paths[idx]).convert('RGB') | |
| if self.transform: | |
| image = self.transform(image) | |
| return image | |
| def segment(image_files): | |
| """Takes a list of UploadedFile objects and returns a list of segmented images.""" | |
| dataset = ImageDataset(image_files, transform=transforms.ToTensor()) | |
| dataloader = DataLoader(dataset, batch_size=len(image_files), shuffle=False) # Batch size is the number of images | |
| # process a batch | |
| with torch.no_grad(): | |
| for batch in dataloader: # Only one iteration since batch_size = len(image_files) | |
| pixel_values = batch.to(device, dtype=torch.float32) | |
| outputs = model(pixel_values=pixel_values) | |
| # Post-processing | |
| original_images = outputs.get("org_images", batch) | |
| target_sizes = [(image.shape[-2], image.shape[-1]) for image in original_images] | |
| predicted_masks = image_processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes) | |
| return predicted_masks # Return the list of segmented images | |
| # components for Gradion interface | |
| def update_gallery(images): | |
| print(f"Type in update_gallery: {type(images[0])}") | |
| gallery_data = [] | |
| if images: | |
| segmented_images = segment(images) # Process all images in one batch | |
| for i, image in enumerate(images): | |
| segmented_image_pil = transforms.ToPILImage()(segmented_images[i].to(device, dtype=torch.float32)) | |
| gallery_data.extend([(image, "Original Image"), (segmented_image_pil, "Segmented Image")]) | |
| return gallery_data | |
| # Gradio UI for MEI segmentation | |
| import gradio as gr | |
| with gr.Blocks() as demo: | |
| gr.Markdown("<h1 style='text-align: center;'>MRI Brain Tumor Segmentation App</h1>") | |
| with gr.Column(): | |
| with gr.Column(): | |
| image_files = gr.Files(label="Upload MRI files", | |
| file_count="multiple", | |
| type="filepath") | |
| with gr.Row(): | |
| gallery = gr.Gallery(label="Brain Images and Tumor Segmentation") | |
| image_files.change( | |
| fn=update_gallery, | |
| inputs=[image_files], | |
| outputs=[gallery]) | |
| demo.launch() |