ICLR_FLAG / utils /misc.py
zaixizhang
renew
10efe81
raw
history blame
1.97 kB
import os
import time
import random
import logging
import torch
import numpy as np
import yaml
from easydict import EasyDict
from logging import Logger
from tqdm.auto import tqdm
class BlackHole(object):
def __setattr__(self, name, value):
pass
def __call__(self, *args, **kwargs):
return self
def __getattr__(self, name):
return self
def load_config(path):
with open(path, 'r') as f:
return EasyDict(yaml.safe_load(f))
def get_logger(name, log_dir=None):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s')
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.DEBUG)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if log_dir is not None:
file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt'))
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
def get_new_log_dir(root='./logs', prefix='', tag=''):
fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime())
if prefix != '':
fn = prefix + '_' + fn
if tag != '':
fn = fn + '_' + tag
log_dir = os.path.join(root, fn)
os.makedirs(log_dir)
return log_dir
def seed_all(seed):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def log_hyperparams(writer, args):
from torch.utils.tensorboard.summary import hparams
vars_args = {k:v if isinstance(v, str) else repr(v) for k, v in vars(args).items()}
exp, ssi, sei = hparams(vars_args, {})
writer.file_writer.add_summary(exp)
writer.file_writer.add_summary(ssi)
writer.file_writer.add_summary(sei)
def int_tuple(argstr):
return tuple(map(int, argstr.split(',')))
def str_tuple(argstr):
return tuple(argstr.split(','))