import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from torch.distributions import laplace from torch.distributions import uniform from torch.nn.modules.loss import _Loss from contextlib import contextmanager def replicate_input(x): """ Clone the input tensor x. """ return x.detach().clone() def replicate_input_withgrad(x): """ Clone the input tensor x and set requires_grad=True. """ return x.detach().clone().requires_grad_() def calc_l2distsq(x, y): """ Calculate L2 distance between tensors x and y. """ d = (x - y)**2 return d.view(d.shape[0], -1).sum(dim=1) def clamp(input, min=None, max=None): """ Clamp a tensor by its minimun and maximun values. """ ndim = input.ndimension() if min is None: pass elif isinstance(min, (float, int)): input = torch.clamp(input, min=min) elif isinstance(min, torch.Tensor): if min.ndimension() == ndim - 1 and min.shape == input.shape[1:]: input = torch.max(input, min.view(1, *min.shape)) else: assert min.shape == input.shape input = torch.max(input, min) else: raise ValueError("min can only be None | float | torch.Tensor") if max is None: pass elif isinstance(max, (float, int)): input = torch.clamp(input, max=max) elif isinstance(max, torch.Tensor): if max.ndimension() == ndim - 1 and max.shape == input.shape[1:]: input = torch.min(input, max.view(1, *max.shape)) else: assert max.shape == input.shape input = torch.min(input, max) else: raise ValueError("max can only be None | float | torch.Tensor") return input def _batch_multiply_tensor_by_vector(vector, batch_tensor): """Equivalent to the following. for ii in range(len(vector)): batch_tensor.data[ii] *= vector[ii] return batch_tensor """ return ( batch_tensor.transpose(0, -1) * vector).transpose(0, -1).contiguous() def _batch_clamp_tensor_by_vector(vector, batch_tensor): """Equivalent to the following. for ii in range(len(vector)): batch_tensor[ii] = clamp( batch_tensor[ii], -vector[ii], vector[ii]) """ return torch.min( torch.max(batch_tensor.transpose(0, -1), -vector), vector ).transpose(0, -1).contiguous() def batch_multiply(float_or_vector, tensor): """ Multpliy a batch of tensors with a float or vector. """ if isinstance(float_or_vector, torch.Tensor): assert len(float_or_vector) == len(tensor) tensor = _batch_multiply_tensor_by_vector(float_or_vector, tensor) elif isinstance(float_or_vector, float): tensor *= float_or_vector else: raise TypeError("Value has to be float or torch.Tensor") return tensor def batch_clamp(float_or_vector, tensor): """ Clamp a batch of tensors. """ if isinstance(float_or_vector, torch.Tensor): assert len(float_or_vector) == len(tensor) tensor = _batch_clamp_tensor_by_vector(float_or_vector, tensor) return tensor elif isinstance(float_or_vector, float): tensor = clamp(tensor, -float_or_vector, float_or_vector) else: raise TypeError("Value has to be float or torch.Tensor") return tensor def _get_norm_batch(x, p): """ Returns the Lp norm of batch x. """ batch_size = x.size(0) return x.abs().pow(p).view(batch_size, -1).sum(dim=1).pow(1. / p) def _thresh_by_magnitude(theta, x): """ Threshold by magnitude. """ return torch.relu(torch.abs(x) - theta) * x.sign() def clamp_by_pnorm(x, p, r): """ Clamp tensor by its norm. """ assert isinstance(p, float) or isinstance(p, int) norm = _get_norm_batch(x, p) if isinstance(r, torch.Tensor): assert norm.size() == r.size() else: assert isinstance(r, float) factor = torch.min(r / norm, torch.ones_like(norm)) return batch_multiply(factor, x) def is_float_or_torch_tensor(x): """ Return whether input x is a float or a torch.Tensor. """ return isinstance(x, torch.Tensor) or isinstance(x, float) def normalize_by_pnorm(x, p=2, small_constant=1e-6): """ Normalize gradients for gradient (not gradient sign) attacks. Arguments: x (torch.Tensor): tensor containing the gradients on the input. p (int): (optional) order of the norm for the normalization (1 or 2). small_constant (float): (optional) to avoid dividing by zero. Returns: normalized gradients. """ assert isinstance(p, float) or isinstance(p, int) norm = _get_norm_batch(x, p) norm = torch.max(norm, torch.ones_like(norm) * small_constant) return batch_multiply(1. / norm, x) def rand_init_delta(delta, x, ord, eps, clip_min, clip_max): """ Randomly initialize the perturbation. """ if isinstance(eps, torch.Tensor): assert len(eps) == len(delta) if ord == np.inf: delta.data.uniform_(-1, 1) delta.data = batch_multiply(eps, delta.data) elif ord == 2: delta.data.uniform_(clip_min, clip_max) delta.data = delta.data - x delta.data = clamp_by_pnorm(delta.data, ord, eps) elif ord == 1: ini = laplace.Laplace( loc=delta.new_tensor(0), scale=delta.new_tensor(1)) delta.data = ini.sample(delta.data.shape) delta.data = normalize_by_pnorm(delta.data, p=1) ray = uniform.Uniform(0, eps).sample() delta.data *= ray delta.data = clamp(x.data + delta.data, clip_min, clip_max) - x.data else: error = "Only ord = inf, ord = 1 and ord = 2 have been implemented" raise NotImplementedError(error) delta.data = clamp( x + delta.data, min=clip_min, max=clip_max) - x return delta.data def CWLoss(output, target, confidence=0): """ CW loss (Marging loss). """ num_classes = output.shape[-1] target = target.data target_onehot = torch.zeros(target.size() + (num_classes,)) target_onehot = target_onehot.cuda() target_onehot.scatter_(1, target.unsqueeze(1), 1.) target_var = Variable(target_onehot, requires_grad=False) real = (target_var * output).sum(1) other = ((1. - target_var) * output - target_var * 10000.).max(1)[0] loss = - torch.clamp(real - other + confidence, min=0.) loss = torch.sum(loss) return loss class ctx_noparamgrad(object): def __init__(self, module): self.prev_grad_state = get_param_grad_state(module) self.module = module set_param_grad_off(module) def __enter__(self): pass def __exit__(self, *args): set_param_grad_state(self.module, self.prev_grad_state) return False class ctx_eval(object): def __init__(self, module): self.prev_training_state = get_module_training_state(module) self.module = module set_module_training_off(module) def __enter__(self): pass def __exit__(self, *args): set_module_training_state(self.module, self.prev_training_state) return False @contextmanager def ctx_noparamgrad_and_eval(module): with ctx_noparamgrad(module) as a, ctx_eval(module) as b: yield (a, b) def get_module_training_state(module): return {mod: mod.training for mod in module.modules()} def set_module_training_state(module, training_state): for mod in module.modules(): mod.training = training_state[mod] def set_module_training_off(module): for mod in module.modules(): mod.training = False def get_param_grad_state(module): return {param: param.requires_grad for param in module.parameters()} def set_param_grad_state(module, grad_state): for param in module.parameters(): param.requires_grad = grad_state[param] def set_param_grad_off(module): for param in module.parameters(): param.requires_grad = False