|
|
import numpy as np |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.autograd import Variable |
|
|
|
|
|
from torch.distributions import laplace |
|
|
from torch.distributions import uniform |
|
|
from torch.nn.modules.loss import _Loss |
|
|
from contextlib import contextmanager |
|
|
|
|
|
def replicate_input(x): |
|
|
""" |
|
|
Clone the input tensor x. |
|
|
""" |
|
|
return x.detach().clone() |
|
|
|
|
|
|
|
|
def replicate_input_withgrad(x): |
|
|
""" |
|
|
Clone the input tensor x and set requires_grad=True. |
|
|
""" |
|
|
return x.detach().clone().requires_grad_() |
|
|
|
|
|
|
|
|
def calc_l2distsq(x, y): |
|
|
""" |
|
|
Calculate L2 distance between tensors x and y. |
|
|
""" |
|
|
d = (x - y)**2 |
|
|
return d.view(d.shape[0], -1).sum(dim=1) |
|
|
|
|
|
|
|
|
def clamp(input, min=None, max=None): |
|
|
""" |
|
|
Clamp a tensor by its minimun and maximun values. |
|
|
""" |
|
|
ndim = input.ndimension() |
|
|
if min is None: |
|
|
pass |
|
|
elif isinstance(min, (float, int)): |
|
|
input = torch.clamp(input, min=min) |
|
|
elif isinstance(min, torch.Tensor): |
|
|
if min.ndimension() == ndim - 1 and min.shape == input.shape[1:]: |
|
|
input = torch.max(input, min.view(1, *min.shape)) |
|
|
else: |
|
|
assert min.shape == input.shape |
|
|
input = torch.max(input, min) |
|
|
else: |
|
|
raise ValueError("min can only be None | float | torch.Tensor") |
|
|
|
|
|
if max is None: |
|
|
pass |
|
|
elif isinstance(max, (float, int)): |
|
|
input = torch.clamp(input, max=max) |
|
|
elif isinstance(max, torch.Tensor): |
|
|
if max.ndimension() == ndim - 1 and max.shape == input.shape[1:]: |
|
|
input = torch.min(input, max.view(1, *max.shape)) |
|
|
else: |
|
|
assert max.shape == input.shape |
|
|
input = torch.min(input, max) |
|
|
else: |
|
|
raise ValueError("max can only be None | float | torch.Tensor") |
|
|
return input |
|
|
|
|
|
|
|
|
def _batch_multiply_tensor_by_vector(vector, batch_tensor): |
|
|
"""Equivalent to the following. |
|
|
for ii in range(len(vector)): |
|
|
batch_tensor.data[ii] *= vector[ii] |
|
|
return batch_tensor |
|
|
""" |
|
|
return ( |
|
|
batch_tensor.transpose(0, -1) * vector).transpose(0, -1).contiguous() |
|
|
|
|
|
|
|
|
def _batch_clamp_tensor_by_vector(vector, batch_tensor): |
|
|
"""Equivalent to the following. |
|
|
for ii in range(len(vector)): |
|
|
batch_tensor[ii] = clamp( |
|
|
batch_tensor[ii], -vector[ii], vector[ii]) |
|
|
""" |
|
|
return torch.min( |
|
|
torch.max(batch_tensor.transpose(0, -1), -vector), vector |
|
|
).transpose(0, -1).contiguous() |
|
|
|
|
|
|
|
|
def batch_multiply(float_or_vector, tensor): |
|
|
""" |
|
|
Multpliy a batch of tensors with a float or vector. |
|
|
""" |
|
|
if isinstance(float_or_vector, torch.Tensor): |
|
|
assert len(float_or_vector) == len(tensor) |
|
|
tensor = _batch_multiply_tensor_by_vector(float_or_vector, tensor) |
|
|
elif isinstance(float_or_vector, float): |
|
|
tensor *= float_or_vector |
|
|
else: |
|
|
raise TypeError("Value has to be float or torch.Tensor") |
|
|
return tensor |
|
|
|
|
|
|
|
|
def batch_clamp(float_or_vector, tensor): |
|
|
""" |
|
|
Clamp a batch of tensors. |
|
|
""" |
|
|
if isinstance(float_or_vector, torch.Tensor): |
|
|
assert len(float_or_vector) == len(tensor) |
|
|
tensor = _batch_clamp_tensor_by_vector(float_or_vector, tensor) |
|
|
return tensor |
|
|
elif isinstance(float_or_vector, float): |
|
|
tensor = clamp(tensor, -float_or_vector, float_or_vector) |
|
|
else: |
|
|
raise TypeError("Value has to be float or torch.Tensor") |
|
|
return tensor |
|
|
|
|
|
|
|
|
def _get_norm_batch(x, p): |
|
|
""" |
|
|
Returns the Lp norm of batch x. |
|
|
""" |
|
|
batch_size = x.size(0) |
|
|
return x.abs().pow(p).view(batch_size, -1).sum(dim=1).pow(1. / p) |
|
|
|
|
|
|
|
|
def _thresh_by_magnitude(theta, x): |
|
|
""" |
|
|
Threshold by magnitude. |
|
|
""" |
|
|
return torch.relu(torch.abs(x) - theta) * x.sign() |
|
|
|
|
|
|
|
|
def clamp_by_pnorm(x, p, r): |
|
|
""" |
|
|
Clamp tensor by its norm. |
|
|
""" |
|
|
assert isinstance(p, float) or isinstance(p, int) |
|
|
norm = _get_norm_batch(x, p) |
|
|
if isinstance(r, torch.Tensor): |
|
|
assert norm.size() == r.size() |
|
|
else: |
|
|
assert isinstance(r, float) |
|
|
factor = torch.min(r / norm, torch.ones_like(norm)) |
|
|
return batch_multiply(factor, x) |
|
|
|
|
|
|
|
|
def is_float_or_torch_tensor(x): |
|
|
""" |
|
|
Return whether input x is a float or a torch.Tensor. |
|
|
""" |
|
|
return isinstance(x, torch.Tensor) or isinstance(x, float) |
|
|
|
|
|
|
|
|
def normalize_by_pnorm(x, p=2, small_constant=1e-6): |
|
|
""" |
|
|
Normalize gradients for gradient (not gradient sign) attacks. |
|
|
Arguments: |
|
|
x (torch.Tensor): tensor containing the gradients on the input. |
|
|
p (int): (optional) order of the norm for the normalization (1 or 2). |
|
|
small_constant (float): (optional) to avoid dividing by zero. |
|
|
Returns: |
|
|
normalized gradients. |
|
|
""" |
|
|
assert isinstance(p, float) or isinstance(p, int) |
|
|
norm = _get_norm_batch(x, p) |
|
|
norm = torch.max(norm, torch.ones_like(norm) * small_constant) |
|
|
return batch_multiply(1. / norm, x) |
|
|
|
|
|
|
|
|
def rand_init_delta(delta, x, ord, eps, clip_min, clip_max): |
|
|
""" |
|
|
Randomly initialize the perturbation. |
|
|
""" |
|
|
if isinstance(eps, torch.Tensor): |
|
|
assert len(eps) == len(delta) |
|
|
|
|
|
if ord == np.inf: |
|
|
delta.data.uniform_(-1, 1) |
|
|
delta.data = batch_multiply(eps, delta.data) |
|
|
elif ord == 2: |
|
|
delta.data.uniform_(clip_min, clip_max) |
|
|
delta.data = delta.data - x |
|
|
delta.data = clamp_by_pnorm(delta.data, ord, eps) |
|
|
elif ord == 1: |
|
|
ini = laplace.Laplace( |
|
|
loc=delta.new_tensor(0), scale=delta.new_tensor(1)) |
|
|
delta.data = ini.sample(delta.data.shape) |
|
|
delta.data = normalize_by_pnorm(delta.data, p=1) |
|
|
ray = uniform.Uniform(0, eps).sample() |
|
|
delta.data *= ray |
|
|
delta.data = clamp(x.data + delta.data, clip_min, clip_max) - x.data |
|
|
else: |
|
|
error = "Only ord = inf, ord = 1 and ord = 2 have been implemented" |
|
|
raise NotImplementedError(error) |
|
|
|
|
|
delta.data = clamp( |
|
|
x + delta.data, min=clip_min, max=clip_max) - x |
|
|
return delta.data |
|
|
|
|
|
|
|
|
def CWLoss(output, target, confidence=0): |
|
|
""" |
|
|
CW loss (Marging loss). |
|
|
""" |
|
|
num_classes = output.shape[-1] |
|
|
target = target.data |
|
|
target_onehot = torch.zeros(target.size() + (num_classes,)) |
|
|
target_onehot = target_onehot.cuda() |
|
|
target_onehot.scatter_(1, target.unsqueeze(1), 1.) |
|
|
target_var = Variable(target_onehot, requires_grad=False) |
|
|
real = (target_var * output).sum(1) |
|
|
other = ((1. - target_var) * output - target_var * 10000.).max(1)[0] |
|
|
loss = - torch.clamp(real - other + confidence, min=0.) |
|
|
loss = torch.sum(loss) |
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ctx_noparamgrad(object): |
|
|
def __init__(self, module): |
|
|
self.prev_grad_state = get_param_grad_state(module) |
|
|
self.module = module |
|
|
set_param_grad_off(module) |
|
|
|
|
|
def __enter__(self): |
|
|
pass |
|
|
|
|
|
def __exit__(self, *args): |
|
|
set_param_grad_state(self.module, self.prev_grad_state) |
|
|
return False |
|
|
|
|
|
|
|
|
class ctx_eval(object): |
|
|
def __init__(self, module): |
|
|
self.prev_training_state = get_module_training_state(module) |
|
|
self.module = module |
|
|
set_module_training_off(module) |
|
|
|
|
|
def __enter__(self): |
|
|
pass |
|
|
|
|
|
def __exit__(self, *args): |
|
|
set_module_training_state(self.module, self.prev_training_state) |
|
|
return False |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def ctx_noparamgrad_and_eval(module): |
|
|
with ctx_noparamgrad(module) as a, ctx_eval(module) as b: |
|
|
yield (a, b) |
|
|
|
|
|
|
|
|
def get_module_training_state(module): |
|
|
return {mod: mod.training for mod in module.modules()} |
|
|
|
|
|
|
|
|
def set_module_training_state(module, training_state): |
|
|
for mod in module.modules(): |
|
|
mod.training = training_state[mod] |
|
|
|
|
|
|
|
|
def set_module_training_off(module): |
|
|
for mod in module.modules(): |
|
|
mod.training = False |
|
|
|
|
|
|
|
|
def get_param_grad_state(module): |
|
|
return {param: param.requires_grad for param in module.parameters()} |
|
|
|
|
|
|
|
|
def set_param_grad_state(module, grad_state): |
|
|
for param in module.parameters(): |
|
|
param.requires_grad = grad_state[param] |
|
|
|
|
|
|
|
|
def set_param_grad_off(module): |
|
|
for param in module.parameters(): |
|
|
param.requires_grad = False |