import numpy as np import torchvision.utils as vutils import torch, random import torch.nn.functional as F # print arguments def print_args(args): print("################################ args ################################") for k, v in args.__dict__.items(): print("{0: <10}\t{1: <30}\t{2: <20}".format(k, str(v), str(type(v)))) print("########################################################################") # torch.no_grad warpper for functions def make_nograd_func(func): def wrapper(*f_args, **f_kwargs): with torch.no_grad(): ret = func(*f_args, **f_kwargs) return ret return wrapper # convert a function into recursive style to handle nested dict/list/tuple variables def make_recursive_func(func): def wrapper(vars, device=None): if isinstance(vars, list): return [wrapper(x, device) for x in vars] elif isinstance(vars, tuple): return tuple([wrapper(x, device) for x in vars]) elif isinstance(vars, dict): return {k: wrapper(v, device) for k, v in vars.items()} else: return func(vars, device) return wrapper @make_recursive_func def tensor2float(vars): if isinstance(vars, float): return vars elif isinstance(vars, torch.Tensor): return vars.data.item() else: raise NotImplementedError("invalid input type {} for tensor2float".format(type(vars))) @make_recursive_func def tensor2numpy(vars): if isinstance(vars, np.ndarray): return vars elif isinstance(vars, torch.Tensor): return vars.detach().cpu().numpy().copy() else: raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars))) @make_recursive_func def numpy2tensor(vars, device='cpu'): if not isinstance(vars, torch.Tensor) and vars is not None : return torch.tensor(vars, device=device) elif isinstance(vars, torch.Tensor): return vars elif vars is None: return vars else: raise NotImplementedError("invalid input type {} for float2tensor".format(type(vars))) @make_recursive_func def tocuda(vars, device='cuda'): if isinstance(vars, torch.Tensor): return vars.to(device) elif isinstance(vars, str): return vars else: raise NotImplementedError("invalid input type {} for tocuda".format(type(vars))) import torch.distributed as dist def synchronize(): """ Helper function to synchronize (barrier) among all processes when using distributed training """ if not dist.is_available(): return if not dist.is_initialized(): return world_size = dist.get_world_size() if world_size == 1: return dist.barrier() def get_world_size(): if not dist.is_available(): return 1 if not dist.is_initialized(): return 1 return dist.get_world_size() def reduce_scalar_outputs(scalar_outputs): world_size = get_world_size() if world_size < 2: return scalar_outputs with torch.no_grad(): names = [] scalars = [] for k in sorted(scalar_outputs.keys()): names.append(k) if isinstance(scalar_outputs[k], torch.Tensor): scalars.append(scalar_outputs[k]) else: scalars.append(torch.tensor(scalar_outputs[k], device='cuda')) scalars = torch.stack(scalars, dim=0) dist.reduce(scalars, dst=0) if dist.get_rank() == 0: # only main process gets accumulated, so only divide by # world_size in this case scalars /= world_size reduced_scalars = {k: v for k, v in zip(names, scalars)} return reduced_scalars