|
import os |
|
import sys |
|
import numpy as np |
|
import torch |
|
import matplotlib.pyplot as plt |
|
from cityscapesscripts.helpers.labels import labels as cs_labels |
|
from datasets.cityscapes import get_cs_labeldata |
|
from datasets.cocostuff import get_coco_labeldata |
|
from datasets.potsdam import get_pd_labeldata |
|
|
|
sys.path.append(os.getcwd()) |
|
import modules.transforms as transforms |
|
|
|
|
|
def visualize_segmentation(img = None, |
|
label = None, |
|
linear = None, |
|
mlp = None, |
|
cluster = None, |
|
dataset_name = None, |
|
additional = None, |
|
additional_name = None, |
|
additional2 = None, |
|
additional_name2 = None, |
|
legend = None, |
|
name = None): |
|
|
|
|
|
if dataset_name == "cityscapes": |
|
colormap = np.array([ |
|
[128, 64, 128], |
|
[244, 35, 232], |
|
[250, 170, 160], |
|
[230, 150, 140], |
|
[70, 70, 70], |
|
[102, 102, 156], |
|
[190, 153, 153], |
|
[180, 165, 180], |
|
[150, 100, 100], |
|
[150, 120, 90], |
|
[153, 153, 153], |
|
[153, 153, 153], |
|
[250, 170, 30], |
|
[220, 220, 0], |
|
[107, 142, 35], |
|
[152, 251, 152], |
|
[70, 130, 180], |
|
[220, 20, 60], |
|
[255, 0, 0], |
|
[0, 0, 142], |
|
[0, 0, 70], |
|
[0, 60, 100], |
|
[0, 0, 90], |
|
[0, 0, 110], |
|
[0, 80, 100], |
|
[0, 0, 230], |
|
[119, 11, 32], |
|
[0, 0, 0], |
|
[220, 220, 220]]) |
|
elif dataset_name == "cocostuff": |
|
colormap = get_coco_labeldata()[-1] |
|
|
|
|
|
orig_h, orig_w = label.cpu().shape[-2:] |
|
img = img.cpu().squeeze(0).numpy().transpose(1, 2, 0) |
|
img = (img-img.min())/(img-img.min()).max() |
|
label = label.cpu().squeeze(0).numpy().transpose(1, 2, 0) |
|
|
|
|
|
label[label == 255] = 27 |
|
colored_label = colormap[label.flatten()] |
|
colored_label = colored_label.reshape(orig_h, orig_w, 3) |
|
|
|
num_subplots = 3 |
|
if linear != None: num_subplots += 1 |
|
if mlp != None: num_subplots += 1 |
|
if additional != None: num_subplots += 1 |
|
if additional2 != None: num_subplots += 1 |
|
|
|
|
|
fig = plt.figure(figsize=(8, 2), dpi=200) |
|
fig.tight_layout() |
|
plt.axis('off') |
|
plt.subplot(1, num_subplots, 1) |
|
plt.gca().set_title('Image') |
|
plt.imshow(img) |
|
plt.axis("off") |
|
plt.subplot(1, num_subplots, 2) |
|
plt.gca().set_title('Ground Truth') |
|
plt.imshow(colored_label) |
|
plt.axis("off") |
|
i = 3 |
|
if linear != None: |
|
linear = linear.cpu().numpy().transpose(1, 2, 0).astype('uint8') |
|
linear = colormap[linear.flatten()].reshape(linear.shape[0], linear.shape[1], 3) |
|
plt.axis("off") |
|
plt.subplot(1, num_subplots, i) |
|
plt.gca().set_title('Linear') |
|
plt.imshow(linear) |
|
i+=1 |
|
|
|
if mlp != None: |
|
mlp = mlp.cpu().numpy().transpose(1, 2, 0).astype('uint8') |
|
mlp = colormap[mlp.flatten()].reshape(mlp.shape[0], mlp.shape[1], 3) |
|
plt.axis("off") |
|
plt.subplot(1, num_subplots, i) |
|
plt.gca().set_title('MLP') |
|
plt.imshow(mlp) |
|
plt.axis("off") |
|
i+=1 |
|
|
|
if cluster != None: |
|
cluster = cluster.cpu().numpy().transpose(1, 2, 0).astype('uint8') |
|
cluster = colormap[cluster.flatten()].reshape(cluster.shape[0], cluster.shape[1], 3) |
|
plt.axis("off") |
|
plt.subplot(1, num_subplots, i) |
|
plt.gca().set_title('Cluster') |
|
plt.imshow(cluster) |
|
plt.axis("off") |
|
i+=1 |
|
|
|
if additional != None: |
|
|
|
additional = additional.cpu().numpy().transpose(1, 2, 0).astype('uint8') |
|
additional = colormap[additional.flatten()].reshape(additional.shape[0], additional.shape[1], 3) |
|
plt.axis("off") |
|
plt.subplot(1, num_subplots, i) |
|
plt.gca().set_title(additional_name) |
|
plt.imshow(additional) |
|
plt.axis("off") |
|
i+=1 |
|
|
|
if additional2 != None: |
|
additional2 = additional2.cpu().numpy() |
|
plt.axis("off") |
|
plt.subplot(1, num_subplots, i) |
|
plt.gca().set_title(additional_name2) |
|
plt.imshow(additional2) |
|
plt.axis("off") |
|
i+=1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if name != None: plt.savefig(name) |
|
fig.canvas.draw() |
|
|
|
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
|
plt.close('all') |
|
|
|
return data |
|
|
|
|
|
|
|
def visualize_confusion_matrix(cls_names, meter, name=None): |
|
|
|
conf_matrix = (meter.histogram/meter.histogram.sum(dim=0)) |
|
conf_matrix = np.array(conf_matrix.cpu(), dtype=np.float16) |
|
fig, ax = plt.subplots(figsize=(15, 15)) |
|
ax.matshow(torch.Tensor(conf_matrix).fill_diagonal_(0), cmap=plt.cm.Blues, alpha=0.8) |
|
for i in range(conf_matrix.shape[0]): |
|
for j in range(conf_matrix.shape[1]): |
|
ax.text(x=j, y=i,s=(conf_matrix[i, j]*100).round(1), va='center', ha='center', size='large') |
|
ax.set_xticks(list(range(cls_names.__len__()))) |
|
ax.set_xticklabels(cls_names, rotation=90, ha='center', fontsize=12) |
|
ax.set_yticks(list(range(cls_names.__len__()))) |
|
ax.set_yticklabels(cls_names, fontsize=12) |
|
plt.xlabel('Predictions', fontsize=18) |
|
plt.ylabel('Actuals', fontsize=18) |
|
plt.title('Confusion Matrix', fontsize=18) |
|
|
|
if name != None: plt.savefig(name) |
|
fig.canvas.draw() |
|
|
|
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
|
plt.close('all') |
|
return data |
|
|
|
|
|
|
|
|
|
|
|
def batch_visualize_segmentation(img = None, |
|
label = None, |
|
in1 = None, |
|
in2 = None, |
|
in3 = None, |
|
in4 = None, |
|
dataset_name = None): |
|
|
|
|
|
if dataset_name == "cityscapes": |
|
colormap = get_cs_labeldata()[-1] |
|
elif dataset_name == "cocostuff": |
|
colormap = get_coco_labeldata()[-1] |
|
elif dataset_name == "potsdam": |
|
colormap = get_pd_labeldata()[-1] |
|
|
|
def _vis_one_img(idx, img, label, ins): |
|
|
|
orig_h, orig_w = label.cpu().shape[-2:] |
|
img = img.cpu().numpy().transpose(1, 2, 0) |
|
img = (img-img.min())/(img-img.min()).max() |
|
label = label.cpu().numpy().transpose(1, 2, 0) |
|
label[label > 27] = 27 |
|
colored_label = colormap[label.flatten()].reshape(orig_h, orig_w, 3) |
|
|
|
num_subplots = sum([1 for x in [in1, in2, in3, in4] if x != None]) + 2 |
|
|
|
fig = plt.figure(figsize=(10, 2), dpi=150) |
|
fig.tight_layout() |
|
plt.axis('off') |
|
plt.subplot(1, num_subplots, 1) |
|
if idx == 0: plt.gca().set_title('Image') |
|
plt.imshow(img) |
|
plt.axis("off") |
|
plt.subplot(1, num_subplots, 2) |
|
if idx == 0: plt.gca().set_title('Ground Truth') |
|
plt.imshow(colored_label) |
|
plt.axis("off") |
|
if ins != None: |
|
i = 3 |
|
for input in ins: |
|
vis = input[1].cpu().numpy().transpose(1, 2, 0).astype('uint8') |
|
vis = colormap[vis.flatten()].reshape(vis.shape[0], vis.shape[1], 3) |
|
plt.axis("off") |
|
plt.subplot(1, num_subplots, i) |
|
if idx == 0: plt.gca().set_title(input[0]) |
|
plt.imshow(vis) |
|
plt.axis("off") |
|
i+=1 |
|
|
|
fig.canvas.draw() |
|
plt.close('all') |
|
one_vis = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
|
one_vis = one_vis.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
|
plt.close('all') |
|
return one_vis |
|
|
|
imgs = [] |
|
for idx, (data) in enumerate(zip(img, label)): |
|
imgs.append(_vis_one_img(idx, data[0], data[1], [[i[0], i[1][idx].unsqueeze(0)] for i in [in1, in2, in3, in4] if i!=None])) |
|
|
|
return np.vstack(imgs) |
|
|
|
|
|
|
|
def visualize_single_masks(img, |
|
label, |
|
data, |
|
dataset_name = None): |
|
|
|
|
|
if dataset_name == "cityscapes": |
|
colormap = get_cs_labeldata()[-1] |
|
elif dataset_name == "cocostuff": |
|
colormap = get_coco_labeldata()[-1] |
|
elif dataset_name == "potsdam": |
|
colormap = get_pd_labeldata()[-1] |
|
|
|
|
|
fig = plt.figure(figsize=(data['sim'].__len__()*2, 7*2), dpi=150) |
|
fig.tight_layout() |
|
for indx, (sim, nnsim, nnsim_thresh, crf, pamr, mask) in enumerate(zip(data['sim'], data['nnsim'], data['nnsim_tresh'], data['crf'], data['pamr'], data['outmask'])): |
|
rows = data['sim'].__len__() |
|
cols = 8 |
|
plotlabel=colormap[label.squeeze(0).squeeze(0).int().cpu()] |
|
plt.subplot(rows, cols, 1+(indx*cols)) |
|
img = (img-img.min())/(img.max()-img.min()) |
|
if indx == 0: plt.title('Image') |
|
plt.imshow(img.squeeze(0).permute(1, 2, 0).cpu()) |
|
plt.axis('off') |
|
plt.subplot(rows, cols, 2+(indx*cols)) |
|
if indx == 0: plt.title('GT') |
|
plt.imshow(plotlabel) |
|
plt.axis('off') |
|
plt.subplot(rows, cols, 3+(indx*cols)) |
|
if indx == 0: plt.title('1.Eig') |
|
plt.imshow(sim.cpu().numpy()) |
|
plt.axis('off') |
|
plt.subplot(rows, cols, 4+(indx*cols)) |
|
if indx == 0: plt.title('1.EigNN') |
|
plt.imshow(nnsim.cpu().numpy()) |
|
plt.axis('off') |
|
plt.subplot(rows, cols, 5+(indx*cols)) |
|
if indx == 0: plt.title('+Thresh') |
|
plt.imshow(nnsim_thresh) |
|
plt.axis('off') |
|
plt.subplot(rows, cols, 6+(indx*cols)) |
|
if indx == 0: plt.title('+CRF') |
|
plt.imshow(crf) |
|
plt.axis('off') |
|
plt.subplot(rows, cols, 7+(indx*cols)) |
|
if indx == 0: plt.title('PAMR') |
|
plt.imshow(pamr.squeeze().cpu().numpy()) |
|
plt.axis('off') |
|
plt.subplot(rows, cols, 8+(indx*cols)) |
|
if indx == 0: plt.title('Mask') |
|
mask[0, 0] = 0 |
|
plt.imshow(mask.numpy(), cmap='Greys') |
|
plt.axis('off') |
|
|
|
|
|
|
|
fig.canvas.draw() |
|
plt.close('all') |
|
one_vis = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
|
one_vis = one_vis.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
|
plt.close('all') |
|
return one_vis |
|
|
|
|
|
|
|
|
|
|
|
def visualize_pseudo_paper(img, |
|
label, |
|
pseudo_gt, |
|
pseudo_plain, |
|
dataset_name = None, |
|
save_name = None): |
|
|
|
|
|
if dataset_name == "cityscapes": |
|
colormap = get_cs_labeldata()[-1] |
|
elif dataset_name == "cocostuff": |
|
colormap = get_coco_labeldata()[-1] |
|
elif dataset_name == "potsdam": |
|
colormap = get_pd_labeldata()[-1] |
|
|
|
|
|
np.random.seed(0) |
|
cb_colomap = np.array([list(np.random.randint(0, 255, size=(1,3))[0]) for _ in range(400)]+[[0, 0, 0]]) |
|
pseudo_plain = pseudo_plain.int().cpu() |
|
pseudo_plain[pseudo_plain==255] = 400 |
|
pseudo_plain = cb_colomap[pseudo_plain.int().cpu()].squeeze() |
|
|
|
|
|
|
|
|
|
fig = plt.figure(figsize=(8, 2), dpi=150) |
|
fig.subplots_adjust(left=0.1, |
|
bottom=0.1, |
|
right=0.5, |
|
top=0.5, |
|
wspace=0.05, |
|
hspace=0.0) |
|
|
|
plt.subplot(1, 4, 1) |
|
img = (img-img.min())/(img.max()-img.min()) |
|
img = img.squeeze(0).permute(1, 2, 0).cpu() |
|
plt.imshow(img) |
|
plt.axis('off') |
|
|
|
plt.subplot(1, 4, 2) |
|
plotlabel=colormap[label.squeeze(0).squeeze(0).int().cpu()] |
|
plt.imshow(plotlabel) |
|
plt.axis('off') |
|
|
|
plt.subplot(1, 4, 3) |
|
plotpseudo=colormap[pseudo_gt.squeeze(0).squeeze(0).int().cpu()] |
|
|
|
|
|
plt.imshow(plotpseudo) |
|
plt.axis('off') |
|
|
|
plt.subplot(1, 4, 4) |
|
plt.imshow(pseudo_plain) |
|
plt.axis('off') |
|
plt.savefig(save_name+'.pdf', bbox_inches='tight', pad_inches=0.0) |
|
|
|
|
|
save_name_single = os.path.join(os.path.dirname(save_name), 'singleimgs/') |
|
os.makedirs(os.path.dirname(save_name_single), exist_ok=True) |
|
for i, n in zip([img, plotlabel, plotpseudo, pseudo_plain], ['img', 'gt', 'pseudo', 'pseudoc']): |
|
fig = plt.figure(figsize=(2, 2), dpi=300) |
|
plt.imshow(i) |
|
plt.axis('off') |
|
plt.savefig(os.path.join(save_name_single, os.path.split(save_name)[-1]+'_'+n+'.png'), bbox_inches='tight', pad_inches=0.0) |
|
|
|
|
|
|
|
|
|
|
|
def logits_to_image(logits = None, |
|
img = None, |
|
label = None, |
|
dataset_name = None, |
|
save_path = None, |
|
save_imggt = False): |
|
|
|
|
|
if dataset_name == "cityscapes": |
|
colormap = get_cs_labeldata()[-1] |
|
elif dataset_name == "cocostuff": |
|
colormap = get_coco_labeldata()[-1] |
|
elif dataset_name == "potsdam": |
|
colormap = get_pd_labeldata()[-1] |
|
|
|
vis = logits.cpu().numpy().transpose(1, 2, 0).astype('uint8') |
|
vis = colormap[vis.flatten()].reshape(vis.shape[0], vis.shape[1], 3) |
|
|
|
fig = plt.figure(figsize=(2, 2), dpi=400) |
|
fig.tight_layout() |
|
plt.subplot(1, 1, 1) |
|
plt.imshow(vis) |
|
plt.axis("off") |
|
plt.savefig(save_path+'_pred.png', bbox_inches='tight', pad_inches=0.0) |
|
plt.close('all') |
|
|
|
if save_imggt: |
|
orig_h, orig_w = label.cpu().shape[-2:] |
|
img = img.cpu().numpy().transpose(1, 2, 0) |
|
img = (img-img.min())/(img-img.min()).max() |
|
label = label.cpu().numpy().transpose(1, 2, 0) |
|
label[label > 27] = 27 |
|
colored_label = colormap[label.flatten()].reshape(orig_h, orig_w, 3) |
|
|
|
fig = plt.figure(figsize=(2, 2), dpi=400) |
|
fig.tight_layout() |
|
plt.subplot(1, 1, 1) |
|
plt.imshow(img) |
|
plt.axis("off") |
|
plt.savefig(save_path+'_img.png', bbox_inches='tight', pad_inches=0.0) |
|
plt.close('all') |
|
|
|
fig = plt.figure(figsize=(2, 2), dpi=400) |
|
fig.tight_layout() |
|
plt.subplot(1, 1, 1) |
|
plt.imshow(colored_label) |
|
plt.axis("off") |
|
plt.savefig(save_path+'_gt.png', bbox_inches='tight', pad_inches=0.0) |
|
plt.close('all') |
|
|
|
|
|
|
|
|
|
class Vis_Demo(): |
|
def __init__(self): |
|
super(Vis_Demo, self).__init__() |
|
self.colormap = get_coco_labeldata()[-1] |
|
|
|
def apply_colors(self, logits): |
|
vis = logits.cpu().numpy().transpose(1, 2, 0).astype('uint8') |
|
vis = self.colormap[vis.flatten()].reshape(vis.shape[0], vis.shape[1], 3) |
|
return vis |
|
|
|
|
|
|
|
def visualize_demo(img, pseudo, alpha = 0.5): |
|
np.random.seed(0) |
|
cb_colomap = np.array([list(np.random.randint(0, 255, size=(1,3))[0]) for _ in range(400)]+[[0, 0, 0]]) |
|
pseudo_plain = pseudo.long().cpu().numpy() |
|
pseudo_plain[pseudo_plain==255] = 400 |
|
pseudo_plain = cb_colomap[pseudo_plain].squeeze() |
|
|
|
img = transforms.UnNormalize()(img)*255 |
|
img = img.permute(1, 2, 0).long().cpu().numpy() |
|
out = alpha*img + (1-alpha)*pseudo_plain |
|
|
|
return np.array(out, dtype=np.uint8) |
|
|
|
|