|
|
|
|
|
import numpy as np
|
|
import os
|
|
import ntpath
|
|
import time
|
|
import glob
|
|
from scipy.misc import imresize
|
|
import torchvision.utils as vutils
|
|
from operator import itemgetter
|
|
from tensorboardX import SummaryWriter
|
|
|
|
|
|
class Visualizer():
|
|
def __init__(self, checkpoints_dir, name):
|
|
self.win_size = 256
|
|
self.name = name
|
|
self.saved = False
|
|
self.checkpoints_dir = checkpoints_dir
|
|
self.ncols = 4
|
|
|
|
|
|
for filename in glob.glob(self.checkpoints_dir+"/events*"):
|
|
os.remove(filename)
|
|
self.writer = SummaryWriter(checkpoints_dir)
|
|
|
|
def reset(self):
|
|
self.saved = False
|
|
|
|
|
|
def image_summary(self, mode, epoch, images):
|
|
images = vutils.make_grid(images, normalize=True, scale_each=True)
|
|
self.writer.add_image('{}/Image'.format(mode), images, epoch)
|
|
|
|
|
|
def text_summary(self, mode, epoch, type, text, vocabulary, gt=True, max_length=20):
|
|
for i, el in enumerate(text):
|
|
if not gt:
|
|
idx = el.nonzero().squeeze() + 1
|
|
else:
|
|
idx = el
|
|
|
|
words_list = itemgetter(*idx)(vocabulary)
|
|
|
|
if len(words_list) <= max_length:
|
|
self.writer.add_text('{}/{}_{}_{}'.format(mode, type, i, 'gt' if gt else 'prediction'),
|
|
', '.join(filter(lambda x: x != '<pad>', words_list)), epoch)
|
|
else:
|
|
self.writer.add_text('{}/{}_{}_{}'.format(mode, type, i, 'gt' if gt else 'prediction'),
|
|
'Number of sampled ingredients is too big: {}'.format(len(words_list)), epoch)
|
|
|
|
|
|
def scalar_summary(self, mode, epoch, **args):
|
|
for k, v in args.items():
|
|
self.writer.add_scalar('{}/{}'.format(mode, k), v, epoch)
|
|
|
|
self.writer.export_scalars_to_json("{}/tensorboard_all_scalars.json".format(self.checkpoints_dir))
|
|
|
|
def histo_summary(self, model, step):
|
|
"""Log a histogram of the tensor of values."""
|
|
|
|
for name, param in model.named_parameters():
|
|
self.writer.add_histogram(name, param, step)
|
|
|
|
def close(self):
|
|
self.writer.close()
|
|
|