elliemci commited on
Commit
0ebdb2c
·
verified ·
1 Parent(s): 3b4a878

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torchvision import transforms
4
+ from torch.utils.data import DataLoader, Dataset
5
+ from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor
6
+ from PIL import Image
7
+ from pathlib import Path
8
+
9
+ # can upload from Huggingface Space "elliemci/maskformer_tumor_segmentation"
10
+ model = MaskFormerForInstanceSegmentation.from_pretrained("elliemci/maskformer_tumor_segmentation")
11
+ image_processor = MaskFormerImageProcessor.from_pretrained("elliemci/maskformer_tumor_segmentation")
12
+
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ model.to(device)
15
+
16
+ # Define a custom dataset class to handle images
17
+ class ImageDataset(Dataset):
18
+ def __init__(self, image_paths, transform=None):
19
+ self.image_paths = image_paths
20
+ self.transform = transform
21
+
22
+ def __len__(self):
23
+ return len(self.image_paths)
24
+
25
+ def __getitem__(self, idx):
26
+ image = Image.open(self.image_paths[idx]).convert('RGB')
27
+
28
+ if self.transform:
29
+ image = self.transform(image)
30
+ return image
31
+
32
+ def segment(image_files):
33
+ """Takes a list of UploadedFile objects and returns a list of segmented images."""
34
+
35
+ dataset = ImageDataset(image_files, transform=transforms.ToTensor())
36
+ dataloader = DataLoader(dataset, batch_size=len(image_files), shuffle=False) # Batch size is the number of images
37
+
38
+ # process a batch
39
+ with torch.no_grad():
40
+ for batch in dataloader: # Only one iteration since batch_size = len(image_files)
41
+ pixel_values = batch.to(device, dtype=torch.float32)
42
+ outputs = model(pixel_values=pixel_values)
43
+
44
+ # Post-processing
45
+ original_images = outputs.get("org_images", batch)
46
+ target_sizes = [(image.shape[-2], image.shape[-1]) for image in original_images]
47
+ predicted_masks = image_processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes)
48
+
49
+ return predicted_masks # Return the list of segmented images
50
+
51
+ # components for Gradion interface
52
+ def update_gallery(images):
53
+ print(f"Type in update_gallery: {type(images[0])}")
54
+ gallery_data = []
55
+
56
+ if images:
57
+ segmented_images = segment(images) # Process all images in one batch
58
+
59
+ for i, image in enumerate(images):
60
+ segmented_image_pil = transforms.ToPILImage()(segmented_images[i].to(device, dtype=torch.float32))
61
+ gallery_data.extend([(image, "Original Image"), (segmented_image_pil, "Segmented Image")])
62
+
63
+ return gallery_data
64
+
65
+ # Gradio UI for MEI segmentation
66
+ import gradio as gr
67
+
68
+ with gr.Blocks() as demo:
69
+ gr.Markdown("<h1 style='text-align: center;'>MRI Brain Tumor Segmentation App</h1>")
70
+
71
+ with gr.Column():
72
+ with gr.Column():
73
+ image_files = gr.Files(label="Upload MRI files",
74
+ file_count="multiple",
75
+ type="filepath")
76
+ with gr.Row():
77
+ gallery = gr.Gallery(label="Brain Images and Tumor Segmentation")
78
+
79
+ image_files.change(
80
+ fn=update_gallery,
81
+ inputs=[image_files],
82
+ outputs=[gallery])
83
+
84
+ demo.launch()