import cv2 import torch import numpy as np from PIL import Image import matplotlib.pyplot as plt from supervised import UNet, Segformer, Inception from sklearn.cluster import KMeans from sklearn.mixture import GaussianMixture from torchvision import transforms from sklearn.metrics import accuracy_score, jaccard_score, f1_score, confusion_matrix, ConfusionMatrixDisplay def postprocess(masks, mode="open", kernel_size=5, iters=1): kernel = np.ones((kernel_size, kernel_size), np.uint8) if mode == "open": new_masks = [cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel, iterations=iters) for mask in masks] elif mode == "close": new_masks = [cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel, iterations=iters) for mask in masks] elif mode == "erosion": new_masks = [cv2.erode(mask.astype(np.uint8), kernel, iterations=iters) for mask in masks] elif mode == "dilation": new_masks = [cv2.dilate(mask.astype(np.uint8), kernel, iterations=iters) for mask in masks] else: new_masks = masks return new_masks def overlay_mask(image, mask, color=(255, 0, 0), alpha=0.5): """ Overlay a binary mask on top of an image. - image: (H, W, 3) numpy array, RGB - mask: (H, W) numpy array, 0/1 values or 0/255 - color: RGB tuple for mask color - alpha: transparency factor (0=transparent, 1=opaque) """ image = image.copy() # Make sure mask is binary 0 or 1 if mask.max() > 1: mask = (mask > 127).astype(np.uint8) # Create colored mask colored_mask = np.zeros_like(image) colored_mask[:, :, 0] = color[0] colored_mask[:, :, 1] = color[1] colored_mask[:, :, 2] = color[2] # Apply mask mask_3d = np.repeat(mask[:, :, np.newaxis], 3, axis=2) overlay = np.where(mask_3d, (1 - alpha) * image + alpha * colored_mask, image) return overlay.astype(np.uint8) def predict_and_visualize_single(model, image_path, postprocess_mode='none', alpha=0.5, device='cpu'): image = Image.fromarray(image_path).convert('RGB') original_np = np.array(image.resize((128, 128))) transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor() ]) input_tensor = transform(image).unsqueeze(0).to(device) if isinstance(model, (UNet, Segformer, Inception)): with torch.no_grad(): output = model(input_tensor) if isinstance(output, dict): output = output.get("logits") or output.get("out") pred_mask = torch.argmax(output.squeeze(), dim=0).cpu().numpy() elif isinstance(model, (KMeans, GaussianMixture)): model.fit(original_np.reshape(-1, 3)) pred_mask = model.predict(original_np.reshape(-1, 3)).reshape(128, 128) if postprocess_mode != 'none': pred_mask = postprocess([pred_mask], mode=postprocess_mode)[0] bw_mask = (pred_mask * 255).astype(np.uint8) overlay = overlay_mask(original_np, pred_mask, color=(255, 0, 0), alpha=alpha) # Resize outputs to 384x384 bw_mask = cv2.resize(pred_mask.astype(np.uint8) * 255, (256, 256), interpolation=cv2.INTER_NEAREST) overlay = cv2.resize(overlay_mask(original_np, pred_mask, color=(255, 0, 0), alpha=alpha), (256, 256), interpolation=cv2.INTER_LINEAR ) return bw_mask, overlay