import os import time import random import logging import torch import numpy as np import yaml from easydict import EasyDict from logging import Logger from tqdm.auto import tqdm class BlackHole(object): def __setattr__(self, name, value): pass def __call__(self, *args, **kwargs): return self def __getattr__(self, name): return self def load_config(path): with open(path, 'r') as f: return EasyDict(yaml.safe_load(f)) def get_logger(name, log_dir=None): logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s') stream_handler = logging.StreamHandler() stream_handler.setLevel(logging.DEBUG) stream_handler.setFormatter(formatter) logger.addHandler(stream_handler) if log_dir is not None: file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt')) file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(formatter) logger.addHandler(file_handler) return logger def get_new_log_dir(root='./logs', prefix='', tag=''): fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime()) if prefix != '': fn = prefix + '_' + fn if tag != '': fn = fn + '_' + tag log_dir = os.path.join(root, fn) os.makedirs(log_dir) return log_dir def seed_all(seed): torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) def log_hyperparams(writer, args): from torch.utils.tensorboard.summary import hparams vars_args = {k:v if isinstance(v, str) else repr(v) for k, v in vars(args).items()} exp, ssi, sei = hparams(vars_args, {}) writer.file_writer.add_summary(exp) writer.file_writer.add_summary(ssi) writer.file_writer.add_summary(sei) def int_tuple(argstr): return tuple(map(int, argstr.split(','))) def str_tuple(argstr): return tuple(argstr.split(','))