File size: 4,026 Bytes
7771996
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import torch.nn as nn 
import torch.nn.functional as F
from autoattack import AutoAttack
import numpy as np
import logging
from .base import Attack,LabelMixin 
import torch
import torch.nn as nn
from typing import Dict
from .utils import  ctx_noparamgrad_and_eval
from utils.distributed import DistributedMetric
from tqdm import tqdm
from torchpack import distributed as dist
from utils import accuracy
class AutoAttacks(Attack, LabelMixin):
    
    def __init__(
            self, predict, loss_fn=None, eps=0.3, nb_iter=40, eps_iter=0.01, rand_init=True, clip_min=0., clip_max=1.,
            ord=np.inf, targeted=False, rand_init_type='uniform'):
        super(AutoAttacks, self).__init__(predict, loss_fn, clip_min, clip_max)
        self.eps = eps
        self.nb_iter = nb_iter
        self.eps_iter = eps_iter
        self.rand_init = rand_init
        self.rand_init_type = rand_init_type
        self.ord = ord
        self.targeted = targeted
        if self.loss_fn is None:
            self.loss_fn = nn.CrossEntropyLoss(reduction="sum")
        self.adversary =  AutoAttack(predict, norm='Linf', eps=self.eps, version='standard')     
    def perturb(self, x, y=None):
        adversarial_examples = self.adversary.run_standard_evaluation(x, y, bs=100)  
        return adversarial_examples,adversarial_examples 
    def eval_AutoAttack(self,data_loader_dict: Dict)-> Dict:

        test_criterion = nn.CrossEntropyLoss().cuda()
        val_loss = DistributedMetric()
        val_top1 = DistributedMetric()
        val_top5 = DistributedMetric()
        val_advloss = DistributedMetric()
        val_advtop1 = DistributedMetric()
        val_advtop5 = DistributedMetric()
        self.predict.eval()
        with tqdm(
                total=len(data_loader_dict["val"]),
                desc="Eval",
                disable=not dist.is_master(),
            ) as t:
                for images, labels in data_loader_dict["val"]:
                    images, labels = images.cuda(), labels.cuda()
                    # compute output
                    output = self.predict(images)
                    loss = test_criterion(output, labels)
                    val_loss.update(loss, images.shape[0])
                    acc1, acc5 = accuracy(output, labels, topk=(1, 5))
                    val_top5.update(acc5[0], images.shape[0])
                    val_top1.update(acc1[0], images.shape[0])
                    with ctx_noparamgrad_and_eval(self.predict):
                        images_adv,_ = self.perturb(images, labels)
                    output_adv = self.predict(images_adv)   
                    loss_adv = test_criterion(output_adv,labels) 
                    val_advloss.update(loss_adv, images.shape[0])   
                    acc1_adv, acc5_adv = accuracy(output_adv, labels, topk=(1, 5))   
                    val_advtop1.update(acc1_adv[0], images.shape[0])
                    val_advtop5.update(acc5_adv[0], images.shape[0])                  
                    t.set_postfix(
                        {
                            "loss": val_loss.avg.item(),
                            "top1": val_top1.avg.item(),
                            "top5": val_top5.avg.item(),
                            "adv_loss": val_advloss.avg.item(),
                            "adv_top1": val_advtop1.avg.item(),
                            "adv_top5": val_advtop5.avg.item(),
                            "#samples": val_top1.count.item(),
                            "batch_size": images.shape[0],
                            "img_size": images.shape[2],
                        }
                    )
                    t.update()

        val_results = {
            "val_top1": val_top1.avg.item(),
            "val_top5": val_top5.avg.item(),
            "val_loss": val_loss.avg.item(),
            "val_advtop1": val_advtop1.avg.item(),
            "val_advtop5": val_advtop5.avg.item(),
            "val_advloss": val_advloss.avg.item(),
        }
        return val_results