theodore-ioann commited on
Commit
171a6dc
·
verified ·
1 Parent(s): 8bbbbb6

Update utils.py

Browse files

remove extra functions used only for training

Files changed (1) hide show
  1. utils.py +83 -204
utils.py CHANGED
@@ -1,205 +1,84 @@
1
- import cv2
2
- import torch
3
- import numpy as np
4
- from PIL import Image
5
- import matplotlib.pyplot as plt
6
- from supervised import UNet, Segformer, Inception
7
- from sklearn.cluster import KMeans
8
- from sklearn.mixture import GaussianMixture
9
- from torchvision import transforms
10
- from sklearn.metrics import accuracy_score, jaccard_score, f1_score, confusion_matrix, ConfusionMatrixDisplay
11
-
12
- def postprocess(masks, mode="open", kernel_size=5, iters=1):
13
- kernel = np.ones((kernel_size, kernel_size), np.uint8)
14
- if mode == "open":
15
- new_masks = [cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel, iterations=iters) for mask in masks]
16
- elif mode == "close":
17
- new_masks = [cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel, iterations=iters) for mask in masks]
18
- elif mode == "erosion":
19
- new_masks = [cv2.erode(mask.astype(np.uint8), kernel, iterations=iters) for mask in masks]
20
- elif mode == "dilation":
21
- new_masks = [cv2.dilate(mask.astype(np.uint8), kernel, iterations=iters) for mask in masks]
22
- else:
23
- new_masks = masks
24
- return new_masks
25
-
26
- def fix_labels(pred_masks, gt_masks, lesion_positive=True):
27
- """
28
- Flip predicted masks if needed based on GT, and ensure lesion is 1.
29
- If lesion_positive=True, final output has lesion as 1.
30
- """
31
- fixed_preds = []
32
-
33
- for pred, gt in zip(pred_masks, gt_masks):
34
- pred = pred.astype(np.uint8)
35
- gt = (gt > 0).astype(np.uint8)
36
-
37
- # Flatten for metric comparison
38
- pred_flat = pred.flatten()
39
- gt_flat = gt.flatten()
40
-
41
- # Try both label assignments
42
- iou_0 = jaccard_score(gt_flat, (pred_flat == 0))
43
- iou_1 = jaccard_score(gt_flat, (pred_flat == 1))
44
-
45
- # Flip if label 0 gives better IoU
46
- if iou_0 > iou_1:
47
- pred = 1 - pred
48
-
49
- # Optional: ensure lesion is positive (class 1)
50
- if lesion_positive:
51
- # If GT has more lesion pixels than background, make sure pred does too
52
- gt_lesion_ratio = np.sum(gt) / gt.size
53
- pred_lesion_ratio = np.sum(pred) / pred.size
54
-
55
- if pred_lesion_ratio < 0.5 and gt_lesion_ratio > 0.5:
56
- pred = 1 - pred
57
-
58
- fixed_preds.append(pred)
59
-
60
- return fixed_preds
61
-
62
- def evaluate_masks(pred_masks, gt_masks):
63
- """
64
- Evaluate predicted masks.
65
- Returns mean metrics (accuracy, iou, f1).
66
- """
67
- acc_list = []
68
- iou_list = []
69
- f1_list = []
70
- cm = np.zeros((2, 2), dtype=int)
71
- for pred, gt in zip(pred_masks, gt_masks):
72
- pred_flat = pred.flatten()
73
- gt_flat = (gt.flatten() > 0).astype(np.uint8)
74
-
75
- acc0 = accuracy_score(gt_flat, (pred_flat == 0))
76
- acc1 = accuracy_score(gt_flat, (pred_flat == 1))
77
-
78
- acc = accuracy_score(gt_flat, pred_flat)
79
- iou = jaccard_score(gt_flat, pred_flat)
80
- f1 = f1_score(gt_flat, pred_flat)
81
-
82
- acc_list.append(acc)
83
- iou_list.append(iou)
84
- f1_list.append(f1)
85
- cm += confusion_matrix(gt_flat, pred_flat, labels=[0, 1])
86
-
87
- mean_acc = np.mean(acc_list)
88
- mean_iou = np.mean(iou_list)
89
- mean_f1 = np.mean(f1_list)
90
-
91
- print(f"Mean Accuracy: {mean_acc:.4f}")
92
- print(f"Mean IoU (Jaccard): {mean_iou:.4f}")
93
- print(f"Mean F1 Score (Dice): {mean_f1:.4f}")
94
-
95
- disp = ConfusionMatrixDisplay(cm, display_labels=["Background", "Lesion"])
96
- disp.plot(cmap="Blues", values_format="d")
97
- plt.title("Confusion Matrix (Aggregated)")
98
- plt.show()
99
-
100
- # Plot histograms
101
- plt.figure(figsize=(15, 4))
102
- plt.subplot(1, 3, 1)
103
- plt.hist(acc_list, bins=10, color='r', alpha=0.6, edgecolor='black')
104
- plt.title("Accuracy Distribution")
105
-
106
- plt.subplot(1, 3, 2)
107
- plt.hist(iou_list, bins=10, color='g', alpha=0.6, edgecolor='black')
108
- plt.title("IoU Distribution")
109
-
110
- plt.subplot(1, 3, 3)
111
- plt.hist(f1_list, bins=10, color='skyblue', alpha=0.6, edgecolor='black')
112
- plt.title("F1 Score Distribution")
113
-
114
- plt.tight_layout()
115
- plt.show()
116
-
117
- def overlay_mask(image, mask, color=(255, 0, 0), alpha=0.5):
118
- """
119
- Overlay a binary mask on top of an image.
120
- - image: (H, W, 3) numpy array, RGB
121
- - mask: (H, W) numpy array, 0/1 values or 0/255
122
- - color: RGB tuple for mask color
123
- - alpha: transparency factor (0=transparent, 1=opaque)
124
- """
125
- image = image.copy()
126
-
127
- # Make sure mask is binary 0 or 1
128
- if mask.max() > 1:
129
- mask = (mask > 127).astype(np.uint8)
130
-
131
- # Create colored mask
132
- colored_mask = np.zeros_like(image)
133
- colored_mask[:, :, 0] = color[0]
134
- colored_mask[:, :, 1] = color[1]
135
- colored_mask[:, :, 2] = color[2]
136
-
137
- # Apply mask
138
- mask_3d = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
139
- overlay = np.where(mask_3d, (1 - alpha) * image + alpha * colored_mask, image)
140
-
141
- return overlay.astype(np.uint8)
142
-
143
-
144
- def visualize_overlay(image, gt_mask, pred_mask, post_mask=None, alpha=0.5):
145
- """
146
- Plot original image + overlay GT mask and Predicted mask.
147
- """
148
- plt.figure(figsize=(18, 6))
149
-
150
- # Original
151
- plt.subplot(1, 3, 1)
152
- plt.imshow(image)
153
- plt.title("Original Image")
154
- plt.axis("off")
155
-
156
- # Ground Truth Overlay
157
- overlay_gt = overlay_mask(image, gt_mask, color=(0, 255, 0), alpha=alpha)
158
- plt.subplot(1, 3, 2)
159
- plt.imshow(overlay_gt)
160
- plt.title("Ground Truth Overlay (Green)")
161
- plt.axis("off")
162
-
163
- # Predicted Overlay
164
- overlay_pred = overlay_mask(image, pred_mask, color=(255, 0, 0), alpha=alpha)
165
- plt.subplot(1, 3, 3)
166
- plt.imshow(overlay_pred)
167
- plt.title("Prediction Overlay (Red)")
168
- plt.axis("off")
169
-
170
- plt.tight_layout()
171
- plt.show()
172
-
173
- def predict_and_visualize_single(model, image_path, postprocess_mode='none', alpha=0.5, device='cpu'):
174
- image = Image.fromarray(image_path).convert('RGB')
175
- original_np = np.array(image.resize((128, 128)))
176
-
177
- transform = transforms.Compose([
178
- transforms.Resize((128, 128)),
179
- transforms.ToTensor()
180
- ])
181
- input_tensor = transform(image).unsqueeze(0).to(device)
182
-
183
- if isinstance(model, (UNet, Segformer, Inception)):
184
- with torch.no_grad():
185
- output = model(input_tensor)
186
- if isinstance(output, dict):
187
- output = output.get("logits") or output.get("out")
188
- pred_mask = torch.argmax(output.squeeze(), dim=0).cpu().numpy()
189
- elif isinstance(model, (KMeans, GaussianMixture)):
190
- model.fit(original_np.reshape(-1, 3))
191
- pred_mask = model.predict(original_np.reshape(-1, 3)).reshape(128, 128)
192
-
193
- if postprocess_mode != 'none':
194
- pred_mask = postprocess([pred_mask], mode=postprocess_mode)[0]
195
-
196
- bw_mask = (pred_mask * 255).astype(np.uint8)
197
- overlay = overlay_mask(original_np, pred_mask, color=(255, 0, 0), alpha=alpha)
198
- # Resize outputs to 384x384
199
- bw_mask = cv2.resize(pred_mask.astype(np.uint8) * 255, (256, 256), interpolation=cv2.INTER_NEAREST)
200
- overlay = cv2.resize(overlay_mask(original_np, pred_mask, color=(255, 0, 0), alpha=alpha),
201
- (256, 256),
202
- interpolation=cv2.INTER_LINEAR
203
- )
204
-
205
  return bw_mask, overlay
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+ from supervised import UNet, Segformer, Inception
7
+ from sklearn.cluster import KMeans
8
+ from sklearn.mixture import GaussianMixture
9
+ from torchvision import transforms
10
+ from sklearn.metrics import accuracy_score, jaccard_score, f1_score, confusion_matrix, ConfusionMatrixDisplay
11
+
12
+ def postprocess(masks, mode="open", kernel_size=5, iters=1):
13
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
14
+ if mode == "open":
15
+ new_masks = [cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel, iterations=iters) for mask in masks]
16
+ elif mode == "close":
17
+ new_masks = [cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel, iterations=iters) for mask in masks]
18
+ elif mode == "erosion":
19
+ new_masks = [cv2.erode(mask.astype(np.uint8), kernel, iterations=iters) for mask in masks]
20
+ elif mode == "dilation":
21
+ new_masks = [cv2.dilate(mask.astype(np.uint8), kernel, iterations=iters) for mask in masks]
22
+ else:
23
+ new_masks = masks
24
+ return new_masks
25
+
26
+ def overlay_mask(image, mask, color=(255, 0, 0), alpha=0.5):
27
+ """
28
+ Overlay a binary mask on top of an image.
29
+ - image: (H, W, 3) numpy array, RGB
30
+ - mask: (H, W) numpy array, 0/1 values or 0/255
31
+ - color: RGB tuple for mask color
32
+ - alpha: transparency factor (0=transparent, 1=opaque)
33
+ """
34
+ image = image.copy()
35
+
36
+ # Make sure mask is binary 0 or 1
37
+ if mask.max() > 1:
38
+ mask = (mask > 127).astype(np.uint8)
39
+
40
+ # Create colored mask
41
+ colored_mask = np.zeros_like(image)
42
+ colored_mask[:, :, 0] = color[0]
43
+ colored_mask[:, :, 1] = color[1]
44
+ colored_mask[:, :, 2] = color[2]
45
+
46
+ # Apply mask
47
+ mask_3d = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
48
+ overlay = np.where(mask_3d, (1 - alpha) * image + alpha * colored_mask, image)
49
+
50
+ return overlay.astype(np.uint8)
51
+
52
+ def predict_and_visualize_single(model, image_path, postprocess_mode='none', alpha=0.5, device='cpu'):
53
+ image = Image.fromarray(image_path).convert('RGB')
54
+ original_np = np.array(image.resize((128, 128)))
55
+
56
+ transform = transforms.Compose([
57
+ transforms.Resize((128, 128)),
58
+ transforms.ToTensor()
59
+ ])
60
+ input_tensor = transform(image).unsqueeze(0).to(device)
61
+
62
+ if isinstance(model, (UNet, Segformer, Inception)):
63
+ with torch.no_grad():
64
+ output = model(input_tensor)
65
+ if isinstance(output, dict):
66
+ output = output.get("logits") or output.get("out")
67
+ pred_mask = torch.argmax(output.squeeze(), dim=0).cpu().numpy()
68
+ elif isinstance(model, (KMeans, GaussianMixture)):
69
+ model.fit(original_np.reshape(-1, 3))
70
+ pred_mask = model.predict(original_np.reshape(-1, 3)).reshape(128, 128)
71
+
72
+ if postprocess_mode != 'none':
73
+ pred_mask = postprocess([pred_mask], mode=postprocess_mode)[0]
74
+
75
+ bw_mask = (pred_mask * 255).astype(np.uint8)
76
+ overlay = overlay_mask(original_np, pred_mask, color=(255, 0, 0), alpha=alpha)
77
+ # Resize outputs to 384x384
78
+ bw_mask = cv2.resize(pred_mask.astype(np.uint8) * 255, (256, 256), interpolation=cv2.INTER_NEAREST)
79
+ overlay = cv2.resize(overlay_mask(original_np, pred_mask, color=(255, 0, 0), alpha=alpha),
80
+ (256, 256),
81
+ interpolation=cv2.INTER_LINEAR
82
+ )
83
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  return bw_mask, overlay