|
import numpy as np |
|
import torchvision.utils as vutils |
|
import torch, random |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
def print_args(args): |
|
print("################################ args ################################") |
|
for k, v in args.__dict__.items(): |
|
print("{0: <10}\t{1: <30}\t{2: <20}".format(k, str(v), str(type(v)))) |
|
print("########################################################################") |
|
|
|
|
|
|
|
def make_nograd_func(func): |
|
def wrapper(*f_args, **f_kwargs): |
|
with torch.no_grad(): |
|
ret = func(*f_args, **f_kwargs) |
|
return ret |
|
|
|
return wrapper |
|
|
|
|
|
|
|
def make_recursive_func(func): |
|
def wrapper(vars, device=None): |
|
if isinstance(vars, list): |
|
return [wrapper(x, device) for x in vars] |
|
elif isinstance(vars, tuple): |
|
return tuple([wrapper(x, device) for x in vars]) |
|
elif isinstance(vars, dict): |
|
return {k: wrapper(v, device) for k, v in vars.items()} |
|
else: |
|
return func(vars, device) |
|
|
|
return wrapper |
|
|
|
|
|
@make_recursive_func |
|
def tensor2float(vars): |
|
if isinstance(vars, float): |
|
return vars |
|
elif isinstance(vars, torch.Tensor): |
|
return vars.data.item() |
|
else: |
|
raise NotImplementedError("invalid input type {} for tensor2float".format(type(vars))) |
|
|
|
|
|
@make_recursive_func |
|
def tensor2numpy(vars): |
|
if isinstance(vars, np.ndarray): |
|
return vars |
|
elif isinstance(vars, torch.Tensor): |
|
return vars.detach().cpu().numpy().copy() |
|
else: |
|
raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars))) |
|
|
|
|
|
@make_recursive_func |
|
def numpy2tensor(vars, device='cpu'): |
|
if not isinstance(vars, torch.Tensor) and vars is not None : |
|
return torch.tensor(vars, device=device) |
|
elif isinstance(vars, torch.Tensor): |
|
return vars |
|
elif vars is None: |
|
return vars |
|
else: |
|
raise NotImplementedError("invalid input type {} for float2tensor".format(type(vars))) |
|
|
|
|
|
@make_recursive_func |
|
def tocuda(vars, device='cuda'): |
|
if isinstance(vars, torch.Tensor): |
|
return vars.to(device) |
|
elif isinstance(vars, str): |
|
return vars |
|
else: |
|
raise NotImplementedError("invalid input type {} for tocuda".format(type(vars))) |
|
|
|
|
|
import torch.distributed as dist |
|
|
|
|
|
def synchronize(): |
|
""" |
|
Helper function to synchronize (barrier) among all processes when |
|
using distributed training |
|
""" |
|
if not dist.is_available(): |
|
return |
|
if not dist.is_initialized(): |
|
return |
|
world_size = dist.get_world_size() |
|
if world_size == 1: |
|
return |
|
dist.barrier() |
|
|
|
|
|
def get_world_size(): |
|
if not dist.is_available(): |
|
return 1 |
|
if not dist.is_initialized(): |
|
return 1 |
|
return dist.get_world_size() |
|
|
|
|
|
def reduce_scalar_outputs(scalar_outputs): |
|
world_size = get_world_size() |
|
if world_size < 2: |
|
return scalar_outputs |
|
with torch.no_grad(): |
|
names = [] |
|
scalars = [] |
|
for k in sorted(scalar_outputs.keys()): |
|
names.append(k) |
|
if isinstance(scalar_outputs[k], torch.Tensor): |
|
scalars.append(scalar_outputs[k]) |
|
else: |
|
scalars.append(torch.tensor(scalar_outputs[k], device='cuda')) |
|
scalars = torch.stack(scalars, dim=0) |
|
dist.reduce(scalars, dst=0) |
|
if dist.get_rank() == 0: |
|
|
|
|
|
scalars /= world_size |
|
reduced_scalars = {k: v for k, v in zip(names, scalars)} |
|
|
|
return reduced_scalars |
|
|