PriMaPs / modules /visualization.py
Oliver Hahn
add demo
1dc26c7
raw
history blame contribute delete
16.1 kB
import os
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt
from cityscapesscripts.helpers.labels import labels as cs_labels
from datasets.cityscapes import get_cs_labeldata
from datasets.cocostuff import get_coco_labeldata
from datasets.potsdam import get_pd_labeldata
sys.path.append(os.getcwd())
import modules.transforms as transforms
def visualize_segmentation(img = None,
label = None,
linear = None,
mlp = None,
cluster = None,
dataset_name = None,
additional = None,
additional_name = None,
additional2 = None,
additional_name2 = None,
legend = None,
name = None):
if dataset_name == "cityscapes":
colormap = np.array([
[128, 64, 128],
[244, 35, 232],
[250, 170, 160],
[230, 150, 140],
[70, 70, 70],
[102, 102, 156],
[190, 153, 153],
[180, 165, 180],
[150, 100, 100],
[150, 120, 90],
[153, 153, 153],
[153, 153, 153],
[250, 170, 30],
[220, 220, 0],
[107, 142, 35],
[152, 251, 152],
[70, 130, 180],
[220, 20, 60],
[255, 0, 0],
[0, 0, 142],
[0, 0, 70],
[0, 60, 100],
[0, 0, 90],
[0, 0, 110],
[0, 80, 100],
[0, 0, 230],
[119, 11, 32],
[0, 0, 0],
[220, 220, 220]])
elif dataset_name == "cocostuff":
colormap = get_coco_labeldata()[-1]
orig_h, orig_w = label.cpu().shape[-2:]
img = img.cpu().squeeze(0).numpy().transpose(1, 2, 0)
img = (img-img.min())/(img-img.min()).max()
label = label.cpu().squeeze(0).numpy().transpose(1, 2, 0)
#transforms.labelIdsToTrainIds(source="cityscapes", target="cityscapes")
label[label == 255] = 27
colored_label = colormap[label.flatten()]
colored_label = colored_label.reshape(orig_h, orig_w, 3)
num_subplots = 3
if linear != None: num_subplots += 1
if mlp != None: num_subplots += 1
if additional != None: num_subplots += 1
if additional2 != None: num_subplots += 1
fig = plt.figure(figsize=(8, 2), dpi=200)
fig.tight_layout()
plt.axis('off')
plt.subplot(1, num_subplots, 1)
plt.gca().set_title('Image')
plt.imshow(img)
plt.axis("off")
plt.subplot(1, num_subplots, 2)
plt.gca().set_title('Ground Truth')
plt.imshow(colored_label)
plt.axis("off")
i = 3
if linear != None:
linear = linear.cpu().numpy().transpose(1, 2, 0).astype('uint8')
linear = colormap[linear.flatten()].reshape(linear.shape[0], linear.shape[1], 3)
plt.axis("off")
plt.subplot(1, num_subplots, i)
plt.gca().set_title('Linear')
plt.imshow(linear)
i+=1
if mlp != None:
mlp = mlp.cpu().numpy().transpose(1, 2, 0).astype('uint8')
mlp = colormap[mlp.flatten()].reshape(mlp.shape[0], mlp.shape[1], 3)
plt.axis("off")
plt.subplot(1, num_subplots, i)
plt.gca().set_title('MLP')
plt.imshow(mlp)
plt.axis("off")
i+=1
if cluster != None:
cluster = cluster.cpu().numpy().transpose(1, 2, 0).astype('uint8')
cluster = colormap[cluster.flatten()].reshape(cluster.shape[0], cluster.shape[1], 3)
plt.axis("off")
plt.subplot(1, num_subplots, i)
plt.gca().set_title('Cluster')
plt.imshow(cluster)
plt.axis("off")
i+=1
if additional != None:
#additional = additional.cpu().numpy()
additional = additional.cpu().numpy().transpose(1, 2, 0).astype('uint8')
additional = colormap[additional.flatten()].reshape(additional.shape[0], additional.shape[1], 3)
plt.axis("off")
plt.subplot(1, num_subplots, i)
plt.gca().set_title(additional_name)
plt.imshow(additional)
plt.axis("off")
i+=1
if additional2 != None:
additional2 = additional2.cpu().numpy()
plt.axis("off")
plt.subplot(1, num_subplots, i)
plt.gca().set_title(additional_name2)
plt.imshow(additional2)
plt.axis("off")
i+=1
# if legend != None:
# from matplotlib.lines import Line2D
# legend_elements = [Line2D([0], [0], color=np.array(cls[7])/255, lw=4, label=cls[0]) for cls in cs_labels[7:-1]]
# # Create the figure
# #fig, ax = plt.subplots()
# plt.legend(handles=legend_elements, loc='right')
if name != None: plt.savefig(name)
fig.canvas.draw()
# Now we can save it to a numpy array.
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close('all')
return data
def visualize_confusion_matrix(cls_names, meter, name=None):
# plot of confusion matrix
conf_matrix = (meter.histogram/meter.histogram.sum(dim=0))
conf_matrix = np.array(conf_matrix.cpu(), dtype=np.float16)
fig, ax = plt.subplots(figsize=(15, 15))
ax.matshow(torch.Tensor(conf_matrix).fill_diagonal_(0), cmap=plt.cm.Blues, alpha=0.8)
for i in range(conf_matrix.shape[0]):
for j in range(conf_matrix.shape[1]):
ax.text(x=j, y=i,s=(conf_matrix[i, j]*100).round(1), va='center', ha='center', size='large')
ax.set_xticks(list(range(cls_names.__len__())))
ax.set_xticklabels(cls_names, rotation=90, ha='center', fontsize=12)
ax.set_yticks(list(range(cls_names.__len__())))
ax.set_yticklabels(cls_names, fontsize=12)
plt.xlabel('Predictions', fontsize=18)
plt.ylabel('Actuals', fontsize=18)
plt.title('Confusion Matrix', fontsize=18)
if name != None: plt.savefig(name)
fig.canvas.draw()
# Now we can save it to a numpy array.
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close('all')
return data
def batch_visualize_segmentation(img = None,
label = None,
in1 = None,
in2 = None,
in3 = None,
in4 = None,
dataset_name = None):
if dataset_name == "cityscapes":
colormap = get_cs_labeldata()[-1]
elif dataset_name == "cocostuff":
colormap = get_coco_labeldata()[-1]
elif dataset_name == "potsdam":
colormap = get_pd_labeldata()[-1]
def _vis_one_img(idx, img, label, ins):
orig_h, orig_w = label.cpu().shape[-2:]
img = img.cpu().numpy().transpose(1, 2, 0)
img = (img-img.min())/(img-img.min()).max()
label = label.cpu().numpy().transpose(1, 2, 0)
label[label > 27] = 27
colored_label = colormap[label.flatten()].reshape(orig_h, orig_w, 3)
num_subplots = sum([1 for x in [in1, in2, in3, in4] if x != None]) + 2
fig = plt.figure(figsize=(10, 2), dpi=150)
fig.tight_layout()
plt.axis('off')
plt.subplot(1, num_subplots, 1)
if idx == 0: plt.gca().set_title('Image')
plt.imshow(img)
plt.axis("off")
plt.subplot(1, num_subplots, 2)
if idx == 0: plt.gca().set_title('Ground Truth')
plt.imshow(colored_label)
plt.axis("off")
if ins != None:
i = 3
for input in ins:
vis = input[1].cpu().numpy().transpose(1, 2, 0).astype('uint8')
vis = colormap[vis.flatten()].reshape(vis.shape[0], vis.shape[1], 3)
plt.axis("off")
plt.subplot(1, num_subplots, i)
if idx == 0: plt.gca().set_title(input[0])
plt.imshow(vis)
plt.axis("off")
i+=1
fig.canvas.draw()
plt.close('all')
one_vis = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
one_vis = one_vis.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close('all')
return one_vis
imgs = []
for idx, (data) in enumerate(zip(img, label)):
imgs.append(_vis_one_img(idx, data[0], data[1], [[i[0], i[1][idx].unsqueeze(0)] for i in [in1, in2, in3, in4] if i!=None]))
return np.vstack(imgs)
def visualize_single_masks(img,
label,
data,
dataset_name = None):
if dataset_name == "cityscapes":
colormap = get_cs_labeldata()[-1]
elif dataset_name == "cocostuff":
colormap = get_coco_labeldata()[-1]
elif dataset_name == "potsdam":
colormap = get_pd_labeldata()[-1]
fig = plt.figure(figsize=(data['sim'].__len__()*2, 7*2), dpi=150)
fig.tight_layout()
for indx, (sim, nnsim, nnsim_thresh, crf, pamr, mask) in enumerate(zip(data['sim'], data['nnsim'], data['nnsim_tresh'], data['crf'], data['pamr'], data['outmask'])):
rows = data['sim'].__len__()
cols = 8
plotlabel=colormap[label.squeeze(0).squeeze(0).int().cpu()]
plt.subplot(rows, cols, 1+(indx*cols))
img = (img-img.min())/(img.max()-img.min())
if indx == 0: plt.title('Image')
plt.imshow(img.squeeze(0).permute(1, 2, 0).cpu())
plt.axis('off')
plt.subplot(rows, cols, 2+(indx*cols))
if indx == 0: plt.title('GT')
plt.imshow(plotlabel)
plt.axis('off')
plt.subplot(rows, cols, 3+(indx*cols))
if indx == 0: plt.title('1.Eig')
plt.imshow(sim.cpu().numpy())
plt.axis('off')
plt.subplot(rows, cols, 4+(indx*cols))
if indx == 0: plt.title('1.EigNN')
plt.imshow(nnsim.cpu().numpy())
plt.axis('off')
plt.subplot(rows, cols, 5+(indx*cols))
if indx == 0: plt.title('+Thresh')
plt.imshow(nnsim_thresh)
plt.axis('off')
plt.subplot(rows, cols, 6+(indx*cols))
if indx == 0: plt.title('+CRF')
plt.imshow(crf)
plt.axis('off')
plt.subplot(rows, cols, 7+(indx*cols))
if indx == 0: plt.title('PAMR')
plt.imshow(pamr.squeeze().cpu().numpy())
plt.axis('off')
plt.subplot(rows, cols, 8+(indx*cols))
if indx == 0: plt.title('Mask')
mask[0, 0] = 0
plt.imshow(mask.numpy(), cmap='Greys')
plt.axis('off')
# plt.savefig(str(idx)+'.png', tight_layout=True)
fig.canvas.draw()
plt.close('all')
one_vis = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
one_vis = one_vis.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close('all')
return one_vis
def visualize_pseudo_paper(img,
label,
pseudo_gt,
pseudo_plain,
dataset_name = None,
save_name = None):
if dataset_name == "cityscapes":
colormap = get_cs_labeldata()[-1]
elif dataset_name == "cocostuff":
colormap = get_coco_labeldata()[-1]
elif dataset_name == "potsdam":
colormap = get_pd_labeldata()[-1]
np.random.seed(0)
cb_colomap = np.array([list(np.random.randint(0, 255, size=(1,3))[0]) for _ in range(400)]+[[0, 0, 0]])
pseudo_plain = pseudo_plain.int().cpu()
pseudo_plain[pseudo_plain==255] = 400
pseudo_plain = cb_colomap[pseudo_plain.int().cpu()].squeeze()
fig = plt.figure(figsize=(8, 2), dpi=150)
fig.subplots_adjust(left=0.1,
bottom=0.1,
right=0.5,
top=0.5,
wspace=0.05,
hspace=0.0)
plt.subplot(1, 4, 1)
img = (img-img.min())/(img.max()-img.min())
img = img.squeeze(0).permute(1, 2, 0).cpu()
plt.imshow(img)
plt.axis('off')
plt.subplot(1, 4, 2)
plotlabel=colormap[label.squeeze(0).squeeze(0).int().cpu()]
plt.imshow(plotlabel)
plt.axis('off')
plt.subplot(1, 4, 3)
plotpseudo=colormap[pseudo_gt.squeeze(0).squeeze(0).int().cpu()]
# pseudo_plain = np.array(pseudo_plain.cpu(), dtype=np.int16).squeeze()
# plotpseudo = mark_boundaries(plotlabel/255, pseudo_plain, color=(1, 1, 1))
plt.imshow(plotpseudo)
plt.axis('off')
plt.subplot(1, 4, 4)
plt.imshow(pseudo_plain)
plt.axis('off')
plt.savefig(save_name+'.pdf', bbox_inches='tight', pad_inches=0.0)
save_name_single = os.path.join(os.path.dirname(save_name), 'singleimgs/')
os.makedirs(os.path.dirname(save_name_single), exist_ok=True)
for i, n in zip([img, plotlabel, plotpseudo, pseudo_plain], ['img', 'gt', 'pseudo', 'pseudoc']):
fig = plt.figure(figsize=(2, 2), dpi=300)
plt.imshow(i)
plt.axis('off')
plt.savefig(os.path.join(save_name_single, os.path.split(save_name)[-1]+'_'+n+'.png'), bbox_inches='tight', pad_inches=0.0)
def logits_to_image(logits = None,
img = None,
label = None,
dataset_name = None,
save_path = None,
save_imggt = False):
if dataset_name == "cityscapes":
colormap = get_cs_labeldata()[-1]
elif dataset_name == "cocostuff":
colormap = get_coco_labeldata()[-1]
elif dataset_name == "potsdam":
colormap = get_pd_labeldata()[-1]
vis = logits.cpu().numpy().transpose(1, 2, 0).astype('uint8')
vis = colormap[vis.flatten()].reshape(vis.shape[0], vis.shape[1], 3)
fig = plt.figure(figsize=(2, 2), dpi=400)
fig.tight_layout()
plt.subplot(1, 1, 1)
plt.imshow(vis)
plt.axis("off")
plt.savefig(save_path+'_pred.png', bbox_inches='tight', pad_inches=0.0)
plt.close('all')
if save_imggt:
orig_h, orig_w = label.cpu().shape[-2:]
img = img.cpu().numpy().transpose(1, 2, 0)
img = (img-img.min())/(img-img.min()).max()
label = label.cpu().numpy().transpose(1, 2, 0)
label[label > 27] = 27
colored_label = colormap[label.flatten()].reshape(orig_h, orig_w, 3)
fig = plt.figure(figsize=(2, 2), dpi=400)
fig.tight_layout()
plt.subplot(1, 1, 1)
plt.imshow(img)
plt.axis("off")
plt.savefig(save_path+'_img.png', bbox_inches='tight', pad_inches=0.0)
plt.close('all')
fig = plt.figure(figsize=(2, 2), dpi=400)
fig.tight_layout()
plt.subplot(1, 1, 1)
plt.imshow(colored_label)
plt.axis("off")
plt.savefig(save_path+'_gt.png', bbox_inches='tight', pad_inches=0.0)
plt.close('all')
class Vis_Demo():
def __init__(self):
super(Vis_Demo, self).__init__()
self.colormap = get_coco_labeldata()[-1]
def apply_colors(self, logits):
vis = logits.cpu().numpy().transpose(1, 2, 0).astype('uint8')
vis = self.colormap[vis.flatten()].reshape(vis.shape[0], vis.shape[1], 3)
return vis
def visualize_demo(img, pseudo, alpha = 0.5):
np.random.seed(0)
cb_colomap = np.array([list(np.random.randint(0, 255, size=(1,3))[0]) for _ in range(400)]+[[0, 0, 0]])
pseudo_plain = pseudo.long().cpu().numpy()
pseudo_plain[pseudo_plain==255] = 400
pseudo_plain = cb_colomap[pseudo_plain].squeeze()
img = transforms.UnNormalize()(img)*255
img = img.permute(1, 2, 0).long().cpu().numpy()
out = alpha*img + (1-alpha)*pseudo_plain
return np.array(out, dtype=np.uint8)