maryann-gitonga commited on
Commit
8a142b1
·
1 Parent(s): 662aaa7

Delete 3D_Brain_Tumor_Segmentation_Attention_UNet.ipynb

Browse files
3D_Brain_Tumor_Segmentation_Attention_UNet.ipynb DELETED
@@ -1,1105 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {
6
- "id": "TdEse3Kwq3JD"
7
- },
8
- "source": [
9
- "# Import Necessary Libraries"
10
- ]
11
- },
12
- {
13
- "cell_type": "code",
14
- "execution_count": null,
15
- "metadata": {
16
- "id": "WRKzuv_5owuz"
17
- },
18
- "outputs": [],
19
- "source": [
20
- "import numpy as np\n",
21
- "import nibabel as nib\n",
22
- "import glob\n",
23
- "from tensorflow.keras.utils import to_categorical # multiclass semantic segmentation, therefore the volumes to categorical\n",
24
- "import matplotlib.pyplot as plt\n",
25
- "from tifffile import imsave\n",
26
- "from sklearn.preprocessing import MinMaxScaler #scale values\n",
27
- "import tensorflow as tf\n",
28
- "import random\n",
29
- "import os.path\n",
30
- "!pip install split-folders\n",
31
- "!pip3 install -U segmentation-models-3D\n",
32
- "import splitfolders\n",
33
- "!pip install -q -U keras-tuner"
34
- ]
35
- },
36
- {
37
- "cell_type": "code",
38
- "execution_count": null,
39
- "metadata": {
40
- "id": "vEtRg2vutWru"
41
- },
42
- "outputs": [],
43
- "source": [
44
- "# To always ensure that the GPU is available\n",
45
- "import tensorflow as tf\n",
46
- "device_name = tf.test.gpu_device_name()\n",
47
- "if device_name != '/device:GPU:0':\n",
48
- " raise SystemError('GPU device not found')\n",
49
- "print('Found GPU at: {}'.format(device_name))"
50
- ]
51
- },
52
- {
53
- "cell_type": "markdown",
54
- "metadata": {
55
- "id": "L5yBxROtvDAI"
56
- },
57
- "source": [
58
- "# Define the MinMax Scaler + Mount Drive to access Dataset\n",
59
- "\n",
60
- "* The MinMax scaler is necessary for transforming the scans' features to a range between 0 and 1"
61
- ]
62
- },
63
- {
64
- "cell_type": "code",
65
- "execution_count": null,
66
- "metadata": {
67
- "id": "sqMRiba8q-30"
68
- },
69
- "outputs": [],
70
- "source": [
71
- "scaler = MinMaxScaler()\n",
72
- "\n",
73
- "from google.colab import drive\n",
74
- "drive.mount('/content/drive')"
75
- ]
76
- },
77
- {
78
- "cell_type": "markdown",
79
- "metadata": {
80
- "id": "XH4_Z5f2sfxZ"
81
- },
82
- "source": [
83
- "# Load sample images and visualize\n",
84
- "\n"
85
- ]
86
- },
87
- {
88
- "cell_type": "code",
89
- "execution_count": null,
90
- "metadata": {
91
- "id": "SvfI9iTrrZuN"
92
- },
93
- "outputs": [],
94
- "source": [
95
- "DATASET_PATH = ''\n",
96
- "\n",
97
- "test_image_flair = nib.load(DATASET_PATH + 'flair.nii').get_fdata()\n",
98
- "print(test_image_flair[156][98][78])\n",
99
- "test_image_flair = scaler.fit_transform(test_image_flair.reshape(-1, test_image_flair.shape[-1])).reshape(test_image_flair.shape)\n",
100
- "print(test_image_flair[156][98][78])\n",
101
- "\n",
102
- "test_image_t1 = nib.load(DATASET_PATH + 't1.nii').get_fdata()\n",
103
- "test_image_t1 = scaler.fit_transform(test_image_t1.reshape(-1, test_image_t1.shape[-1])).reshape(test_image_t1.shape)\n",
104
- "\n",
105
- "test_image_t1ce = nib.load(DATASET_PATH + 't1ce.nii').get_fdata()\n",
106
- "test_image_t1ce = scaler.fit_transform(test_image_t1ce.reshape(-1, test_image_t1ce.shape[-1])).reshape(test_image_t1ce.shape)\n",
107
- "\n",
108
- "test_image_t2 = nib.load(DATASET_PATH + 't2.nii').get_fdata()\n",
109
- "test_image_t2 = scaler.fit_transform(test_image_t2.reshape(-1, test_image_t2.shape[-1])).reshape(test_image_t2.shape)\n",
110
- "\n",
111
- "test_mask = nib.load(DATASET_PATH + 'seg.nii').get_fdata()\n",
112
- "test_mask = test_mask.astype(np.uint8)\n",
113
- "\n",
114
- "print(np.unique(test_mask))\n",
115
- "# Reassign label value 4 to 3\n",
116
- "test_mask[test_mask==4] = 3\n",
117
- "print(np.unique(test_mask))"
118
- ]
119
- },
120
- {
121
- "cell_type": "code",
122
- "execution_count": null,
123
- "metadata": {
124
- "id": "aTkjA-mgwecE"
125
- },
126
- "outputs": [],
127
- "source": [
128
- "n_slice = random.randint(0, test_mask.shape[2])\n",
129
- "\n",
130
- "plt.figure(figsize=(12,8))\n",
131
- "plt.subplot(231)\n",
132
- "plt.imshow(test_image_flair[:, :, n_slice], cmap='gray')\n",
133
- "plt.title('Flair Scan')\n",
134
- "\n",
135
- "plt.subplot(232)\n",
136
- "plt.imshow(test_image_t1[:, :, n_slice], cmap='gray')\n",
137
- "plt.title('T1 Scan')\n",
138
- "\n",
139
- "plt.subplot(233)\n",
140
- "plt.imshow(test_image_t1ce[:, :, n_slice], cmap='gray')\n",
141
- "plt.title('T1ce Scan')\n",
142
- "\n",
143
- "plt.subplot(234)\n",
144
- "plt.imshow(test_image_t2[:, :, n_slice], cmap='gray')\n",
145
- "plt.title('T2 Scan')\n",
146
- "\n",
147
- "plt.subplot(235)\n",
148
- "plt.imshow(test_mask[:, :, n_slice])\n",
149
- "plt.title('Mask')\n",
150
- "\n",
151
- "plt.show()\n",
152
- "\n"
153
- ]
154
- },
155
- {
156
- "cell_type": "markdown",
157
- "metadata": {
158
- "id": "EORoZoj7yPfW"
159
- },
160
- "source": [
161
- "# Data Processing: Combining the volumes of scans to one + Cropping the scans and masks\n",
162
- "\n",
163
- "* The numpy array is reshaped to 2D, the dimensions the scaler can take as input, the array is transformed and then reshaped back to 3D\n",
164
- "* Result: the feature at position [156][98][78] of the loaded FLAIR scan numpy array is transformed from 1920.0 to 0.7683...\n",
165
- "* The three scans to be used are stacked together to forme a combined scan.\n",
166
- "* Result: A FLAIR scan, a T1CE scan and a T2 scan, all of dimensions 255 x 255 x 155 are stacked to form a combined scan of dimensions 255 x 255 x 155 x 3\n",
167
- "* The combined scan is cropped to 128 x 128 x 128 x 3\n",
168
- "* Label 4 in the dataset is reassigned to label 3 resulting to a continuous list of labels: 0, 1, 2, 3"
169
- ]
170
- },
171
- {
172
- "cell_type": "code",
173
- "execution_count": null,
174
- "metadata": {
175
- "id": "-3u91yIqybn-"
176
- },
177
- "outputs": [],
178
- "source": [
179
- "combined_x = np.stack([test_image_flair, test_image_t1ce, test_image_t2], axis=3)\n",
180
- "combined_x = combined_x[56:184, 56:184, 13:141] #crop to 128 x 128 x 128 X 3\n",
181
- "\n",
182
- "test_mask = test_mask[56:184, 56:184, 13:141]\n",
183
- "n_slice = random.randint(0, test_mask.shape[1])\n",
184
- "plt.figure(figsize=(12, 8))\n",
185
- "\n",
186
- "plt.subplot(231)\n",
187
- "plt.imshow(combined_x[:, :, n_slice, 0], cmap='gray')\n",
188
- "plt.title('Flair Scan')\n",
189
- "\n",
190
- "plt.subplot(232)\n",
191
- "plt.imshow(combined_x[:, :, n_slice, 1], cmap='gray')\n",
192
- "plt.title('T1ce Scan')\n",
193
- "\n",
194
- "plt.subplot(233)\n",
195
- "plt.imshow(combined_x[:, :, n_slice, 2], cmap='gray')\n",
196
- "plt.title('T2 Scan')\n",
197
- "\n",
198
- "plt.subplot(234)\n",
199
- "plt.imshow(test_mask[:, :, n_slice])\n",
200
- "plt.title('Mask')\n",
201
- "\n",
202
- "plt.show()"
203
- ]
204
- },
205
- {
206
- "cell_type": "code",
207
- "execution_count": null,
208
- "metadata": {
209
- "id": "T8r7sy4QND41"
210
- },
211
- "outputs": [],
212
- "source": [
213
- "from tensorflow.keras import backend as K\n",
214
- "\n",
215
- "print(K.int_shape(test_image_flair))\n",
216
- "\n",
217
- "print(K.int_shape(combined_x))"
218
- ]
219
- },
220
- {
221
- "cell_type": "code",
222
- "execution_count": null,
223
- "metadata": {
224
- "id": "WeD_PqCv6Vww"
225
- },
226
- "outputs": [],
227
- "source": [
228
- "flair_list = sorted(glob.glob(DATASET_PATH + '*/flair.nii'))\n",
229
- "t1_list = sorted(glob.glob(DATASET_PATH + '*/t1.nii'))\n",
230
- "t1ce_list = sorted(glob.glob(DATASET_PATH + '*/t1ce.nii'))\n",
231
- "t2_list = sorted(glob.glob(DATASET_PATH + '*/t2.nii'))\n",
232
- "mask_list = sorted(glob.glob(DATASET_PATH + '*/seg.nii'))\n",
233
- "\n",
234
- "\n",
235
- "for img in range(len(flair_list)):\n",
236
- " print('Now processing image and masks no: ', img)\n",
237
- "\n",
238
- " temp_image_flair = nib.load(flair_list[img]).get_fdata()\n",
239
- " temp_image_flair = scaler.fit_transform(temp_image_flair.reshape(-1, temp_image_flair.shape[-1])).reshape(temp_image_flair.shape)\n",
240
- "\n",
241
- " temp_image_t1 = nib.load(t1_list[img]).get_fdata()\n",
242
- " temp_image_t1 = scaler.fit_transform(temp_image_t1.reshape(-1, temp_image_t1.shape[-1])).reshape(temp_image_t1.shape)\n",
243
- "\n",
244
- " temp_image_t1ce = nib.load(t1ce_list[img]).get_fdata()\n",
245
- " temp_image_t1ce = scaler.fit_transform(temp_image_t1ce.reshape(-1, temp_image_t1ce.shape[-1])).reshape(temp_image_t1ce.shape)\n",
246
- "\n",
247
- " temp_image_t2 = nib.load(t2_list[img]).get_fdata()\n",
248
- " temp_image_t2 = scaler.fit_transform(temp_image_t2.reshape(-1, temp_image_t2.shape[-1])).reshape(temp_image_t2.shape)\n",
249
- "\n",
250
- " temp_mask = nib.load(mask_list[img]).get_fdata()\n",
251
- " temp_mask = temp_mask.astype(np.uint8)\n",
252
- " temp_mask[temp_mask == 4] = 3\n",
253
- "\n",
254
- " temp_combined_images = np.stack([temp_image_flair, temp_image_t1, temp_image_t1ce, temp_image_t2], axis = 3)\n",
255
- " temp_combined_images = temp_combined_images[56:184, 56:184, 13:141]\n",
256
- " temp_mask = temp_mask[56:184, 56:184, 13:141]\n",
257
- "\n",
258
- " val, counts = np.unique(temp_mask, return_counts=True)\n",
259
- "\n",
260
- " if(1 - (counts[0]/counts.sum())) > 0.01:\n",
261
- " temp_mask = to_categorical(temp_mask, num_classes=4)\n",
262
- " np.save(DATASET_PATH + 'final_dataset/scans/image_' + str(img) + '.npy', temp_combined_images)\n",
263
- " np.save(DATASET_PATH + 'final_dataset/masks/image_' + str(img) + '.npy', temp_mask)\n",
264
- " print(\"Saved\")\n",
265
- " else:\n",
266
- " print(\"Not saved\")"
267
- ]
268
- },
269
- {
270
- "cell_type": "markdown",
271
- "metadata": {
272
- "id": "-wICUx56ugDz"
273
- },
274
- "source": [
275
- "# Dataset Splitting: 60:20:20 for train, val and test"
276
- ]
277
- },
278
- {
279
- "cell_type": "code",
280
- "execution_count": null,
281
- "metadata": {
282
- "id": "Oi_g5D01HSnq"
283
- },
284
- "outputs": [],
285
- "source": [
286
- "input_folder = DATASET_PATH + 'final_dataset/'\n",
287
- "output_folder = DATASET_PATH + 'split_dataset/'\n",
288
- "splitfolders.ratio(input_folder, output=output_folder, seed=42, ratio=(.6, .2, .2), group_prefix=None)"
289
- ]
290
- },
291
- {
292
- "cell_type": "markdown",
293
- "metadata": {
294
- "id": "RtaRf0B4kPkM"
295
- },
296
- "source": [
297
- "# Data Generator\n",
298
- "\n",
299
- "\n"
300
- ]
301
- },
302
- {
303
- "cell_type": "code",
304
- "execution_count": null,
305
- "metadata": {
306
- "id": "UMfHysy2ixc8"
307
- },
308
- "outputs": [],
309
- "source": [
310
- "import os\n",
311
- "import numpy as np\n",
312
- "\n",
313
- "def load_img(img_dir, img_list):\n",
314
- " images=[]\n",
315
- " for i, image_name in enumerate(img_list):\n",
316
- " if(image_name.split('.')[1] == 'npy'):\n",
317
- " image = np.load(img_dir + image_name)\n",
318
- " images.append(image)\n",
319
- " images = np.array(images)\n",
320
- " return images\n",
321
- "\n",
322
- "def imageLoader(img_dir, img_list, mask_dir, mask_list, batch_size):\n",
323
- " L = len(img_list)\n",
324
- " # keras needs the generator infinite, so use while True\n",
325
- " while True:\n",
326
- " batch_start = 0\n",
327
- " batch_end = batch_size\n",
328
- "\n",
329
- " while batch_start < L:\n",
330
- " limit = min(batch_end, L)\n",
331
- " X = load_img(img_dir, img_list[batch_start:limit])\n",
332
- " Y = load_img(mask_dir, mask_list[batch_start:limit])\n",
333
- "\n",
334
- " yield(X, Y) # a tuple with two numpy arrays with batch_size samples\n",
335
- "\n",
336
- " batch_start += batch_size\n",
337
- " batch_end += batch_size\n",
338
- "\n",
339
- "\n",
340
- "# Test the generator\n",
341
- "TRAIN_DATASET_PATH = ''\n",
342
- "train_img_dir = TRAIN_DATASET_PATH + 'scans/'\n",
343
- "train_mask_dir = TRAIN_DATASET_PATH + 'masks/'\n",
344
- "\n",
345
- "train_img_list = os.listdir(train_img_dir)\n",
346
- "train_mask_list = os.listdir(train_mask_dir)\n",
347
- "\n",
348
- "batch_size = 2\n",
349
- "\n",
350
- "train_img_datagen = imageLoader(train_img_dir, train_img_list,\n",
351
- " train_mask_dir, train_mask_list, batch_size)\n",
352
- "\n",
353
- "# Verify generator - In python 3 next() is renamed as __next__()\n",
354
- "img, msk = train_img_datagen.__next__()\n",
355
- "\n",
356
- "img_num = random.randint(0, img.shape[0]-1)\n",
357
- "\n",
358
- "test_img = img[img_num]\n",
359
- "test_mask = msk[img_num]\n",
360
- "test_mask = np.argmax(test_mask, axis=3)\n",
361
- "\n",
362
- "n_slice = random.randint(0, test_mask.shape[2])\n",
363
- "plt.figure(figsize=(12,8))\n",
364
- "\n",
365
- "plt.subplot(221)\n",
366
- "plt.imshow(test_img[:, :, n_slice, 0], cmap='gray')\n",
367
- "plt.title('Flair Scan')\n",
368
- "\n",
369
- "plt.subplot(222)\n",
370
- "plt.imshow(test_img[:, :, n_slice, 1], cmap='gray')\n",
371
- "plt.title('T1ce Scan')\n",
372
- "\n",
373
- "plt.subplot(223)\n",
374
- "plt.imshow(test_img[:, :, n_slice, 2], cmap='gray')\n",
375
- "plt.title('T2 Scan')\n",
376
- "\n",
377
- "plt.subplot(224)\n",
378
- "plt.imshow(test_mask[:, :, n_slice])\n",
379
- "plt.title('Mask')\n",
380
- "\n",
381
- "plt.show()"
382
- ]
383
- },
384
- {
385
- "cell_type": "markdown",
386
- "metadata": {
387
- "id": "ReTmFPr0QV17"
388
- },
389
- "source": [
390
- "# Define image generators for training, validation and testing"
391
- ]
392
- },
393
- {
394
- "cell_type": "code",
395
- "execution_count": null,
396
- "metadata": {
397
- "id": "HS9Dihs_QbqU"
398
- },
399
- "outputs": [],
400
- "source": [
401
- "DATASET_PATH = ''\n",
402
- "train_img_dir = DATASET_PATH + 'train/scans/'\n",
403
- "train_mask_dir = DATASET_PATH + 'train/masks/'\n",
404
- "\n",
405
- "val_img_dir = DATASET_PATH + 'val/scans/'\n",
406
- "val_mask_dir = DATASET_PATH + 'val/masks/'\n",
407
- "\n",
408
- "test_img_dir = DATASET_PATH + 'test/scans/'\n",
409
- "test_mask_dir = DATASET_PATH + 'test/masks/'\n",
410
- "\n",
411
- "train_img_list = os.listdir(train_img_dir)\n",
412
- "train_mask_list = os.listdir(train_mask_dir)\n",
413
- "\n",
414
- "val_img_list = os.listdir(val_img_dir)\n",
415
- "val_mask_list = os.listdir(val_mask_dir)\n",
416
- "\n",
417
- "test_img_list = os.listdir(test_img_dir)\n",
418
- "test_mask_list = os.listdir(test_mask_dir)\n",
419
- "\n",
420
- "batch_size = 2\n",
421
- "train_img_datagen = imageLoader(train_img_dir, train_img_list,\n",
422
- " train_mask_dir, train_mask_list, batch_size)\n",
423
- "\n",
424
- "val_img_datagen = imageLoader(val_img_dir, val_img_list,\n",
425
- " val_mask_dir, val_mask_list, batch_size)\n",
426
- "\n",
427
- "test_img_datagen = imageLoader(test_img_dir, test_img_list,\n",
428
- " test_mask_dir, test_mask_list, batch_size)\n"
429
- ]
430
- },
431
- {
432
- "cell_type": "markdown",
433
- "metadata": {
434
- "id": "dBKMHMn96Z3c"
435
- },
436
- "source": [
437
- "# Losses and metrics\n",
438
- "* These losses and metrics best handle the problem of class imbalance\n",
439
- "* Used: dice_coef as a metric, tversky_loss as a loss"
440
- ]
441
- },
442
- {
443
- "cell_type": "code",
444
- "execution_count": null,
445
- "metadata": {
446
- "id": "pshixCsr6eyt"
447
- },
448
- "outputs": [],
449
- "source": [
450
- "import tensorflow.keras.backend as K\n",
451
- "\n",
452
- "\n",
453
- "def dice_coef(y_true, y_pred, smooth=1):\n",
454
- " y_true_f = K.flatten(y_true)\n",
455
- " y_pred_f = K.flatten(y_pred)\n",
456
- " intersection = K.sum(y_true_f * y_pred_f)\n",
457
- " return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) +\n",
458
- " smooth)\n",
459
- "\n",
460
- "\n",
461
- "def dice_coef_loss(y_true, y_pred):\n",
462
- " return 1 - dice_coef(y_true, y_pred)\n",
463
- "\n",
464
- "\n",
465
- "def tversky(y_true, y_pred, smooth=1, alpha=0.7):\n",
466
- " y_true_pos = K.flatten(y_true)\n",
467
- " y_pred_pos = K.flatten(y_pred)\n",
468
- " true_pos = K.sum(y_true_pos * y_pred_pos)\n",
469
- " false_neg = K.sum(y_true_pos * (1 - y_pred_pos))\n",
470
- " false_pos = K.sum((1 - y_true_pos) * y_pred_pos)\n",
471
- " return (true_pos + smooth) / (true_pos + alpha * false_neg +\n",
472
- " (1 - alpha) * false_pos + smooth)\n",
473
- "\n",
474
- "\n",
475
- "def tversky_loss(y_true, y_pred):\n",
476
- " return 1 - tversky(y_true, y_pred)\n",
477
- "\n",
478
- "\n",
479
- "def focal_tversky_loss(y_true, y_pred, gamma=0.75):\n",
480
- " tv = tversky(y_true, y_pred)\n",
481
- " return K.pow((1 - tv), gamma)"
482
- ]
483
- },
484
- {
485
- "cell_type": "markdown",
486
- "metadata": {
487
- "id": "2o2WuIhaW5ff"
488
- },
489
- "source": [
490
- "# Define loss, metrics and optimizer to be used for training"
491
- ]
492
- },
493
- {
494
- "cell_type": "code",
495
- "execution_count": null,
496
- "metadata": {
497
- "id": "WxiJ1eUQXJ4I"
498
- },
499
- "outputs": [],
500
- "source": [
501
- "from keras.models import Model\n",
502
- "from keras.layers import Input, Conv3D, MaxPooling3D, Activation, add, concatenate, Conv3DTranspose, BatchNormalization, Dropout, UpSampling3D, multiply\n",
503
- "from tensorflow.keras.optimizers import Adam\n",
504
- "from keras import layers\n",
505
- "\n",
506
- "kernel_initializer = 'he_uniform'\n",
507
- "\n",
508
- "import segmentation_models_3D as sm\n",
509
- "\n",
510
- "metrics = [dice_coef]\n",
511
- "\n",
512
- "LR = 0.0001\n",
513
- "optim = Adam(LR)\n",
514
- "\n",
515
- "steps_per_epoch = len(train_img_list) // batch_size\n",
516
- "val_steps_per_epoch = len(val_img_list) // batch_size"
517
- ]
518
- },
519
- {
520
- "cell_type": "markdown",
521
- "metadata": {
522
- "id": "PR2Ugre0YP-v"
523
- },
524
- "source": [
525
- "# 3D UNet Model"
526
- ]
527
- },
528
- {
529
- "cell_type": "code",
530
- "execution_count": null,
531
- "metadata": {
532
- "id": "N0VyhdjCYVuZ"
533
- },
534
- "outputs": [],
535
- "source": [
536
- "def UNet(IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS, num_classes):\n",
537
- " inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS))\n",
538
- "\n",
539
- " # Downsampling\n",
540
- " c1 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(inputs)\n",
541
- " c1 = Dropout(0.1)(c1)\n",
542
- " c1 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c1)\n",
543
- " p1 = MaxPooling3D((2, 2, 2))(c1)\n",
544
- "\n",
545
- " c2 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(p1)\n",
546
- " c2 = Dropout(0.1)(c2)\n",
547
- " c2 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c2)\n",
548
- " p2 = MaxPooling3D((2, 2, 2))(c2)\n",
549
- "\n",
550
- " c3 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(p2)\n",
551
- " c3 = Dropout(0.2)(c3)\n",
552
- " c3 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c3)\n",
553
- " p3 = MaxPooling3D((2, 2, 2))(c3)\n",
554
- "\n",
555
- " c4 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(p3)\n",
556
- " c4 = Dropout(0.2)(c4)\n",
557
- " c4 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c4)\n",
558
- " p4 = MaxPooling3D((2, 2, 2))(c4)\n",
559
- "\n",
560
- " c5 = Conv3D(filters = 256, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(p4)\n",
561
- " c5 = Dropout(0.3)(c5)\n",
562
- " c5 = Conv3D(filters = 256, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c5)\n",
563
- " \n",
564
- " # Upsampling part\n",
565
- " u6 = Conv3DTranspose(128, (2, 2, 2), strides=(2, 2, 2), padding='same')(c5)\n",
566
- " u6 = concatenate([u6, c4])\n",
567
- " c6 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(u6)\n",
568
- " c6 = Dropout(0.2)(c6)\n",
569
- " c6 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c6) \n",
570
- " \n",
571
- " u7 = Conv3DTranspose(64, (2, 2, 2), strides=(2, 2, 2), padding='same')(c6)\n",
572
- " u7 = concatenate([u7, c3])\n",
573
- " c7 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(u7)\n",
574
- " c7 = Dropout(0.2)(c7)\n",
575
- " c7 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c7) \n",
576
- " \n",
577
- " u8 = Conv3DTranspose(32, (2, 2, 2), strides=(2, 2, 2), padding='same')(c7)\n",
578
- " u8 = concatenate([u8, c2])\n",
579
- " c8 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(u8)\n",
580
- " c8 = Dropout(0.1)(c8)\n",
581
- " c8 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c8) \n",
582
- "\n",
583
- " u9 = Conv3DTranspose(16, (2, 2, 2), strides=(2, 2, 2), padding='same')(c8)\n",
584
- " u9 = concatenate([u9, c1])\n",
585
- " c9 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(u9)\n",
586
- " c9 = Dropout(0.1)(c9)\n",
587
- " c9 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation='relu', kernel_initializer=kernel_initializer, padding='same')(c9) \n",
588
- "\n",
589
- " outputs = Conv3D(num_classes, (1, 1, 1), activation='softmax')(c9)\n",
590
- "\n",
591
- " model = Model(inputs=[inputs], outputs=[outputs])\n",
592
- " model.summary()\n",
593
- "\n",
594
- " return model"
595
- ]
596
- },
597
- {
598
- "cell_type": "markdown",
599
- "metadata": {
600
- "id": "-Aw_Peb9iJYb"
601
- },
602
- "source": [
603
- "# Test the working of the 3D UNet model"
604
- ]
605
- },
606
- {
607
- "cell_type": "code",
608
- "execution_count": null,
609
- "metadata": {
610
- "id": "fjdzCTisiMLI"
611
- },
612
- "outputs": [],
613
- "source": [
614
- "steps_per_epoch = len(train_img_list)//batch_size\n",
615
- "val_steps_per_epoch = len(val_img_list)//batch_size\n",
616
- "\n",
617
- "model = UNet(IMG_HEIGHT = 128,\n",
618
- " IMG_WIDTH = 128,\n",
619
- " IMG_DEPTH = 128,\n",
620
- " IMG_CHANNELS = 3,\n",
621
- " num_classes = 4)\n",
622
- "\n",
623
- "model.compile(optimizer = optim, loss = tversky_loss, metrics = metrics)\n",
624
- "\n",
625
- "print(model.summary)\n",
626
- "\n",
627
- "print(model.input_shape)\n",
628
- "print(model.output_shape)"
629
- ]
630
- },
631
- {
632
- "cell_type": "markdown",
633
- "metadata": {
634
- "id": "e6Cvn6hWvars"
635
- },
636
- "source": [
637
- "# 3D Attention UNet Model"
638
- ]
639
- },
640
- {
641
- "cell_type": "code",
642
- "execution_count": null,
643
- "metadata": {
644
- "id": "JBcFdz80v2mL"
645
- },
646
- "outputs": [],
647
- "source": [
648
- "from keras.layers.core.activation import Activation\n",
649
- "from tensorflow.keras import backend as K\n",
650
- "from keras.layers import LeakyReLU\n",
651
- "\n",
652
- "def repeat_elem(tensor, rep):\n",
653
- " # lambda function to repeat Repeats the elements of a tensor along an axis\n",
654
- " #by a factor of rep.\n",
655
- " # If tensor has shape (None, 128,128,3), lambda will return a tensor of shape \n",
656
- " #(None, 128,128,6), if specified axis=3 and rep=2.\n",
657
- "\n",
658
- " return layers.Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=4),\n",
659
- " arguments={'repnum': rep})(tensor)\n",
660
- "\n",
661
- "def attention_block(x, gating, inter_shape):\n",
662
- " shape_x = K.int_shape(x)\n",
663
- " shape_g = K.int_shape(gating)\n",
664
- "\n",
665
- " # Getting the gating signal to the same number of filters as the inter_shape\n",
666
- " phi_g = Conv3D(filters=inter_shape, kernel_size=1, strides=1, padding='same')(gating)\n",
667
- "\n",
668
- " # Geting the x signal to the same shape as the gating signal\n",
669
- " theta_x = Conv3D(filters=inter_shape, kernel_size=3, strides=(\n",
670
- " shape_x[1] // shape_g[1],\n",
671
- " shape_x[2] // shape_g[2],\n",
672
- " shape_x[3] // shape_g[3]\n",
673
- " ), padding='same')(x)\n",
674
- " shape_theta_x = K.int_shape(theta_x)\n",
675
- "\n",
676
- " print(shape_theta_x, shape_g)\n",
677
- "\n",
678
- " # Elemet-wise addition of the gating and x signals\n",
679
- " xg_sum = add([phi_g, theta_x])\n",
680
- " xg_sum = Activation('relu')(xg_sum)\n",
681
- "\n",
682
- " # 1x1x1 convolution\n",
683
- " psi = Conv3D(filters=1, kernel_size=1, padding='same')(xg_sum)\n",
684
- " sigmoid_psi = Activation('sigmoid')(psi)\n",
685
- " shape_sigmoid = K.int_shape(sigmoid_psi)\n",
686
- "\n",
687
- " # Upsampling psi back to the original dimensions of x signal to enable \n",
688
- " # element-wise multiplication with the signal\n",
689
- "\n",
690
- " upsampled_sigmoid_psi = UpSampling3D(size=(\n",
691
- " shape_x[1] // shape_sigmoid[1], \n",
692
- " shape_x[2] // shape_sigmoid[2],\n",
693
- " shape_x[3] // shape_sigmoid[3]\n",
694
- " ))(sigmoid_psi)\n",
695
- "\n",
696
- " # Expand the filter axis to the number of filters in the original x signal\n",
697
- " upsampled_sigmoid_psi = repeat_elem(upsampled_sigmoid_psi, shape_x[4])\n",
698
- "\n",
699
- " # Element-wise multiplication of attention coefficients back onto original x signal\n",
700
- " attention_coeffs = multiply([upsampled_sigmoid_psi, x])\n",
701
- "\n",
702
- " # Final 1x1x1 convolution to consolidate attention signal to original x dimensions\n",
703
- " output = Conv3D(filters=shape_x[3], kernel_size=1, strides=1, padding='same')(attention_coeffs)\n",
704
- " output = BatchNormalization()(output)\n",
705
- " return output\n",
706
- "\n",
707
- "\n",
708
- "# Gating signal\n",
709
- "def gating_signal(input, output_size, batch_norm=False):\n",
710
- " # Resize the down layer feature map into the same dimensions as the up layer feature map using 1x1 conv\n",
711
- " # Return: the gating feature map with the same dimension of the up layer feature map\n",
712
- " x = Conv3D(output_size, (1, 1, 1), padding='same')(input)\n",
713
- " if batch_norm:\n",
714
- " x = BatchNormalization()(x)\n",
715
- " x = Activation('relu')(x)\n",
716
- " return x\n",
717
- "\n",
718
- "\n",
719
- "# Attention UNet\n",
720
- "def attention_unet(IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS, num_classes, batch_norm = True):\n",
721
- " inputs = Input((IMG_HEIGHT, IMG_WIDTH, IMG_DEPTH, IMG_CHANNELS))\n",
722
- " FILTER_NUM = 64 #\n",
723
- " FILTER_SIZE = 3 #\n",
724
- " UP_SAMPLING_SIZE = 2 # \n",
725
- "\n",
726
- " c1 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(inputs)\n",
727
- " c1 = Dropout(0.1)(c1)\n",
728
- " c1 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c1)\n",
729
- " p1 = MaxPooling3D((2, 2, 2))(c1)\n",
730
- "\n",
731
- " c2 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(p1)\n",
732
- " c2 = Dropout(0.1)(c2)\n",
733
- " c2 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c2)\n",
734
- " p2 = MaxPooling3D((2, 2, 2))(c2)\n",
735
- "\n",
736
- " c3 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(p2)\n",
737
- " c3 = Dropout(0.2)(c3)\n",
738
- " c3 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c3)\n",
739
- " p3 = MaxPooling3D((2, 2, 2))(c3)\n",
740
- "\n",
741
- " c4 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(p3)\n",
742
- " c4 = Dropout(0.2)(c4)\n",
743
- " c4 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c4)\n",
744
- " p4 = MaxPooling3D((2, 2, 2))(c4)\n",
745
- "\n",
746
- " c5 = Conv3D(filters = 256, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(p4)\n",
747
- " c5 = Dropout(0.3)(c5)\n",
748
- " c5 = Conv3D(filters = 256, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c5)\n",
749
- " \n",
750
- "\n",
751
- " gating_6 = gating_signal(c5, 128, batch_norm)\n",
752
- " att_6 = attention_block(c4, gating_6, 128)\n",
753
- " u6 = UpSampling3D((2, 2, 2), data_format='channels_last')(c5)\n",
754
- " u6 = concatenate([u6, att_6])\n",
755
- " c6 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(u6)\n",
756
- " c6 = Dropout(0.2)(c6)\n",
757
- " c6 = Conv3D(filters = 128, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c6) \n",
758
- " \n",
759
- " gating_7 = gating_signal(c6, 64, batch_norm)\n",
760
- " att_7 = attention_block(c3, gating_6, 64)\n",
761
- " u7 = UpSampling3D((2, 2, 2), data_format='channels_last')(c6)\n",
762
- " u7 = concatenate([u7, att_7])\n",
763
- " c7 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(u7)\n",
764
- " c7 = Dropout(0.2)(c7)\n",
765
- " c7 = Conv3D(filters = 64, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c7) \n",
766
- " \n",
767
- " gating_8 = gating_signal(c7, 64, batch_norm)\n",
768
- " att_8 = attention_block(c2, gating_6, 64)\n",
769
- " u8 = UpSampling3D((2, 2, 2), data_format='channels_last')(c7)\n",
770
- " u8 = concatenate([u8, att_8])\n",
771
- " c8 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(u8)\n",
772
- " c8 = Dropout(0.1)(c8)\n",
773
- " c8 = Conv3D(filters = 32, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c8) \n",
774
- "\n",
775
- " gating_9 = gating_signal(c8, 64, batch_norm)\n",
776
- " att_9 = attention_block(c1, gating_6, 64)\n",
777
- " u9 = UpSampling3D((2, 2, 2), data_format='channels_last')(c8)\n",
778
- " u9 = concatenate([u9, att_9])\n",
779
- " c9 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(u9)\n",
780
- " c9 = Dropout(0.1)(c9)\n",
781
- " c9 = Conv3D(filters = 16, kernel_size = 3, strides = 1, activation=LeakyReLU(alpha=0.1), kernel_initializer=kernel_initializer, padding='same')(c9) \n",
782
- "\n",
783
- " outputs = Conv3D(num_classes, (1, 1, 1))(c9)\n",
784
- " outputs = BatchNormalization()(outputs)\n",
785
- " outputs = Activation('softmax')(outputs)\n",
786
- "\n",
787
- " model = Model(inputs=[inputs], outputs=[outputs], name=\"Attention_UNet\")\n",
788
- " model.summary()\n",
789
- "\n",
790
- " return model"
791
- ]
792
- },
793
- {
794
- "cell_type": "markdown",
795
- "metadata": {
796
- "id": "xndmsEwjVhn7"
797
- },
798
- "source": [
799
- "# Test the working of a 3D Attention UNet Model"
800
- ]
801
- },
802
- {
803
- "cell_type": "code",
804
- "execution_count": null,
805
- "metadata": {
806
- "id": "pBNjxGbjVn9U"
807
- },
808
- "outputs": [],
809
- "source": [
810
- "steps_per_epoch = len(train_img_list)//batch_size\n",
811
- "val_steps_per_epoch = len(val_img_list)//batch_size\n",
812
- "\n",
813
- "model = attention_unet(IMG_HEIGHT = 128,\n",
814
- " IMG_WIDTH = 128,\n",
815
- " IMG_DEPTH = 128,\n",
816
- " IMG_CHANNELS = 3,\n",
817
- " num_classes = 4)\n",
818
- "\n",
819
- "model.compile(optimizer = optim, loss = tversky_loss, metrics = metrics)\n",
820
- "\n",
821
- "print(model.summary)\n",
822
- "\n",
823
- "print(model.input_shape)\n",
824
- "print(model.output_shape)"
825
- ]
826
- },
827
- {
828
- "cell_type": "markdown",
829
- "metadata": {
830
- "id": "8qnlrlr1YXu4"
831
- },
832
- "source": [
833
- "# Fit the Model"
834
- ]
835
- },
836
- {
837
- "cell_type": "code",
838
- "execution_count": null,
839
- "metadata": {
840
- "id": "UXmCjFvjYaSG"
841
- },
842
- "outputs": [],
843
- "source": [
844
- "import tensorflow.keras as keras\n",
845
- "from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, CSVLogger, TerminateOnNaN\n",
846
- "\n",
847
- "checkpoint_path = ''\n",
848
- "log_path = ''\n",
849
- "\n",
850
- "callbacks = [\n",
851
- " EarlyStopping(monitor='val_loss', patience=4, verbose=1),\n",
852
- " ReduceLROnPlateau(factor=0.1,\n",
853
- " monitor='val_loss',\n",
854
- " patience=4,\n",
855
- " min_lr=0.0001,\n",
856
- " verbose=1,\n",
857
- " mode='min'),\n",
858
- " ModelCheckpoint(checkpoint_path,\n",
859
- " monitor='val_loss',\n",
860
- " mode='min',\n",
861
- " verbose=0,\n",
862
- " save_best_only=True),\n",
863
- " CSVLogger(log_path, separator=',', append=True),\n",
864
- " TerminateOnNaN()\n",
865
- "]\n",
866
- "\n",
867
- "history = model.fit(train_img_datagen,\n",
868
- " steps_per_epoch=steps_per_epoch,\n",
869
- " epochs=100,\n",
870
- " verbose=1,\n",
871
- " validation_data=val_img_datagen,\n",
872
- " validation_steps=val_steps_per_epoch,\n",
873
- " callbacks=callbacks\n",
874
- " )\n",
875
- "\n",
876
- "history_callback = np.save('', history.history)"
877
- ]
878
- },
879
- {
880
- "cell_type": "markdown",
881
- "metadata": {
882
- "id": "pfcKmJv4jP2J"
883
- },
884
- "source": [
885
- "# Load Model for more training"
886
- ]
887
- },
888
- {
889
- "cell_type": "code",
890
- "execution_count": null,
891
- "metadata": {
892
- "id": "7RXukeY_jiad"
893
- },
894
- "outputs": [],
895
- "source": [
896
- "import tensorflow.keras.models as load\n",
897
- "import keras\n",
898
- "model = load.load_model('', custom_objects={\n",
899
- " 'tversky_loss': tversky_loss,\n",
900
- " 'dice_coef': dice_coef\n",
901
- "})\n",
902
- "\n",
903
- "checkpoint_path = ''\n",
904
- "log_path = ''\n",
905
- "\n",
906
- "callbacks = [\n",
907
- " EarlyStopping(monitor='val_loss', patience=4, verbose=1),\n",
908
- " ReduceLROnPlateau(factor=0.1,\n",
909
- " monitor='val_loss',\n",
910
- " patience=4,\n",
911
- " min_lr=0.0001,\n",
912
- " verbose=1,\n",
913
- " mode='min'),\n",
914
- " ModelCheckpoint(checkpoint_path,\n",
915
- " monitor='val_loss',\n",
916
- " mode='min',\n",
917
- " verbose=0,\n",
918
- " save_best_only=True),\n",
919
- " CSVLogger(log_path, separator=',', append=True),\n",
920
- " TerminateOnNaN()\n",
921
- "]\n",
922
- "\n",
923
- "history = model.fit(train_img_datagen,\n",
924
- " steps_per_epoch=steps_per_epoch,\n",
925
- " epochs=100,\n",
926
- " verbose=1,\n",
927
- " validation_data=val_img_datagen,\n",
928
- " validation_steps=val_steps_per_epoch,\n",
929
- " callbacks=callbacks\n",
930
- " )\n",
931
- "\n",
932
- "history_callback = np.save('', history.history)"
933
- ]
934
- },
935
- {
936
- "cell_type": "markdown",
937
- "metadata": {
938
- "id": "SPBUC1HIfqDt"
939
- },
940
- "source": [
941
- "# Plot the training and validation loss (tversky) and dice coefficient (metric) at each epoch"
942
- ]
943
- },
944
- {
945
- "cell_type": "code",
946
- "execution_count": null,
947
- "metadata": {
948
- "id": "I7e4YkM5f1Jg"
949
- },
950
- "outputs": [],
951
- "source": [
952
- "history = np.load('',allow_pickle='TRUE').item()\n",
953
- "\n",
954
- "print(history)\n",
955
- "loss = history['loss']\n",
956
- "val_loss = history['val_loss']\n",
957
- "epochs = range(1, len(loss) + 1)\n",
958
- "plt.plot(epochs, loss, 'y', label='Training loss')\n",
959
- "plt.plot(epochs, val_loss, 'r', label='Validation loss')\n",
960
- "plt.title('Training and Validation Loss')\n",
961
- "plt.xlabel('Epochs')\n",
962
- "plt.ylabel('Loss')\n",
963
- "plt.legend()\n",
964
- "plt.show()\n",
965
- "\n",
966
- "acc = history['dice_coef']\n",
967
- "val_acc = history['val_dice_coef']\n",
968
- "\n",
969
- "plt.plot(epochs, acc, 'y', label='Training accuracy')\n",
970
- "plt.plot(epochs, val_acc, 'r', label='Validation accuracy')\n",
971
- "plt.title('Trainign and Validation Accuracy')\n",
972
- "plt.xlabel('Epochs')\n",
973
- "plt.ylabel('Accuracy')\n",
974
- "plt.legend()\n",
975
- "plt.show()"
976
- ]
977
- },
978
- {
979
- "cell_type": "markdown",
980
- "metadata": {
981
- "id": "XV8kjMkemQ-W"
982
- },
983
- "source": [
984
- "# Model Evaluation"
985
- ]
986
- },
987
- {
988
- "cell_type": "code",
989
- "execution_count": null,
990
- "metadata": {
991
- "id": "ChhYHB8PmTnK"
992
- },
993
- "outputs": [],
994
- "source": [
995
- "from tensorflow.keras.models import load_model\n",
996
- "my_model = load_model('', custom_objects={\n",
997
- " 'tversky_loss': tversky_loss,\n",
998
- " 'dice_coef': dice_coef},\n",
999
- " compile = True)\n",
1000
- "\n",
1001
- "# Verify IoU on a batch of images from the test dataset\n",
1002
- "batch_size = 8\n",
1003
- "test_img_datagen = imageLoader(val_img_dir, val_img_list,\n",
1004
- " val_mask_dir, val_mask_list, batch_size)\n",
1005
- "\n",
1006
- "test_image_batch, test_mask_batch = test_img_datagen.__next__()\n",
1007
- "\n",
1008
- "test_mask_batch_argmax = np.argmax(test_mask_batch, axis=4)\n",
1009
- "\n",
1010
- "results = my_model.evaluate(test_image_batch, test_mask_batch, batch_size=batch_size)\n",
1011
- "print(\"test acc, test loss:\", results)"
1012
- ]
1013
- },
1014
- {
1015
- "cell_type": "markdown",
1016
- "metadata": {
1017
- "id": "xvEqiU6SqY2y"
1018
- },
1019
- "source": [
1020
- "# Predict on a test scan"
1021
- ]
1022
- },
1023
- {
1024
- "cell_type": "code",
1025
- "execution_count": null,
1026
- "metadata": {
1027
- "id": "8-MUQpCiqcxd"
1028
- },
1029
- "outputs": [],
1030
- "source": [
1031
- "from tensorflow.keras.models import load_model\n",
1032
- "my_model = load_model('', compile=False)\n",
1033
- "\n",
1034
- "img_num = 53\n",
1035
- "test_scan = np.load('' + str(img_num) + '.npy')\n",
1036
- "\n",
1037
- "test_mask = np.load('' + str(img_num) + '.npy')\n",
1038
- "test_mask_argmax = np.argmax(test_mask, axis = 3)\n",
1039
- "\n",
1040
- "test_scan_input = np.expand_dims(test_scan, axis = 0)\n",
1041
- "test_prediction = my_model.predict(test_scan_input)\n",
1042
- "test_prediction_argmax = np.argmax(test_prediction, axis = 4)[0, :, :, :]"
1043
- ]
1044
- },
1045
- {
1046
- "cell_type": "code",
1047
- "execution_count": null,
1048
- "metadata": {
1049
- "colab": {
1050
- "background_save": true
1051
- },
1052
- "id": "65FmAMNhmX8E"
1053
- },
1054
- "outputs": [],
1055
- "source": [
1056
- "# n_slice = 55\n",
1057
- "n_slice = random.randint(0, test_mask_argmax.shape[2])\n",
1058
- "\n",
1059
- "plt.figure(figsize=(12,8))\n",
1060
- "plt.subplot(231)\n",
1061
- "plt.imshow(test_scan[:, :, n_slice, 1], cmap='gray')\n",
1062
- "plt.title('Testing Scan')\n",
1063
- "\n",
1064
- "plt.subplot(232)\n",
1065
- "plt.imshow(test_mask_argmax[:, :, n_slice])\n",
1066
- "plt.title('Testing Label')\n",
1067
- "\n",
1068
- "plt.subplot(235)\n",
1069
- "plt.imshow(test_prediction_argmax[:, :, n_slice])\n",
1070
- "plt.title('Prediction on test image')\n",
1071
- "\n",
1072
- "plt.show()"
1073
- ]
1074
- }
1075
- ],
1076
- "metadata": {
1077
- "accelerator": "GPU",
1078
- "colab": {
1079
- "collapsed_sections": [
1080
- "TdEse3Kwq3JD",
1081
- "L5yBxROtvDAI",
1082
- "EORoZoj7yPfW",
1083
- "-wICUx56ugDz",
1084
- "nq3p80zN2ew2",
1085
- "dBKMHMn96Z3c",
1086
- "PR2Ugre0YP-v",
1087
- "-Aw_Peb9iJYb",
1088
- "e6Cvn6hWvars",
1089
- "xndmsEwjVhn7",
1090
- "pfcKmJv4jP2J"
1091
- ],
1092
- "provenance": []
1093
- },
1094
- "gpuClass": "standard",
1095
- "kernelspec": {
1096
- "display_name": "Python 3",
1097
- "name": "python3"
1098
- },
1099
- "language_info": {
1100
- "name": "python"
1101
- }
1102
- },
1103
- "nbformat": 4,
1104
- "nbformat_minor": 0
1105
- }