Spaces:
Sleeping
Sleeping
File size: 3,135 Bytes
0ebdb2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor
from PIL import Image
from pathlib import Path
# can upload from Huggingface Space "elliemci/maskformer_tumor_segmentation"
model = MaskFormerForInstanceSegmentation.from_pretrained("elliemci/maskformer_tumor_segmentation")
image_processor = MaskFormerImageProcessor.from_pretrained("elliemci/maskformer_tumor_segmentation")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Define a custom dataset class to handle images
class ImageDataset(Dataset):
def __init__(self, image_paths, transform=None):
self.image_paths = image_paths
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert('RGB')
if self.transform:
image = self.transform(image)
return image
def segment(image_files):
"""Takes a list of UploadedFile objects and returns a list of segmented images."""
dataset = ImageDataset(image_files, transform=transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=len(image_files), shuffle=False) # Batch size is the number of images
# process a batch
with torch.no_grad():
for batch in dataloader: # Only one iteration since batch_size = len(image_files)
pixel_values = batch.to(device, dtype=torch.float32)
outputs = model(pixel_values=pixel_values)
# Post-processing
original_images = outputs.get("org_images", batch)
target_sizes = [(image.shape[-2], image.shape[-1]) for image in original_images]
predicted_masks = image_processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)
return predicted_masks # Return the list of segmented images
# components for Gradion interface
def update_gallery(images):
print(f"Type in update_gallery: {type(images[0])}")
gallery_data = []
if images:
segmented_images = segment(images) # Process all images in one batch
for i, image in enumerate(images):
segmented_image_pil = transforms.ToPILImage()(segmented_images[i].to(device, dtype=torch.float32))
gallery_data.extend([(image, "Original Image"), (segmented_image_pil, "Segmented Image")])
return gallery_data
# Gradio UI for MEI segmentation
import gradio as gr
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>MRI Brain Tumor Segmentation App</h1>")
with gr.Column():
with gr.Column():
image_files = gr.Files(label="Upload MRI files",
file_count="multiple",
type="filepath")
with gr.Row():
gallery = gr.Gallery(label="Brain Images and Tumor Segmentation")
image_files.change(
fn=update_gallery,
inputs=[image_files],
outputs=[gallery])
demo.launch() |