File size: 3,135 Bytes
0ebdb2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch

from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor
from PIL import Image
from pathlib import Path

# can upload from Huggingface Space "elliemci/maskformer_tumor_segmentation"
model = MaskFormerForInstanceSegmentation.from_pretrained("elliemci/maskformer_tumor_segmentation")
image_processor = MaskFormerImageProcessor.from_pretrained("elliemci/maskformer_tumor_segmentation")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define a custom dataset class to handle images
class ImageDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')

        if self.transform:
            image = self.transform(image)
        return image

def segment(image_files):
    """Takes a list of UploadedFile objects and returns a list of segmented images."""

    dataset = ImageDataset(image_files, transform=transforms.ToTensor())
    dataloader = DataLoader(dataset, batch_size=len(image_files), shuffle=False)  # Batch size is the number of images

    # process a batch
    with torch.no_grad():
        for batch in dataloader:  # Only one iteration since batch_size = len(image_files)
            pixel_values = batch.to(device, dtype=torch.float32)
            outputs = model(pixel_values=pixel_values)

            # Post-processing
            original_images = outputs.get("org_images", batch)
            target_sizes = [(image.shape[-2], image.shape[-1]) for image in original_images]
            predicted_masks = image_processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)

            return predicted_masks  # Return the list of segmented images

# components for Gradion interface
def update_gallery(images):
    print(f"Type in update_gallery: {type(images[0])}")
    gallery_data = []

    if images:
        segmented_images = segment(images)  # Process all images in one batch

        for i, image in enumerate(images):
            segmented_image_pil = transforms.ToPILImage()(segmented_images[i].to(device, dtype=torch.float32))
            gallery_data.extend([(image, "Original Image"), (segmented_image_pil, "Segmented Image")])

    return gallery_data

# Gradio UI for MEI segmentation
import gradio as gr

with gr.Blocks() as demo:
  gr.Markdown("<h1 style='text-align: center;'>MRI Brain Tumor Segmentation App</h1>")

  with gr.Column():
    with gr.Column():
      image_files = gr.Files(label="Upload MRI files",
                                     file_count="multiple",
                                     type="filepath")
      with gr.Row():
        gallery = gr.Gallery(label="Brain Images and Tumor Segmentation")

        image_files.change(
             fn=update_gallery,
             inputs=[image_files],
             outputs=[gallery])

demo.launch()