File size: 3,766 Bytes
854f0d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import numpy as np
import torchvision.utils as vutils
import torch, random
import torch.nn.functional as F
# print arguments
def print_args(args):
print("################################ args ################################")
for k, v in args.__dict__.items():
print("{0: <10}\t{1: <30}\t{2: <20}".format(k, str(v), str(type(v))))
print("########################################################################")
# torch.no_grad warpper for functions
def make_nograd_func(func):
def wrapper(*f_args, **f_kwargs):
with torch.no_grad():
ret = func(*f_args, **f_kwargs)
return ret
return wrapper
# convert a function into recursive style to handle nested dict/list/tuple variables
def make_recursive_func(func):
def wrapper(vars, device=None):
if isinstance(vars, list):
return [wrapper(x, device) for x in vars]
elif isinstance(vars, tuple):
return tuple([wrapper(x, device) for x in vars])
elif isinstance(vars, dict):
return {k: wrapper(v, device) for k, v in vars.items()}
else:
return func(vars, device)
return wrapper
@make_recursive_func
def tensor2float(vars):
if isinstance(vars, float):
return vars
elif isinstance(vars, torch.Tensor):
return vars.data.item()
else:
raise NotImplementedError("invalid input type {} for tensor2float".format(type(vars)))
@make_recursive_func
def tensor2numpy(vars):
if isinstance(vars, np.ndarray):
return vars
elif isinstance(vars, torch.Tensor):
return vars.detach().cpu().numpy().copy()
else:
raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars)))
@make_recursive_func
def numpy2tensor(vars, device='cpu'):
if not isinstance(vars, torch.Tensor) and vars is not None :
return torch.tensor(vars, device=device)
elif isinstance(vars, torch.Tensor):
return vars
elif vars is None:
return vars
else:
raise NotImplementedError("invalid input type {} for float2tensor".format(type(vars)))
@make_recursive_func
def tocuda(vars, device='cuda'):
if isinstance(vars, torch.Tensor):
return vars.to(device)
elif isinstance(vars, str):
return vars
else:
raise NotImplementedError("invalid input type {} for tocuda".format(type(vars)))
import torch.distributed as dist
def synchronize():
"""
Helper function to synchronize (barrier) among all processes when
using distributed training
"""
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
def get_world_size():
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def reduce_scalar_outputs(scalar_outputs):
world_size = get_world_size()
if world_size < 2:
return scalar_outputs
with torch.no_grad():
names = []
scalars = []
for k in sorted(scalar_outputs.keys()):
names.append(k)
if isinstance(scalar_outputs[k], torch.Tensor):
scalars.append(scalar_outputs[k])
else:
scalars.append(torch.tensor(scalar_outputs[k], device='cuda'))
scalars = torch.stack(scalars, dim=0)
dist.reduce(scalars, dst=0)
if dist.get_rank() == 0:
# only main process gets accumulated, so only divide by
# world_size in this case
scalars /= world_size
reduced_scalars = {k: v for k, v in zip(names, scalars)}
return reduced_scalars
|