Delete folder smi08 with huggingface_hub
Browse files- smi08/ProArd/attacks/CW.py +0 -0
- smi08/ProArd/attacks/LBFGS.py +0 -0
- smi08/ProArd/attacks/__init__.py +0 -71
- smi08/ProArd/attacks/apgd_ce.py +0 -102
- smi08/ProArd/attacks/autoattack.py +0 -92
- smi08/ProArd/attacks/base.py +0 -63
- smi08/ProArd/attacks/deepfool.py +0 -253
- smi08/ProArd/attacks/fgsm.py +0 -171
- smi08/ProArd/attacks/local_lip.py +0 -29
- smi08/ProArd/attacks/pgd.py +0 -248
- smi08/ProArd/attacks/squred.py +0 -86
- smi08/ProArd/attacks/utils.py +0 -279
smi08/ProArd/attacks/CW.py
DELETED
|
File without changes
|
smi08/ProArd/attacks/LBFGS.py
DELETED
|
File without changes
|
smi08/ProArd/attacks/__init__.py
DELETED
|
@@ -1,71 +0,0 @@
|
|
| 1 |
-
from .base import Attack
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
from .fgsm import FGMAttack
|
| 5 |
-
from .fgsm import FGSMAttack
|
| 6 |
-
from .fgsm import L2FastGradientAttack
|
| 7 |
-
from .fgsm import LinfFastGradientAttack
|
| 8 |
-
|
| 9 |
-
from .pgd import PGDAttack
|
| 10 |
-
from .pgd import L2PGDAttack
|
| 11 |
-
from .pgd import LinfPGDAttack
|
| 12 |
-
|
| 13 |
-
from .deepfool import DeepFoolAttack
|
| 14 |
-
from .deepfool import LinfDeepFoolAttack
|
| 15 |
-
from .deepfool import L2DeepFoolAttack
|
| 16 |
-
|
| 17 |
-
from .utils import CWLoss
|
| 18 |
-
from .autoattack import AutoAttacks
|
| 19 |
-
from .apgd_ce import Autoattack_apgd_ce
|
| 20 |
-
from .squred import Squre_Attack
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
ATTACKS = ['fgsm', 'linf-pgd', 'fgm', 'l2-pgd', 'linf-df', 'l2-df', 'linf-apgd', 'l2-apgd','squar_attack','autoattack','apgd_ce']
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def create_attack(model, criterion, attack_type, attack_eps, attack_iter, attack_step, rand_init_type='uniform',
|
| 27 |
-
clip_min=0., clip_max=1.):
|
| 28 |
-
"""
|
| 29 |
-
Initialize adversary.
|
| 30 |
-
Arguments:
|
| 31 |
-
model (nn.Module): forward pass function.
|
| 32 |
-
criterion (nn.Module): loss function.
|
| 33 |
-
attack_type (str): name of the attack.
|
| 34 |
-
attack_eps (float): attack radius.
|
| 35 |
-
attack_iter (int): number of attack iterations.
|
| 36 |
-
attack_step (float): step size for the attack.
|
| 37 |
-
rand_init_type (str): random initialization type for PGD (default: uniform).
|
| 38 |
-
clip_min (float): mininum value per input dimension.
|
| 39 |
-
clip_max (float): maximum value per input dimension.
|
| 40 |
-
Returns:
|
| 41 |
-
Attack
|
| 42 |
-
"""
|
| 43 |
-
|
| 44 |
-
if attack_type == 'fgsm':
|
| 45 |
-
attack = FGSMAttack(model, criterion, eps=attack_eps, clip_min=clip_min, clip_max=clip_max)
|
| 46 |
-
elif attack_type == 'fgm':
|
| 47 |
-
attack = FGMAttack(model, criterion, eps=attack_eps, clip_min=clip_min, clip_max=clip_max)
|
| 48 |
-
elif attack_type == 'linf-pgd':
|
| 49 |
-
attack = LinfPGDAttack(model, criterion, eps=attack_eps, nb_iter=attack_iter, eps_iter=attack_step,
|
| 50 |
-
rand_init_type=rand_init_type, clip_min=clip_min, clip_max=clip_max)
|
| 51 |
-
elif attack_type == 'l2-pgd':
|
| 52 |
-
attack = L2PGDAttack(model, criterion, eps=attack_eps, nb_iter=attack_iter, eps_iter=attack_step,
|
| 53 |
-
rand_init_type=rand_init_type, clip_min=clip_min, clip_max=clip_max)
|
| 54 |
-
elif attack_type == 'linf-df':
|
| 55 |
-
attack = LinfDeepFoolAttack(model, overshoot=0.02, nb_iter=attack_iter, search_iter=0, clip_min=clip_min,
|
| 56 |
-
clip_max=clip_max)
|
| 57 |
-
elif attack_type == 'l2-df':
|
| 58 |
-
attack = L2DeepFoolAttack(model, overshoot=0.02, nb_iter=attack_iter, search_iter=0, clip_min=clip_min,
|
| 59 |
-
clip_max=clip_max)
|
| 60 |
-
elif attack_type == 'squar_attack':
|
| 61 |
-
attack = Squre_Attack(model, criterion, nb_iter=attack_iter, eps_iter=attack_step,
|
| 62 |
-
rand_init_type=rand_init_type, clip_min=clip_min, clip_max=clip_max)
|
| 63 |
-
elif attack_type == "autoattack":
|
| 64 |
-
attack = AutoAttacks(model, nb_iter=attack_iter, eps=attack_eps, eps_iter=attack_step,
|
| 65 |
-
rand_init_type=rand_init_type, clip_min=clip_min, clip_max=clip_max)
|
| 66 |
-
elif attack_type == "apgd_ce":
|
| 67 |
-
attack = Autoattack_apgd_ce (model, nb_iter=attack_iter, eps_iter=attack_step,
|
| 68 |
-
rand_init_type=rand_init_type, clip_min=clip_min, clip_max=clip_max)
|
| 69 |
-
else:
|
| 70 |
-
raise NotImplementedError('{} is not yet implemented!'.format(attack_type))
|
| 71 |
-
return attack
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
smi08/ProArd/attacks/apgd_ce.py
DELETED
|
@@ -1,102 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
from torch.autograd import Variable
|
| 5 |
-
from adv_lib.attacks import carlini_wagner_linf
|
| 6 |
-
import torch.optim as optim
|
| 7 |
-
from autoattack import AutoAttack
|
| 8 |
-
import numpy as np
|
| 9 |
-
import logging
|
| 10 |
-
from .base import Attack,LabelMixin
|
| 11 |
-
from typing import List, Union,Dict
|
| 12 |
-
|
| 13 |
-
import torch
|
| 14 |
-
import torch.nn as nn
|
| 15 |
-
from typing import Dict
|
| 16 |
-
from .utils import ctx_noparamgrad_and_eval
|
| 17 |
-
from utils.distributed import DistributedMetric
|
| 18 |
-
from tqdm import tqdm
|
| 19 |
-
from torchpack import distributed as dist
|
| 20 |
-
from utils import accuracy
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class Autoattack_apgd_ce(Attack, LabelMixin):
|
| 25 |
-
|
| 26 |
-
def __init__(
|
| 27 |
-
self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1.,
|
| 28 |
-
ord=np.inf, targeted=False, rand_init_type='uniform'):
|
| 29 |
-
super(Autoattack_apgd_ce, self).__init__(predict, loss_fn, clip_min, clip_max)
|
| 30 |
-
self.eps = eps
|
| 31 |
-
self.nb_iter = nb_iter
|
| 32 |
-
self.eps_iter = eps_iter
|
| 33 |
-
self.rand_init = rand_init
|
| 34 |
-
self.rand_init_type = rand_init_type
|
| 35 |
-
self.ord = ord
|
| 36 |
-
self.targeted = targeted
|
| 37 |
-
if self.loss_fn is None:
|
| 38 |
-
self.loss_fn = nn.CrossEntropyLoss(reduction="sum")
|
| 39 |
-
self.adversary = AutoAttack(predict, norm='Linf', eps=self.eps, version='standard')
|
| 40 |
-
def perturb(self, x, y=None):
|
| 41 |
-
self.adversary.attacks_to_run=['apgd-ce']
|
| 42 |
-
adversarial_examples = self.adversary.run_standard_evaluation(x, y, bs=100)
|
| 43 |
-
return adversarial_examples,adversarial_examples
|
| 44 |
-
|
| 45 |
-
def eval_AutoAttack_apgd_ce(self,data_loader_dict: Dict)-> Dict:
|
| 46 |
-
|
| 47 |
-
test_criterion = nn.CrossEntropyLoss().cuda()
|
| 48 |
-
val_loss = DistributedMetric()
|
| 49 |
-
val_top1 = DistributedMetric()
|
| 50 |
-
val_top5 = DistributedMetric()
|
| 51 |
-
val_advloss = DistributedMetric()
|
| 52 |
-
val_advtop1 = DistributedMetric()
|
| 53 |
-
val_advtop5 = DistributedMetric()
|
| 54 |
-
self.predict.eval()
|
| 55 |
-
with tqdm(
|
| 56 |
-
total=len(data_loader_dict["val"]),
|
| 57 |
-
desc="Eval",
|
| 58 |
-
disable=not dist.is_master(),
|
| 59 |
-
) as t:
|
| 60 |
-
for images, labels in data_loader_dict["val"]:
|
| 61 |
-
images, labels = images.cuda(), labels.cuda()
|
| 62 |
-
# compute output
|
| 63 |
-
output = self.predict(images)
|
| 64 |
-
loss = test_criterion(output, labels)
|
| 65 |
-
val_loss.update(loss, images.shape[0])
|
| 66 |
-
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
| 67 |
-
val_top5.update(acc5[0], images.shape[0])
|
| 68 |
-
val_top1.update(acc1[0], images.shape[0])
|
| 69 |
-
with ctx_noparamgrad_and_eval(self.predict):
|
| 70 |
-
images_adv,_ = self.perturb(images, labels)
|
| 71 |
-
output_adv = self.predict(images_adv)
|
| 72 |
-
loss_adv = test_criterion(output_adv,labels)
|
| 73 |
-
val_advloss.update(loss_adv, images.shape[0])
|
| 74 |
-
acc1_adv, acc5_adv = accuracy(output_adv, labels, topk=(1, 5))
|
| 75 |
-
val_advtop1.update(acc1_adv[0], images.shape[0])
|
| 76 |
-
val_advtop5.update(acc5_adv[0], images.shape[0])
|
| 77 |
-
t.set_postfix(
|
| 78 |
-
{
|
| 79 |
-
"loss": val_loss.avg.item(),
|
| 80 |
-
"top1": val_top1.avg.item(),
|
| 81 |
-
"top5": val_top5.avg.item(),
|
| 82 |
-
"adv_loss": val_advloss.avg.item(),
|
| 83 |
-
"adv_top1": val_advtop1.avg.item(),
|
| 84 |
-
"adv_top5": val_advtop5.avg.item(),
|
| 85 |
-
"#samples": val_top1.count.item(),
|
| 86 |
-
"batch_size": images.shape[0],
|
| 87 |
-
"img_size": images.shape[2],
|
| 88 |
-
}
|
| 89 |
-
)
|
| 90 |
-
t.update()
|
| 91 |
-
|
| 92 |
-
val_results = {
|
| 93 |
-
"val_top1": val_top1.avg.item(),
|
| 94 |
-
"val_top5": val_top5.avg.item(),
|
| 95 |
-
"val_loss": val_loss.avg.item(),
|
| 96 |
-
"val_advtop1": val_advtop1.avg.item(),
|
| 97 |
-
"val_advtop5": val_advtop5.avg.item(),
|
| 98 |
-
"val_advloss": val_advloss.avg.item(),
|
| 99 |
-
}
|
| 100 |
-
return val_results
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
smi08/ProArd/attacks/autoattack.py
DELETED
|
@@ -1,92 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
from autoattack import AutoAttack
|
| 5 |
-
import numpy as np
|
| 6 |
-
import logging
|
| 7 |
-
from .base import Attack,LabelMixin
|
| 8 |
-
import torch
|
| 9 |
-
import torch.nn as nn
|
| 10 |
-
from typing import Dict
|
| 11 |
-
from .utils import ctx_noparamgrad_and_eval
|
| 12 |
-
from utils.distributed import DistributedMetric
|
| 13 |
-
from tqdm import tqdm
|
| 14 |
-
from torchpack import distributed as dist
|
| 15 |
-
from utils import accuracy
|
| 16 |
-
class AutoAttacks(Attack, LabelMixin):
|
| 17 |
-
|
| 18 |
-
def __init__(
|
| 19 |
-
self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1.,
|
| 20 |
-
ord=np.inf, targeted=False, rand_init_type='uniform'):
|
| 21 |
-
super(AutoAttacks, self).__init__(predict, loss_fn, clip_min, clip_max)
|
| 22 |
-
self.eps = eps
|
| 23 |
-
self.nb_iter = nb_iter
|
| 24 |
-
self.eps_iter = eps_iter
|
| 25 |
-
self.rand_init = rand_init
|
| 26 |
-
self.rand_init_type = rand_init_type
|
| 27 |
-
self.ord = ord
|
| 28 |
-
self.targeted = targeted
|
| 29 |
-
if self.loss_fn is None:
|
| 30 |
-
self.loss_fn = nn.CrossEntropyLoss(reduction="sum")
|
| 31 |
-
self.adversary = AutoAttack(predict, norm='Linf', eps=self.eps, version='standard')
|
| 32 |
-
def perturb(self, x, y=None):
|
| 33 |
-
adversarial_examples = self.adversary.run_standard_evaluation(x, y, bs=100)
|
| 34 |
-
return adversarial_examples,adversarial_examples
|
| 35 |
-
def eval_AutoAttack(self,data_loader_dict: Dict)-> Dict:
|
| 36 |
-
|
| 37 |
-
test_criterion = nn.CrossEntropyLoss().cuda()
|
| 38 |
-
val_loss = DistributedMetric()
|
| 39 |
-
val_top1 = DistributedMetric()
|
| 40 |
-
val_top5 = DistributedMetric()
|
| 41 |
-
val_advloss = DistributedMetric()
|
| 42 |
-
val_advtop1 = DistributedMetric()
|
| 43 |
-
val_advtop5 = DistributedMetric()
|
| 44 |
-
self.predict.eval()
|
| 45 |
-
with tqdm(
|
| 46 |
-
total=len(data_loader_dict["val"]),
|
| 47 |
-
desc="Eval",
|
| 48 |
-
disable=not dist.is_master(),
|
| 49 |
-
) as t:
|
| 50 |
-
for images, labels in data_loader_dict["val"]:
|
| 51 |
-
images, labels = images.cuda(), labels.cuda()
|
| 52 |
-
# compute output
|
| 53 |
-
output = self.predict(images)
|
| 54 |
-
loss = test_criterion(output, labels)
|
| 55 |
-
val_loss.update(loss, images.shape[0])
|
| 56 |
-
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
| 57 |
-
val_top5.update(acc5[0], images.shape[0])
|
| 58 |
-
val_top1.update(acc1[0], images.shape[0])
|
| 59 |
-
with ctx_noparamgrad_and_eval(self.predict):
|
| 60 |
-
images_adv,_ = self.perturb(images, labels)
|
| 61 |
-
output_adv = self.predict(images_adv)
|
| 62 |
-
loss_adv = test_criterion(output_adv,labels)
|
| 63 |
-
val_advloss.update(loss_adv, images.shape[0])
|
| 64 |
-
acc1_adv, acc5_adv = accuracy(output_adv, labels, topk=(1, 5))
|
| 65 |
-
val_advtop1.update(acc1_adv[0], images.shape[0])
|
| 66 |
-
val_advtop5.update(acc5_adv[0], images.shape[0])
|
| 67 |
-
t.set_postfix(
|
| 68 |
-
{
|
| 69 |
-
"loss": val_loss.avg.item(),
|
| 70 |
-
"top1": val_top1.avg.item(),
|
| 71 |
-
"top5": val_top5.avg.item(),
|
| 72 |
-
"adv_loss": val_advloss.avg.item(),
|
| 73 |
-
"adv_top1": val_advtop1.avg.item(),
|
| 74 |
-
"adv_top5": val_advtop5.avg.item(),
|
| 75 |
-
"#samples": val_top1.count.item(),
|
| 76 |
-
"batch_size": images.shape[0],
|
| 77 |
-
"img_size": images.shape[2],
|
| 78 |
-
}
|
| 79 |
-
)
|
| 80 |
-
t.update()
|
| 81 |
-
|
| 82 |
-
val_results = {
|
| 83 |
-
"val_top1": val_top1.avg.item(),
|
| 84 |
-
"val_top5": val_top5.avg.item(),
|
| 85 |
-
"val_loss": val_loss.avg.item(),
|
| 86 |
-
"val_advtop1": val_advtop1.avg.item(),
|
| 87 |
-
"val_advtop5": val_advtop5.avg.item(),
|
| 88 |
-
"val_advloss": val_advloss.avg.item(),
|
| 89 |
-
}
|
| 90 |
-
return val_results
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
smi08/ProArd/attacks/base.py
DELETED
|
@@ -1,63 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
|
| 4 |
-
from .utils import replicate_input
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
class Attack(object):
|
| 8 |
-
"""
|
| 9 |
-
Abstract base class for all attack classes.
|
| 10 |
-
Arguments:
|
| 11 |
-
predict (nn.Module): forward pass function.
|
| 12 |
-
loss_fn (nn.Module): loss function.
|
| 13 |
-
clip_min (float): mininum value per input dimension.
|
| 14 |
-
clip_max (float): maximum value per input dimension.
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
-
def __init__(self, predict, loss_fn, clip_min, clip_max):
|
| 18 |
-
self.predict = predict
|
| 19 |
-
self.loss_fn = loss_fn
|
| 20 |
-
self.clip_min = clip_min
|
| 21 |
-
self.clip_max = clip_max
|
| 22 |
-
|
| 23 |
-
def perturb(self, x, **kwargs):
|
| 24 |
-
"""
|
| 25 |
-
Virtual method for generating the adversarial examples.
|
| 26 |
-
Arguments:
|
| 27 |
-
x (torch.Tensor): the model's input tensor.
|
| 28 |
-
**kwargs: optional parameters used by child classes.
|
| 29 |
-
Returns:
|
| 30 |
-
adversarial examples.
|
| 31 |
-
"""
|
| 32 |
-
error = "Sub-classes must implement perturb."
|
| 33 |
-
raise NotImplementedError(error)
|
| 34 |
-
|
| 35 |
-
def __call__(self, *args, **kwargs):
|
| 36 |
-
return self.perturb(*args, **kwargs)
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
class LabelMixin(object):
|
| 40 |
-
def _get_predicted_label(self, x):
|
| 41 |
-
"""
|
| 42 |
-
Compute predicted labels given x. Used to prevent label leaking during adversarial training.
|
| 43 |
-
Arguments:
|
| 44 |
-
x (torch.Tensor): the model's input tensor.
|
| 45 |
-
Returns:
|
| 46 |
-
torch.Tensor containing predicted labels.
|
| 47 |
-
"""
|
| 48 |
-
with torch.no_grad():
|
| 49 |
-
outputs = self.predict(x)
|
| 50 |
-
_, y = torch.max(outputs, dim=1)
|
| 51 |
-
return y
|
| 52 |
-
|
| 53 |
-
def _verify_and_process_inputs(self, x, y):
|
| 54 |
-
if self.targeted:
|
| 55 |
-
assert y is not None
|
| 56 |
-
|
| 57 |
-
if not self.targeted:
|
| 58 |
-
if y is None:
|
| 59 |
-
y = self._get_predicted_label(x)
|
| 60 |
-
|
| 61 |
-
x = replicate_input(x)
|
| 62 |
-
y = replicate_input(y)
|
| 63 |
-
return x,y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
smi08/ProArd/attacks/deepfool.py
DELETED
|
@@ -1,253 +0,0 @@
|
|
| 1 |
-
import copy
|
| 2 |
-
import numpy as np
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn as nn
|
| 5 |
-
from torch.autograd import Variable
|
| 6 |
-
from .utils import ctx_noparamgrad_and_eval
|
| 7 |
-
from .base import Attack, LabelMixin
|
| 8 |
-
from typing import Dict
|
| 9 |
-
from .utils import batch_multiply
|
| 10 |
-
from .utils import clamp
|
| 11 |
-
from .utils import is_float_or_torch_tensor
|
| 12 |
-
from utils.distributed import DistributedMetric
|
| 13 |
-
from tqdm import tqdm
|
| 14 |
-
from torchpack import distributed as dist
|
| 15 |
-
from utils import accuracy
|
| 16 |
-
|
| 17 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def perturb_deepfool(xvar, yvar, predict, nb_iter=50, overshoot=0.02, ord=np.inf, clip_min=0.0, clip_max=1.0,
|
| 21 |
-
search_iter=0, device=None):
|
| 22 |
-
"""
|
| 23 |
-
Compute DeepFool perturbations (Moosavi-Dezfooli et al, 2016).
|
| 24 |
-
Arguments:
|
| 25 |
-
xvar (torch.Tensor): input images.
|
| 26 |
-
yvar (torch.Tensor): predictions.
|
| 27 |
-
predict (nn.Module): forward pass function.
|
| 28 |
-
nb_iter (int): number of iterations.
|
| 29 |
-
overshoot (float): how much to overshoot the boundary.
|
| 30 |
-
ord (int): (optional) the order of maximum distortion (inf or 2).
|
| 31 |
-
clip_min (float): mininum value per input dimension.
|
| 32 |
-
clip_max (float): maximum value per input dimension.
|
| 33 |
-
search_iter (int): no of search iterations.
|
| 34 |
-
device (torch.device): device to work on.
|
| 35 |
-
Returns:
|
| 36 |
-
torch.Tensor containing the perturbed input,
|
| 37 |
-
torch.Tensor containing the perturbation
|
| 38 |
-
"""
|
| 39 |
-
|
| 40 |
-
x_orig = xvar
|
| 41 |
-
x = torch.empty_like(xvar).copy_(xvar)
|
| 42 |
-
x.requires_grad_(True)
|
| 43 |
-
|
| 44 |
-
batch_i = torch.arange(x.shape[0])
|
| 45 |
-
r_tot = torch.zeros_like(x.data)
|
| 46 |
-
for i in range(nb_iter):
|
| 47 |
-
if x.grad is not None:
|
| 48 |
-
x.grad.zero_()
|
| 49 |
-
|
| 50 |
-
logits = predict(x)
|
| 51 |
-
df_inds = np.argsort(logits.detach().cpu().numpy(), axis=-1)
|
| 52 |
-
df_inds_other, df_inds_orig = df_inds[:, :-1], df_inds[:, -1]
|
| 53 |
-
df_inds_orig = torch.from_numpy(df_inds_orig)
|
| 54 |
-
df_inds_orig = df_inds_orig.to(device)
|
| 55 |
-
not_done_inds = df_inds_orig == yvar
|
| 56 |
-
if not_done_inds.sum() == 0:
|
| 57 |
-
break
|
| 58 |
-
|
| 59 |
-
logits[batch_i, df_inds_orig].sum().backward(retain_graph=True)
|
| 60 |
-
grad_orig = x.grad.data.clone().detach()
|
| 61 |
-
pert = x.data.new_ones(x.shape[0]) * np.inf
|
| 62 |
-
w = torch.zeros_like(x.data)
|
| 63 |
-
|
| 64 |
-
for inds in df_inds_other.T:
|
| 65 |
-
x.grad.zero_()
|
| 66 |
-
logits[batch_i, inds].sum().backward(retain_graph=True)
|
| 67 |
-
grad_cur = x.grad.data.clone().detach()
|
| 68 |
-
with torch.no_grad():
|
| 69 |
-
w_k = grad_cur - grad_orig
|
| 70 |
-
f_k = logits[batch_i, inds] - logits[batch_i, df_inds_orig]
|
| 71 |
-
if ord == 2:
|
| 72 |
-
pert_k = torch.abs(f_k) / torch.norm(w_k.flatten(1), 2, -1)
|
| 73 |
-
elif ord == np.inf:
|
| 74 |
-
pert_k = torch.abs(f_k) / torch.norm(w_k.flatten(1), 1, -1)
|
| 75 |
-
else:
|
| 76 |
-
raise NotImplementedError("Only ord=inf and ord=2 have been implemented")
|
| 77 |
-
swi = pert_k < pert
|
| 78 |
-
if swi.sum() > 0:
|
| 79 |
-
pert[swi] = pert_k[swi]
|
| 80 |
-
w[swi] = w_k[swi]
|
| 81 |
-
|
| 82 |
-
if ord == 2:
|
| 83 |
-
r_i = (pert + 1e-6)[:, None, None, None] * w / torch.norm(w.flatten(1), 2, -1)[:, None, None, None]
|
| 84 |
-
elif ord == np.inf:
|
| 85 |
-
r_i = (pert + 1e-6)[:, None, None, None] * w.sign()
|
| 86 |
-
|
| 87 |
-
r_tot += r_i * not_done_inds[:, None, None, None].float()
|
| 88 |
-
x.data = x_orig + (1. + overshoot) * r_tot
|
| 89 |
-
x.data = torch.clamp(x.data, clip_min, clip_max)
|
| 90 |
-
|
| 91 |
-
x = x.detach()
|
| 92 |
-
if search_iter > 0:
|
| 93 |
-
dx = x - x_orig
|
| 94 |
-
dx_l_low, dx_l_high = torch.zeros_like(dx), torch.ones_like(dx)
|
| 95 |
-
for i in range(search_iter):
|
| 96 |
-
dx_l = (dx_l_low + dx_l_high) / 2.
|
| 97 |
-
dx_x = x_orig + dx_l * dx
|
| 98 |
-
dx_y = predict(dx_x).argmax(-1)
|
| 99 |
-
label_stay = dx_y == yvar
|
| 100 |
-
label_change = dx_y != yvar
|
| 101 |
-
dx_l_low[label_stay] = dx_l[label_stay]
|
| 102 |
-
dx_l_high[label_change] = dx_l[label_change]
|
| 103 |
-
x = dx_x
|
| 104 |
-
|
| 105 |
-
# x.data = torch.clamp(x.data, clip_min, clip_max)
|
| 106 |
-
r_tot = x.data - x_orig
|
| 107 |
-
return x, r_tot
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
class DeepFoolAttack(Attack, LabelMixin):
|
| 112 |
-
"""
|
| 113 |
-
DeepFool attack.
|
| 114 |
-
[Seyed-Mohsen Moosavi-Dezfooli, Alhussein Fawzi, Pascal Frossard,
|
| 115 |
-
"DeepFool: a simple and accurate method to fool deep neural networks"]
|
| 116 |
-
Arguments:
|
| 117 |
-
predict (nn.Module): forward pass function.
|
| 118 |
-
overshoot (float): how much to overshoot the boundary.
|
| 119 |
-
nb_iter (int): number of iterations.
|
| 120 |
-
search_iter (int): no of search iterations.
|
| 121 |
-
clip_min (float): mininum value per input dimension.
|
| 122 |
-
clip_max (float): maximum value per input dimension.
|
| 123 |
-
ord (int): (optional) the order of maximum distortion (inf or 2).
|
| 124 |
-
"""
|
| 125 |
-
|
| 126 |
-
def __init__(
|
| 127 |
-
self, predict, overshoot=0.02, nb_iter=50, search_iter=50, clip_min=0., clip_max=1., ord=np.inf):
|
| 128 |
-
super(DeepFoolAttack, self).__init__(predict, None, clip_min, clip_max)
|
| 129 |
-
self.overshoot = overshoot
|
| 130 |
-
self.nb_iter = nb_iter
|
| 131 |
-
self.search_iter = search_iter
|
| 132 |
-
self.targeted = False
|
| 133 |
-
|
| 134 |
-
self.ord = ord
|
| 135 |
-
assert is_float_or_torch_tensor(self.overshoot)
|
| 136 |
-
|
| 137 |
-
def perturb(self, x, y=None):
|
| 138 |
-
"""
|
| 139 |
-
Given examples x, returns their adversarial counterparts.
|
| 140 |
-
Arguments:
|
| 141 |
-
x (torch.Tensor): input tensor.
|
| 142 |
-
y (torch.Tensor): label tensor.
|
| 143 |
-
- if None and self.targeted=False, compute y as predicted labels.
|
| 144 |
-
Returns:
|
| 145 |
-
torch.Tensor containing perturbed inputs,
|
| 146 |
-
torch.Tensor containing the perturbation
|
| 147 |
-
"""
|
| 148 |
-
|
| 149 |
-
x, y = self._verify_and_process_inputs(x, None)
|
| 150 |
-
x_adv, r_adv = perturb_deepfool(x, y, self.predict, self.nb_iter, self.overshoot, ord=self.ord,
|
| 151 |
-
clip_min=self.clip_min, clip_max=self.clip_max, search_iter=self.search_iter,
|
| 152 |
-
device=device)
|
| 153 |
-
return x_adv, r_adv
|
| 154 |
-
def eval_deepfool(self,data_loader_dict: Dict)-> Dict:
|
| 155 |
-
|
| 156 |
-
test_criterion = nn.CrossEntropyLoss().cuda()
|
| 157 |
-
val_loss = DistributedMetric()
|
| 158 |
-
val_top1 = DistributedMetric()
|
| 159 |
-
val_top5 = DistributedMetric()
|
| 160 |
-
val_advloss = DistributedMetric()
|
| 161 |
-
val_advtop1 = DistributedMetric()
|
| 162 |
-
val_advtop5 = DistributedMetric()
|
| 163 |
-
self.predict.eval()
|
| 164 |
-
with tqdm(
|
| 165 |
-
total=len(data_loader_dict["val"]),
|
| 166 |
-
desc="Eval",
|
| 167 |
-
disable=not dist.is_master(),
|
| 168 |
-
) as t:
|
| 169 |
-
for images, labels in data_loader_dict["val"]:
|
| 170 |
-
images, labels = images.cuda(), labels.cuda()
|
| 171 |
-
# compute output
|
| 172 |
-
output = self.predict(images)
|
| 173 |
-
loss = test_criterion(output, labels)
|
| 174 |
-
val_loss.update(loss, images.shape[0])
|
| 175 |
-
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
| 176 |
-
val_top5.update(acc5[0], images.shape[0])
|
| 177 |
-
val_top1.update(acc1[0], images.shape[0])
|
| 178 |
-
with ctx_noparamgrad_and_eval(self.predict):
|
| 179 |
-
images_adv,_ = self.perturb(images, labels)
|
| 180 |
-
output_adv = self.predict(images_adv)
|
| 181 |
-
loss_adv = test_criterion(output_adv,labels)
|
| 182 |
-
val_advloss.update(loss_adv, images.shape[0])
|
| 183 |
-
acc1_adv, acc5_adv = accuracy(output_adv, labels, topk=(1, 5))
|
| 184 |
-
val_advtop1.update(acc1_adv[0], images.shape[0])
|
| 185 |
-
val_advtop5.update(acc5_adv[0], images.shape[0])
|
| 186 |
-
t.set_postfix(
|
| 187 |
-
{
|
| 188 |
-
"loss": val_loss.avg.item(),
|
| 189 |
-
"top1": val_top1.avg.item(),
|
| 190 |
-
"top5": val_top5.avg.item(),
|
| 191 |
-
"adv_loss": val_advloss.avg.item(),
|
| 192 |
-
"adv_top1": val_advtop1.avg.item(),
|
| 193 |
-
"adv_top5": val_advtop5.avg.item(),
|
| 194 |
-
"#samples": val_top1.count.item(),
|
| 195 |
-
"batch_size": images.shape[0],
|
| 196 |
-
"img_size": images.shape[2],
|
| 197 |
-
}
|
| 198 |
-
)
|
| 199 |
-
t.update()
|
| 200 |
-
|
| 201 |
-
val_results = {
|
| 202 |
-
"val_top1": val_top1.avg.item(),
|
| 203 |
-
"val_top5": val_top5.avg.item(),
|
| 204 |
-
"val_loss": val_loss.avg.item(),
|
| 205 |
-
"val_advtop1": val_advtop1.avg.item(),
|
| 206 |
-
"val_advtop5": val_advtop5.avg.item(),
|
| 207 |
-
"val_advloss": val_advloss.avg.item(),
|
| 208 |
-
}
|
| 209 |
-
return val_results
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
class LinfDeepFoolAttack(DeepFoolAttack):
|
| 213 |
-
"""
|
| 214 |
-
DeepFool Attack with order=Linf.
|
| 215 |
-
Arguments:
|
| 216 |
-
Arguments:
|
| 217 |
-
predict (nn.Module): forward pass function.
|
| 218 |
-
overshoot (float): how much to overshoot the boundary.
|
| 219 |
-
nb_iter (int): number of iterations.
|
| 220 |
-
search_iter (int): no of search iterations.
|
| 221 |
-
clip_min (float): mininum value per input dimension.
|
| 222 |
-
clip_max (float): maximum value per input dimension.
|
| 223 |
-
"""
|
| 224 |
-
|
| 225 |
-
def __init__(
|
| 226 |
-
self, predict, overshoot=0.02, nb_iter=50, search_iter=50, clip_min=0., clip_max=1.):
|
| 227 |
-
|
| 228 |
-
ord = np.inf
|
| 229 |
-
super(LinfDeepFoolAttack, self).__init__(
|
| 230 |
-
predict=predict, overshoot=overshoot, nb_iter=nb_iter, search_iter=search_iter, clip_min=clip_min,
|
| 231 |
-
clip_max=clip_max, ord=ord)
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
class L2DeepFoolAttack(DeepFoolAttack):
|
| 236 |
-
"""
|
| 237 |
-
DeepFool Attack with order=L2.
|
| 238 |
-
Arguments:
|
| 239 |
-
predict (nn.Module): forward pass function.
|
| 240 |
-
overshoot (float): how much to overshoot the boundary.
|
| 241 |
-
nb_iter (int): number of iterations.
|
| 242 |
-
search_iter (int): no of search iterations.
|
| 243 |
-
clip_min (float): mininum value per input dimension.
|
| 244 |
-
clip_max (float): maximum value per input dimension.
|
| 245 |
-
"""
|
| 246 |
-
|
| 247 |
-
def __init__(
|
| 248 |
-
self, predict, overshoot=0.02, nb_iter=50, search_iter=50, clip_min=0., clip_max=1.):
|
| 249 |
-
|
| 250 |
-
ord = 2
|
| 251 |
-
super(L2DeepFoolAttack, self).__init__(
|
| 252 |
-
predict=predict, overshoot=overshoot, nb_iter=nb_iter, search_iter=search_iter, clip_min=clip_min,
|
| 253 |
-
clip_max=clip_max, ord=ord)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
smi08/ProArd/attacks/fgsm.py
DELETED
|
@@ -1,171 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
|
| 4 |
-
from .base import Attack, LabelMixin
|
| 5 |
-
from .utils import ctx_noparamgrad_and_eval
|
| 6 |
-
from .utils import batch_multiply
|
| 7 |
-
from .utils import clamp ,normalize_by_pnorm
|
| 8 |
-
from utils.distributed import DistributedMetric
|
| 9 |
-
from tqdm import tqdm
|
| 10 |
-
from torchpack import distributed as dist
|
| 11 |
-
from utils import accuracy
|
| 12 |
-
from typing import Dict
|
| 13 |
-
class FGSMAttack(Attack, LabelMixin):
|
| 14 |
-
"""
|
| 15 |
-
One step fast gradient sign method (Goodfellow et al, 2014).
|
| 16 |
-
Arguments:
|
| 17 |
-
predict (nn.Module): forward pass function.
|
| 18 |
-
loss_fn (nn.Module): loss function.
|
| 19 |
-
eps (float): attack step size.
|
| 20 |
-
clip_min (float): mininum value per input dimension.
|
| 21 |
-
clip_max (float): maximum value per input dimension.
|
| 22 |
-
targeted (bool): indicate if this is a targeted attack.
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
-
def __init__(self, predict, loss_fn=None, eps=0.3, clip_min=0., clip_max=1., targeted=False):
|
| 26 |
-
super(FGSMAttack, self).__init__(predict, loss_fn, clip_min, clip_max)
|
| 27 |
-
|
| 28 |
-
self.eps = eps
|
| 29 |
-
self.targeted = targeted
|
| 30 |
-
if self.loss_fn is None:
|
| 31 |
-
self.loss_fn = nn.CrossEntropyLoss(reduction="sum")
|
| 32 |
-
|
| 33 |
-
def perturb(self, x, y=None):
|
| 34 |
-
"""
|
| 35 |
-
Given examples (x, y), returns their adversarial counterparts with an attack length of eps.
|
| 36 |
-
Arguments:
|
| 37 |
-
x (torch.Tensor): input tensor.
|
| 38 |
-
y (torch.Tensor): label tensor.
|
| 39 |
-
- if None and self.targeted=False, compute y as predicted labels.
|
| 40 |
-
- if self.targeted=True, then y must be the targeted labels.
|
| 41 |
-
Returns:
|
| 42 |
-
torch.Tensor containing perturbed inputs.
|
| 43 |
-
torch.Tensor containing the perturbation.
|
| 44 |
-
"""
|
| 45 |
-
|
| 46 |
-
x, y = self._verify_and_process_inputs(x, y)
|
| 47 |
-
|
| 48 |
-
xadv = x.requires_grad_()
|
| 49 |
-
outputs = self.predict(xadv)
|
| 50 |
-
|
| 51 |
-
loss = self.loss_fn(outputs, y)
|
| 52 |
-
if self.targeted:
|
| 53 |
-
loss = -loss
|
| 54 |
-
loss.backward()
|
| 55 |
-
grad_sign = xadv.grad.detach().sign()
|
| 56 |
-
|
| 57 |
-
xadv = xadv + batch_multiply(self.eps, grad_sign)
|
| 58 |
-
xadv = clamp(xadv, self.clip_min, self.clip_max)
|
| 59 |
-
radv = xadv - x
|
| 60 |
-
return xadv.detach(), radv.detach()
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
LinfFastGradientAttack = FGSMAttack
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
class FGMAttack(Attack, LabelMixin):
|
| 67 |
-
"""
|
| 68 |
-
One step fast gradient method. Perturbs the input with gradient (not gradient sign) of the loss wrt the input.
|
| 69 |
-
Arguments:
|
| 70 |
-
predict (nn.Module): forward pass function.
|
| 71 |
-
loss_fn (nn.Module): loss function.
|
| 72 |
-
eps (float): attack step size.
|
| 73 |
-
clip_min (float): mininum value per input dimension.
|
| 74 |
-
clip_max (float): maximum value per input dimension.
|
| 75 |
-
targeted (bool): indicate if this is a targeted attack.
|
| 76 |
-
"""
|
| 77 |
-
|
| 78 |
-
def __init__(self, predict, loss_fn=None, eps=0.3, clip_min=0., clip_max=1., targeted=False):
|
| 79 |
-
super(FGMAttack, self).__init__(
|
| 80 |
-
predict, loss_fn, clip_min, clip_max)
|
| 81 |
-
|
| 82 |
-
self.eps = eps
|
| 83 |
-
self.targeted = targeted
|
| 84 |
-
if self.loss_fn is None:
|
| 85 |
-
self.loss_fn = nn.CrossEntropyLoss(reduction="sum")
|
| 86 |
-
|
| 87 |
-
def perturb(self, x, y=None):
|
| 88 |
-
"""
|
| 89 |
-
Given examples (x, y), returns their adversarial counterparts with an attack length of eps.
|
| 90 |
-
Arguments:
|
| 91 |
-
x (torch.Tensor): input tensor.
|
| 92 |
-
y (torch.Tensor): label tensor.
|
| 93 |
-
- if None and self.targeted=False, compute y as predicted labels.
|
| 94 |
-
- if self.targeted=True, then y must be the targeted labels.
|
| 95 |
-
Returns:
|
| 96 |
-
torch.Tensor containing perturbed inputs.
|
| 97 |
-
torch.Tensor containing the perturbation.
|
| 98 |
-
"""
|
| 99 |
-
|
| 100 |
-
x, y = self._verify_and_process_inputs(x, y)
|
| 101 |
-
xadv = x.requires_grad_()
|
| 102 |
-
outputs = self.predict(xadv)
|
| 103 |
-
|
| 104 |
-
loss = self.loss_fn(outputs, y)
|
| 105 |
-
if self.targeted:
|
| 106 |
-
loss = -loss
|
| 107 |
-
loss.backward()
|
| 108 |
-
grad = normalize_by_pnorm(xadv.grad)
|
| 109 |
-
xadv = xadv + batch_multiply(self.eps, grad)
|
| 110 |
-
xadv = clamp(xadv, self.clip_min, self.clip_max)
|
| 111 |
-
radv = xadv - x
|
| 112 |
-
|
| 113 |
-
return xadv.detach(), radv.detach()
|
| 114 |
-
def eval_fgsm(self,data_loader_dict: Dict)-> Dict:
|
| 115 |
-
|
| 116 |
-
test_criterion = nn.CrossEntropyLoss().cuda()
|
| 117 |
-
val_loss = DistributedMetric()
|
| 118 |
-
val_top1 = DistributedMetric()
|
| 119 |
-
val_top5 = DistributedMetric()
|
| 120 |
-
val_advloss = DistributedMetric()
|
| 121 |
-
val_advtop1 = DistributedMetric()
|
| 122 |
-
val_advtop5 = DistributedMetric()
|
| 123 |
-
self.predict.eval()
|
| 124 |
-
with tqdm(
|
| 125 |
-
total=len(data_loader_dict["val"]),
|
| 126 |
-
desc="Eval",
|
| 127 |
-
disable=not dist.is_master(),
|
| 128 |
-
) as t:
|
| 129 |
-
for images, labels in data_loader_dict["val"]:
|
| 130 |
-
images, labels = images.cuda(), labels.cuda()
|
| 131 |
-
# compute output
|
| 132 |
-
output = self.predict(images)
|
| 133 |
-
loss = test_criterion(output, labels)
|
| 134 |
-
val_loss.update(loss, images.shape[0])
|
| 135 |
-
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
| 136 |
-
val_top5.update(acc5[0], images.shape[0])
|
| 137 |
-
val_top1.update(acc1[0], images.shape[0])
|
| 138 |
-
with ctx_noparamgrad_and_eval(self.predict):
|
| 139 |
-
images_adv,_ = self.perturb(images, labels)
|
| 140 |
-
output_adv = self.predict(images_adv)
|
| 141 |
-
loss_adv = test_criterion(output_adv,labels)
|
| 142 |
-
val_advloss.update(loss_adv, images.shape[0])
|
| 143 |
-
acc1_adv, acc5_adv = accuracy(output_adv, labels, topk=(1, 5))
|
| 144 |
-
val_advtop1.update(acc1_adv[0], images.shape[0])
|
| 145 |
-
val_advtop5.update(acc5_adv[0], images.shape[0])
|
| 146 |
-
t.set_postfix(
|
| 147 |
-
{
|
| 148 |
-
"loss": val_loss.avg.item(),
|
| 149 |
-
"top1": val_top1.avg.item(),
|
| 150 |
-
"top5": val_top5.avg.item(),
|
| 151 |
-
"adv_loss": val_advloss.avg.item(),
|
| 152 |
-
"adv_top1": val_advtop1.avg.item(),
|
| 153 |
-
"adv_top5": val_advtop5.avg.item(),
|
| 154 |
-
"#samples": val_top1.count.item(),
|
| 155 |
-
"batch_size": images.shape[0],
|
| 156 |
-
"img_size": images.shape[2],
|
| 157 |
-
}
|
| 158 |
-
)
|
| 159 |
-
t.update()
|
| 160 |
-
|
| 161 |
-
val_results = {
|
| 162 |
-
"val_top1": val_top1.avg.item(),
|
| 163 |
-
"val_top5": val_top5.avg.item(),
|
| 164 |
-
"val_loss": val_loss.avg.item(),
|
| 165 |
-
"val_advtop1": val_advtop1.avg.item(),
|
| 166 |
-
"val_advtop5": val_advtop5.avg.item(),
|
| 167 |
-
"val_advloss": val_advloss.avg.item(),
|
| 168 |
-
}
|
| 169 |
-
return val_results
|
| 170 |
-
|
| 171 |
-
L2FastGradientAttack = FGMAttack
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
smi08/ProArd/attacks/local_lip.py
DELETED
|
@@ -1,29 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
from typing import Dict
|
| 4 |
-
from utils.distributed import DistributedMetric
|
| 5 |
-
from tqdm import tqdm
|
| 6 |
-
from torchpack import distributed as dist
|
| 7 |
-
from utils import accuracy
|
| 8 |
-
import copy
|
| 9 |
-
import torch.nn.functional as F
|
| 10 |
-
import numpy as np
|
| 11 |
-
def eval_local_lip(model, x, xp, top_norm=1, btm_norm=float('inf'), reduction='mean'):
|
| 12 |
-
model.eval()
|
| 13 |
-
down = torch.flatten(x - xp, start_dim=1)
|
| 14 |
-
with torch.no_grad():
|
| 15 |
-
if top_norm == "kl":
|
| 16 |
-
criterion_kl = nn.KLDivLoss(reduction='none')
|
| 17 |
-
top = criterion_kl(F.log_softmax(model(xp), dim=1),
|
| 18 |
-
F.softmax(model(x), dim=1))
|
| 19 |
-
ret = torch.sum(top, dim=1) / torch.norm(down + 1e-6, dim=1, p=btm_norm)
|
| 20 |
-
else:
|
| 21 |
-
top = torch.flatten(model(x), start_dim=1) - torch.flatten(model(xp), start_dim=1)
|
| 22 |
-
ret = torch.norm(top, dim=1, p=top_norm) / torch.norm(down + 1e-6, dim=1, p=btm_norm)
|
| 23 |
-
|
| 24 |
-
if reduction == 'mean':
|
| 25 |
-
return torch.mean(ret)
|
| 26 |
-
elif reduction == 'sum':
|
| 27 |
-
return torch.sum(ret)
|
| 28 |
-
else:
|
| 29 |
-
raise ValueError("Not supported reduction")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
smi08/ProArd/attacks/pgd.py
DELETED
|
@@ -1,248 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
import torch
|
| 3 |
-
import torch.nn as nn
|
| 4 |
-
from .utils import ctx_noparamgrad_and_eval
|
| 5 |
-
from .base import Attack, LabelMixin
|
| 6 |
-
from typing import Dict
|
| 7 |
-
from .utils import batch_clamp
|
| 8 |
-
from .utils import batch_multiply
|
| 9 |
-
from .utils import clamp
|
| 10 |
-
from .utils import clamp_by_pnorm
|
| 11 |
-
from .utils import is_float_or_torch_tensor
|
| 12 |
-
from .utils import normalize_by_pnorm
|
| 13 |
-
from .utils import rand_init_delta
|
| 14 |
-
from .utils import replicate_input
|
| 15 |
-
from utils.distributed import DistributedMetric
|
| 16 |
-
from tqdm import tqdm
|
| 17 |
-
from torchpack import distributed as dist
|
| 18 |
-
from utils import accuracy
|
| 19 |
-
|
| 20 |
-
def perturb_iterative(xvar, yvar, predict, nb_iter, eps, eps_iter, loss_fn, delta_init=None, minimize=False, ord=np.inf,
|
| 21 |
-
clip_min=0.0, clip_max=1.0):
|
| 22 |
-
"""
|
| 23 |
-
Iteratively maximize the loss over the input. It is a shared method for iterative attacks.
|
| 24 |
-
Arguments:
|
| 25 |
-
xvar (torch.Tensor): input data.
|
| 26 |
-
yvar (torch.Tensor): input labels.
|
| 27 |
-
predict (nn.Module): forward pass function.
|
| 28 |
-
nb_iter (int): number of iterations.
|
| 29 |
-
eps (float): maximum distortion.
|
| 30 |
-
eps_iter (float): attack step size.
|
| 31 |
-
loss_fn (nn.Module): loss function.
|
| 32 |
-
delta_init (torch.Tensor): (optional) tensor contains the random initialization.
|
| 33 |
-
minimize (bool): (optional) whether to minimize or maximize the loss.
|
| 34 |
-
ord (int): (optional) the order of maximum distortion (inf or 2).
|
| 35 |
-
clip_min (float): mininum value per input dimension.
|
| 36 |
-
clip_max (float): maximum value per input dimension.
|
| 37 |
-
Returns:
|
| 38 |
-
torch.Tensor containing the perturbed input,
|
| 39 |
-
torch.Tensor containing the perturbation
|
| 40 |
-
"""
|
| 41 |
-
if delta_init is not None:
|
| 42 |
-
delta = delta_init
|
| 43 |
-
else:
|
| 44 |
-
delta = torch.zeros_like(xvar)
|
| 45 |
-
|
| 46 |
-
delta.requires_grad_()
|
| 47 |
-
for ii in range(nb_iter):
|
| 48 |
-
outputs = predict(xvar + delta)
|
| 49 |
-
loss = loss_fn(outputs, yvar)
|
| 50 |
-
if minimize:
|
| 51 |
-
loss = -loss
|
| 52 |
-
|
| 53 |
-
loss.backward()
|
| 54 |
-
if ord == np.inf:
|
| 55 |
-
grad_sign = delta.grad.data.sign()
|
| 56 |
-
delta.data = delta.data + batch_multiply(eps_iter, grad_sign)
|
| 57 |
-
delta.data = batch_clamp(eps, delta.data)
|
| 58 |
-
delta.data = clamp(xvar.data + delta.data, clip_min, clip_max) - xvar.data
|
| 59 |
-
elif ord == 2:
|
| 60 |
-
grad = delta.grad.data
|
| 61 |
-
grad = normalize_by_pnorm(grad)
|
| 62 |
-
delta.data = delta.data + batch_multiply(eps_iter, grad)
|
| 63 |
-
delta.data = clamp(xvar.data + delta.data, clip_min, clip_max) - xvar.data
|
| 64 |
-
if eps is not None:
|
| 65 |
-
delta.data = clamp_by_pnorm(delta.data, ord, eps)
|
| 66 |
-
else:
|
| 67 |
-
error = "Only ord=inf and ord=2 have been implemented"
|
| 68 |
-
raise NotImplementedError(error)
|
| 69 |
-
delta.grad.data.zero_()
|
| 70 |
-
|
| 71 |
-
x_adv = clamp(xvar + delta, clip_min, clip_max)
|
| 72 |
-
r_adv = x_adv - xvar
|
| 73 |
-
return x_adv, r_adv
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
class PGDAttack(Attack, LabelMixin):
|
| 77 |
-
"""
|
| 78 |
-
The projected gradient descent attack (Madry et al, 2017).
|
| 79 |
-
The attack performs nb_iter steps of size eps_iter, while always staying within eps from the initial point.
|
| 80 |
-
Arguments:
|
| 81 |
-
predict (nn.Module): forward pass function.
|
| 82 |
-
loss_fn (nn.Module): loss function.
|
| 83 |
-
eps (float): maximum distortion.
|
| 84 |
-
nb_iter (int): number of iterations.
|
| 85 |
-
eps_iter (float): attack step size.
|
| 86 |
-
rand_init (bool): (optional) random initialization.
|
| 87 |
-
clip_min (float): mininum value per input dimension.
|
| 88 |
-
clip_max (float): maximum value per input dimension.
|
| 89 |
-
ord (int): (optional) the order of maximum distortion (inf or 2).
|
| 90 |
-
targeted (bool): if the attack is targeted.
|
| 91 |
-
rand_init_type (str): (optional) random initialization type.
|
| 92 |
-
"""
|
| 93 |
-
|
| 94 |
-
def __init__(
|
| 95 |
-
self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1.,
|
| 96 |
-
ord=np.inf, targeted=False, rand_init_type='uniform'):
|
| 97 |
-
super(PGDAttack, self).__init__(predict, loss_fn, clip_min, clip_max)
|
| 98 |
-
self.eps = eps
|
| 99 |
-
self.nb_iter = nb_iter
|
| 100 |
-
self.eps_iter = eps_iter
|
| 101 |
-
self.rand_init = rand_init
|
| 102 |
-
self.rand_init_type = rand_init_type
|
| 103 |
-
self.ord = ord
|
| 104 |
-
self.targeted = targeted
|
| 105 |
-
if self.loss_fn is None:
|
| 106 |
-
self.loss_fn = nn.CrossEntropyLoss(reduction="sum")
|
| 107 |
-
assert is_float_or_torch_tensor(self.eps_iter)
|
| 108 |
-
assert is_float_or_torch_tensor(self.eps)
|
| 109 |
-
|
| 110 |
-
def perturb(self, x, y=None):
|
| 111 |
-
"""
|
| 112 |
-
Given examples (x, y), returns their adversarial counterparts with an attack length of eps.
|
| 113 |
-
Arguments:
|
| 114 |
-
x (torch.Tensor): input tensor.
|
| 115 |
-
y (torch.Tensor): label tensor.
|
| 116 |
-
- if None and self.targeted=False, compute y as predicted
|
| 117 |
-
labels.
|
| 118 |
-
- if self.targeted=True, then y must be the targeted labels.
|
| 119 |
-
Returns:
|
| 120 |
-
torch.Tensor containing perturbed inputs,
|
| 121 |
-
torch.Tensor containing the perturbation
|
| 122 |
-
"""
|
| 123 |
-
x, y = self._verify_and_process_inputs(x, y)
|
| 124 |
-
|
| 125 |
-
delta = torch.zeros_like(x)
|
| 126 |
-
delta = nn.Parameter(delta)
|
| 127 |
-
if self.rand_init:
|
| 128 |
-
if self.rand_init_type == 'uniform':
|
| 129 |
-
rand_init_delta(
|
| 130 |
-
delta, x, self.ord, self.eps, self.clip_min, self.clip_max)
|
| 131 |
-
delta.data = clamp(
|
| 132 |
-
x + delta.data, min=self.clip_min, max=self.clip_max) - x
|
| 133 |
-
elif self.rand_init_type == 'normal':
|
| 134 |
-
delta.data = 0.001 * torch.randn_like(x) # initialize as in TRADES
|
| 135 |
-
else:
|
| 136 |
-
raise NotImplementedError('Only rand_init_type=normal and rand_init_type=uniform have been implemented.')
|
| 137 |
-
|
| 138 |
-
x_adv, r_adv = perturb_iterative(
|
| 139 |
-
x, y, self.predict, nb_iter=self.nb_iter, eps=self.eps, eps_iter=self.eps_iter, loss_fn=self.loss_fn,
|
| 140 |
-
minimize=self.targeted, ord=self.ord, clip_min=self.clip_min, clip_max=self.clip_max, delta_init=delta
|
| 141 |
-
)
|
| 142 |
-
|
| 143 |
-
return x_adv.data, r_adv.data
|
| 144 |
-
|
| 145 |
-
def eval_pgd(self,data_loader_dict: Dict)-> Dict:
|
| 146 |
-
|
| 147 |
-
test_criterion = nn.CrossEntropyLoss().cuda()
|
| 148 |
-
val_loss = DistributedMetric()
|
| 149 |
-
val_top1 = DistributedMetric()
|
| 150 |
-
val_top5 = DistributedMetric()
|
| 151 |
-
val_advloss = DistributedMetric()
|
| 152 |
-
val_advtop1 = DistributedMetric()
|
| 153 |
-
val_advtop5 = DistributedMetric()
|
| 154 |
-
self.predict.eval()
|
| 155 |
-
with tqdm(
|
| 156 |
-
total=len(data_loader_dict["val"]),
|
| 157 |
-
desc="Eval",
|
| 158 |
-
disable=not dist.is_master(),
|
| 159 |
-
) as t:
|
| 160 |
-
for images, labels in data_loader_dict["val"]:
|
| 161 |
-
images, labels = images.cuda(), labels.cuda()
|
| 162 |
-
# compute output
|
| 163 |
-
output = self.predict(images)
|
| 164 |
-
loss = test_criterion(output, labels)
|
| 165 |
-
val_loss.update(loss, images.shape[0])
|
| 166 |
-
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
| 167 |
-
val_top5.update(acc5[0], images.shape[0])
|
| 168 |
-
val_top1.update(acc1[0], images.shape[0])
|
| 169 |
-
with ctx_noparamgrad_and_eval(self.predict):
|
| 170 |
-
images_adv,_ = self.perturb(images, labels)
|
| 171 |
-
output_adv = self.predict(images_adv)
|
| 172 |
-
loss_adv = test_criterion(output_adv,labels)
|
| 173 |
-
val_advloss.update(loss_adv, images.shape[0])
|
| 174 |
-
acc1_adv, acc5_adv = accuracy(output_adv, labels, topk=(1, 5))
|
| 175 |
-
val_advtop1.update(acc1_adv[0], images.shape[0])
|
| 176 |
-
val_advtop5.update(acc5_adv[0], images.shape[0])
|
| 177 |
-
t.set_postfix(
|
| 178 |
-
{
|
| 179 |
-
"loss": val_loss.avg.item(),
|
| 180 |
-
"top1": val_top1.avg.item(),
|
| 181 |
-
"top5": val_top5.avg.item(),
|
| 182 |
-
"adv_loss": val_advloss.avg.item(),
|
| 183 |
-
"adv_top1": val_advtop1.avg.item(),
|
| 184 |
-
"adv_top5": val_advtop5.avg.item(),
|
| 185 |
-
"#samples": val_top1.count.item(),
|
| 186 |
-
"batch_size": images.shape[0],
|
| 187 |
-
"img_size": images.shape[2],
|
| 188 |
-
}
|
| 189 |
-
)
|
| 190 |
-
t.update()
|
| 191 |
-
|
| 192 |
-
val_results = {
|
| 193 |
-
"val_top1": val_top1.avg.item(),
|
| 194 |
-
"val_top5": val_top5.avg.item(),
|
| 195 |
-
"val_loss": val_loss.avg.item(),
|
| 196 |
-
"val_advtop1": val_advtop1.avg.item(),
|
| 197 |
-
"val_advtop5": val_advtop5.avg.item(),
|
| 198 |
-
"val_advloss": val_advloss.avg.item(),
|
| 199 |
-
}
|
| 200 |
-
return val_results
|
| 201 |
-
class LinfPGDAttack(PGDAttack):
|
| 202 |
-
"""
|
| 203 |
-
PGD Attack with order=Linf
|
| 204 |
-
Arguments:
|
| 205 |
-
predict (nn.Module): forward pass function.
|
| 206 |
-
loss_fn (nn.Module): loss function.
|
| 207 |
-
eps (float): maximum distortion.
|
| 208 |
-
nb_iter (int): number of iterations.
|
| 209 |
-
eps_iter (float): attack step size.
|
| 210 |
-
rand_init (bool): (optional) random initialization.
|
| 211 |
-
clip_min (float): mininum value per input dimension.
|
| 212 |
-
clip_max (float): maximum value per input dimension.
|
| 213 |
-
targeted (bool): if the attack is targeted.
|
| 214 |
-
rand_init_type (str): (optional) random initialization type.
|
| 215 |
-
"""
|
| 216 |
-
|
| 217 |
-
def __init__(
|
| 218 |
-
self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1.,
|
| 219 |
-
targeted=False, rand_init_type='uniform'):
|
| 220 |
-
ord = np.inf
|
| 221 |
-
super(LinfPGDAttack, self).__init__(
|
| 222 |
-
predict=predict, loss_fn=loss_fn, eps=eps, nb_iter=nb_iter, eps_iter=eps_iter, rand_init=rand_init,
|
| 223 |
-
clip_min=clip_min, clip_max=clip_max, targeted=targeted, ord=ord, rand_init_type=rand_init_type)
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
class L2PGDAttack(PGDAttack):
|
| 227 |
-
"""
|
| 228 |
-
PGD Attack with order=L2
|
| 229 |
-
Arguments:
|
| 230 |
-
predict (nn.Module): forward pass function.
|
| 231 |
-
loss_fn (nn.Module): loss function.
|
| 232 |
-
eps (float): maximum distortion.
|
| 233 |
-
nb_iter (int): number of iterations.
|
| 234 |
-
eps_iter (float): attack step size.
|
| 235 |
-
rand_init (bool): (optional) random initialization.
|
| 236 |
-
clip_min (float): mininum value per input dimension.
|
| 237 |
-
clip_max (float): maximum value per input dimension.
|
| 238 |
-
targeted (bool): if the attack is targeted.
|
| 239 |
-
rand_init_type (str): (optional) random initialization type.
|
| 240 |
-
"""
|
| 241 |
-
|
| 242 |
-
def __init__(
|
| 243 |
-
self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1.,
|
| 244 |
-
targeted=False, rand_init_type='uniform'):
|
| 245 |
-
ord = 2
|
| 246 |
-
super(L2PGDAttack, self).__init__(
|
| 247 |
-
predict=predict, loss_fn=loss_fn, eps=eps, nb_iter=nb_iter, eps_iter=eps_iter, rand_init=rand_init,
|
| 248 |
-
clip_min=clip_min, clip_max=clip_max, targeted=targeted, ord=ord, rand_init_type=rand_init_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
smi08/ProArd/attacks/squred.py
DELETED
|
@@ -1,86 +0,0 @@
|
|
| 1 |
-
from autoattack import AutoAttack
|
| 2 |
-
import numpy as np
|
| 3 |
-
from .base import Attack,LabelMixin
|
| 4 |
-
import torch.nn as nn
|
| 5 |
-
from utils.distributed import DistributedMetric
|
| 6 |
-
from tqdm import tqdm
|
| 7 |
-
from torchpack import distributed as dist
|
| 8 |
-
from utils import accuracy
|
| 9 |
-
from typing import Dict
|
| 10 |
-
from .utils import ctx_noparamgrad_and_eval
|
| 11 |
-
class Squre_Attack(Attack, LabelMixin):
|
| 12 |
-
|
| 13 |
-
def __init__(
|
| 14 |
-
self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1.,
|
| 15 |
-
ord=np.inf, targeted=False, rand_init_type='uniform'):
|
| 16 |
-
super(Squre_Attack, self).__init__(predict, loss_fn, clip_min, clip_max)
|
| 17 |
-
self.eps = eps
|
| 18 |
-
self.nb_iter = nb_iter
|
| 19 |
-
self.eps_iter = eps_iter
|
| 20 |
-
self.rand_init = rand_init
|
| 21 |
-
self.rand_init_type = rand_init_type
|
| 22 |
-
self.ord = ord
|
| 23 |
-
self.targeted = targeted
|
| 24 |
-
if self.loss_fn is None:
|
| 25 |
-
self.loss_fn = nn.CrossEntropyLoss(reduction="sum")
|
| 26 |
-
self.adversary = AutoAttack(predict, norm='Linf', eps=self.eps, version='standard')
|
| 27 |
-
def perturb(self, x, y=None):
|
| 28 |
-
self.adversary.attacks_to_run=['square']
|
| 29 |
-
adversarial_examples = self.adversary.run_standard_evaluation(x, y, bs=100)
|
| 30 |
-
return adversarial_examples,adversarial_examples
|
| 31 |
-
def eval_squred(self,data_loader_dict: Dict)-> Dict:
|
| 32 |
-
|
| 33 |
-
test_criterion = nn.CrossEntropyLoss().cuda()
|
| 34 |
-
val_loss = DistributedMetric()
|
| 35 |
-
val_top1 = DistributedMetric()
|
| 36 |
-
val_top5 = DistributedMetric()
|
| 37 |
-
val_advloss = DistributedMetric()
|
| 38 |
-
val_advtop1 = DistributedMetric()
|
| 39 |
-
val_advtop5 = DistributedMetric()
|
| 40 |
-
self.predict.eval()
|
| 41 |
-
with tqdm(
|
| 42 |
-
total=len(data_loader_dict["val"]),
|
| 43 |
-
desc="Eval",
|
| 44 |
-
disable=not dist.is_master(),
|
| 45 |
-
) as t:
|
| 46 |
-
for images, labels in data_loader_dict["val"]:
|
| 47 |
-
images, labels = images.cuda(), labels.cuda()
|
| 48 |
-
# compute output
|
| 49 |
-
output = self.predict(images)
|
| 50 |
-
loss = test_criterion(output, labels)
|
| 51 |
-
val_loss.update(loss, images.shape[0])
|
| 52 |
-
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
| 53 |
-
val_top5.update(acc5[0], images.shape[0])
|
| 54 |
-
val_top1.update(acc1[0], images.shape[0])
|
| 55 |
-
with ctx_noparamgrad_and_eval(self.predict):
|
| 56 |
-
images_adv,_ = self.perturb(images, labels)
|
| 57 |
-
output_adv = self.predict(images_adv)
|
| 58 |
-
loss_adv = test_criterion(output_adv,labels)
|
| 59 |
-
val_advloss.update(loss_adv, images.shape[0])
|
| 60 |
-
acc1_adv, acc5_adv = accuracy(output_adv, labels, topk=(1, 5))
|
| 61 |
-
val_advtop1.update(acc1_adv[0], images.shape[0])
|
| 62 |
-
val_advtop5.update(acc5_adv[0], images.shape[0])
|
| 63 |
-
t.set_postfix(
|
| 64 |
-
{
|
| 65 |
-
"loss": val_loss.avg.item(),
|
| 66 |
-
"top1": val_top1.avg.item(),
|
| 67 |
-
"top5": val_top5.avg.item(),
|
| 68 |
-
"adv_loss": val_advloss.avg.item(),
|
| 69 |
-
"adv_top1": val_advtop1.avg.item(),
|
| 70 |
-
"adv_top5": val_advtop5.avg.item(),
|
| 71 |
-
"#samples": val_top1.count.item(),
|
| 72 |
-
"batch_size": images.shape[0],
|
| 73 |
-
"img_size": images.shape[2],
|
| 74 |
-
}
|
| 75 |
-
)
|
| 76 |
-
t.update()
|
| 77 |
-
|
| 78 |
-
val_results = {
|
| 79 |
-
"val_top1": val_top1.avg.item(),
|
| 80 |
-
"val_top5": val_top5.avg.item(),
|
| 81 |
-
"val_loss": val_loss.avg.item(),
|
| 82 |
-
"val_advtop1": val_advtop1.avg.item(),
|
| 83 |
-
"val_advtop5": val_advtop5.avg.item(),
|
| 84 |
-
"val_advloss": val_advloss.avg.item(),
|
| 85 |
-
}
|
| 86 |
-
return val_results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
smi08/ProArd/attacks/utils.py
DELETED
|
@@ -1,279 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
import torch.nn as nn
|
| 5 |
-
import torch.nn.functional as F
|
| 6 |
-
from torch.autograd import Variable
|
| 7 |
-
|
| 8 |
-
from torch.distributions import laplace
|
| 9 |
-
from torch.distributions import uniform
|
| 10 |
-
from torch.nn.modules.loss import _Loss
|
| 11 |
-
from contextlib import contextmanager
|
| 12 |
-
|
| 13 |
-
def replicate_input(x):
|
| 14 |
-
"""
|
| 15 |
-
Clone the input tensor x.
|
| 16 |
-
"""
|
| 17 |
-
return x.detach().clone()
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def replicate_input_withgrad(x):
|
| 21 |
-
"""
|
| 22 |
-
Clone the input tensor x and set requires_grad=True.
|
| 23 |
-
"""
|
| 24 |
-
return x.detach().clone().requires_grad_()
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def calc_l2distsq(x, y):
|
| 28 |
-
"""
|
| 29 |
-
Calculate L2 distance between tensors x and y.
|
| 30 |
-
"""
|
| 31 |
-
d = (x - y)**2
|
| 32 |
-
return d.view(d.shape[0], -1).sum(dim=1)
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def clamp(input, min=None, max=None):
|
| 36 |
-
"""
|
| 37 |
-
Clamp a tensor by its minimun and maximun values.
|
| 38 |
-
"""
|
| 39 |
-
ndim = input.ndimension()
|
| 40 |
-
if min is None:
|
| 41 |
-
pass
|
| 42 |
-
elif isinstance(min, (float, int)):
|
| 43 |
-
input = torch.clamp(input, min=min)
|
| 44 |
-
elif isinstance(min, torch.Tensor):
|
| 45 |
-
if min.ndimension() == ndim - 1 and min.shape == input.shape[1:]:
|
| 46 |
-
input = torch.max(input, min.view(1, *min.shape))
|
| 47 |
-
else:
|
| 48 |
-
assert min.shape == input.shape
|
| 49 |
-
input = torch.max(input, min)
|
| 50 |
-
else:
|
| 51 |
-
raise ValueError("min can only be None | float | torch.Tensor")
|
| 52 |
-
|
| 53 |
-
if max is None:
|
| 54 |
-
pass
|
| 55 |
-
elif isinstance(max, (float, int)):
|
| 56 |
-
input = torch.clamp(input, max=max)
|
| 57 |
-
elif isinstance(max, torch.Tensor):
|
| 58 |
-
if max.ndimension() == ndim - 1 and max.shape == input.shape[1:]:
|
| 59 |
-
input = torch.min(input, max.view(1, *max.shape))
|
| 60 |
-
else:
|
| 61 |
-
assert max.shape == input.shape
|
| 62 |
-
input = torch.min(input, max)
|
| 63 |
-
else:
|
| 64 |
-
raise ValueError("max can only be None | float | torch.Tensor")
|
| 65 |
-
return input
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
def _batch_multiply_tensor_by_vector(vector, batch_tensor):
|
| 69 |
-
"""Equivalent to the following.
|
| 70 |
-
for ii in range(len(vector)):
|
| 71 |
-
batch_tensor.data[ii] *= vector[ii]
|
| 72 |
-
return batch_tensor
|
| 73 |
-
"""
|
| 74 |
-
return (
|
| 75 |
-
batch_tensor.transpose(0, -1) * vector).transpose(0, -1).contiguous()
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
def _batch_clamp_tensor_by_vector(vector, batch_tensor):
|
| 79 |
-
"""Equivalent to the following.
|
| 80 |
-
for ii in range(len(vector)):
|
| 81 |
-
batch_tensor[ii] = clamp(
|
| 82 |
-
batch_tensor[ii], -vector[ii], vector[ii])
|
| 83 |
-
"""
|
| 84 |
-
return torch.min(
|
| 85 |
-
torch.max(batch_tensor.transpose(0, -1), -vector), vector
|
| 86 |
-
).transpose(0, -1).contiguous()
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
def batch_multiply(float_or_vector, tensor):
|
| 90 |
-
"""
|
| 91 |
-
Multpliy a batch of tensors with a float or vector.
|
| 92 |
-
"""
|
| 93 |
-
if isinstance(float_or_vector, torch.Tensor):
|
| 94 |
-
assert len(float_or_vector) == len(tensor)
|
| 95 |
-
tensor = _batch_multiply_tensor_by_vector(float_or_vector, tensor)
|
| 96 |
-
elif isinstance(float_or_vector, float):
|
| 97 |
-
tensor *= float_or_vector
|
| 98 |
-
else:
|
| 99 |
-
raise TypeError("Value has to be float or torch.Tensor")
|
| 100 |
-
return tensor
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def batch_clamp(float_or_vector, tensor):
|
| 104 |
-
"""
|
| 105 |
-
Clamp a batch of tensors.
|
| 106 |
-
"""
|
| 107 |
-
if isinstance(float_or_vector, torch.Tensor):
|
| 108 |
-
assert len(float_or_vector) == len(tensor)
|
| 109 |
-
tensor = _batch_clamp_tensor_by_vector(float_or_vector, tensor)
|
| 110 |
-
return tensor
|
| 111 |
-
elif isinstance(float_or_vector, float):
|
| 112 |
-
tensor = clamp(tensor, -float_or_vector, float_or_vector)
|
| 113 |
-
else:
|
| 114 |
-
raise TypeError("Value has to be float or torch.Tensor")
|
| 115 |
-
return tensor
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
def _get_norm_batch(x, p):
|
| 119 |
-
"""
|
| 120 |
-
Returns the Lp norm of batch x.
|
| 121 |
-
"""
|
| 122 |
-
batch_size = x.size(0)
|
| 123 |
-
return x.abs().pow(p).view(batch_size, -1).sum(dim=1).pow(1. / p)
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
def _thresh_by_magnitude(theta, x):
|
| 127 |
-
"""
|
| 128 |
-
Threshold by magnitude.
|
| 129 |
-
"""
|
| 130 |
-
return torch.relu(torch.abs(x) - theta) * x.sign()
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
def clamp_by_pnorm(x, p, r):
|
| 134 |
-
"""
|
| 135 |
-
Clamp tensor by its norm.
|
| 136 |
-
"""
|
| 137 |
-
assert isinstance(p, float) or isinstance(p, int)
|
| 138 |
-
norm = _get_norm_batch(x, p)
|
| 139 |
-
if isinstance(r, torch.Tensor):
|
| 140 |
-
assert norm.size() == r.size()
|
| 141 |
-
else:
|
| 142 |
-
assert isinstance(r, float)
|
| 143 |
-
factor = torch.min(r / norm, torch.ones_like(norm))
|
| 144 |
-
return batch_multiply(factor, x)
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
def is_float_or_torch_tensor(x):
|
| 148 |
-
"""
|
| 149 |
-
Return whether input x is a float or a torch.Tensor.
|
| 150 |
-
"""
|
| 151 |
-
return isinstance(x, torch.Tensor) or isinstance(x, float)
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
def normalize_by_pnorm(x, p=2, small_constant=1e-6):
|
| 155 |
-
"""
|
| 156 |
-
Normalize gradients for gradient (not gradient sign) attacks.
|
| 157 |
-
Arguments:
|
| 158 |
-
x (torch.Tensor): tensor containing the gradients on the input.
|
| 159 |
-
p (int): (optional) order of the norm for the normalization (1 or 2).
|
| 160 |
-
small_constant (float): (optional) to avoid dividing by zero.
|
| 161 |
-
Returns:
|
| 162 |
-
normalized gradients.
|
| 163 |
-
"""
|
| 164 |
-
assert isinstance(p, float) or isinstance(p, int)
|
| 165 |
-
norm = _get_norm_batch(x, p)
|
| 166 |
-
norm = torch.max(norm, torch.ones_like(norm) * small_constant)
|
| 167 |
-
return batch_multiply(1. / norm, x)
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
def rand_init_delta(delta, x, ord, eps, clip_min, clip_max):
|
| 171 |
-
"""
|
| 172 |
-
Randomly initialize the perturbation.
|
| 173 |
-
"""
|
| 174 |
-
if isinstance(eps, torch.Tensor):
|
| 175 |
-
assert len(eps) == len(delta)
|
| 176 |
-
|
| 177 |
-
if ord == np.inf:
|
| 178 |
-
delta.data.uniform_(-1, 1)
|
| 179 |
-
delta.data = batch_multiply(eps, delta.data)
|
| 180 |
-
elif ord == 2:
|
| 181 |
-
delta.data.uniform_(clip_min, clip_max)
|
| 182 |
-
delta.data = delta.data - x
|
| 183 |
-
delta.data = clamp_by_pnorm(delta.data, ord, eps)
|
| 184 |
-
elif ord == 1:
|
| 185 |
-
ini = laplace.Laplace(
|
| 186 |
-
loc=delta.new_tensor(0), scale=delta.new_tensor(1))
|
| 187 |
-
delta.data = ini.sample(delta.data.shape)
|
| 188 |
-
delta.data = normalize_by_pnorm(delta.data, p=1)
|
| 189 |
-
ray = uniform.Uniform(0, eps).sample()
|
| 190 |
-
delta.data *= ray
|
| 191 |
-
delta.data = clamp(x.data + delta.data, clip_min, clip_max) - x.data
|
| 192 |
-
else:
|
| 193 |
-
error = "Only ord = inf, ord = 1 and ord = 2 have been implemented"
|
| 194 |
-
raise NotImplementedError(error)
|
| 195 |
-
|
| 196 |
-
delta.data = clamp(
|
| 197 |
-
x + delta.data, min=clip_min, max=clip_max) - x
|
| 198 |
-
return delta.data
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
def CWLoss(output, target, confidence=0):
|
| 202 |
-
"""
|
| 203 |
-
CW loss (Marging loss).
|
| 204 |
-
"""
|
| 205 |
-
num_classes = output.shape[-1]
|
| 206 |
-
target = target.data
|
| 207 |
-
target_onehot = torch.zeros(target.size() + (num_classes,))
|
| 208 |
-
target_onehot = target_onehot.cuda()
|
| 209 |
-
target_onehot.scatter_(1, target.unsqueeze(1), 1.)
|
| 210 |
-
target_var = Variable(target_onehot, requires_grad=False)
|
| 211 |
-
real = (target_var * output).sum(1)
|
| 212 |
-
other = ((1. - target_var) * output - target_var * 10000.).max(1)[0]
|
| 213 |
-
loss = - torch.clamp(real - other + confidence, min=0.)
|
| 214 |
-
loss = torch.sum(loss)
|
| 215 |
-
return loss
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
class ctx_noparamgrad(object):
|
| 221 |
-
def __init__(self, module):
|
| 222 |
-
self.prev_grad_state = get_param_grad_state(module)
|
| 223 |
-
self.module = module
|
| 224 |
-
set_param_grad_off(module)
|
| 225 |
-
|
| 226 |
-
def __enter__(self):
|
| 227 |
-
pass
|
| 228 |
-
|
| 229 |
-
def __exit__(self, *args):
|
| 230 |
-
set_param_grad_state(self.module, self.prev_grad_state)
|
| 231 |
-
return False
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
class ctx_eval(object):
|
| 235 |
-
def __init__(self, module):
|
| 236 |
-
self.prev_training_state = get_module_training_state(module)
|
| 237 |
-
self.module = module
|
| 238 |
-
set_module_training_off(module)
|
| 239 |
-
|
| 240 |
-
def __enter__(self):
|
| 241 |
-
pass
|
| 242 |
-
|
| 243 |
-
def __exit__(self, *args):
|
| 244 |
-
set_module_training_state(self.module, self.prev_training_state)
|
| 245 |
-
return False
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
@contextmanager
|
| 249 |
-
def ctx_noparamgrad_and_eval(module):
|
| 250 |
-
with ctx_noparamgrad(module) as a, ctx_eval(module) as b:
|
| 251 |
-
yield (a, b)
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
def get_module_training_state(module):
|
| 255 |
-
return {mod: mod.training for mod in module.modules()}
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
def set_module_training_state(module, training_state):
|
| 259 |
-
for mod in module.modules():
|
| 260 |
-
mod.training = training_state[mod]
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
def set_module_training_off(module):
|
| 264 |
-
for mod in module.modules():
|
| 265 |
-
mod.training = False
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
def get_param_grad_state(module):
|
| 269 |
-
return {param: param.requires_grad for param in module.parameters()}
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
def set_param_grad_state(module, grad_state):
|
| 273 |
-
for param in module.parameters():
|
| 274 |
-
param.requires_grad = grad_state[param]
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
def set_param_grad_off(module):
|
| 278 |
-
for param in module.parameters():
|
| 279 |
-
param.requires_grad = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|