helblazer811 commited on
Commit
55866f4
·
0 Parent(s):

"Orphan branch commit with a readme"

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +5 -0
  2. README.md +7 -0
  3. app.py +140 -0
  4. concept_attention/__init__.py +2 -0
  5. concept_attention/binary_segmentation_baselines/__init__.py +0 -0
  6. concept_attention/binary_segmentation_baselines/__pycache__/__init__.cpython-310.pyc +0 -0
  7. concept_attention/binary_segmentation_baselines/__pycache__/chefer_clip_vit_baselines.cpython-310.pyc +0 -0
  8. concept_attention/binary_segmentation_baselines/__pycache__/clip_text_span_baseline.cpython-310.pyc +0 -0
  9. concept_attention/binary_segmentation_baselines/__pycache__/daam.cpython-310.pyc +0 -0
  10. concept_attention/binary_segmentation_baselines/__pycache__/daam_sd2.cpython-310.pyc +0 -0
  11. concept_attention/binary_segmentation_baselines/__pycache__/daam_sdxl.cpython-310.pyc +0 -0
  12. concept_attention/binary_segmentation_baselines/__pycache__/dino.cpython-310.pyc +0 -0
  13. concept_attention/binary_segmentation_baselines/__pycache__/raw_cross_attention.cpython-310.pyc +0 -0
  14. concept_attention/binary_segmentation_baselines/__pycache__/raw_output_space.cpython-310.pyc +0 -0
  15. concept_attention/binary_segmentation_baselines/__pycache__/raw_value_space.cpython-310.pyc +0 -0
  16. concept_attention/binary_segmentation_baselines/chefer_clip_vit_baselines.py +272 -0
  17. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/ViT_LRP.py +437 -0
  18. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/ViT_explanation_generator.py +83 -0
  19. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/ViT_new.py +238 -0
  20. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/ViT_orig_LRP.py +425 -0
  21. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/ViT_LRP.cpython-310.pyc +0 -0
  22. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/ViT_explanation_generator.cpython-310.pyc +0 -0
  23. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/ViT_new.cpython-310.pyc +0 -0
  24. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/ViT_orig_LRP.cpython-310.pyc +0 -0
  25. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/helpers.cpython-310.pyc +0 -0
  26. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/layer_helpers.cpython-310.pyc +0 -0
  27. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/weight_init.cpython-310.pyc +0 -0
  28. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/VOC.py +395 -0
  29. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/__init__.py +0 -0
  30. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/__pycache__/Imagenet.cpython-310.pyc +0 -0
  31. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/__pycache__/VOC.cpython-310.pyc +0 -0
  32. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/__pycache__/__init__.cpython-310.pyc +0 -0
  33. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/__pycache__/imagenet.cpython-310.pyc +0 -0
  34. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/imagenet.py +200 -0
  35. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/imagenet_utils.py +1002 -0
  36. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/transforms.py +442 -0
  37. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/generate_visualizations.py +208 -0
  38. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/helpers.py +295 -0
  39. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/layer_helpers.py +21 -0
  40. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/misc_functions.py +68 -0
  41. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/__init__.py +0 -0
  42. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  43. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/__pycache__/layers_lrp.cpython-310.pyc +0 -0
  44. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/__pycache__/layers_ours.cpython-310.pyc +0 -0
  45. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/layers_lrp.py +261 -0
  46. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/layers_ours.py +280 -0
  47. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/pertubation_eval_from_hdf5.py +232 -0
  48. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/utils/__init__.py +0 -0
  49. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  50. concept_attention/binary_segmentation_baselines/chefer_vit_explainability/utils/__pycache__/confusionmatrix.cpython-310.pyc +0 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ *.png
2
+ *.pyc
3
+ concept_attention.egg-info
4
+ concept_attention/flux/src/flux.egg-info/PKG-INFO
5
+ *.pyc
README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ConceptAttention
3
+ sdk: gradio
4
+ sdk_version: "5.15.0"
5
+ app_file: app.py
6
+ pinned: false
7
+ ---
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+
4
+ import spaces
5
+ import gradio as gr
6
+ from PIL import Image
7
+
8
+ from concept_attention import ConceptAttentionFluxPipeline
9
+
10
+ concept_attention_default_args = {
11
+ "model_name": "flux-schnell",
12
+ "device": "cuda",
13
+ "layer_indices": list(range(10, 19)),
14
+ "timesteps": list(range(4)),
15
+ "num_samples": 4,
16
+ "num_inference_steps": 4
17
+ }
18
+ IMG_SIZE = 250
19
+
20
+ EXAMPLES = [
21
+ [
22
+ "A fluffy cat sitting on a windowsill", # prompt
23
+ "cat.jpg", # image
24
+ "fur, whiskers, eyes", # words
25
+ 42, # seed
26
+ ],
27
+ # ["Mountain landscape with lake", "cat.jpg", "sky, trees, water", 123],
28
+ # ["Portrait of a young woman", "monkey.png", "face, hair, eyes", 456],
29
+ ]
30
+
31
+
32
+ pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell", device="cuda")
33
+
34
+
35
+ @spaces.GPU(duration=60)
36
+ def process_inputs(prompt, input_image, word_list, seed):
37
+ prompt = prompt.strip()
38
+ if not word_list.strip():
39
+ return None, "Please enter comma-separated words"
40
+
41
+ concepts = [w.strip() for w in word_list.split(",")]
42
+
43
+ if input_image is not None:
44
+ input_image = Image.fromarray(input_image)
45
+ input_image = input_image.convert("RGB")
46
+ input_image = input_image.resize((1024, 1024))
47
+
48
+ pipeline_output = pipeline.encode_image(
49
+ image=input_image,
50
+ concepts=concepts,
51
+ prompt=prompt,
52
+ width=1024,
53
+ height=1024,
54
+ seed=seed,
55
+ num_samples=concept_attention_default_args["num_samples"]
56
+ )
57
+ else:
58
+ pipeline_output = pipeline.generate_image(
59
+ prompt=prompt,
60
+ concepts=concepts,
61
+ width=1024,
62
+ height=1024,
63
+ seed=seed,
64
+ timesteps=concept_attention_default_args["timesteps"],
65
+ num_inference_steps=concept_attention_default_args["num_inference_steps"],
66
+ )
67
+
68
+ output_image = pipeline_output.image
69
+ concept_heatmaps = pipeline_output.concept_heatmaps
70
+
71
+ html_elements = []
72
+ for concept, heatmap in zip(concepts, concept_heatmaps):
73
+ img = heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST)
74
+ buffered = io.BytesIO()
75
+ img.save(buffered, format="PNG")
76
+ img_str = base64.b64encode(buffered.getvalue()).decode()
77
+
78
+ html = f"""
79
+ <div style='text-align: center; margin: 5px; padding: 5px; overflow-x: auto; white-space: nowrap;'>
80
+ <h1 style='margin-bottom: 10px;'>{concept}</h1>
81
+ <img src='data:image/png;base64,{img_str}' style='width: {IMG_SIZE}px; display: inline-block; height: {IMG_SIZE}px;'>
82
+ </div>
83
+ """
84
+ html_elements.append(html)
85
+
86
+ combined_html = "<div style='display: flex; flex-wrap: wrap; justify-content: center;'>" + "".join(html_elements) + "</div>"
87
+ return output_image, combined_html
88
+
89
+
90
+ with gr.Blocks(
91
+ css="""
92
+ .container { max-width: 1200px; margin: 0 auto; padding: 20px; }
93
+ .title { text-align: center; margin-bottom: 10px; }
94
+ .authors { text-align: center; margin-bottom: 20px; }
95
+ .affiliations { text-align: center; color: #666; margin-bottom: 40px; }
96
+ .content { display: grid; grid-template-columns: 1fr 1fr; gap: 20px; }
97
+ .section { border: 2px solid #ddd; border-radius: 10px; padding: 20px; }
98
+ """
99
+ ) as demo:
100
+ with gr.Column(elem_classes="container"):
101
+ gr.Markdown("# ConceptAttention: Diffusion Transformers Learn Highly Interpretable Features", elem_classes="title")
102
+ gr.Markdown("**Alec Helbling**¹, **Tuna Meral**², **Ben Hoover**¹³, **Pinar Yanardag**², **Duen Horng (Polo) Chau**¹", elem_classes="authors")
103
+ gr.Markdown("¹Georgia Tech · ²Virginia Tech · ³IBM Research", elem_classes="affiliations")
104
+
105
+ with gr.Row(elem_classes="content"):
106
+ with gr.Column(elem_classes="section"):
107
+ gr.Markdown("### Input")
108
+ prompt = gr.Textbox(label="Enter your prompt")
109
+ words = gr.Textbox(label="Enter words (comma-separated)")
110
+ seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
111
+ gr.HTML("<div style='text-align: center;'> <h1> Or </h1> </div>")
112
+ image_input = gr.Image(type="numpy", label="Upload image (optional)")
113
+
114
+ with gr.Column(elem_classes="section"):
115
+ gr.Markdown("### Output")
116
+ output_image = gr.Image(type="numpy", label="Output image")
117
+
118
+ with gr.Row():
119
+ submit_btn = gr.Button("Process")
120
+
121
+ with gr.Row(elem_classes="section"):
122
+ saliency_display = gr.HTML(label="Saliency Maps")
123
+
124
+ submit_btn.click(
125
+ fn=process_inputs,
126
+ inputs=[prompt, image_input, words, seed], outputs=[output_image, saliency_display]
127
+ )
128
+
129
+ gr.Examples(examples=EXAMPLES, inputs=[prompt, image_input, words, seed], outputs=[output_image, saliency_display], fn=process_inputs, cache_examples=False)
130
+
131
+ if __name__ == "__main__":
132
+ demo.launch(
133
+ share=True,
134
+ server_name="0.0.0.0",
135
+ inbrowser=True,
136
+ # share=False,
137
+ server_port=6754,
138
+ quiet=True,
139
+ max_threads=1
140
+ )
concept_attention/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ from concept_attention.concept_attention_pipeline import ConceptAttentionFluxPipeline
concept_attention/binary_segmentation_baselines/__init__.py ADDED
File without changes
concept_attention/binary_segmentation_baselines/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (214 Bytes). View file
 
concept_attention/binary_segmentation_baselines/__pycache__/chefer_clip_vit_baselines.cpython-310.pyc ADDED
Binary file (7.18 kB). View file
 
concept_attention/binary_segmentation_baselines/__pycache__/clip_text_span_baseline.cpython-310.pyc ADDED
Binary file (3.66 kB). View file
 
concept_attention/binary_segmentation_baselines/__pycache__/daam.cpython-310.pyc ADDED
Binary file (2.52 kB). View file
 
concept_attention/binary_segmentation_baselines/__pycache__/daam_sd2.cpython-310.pyc ADDED
Binary file (3.81 kB). View file
 
concept_attention/binary_segmentation_baselines/__pycache__/daam_sdxl.cpython-310.pyc ADDED
Binary file (4.69 kB). View file
 
concept_attention/binary_segmentation_baselines/__pycache__/dino.cpython-310.pyc ADDED
Binary file (2.93 kB). View file
 
concept_attention/binary_segmentation_baselines/__pycache__/raw_cross_attention.cpython-310.pyc ADDED
Binary file (6.26 kB). View file
 
concept_attention/binary_segmentation_baselines/__pycache__/raw_output_space.cpython-310.pyc ADDED
Binary file (7 kB). View file
 
concept_attention/binary_segmentation_baselines/__pycache__/raw_value_space.cpython-310.pyc ADDED
Binary file (6.64 kB). View file
 
concept_attention/binary_segmentation_baselines/chefer_clip_vit_baselines.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is just a wrapper around the various baselines implemented in the
3
+ Chefer et. al. Transformer Explainability repository.
4
+
5
+ Implements
6
+ - CheferLRPSegmentationModel
7
+ - CheferRolloutSegmentationModel
8
+ - CheferLastLayerAttentionSegmentationModel
9
+ - CheferAttentionGradCAMSegmentationModel
10
+ - CheferTransformerAttributionSegmentationModel
11
+ - CheferFullLRPSegmentationModel
12
+ - CheferLastLayerLRPSegmentationModel
13
+ """
14
+
15
+ # # segmentation test for the rollout baseline
16
+ # if args.method == 'rollout':
17
+ # Res = baselines.generate_rollout(image.cuda(), start_layer=1).reshape(batch_size, 1, 14, 14)
18
+
19
+ # # segmentation test for the LRP baseline (this is full LRP, not partial)
20
+ # elif args.method == 'full_lrp':
21
+ # Res = orig_lrp.generate_LRP(image.cuda(), method="full").reshape(batch_size, 1, 224, 224)
22
+
23
+ # # segmentation test for our method
24
+ # elif args.method == 'transformer_attribution':
25
+ # Res = lrp.generate_LRP(image.cuda(), start_layer=1, method="transformer_attribution").reshape(batch_size, 1, 14, 14)
26
+
27
+ # # segmentation test for the partial LRP baseline (last attn layer)
28
+ # elif args.method == 'lrp_last_layer':
29
+ # Res = orig_lrp.generate_LRP(image.cuda(), method="last_layer", is_ablation=args.is_ablation)\
30
+ # .reshape(batch_size, 1, 14, 14)
31
+
32
+ # # segmentation test for the raw attention baseline (last attn layer)
33
+ # elif args.method == 'attn_last_layer':
34
+ # Res = orig_lrp.generate_LRP(image.cuda(), method="last_layer_attn", is_ablation=args.is_ablation)\
35
+ # .reshape(batch_size, 1, 14, 14)
36
+
37
+ # # segmentation test for the GradCam baseline (last attn layer)
38
+ # elif args.method == 'attn_gradcam':
39
+ # Res = baselines.generate_cam_attn(image.cuda()).reshape(batch_size, 1, 14, 14)
40
+
41
+ # if args.method != 'full_lrp':
42
+ # # interpolate to full image size (224,224)
43
+ # Res = torch.nn.functional.interpolate(Res, scale_factor=16, mode='bilinear').cuda()
44
+
45
+ import torch
46
+ import PIL
47
+
48
+ from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.ViT_explanation_generator import LRP
49
+ from concept_attention.segmentation import SegmentationAbstractClass
50
+ from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.ViT_explanation_generator import Baselines, LRP
51
+ from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.ViT_new import vit_base_patch16_224
52
+ from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.ViT_LRP import vit_base_patch16_224 as vit_LRP
53
+ from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP
54
+
55
+
56
+ # # Model
57
+ # model = vit_base_patch16_224(pretrained=True).cuda()
58
+ # baselines = Baselines(model)
59
+
60
+ # # LRP
61
+ # model_LRP = vit_LRP(pretrained=True).cuda()
62
+ # model_LRP.eval()
63
+ # lrp = LRP(model_LRP)
64
+
65
+ # # orig LRP
66
+ # model_orig_LRP = vit_orig_LRP(pretrained=True).cuda()
67
+ # model_orig_LRP.eval()
68
+ # orig_lrp = LRP(model_orig_LRP)
69
+
70
+ # model.eval()
71
+
72
+ class CheferLRPSegmentationModel(SegmentationAbstractClass):
73
+
74
+ def __init__(
75
+ self,
76
+ device: str = "cuda",
77
+ width: int = 224,
78
+ height: int = 224,
79
+ ):
80
+ """
81
+ Initialize the segmentation model.
82
+ """
83
+ super(CheferLRPSegmentationModel, self).__init__()
84
+ self.width = width
85
+ self.height = height
86
+ self.device = device
87
+ # Load the LRP model
88
+ model_orig_LRP = vit_orig_LRP(pretrained=True).to(self.device)
89
+ model_orig_LRP.eval()
90
+ self.orig_lrp = LRP(model_orig_LRP)
91
+
92
+ def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs):
93
+ """
94
+ Takes a real image and generates a concept segmentation map
95
+ it by adding noise and running the DiT on it.
96
+ """
97
+ if len(image.shape) == 3:
98
+ image = image.unsqueeze(0)
99
+
100
+ prediction_map = self.orig_lrp.generate_LRP(
101
+ image.to(self.device),
102
+ method="full"
103
+ )
104
+ prediction_map = prediction_map.unsqueeze(0)
105
+ # Rescale the prediction map to 64x64
106
+ prediction_map = torch.nn.functional.interpolate(
107
+ prediction_map,
108
+ size=(self.width, self.height),
109
+ mode="nearest"
110
+ ).reshape(1, self.width, self.height)
111
+
112
+ return prediction_map, None
113
+
114
+ class CheferRolloutSegmentationModel(SegmentationAbstractClass):
115
+
116
+ def __init__(self, device: str = "cuda", width: int = 224, height: int = 224):
117
+ super(CheferRolloutSegmentationModel, self).__init__()
118
+ self.width = width
119
+ self.height = height
120
+ self.device = device
121
+ model = vit_base_patch16_224(pretrained=True).to(device)
122
+ self.baselines = Baselines(model)
123
+
124
+ def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs):
125
+ if len(image.shape) == 3:
126
+ image = image.unsqueeze(0)
127
+ prediction_map = self.baselines.generate_rollout(
128
+ image.to(self.device), start_layer=1
129
+ ).reshape(1, 1, 14, 14)
130
+ # Rescale the prediction map to 64x64
131
+ prediction_map = torch.nn.functional.interpolate(
132
+ prediction_map,
133
+ size=(self.width, self.height),
134
+ mode="nearest"
135
+ ).reshape(1, self.width, self.height)
136
+
137
+ return prediction_map, None
138
+
139
+
140
+ class CheferLastLayerAttentionSegmentationModel(SegmentationAbstractClass):
141
+
142
+ def __init__(self, device: str = "cuda", width: int = 224, height: int = 224):
143
+ super(CheferLastLayerAttentionSegmentationModel, self).__init__()
144
+ self.width = width
145
+ self.height = height
146
+ self.device = device
147
+ model_orig_LRP = vit_orig_LRP(pretrained=True).to(device)
148
+ model_orig_LRP.eval()
149
+ self.orig_lrp = LRP(model_orig_LRP)
150
+
151
+ def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs):
152
+ if len(image.shape) == 3:
153
+ image = image.unsqueeze(0)
154
+
155
+ prediction_map = self.orig_lrp.generate_LRP(
156
+ image.to(self.device), method="last_layer_attn"
157
+ ).reshape(1, 1, 14, 14)
158
+ # Rescale the prediction map to 64x64
159
+ prediction_map = torch.nn.functional.interpolate(
160
+ prediction_map,
161
+ size=(self.width, self.height),
162
+ mode="nearest"
163
+ ).reshape(1, self.width, self.height)
164
+
165
+ return prediction_map, None
166
+
167
+
168
+ class CheferAttentionGradCAMSegmentationModel(SegmentationAbstractClass):
169
+
170
+ def __init__(self, device: str = "cuda", width: int = 224, height: int = 224):
171
+ super(CheferAttentionGradCAMSegmentationModel, self).__init__()
172
+ self.width = width
173
+ self.height = height
174
+ self.device = device
175
+ model = vit_base_patch16_224(pretrained=True).to(device)
176
+ self.baselines = Baselines(model)
177
+
178
+ def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs):
179
+ if len(image.shape) == 3:
180
+ image = image.unsqueeze(0)
181
+ prediction_map = self.baselines.generate_cam_attn(
182
+ image.to(self.device)
183
+ ).reshape(1, 1, 14, 14)
184
+ # Rescale the prediction map to 64x64
185
+ prediction_map = torch.nn.functional.interpolate(
186
+ prediction_map,
187
+ size=(self.width, self.height),
188
+ mode="nearest"
189
+ ).reshape(1, self.width, self.height)
190
+
191
+ return prediction_map, None
192
+
193
+
194
+ class CheferTransformerAttributionSegmentationModel(SegmentationAbstractClass):
195
+
196
+ def __init__(self, device: str = "cuda", width: int = 224, height: int = 224):
197
+ super(CheferTransformerAttributionSegmentationModel, self).__init__()
198
+ self.width = width
199
+ self.height = height
200
+ self.device = device
201
+ model_LRP = vit_LRP(pretrained=True).to(device)
202
+ model_LRP.eval()
203
+ self.lrp = LRP(model_LRP)
204
+
205
+ def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs):
206
+ if len(image.shape) == 3:
207
+ image = image.unsqueeze(0)
208
+ prediction_map = self.lrp.generate_LRP(
209
+ image.to(self.device), start_layer=1, method="transformer_attribution"
210
+ ).reshape(1, 1, 14, 14)
211
+ # Rescale the prediction map to 64x64
212
+ prediction_map = torch.nn.functional.interpolate(
213
+ prediction_map,
214
+ size=(self.width, self.height),
215
+ mode="nearest"
216
+ ).reshape(1, self.width, self.height)
217
+
218
+ return prediction_map, None
219
+
220
+
221
+ class CheferFullLRPSegmentationModel(SegmentationAbstractClass):
222
+
223
+ def __init__(self, device: str = "cuda", width: int = 224, height: int = 224):
224
+ super(CheferFullLRPSegmentationModel, self).__init__()
225
+ self.width = width
226
+ self.height = height
227
+ self.device = device
228
+ model_LRP = vit_LRP(pretrained=True).to(device)
229
+ model_LRP.eval()
230
+ self.lrp = LRP(model_LRP)
231
+
232
+ def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs):
233
+ if len(image.shape) == 3:
234
+ image = image.unsqueeze(0)
235
+ prediction_map = self.lrp.generate_LRP(
236
+ image.to(self.device), method="full"
237
+ ).reshape(1, 1, 224, 224)
238
+ # Rescale the prediction map to 64x64
239
+ prediction_map = torch.nn.functional.interpolate(
240
+ prediction_map,
241
+ size=(self.width, self.height),
242
+ mode="nearest"
243
+ ).reshape(1, self.width, self.height)
244
+
245
+ return prediction_map, None
246
+
247
+
248
+ class CheferLastLayerLRPSegmentationModel(SegmentationAbstractClass):
249
+
250
+ def __init__(self, device: str = "cuda", width: int = 224, height: int = 224):
251
+ super(CheferLastLayerLRPSegmentationModel, self).__init__()
252
+ self.width = width
253
+ self.height = height
254
+ self.device = device
255
+ model_LRP = vit_LRP(pretrained=True).to(device)
256
+ model_LRP.eval()
257
+ self.lrp = LRP(model_LRP)
258
+
259
+ def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs):
260
+ if len(image.shape) == 3:
261
+ image = image.unsqueeze(0)
262
+ prediction_map = self.lrp.generate_LRP(
263
+ image.to(self.device), method="last_layer"
264
+ ).reshape(1, 1, 14, 14)
265
+ # Rescale the prediction map to 64x64
266
+ prediction_map = torch.nn.functional.interpolate(
267
+ prediction_map,
268
+ size=(self.width, self.height),
269
+ mode="nearest"
270
+ ).reshape(1, self.width, self.height)
271
+
272
+ return prediction_map, None
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/ViT_LRP.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Vision Transformer (ViT) in PyTorch
2
+ Hacked together by / Copyright 2020 Ross Wightman
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+
8
+ from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.modules.layers_ours import *
9
+ from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.helpers import load_pretrained
10
+ from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.weight_init import trunc_normal_
11
+ from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.layer_helpers import to_2tuple
12
+
13
+
14
+ def _cfg(url='', **kwargs):
15
+ return {
16
+ 'url': url,
17
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
18
+ 'crop_pct': .9, 'interpolation': 'bicubic',
19
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
20
+ **kwargs
21
+ }
22
+
23
+
24
+ default_cfgs = {
25
+ # patch models
26
+ 'vit_small_patch16_224': _cfg(
27
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
28
+ ),
29
+ 'vit_base_patch16_224': _cfg(
30
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
31
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
32
+ ),
33
+ 'vit_large_patch16_224': _cfg(
34
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
35
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
36
+ }
37
+
38
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
39
+ # adding residual consideration
40
+ num_tokens = all_layer_matrices[0].shape[1]
41
+ batch_size = all_layer_matrices[0].shape[0]
42
+ eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
43
+ all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
44
+ # all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
45
+ # for i in range(len(all_layer_matrices))]
46
+ joint_attention = all_layer_matrices[start_layer]
47
+ for i in range(start_layer+1, len(all_layer_matrices)):
48
+ joint_attention = all_layer_matrices[i].bmm(joint_attention)
49
+ return joint_attention
50
+
51
+ class Mlp(nn.Module):
52
+ def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
53
+ super().__init__()
54
+ out_features = out_features or in_features
55
+ hidden_features = hidden_features or in_features
56
+ self.fc1 = Linear(in_features, hidden_features)
57
+ self.act = GELU()
58
+ self.fc2 = Linear(hidden_features, out_features)
59
+ self.drop = Dropout(drop)
60
+
61
+ def forward(self, x):
62
+ x = self.fc1(x)
63
+ x = self.act(x)
64
+ x = self.drop(x)
65
+ x = self.fc2(x)
66
+ x = self.drop(x)
67
+ return x
68
+
69
+ def relprop(self, cam, **kwargs):
70
+ cam = self.drop.relprop(cam, **kwargs)
71
+ cam = self.fc2.relprop(cam, **kwargs)
72
+ cam = self.act.relprop(cam, **kwargs)
73
+ cam = self.fc1.relprop(cam, **kwargs)
74
+ return cam
75
+
76
+
77
+ class Attention(nn.Module):
78
+ def __init__(self, dim, num_heads=8, qkv_bias=False,attn_drop=0., proj_drop=0.):
79
+ super().__init__()
80
+ self.num_heads = num_heads
81
+ head_dim = dim // num_heads
82
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
83
+ self.scale = head_dim ** -0.5
84
+
85
+ # A = Q*K^T
86
+ self.matmul1 = einsum('bhid,bhjd->bhij')
87
+ # attn = A*V
88
+ self.matmul2 = einsum('bhij,bhjd->bhid')
89
+
90
+ self.qkv = Linear(dim, dim * 3, bias=qkv_bias)
91
+ self.attn_drop = Dropout(attn_drop)
92
+ self.proj = Linear(dim, dim)
93
+ self.proj_drop = Dropout(proj_drop)
94
+ self.softmax = Softmax(dim=-1)
95
+
96
+ self.attn_cam = None
97
+ self.attn = None
98
+ self.v = None
99
+ self.v_cam = None
100
+ self.attn_gradients = None
101
+
102
+ def get_attn(self):
103
+ return self.attn
104
+
105
+ def save_attn(self, attn):
106
+ self.attn = attn
107
+
108
+ def save_attn_cam(self, cam):
109
+ self.attn_cam = cam
110
+
111
+ def get_attn_cam(self):
112
+ return self.attn_cam
113
+
114
+ def get_v(self):
115
+ return self.v
116
+
117
+ def save_v(self, v):
118
+ self.v = v
119
+
120
+ def save_v_cam(self, cam):
121
+ self.v_cam = cam
122
+
123
+ def get_v_cam(self):
124
+ return self.v_cam
125
+
126
+ def save_attn_gradients(self, attn_gradients):
127
+ self.attn_gradients = attn_gradients
128
+
129
+ def get_attn_gradients(self):
130
+ return self.attn_gradients
131
+
132
+ def forward(self, x):
133
+ b, n, _, h = *x.shape, self.num_heads
134
+ qkv = self.qkv(x)
135
+ q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h)
136
+
137
+ self.save_v(v)
138
+
139
+ dots = self.matmul1([q, k]) * self.scale
140
+
141
+ attn = self.softmax(dots)
142
+ attn = self.attn_drop(attn)
143
+
144
+ self.save_attn(attn)
145
+ attn.register_hook(self.save_attn_gradients)
146
+
147
+ out = self.matmul2([attn, v])
148
+ out = rearrange(out, 'b h n d -> b n (h d)')
149
+
150
+ out = self.proj(out)
151
+ out = self.proj_drop(out)
152
+ return out
153
+
154
+ def relprop(self, cam, **kwargs):
155
+ cam = self.proj_drop.relprop(cam, **kwargs)
156
+ cam = self.proj.relprop(cam, **kwargs)
157
+ cam = rearrange(cam, 'b n (h d) -> b h n d', h=self.num_heads)
158
+
159
+ # attn = A*V
160
+ (cam1, cam_v)= self.matmul2.relprop(cam, **kwargs)
161
+ cam1 /= 2
162
+ cam_v /= 2
163
+
164
+ self.save_v_cam(cam_v)
165
+ self.save_attn_cam(cam1)
166
+
167
+ cam1 = self.attn_drop.relprop(cam1, **kwargs)
168
+ cam1 = self.softmax.relprop(cam1, **kwargs)
169
+
170
+ # A = Q*K^T
171
+ (cam_q, cam_k) = self.matmul1.relprop(cam1, **kwargs)
172
+ cam_q /= 2
173
+ cam_k /= 2
174
+
175
+ cam_qkv = rearrange([cam_q, cam_k, cam_v], 'qkv b h n d -> b n (qkv h d)', qkv=3, h=self.num_heads)
176
+
177
+ return self.qkv.relprop(cam_qkv, **kwargs)
178
+
179
+
180
+ class Block(nn.Module):
181
+
182
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
183
+ super().__init__()
184
+ self.norm1 = LayerNorm(dim, eps=1e-6)
185
+ self.attn = Attention(
186
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
187
+ self.norm2 = LayerNorm(dim, eps=1e-6)
188
+ mlp_hidden_dim = int(dim * mlp_ratio)
189
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
190
+
191
+ self.add1 = Add()
192
+ self.add2 = Add()
193
+ self.clone1 = Clone()
194
+ self.clone2 = Clone()
195
+
196
+ def forward(self, x):
197
+ x1, x2 = self.clone1(x, 2)
198
+ x = self.add1([x1, self.attn(self.norm1(x2))])
199
+ x1, x2 = self.clone2(x, 2)
200
+ x = self.add2([x1, self.mlp(self.norm2(x2))])
201
+ return x
202
+
203
+ def relprop(self, cam, **kwargs):
204
+ (cam1, cam2) = self.add2.relprop(cam, **kwargs)
205
+ cam2 = self.mlp.relprop(cam2, **kwargs)
206
+ cam2 = self.norm2.relprop(cam2, **kwargs)
207
+ cam = self.clone2.relprop((cam1, cam2), **kwargs)
208
+
209
+ (cam1, cam2) = self.add1.relprop(cam, **kwargs)
210
+ cam2 = self.attn.relprop(cam2, **kwargs)
211
+ cam2 = self.norm1.relprop(cam2, **kwargs)
212
+ cam = self.clone1.relprop((cam1, cam2), **kwargs)
213
+ return cam
214
+
215
+
216
+ class PatchEmbed(nn.Module):
217
+ """ Image to Patch Embedding
218
+ """
219
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
220
+ super().__init__()
221
+ img_size = to_2tuple(img_size)
222
+ patch_size = to_2tuple(patch_size)
223
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
224
+ self.img_size = img_size
225
+ self.patch_size = patch_size
226
+ self.num_patches = num_patches
227
+
228
+ self.proj = Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
229
+
230
+ def forward(self, x):
231
+ B, C, H, W = x.shape
232
+ # FIXME look at relaxing size constraints
233
+ assert H == self.img_size[0] and W == self.img_size[1], \
234
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
235
+ x = self.proj(x).flatten(2).transpose(1, 2)
236
+ return x
237
+
238
+ def relprop(self, cam, **kwargs):
239
+ cam = cam.transpose(1,2)
240
+ cam = cam.reshape(cam.shape[0], cam.shape[1],
241
+ (self.img_size[0] // self.patch_size[0]), (self.img_size[1] // self.patch_size[1]))
242
+ return self.proj.relprop(cam, **kwargs)
243
+
244
+
245
+ class VisionTransformer(nn.Module):
246
+ """ Vision Transformer with support for patch or hybrid CNN input stage
247
+ """
248
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
249
+ num_heads=12, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0., attn_drop_rate=0.):
250
+ super().__init__()
251
+ self.num_classes = num_classes
252
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
253
+ self.patch_embed = PatchEmbed(
254
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
255
+ num_patches = self.patch_embed.num_patches
256
+
257
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
258
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
259
+
260
+ self.blocks = nn.ModuleList([
261
+ Block(
262
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
263
+ drop=drop_rate, attn_drop=attn_drop_rate)
264
+ for i in range(depth)])
265
+
266
+ self.norm = LayerNorm(embed_dim)
267
+ if mlp_head:
268
+ # paper diagram suggests 'MLP head', but results in 4M extra parameters vs paper
269
+ self.head = Mlp(embed_dim, int(embed_dim * mlp_ratio), num_classes)
270
+ else:
271
+ # with a single Linear layer as head, the param count within rounding of paper
272
+ self.head = Linear(embed_dim, num_classes)
273
+
274
+ # FIXME not quite sure what the proper weight init is supposed to be,
275
+ # normal / trunc normal w/ std == .02 similar to other Bert like transformers
276
+ trunc_normal_(self.pos_embed, std=.02) # embeddings same as weights?
277
+ trunc_normal_(self.cls_token, std=.02)
278
+ self.apply(self._init_weights)
279
+
280
+ self.pool = IndexSelect()
281
+ self.add = Add()
282
+
283
+ self.inp_grad = None
284
+
285
+ def save_inp_grad(self,grad):
286
+ self.inp_grad = grad
287
+
288
+ def get_inp_grad(self):
289
+ return self.inp_grad
290
+
291
+
292
+ def _init_weights(self, m):
293
+ if isinstance(m, nn.Linear):
294
+ trunc_normal_(m.weight, std=.02)
295
+ if isinstance(m, nn.Linear) and m.bias is not None:
296
+ nn.init.constant_(m.bias, 0)
297
+ elif isinstance(m, nn.LayerNorm):
298
+ nn.init.constant_(m.bias, 0)
299
+ nn.init.constant_(m.weight, 1.0)
300
+
301
+ @property
302
+ def no_weight_decay(self):
303
+ return {'pos_embed', 'cls_token'}
304
+
305
+ def forward(self, x):
306
+ B = x.shape[0]
307
+ x = self.patch_embed(x)
308
+
309
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
310
+ x = torch.cat((cls_tokens, x), dim=1)
311
+ x = self.add([x, self.pos_embed])
312
+
313
+ x.register_hook(self.save_inp_grad)
314
+
315
+ for blk in self.blocks:
316
+ x = blk(x)
317
+
318
+ x = self.norm(x)
319
+ x = self.pool(x, dim=1, indices=torch.tensor(0, device=x.device))
320
+ x = x.squeeze(1)
321
+ x = self.head(x)
322
+ return x
323
+
324
+ def relprop(self, cam=None,method="transformer_attribution", is_ablation=False, start_layer=0, **kwargs):
325
+ # print(kwargs)
326
+ # print("conservation 1", cam.sum())
327
+ cam = self.head.relprop(cam, **kwargs)
328
+ cam = cam.unsqueeze(1)
329
+ cam = self.pool.relprop(cam, **kwargs)
330
+ cam = self.norm.relprop(cam, **kwargs)
331
+ for blk in reversed(self.blocks):
332
+ cam = blk.relprop(cam, **kwargs)
333
+
334
+ # print("conservation 2", cam.sum())
335
+ # print("min", cam.min())
336
+
337
+ if method == "full":
338
+ (cam, _) = self.add.relprop(cam, **kwargs)
339
+ cam = cam[:, 1:]
340
+ cam = self.patch_embed.relprop(cam, **kwargs)
341
+ # sum on channels
342
+ cam = cam.sum(dim=1)
343
+ return cam
344
+
345
+ elif method == "rollout":
346
+ # cam rollout
347
+ attn_cams = []
348
+ for blk in self.blocks:
349
+ attn_heads = blk.attn.get_attn_cam().clamp(min=0)
350
+ avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
351
+ attn_cams.append(avg_heads)
352
+ cam = compute_rollout_attention(attn_cams, start_layer=start_layer)
353
+ cam = cam[:, 0, 1:]
354
+ return cam
355
+
356
+ # our method, method name grad is legacy
357
+ elif method == "transformer_attribution" or method == "grad":
358
+ cams = []
359
+ for blk in self.blocks:
360
+ grad = blk.attn.get_attn_gradients()
361
+ cam = blk.attn.get_attn_cam()
362
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
363
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
364
+ cam = grad * cam
365
+ cam = cam.clamp(min=0).mean(dim=0)
366
+ cams.append(cam.unsqueeze(0))
367
+ rollout = compute_rollout_attention(cams, start_layer=start_layer)
368
+ cam = rollout[:, 0, 1:]
369
+ return cam
370
+
371
+ elif method == "last_layer":
372
+ cam = self.blocks[-1].attn.get_attn_cam()
373
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
374
+ if is_ablation:
375
+ grad = self.blocks[-1].attn.get_attn_gradients()
376
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
377
+ cam = grad * cam
378
+ cam = cam.clamp(min=0).mean(dim=0)
379
+ cam = cam[0, 1:]
380
+ return cam
381
+
382
+ elif method == "last_layer_attn":
383
+ cam = self.blocks[-1].attn.get_attn()
384
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
385
+ cam = cam.clamp(min=0).mean(dim=0)
386
+ cam = cam[0, 1:]
387
+ return cam
388
+
389
+ elif method == "second_layer":
390
+ cam = self.blocks[1].attn.get_attn_cam()
391
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
392
+ if is_ablation:
393
+ grad = self.blocks[1].attn.get_attn_gradients()
394
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
395
+ cam = grad * cam
396
+ cam = cam.clamp(min=0).mean(dim=0)
397
+ cam = cam[0, 1:]
398
+ return cam
399
+
400
+
401
+ def _conv_filter(state_dict, patch_size=16):
402
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
403
+ out_dict = {}
404
+ for k, v in state_dict.items():
405
+ if 'patch_embed.proj.weight' in k:
406
+ v = v.reshape((v.shape[0], 3, patch_size, patch_size))
407
+ out_dict[k] = v
408
+ return out_dict
409
+
410
+ def vit_base_patch16_224(pretrained=False, **kwargs):
411
+ model = VisionTransformer(
412
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs)
413
+ model.default_cfg = default_cfgs['vit_base_patch16_224']
414
+ if pretrained:
415
+ load_pretrained(
416
+ model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
417
+ return model
418
+
419
+ def vit_large_patch16_224(pretrained=False, **kwargs):
420
+ model = VisionTransformer(
421
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, **kwargs)
422
+ model.default_cfg = default_cfgs['vit_large_patch16_224']
423
+ if pretrained:
424
+ load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
425
+ return model
426
+
427
+ def deit_base_patch16_224(pretrained=False, **kwargs):
428
+ model = VisionTransformer(
429
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs)
430
+ model.default_cfg = _cfg()
431
+ if pretrained:
432
+ checkpoint = torch.hub.load_state_dict_from_url(
433
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
434
+ map_location="cpu", check_hash=True
435
+ )
436
+ model.load_state_dict(checkpoint["model"])
437
+ return model
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/ViT_explanation_generator.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import numpy as np
4
+ from numpy import *
5
+
6
+ # compute rollout between attention layers
7
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
8
+ # adding residual consideration- code adapted from https://github.com/samiraabnar/attention_flow
9
+ num_tokens = all_layer_matrices[0].shape[1]
10
+ batch_size = all_layer_matrices[0].shape[0]
11
+ eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
12
+ all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
13
+ matrices_aug = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
14
+ for i in range(len(all_layer_matrices))]
15
+ joint_attention = matrices_aug[start_layer]
16
+ for i in range(start_layer+1, len(matrices_aug)):
17
+ joint_attention = matrices_aug[i].bmm(joint_attention)
18
+ return joint_attention
19
+
20
+ class LRP:
21
+ def __init__(self, model):
22
+ self.model = model
23
+ self.model.eval()
24
+
25
+ def generate_LRP(self, input, index=None, method="transformer_attribution", is_ablation=False, start_layer=0):
26
+ output = self.model(input)
27
+ kwargs = {"alpha": 1}
28
+ if index == None:
29
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
30
+
31
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
32
+ one_hot[0, index] = 1
33
+ one_hot_vector = one_hot
34
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
35
+ one_hot = torch.sum(one_hot.to(input.device) * output)
36
+
37
+ self.model.zero_grad()
38
+ one_hot.backward(retain_graph=True)
39
+
40
+ return self.model.relprop(torch.tensor(one_hot_vector).to(input.device), method=method, is_ablation=is_ablation,
41
+ start_layer=start_layer, **kwargs)
42
+
43
+
44
+
45
+ class Baselines:
46
+ def __init__(self, model):
47
+ self.model = model
48
+ self.model.eval()
49
+
50
+ def generate_cam_attn(self, input, index=None):
51
+ output = self.model(input, register_hook=True)
52
+ if index == None:
53
+ index = np.argmax(output.cpu().data.numpy())
54
+
55
+ one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
56
+ one_hot[0][index] = 1
57
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
58
+ one_hot = torch.sum(one_hot.to(output.device) * output)
59
+
60
+ self.model.zero_grad()
61
+ one_hot.backward(retain_graph=True)
62
+ #################### attn
63
+ grad = self.model.blocks[-1].attn.get_attn_gradients()
64
+ cam = self.model.blocks[-1].attn.get_attention_map()
65
+ cam = cam[0, :, 0, 1:].reshape(-1, 14, 14)
66
+ grad = grad[0, :, 0, 1:].reshape(-1, 14, 14)
67
+ grad = grad.mean(dim=[1, 2], keepdim=True)
68
+ cam = (cam * grad).mean(0).clamp(min=0)
69
+ cam = (cam - cam.min()) / (cam.max() - cam.min())
70
+
71
+ return cam
72
+ #################### attn
73
+
74
+ def generate_rollout(self, input, start_layer=0):
75
+ self.model(input)
76
+ blocks = self.model.blocks
77
+ all_layer_attentions = []
78
+ for blk in blocks:
79
+ attn_heads = blk.attn.get_attention_map()
80
+ avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
81
+ all_layer_attentions.append(avg_heads)
82
+ rollout = compute_rollout_attention(all_layer_attentions, start_layer=start_layer)
83
+ return rollout[:,0, 1:]
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/ViT_new.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Vision Transformer (ViT) in PyTorch
2
+ Hacked together by / Copyright 2020 Ross Wightman
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ from functools import partial
7
+ from einops import rearrange
8
+
9
+ from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.helpers import load_pretrained
10
+ from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.weight_init import trunc_normal_
11
+ from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.layer_helpers import to_2tuple
12
+
13
+
14
+ def _cfg(url='', **kwargs):
15
+ return {
16
+ 'url': url,
17
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
18
+ 'crop_pct': .9, 'interpolation': 'bicubic',
19
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
20
+ **kwargs
21
+ }
22
+
23
+
24
+ default_cfgs = {
25
+ # patch models
26
+ 'vit_small_patch16_224': _cfg(
27
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
28
+ ),
29
+ 'vit_base_patch16_224': _cfg(
30
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
31
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
32
+ ),
33
+ 'vit_large_patch16_224': _cfg(
34
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
35
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
36
+ }
37
+
38
+ class Mlp(nn.Module):
39
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
40
+ super().__init__()
41
+ out_features = out_features or in_features
42
+ hidden_features = hidden_features or in_features
43
+ self.fc1 = nn.Linear(in_features, hidden_features)
44
+ self.act = act_layer()
45
+ self.fc2 = nn.Linear(hidden_features, out_features)
46
+ self.drop = nn.Dropout(drop)
47
+
48
+ def forward(self, x):
49
+ x = self.fc1(x)
50
+ x = self.act(x)
51
+ x = self.drop(x)
52
+ x = self.fc2(x)
53
+ x = self.drop(x)
54
+ return x
55
+
56
+
57
+ class Attention(nn.Module):
58
+ def __init__(self, dim, num_heads=8, qkv_bias=False,attn_drop=0., proj_drop=0.):
59
+ super().__init__()
60
+ self.num_heads = num_heads
61
+ head_dim = dim // num_heads
62
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
63
+ self.scale = head_dim ** -0.5
64
+
65
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
66
+ self.attn_drop = nn.Dropout(attn_drop)
67
+ self.proj = nn.Linear(dim, dim)
68
+ self.proj_drop = nn.Dropout(proj_drop)
69
+
70
+ self.attn_gradients = None
71
+ self.attention_map = None
72
+
73
+ def save_attn_gradients(self, attn_gradients):
74
+ self.attn_gradients = attn_gradients
75
+
76
+ def get_attn_gradients(self):
77
+ return self.attn_gradients
78
+
79
+ def save_attention_map(self, attention_map):
80
+ self.attention_map = attention_map
81
+
82
+ def get_attention_map(self):
83
+ return self.attention_map
84
+
85
+ def forward(self, x, register_hook=False):
86
+ b, n, _, h = *x.shape, self.num_heads
87
+
88
+ # self.save_output(x)
89
+ # x.register_hook(self.save_output_grad)
90
+
91
+ qkv = self.qkv(x)
92
+ q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)
93
+
94
+ dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
95
+
96
+ attn = dots.softmax(dim=-1)
97
+ attn = self.attn_drop(attn)
98
+
99
+ out = torch.einsum('bhij,bhjd->bhid', attn, v)
100
+
101
+ self.save_attention_map(attn)
102
+ if register_hook:
103
+ attn.register_hook(self.save_attn_gradients)
104
+
105
+ out = rearrange(out, 'b h n d -> b n (h d)')
106
+ out = self.proj(out)
107
+ out = self.proj_drop(out)
108
+ return out
109
+
110
+
111
+ class Block(nn.Module):
112
+
113
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
114
+ super().__init__()
115
+ self.norm1 = norm_layer(dim)
116
+ self.attn = Attention(
117
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
118
+ self.norm2 = norm_layer(dim)
119
+ mlp_hidden_dim = int(dim * mlp_ratio)
120
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
121
+
122
+ def forward(self, x, register_hook=False):
123
+ x = x + self.attn(self.norm1(x), register_hook=register_hook)
124
+ x = x + self.mlp(self.norm2(x))
125
+ return x
126
+
127
+
128
+ class PatchEmbed(nn.Module):
129
+ """ Image to Patch Embedding
130
+ """
131
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
132
+ super().__init__()
133
+ img_size = to_2tuple(img_size)
134
+ patch_size = to_2tuple(patch_size)
135
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
136
+ self.img_size = img_size
137
+ self.patch_size = patch_size
138
+ self.num_patches = num_patches
139
+
140
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
141
+
142
+ def forward(self, x):
143
+ B, C, H, W = x.shape
144
+ # FIXME look at relaxing size constraints
145
+ assert H == self.img_size[0] and W == self.img_size[1], \
146
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
147
+ x = self.proj(x).flatten(2).transpose(1, 2)
148
+ return x
149
+
150
+ class VisionTransformer(nn.Module):
151
+ """ Vision Transformer
152
+ """
153
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
154
+ num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., norm_layer=nn.LayerNorm):
155
+ super().__init__()
156
+ self.num_classes = num_classes
157
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
158
+ self.patch_embed = PatchEmbed(
159
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
160
+ num_patches = self.patch_embed.num_patches
161
+
162
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
163
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
164
+ self.pos_drop = nn.Dropout(p=drop_rate)
165
+
166
+ self.blocks = nn.ModuleList([
167
+ Block(
168
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
169
+ drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer)
170
+ for i in range(depth)])
171
+ self.norm = norm_layer(embed_dim)
172
+
173
+ # Classifier head
174
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
175
+
176
+ trunc_normal_(self.pos_embed, std=.02)
177
+ trunc_normal_(self.cls_token, std=.02)
178
+ self.apply(self._init_weights)
179
+
180
+ def _init_weights(self, m):
181
+ if isinstance(m, nn.Linear):
182
+ trunc_normal_(m.weight, std=.02)
183
+ if isinstance(m, nn.Linear) and m.bias is not None:
184
+ nn.init.constant_(m.bias, 0)
185
+ elif isinstance(m, nn.LayerNorm):
186
+ nn.init.constant_(m.bias, 0)
187
+ nn.init.constant_(m.weight, 1.0)
188
+
189
+ @torch.jit.ignore
190
+ def no_weight_decay(self):
191
+ return {'pos_embed', 'cls_token'}
192
+
193
+ def forward(self, x, register_hook=False):
194
+ B = x.shape[0]
195
+ x = self.patch_embed(x)
196
+
197
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
198
+ x = torch.cat((cls_tokens, x), dim=1)
199
+ x = x + self.pos_embed
200
+ x = self.pos_drop(x)
201
+
202
+ for blk in self.blocks:
203
+ x = blk(x, register_hook=register_hook)
204
+
205
+ x = self.norm(x)
206
+ x = x[:, 0]
207
+ x = self.head(x)
208
+ return x
209
+
210
+
211
+ def _conv_filter(state_dict, patch_size=16):
212
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
213
+ out_dict = {}
214
+ for k, v in state_dict.items():
215
+ if 'patch_embed.proj.weight' in k:
216
+ v = v.reshape((v.shape[0], 3, patch_size, patch_size))
217
+ out_dict[k] = v
218
+ return out_dict
219
+
220
+
221
+ def vit_base_patch16_224(pretrained=False, **kwargs):
222
+ model = VisionTransformer(
223
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
224
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
225
+ model.default_cfg = default_cfgs['vit_base_patch16_224']
226
+ if pretrained:
227
+ load_pretrained(
228
+ model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
229
+ return model
230
+
231
+ def vit_large_patch16_224(pretrained=False, **kwargs):
232
+ model = VisionTransformer(
233
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
234
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
235
+ model.default_cfg = default_cfgs['vit_large_patch16_224']
236
+ if pretrained:
237
+ load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
238
+ return model
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/ViT_orig_LRP.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Vision Transformer (ViT) in PyTorch
2
+ Hacked together by / Copyright 2020 Ross Wightman
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange
7
+
8
+ from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.modules.layers_lrp import *
9
+ from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.helpers import load_pretrained
10
+ from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.weight_init import trunc_normal_
11
+ from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.layer_helpers import to_2tuple
12
+
13
+
14
+ def _cfg(url='', **kwargs):
15
+ return {
16
+ 'url': url,
17
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
18
+ 'crop_pct': .9, 'interpolation': 'bicubic',
19
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
20
+ **kwargs
21
+ }
22
+
23
+
24
+ default_cfgs = {
25
+ # patch models
26
+ 'vit_small_patch16_224': _cfg(
27
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
28
+ ),
29
+ 'vit_base_patch16_224': _cfg(
30
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
31
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
32
+ ),
33
+ 'vit_large_patch16_224': _cfg(
34
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
35
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
36
+ }
37
+
38
+ def compute_rollout_attention(all_layer_matrices, start_layer=0):
39
+ # adding residual consideration
40
+ num_tokens = all_layer_matrices[0].shape[1]
41
+ batch_size = all_layer_matrices[0].shape[0]
42
+ eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
43
+ all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
44
+ # all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
45
+ # for i in range(len(all_layer_matrices))]
46
+ joint_attention = all_layer_matrices[start_layer]
47
+ for i in range(start_layer+1, len(all_layer_matrices)):
48
+ joint_attention = all_layer_matrices[i].bmm(joint_attention)
49
+ return joint_attention
50
+
51
+ class Mlp(nn.Module):
52
+ def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
53
+ super().__init__()
54
+ out_features = out_features or in_features
55
+ hidden_features = hidden_features or in_features
56
+ self.fc1 = Linear(in_features, hidden_features)
57
+ self.act = GELU()
58
+ self.fc2 = Linear(hidden_features, out_features)
59
+ self.drop = Dropout(drop)
60
+
61
+ def forward(self, x):
62
+ x = self.fc1(x)
63
+ x = self.act(x)
64
+ x = self.drop(x)
65
+ x = self.fc2(x)
66
+ x = self.drop(x)
67
+ return x
68
+
69
+ def relprop(self, cam, **kwargs):
70
+ cam = self.drop.relprop(cam, **kwargs)
71
+ cam = self.fc2.relprop(cam, **kwargs)
72
+ cam = self.act.relprop(cam, **kwargs)
73
+ cam = self.fc1.relprop(cam, **kwargs)
74
+ return cam
75
+
76
+
77
+ class Attention(nn.Module):
78
+ def __init__(self, dim, num_heads=8, qkv_bias=False,attn_drop=0., proj_drop=0.):
79
+ super().__init__()
80
+ self.num_heads = num_heads
81
+ head_dim = dim // num_heads
82
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
83
+ self.scale = head_dim ** -0.5
84
+
85
+ # A = Q*K^T
86
+ self.matmul1 = einsum('bhid,bhjd->bhij')
87
+ # attn = A*V
88
+ self.matmul2 = einsum('bhij,bhjd->bhid')
89
+
90
+ self.qkv = Linear(dim, dim * 3, bias=qkv_bias)
91
+ self.attn_drop = Dropout(attn_drop)
92
+ self.proj = Linear(dim, dim)
93
+ self.proj_drop = Dropout(proj_drop)
94
+ self.softmax = Softmax(dim=-1)
95
+
96
+ self.attn_cam = None
97
+ self.attn = None
98
+ self.v = None
99
+ self.v_cam = None
100
+ self.attn_gradients = None
101
+
102
+ def get_attn(self):
103
+ return self.attn
104
+
105
+ def save_attn(self, attn):
106
+ self.attn = attn
107
+
108
+ def save_attn_cam(self, cam):
109
+ self.attn_cam = cam
110
+
111
+ def get_attn_cam(self):
112
+ return self.attn_cam
113
+
114
+ def get_v(self):
115
+ return self.v
116
+
117
+ def save_v(self, v):
118
+ self.v = v
119
+
120
+ def save_v_cam(self, cam):
121
+ self.v_cam = cam
122
+
123
+ def get_v_cam(self):
124
+ return self.v_cam
125
+
126
+ def save_attn_gradients(self, attn_gradients):
127
+ self.attn_gradients = attn_gradients
128
+
129
+ def get_attn_gradients(self):
130
+ return self.attn_gradients
131
+
132
+ def forward(self, x):
133
+ b, n, _, h = *x.shape, self.num_heads
134
+ qkv = self.qkv(x)
135
+ q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h)
136
+
137
+ self.save_v(v)
138
+
139
+ dots = self.matmul1([q, k]) * self.scale
140
+
141
+ attn = self.softmax(dots)
142
+ attn = self.attn_drop(attn)
143
+
144
+ self.save_attn(attn)
145
+ attn.register_hook(self.save_attn_gradients)
146
+
147
+ out = self.matmul2([attn, v])
148
+ out = rearrange(out, 'b h n d -> b n (h d)')
149
+
150
+ out = self.proj(out)
151
+ out = self.proj_drop(out)
152
+ return out
153
+
154
+ def relprop(self, cam, **kwargs):
155
+ cam = self.proj_drop.relprop(cam, **kwargs)
156
+ cam = self.proj.relprop(cam, **kwargs)
157
+ cam = rearrange(cam, 'b n (h d) -> b h n d', h=self.num_heads)
158
+
159
+ # attn = A*V
160
+ (cam1, cam_v)= self.matmul2.relprop(cam, **kwargs)
161
+ cam1 /= 2
162
+ cam_v /= 2
163
+
164
+ self.save_v_cam(cam_v)
165
+ self.save_attn_cam(cam1)
166
+
167
+ cam1 = self.attn_drop.relprop(cam1, **kwargs)
168
+ cam1 = self.softmax.relprop(cam1, **kwargs)
169
+
170
+ # A = Q*K^T
171
+ (cam_q, cam_k) = self.matmul1.relprop(cam1, **kwargs)
172
+ cam_q /= 2
173
+ cam_k /= 2
174
+
175
+ cam_qkv = rearrange([cam_q, cam_k, cam_v], 'qkv b h n d -> b n (qkv h d)', qkv=3, h=self.num_heads)
176
+
177
+ return self.qkv.relprop(cam_qkv, **kwargs)
178
+
179
+
180
+ class Block(nn.Module):
181
+
182
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.):
183
+ super().__init__()
184
+ self.norm1 = LayerNorm(dim, eps=1e-6)
185
+ self.attn = Attention(
186
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
187
+ self.norm2 = LayerNorm(dim, eps=1e-6)
188
+ mlp_hidden_dim = int(dim * mlp_ratio)
189
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
190
+
191
+ self.add1 = Add()
192
+ self.add2 = Add()
193
+ self.clone1 = Clone()
194
+ self.clone2 = Clone()
195
+
196
+ def forward(self, x):
197
+ x1, x2 = self.clone1(x, 2)
198
+ x = self.add1([x1, self.attn(self.norm1(x2))])
199
+ x1, x2 = self.clone2(x, 2)
200
+ x = self.add2([x1, self.mlp(self.norm2(x2))])
201
+ return x
202
+
203
+ def relprop(self, cam, **kwargs):
204
+ (cam1, cam2) = self.add2.relprop(cam, **kwargs)
205
+ cam2 = self.mlp.relprop(cam2, **kwargs)
206
+ cam2 = self.norm2.relprop(cam2, **kwargs)
207
+ cam = self.clone2.relprop((cam1, cam2), **kwargs)
208
+
209
+ (cam1, cam2) = self.add1.relprop(cam, **kwargs)
210
+ cam2 = self.attn.relprop(cam2, **kwargs)
211
+ cam2 = self.norm1.relprop(cam2, **kwargs)
212
+ cam = self.clone1.relprop((cam1, cam2), **kwargs)
213
+ return cam
214
+
215
+
216
+ class PatchEmbed(nn.Module):
217
+ """ Image to Patch Embedding
218
+ """
219
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
220
+ super().__init__()
221
+ img_size = to_2tuple(img_size)
222
+ patch_size = to_2tuple(patch_size)
223
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
224
+ self.img_size = img_size
225
+ self.patch_size = patch_size
226
+ self.num_patches = num_patches
227
+
228
+ self.proj = Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
229
+
230
+ def forward(self, x):
231
+ B, C, H, W = x.shape
232
+ # FIXME look at relaxing size constraints
233
+ assert H == self.img_size[0] and W == self.img_size[1], \
234
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
235
+ x = self.proj(x).flatten(2).transpose(1, 2)
236
+ return x
237
+
238
+ def relprop(self, cam, **kwargs):
239
+ cam = cam.transpose(1,2)
240
+ cam = cam.reshape(cam.shape[0], cam.shape[1],
241
+ (self.img_size[0] // self.patch_size[0]), (self.img_size[1] // self.patch_size[1]))
242
+ return self.proj.relprop(cam, **kwargs)
243
+
244
+
245
+ class VisionTransformer(nn.Module):
246
+ """ Vision Transformer with support for patch or hybrid CNN input stage
247
+ """
248
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
249
+ num_heads=12, mlp_ratio=4., qkv_bias=False, mlp_head=False, drop_rate=0., attn_drop_rate=0.):
250
+ super().__init__()
251
+ self.num_classes = num_classes
252
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
253
+ self.patch_embed = PatchEmbed(
254
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
255
+ num_patches = self.patch_embed.num_patches
256
+
257
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
258
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
259
+
260
+ self.blocks = nn.ModuleList([
261
+ Block(
262
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
263
+ drop=drop_rate, attn_drop=attn_drop_rate)
264
+ for i in range(depth)])
265
+
266
+ self.norm = LayerNorm(embed_dim)
267
+ if mlp_head:
268
+ # paper diagram suggests 'MLP head', but results in 4M extra parameters vs paper
269
+ self.head = Mlp(embed_dim, int(embed_dim * mlp_ratio), num_classes)
270
+ else:
271
+ # with a single Linear layer as head, the param count within rounding of paper
272
+ self.head = Linear(embed_dim, num_classes)
273
+
274
+ # FIXME not quite sure what the proper weight init is supposed to be,
275
+ # normal / trunc normal w/ std == .02 similar to other Bert like transformers
276
+ trunc_normal_(self.pos_embed, std=.02) # embeddings same as weights?
277
+ trunc_normal_(self.cls_token, std=.02)
278
+ self.apply(self._init_weights)
279
+
280
+ self.pool = IndexSelect()
281
+ self.add = Add()
282
+
283
+ self.inp_grad = None
284
+
285
+ def save_inp_grad(self,grad):
286
+ self.inp_grad = grad
287
+
288
+ def get_inp_grad(self):
289
+ return self.inp_grad
290
+
291
+
292
+ def _init_weights(self, m):
293
+ if isinstance(m, nn.Linear):
294
+ trunc_normal_(m.weight, std=.02)
295
+ if isinstance(m, nn.Linear) and m.bias is not None:
296
+ nn.init.constant_(m.bias, 0)
297
+ elif isinstance(m, nn.LayerNorm):
298
+ nn.init.constant_(m.bias, 0)
299
+ nn.init.constant_(m.weight, 1.0)
300
+
301
+ @property
302
+ def no_weight_decay(self):
303
+ return {'pos_embed', 'cls_token'}
304
+
305
+ def forward(self, x):
306
+ B = x.shape[0]
307
+ x = self.patch_embed(x)
308
+
309
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
310
+ x = torch.cat((cls_tokens, x), dim=1)
311
+ x = self.add([x, self.pos_embed])
312
+
313
+ x.register_hook(self.save_inp_grad)
314
+
315
+ for blk in self.blocks:
316
+ x = blk(x)
317
+
318
+ x = self.norm(x)
319
+ x = self.pool(x, dim=1, indices=torch.tensor(0, device=x.device))
320
+ x = x.squeeze(1)
321
+ x = self.head(x)
322
+ return x
323
+
324
+ def relprop(self, cam=None,method="grad", is_ablation=False, start_layer=0, **kwargs):
325
+ # print(kwargs)
326
+ # print("conservation 1", cam.sum())
327
+ cam = self.head.relprop(cam, **kwargs)
328
+ cam = cam.unsqueeze(1)
329
+ cam = self.pool.relprop(cam, **kwargs)
330
+ cam = self.norm.relprop(cam, **kwargs)
331
+ for blk in reversed(self.blocks):
332
+ cam = blk.relprop(cam, **kwargs)
333
+
334
+ # print("conservation 2", cam.sum())
335
+ # print("min", cam.min())
336
+
337
+ if method == "full":
338
+ (cam, _) = self.add.relprop(cam, **kwargs)
339
+ cam = cam[:, 1:]
340
+ cam = self.patch_embed.relprop(cam, **kwargs)
341
+ # sum on channels
342
+ cam = cam.sum(dim=1)
343
+ return cam
344
+
345
+ elif method == "rollout":
346
+ # cam rollout
347
+ attn_cams = []
348
+ for blk in self.blocks:
349
+ attn_heads = blk.attn.get_attn_cam().clamp(min=0)
350
+ avg_heads = (attn_heads.sum(dim=1) / attn_heads.shape[1]).detach()
351
+ attn_cams.append(avg_heads)
352
+ cam = compute_rollout_attention(attn_cams, start_layer=start_layer)
353
+ cam = cam[:, 0, 1:]
354
+ return cam
355
+
356
+ elif method == "grad":
357
+ cams = []
358
+ for blk in self.blocks:
359
+ grad = blk.attn.get_attn_gradients()
360
+ cam = blk.attn.get_attn_cam()
361
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
362
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
363
+ cam = grad * cam
364
+ cam = cam.clamp(min=0).mean(dim=0)
365
+ cams.append(cam.unsqueeze(0))
366
+ rollout = compute_rollout_attention(cams, start_layer=start_layer)
367
+ cam = rollout[:, 0, 1:]
368
+ return cam
369
+
370
+ elif method == "last_layer":
371
+ cam = self.blocks[-1].attn.get_attn_cam()
372
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
373
+ if is_ablation:
374
+ grad = self.blocks[-1].attn.get_attn_gradients()
375
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
376
+ cam = grad * cam
377
+ cam = cam.clamp(min=0).mean(dim=0)
378
+ cam = cam[0, 1:]
379
+ return cam
380
+
381
+ elif method == "last_layer_attn":
382
+ cam = self.blocks[-1].attn.get_attn()
383
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
384
+ cam = cam.clamp(min=0).mean(dim=0)
385
+ cam = cam[0, 1:]
386
+ return cam
387
+
388
+ elif method == "second_layer":
389
+ cam = self.blocks[1].attn.get_attn_cam()
390
+ cam = cam[0].reshape(-1, cam.shape[-1], cam.shape[-1])
391
+ if is_ablation:
392
+ grad = self.blocks[1].attn.get_attn_gradients()
393
+ grad = grad[0].reshape(-1, grad.shape[-1], grad.shape[-1])
394
+ cam = grad * cam
395
+ cam = cam.clamp(min=0).mean(dim=0)
396
+ cam = cam[0, 1:]
397
+ return cam
398
+
399
+
400
+ def _conv_filter(state_dict, patch_size=16):
401
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
402
+ out_dict = {}
403
+ for k, v in state_dict.items():
404
+ if 'patch_embed.proj.weight' in k:
405
+ v = v.reshape((v.shape[0], 3, patch_size, patch_size))
406
+ out_dict[k] = v
407
+ return out_dict
408
+
409
+
410
+ def vit_base_patch16_224(pretrained=False, **kwargs):
411
+ model = VisionTransformer(
412
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, **kwargs)
413
+ model.default_cfg = default_cfgs['vit_base_patch16_224']
414
+ if pretrained:
415
+ load_pretrained(
416
+ model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
417
+ return model
418
+
419
+ def vit_large_patch16_224(pretrained=False, **kwargs):
420
+ model = VisionTransformer(
421
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, **kwargs)
422
+ model.default_cfg = default_cfgs['vit_large_patch16_224']
423
+ if pretrained:
424
+ load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
425
+ return model
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/ViT_LRP.cpython-310.pyc ADDED
Binary file (14.4 kB). View file
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/ViT_explanation_generator.cpython-310.pyc ADDED
Binary file (3.49 kB). View file
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/ViT_new.cpython-310.pyc ADDED
Binary file (9.15 kB). View file
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/ViT_orig_LRP.cpython-310.pyc ADDED
Binary file (13.9 kB). View file
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/helpers.cpython-310.pyc ADDED
Binary file (7.28 kB). View file
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/layer_helpers.cpython-310.pyc ADDED
Binary file (810 Bytes). View file
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/__pycache__/weight_init.cpython-310.pyc ADDED
Binary file (1.98 kB). View file
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/VOC.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tarfile
3
+ import torch
4
+ import torch.utils.data as data
5
+ import numpy as np
6
+ import h5py
7
+
8
+ from PIL import Image
9
+ from scipy import io
10
+ from torchvision.datasets.utils import download_url
11
+
12
+ DATASET_YEAR_DICT = {
13
+ '2012': {
14
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
15
+ 'filename': 'VOCtrainval_11-May-2012.tar',
16
+ 'md5': '6cd6e144f989b92b3379bac3b3de84fd',
17
+ 'base_dir': 'VOCdevkit/VOC2012'
18
+ },
19
+ '2011': {
20
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
21
+ 'filename': 'VOCtrainval_25-May-2011.tar',
22
+ 'md5': '6c3384ef61512963050cb5d687e5bf1e',
23
+ 'base_dir': 'TrainVal/VOCdevkit/VOC2011'
24
+ },
25
+ '2010': {
26
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
27
+ 'filename': 'VOCtrainval_03-May-2010.tar',
28
+ 'md5': 'da459979d0c395079b5c75ee67908abb',
29
+ 'base_dir': 'VOCdevkit/VOC2010'
30
+ },
31
+ '2009': {
32
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
33
+ 'filename': 'VOCtrainval_11-May-2009.tar',
34
+ 'md5': '59065e4b188729180974ef6572f6a212',
35
+ 'base_dir': 'VOCdevkit/VOC2009'
36
+ },
37
+ '2008': {
38
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
39
+ 'filename': 'VOCtrainval_11-May-2012.tar',
40
+ 'md5': '2629fa636546599198acfcfbfcf1904a',
41
+ 'base_dir': 'VOCdevkit/VOC2008'
42
+ },
43
+ '2007': {
44
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
45
+ 'filename': 'VOCtrainval_06-Nov-2007.tar',
46
+ 'md5': 'c52e279531787c972589f7e41ab4ae64',
47
+ 'base_dir': 'VOCdevkit/VOC2007'
48
+ }
49
+ }
50
+
51
+
52
+ class VOCSegmentation(data.Dataset):
53
+ """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
54
+
55
+ Args:
56
+ root (string): Root directory of the VOC Dataset.
57
+ year (string, optional): The dataset year, supports years 2007 to 2012.
58
+ image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
59
+ download (bool, optional): If true, downloads the dataset from the internet and
60
+ puts it in root directory. If dataset is already downloaded, it is not
61
+ downloaded again.
62
+ transform (callable, optional): A function/transform that takes in an PIL image
63
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
64
+ target_transform (callable, optional): A function/transform that takes in the
65
+ target and transforms it.
66
+ """
67
+
68
+ CLASSES = 20
69
+ # CLASSES_NAMES = [
70
+ # "background", 'airplane', 'bicycle', 'bird', 'boat', 'bottle',
71
+ # 'bus', 'car', 'cat', 'chair', 'cow', 'table', 'dog', 'horse',
72
+ # 'motorcycle', 'person', 'pot', 'sheep', 'sofa', 'train',
73
+ # 'monitor'
74
+ # # 'ambigious'
75
+ # ]
76
+ CLASSES_NAMES = [
77
+ "background", 'plane', 'bike', 'bird', 'boat', 'bottle',
78
+ 'bus', 'car', 'cat', 'chair', 'cow', 'table', 'dog', 'horse',
79
+ 'motorcycle', 'person', 'pot', 'sheep', 'sofa', 'train',
80
+ 'monitor'
81
+ # 'ambigious'
82
+ ]
83
+
84
+ def __init__(
85
+ self,
86
+ root,
87
+ year='2012',
88
+ image_set='train',
89
+ download=False,
90
+ transform=None,
91
+ target_transform=None,
92
+ binary_class=False
93
+ ):
94
+ self.root = os.path.expanduser(root)
95
+ self.binary_class = binary_class
96
+ self.year = year
97
+ self.url = DATASET_YEAR_DICT[year]['url']
98
+ self.filename = DATASET_YEAR_DICT[year]['filename']
99
+ self.md5 = DATASET_YEAR_DICT[year]['md5']
100
+ self.transform = transform
101
+ self.target_transform = target_transform
102
+ self.image_set = image_set
103
+ base_dir = DATASET_YEAR_DICT[year]['base_dir']
104
+ voc_root = os.path.join(self.root, base_dir)
105
+ image_dir = os.path.join(voc_root, 'JPEGImages')
106
+ mask_dir = os.path.join(voc_root, 'SegmentationClass')
107
+
108
+ if download:
109
+ download_extract(self.url, self.root, self.filename, self.md5)
110
+
111
+ if not os.path.isdir(voc_root):
112
+ raise RuntimeError('Dataset not found or corrupted.' +
113
+ ' You can use download=True to download it')
114
+
115
+ splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
116
+
117
+ split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
118
+
119
+ if not os.path.exists(split_f):
120
+ raise ValueError(
121
+ 'Wrong image_set entered! Please use image_set="train" '
122
+ 'or image_set="trainval" or image_set="val"')
123
+
124
+ with open(os.path.join(split_f), "r") as f:
125
+ file_names = [x.strip() for x in f.readlines()]
126
+
127
+ self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
128
+ self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
129
+ assert (len(self.images) == len(self.masks))
130
+
131
+ def __getitem__(self, index):
132
+ """
133
+ Args:
134
+ index (int): Index
135
+
136
+ Returns:
137
+ tuple: (image, target) where target is the image segmentation.
138
+ """
139
+ img = Image.open(self.images[index]).convert('RGB')
140
+ target = Image.open(self.masks[index])
141
+
142
+ if self.transform is not None:
143
+ img = self.transform(img)
144
+
145
+ if self.target_transform is not None:
146
+ target = np.array(self.target_transform(target)).astype('int32')
147
+ target[target == 255] = -1
148
+ target = torch.from_numpy(target).long()
149
+
150
+ # # Convert target to (2, height, width)
151
+ # target = torch.stack([target, 1 - target], dim=0)
152
+ # Get a list of the classes that are present in the image
153
+ visible_classes = np.unique(target)
154
+ # Convert these to class names
155
+ present_classes = [self.CLASSES_NAMES[i] for i in visible_classes if i != -1]
156
+
157
+ if self.binary_class:
158
+ # Take all classes that aren't zero or -1 and mkae them 1
159
+ target[target >= 1] = 1
160
+
161
+ return img, target, present_classes
162
+
163
+ @staticmethod
164
+ def _mask_transform(mask):
165
+ target = np.array(mask).astype('int32')
166
+ target[target == 255] = -1
167
+ return torch.from_numpy(target).long()
168
+
169
+ def __len__(self):
170
+ return len(self.images)
171
+
172
+ @property
173
+ def pred_offset(self):
174
+ return 0
175
+
176
+
177
+ class VOCClassification(data.Dataset):
178
+ """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
179
+
180
+ Args:
181
+ root (string): Root directory of the VOC Dataset.
182
+ year (string, optional): The dataset year, supports years 2007 to 2012.
183
+ image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
184
+ download (bool, optional): If true, downloads the dataset from the internet and
185
+ puts it in root directory. If dataset is already downloaded, it is not
186
+ downloaded again.
187
+ transform (callable, optional): A function/transform that takes in an PIL image
188
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
189
+ """
190
+ CLASSES = 20
191
+
192
+ def __init__(self,
193
+ root,
194
+ year='2012',
195
+ image_set='train',
196
+ download=False,
197
+ transform=None):
198
+ self.root = os.path.expanduser(root)
199
+ self.year = year
200
+ self.url = DATASET_YEAR_DICT[year]['url']
201
+ self.filename = DATASET_YEAR_DICT[year]['filename']
202
+ self.md5 = DATASET_YEAR_DICT[year]['md5']
203
+ self.transform = transform
204
+ self.image_set = image_set
205
+ base_dir = DATASET_YEAR_DICT[year]['base_dir']
206
+ voc_root = os.path.join(self.root, base_dir)
207
+ image_dir = os.path.join(voc_root, 'JPEGImages')
208
+ mask_dir = os.path.join(voc_root, 'SegmentationClass')
209
+
210
+ if download:
211
+ download_extract(self.url, self.root, self.filename, self.md5)
212
+
213
+ if not os.path.isdir(voc_root):
214
+ raise RuntimeError('Dataset not found or corrupted.' +
215
+ ' You can use download=True to download it')
216
+
217
+ splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
218
+
219
+ split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
220
+
221
+ if not os.path.exists(split_f):
222
+ raise ValueError(
223
+ 'Wrong image_set entered! Please use image_set="train" '
224
+ 'or image_set="trainval" or image_set="val"')
225
+
226
+ with open(os.path.join(split_f), "r") as f:
227
+ file_names = [x.strip() for x in f.readlines()]
228
+
229
+ self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
230
+ self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
231
+ assert (len(self.images) == len(self.masks))
232
+
233
+ def __getitem__(self, index):
234
+ """
235
+ Args:
236
+ index (int): Index
237
+
238
+ Returns:
239
+ tuple: (image, target) where target is the image segmentation.
240
+ """
241
+ img = Image.open(self.images[index]).convert('RGB')
242
+ target = Image.open(self.masks[index])
243
+
244
+ # if self.transform is not None:
245
+ # img = self.transform(img)
246
+ if self.transform is not None:
247
+ img, target = self.transform(img, target)
248
+
249
+ visible_classes = np.unique(target)
250
+ labels = torch.zeros(self.CLASSES)
251
+ for id in visible_classes:
252
+ if id not in (0, 255):
253
+ labels[id - 1].fill_(1)
254
+
255
+ return img, labels
256
+
257
+ def __len__(self):
258
+ return len(self.images)
259
+
260
+
261
+ class VOCSBDClassification(data.Dataset):
262
+ """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
263
+
264
+ Args:
265
+ root (string): Root directory of the VOC Dataset.
266
+ year (string, optional): The dataset year, supports years 2007 to 2012.
267
+ image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
268
+ download (bool, optional): If true, downloads the dataset from the internet and
269
+ puts it in root directory. If dataset is already downloaded, it is not
270
+ downloaded again.
271
+ transform (callable, optional): A function/transform that takes in an PIL image
272
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
273
+ """
274
+ CLASSES = 20
275
+
276
+ def __init__(self,
277
+ root,
278
+ sbd_root,
279
+ year='2012',
280
+ image_set='train',
281
+ download=False,
282
+ transform=None):
283
+ self.root = os.path.expanduser(root)
284
+ self.sbd_root = os.path.expanduser(sbd_root)
285
+ self.year = year
286
+ self.url = DATASET_YEAR_DICT[year]['url']
287
+ self.filename = DATASET_YEAR_DICT[year]['filename']
288
+ self.md5 = DATASET_YEAR_DICT[year]['md5']
289
+ self.transform = transform
290
+ self.image_set = image_set
291
+ base_dir = DATASET_YEAR_DICT[year]['base_dir']
292
+ voc_root = os.path.join(self.root, base_dir)
293
+ image_dir = os.path.join(voc_root, 'JPEGImages')
294
+ mask_dir = os.path.join(voc_root, 'SegmentationClass')
295
+ sbd_image_dir = os.path.join(sbd_root, 'img')
296
+ sbd_mask_dir = os.path.join(sbd_root, 'cls')
297
+
298
+ if download:
299
+ download_extract(self.url, self.root, self.filename, self.md5)
300
+
301
+ if not os.path.isdir(voc_root):
302
+ raise RuntimeError('Dataset not found or corrupted.' +
303
+ ' You can use download=True to download it')
304
+
305
+ splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
306
+
307
+ split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
308
+ sbd_split = os.path.join(sbd_root, 'train.txt')
309
+
310
+ if not os.path.exists(split_f):
311
+ raise ValueError(
312
+ 'Wrong image_set entered! Please use image_set="train" '
313
+ 'or image_set="trainval" or image_set="val"')
314
+
315
+ with open(os.path.join(split_f), "r") as f:
316
+ voc_file_names = [x.strip() for x in f.readlines()]
317
+
318
+ with open(os.path.join(sbd_split), "r") as f:
319
+ sbd_file_names = [x.strip() for x in f.readlines()]
320
+
321
+ self.images = [os.path.join(image_dir, x + ".jpg") for x in voc_file_names]
322
+ self.images += [os.path.join(sbd_image_dir, x + ".jpg") for x in sbd_file_names]
323
+ self.masks = [os.path.join(mask_dir, x + ".png") for x in voc_file_names]
324
+ self.masks += [os.path.join(sbd_mask_dir, x + ".mat") for x in sbd_file_names]
325
+ assert (len(self.images) == len(self.masks))
326
+
327
+ def __getitem__(self, index):
328
+ """
329
+ Args:
330
+ index (int): Index
331
+
332
+ Returns:
333
+ tuple: (image, target) where target is the image segmentation.
334
+ """
335
+ img = Image.open(self.images[index]).convert('RGB')
336
+ mask_path = self.masks[index]
337
+ if mask_path[-3:] == 'mat':
338
+ target = io.loadmat(mask_path, struct_as_record=False, squeeze_me=True)['GTcls'].Segmentation
339
+ target = Image.fromarray(target, mode='P')
340
+ else:
341
+ target = Image.open(self.masks[index])
342
+
343
+ if self.transform is not None:
344
+ img, target = self.transform(img, target)
345
+
346
+ visible_classes = np.unique(target)
347
+ labels = torch.zeros(self.CLASSES)
348
+ for id in visible_classes:
349
+ if id not in (0, 255):
350
+ labels[id - 1].fill_(1)
351
+
352
+ return img, labels
353
+
354
+ def __len__(self):
355
+ return len(self.images)
356
+
357
+
358
+ def download_extract(url, root, filename, md5):
359
+ download_url(url, root, filename, md5)
360
+ with tarfile.open(os.path.join(root, filename), "r") as tar:
361
+ tar.extractall(path=root)
362
+
363
+
364
+ class VOCResults(data.Dataset):
365
+ CLASSES = 20
366
+ CLASSES_NAMES = [
367
+ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
368
+ 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
369
+ 'motorbike', 'person', 'potted-plant', 'sheep', 'sofa', 'train',
370
+ 'tvmonitor', 'ambigious'
371
+ ]
372
+
373
+ def __init__(self, path):
374
+ super(VOCResults, self).__init__()
375
+
376
+ self.path = os.path.join(path, 'results.hdf5')
377
+ self.data = None
378
+
379
+ print('Reading dataset length...')
380
+ with h5py.File(self.path , 'r') as f:
381
+ self.data_length = len(f['/image'])
382
+
383
+ def __len__(self):
384
+ return self.data_length
385
+
386
+ def __getitem__(self, item):
387
+ if self.data is None:
388
+ self.data = h5py.File(self.path, 'r')
389
+
390
+ image = torch.tensor(self.data['image'][item])
391
+ vis = torch.tensor(self.data['vis'][item])
392
+ target = torch.tensor(self.data['target'][item])
393
+ class_pred = torch.tensor(self.data['class_pred'][item])
394
+
395
+ return image, vis, target, class_pred
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/__init__.py ADDED
File without changes
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/__pycache__/Imagenet.cpython-310.pyc ADDED
Binary file (5.25 kB). View file
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/__pycache__/VOC.cpython-310.pyc ADDED
Binary file (12.1 kB). View file
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (220 Bytes). View file
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/__pycache__/imagenet.cpython-310.pyc ADDED
Binary file (5.37 kB). View file
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/imagenet.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.utils.data as data
4
+ import numpy as np
5
+ import cv2
6
+
7
+ from torchvision.datasets import ImageNet
8
+
9
+ from PIL import Image, ImageFilter
10
+ import h5py
11
+ from glob import glob
12
+
13
+
14
+ class ImageNet_blur(ImageNet):
15
+ def __getitem__(self, index):
16
+ """
17
+ Args:
18
+ index (int): Index
19
+
20
+ Returns:
21
+ tuple: (sample, target) where target is class_index of the target class.
22
+ """
23
+ path, target = self.samples[index]
24
+ sample = self.loader(path)
25
+
26
+ gauss_blur = ImageFilter.GaussianBlur(11)
27
+ median_blur = ImageFilter.MedianFilter(11)
28
+
29
+ blurred_img1 = sample.filter(gauss_blur)
30
+ blurred_img2 = sample.filter(median_blur)
31
+ blurred_img = Image.blend(blurred_img1, blurred_img2, 0.5)
32
+
33
+ if self.transform is not None:
34
+ sample = self.transform(sample)
35
+ blurred_img = self.transform(blurred_img)
36
+ if self.target_transform is not None:
37
+ target = self.target_transform(target)
38
+
39
+ return (sample, blurred_img), target
40
+
41
+
42
+ class Imagenet_Segmentation(data.Dataset):
43
+ CLASSES = 2
44
+
45
+ def __init__(self,
46
+ path,
47
+ transform=None,
48
+ target_transform=None):
49
+ self.path = path
50
+ self.transform = transform
51
+ self.target_transform = target_transform
52
+ # self.h5py = h5py.File(path, 'r+')
53
+ self.h5py = None
54
+ with h5py.File(path, 'r') as tmp:
55
+ self.data_length = len(tmp['/value/img'])
56
+
57
+ def __getitem__(self, index):
58
+
59
+ if self.h5py is None:
60
+ self.h5py = h5py.File(self.path, 'r')
61
+
62
+ img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0))
63
+ target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0))
64
+
65
+ img = Image.fromarray(img).convert('RGB')
66
+ target = Image.fromarray(target)
67
+
68
+ if self.transform is not None:
69
+ img = self.transform(img)
70
+
71
+ if self.target_transform is not None:
72
+ target = np.array(self.target_transform(target)).astype('int32')
73
+ target = torch.from_numpy(target).long()
74
+
75
+ return img, target
76
+
77
+ def __len__(self):
78
+ # return len(self.h5py['/value/img'])
79
+ return self.data_length
80
+
81
+
82
+ class Imagenet_Segmentation_Blur(data.Dataset):
83
+ CLASSES = 2
84
+
85
+ def __init__(self,
86
+ path,
87
+ transform=None,
88
+ target_transform=None):
89
+ self.path = path
90
+ self.transform = transform
91
+ self.target_transform = target_transform
92
+ # self.h5py = h5py.File(path, 'r+')
93
+ self.h5py = None
94
+ tmp = h5py.File(path, 'r')
95
+ self.data_length = len(tmp['/value/img'])
96
+ tmp.close()
97
+ del tmp
98
+
99
+ def __getitem__(self, index):
100
+
101
+ if self.h5py is None:
102
+ self.h5py = h5py.File(self.path, 'r')
103
+
104
+ img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0))
105
+ target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0))
106
+
107
+ img = Image.fromarray(img).convert('RGB')
108
+ target = Image.fromarray(target)
109
+
110
+ gauss_blur = ImageFilter.GaussianBlur(11)
111
+ median_blur = ImageFilter.MedianFilter(11)
112
+
113
+ blurred_img1 = img.filter(gauss_blur)
114
+ blurred_img2 = img.filter(median_blur)
115
+ blurred_img = Image.blend(blurred_img1, blurred_img2, 0.5)
116
+
117
+ # blurred_img1 = cv2.GaussianBlur(img, (11, 11), 5)
118
+ # blurred_img2 = np.float32(cv2.medianBlur(img, 11))
119
+ # blurred_img = (blurred_img1 + blurred_img2) / 2
120
+
121
+ if self.transform is not None:
122
+ img = self.transform(img)
123
+ blurred_img = self.transform(blurred_img)
124
+
125
+ if self.target_transform is not None:
126
+ target = np.array(self.target_transform(target)).astype('int32')
127
+ target = torch.from_numpy(target).long()
128
+
129
+ return (img, blurred_img), target
130
+
131
+ def __len__(self):
132
+ # return len(self.h5py['/value/img'])
133
+ return self.data_length
134
+
135
+
136
+ class Imagenet_Segmentation_eval_dir(data.Dataset):
137
+ CLASSES = 2
138
+
139
+ def __init__(self,
140
+ path,
141
+ eval_path,
142
+ transform=None,
143
+ target_transform=None):
144
+ self.transform = transform
145
+ self.target_transform = target_transform
146
+ self.h5py = h5py.File(path, 'r+')
147
+
148
+ # 500 each file
149
+ self.results = glob(os.path.join(eval_path, '*.npy'))
150
+
151
+ def __getitem__(self, index):
152
+
153
+ img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0))
154
+ target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0))
155
+ res = np.load(self.results[index])
156
+
157
+ img = Image.fromarray(img).convert('RGB')
158
+ target = Image.fromarray(target)
159
+
160
+ if self.transform is not None:
161
+ img = self.transform(img)
162
+
163
+ if self.target_transform is not None:
164
+ target = np.array(self.target_transform(target)).astype('int32')
165
+ target = torch.from_numpy(target).long()
166
+
167
+ return img, target
168
+
169
+ def __len__(self):
170
+ return len(self.h5py['/value/img'])
171
+
172
+
173
+ if __name__ == '__main__':
174
+ import torchvision.transforms as transforms
175
+ from tqdm import tqdm
176
+ from imageio import imsave
177
+ import scipy.io as sio
178
+
179
+ # meta = sio.loadmat('/home/shirgur/ext/Data/Datasets/temp/ILSVRC2012_devkit_t12/data/meta.mat', squeeze_me=True)['synsets']
180
+
181
+ # Data
182
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
183
+ std=[0.229, 0.224, 0.225])
184
+ test_img_trans = transforms.Compose([
185
+ transforms.Resize((224, 224)),
186
+ transforms.ToTensor(),
187
+ normalize,
188
+ ])
189
+ test_lbl_trans = transforms.Compose([
190
+ transforms.Resize((224, 224), Image.NEAREST),
191
+ ])
192
+
193
+ ds = Imagenet_Segmentation('/home/shirgur/ext/Data/Datasets/imagenet-seg/other/gtsegs_ijcv.mat',
194
+ transform=test_img_trans, target_transform=test_lbl_trans)
195
+
196
+ for i, (img, tgt) in enumerate(tqdm(ds)):
197
+ tgt = (tgt.numpy() * 255).astype(np.uint8)
198
+ imsave('/home/shirgur/ext/Code/C2S/run/imagenet/gt/{}.png'.format(i), tgt)
199
+
200
+ print('here')
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/imagenet_utils.py ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CLS2IDX = {
2
+ 0: 'tench, Tinca tinca',
3
+ 1: 'goldfish, Carassius auratus',
4
+ 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
5
+ 3: 'tiger shark, Galeocerdo cuvieri',
6
+ 4: 'hammerhead, hammerhead shark',
7
+ 5: 'electric ray, crampfish, numbfish, torpedo',
8
+ 6: 'stingray',
9
+ 7: 'cock',
10
+ 8: 'hen',
11
+ 9: 'ostrich, Struthio camelus',
12
+ 10: 'brambling, Fringilla montifringilla',
13
+ 11: 'goldfinch, Carduelis carduelis',
14
+ 12: 'house finch, linnet, Carpodacus mexicanus',
15
+ 13: 'junco, snowbird',
16
+ 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
17
+ 15: 'robin, American robin, Turdus migratorius',
18
+ 16: 'bulbul',
19
+ 17: 'jay',
20
+ 18: 'magpie',
21
+ 19: 'chickadee',
22
+ 20: 'water ouzel, dipper',
23
+ 21: 'kite',
24
+ 22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
25
+ 23: 'vulture',
26
+ 24: 'great grey owl, great gray owl, Strix nebulosa',
27
+ 25: 'European fire salamander, Salamandra salamandra',
28
+ 26: 'common newt, Triturus vulgaris',
29
+ 27: 'eft',
30
+ 28: 'spotted salamander, Ambystoma maculatum',
31
+ 29: 'axolotl, mud puppy, Ambystoma mexicanum',
32
+ 30: 'bullfrog, Rana catesbeiana',
33
+ 31: 'tree frog, tree-frog',
34
+ 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
35
+ 33: 'loggerhead, loggerhead turtle, Caretta caretta',
36
+ 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
37
+ 35: 'mud turtle',
38
+ 36: 'terrapin',
39
+ 37: 'box turtle, box tortoise',
40
+ 38: 'banded gecko',
41
+ 39: 'common iguana, iguana, Iguana iguana',
42
+ 40: 'American chameleon, anole, Anolis carolinensis',
43
+ 41: 'whiptail, whiptail lizard',
44
+ 42: 'agama',
45
+ 43: 'frilled lizard, Chlamydosaurus kingi',
46
+ 44: 'alligator lizard',
47
+ 45: 'Gila monster, Heloderma suspectum',
48
+ 46: 'green lizard, Lacerta viridis',
49
+ 47: 'African chameleon, Chamaeleo chamaeleon',
50
+ 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
51
+ 49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
52
+ 50: 'American alligator, Alligator mississipiensis',
53
+ 51: 'triceratops',
54
+ 52: 'thunder snake, worm snake, Carphophis amoenus',
55
+ 53: 'ringneck snake, ring-necked snake, ring snake',
56
+ 54: 'hognose snake, puff adder, sand viper',
57
+ 55: 'green snake, grass snake',
58
+ 56: 'king snake, kingsnake',
59
+ 57: 'garter snake, grass snake',
60
+ 58: 'water snake',
61
+ 59: 'vine snake',
62
+ 60: 'night snake, Hypsiglena torquata',
63
+ 61: 'boa constrictor, Constrictor constrictor',
64
+ 62: 'rock python, rock snake, Python sebae',
65
+ 63: 'Indian cobra, Naja naja',
66
+ 64: 'green mamba',
67
+ 65: 'sea snake',
68
+ 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
69
+ 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
70
+ 68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
71
+ 69: 'trilobite',
72
+ 70: 'harvestman, daddy longlegs, Phalangium opilio',
73
+ 71: 'scorpion',
74
+ 72: 'black and gold garden spider, Argiope aurantia',
75
+ 73: 'barn spider, Araneus cavaticus',
76
+ 74: 'garden spider, Aranea diademata',
77
+ 75: 'black widow, Latrodectus mactans',
78
+ 76: 'tarantula',
79
+ 77: 'wolf spider, hunting spider',
80
+ 78: 'tick',
81
+ 79: 'centipede',
82
+ 80: 'black grouse',
83
+ 81: 'ptarmigan',
84
+ 82: 'ruffed grouse, partridge, Bonasa umbellus',
85
+ 83: 'prairie chicken, prairie grouse, prairie fowl',
86
+ 84: 'peacock',
87
+ 85: 'quail',
88
+ 86: 'partridge',
89
+ 87: 'African grey, African gray, Psittacus erithacus',
90
+ 88: 'macaw',
91
+ 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
92
+ 90: 'lorikeet',
93
+ 91: 'coucal',
94
+ 92: 'bee eater',
95
+ 93: 'hornbill',
96
+ 94: 'hummingbird',
97
+ 95: 'jacamar',
98
+ 96: 'toucan',
99
+ 97: 'drake',
100
+ 98: 'red-breasted merganser, Mergus serrator',
101
+ 99: 'goose',
102
+ 100: 'black swan, Cygnus atratus',
103
+ 101: 'tusker',
104
+ 102: 'echidna, spiny anteater, anteater',
105
+ 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
106
+ 104: 'wallaby, brush kangaroo',
107
+ 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
108
+ 106: 'wombat',
109
+ 107: 'jellyfish',
110
+ 108: 'sea anemone, anemone',
111
+ 109: 'brain coral',
112
+ 110: 'flatworm, platyhelminth',
113
+ 111: 'nematode, nematode worm, roundworm',
114
+ 112: 'conch',
115
+ 113: 'snail',
116
+ 114: 'slug',
117
+ 115: 'sea slug, nudibranch',
118
+ 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
119
+ 117: 'chambered nautilus, pearly nautilus, nautilus',
120
+ 118: 'Dungeness crab, Cancer magister',
121
+ 119: 'rock crab, Cancer irroratus',
122
+ 120: 'fiddler crab',
123
+ 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
124
+ 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
125
+ 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
126
+ 124: 'crayfish, crawfish, crawdad, crawdaddy',
127
+ 125: 'hermit crab',
128
+ 126: 'isopod',
129
+ 127: 'white stork, Ciconia ciconia',
130
+ 128: 'black stork, Ciconia nigra',
131
+ 129: 'spoonbill',
132
+ 130: 'flamingo',
133
+ 131: 'little blue heron, Egretta caerulea',
134
+ 132: 'American egret, great white heron, Egretta albus',
135
+ 133: 'bittern',
136
+ 134: 'crane',
137
+ 135: 'limpkin, Aramus pictus',
138
+ 136: 'European gallinule, Porphyrio porphyrio',
139
+ 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
140
+ 138: 'bustard',
141
+ 139: 'ruddy turnstone, Arenaria interpres',
142
+ 140: 'red-backed sandpiper, dunlin, Erolia alpina',
143
+ 141: 'redshank, Tringa totanus',
144
+ 142: 'dowitcher',
145
+ 143: 'oystercatcher, oyster catcher',
146
+ 144: 'pelican',
147
+ 145: 'king penguin, Aptenodytes patagonica',
148
+ 146: 'albatross, mollymawk',
149
+ 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
150
+ 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
151
+ 149: 'dugong, Dugong dugon',
152
+ 150: 'sea lion',
153
+ 151: 'Chihuahua',
154
+ 152: 'Japanese spaniel',
155
+ 153: 'Maltese dog, Maltese terrier, Maltese',
156
+ 154: 'Pekinese, Pekingese, Peke',
157
+ 155: 'Shih-Tzu',
158
+ 156: 'Blenheim spaniel',
159
+ 157: 'papillon',
160
+ 158: 'toy terrier',
161
+ 159: 'Rhodesian ridgeback',
162
+ 160: 'Afghan hound, Afghan',
163
+ 161: 'basset, basset hound',
164
+ 162: 'beagle',
165
+ 163: 'bloodhound, sleuthhound',
166
+ 164: 'bluetick',
167
+ 165: 'black-and-tan coonhound',
168
+ 166: 'Walker hound, Walker foxhound',
169
+ 167: 'English foxhound',
170
+ 168: 'redbone',
171
+ 169: 'borzoi, Russian wolfhound',
172
+ 170: 'Irish wolfhound',
173
+ 171: 'Italian greyhound',
174
+ 172: 'whippet',
175
+ 173: 'Ibizan hound, Ibizan Podenco',
176
+ 174: 'Norwegian elkhound, elkhound',
177
+ 175: 'otterhound, otter hound',
178
+ 176: 'Saluki, gazelle hound',
179
+ 177: 'Scottish deerhound, deerhound',
180
+ 178: 'Weimaraner',
181
+ 179: 'Staffordshire bullterrier, Staffordshire bull terrier',
182
+ 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
183
+ 181: 'Bedlington terrier',
184
+ 182: 'Border terrier',
185
+ 183: 'Kerry blue terrier',
186
+ 184: 'Irish terrier',
187
+ 185: 'Norfolk terrier',
188
+ 186: 'Norwich terrier',
189
+ 187: 'Yorkshire terrier',
190
+ 188: 'wire-haired fox terrier',
191
+ 189: 'Lakeland terrier',
192
+ 190: 'Sealyham terrier, Sealyham',
193
+ 191: 'Airedale, Airedale terrier',
194
+ 192: 'cairn, cairn terrier',
195
+ 193: 'Australian terrier',
196
+ 194: 'Dandie Dinmont, Dandie Dinmont terrier',
197
+ 195: 'Boston bull, Boston terrier',
198
+ 196: 'miniature schnauzer',
199
+ 197: 'giant schnauzer',
200
+ 198: 'standard schnauzer',
201
+ 199: 'Scotch terrier, Scottish terrier, Scottie',
202
+ 200: 'Tibetan terrier, chrysanthemum dog',
203
+ 201: 'silky terrier, Sydney silky',
204
+ 202: 'soft-coated wheaten terrier',
205
+ 203: 'West Highland white terrier',
206
+ 204: 'Lhasa, Lhasa apso',
207
+ 205: 'flat-coated retriever',
208
+ 206: 'curly-coated retriever',
209
+ 207: 'golden retriever',
210
+ 208: 'Labrador retriever',
211
+ 209: 'Chesapeake Bay retriever',
212
+ 210: 'German short-haired pointer',
213
+ 211: 'vizsla, Hungarian pointer',
214
+ 212: 'English setter',
215
+ 213: 'Irish setter, red setter',
216
+ 214: 'Gordon setter',
217
+ 215: 'Brittany spaniel',
218
+ 216: 'clumber, clumber spaniel',
219
+ 217: 'English springer, English springer spaniel',
220
+ 218: 'Welsh springer spaniel',
221
+ 219: 'cocker spaniel, English cocker spaniel, cocker',
222
+ 220: 'Sussex spaniel',
223
+ 221: 'Irish water spaniel',
224
+ 222: 'kuvasz',
225
+ 223: 'schipperke',
226
+ 224: 'groenendael',
227
+ 225: 'malinois',
228
+ 226: 'briard',
229
+ 227: 'kelpie',
230
+ 228: 'komondor',
231
+ 229: 'Old English sheepdog, bobtail',
232
+ 230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
233
+ 231: 'collie',
234
+ 232: 'Border collie',
235
+ 233: 'Bouvier des Flandres, Bouviers des Flandres',
236
+ 234: 'Rottweiler',
237
+ 235: 'German shepherd, German shepherd dog, German police dog, alsatian',
238
+ 236: 'Doberman, Doberman pinscher',
239
+ 237: 'miniature pinscher',
240
+ 238: 'Greater Swiss Mountain dog',
241
+ 239: 'Bernese mountain dog',
242
+ 240: 'Appenzeller',
243
+ 241: 'EntleBucher',
244
+ 242: 'boxer',
245
+ 243: 'bull mastiff',
246
+ 244: 'Tibetan mastiff',
247
+ 245: 'French bulldog',
248
+ 246: 'Great Dane',
249
+ 247: 'Saint Bernard, St Bernard',
250
+ 248: 'Eskimo dog, husky',
251
+ 249: 'malamute, malemute, Alaskan malamute',
252
+ 250: 'Siberian husky',
253
+ 251: 'dalmatian, coach dog, carriage dog',
254
+ 252: 'affenpinscher, monkey pinscher, monkey dog',
255
+ 253: 'basenji',
256
+ 254: 'pug, pug-dog',
257
+ 255: 'Leonberg',
258
+ 256: 'Newfoundland, Newfoundland dog',
259
+ 257: 'Great Pyrenees',
260
+ 258: 'Samoyed, Samoyede',
261
+ 259: 'Pomeranian',
262
+ 260: 'chow, chow chow',
263
+ 261: 'keeshond',
264
+ 262: 'Brabancon griffon',
265
+ 263: 'Pembroke, Pembroke Welsh corgi',
266
+ 264: 'Cardigan, Cardigan Welsh corgi',
267
+ 265: 'toy poodle',
268
+ 266: 'miniature poodle',
269
+ 267: 'standard poodle',
270
+ 268: 'Mexican hairless',
271
+ 269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
272
+ 270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
273
+ 271: 'red wolf, maned wolf, Canis rufus, Canis niger',
274
+ 272: 'coyote, prairie wolf, brush wolf, Canis latrans',
275
+ 273: 'dingo, warrigal, warragal, Canis dingo',
276
+ 274: 'dhole, Cuon alpinus',
277
+ 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
278
+ 276: 'hyena, hyaena',
279
+ 277: 'red fox, Vulpes vulpes',
280
+ 278: 'kit fox, Vulpes macrotis',
281
+ 279: 'Arctic fox, white fox, Alopex lagopus',
282
+ 280: 'grey fox, gray fox, Urocyon cinereoargenteus',
283
+ 281: 'tabby, tabby cat',
284
+ 282: 'tiger cat',
285
+ 283: 'Persian cat',
286
+ 284: 'Siamese cat, Siamese',
287
+ 285: 'Egyptian cat',
288
+ 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
289
+ 287: 'lynx, catamount',
290
+ 288: 'leopard, Panthera pardus',
291
+ 289: 'snow leopard, ounce, Panthera uncia',
292
+ 290: 'jaguar, panther, Panthera onca, Felis onca',
293
+ 291: 'lion, king of beasts, Panthera leo',
294
+ 292: 'tiger, Panthera tigris',
295
+ 293: 'cheetah, chetah, Acinonyx jubatus',
296
+ 294: 'brown bear, bruin, Ursus arctos',
297
+ 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
298
+ 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
299
+ 297: 'sloth bear, Melursus ursinus, Ursus ursinus',
300
+ 298: 'mongoose',
301
+ 299: 'meerkat, mierkat',
302
+ 300: 'tiger beetle',
303
+ 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
304
+ 302: 'ground beetle, carabid beetle',
305
+ 303: 'long-horned beetle, longicorn, longicorn beetle',
306
+ 304: 'leaf beetle, chrysomelid',
307
+ 305: 'dung beetle',
308
+ 306: 'rhinoceros beetle',
309
+ 307: 'weevil',
310
+ 308: 'fly',
311
+ 309: 'bee',
312
+ 310: 'ant, emmet, pismire',
313
+ 311: 'grasshopper, hopper',
314
+ 312: 'cricket',
315
+ 313: 'walking stick, walkingstick, stick insect',
316
+ 314: 'cockroach, roach',
317
+ 315: 'mantis, mantid',
318
+ 316: 'cicada, cicala',
319
+ 317: 'leafhopper',
320
+ 318: 'lacewing, lacewing fly',
321
+ 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
322
+ 320: 'damselfly',
323
+ 321: 'admiral',
324
+ 322: 'ringlet, ringlet butterfly',
325
+ 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
326
+ 324: 'cabbage butterfly',
327
+ 325: 'sulphur butterfly, sulfur butterfly',
328
+ 326: 'lycaenid, lycaenid butterfly',
329
+ 327: 'starfish, sea star',
330
+ 328: 'sea urchin',
331
+ 329: 'sea cucumber, holothurian',
332
+ 330: 'wood rabbit, cottontail, cottontail rabbit',
333
+ 331: 'hare',
334
+ 332: 'Angora, Angora rabbit',
335
+ 333: 'hamster',
336
+ 334: 'porcupine, hedgehog',
337
+ 335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
338
+ 336: 'marmot',
339
+ 337: 'beaver',
340
+ 338: 'guinea pig, Cavia cobaya',
341
+ 339: 'sorrel',
342
+ 340: 'zebra',
343
+ 341: 'hog, pig, grunter, squealer, Sus scrofa',
344
+ 342: 'wild boar, boar, Sus scrofa',
345
+ 343: 'warthog',
346
+ 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
347
+ 345: 'ox',
348
+ 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
349
+ 347: 'bison',
350
+ 348: 'ram, tup',
351
+ 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
352
+ 350: 'ibex, Capra ibex',
353
+ 351: 'hartebeest',
354
+ 352: 'impala, Aepyceros melampus',
355
+ 353: 'gazelle',
356
+ 354: 'Arabian camel, dromedary, Camelus dromedarius',
357
+ 355: 'llama',
358
+ 356: 'weasel',
359
+ 357: 'mink',
360
+ 358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
361
+ 359: 'black-footed ferret, ferret, Mustela nigripes',
362
+ 360: 'otter',
363
+ 361: 'skunk, polecat, wood pussy',
364
+ 362: 'badger',
365
+ 363: 'armadillo',
366
+ 364: 'three-toed sloth, ai, Bradypus tridactylus',
367
+ 365: 'orangutan, orang, orangutang, Pongo pygmaeus',
368
+ 366: 'gorilla, Gorilla gorilla',
369
+ 367: 'chimpanzee, chimp, Pan troglodytes',
370
+ 368: 'gibbon, Hylobates lar',
371
+ 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
372
+ 370: 'guenon, guenon monkey',
373
+ 371: 'patas, hussar monkey, Erythrocebus patas',
374
+ 372: 'baboon',
375
+ 373: 'macaque',
376
+ 374: 'langur',
377
+ 375: 'colobus, colobus monkey',
378
+ 376: 'proboscis monkey, Nasalis larvatus',
379
+ 377: 'marmoset',
380
+ 378: 'capuchin, ringtail, Cebus capucinus',
381
+ 379: 'howler monkey, howler',
382
+ 380: 'titi, titi monkey',
383
+ 381: 'spider monkey, Ateles geoffroyi',
384
+ 382: 'squirrel monkey, Saimiri sciureus',
385
+ 383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
386
+ 384: 'indri, indris, Indri indri, Indri brevicaudatus',
387
+ 385: 'Indian elephant, Elephas maximus',
388
+ 386: 'African elephant, Loxodonta africana',
389
+ 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
390
+ 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
391
+ 389: 'barracouta, snoek',
392
+ 390: 'eel',
393
+ 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
394
+ 392: 'rock beauty, Holocanthus tricolor',
395
+ 393: 'anemone fish',
396
+ 394: 'sturgeon',
397
+ 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
398
+ 396: 'lionfish',
399
+ 397: 'puffer, pufferfish, blowfish, globefish',
400
+ 398: 'abacus',
401
+ 399: 'abaya',
402
+ 400: "academic gown, academic robe, judge's robe",
403
+ 401: 'accordion, piano accordion, squeeze box',
404
+ 402: 'acoustic guitar',
405
+ 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
406
+ 404: 'airliner',
407
+ 405: 'airship, dirigible',
408
+ 406: 'altar',
409
+ 407: 'ambulance',
410
+ 408: 'amphibian, amphibious vehicle',
411
+ 409: 'analog clock',
412
+ 410: 'apiary, bee house',
413
+ 411: 'apron',
414
+ 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
415
+ 413: 'assault rifle, assault gun',
416
+ 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
417
+ 415: 'bakery, bakeshop, bakehouse',
418
+ 416: 'balance beam, beam',
419
+ 417: 'balloon',
420
+ 418: 'ballpoint, ballpoint pen, ballpen, Biro',
421
+ 419: 'Band Aid',
422
+ 420: 'banjo',
423
+ 421: 'bannister, banister, balustrade, balusters, handrail',
424
+ 422: 'barbell',
425
+ 423: 'barber chair',
426
+ 424: 'barbershop',
427
+ 425: 'barn',
428
+ 426: 'barometer',
429
+ 427: 'barrel, cask',
430
+ 428: 'barrow, garden cart, lawn cart, wheelbarrow',
431
+ 429: 'baseball',
432
+ 430: 'basketball',
433
+ 431: 'bassinet',
434
+ 432: 'bassoon',
435
+ 433: 'bathing cap, swimming cap',
436
+ 434: 'bath towel',
437
+ 435: 'bathtub, bathing tub, bath, tub',
438
+ 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
439
+ 437: 'beacon, lighthouse, beacon light, pharos',
440
+ 438: 'beaker',
441
+ 439: 'bearskin, busby, shako',
442
+ 440: 'beer bottle',
443
+ 441: 'beer glass',
444
+ 442: 'bell cote, bell cot',
445
+ 443: 'bib',
446
+ 444: 'bicycle-built-for-two, tandem bicycle, tandem',
447
+ 445: 'bikini, two-piece',
448
+ 446: 'binder, ring-binder',
449
+ 447: 'binoculars, field glasses, opera glasses',
450
+ 448: 'birdhouse',
451
+ 449: 'boathouse',
452
+ 450: 'bobsled, bobsleigh, bob',
453
+ 451: 'bolo tie, bolo, bola tie, bola',
454
+ 452: 'bonnet, poke bonnet',
455
+ 453: 'bookcase',
456
+ 454: 'bookshop, bookstore, bookstall',
457
+ 455: 'bottlecap',
458
+ 456: 'bow',
459
+ 457: 'bow tie, bow-tie, bowtie',
460
+ 458: 'brass, memorial tablet, plaque',
461
+ 459: 'brassiere, bra, bandeau',
462
+ 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
463
+ 461: 'breastplate, aegis, egis',
464
+ 462: 'broom',
465
+ 463: 'bucket, pail',
466
+ 464: 'buckle',
467
+ 465: 'bulletproof vest',
468
+ 466: 'bullet train, bullet',
469
+ 467: 'butcher shop, meat market',
470
+ 468: 'cab, hack, taxi, taxicab',
471
+ 469: 'caldron, cauldron',
472
+ 470: 'candle, taper, wax light',
473
+ 471: 'cannon',
474
+ 472: 'canoe',
475
+ 473: 'can opener, tin opener',
476
+ 474: 'cardigan',
477
+ 475: 'car mirror',
478
+ 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
479
+ 477: "carpenter's kit, tool kit",
480
+ 478: 'carton',
481
+ 479: 'car wheel',
482
+ 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
483
+ 481: 'cassette',
484
+ 482: 'cassette player',
485
+ 483: 'castle',
486
+ 484: 'catamaran',
487
+ 485: 'CD player',
488
+ 486: 'cello, violoncello',
489
+ 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
490
+ 488: 'chain',
491
+ 489: 'chainlink fence',
492
+ 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
493
+ 491: 'chain saw, chainsaw',
494
+ 492: 'chest',
495
+ 493: 'chiffonier, commode',
496
+ 494: 'chime, bell, gong',
497
+ 495: 'china cabinet, china closet',
498
+ 496: 'Christmas stocking',
499
+ 497: 'church, church building',
500
+ 498: 'cinema, movie theater, movie theatre, movie house, picture palace',
501
+ 499: 'cleaver, meat cleaver, chopper',
502
+ 500: 'cliff dwelling',
503
+ 501: 'cloak',
504
+ 502: 'clog, geta, patten, sabot',
505
+ 503: 'cocktail shaker',
506
+ 504: 'coffee mug',
507
+ 505: 'coffeepot',
508
+ 506: 'coil, spiral, volute, whorl, helix',
509
+ 507: 'combination lock',
510
+ 508: 'computer keyboard, keypad',
511
+ 509: 'confectionery, confectionary, candy store',
512
+ 510: 'container ship, containership, container vessel',
513
+ 511: 'convertible',
514
+ 512: 'corkscrew, bottle screw',
515
+ 513: 'cornet, horn, trumpet, trump',
516
+ 514: 'cowboy boot',
517
+ 515: 'cowboy hat, ten-gallon hat',
518
+ 516: 'cradle',
519
+ 517: 'crane',
520
+ 518: 'crash helmet',
521
+ 519: 'crate',
522
+ 520: 'crib, cot',
523
+ 521: 'Crock Pot',
524
+ 522: 'croquet ball',
525
+ 523: 'crutch',
526
+ 524: 'cuirass',
527
+ 525: 'dam, dike, dyke',
528
+ 526: 'desk',
529
+ 527: 'desktop computer',
530
+ 528: 'dial telephone, dial phone',
531
+ 529: 'diaper, nappy, napkin',
532
+ 530: 'digital clock',
533
+ 531: 'digital watch',
534
+ 532: 'dining table, board',
535
+ 533: 'dishrag, dishcloth',
536
+ 534: 'dishwasher, dish washer, dishwashing machine',
537
+ 535: 'disk brake, disc brake',
538
+ 536: 'dock, dockage, docking facility',
539
+ 537: 'dogsled, dog sled, dog sleigh',
540
+ 538: 'dome',
541
+ 539: 'doormat, welcome mat',
542
+ 540: 'drilling platform, offshore rig',
543
+ 541: 'drum, membranophone, tympan',
544
+ 542: 'drumstick',
545
+ 543: 'dumbbell',
546
+ 544: 'Dutch oven',
547
+ 545: 'electric fan, blower',
548
+ 546: 'electric guitar',
549
+ 547: 'electric locomotive',
550
+ 548: 'entertainment center',
551
+ 549: 'envelope',
552
+ 550: 'espresso maker',
553
+ 551: 'face powder',
554
+ 552: 'feather boa, boa',
555
+ 553: 'file, file cabinet, filing cabinet',
556
+ 554: 'fireboat',
557
+ 555: 'fire engine, fire truck',
558
+ 556: 'fire screen, fireguard',
559
+ 557: 'flagpole, flagstaff',
560
+ 558: 'flute, transverse flute',
561
+ 559: 'folding chair',
562
+ 560: 'football helmet',
563
+ 561: 'forklift',
564
+ 562: 'fountain',
565
+ 563: 'fountain pen',
566
+ 564: 'four-poster',
567
+ 565: 'freight car',
568
+ 566: 'French horn, horn',
569
+ 567: 'frying pan, frypan, skillet',
570
+ 568: 'fur coat',
571
+ 569: 'garbage truck, dustcart',
572
+ 570: 'gasmask, respirator, gas helmet',
573
+ 571: 'gas pump, gasoline pump, petrol pump, island dispenser',
574
+ 572: 'goblet',
575
+ 573: 'go-kart',
576
+ 574: 'golf ball',
577
+ 575: 'golfcart, golf cart',
578
+ 576: 'gondola',
579
+ 577: 'gong, tam-tam',
580
+ 578: 'gown',
581
+ 579: 'grand piano, grand',
582
+ 580: 'greenhouse, nursery, glasshouse',
583
+ 581: 'grille, radiator grille',
584
+ 582: 'grocery store, grocery, food market, market',
585
+ 583: 'guillotine',
586
+ 584: 'hair slide',
587
+ 585: 'hair spray',
588
+ 586: 'half track',
589
+ 587: 'hammer',
590
+ 588: 'hamper',
591
+ 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
592
+ 590: 'hand-held computer, hand-held microcomputer',
593
+ 591: 'handkerchief, hankie, hanky, hankey',
594
+ 592: 'hard disc, hard disk, fixed disk',
595
+ 593: 'harmonica, mouth organ, harp, mouth harp',
596
+ 594: 'harp',
597
+ 595: 'harvester, reaper',
598
+ 596: 'hatchet',
599
+ 597: 'holster',
600
+ 598: 'home theater, home theatre',
601
+ 599: 'honeycomb',
602
+ 600: 'hook, claw',
603
+ 601: 'hoopskirt, crinoline',
604
+ 602: 'horizontal bar, high bar',
605
+ 603: 'horse cart, horse-cart',
606
+ 604: 'hourglass',
607
+ 605: 'iPod',
608
+ 606: 'iron, smoothing iron',
609
+ 607: "jack-o'-lantern",
610
+ 608: 'jean, blue jean, denim',
611
+ 609: 'jeep, landrover',
612
+ 610: 'jersey, T-shirt, tee shirt',
613
+ 611: 'jigsaw puzzle',
614
+ 612: 'jinrikisha, ricksha, rickshaw',
615
+ 613: 'joystick',
616
+ 614: 'kimono',
617
+ 615: 'knee pad',
618
+ 616: 'knot',
619
+ 617: 'lab coat, laboratory coat',
620
+ 618: 'ladle',
621
+ 619: 'lampshade, lamp shade',
622
+ 620: 'laptop, laptop computer',
623
+ 621: 'lawn mower, mower',
624
+ 622: 'lens cap, lens cover',
625
+ 623: 'letter opener, paper knife, paperknife',
626
+ 624: 'library',
627
+ 625: 'lifeboat',
628
+ 626: 'lighter, light, igniter, ignitor',
629
+ 627: 'limousine, limo',
630
+ 628: 'liner, ocean liner',
631
+ 629: 'lipstick, lip rouge',
632
+ 630: 'Loafer',
633
+ 631: 'lotion',
634
+ 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
635
+ 633: "loupe, jeweler's loupe",
636
+ 634: 'lumbermill, sawmill',
637
+ 635: 'magnetic compass',
638
+ 636: 'mailbag, postbag',
639
+ 637: 'mailbox, letter box',
640
+ 638: 'maillot',
641
+ 639: 'maillot, tank suit',
642
+ 640: 'manhole cover',
643
+ 641: 'maraca',
644
+ 642: 'marimba, xylophone',
645
+ 643: 'mask',
646
+ 644: 'matchstick',
647
+ 645: 'maypole',
648
+ 646: 'maze, labyrinth',
649
+ 647: 'measuring cup',
650
+ 648: 'medicine chest, medicine cabinet',
651
+ 649: 'megalith, megalithic structure',
652
+ 650: 'microphone, mike',
653
+ 651: 'microwave, microwave oven',
654
+ 652: 'military uniform',
655
+ 653: 'milk can',
656
+ 654: 'minibus',
657
+ 655: 'miniskirt, mini',
658
+ 656: 'minivan',
659
+ 657: 'missile',
660
+ 658: 'mitten',
661
+ 659: 'mixing bowl',
662
+ 660: 'mobile home, manufactured home',
663
+ 661: 'Model T',
664
+ 662: 'modem',
665
+ 663: 'monastery',
666
+ 664: 'monitor',
667
+ 665: 'moped',
668
+ 666: 'mortar',
669
+ 667: 'mortarboard',
670
+ 668: 'mosque',
671
+ 669: 'mosquito net',
672
+ 670: 'motor scooter, scooter',
673
+ 671: 'mountain bike, all-terrain bike, off-roader',
674
+ 672: 'mountain tent',
675
+ 673: 'mouse, computer mouse',
676
+ 674: 'mousetrap',
677
+ 675: 'moving van',
678
+ 676: 'muzzle',
679
+ 677: 'nail',
680
+ 678: 'neck brace',
681
+ 679: 'necklace',
682
+ 680: 'nipple',
683
+ 681: 'notebook, notebook computer',
684
+ 682: 'obelisk',
685
+ 683: 'oboe, hautboy, hautbois',
686
+ 684: 'ocarina, sweet potato',
687
+ 685: 'odometer, hodometer, mileometer, milometer',
688
+ 686: 'oil filter',
689
+ 687: 'organ, pipe organ',
690
+ 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
691
+ 689: 'overskirt',
692
+ 690: 'oxcart',
693
+ 691: 'oxygen mask',
694
+ 692: 'packet',
695
+ 693: 'paddle, boat paddle',
696
+ 694: 'paddlewheel, paddle wheel',
697
+ 695: 'padlock',
698
+ 696: 'paintbrush',
699
+ 697: "pajama, pyjama, pj's, jammies",
700
+ 698: 'palace',
701
+ 699: 'panpipe, pandean pipe, syrinx',
702
+ 700: 'paper towel',
703
+ 701: 'parachute, chute',
704
+ 702: 'parallel bars, bars',
705
+ 703: 'park bench',
706
+ 704: 'parking meter',
707
+ 705: 'passenger car, coach, carriage',
708
+ 706: 'patio, terrace',
709
+ 707: 'pay-phone, pay-station',
710
+ 708: 'pedestal, plinth, footstall',
711
+ 709: 'pencil box, pencil case',
712
+ 710: 'pencil sharpener',
713
+ 711: 'perfume, essence',
714
+ 712: 'Petri dish',
715
+ 713: 'photocopier',
716
+ 714: 'pick, plectrum, plectron',
717
+ 715: 'pickelhaube',
718
+ 716: 'picket fence, paling',
719
+ 717: 'pickup, pickup truck',
720
+ 718: 'pier',
721
+ 719: 'piggy bank, penny bank',
722
+ 720: 'pill bottle',
723
+ 721: 'pillow',
724
+ 722: 'ping-pong ball',
725
+ 723: 'pinwheel',
726
+ 724: 'pirate, pirate ship',
727
+ 725: 'pitcher, ewer',
728
+ 726: "plane, carpenter's plane, woodworking plane",
729
+ 727: 'planetarium',
730
+ 728: 'plastic bag',
731
+ 729: 'plate rack',
732
+ 730: 'plow, plough',
733
+ 731: "plunger, plumber's helper",
734
+ 732: 'Polaroid camera, Polaroid Land camera',
735
+ 733: 'pole',
736
+ 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
737
+ 735: 'poncho',
738
+ 736: 'pool table, billiard table, snooker table',
739
+ 737: 'pop bottle, soda bottle',
740
+ 738: 'pot, flowerpot',
741
+ 739: "potter's wheel",
742
+ 740: 'power drill',
743
+ 741: 'prayer rug, prayer mat',
744
+ 742: 'printer',
745
+ 743: 'prison, prison house',
746
+ 744: 'projectile, missile',
747
+ 745: 'projector',
748
+ 746: 'puck, hockey puck',
749
+ 747: 'punching bag, punch bag, punching ball, punchball',
750
+ 748: 'purse',
751
+ 749: 'quill, quill pen',
752
+ 750: 'quilt, comforter, comfort, puff',
753
+ 751: 'racer, race car, racing car',
754
+ 752: 'racket, racquet',
755
+ 753: 'radiator',
756
+ 754: 'radio, wireless',
757
+ 755: 'radio telescope, radio reflector',
758
+ 756: 'rain barrel',
759
+ 757: 'recreational vehicle, RV, R.V.',
760
+ 758: 'reel',
761
+ 759: 'reflex camera',
762
+ 760: 'refrigerator, icebox',
763
+ 761: 'remote control, remote',
764
+ 762: 'restaurant, eating house, eating place, eatery',
765
+ 763: 'revolver, six-gun, six-shooter',
766
+ 764: 'rifle',
767
+ 765: 'rocking chair, rocker',
768
+ 766: 'rotisserie',
769
+ 767: 'rubber eraser, rubber, pencil eraser',
770
+ 768: 'rugby ball',
771
+ 769: 'rule, ruler',
772
+ 770: 'running shoe',
773
+ 771: 'safe',
774
+ 772: 'safety pin',
775
+ 773: 'saltshaker, salt shaker',
776
+ 774: 'sandal',
777
+ 775: 'sarong',
778
+ 776: 'sax, saxophone',
779
+ 777: 'scabbard',
780
+ 778: 'scale, weighing machine',
781
+ 779: 'school bus',
782
+ 780: 'schooner',
783
+ 781: 'scoreboard',
784
+ 782: 'screen, CRT screen',
785
+ 783: 'screw',
786
+ 784: 'screwdriver',
787
+ 785: 'seat belt, seatbelt',
788
+ 786: 'sewing machine',
789
+ 787: 'shield, buckler',
790
+ 788: 'shoe shop, shoe-shop, shoe store',
791
+ 789: 'shoji',
792
+ 790: 'shopping basket',
793
+ 791: 'shopping cart',
794
+ 792: 'shovel',
795
+ 793: 'shower cap',
796
+ 794: 'shower curtain',
797
+ 795: 'ski',
798
+ 796: 'ski mask',
799
+ 797: 'sleeping bag',
800
+ 798: 'slide rule, slipstick',
801
+ 799: 'sliding door',
802
+ 800: 'slot, one-armed bandit',
803
+ 801: 'snorkel',
804
+ 802: 'snowmobile',
805
+ 803: 'snowplow, snowplough',
806
+ 804: 'soap dispenser',
807
+ 805: 'soccer ball',
808
+ 806: 'sock',
809
+ 807: 'solar dish, solar collector, solar furnace',
810
+ 808: 'sombrero',
811
+ 809: 'soup bowl',
812
+ 810: 'space bar',
813
+ 811: 'space heater',
814
+ 812: 'space shuttle',
815
+ 813: 'spatula',
816
+ 814: 'speedboat',
817
+ 815: "spider web, spider's web",
818
+ 816: 'spindle',
819
+ 817: 'sports car, sport car',
820
+ 818: 'spotlight, spot',
821
+ 819: 'stage',
822
+ 820: 'steam locomotive',
823
+ 821: 'steel arch bridge',
824
+ 822: 'steel drum',
825
+ 823: 'stethoscope',
826
+ 824: 'stole',
827
+ 825: 'stone wall',
828
+ 826: 'stopwatch, stop watch',
829
+ 827: 'stove',
830
+ 828: 'strainer',
831
+ 829: 'streetcar, tram, tramcar, trolley, trolley car',
832
+ 830: 'stretcher',
833
+ 831: 'studio couch, day bed',
834
+ 832: 'stupa, tope',
835
+ 833: 'submarine, pigboat, sub, U-boat',
836
+ 834: 'suit, suit of clothes',
837
+ 835: 'sundial',
838
+ 836: 'sunglass',
839
+ 837: 'sunglasses, dark glasses, shades',
840
+ 838: 'sunscreen, sunblock, sun blocker',
841
+ 839: 'suspension bridge',
842
+ 840: 'swab, swob, mop',
843
+ 841: 'sweatshirt',
844
+ 842: 'swimming trunks, bathing trunks',
845
+ 843: 'swing',
846
+ 844: 'switch, electric switch, electrical switch',
847
+ 845: 'syringe',
848
+ 846: 'table lamp',
849
+ 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
850
+ 848: 'tape player',
851
+ 849: 'teapot',
852
+ 850: 'teddy, teddy bear',
853
+ 851: 'television, television system',
854
+ 852: 'tennis ball',
855
+ 853: 'thatch, thatched roof',
856
+ 854: 'theater curtain, theatre curtain',
857
+ 855: 'thimble',
858
+ 856: 'thresher, thrasher, threshing machine',
859
+ 857: 'throne',
860
+ 858: 'tile roof',
861
+ 859: 'toaster',
862
+ 860: 'tobacco shop, tobacconist shop, tobacconist',
863
+ 861: 'toilet seat',
864
+ 862: 'torch',
865
+ 863: 'totem pole',
866
+ 864: 'tow truck, tow car, wrecker',
867
+ 865: 'toyshop',
868
+ 866: 'tractor',
869
+ 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
870
+ 868: 'tray',
871
+ 869: 'trench coat',
872
+ 870: 'tricycle, trike, velocipede',
873
+ 871: 'trimaran',
874
+ 872: 'tripod',
875
+ 873: 'triumphal arch',
876
+ 874: 'trolleybus, trolley coach, trackless trolley',
877
+ 875: 'trombone',
878
+ 876: 'tub, vat',
879
+ 877: 'turnstile',
880
+ 878: 'typewriter keyboard',
881
+ 879: 'umbrella',
882
+ 880: 'unicycle, monocycle',
883
+ 881: 'upright, upright piano',
884
+ 882: 'vacuum, vacuum cleaner',
885
+ 883: 'vase',
886
+ 884: 'vault',
887
+ 885: 'velvet',
888
+ 886: 'vending machine',
889
+ 887: 'vestment',
890
+ 888: 'viaduct',
891
+ 889: 'violin, fiddle',
892
+ 890: 'volleyball',
893
+ 891: 'waffle iron',
894
+ 892: 'wall clock',
895
+ 893: 'wallet, billfold, notecase, pocketbook',
896
+ 894: 'wardrobe, closet, press',
897
+ 895: 'warplane, military plane',
898
+ 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
899
+ 897: 'washer, automatic washer, washing machine',
900
+ 898: 'water bottle',
901
+ 899: 'water jug',
902
+ 900: 'water tower',
903
+ 901: 'whiskey jug',
904
+ 902: 'whistle',
905
+ 903: 'wig',
906
+ 904: 'window screen',
907
+ 905: 'window shade',
908
+ 906: 'Windsor tie',
909
+ 907: 'wine bottle',
910
+ 908: 'wing',
911
+ 909: 'wok',
912
+ 910: 'wooden spoon',
913
+ 911: 'wool, woolen, woollen',
914
+ 912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
915
+ 913: 'wreck',
916
+ 914: 'yawl',
917
+ 915: 'yurt',
918
+ 916: 'web site, website, internet site, site',
919
+ 917: 'comic book',
920
+ 918: 'crossword puzzle, crossword',
921
+ 919: 'street sign',
922
+ 920: 'traffic light, traffic signal, stoplight',
923
+ 921: 'book jacket, dust cover, dust jacket, dust wrapper',
924
+ 922: 'menu',
925
+ 923: 'plate',
926
+ 924: 'guacamole',
927
+ 925: 'consomme',
928
+ 926: 'hot pot, hotpot',
929
+ 927: 'trifle',
930
+ 928: 'ice cream, icecream',
931
+ 929: 'ice lolly, lolly, lollipop, popsicle',
932
+ 930: 'French loaf',
933
+ 931: 'bagel, beigel',
934
+ 932: 'pretzel',
935
+ 933: 'cheeseburger',
936
+ 934: 'hotdog, hot dog, red hot',
937
+ 935: 'mashed potato',
938
+ 936: 'head cabbage',
939
+ 937: 'broccoli',
940
+ 938: 'cauliflower',
941
+ 939: 'zucchini, courgette',
942
+ 940: 'spaghetti squash',
943
+ 941: 'acorn squash',
944
+ 942: 'butternut squash',
945
+ 943: 'cucumber, cuke',
946
+ 944: 'artichoke, globe artichoke',
947
+ 945: 'bell pepper',
948
+ 946: 'cardoon',
949
+ 947: 'mushroom',
950
+ 948: 'Granny Smith',
951
+ 949: 'strawberry',
952
+ 950: 'orange',
953
+ 951: 'lemon',
954
+ 952: 'fig',
955
+ 953: 'pineapple, ananas',
956
+ 954: 'banana',
957
+ 955: 'jackfruit, jak, jack',
958
+ 956: 'custard apple',
959
+ 957: 'pomegranate',
960
+ 958: 'hay',
961
+ 959: 'carbonara',
962
+ 960: 'chocolate sauce, chocolate syrup',
963
+ 961: 'dough',
964
+ 962: 'meat loaf, meatloaf',
965
+ 963: 'pizza, pizza pie',
966
+ 964: 'potpie',
967
+ 965: 'burrito',
968
+ 966: 'red wine',
969
+ 967: 'espresso',
970
+ 968: 'cup',
971
+ 969: 'eggnog',
972
+ 970: 'alp',
973
+ 971: 'bubble',
974
+ 972: 'cliff, drop, drop-off',
975
+ 973: 'coral reef',
976
+ 974: 'geyser',
977
+ 975: 'lakeside, lakeshore',
978
+ 976: 'promontory, headland, head, foreland',
979
+ 977: 'sandbar, sand bar',
980
+ 978: 'seashore, coast, seacoast, sea-coast',
981
+ 979: 'valley, vale',
982
+ 980: 'volcano',
983
+ 981: 'ballplayer, baseball player',
984
+ 982: 'groom, bridegroom',
985
+ 983: 'scuba diver',
986
+ 984: 'rapeseed',
987
+ 985: 'daisy',
988
+ 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
989
+ 987: 'corn',
990
+ 988: 'acorn',
991
+ 989: 'hip, rose hip, rosehip',
992
+ 990: 'buckeye, horse chestnut, conker',
993
+ 991: 'coral fungus',
994
+ 992: 'agaric',
995
+ 993: 'gyromitra',
996
+ 994: 'stinkhorn, carrion fungus',
997
+ 995: 'earthstar',
998
+ 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
999
+ 997: 'bolete',
1000
+ 998: 'ear, spike, capitulum',
1001
+ 999: 'toilet tissue, toilet paper, bathroom tissue'
1002
+ }
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/data/transforms.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import sys
3
+ import random
4
+ from PIL import Image
5
+
6
+ try:
7
+ import accimage
8
+ except ImportError:
9
+ accimage = None
10
+ import numbers
11
+ import collections
12
+
13
+ from torchvision.transforms import functional as F
14
+
15
+ if sys.version_info < (3, 3):
16
+ Sequence = collections.Sequence
17
+ Iterable = collections.Iterable
18
+ else:
19
+ Sequence = collections.abc.Sequence
20
+ Iterable = collections.abc.Iterable
21
+
22
+ _pil_interpolation_to_str = {
23
+ Image.NEAREST: 'PIL.Image.NEAREST',
24
+ Image.BILINEAR: 'PIL.Image.BILINEAR',
25
+ Image.BICUBIC: 'PIL.Image.BICUBIC',
26
+ Image.LANCZOS: 'PIL.Image.LANCZOS',
27
+ Image.HAMMING: 'PIL.Image.HAMMING',
28
+ Image.BOX: 'PIL.Image.BOX',
29
+ }
30
+
31
+
32
+ class Compose(object):
33
+ """Composes several transforms together.
34
+
35
+ Args:
36
+ transforms (list of ``Transform`` objects): list of transforms to compose.
37
+
38
+ Example:
39
+ >>> transforms.Compose([
40
+ >>> transforms.CenterCrop(10),
41
+ >>> transforms.ToTensor(),
42
+ >>> ])
43
+ """
44
+
45
+ def __init__(self, transforms):
46
+ self.transforms = transforms
47
+
48
+ def __call__(self, img, tgt):
49
+ for t in self.transforms:
50
+ img, tgt = t(img, tgt)
51
+ return img, tgt
52
+
53
+ def __repr__(self):
54
+ format_string = self.__class__.__name__ + '('
55
+ for t in self.transforms:
56
+ format_string += '\n'
57
+ format_string += ' {0}'.format(t)
58
+ format_string += '\n)'
59
+ return format_string
60
+
61
+
62
+ class Resize(object):
63
+ """Resize the input PIL Image to the given size.
64
+
65
+ Args:
66
+ size (sequence or int): Desired output size. If size is a sequence like
67
+ (h, w), output size will be matched to this. If size is an int,
68
+ smaller edge of the image will be matched to this number.
69
+ i.e, if height > width, then image will be rescaled to
70
+ (size * height / width, size)
71
+ interpolation (int, optional): Desired interpolation. Default is
72
+ ``PIL.Image.BILINEAR``
73
+ """
74
+
75
+ def __init__(self, size, interpolation=Image.BILINEAR):
76
+ assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
77
+ self.size = size
78
+ self.interpolation = interpolation
79
+
80
+ def __call__(self, img, tgt):
81
+ """
82
+ Args:
83
+ img (PIL Image): Image to be scaled.
84
+
85
+ Returns:
86
+ PIL Image: Rescaled image.
87
+ """
88
+ return F.resize(img, self.size, self.interpolation), F.resize(tgt, self.size, Image.NEAREST)
89
+
90
+ def __repr__(self):
91
+ interpolate_str = _pil_interpolation_to_str[self.interpolation]
92
+ return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
93
+
94
+
95
+ class CenterCrop(object):
96
+ """Crops the given PIL Image at the center.
97
+
98
+ Args:
99
+ size (sequence or int): Desired output size of the crop. If size is an
100
+ int instead of sequence like (h, w), a square crop (size, size) is
101
+ made.
102
+ """
103
+
104
+ def __init__(self, size):
105
+ if isinstance(size, numbers.Number):
106
+ self.size = (int(size), int(size))
107
+ else:
108
+ self.size = size
109
+
110
+ def __call__(self, img, tgt):
111
+ """
112
+ Args:
113
+ img (PIL Image): Image to be cropped.
114
+
115
+ Returns:
116
+ PIL Image: Cropped image.
117
+ """
118
+ return F.center_crop(img, self.size), F.center_crop(tgt, self.size)
119
+
120
+ def __repr__(self):
121
+ return self.__class__.__name__ + '(size={0})'.format(self.size)
122
+
123
+
124
+ class RandomCrop(object):
125
+ """Crop the given PIL Image at a random location.
126
+
127
+ Args:
128
+ size (sequence or int): Desired output size of the crop. If size is an
129
+ int instead of sequence like (h, w), a square crop (size, size) is
130
+ made.
131
+ padding (int or sequence, optional): Optional padding on each border
132
+ of the image. Default is None, i.e no padding. If a sequence of length
133
+ 4 is provided, it is used to pad left, top, right, bottom borders
134
+ respectively. If a sequence of length 2 is provided, it is used to
135
+ pad left/right, top/bottom borders, respectively.
136
+ pad_if_needed (boolean): It will pad the image if smaller than the
137
+ desired size to avoid raising an exception.
138
+ fill: Pixel fill value for constant fill. Default is 0. If a tuple of
139
+ length 3, it is used to fill R, G, B channels respectively.
140
+ This value is only used when the padding_mode is constant
141
+ padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
142
+
143
+ - constant: pads with a constant value, this value is specified with fill
144
+
145
+ - edge: pads with the last value on the edge of the image
146
+
147
+ - reflect: pads with reflection of image (without repeating the last value on the edge)
148
+
149
+ padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
150
+ will result in [3, 2, 1, 2, 3, 4, 3, 2]
151
+
152
+ - symmetric: pads with reflection of image (repeating the last value on the edge)
153
+
154
+ padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
155
+ will result in [2, 1, 1, 2, 3, 4, 4, 3]
156
+
157
+ """
158
+
159
+ def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
160
+ if isinstance(size, numbers.Number):
161
+ self.size = (int(size), int(size))
162
+ else:
163
+ self.size = size
164
+ self.padding = padding
165
+ self.pad_if_needed = pad_if_needed
166
+ self.fill = fill
167
+ self.padding_mode = padding_mode
168
+
169
+ @staticmethod
170
+ def get_params(img, output_size):
171
+ """Get parameters for ``crop`` for a random crop.
172
+
173
+ Args:
174
+ img (PIL Image): Image to be cropped.
175
+ output_size (tuple): Expected output size of the crop.
176
+
177
+ Returns:
178
+ tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
179
+ """
180
+ w, h = img.size
181
+ th, tw = output_size
182
+ if w == tw and h == th:
183
+ return 0, 0, h, w
184
+
185
+ i = random.randint(0, h - th)
186
+ j = random.randint(0, w - tw)
187
+ return i, j, th, tw
188
+
189
+ def __call__(self, img, tgt):
190
+ """
191
+ Args:
192
+ img (PIL Image): Image to be cropped.
193
+
194
+ Returns:
195
+ PIL Image: Cropped image.
196
+ """
197
+ if self.padding is not None:
198
+ img = F.pad(img, self.padding, self.fill, self.padding_mode)
199
+ tgt = F.pad(tgt, self.padding, self.fill, self.padding_mode)
200
+
201
+ # pad the width if needed
202
+ if self.pad_if_needed and img.size[0] < self.size[1]:
203
+ img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
204
+ tgt = F.pad(tgt, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
205
+ # pad the height if needed
206
+ if self.pad_if_needed and img.size[1] < self.size[0]:
207
+ img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
208
+ tgt = F.pad(tgt, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
209
+
210
+ i, j, h, w = self.get_params(img, self.size)
211
+
212
+ return F.crop(img, i, j, h, w), F.crop(tgt, i, j, h, w)
213
+
214
+ def __repr__(self):
215
+ return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
216
+
217
+
218
+ class RandomHorizontalFlip(object):
219
+ """Horizontally flip the given PIL Image randomly with a given probability.
220
+
221
+ Args:
222
+ p (float): probability of the image being flipped. Default value is 0.5
223
+ """
224
+
225
+ def __init__(self, p=0.5):
226
+ self.p = p
227
+
228
+ def __call__(self, img, tgt):
229
+ """
230
+ Args:
231
+ img (PIL Image): Image to be flipped.
232
+
233
+ Returns:
234
+ PIL Image: Randomly flipped image.
235
+ """
236
+ if random.random() < self.p:
237
+ return F.hflip(img), F.hflip(tgt)
238
+
239
+ return img, tgt
240
+
241
+ def __repr__(self):
242
+ return self.__class__.__name__ + '(p={})'.format(self.p)
243
+
244
+
245
+ class RandomVerticalFlip(object):
246
+ """Vertically flip the given PIL Image randomly with a given probability.
247
+
248
+ Args:
249
+ p (float): probability of the image being flipped. Default value is 0.5
250
+ """
251
+
252
+ def __init__(self, p=0.5):
253
+ self.p = p
254
+
255
+ def __call__(self, img, tgt):
256
+ """
257
+ Args:
258
+ img (PIL Image): Image to be flipped.
259
+
260
+ Returns:
261
+ PIL Image: Randomly flipped image.
262
+ """
263
+ if random.random() < self.p:
264
+ return F.vflip(img), F.vflip(tgt)
265
+ return img, tgt
266
+
267
+ def __repr__(self):
268
+ return self.__class__.__name__ + '(p={})'.format(self.p)
269
+
270
+
271
+ class Lambda(object):
272
+ """Apply a user-defined lambda as a transform.
273
+
274
+ Args:
275
+ lambd (function): Lambda/function to be used for transform.
276
+ """
277
+
278
+ def __init__(self, lambd):
279
+ assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
280
+ self.lambd = lambd
281
+
282
+ def __call__(self, img, tgt):
283
+ return self.lambd(img, tgt)
284
+
285
+ def __repr__(self):
286
+ return self.__class__.__name__ + '()'
287
+
288
+
289
+ class ColorJitter(object):
290
+ """Randomly change the brightness, contrast and saturation of an image.
291
+
292
+ Args:
293
+ brightness (float or tuple of float (min, max)): How much to jitter brightness.
294
+ brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
295
+ or the given [min, max]. Should be non negative numbers.
296
+ contrast (float or tuple of float (min, max)): How much to jitter contrast.
297
+ contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
298
+ or the given [min, max]. Should be non negative numbers.
299
+ saturation (float or tuple of float (min, max)): How much to jitter saturation.
300
+ saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
301
+ or the given [min, max]. Should be non negative numbers.
302
+ hue (float or tuple of float (min, max)): How much to jitter hue.
303
+ hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
304
+ Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
305
+ """
306
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
307
+ self.brightness = self._check_input(brightness, 'brightness')
308
+ self.contrast = self._check_input(contrast, 'contrast')
309
+ self.saturation = self._check_input(saturation, 'saturation')
310
+ self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
311
+ clip_first_on_zero=False)
312
+
313
+ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
314
+ if isinstance(value, numbers.Number):
315
+ if value < 0:
316
+ raise ValueError("If {} is a single number, it must be non negative.".format(name))
317
+ value = [center - value, center + value]
318
+ if clip_first_on_zero:
319
+ value[0] = max(value[0], 0)
320
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
321
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
322
+ raise ValueError("{} values should be between {}".format(name, bound))
323
+ else:
324
+ raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
325
+
326
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
327
+ # or (0., 0.) for hue, do nothing
328
+ if value[0] == value[1] == center:
329
+ value = None
330
+ return value
331
+
332
+ @staticmethod
333
+ def get_params(brightness, contrast, saturation, hue):
334
+ """Get a randomized transform to be applied on image.
335
+
336
+ Arguments are same as that of __init__.
337
+
338
+ Returns:
339
+ Transform which randomly adjusts brightness, contrast and
340
+ saturation in a random order.
341
+ """
342
+ transforms = []
343
+
344
+ if brightness is not None:
345
+ brightness_factor = random.uniform(brightness[0], brightness[1])
346
+ transforms.append(Lambda(lambda img, tgt: (F.adjust_brightness(img, brightness_factor), tgt)))
347
+
348
+ if contrast is not None:
349
+ contrast_factor = random.uniform(contrast[0], contrast[1])
350
+ transforms.append(Lambda(lambda img, tgt: (F.adjust_contrast(img, contrast_factor), tgt)))
351
+
352
+ if saturation is not None:
353
+ saturation_factor = random.uniform(saturation[0], saturation[1])
354
+ transforms.append(Lambda(lambda img, tgt: (F.adjust_saturation(img, saturation_factor), tgt)))
355
+
356
+ if hue is not None:
357
+ hue_factor = random.uniform(hue[0], hue[1])
358
+ transforms.append(Lambda(lambda img, tgt: (F.adjust_hue(img, hue_factor), tgt)))
359
+
360
+ random.shuffle(transforms)
361
+ transform = Compose(transforms)
362
+
363
+ return transform
364
+
365
+ def __call__(self, img, tgt):
366
+ """
367
+ Args:
368
+ img (PIL Image): Input image.
369
+
370
+ Returns:
371
+ PIL Image: Color jittered image.
372
+ """
373
+ transform = self.get_params(self.brightness, self.contrast,
374
+ self.saturation, self.hue)
375
+ return transform(img, tgt)
376
+
377
+ def __repr__(self):
378
+ format_string = self.__class__.__name__ + '('
379
+ format_string += 'brightness={0}'.format(self.brightness)
380
+ format_string += ', contrast={0}'.format(self.contrast)
381
+ format_string += ', saturation={0}'.format(self.saturation)
382
+ format_string += ', hue={0})'.format(self.hue)
383
+ return format_string
384
+
385
+
386
+ class Normalize(object):
387
+ """Normalize a tensor image with mean and standard deviation.
388
+ Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
389
+ will normalize each channel of the input ``torch.*Tensor`` i.e.
390
+ ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
391
+
392
+ .. note::
393
+ This transform acts out of place, i.e., it does not mutates the input tensor.
394
+
395
+ Args:
396
+ mean (sequence): Sequence of means for each channel.
397
+ std (sequence): Sequence of standard deviations for each channel.
398
+ """
399
+
400
+ def __init__(self, mean, std, inplace=False):
401
+ self.mean = mean
402
+ self.std = std
403
+ self.inplace = inplace
404
+
405
+ def __call__(self, img, tgt):
406
+ """
407
+ Args:
408
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
409
+
410
+ Returns:
411
+ Tensor: Normalized Tensor image.
412
+ """
413
+ # return F.normalize(img, self.mean, self.std, self.inplace), tgt
414
+ return F.normalize(img, self.mean, self.std), tgt
415
+
416
+ def __repr__(self):
417
+ return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
418
+
419
+
420
+ class ToTensor(object):
421
+ """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
422
+
423
+ Converts a PIL Image or numpy.ndarray (H x W x C) in the range
424
+ [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
425
+ if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
426
+ or if the numpy.ndarray has dtype = np.uint8
427
+
428
+ In the other cases, tensors are returned without scaling.
429
+ """
430
+
431
+ def __call__(self, img, tgt):
432
+ """
433
+ Args:
434
+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
435
+
436
+ Returns:
437
+ Tensor: Converted image.
438
+ """
439
+ return F.to_tensor(img), tgt
440
+
441
+ def __repr__(self):
442
+ return self.__class__.__name__ + '()'
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/generate_visualizations.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+ import h5py
4
+
5
+ import argparse
6
+
7
+ # Import saliency methods and models
8
+ from misc_functions import *
9
+
10
+ from ViT_explanation_generator import Baselines, LRP
11
+ from ViT_new import vit_base_patch16_224
12
+ from ViT_LRP import vit_base_patch16_224 as vit_LRP
13
+ from ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP
14
+
15
+ from torchvision.datasets import ImageNet
16
+
17
+
18
+ def normalize(tensor,
19
+ mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
20
+ dtype = tensor.dtype
21
+ mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
22
+ std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
23
+ tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
24
+ return tensor
25
+
26
+
27
+ def compute_saliency_and_save(args):
28
+ first = True
29
+ with h5py.File(os.path.join(args.method_dir, 'results.hdf5'), 'a') as f:
30
+ data_cam = f.create_dataset('vis',
31
+ (1, 1, 224, 224),
32
+ maxshape=(None, 1, 224, 224),
33
+ dtype=np.float32,
34
+ compression="gzip")
35
+ data_image = f.create_dataset('image',
36
+ (1, 3, 224, 224),
37
+ maxshape=(None, 3, 224, 224),
38
+ dtype=np.float32,
39
+ compression="gzip")
40
+ data_target = f.create_dataset('target',
41
+ (1,),
42
+ maxshape=(None,),
43
+ dtype=np.int32,
44
+ compression="gzip")
45
+ for batch_idx, (data, target) in enumerate(tqdm(sample_loader)):
46
+ if first:
47
+ first = False
48
+ data_cam.resize(data_cam.shape[0] + data.shape[0] - 1, axis=0)
49
+ data_image.resize(data_image.shape[0] + data.shape[0] - 1, axis=0)
50
+ data_target.resize(data_target.shape[0] + data.shape[0] - 1, axis=0)
51
+ else:
52
+ data_cam.resize(data_cam.shape[0] + data.shape[0], axis=0)
53
+ data_image.resize(data_image.shape[0] + data.shape[0], axis=0)
54
+ data_target.resize(data_target.shape[0] + data.shape[0], axis=0)
55
+
56
+ # Add data
57
+ data_image[-data.shape[0]:] = data.data.cpu().numpy()
58
+ data_target[-data.shape[0]:] = target.data.cpu().numpy()
59
+
60
+ target = target.to(device)
61
+
62
+ data = normalize(data)
63
+ data = data.to(device)
64
+ data.requires_grad_()
65
+
66
+ index = None
67
+ if args.vis_class == 'target':
68
+ index = target
69
+
70
+ if args.method == 'rollout':
71
+ Res = baselines.generate_rollout(data, start_layer=1).reshape(data.shape[0], 1, 14, 14)
72
+ # Res = Res - Res.mean()
73
+
74
+ elif args.method == 'lrp':
75
+ Res = lrp.generate_LRP(data, start_layer=1, index=index).reshape(data.shape[0], 1, 14, 14)
76
+ # Res = Res - Res.mean()
77
+
78
+ elif args.method == 'transformer_attribution':
79
+ Res = lrp.generate_LRP(data, start_layer=1, method="grad", index=index).reshape(data.shape[0], 1, 14, 14)
80
+ # Res = Res - Res.mean()
81
+
82
+ elif args.method == 'full_lrp':
83
+ Res = orig_lrp.generate_LRP(data, method="full", index=index).reshape(data.shape[0], 1, 224, 224)
84
+ # Res = Res - Res.mean()
85
+
86
+ elif args.method == 'lrp_last_layer':
87
+ Res = orig_lrp.generate_LRP(data, method="last_layer", is_ablation=args.is_ablation, index=index) \
88
+ .reshape(data.shape[0], 1, 14, 14)
89
+ # Res = Res - Res.mean()
90
+
91
+ elif args.method == 'attn_last_layer':
92
+ Res = lrp.generate_LRP(data, method="last_layer_attn", is_ablation=args.is_ablation) \
93
+ .reshape(data.shape[0], 1, 14, 14)
94
+
95
+ elif args.method == 'attn_gradcam':
96
+ Res = baselines.generate_cam_attn(data, index=index).reshape(data.shape[0], 1, 14, 14)
97
+
98
+ if args.method != 'full_lrp' and args.method != 'input_grads':
99
+ Res = torch.nn.functional.interpolate(Res, scale_factor=16, mode='bilinear').cuda()
100
+ Res = (Res - Res.min()) / (Res.max() - Res.min())
101
+
102
+ data_cam[-data.shape[0]:] = Res.data.cpu().numpy()
103
+
104
+
105
+ if __name__ == "__main__":
106
+ parser = argparse.ArgumentParser(description='Train a segmentation')
107
+ parser.add_argument('--batch-size', type=int,
108
+ default=1,
109
+ help='')
110
+ parser.add_argument('--method', type=str,
111
+ default='grad_rollout',
112
+ choices=['rollout', 'lrp', 'transformer_attribution', 'full_lrp', 'lrp_last_layer',
113
+ 'attn_last_layer', 'attn_gradcam'],
114
+ help='')
115
+ parser.add_argument('--lmd', type=float,
116
+ default=10,
117
+ help='')
118
+ parser.add_argument('--vis-class', type=str,
119
+ default='top',
120
+ choices=['top', 'target', 'index'],
121
+ help='')
122
+ parser.add_argument('--class-id', type=int,
123
+ default=0,
124
+ help='')
125
+ parser.add_argument('--cls-agn', action='store_true',
126
+ default=False,
127
+ help='')
128
+ parser.add_argument('--no-ia', action='store_true',
129
+ default=False,
130
+ help='')
131
+ parser.add_argument('--no-fx', action='store_true',
132
+ default=False,
133
+ help='')
134
+ parser.add_argument('--no-fgx', action='store_true',
135
+ default=False,
136
+ help='')
137
+ parser.add_argument('--no-m', action='store_true',
138
+ default=False,
139
+ help='')
140
+ parser.add_argument('--no-reg', action='store_true',
141
+ default=False,
142
+ help='')
143
+ parser.add_argument('--is-ablation', type=bool,
144
+ default=False,
145
+ help='')
146
+ parser.add_argument('--imagenet-validation-path', type=str,
147
+ required=True,
148
+ help='')
149
+ args = parser.parse_args()
150
+
151
+ # PATH variables
152
+ PATH = os.path.dirname(os.path.abspath(__file__)) + '/'
153
+ os.makedirs(os.path.join(PATH, 'visualizations'), exist_ok=True)
154
+
155
+ try:
156
+ os.remove(os.path.join(PATH, 'visualizations/{}/{}/results.hdf5'.format(args.method,
157
+ args.vis_class)))
158
+ except OSError:
159
+ pass
160
+
161
+
162
+ os.makedirs(os.path.join(PATH, 'visualizations/{}'.format(args.method)), exist_ok=True)
163
+ if args.vis_class == 'index':
164
+ os.makedirs(os.path.join(PATH, 'visualizations/{}/{}_{}'.format(args.method,
165
+ args.vis_class,
166
+ args.class_id)), exist_ok=True)
167
+ args.method_dir = os.path.join(PATH, 'visualizations/{}/{}_{}'.format(args.method,
168
+ args.vis_class,
169
+ args.class_id))
170
+ else:
171
+ ablation_fold = 'ablation' if args.is_ablation else 'not_ablation'
172
+ os.makedirs(os.path.join(PATH, 'visualizations/{}/{}/{}'.format(args.method,
173
+ args.vis_class, ablation_fold)), exist_ok=True)
174
+ args.method_dir = os.path.join(PATH, 'visualizations/{}/{}/{}'.format(args.method,
175
+ args.vis_class, ablation_fold))
176
+
177
+ cuda = torch.cuda.is_available()
178
+ device = torch.device("cuda" if cuda else "cpu")
179
+
180
+ # Model
181
+ model = vit_base_patch16_224(pretrained=True).cuda()
182
+ baselines = Baselines(model)
183
+
184
+ # LRP
185
+ model_LRP = vit_LRP(pretrained=True).cuda()
186
+ model_LRP.eval()
187
+ lrp = LRP(model_LRP)
188
+
189
+ # orig LRP
190
+ model_orig_LRP = vit_orig_LRP(pretrained=True).cuda()
191
+ model_orig_LRP.eval()
192
+ orig_lrp = LRP(model_orig_LRP)
193
+
194
+ # Dataset loader for sample images
195
+ transform = transforms.Compose([
196
+ transforms.Resize((224, 224)),
197
+ transforms.ToTensor(),
198
+ ])
199
+
200
+ imagenet_ds = ImageNet(args.imagenet_validation_path, split='val', download=False, transform=transform)
201
+ sample_loader = torch.utils.data.DataLoader(
202
+ imagenet_ds,
203
+ batch_size=args.batch_size,
204
+ shuffle=False,
205
+ num_workers=4
206
+ )
207
+
208
+ compute_saliency_and_save(args)
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/helpers.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Model creation / weight loading / state_dict helpers
2
+
3
+ Hacked together by / Copyright 2020 Ross Wightman
4
+ """
5
+ import logging
6
+ import os
7
+ import math
8
+ from collections import OrderedDict
9
+ from copy import deepcopy
10
+ from typing import Callable
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.utils.model_zoo as model_zoo
15
+
16
+ _logger = logging.getLogger(__name__)
17
+
18
+
19
+ def load_state_dict(checkpoint_path, use_ema=False):
20
+ if checkpoint_path and os.path.isfile(checkpoint_path):
21
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
22
+ state_dict_key = 'state_dict'
23
+ if isinstance(checkpoint, dict):
24
+ if use_ema and 'state_dict_ema' in checkpoint:
25
+ state_dict_key = 'state_dict_ema'
26
+ if state_dict_key and state_dict_key in checkpoint:
27
+ new_state_dict = OrderedDict()
28
+ for k, v in checkpoint[state_dict_key].items():
29
+ # strip `module.` prefix
30
+ name = k[7:] if k.startswith('module') else k
31
+ new_state_dict[name] = v
32
+ state_dict = new_state_dict
33
+ else:
34
+ state_dict = checkpoint
35
+ _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
36
+ return state_dict
37
+ else:
38
+ _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
39
+ raise FileNotFoundError()
40
+
41
+
42
+ def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
43
+ state_dict = load_state_dict(checkpoint_path, use_ema)
44
+ model.load_state_dict(state_dict, strict=strict)
45
+
46
+
47
+ def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
48
+ resume_epoch = None
49
+ if os.path.isfile(checkpoint_path):
50
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
51
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
52
+ if log_info:
53
+ _logger.info('Restoring model state from checkpoint...')
54
+ new_state_dict = OrderedDict()
55
+ for k, v in checkpoint['state_dict'].items():
56
+ name = k[7:] if k.startswith('module') else k
57
+ new_state_dict[name] = v
58
+ model.load_state_dict(new_state_dict)
59
+
60
+ if optimizer is not None and 'optimizer' in checkpoint:
61
+ if log_info:
62
+ _logger.info('Restoring optimizer state from checkpoint...')
63
+ optimizer.load_state_dict(checkpoint['optimizer'])
64
+
65
+ if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
66
+ if log_info:
67
+ _logger.info('Restoring AMP loss scaler state from checkpoint...')
68
+ loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
69
+
70
+ if 'epoch' in checkpoint:
71
+ resume_epoch = checkpoint['epoch']
72
+ if 'version' in checkpoint and checkpoint['version'] > 1:
73
+ resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
74
+
75
+ if log_info:
76
+ _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
77
+ else:
78
+ model.load_state_dict(checkpoint)
79
+ if log_info:
80
+ _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
81
+ return resume_epoch
82
+ else:
83
+ _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
84
+ raise FileNotFoundError()
85
+
86
+
87
+ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True):
88
+ if cfg is None:
89
+ cfg = getattr(model, 'default_cfg')
90
+ if cfg is None or 'url' not in cfg or not cfg['url']:
91
+ _logger.warning("Pretrained model URL is invalid, using random initialization.")
92
+ return
93
+
94
+ state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
95
+
96
+ if filter_fn is not None:
97
+ state_dict = filter_fn(state_dict)
98
+
99
+ if in_chans == 1:
100
+ conv1_name = cfg['first_conv']
101
+ _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
102
+ conv1_weight = state_dict[conv1_name + '.weight']
103
+ # Some weights are in torch.half, ensure it's float for sum on CPU
104
+ conv1_type = conv1_weight.dtype
105
+ conv1_weight = conv1_weight.float()
106
+ O, I, J, K = conv1_weight.shape
107
+ if I > 3:
108
+ assert conv1_weight.shape[1] % 3 == 0
109
+ # For models with space2depth stems
110
+ conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
111
+ conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
112
+ else:
113
+ conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
114
+ conv1_weight = conv1_weight.to(conv1_type)
115
+ state_dict[conv1_name + '.weight'] = conv1_weight
116
+ elif in_chans != 3:
117
+ conv1_name = cfg['first_conv']
118
+ conv1_weight = state_dict[conv1_name + '.weight']
119
+ conv1_type = conv1_weight.dtype
120
+ conv1_weight = conv1_weight.float()
121
+ O, I, J, K = conv1_weight.shape
122
+ if I != 3:
123
+ _logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
124
+ del state_dict[conv1_name + '.weight']
125
+ strict = False
126
+ else:
127
+ # NOTE this strategy should be better than random init, but there could be other combinations of
128
+ # the original RGB input layer weights that'd work better for specific cases.
129
+ _logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
130
+ repeat = int(math.ceil(in_chans / 3))
131
+ conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
132
+ conv1_weight *= (3 / float(in_chans))
133
+ conv1_weight = conv1_weight.to(conv1_type)
134
+ state_dict[conv1_name + '.weight'] = conv1_weight
135
+
136
+ classifier_name = cfg['classifier']
137
+ if num_classes == 1000 and cfg['num_classes'] == 1001:
138
+ # special case for imagenet trained models with extra background class in pretrained weights
139
+ classifier_weight = state_dict[classifier_name + '.weight']
140
+ state_dict[classifier_name + '.weight'] = classifier_weight[1:]
141
+ classifier_bias = state_dict[classifier_name + '.bias']
142
+ state_dict[classifier_name + '.bias'] = classifier_bias[1:]
143
+ elif num_classes != cfg['num_classes']:
144
+ # completely discard fully connected for all other differences between pretrained and created model
145
+ del state_dict[classifier_name + '.weight']
146
+ del state_dict[classifier_name + '.bias']
147
+ strict = False
148
+
149
+ model.load_state_dict(state_dict, strict=strict)
150
+
151
+
152
+ def extract_layer(model, layer):
153
+ layer = layer.split('.')
154
+ module = model
155
+ if hasattr(model, 'module') and layer[0] != 'module':
156
+ module = model.module
157
+ if not hasattr(model, 'module') and layer[0] == 'module':
158
+ layer = layer[1:]
159
+ for l in layer:
160
+ if hasattr(module, l):
161
+ if not l.isdigit():
162
+ module = getattr(module, l)
163
+ else:
164
+ module = module[int(l)]
165
+ else:
166
+ return module
167
+ return module
168
+
169
+
170
+ def set_layer(model, layer, val):
171
+ layer = layer.split('.')
172
+ module = model
173
+ if hasattr(model, 'module') and layer[0] != 'module':
174
+ module = model.module
175
+ lst_index = 0
176
+ module2 = module
177
+ for l in layer:
178
+ if hasattr(module2, l):
179
+ if not l.isdigit():
180
+ module2 = getattr(module2, l)
181
+ else:
182
+ module2 = module2[int(l)]
183
+ lst_index += 1
184
+ lst_index -= 1
185
+ for l in layer[:lst_index]:
186
+ if not l.isdigit():
187
+ module = getattr(module, l)
188
+ else:
189
+ module = module[int(l)]
190
+ l = layer[lst_index]
191
+ setattr(module, l, val)
192
+
193
+
194
+ def adapt_model_from_string(parent_module, model_string):
195
+ separator = '***'
196
+ state_dict = {}
197
+ lst_shape = model_string.split(separator)
198
+ for k in lst_shape:
199
+ k = k.split(':')
200
+ key = k[0]
201
+ shape = k[1][1:-1].split(',')
202
+ if shape[0] != '':
203
+ state_dict[key] = [int(i) for i in shape]
204
+
205
+ new_module = deepcopy(parent_module)
206
+ for n, m in parent_module.named_modules():
207
+ old_module = extract_layer(parent_module, n)
208
+ if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
209
+ if isinstance(old_module, Conv2dSame):
210
+ conv = Conv2dSame
211
+ else:
212
+ conv = nn.Conv2d
213
+ s = state_dict[n + '.weight']
214
+ in_channels = s[1]
215
+ out_channels = s[0]
216
+ g = 1
217
+ if old_module.groups > 1:
218
+ in_channels = out_channels
219
+ g = in_channels
220
+ new_conv = conv(
221
+ in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
222
+ bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
223
+ groups=g, stride=old_module.stride)
224
+ set_layer(new_module, n, new_conv)
225
+ if isinstance(old_module, nn.BatchNorm2d):
226
+ new_bn = nn.BatchNorm2d(
227
+ num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
228
+ affine=old_module.affine, track_running_stats=True)
229
+ set_layer(new_module, n, new_bn)
230
+ if isinstance(old_module, nn.Linear):
231
+ # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
232
+ num_features = state_dict[n + '.weight'][1]
233
+ new_fc = nn.Linear(
234
+ in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
235
+ set_layer(new_module, n, new_fc)
236
+ if hasattr(new_module, 'num_features'):
237
+ new_module.num_features = num_features
238
+ new_module.eval()
239
+ parent_module.eval()
240
+
241
+ return new_module
242
+
243
+
244
+ def adapt_model_from_file(parent_module, model_variant):
245
+ adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt')
246
+ with open(adapt_file, 'r') as f:
247
+ return adapt_model_from_string(parent_module, f.read().strip())
248
+
249
+
250
+ def build_model_with_cfg(
251
+ model_cls: Callable,
252
+ variant: str,
253
+ pretrained: bool,
254
+ default_cfg: dict,
255
+ model_cfg: dict = None,
256
+ feature_cfg: dict = None,
257
+ pretrained_strict: bool = True,
258
+ pretrained_filter_fn: Callable = None,
259
+ **kwargs):
260
+ pruned = kwargs.pop('pruned', False)
261
+ features = False
262
+ feature_cfg = feature_cfg or {}
263
+
264
+ if kwargs.pop('features_only', False):
265
+ features = True
266
+ feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
267
+ if 'out_indices' in kwargs:
268
+ feature_cfg['out_indices'] = kwargs.pop('out_indices')
269
+
270
+ model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
271
+ model.default_cfg = deepcopy(default_cfg)
272
+
273
+ if pruned:
274
+ model = adapt_model_from_file(model, variant)
275
+
276
+ if pretrained:
277
+ load_pretrained(
278
+ model,
279
+ num_classes=kwargs.get('num_classes', 0),
280
+ in_chans=kwargs.get('in_chans', 3),
281
+ filter_fn=pretrained_filter_fn, strict=pretrained_strict)
282
+
283
+ if features:
284
+ feature_cls = FeatureListNet
285
+ if 'feature_cls' in feature_cfg:
286
+ feature_cls = feature_cfg.pop('feature_cls')
287
+ if isinstance(feature_cls, str):
288
+ feature_cls = feature_cls.lower()
289
+ if 'hook' in feature_cls:
290
+ feature_cls = FeatureHookNet
291
+ else:
292
+ assert False, f'Unknown feature class {feature_cls}'
293
+ model = feature_cls(model, **feature_cfg)
294
+
295
+ return model
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/layer_helpers.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Layer/Module Helpers
2
+ Hacked together by / Copyright 2020 Ross Wightman
3
+ """
4
+ from itertools import repeat
5
+ import collections.abc
6
+
7
+
8
+ # From PyTorch internals
9
+ def _ntuple(n):
10
+ def parse(x):
11
+ if isinstance(x, collections.abc.Iterable):
12
+ return x
13
+ return tuple(repeat(x, n))
14
+ return parse
15
+
16
+
17
+ to_1tuple = _ntuple(1)
18
+ to_2tuple = _ntuple(2)
19
+ to_3tuple = _ntuple(3)
20
+ to_4tuple = _ntuple(4)
21
+ to_ntuple = _ntuple
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/misc_functions.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/
3
+ # Written by Suraj Srinivas <[email protected]>
4
+ #
5
+
6
+ """ Misc helper functions """
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import subprocess
11
+
12
+ import torch
13
+ import torchvision.transforms as transforms
14
+
15
+
16
+ class NormalizeInverse(transforms.Normalize):
17
+ # Undo normalization on images
18
+
19
+ def __init__(self, mean, std):
20
+ mean = torch.as_tensor(mean)
21
+ std = torch.as_tensor(std)
22
+ std_inv = 1 / (std + 1e-7)
23
+ mean_inv = -mean * std_inv
24
+ super(NormalizeInverse, self).__init__(mean=mean_inv, std=std_inv)
25
+
26
+ def __call__(self, tensor):
27
+ return super(NormalizeInverse, self).__call__(tensor.clone())
28
+
29
+
30
+ def create_folder(folder_name):
31
+ try:
32
+ subprocess.call(['mkdir', '-p', folder_name])
33
+ except OSError:
34
+ None
35
+
36
+
37
+ def save_saliency_map(image, saliency_map, filename):
38
+ """
39
+ Save saliency map on image.
40
+
41
+ Args:
42
+ image: Tensor of size (3,H,W)
43
+ saliency_map: Tensor of size (1,H,W)
44
+ filename: string with complete path and file extension
45
+
46
+ """
47
+
48
+ image = image.data.cpu().numpy()
49
+ saliency_map = saliency_map.data.cpu().numpy()
50
+
51
+ saliency_map = saliency_map - saliency_map.min()
52
+ saliency_map = saliency_map / saliency_map.max()
53
+ saliency_map = saliency_map.clip(0, 1)
54
+
55
+ saliency_map = np.uint8(saliency_map * 255).transpose(1, 2, 0)
56
+ saliency_map = cv2.resize(saliency_map, (224, 224))
57
+
58
+ image = np.uint8(image * 255).transpose(1, 2, 0)
59
+ image = cv2.resize(image, (224, 224))
60
+
61
+ # Apply JET colormap
62
+ color_heatmap = cv2.applyColorMap(saliency_map, cv2.COLORMAP_JET)
63
+
64
+ # Combine image with heatmap
65
+ img_with_heatmap = np.float32(color_heatmap) + np.float32(image)
66
+ img_with_heatmap = img_with_heatmap / np.max(img_with_heatmap)
67
+
68
+ cv2.imwrite(filename, np.uint8(255 * img_with_heatmap))
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/__init__.py ADDED
File without changes
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (223 Bytes). View file
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/__pycache__/layers_lrp.cpython-310.pyc ADDED
Binary file (9.31 kB). View file
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/__pycache__/layers_ours.cpython-310.pyc ADDED
Binary file (9.75 kB). View file
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/layers_lrp.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ __all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d',
6
+ 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect',
7
+ 'LayerNorm', 'AddEye']
8
+
9
+
10
+ def safe_divide(a, b):
11
+ den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
12
+ den = den + den.eq(0).type(den.type()) * 1e-9
13
+ return a / den * b.ne(0).type(b.type())
14
+
15
+
16
+ def forward_hook(self, input, output):
17
+ if type(input[0]) in (list, tuple):
18
+ self.X = []
19
+ for i in input[0]:
20
+ x = i.detach()
21
+ x.requires_grad = True
22
+ self.X.append(x)
23
+ else:
24
+ self.X = input[0].detach()
25
+ self.X.requires_grad = True
26
+
27
+ self.Y = output
28
+
29
+
30
+ def backward_hook(self, grad_input, grad_output):
31
+ self.grad_input = grad_input
32
+ self.grad_output = grad_output
33
+
34
+
35
+ class RelProp(nn.Module):
36
+ def __init__(self):
37
+ super(RelProp, self).__init__()
38
+ # if not self.training:
39
+ self.register_forward_hook(forward_hook)
40
+
41
+ def gradprop(self, Z, X, S):
42
+ C = torch.autograd.grad(Z, X, S, retain_graph=True)
43
+ return C
44
+
45
+ def relprop(self, R, alpha):
46
+ return R
47
+
48
+
49
+ class RelPropSimple(RelProp):
50
+ def relprop(self, R, alpha):
51
+ Z = self.forward(self.X)
52
+ S = safe_divide(R, Z)
53
+ C = self.gradprop(Z, self.X, S)
54
+
55
+ if torch.is_tensor(self.X) == False:
56
+ outputs = []
57
+ outputs.append(self.X[0] * C[0])
58
+ outputs.append(self.X[1] * C[1])
59
+ else:
60
+ outputs = self.X * (C[0])
61
+ return outputs
62
+
63
+ class AddEye(RelPropSimple):
64
+ # input of shape B, C, seq_len, seq_len
65
+ def forward(self, input):
66
+ return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
67
+
68
+ class ReLU(nn.ReLU, RelProp):
69
+ pass
70
+
71
+ class GELU(nn.GELU, RelProp):
72
+ pass
73
+
74
+ class Softmax(nn.Softmax, RelProp):
75
+ pass
76
+
77
+ class LayerNorm(nn.LayerNorm, RelProp):
78
+ pass
79
+
80
+ class Dropout(nn.Dropout, RelProp):
81
+ pass
82
+
83
+
84
+ class MaxPool2d(nn.MaxPool2d, RelPropSimple):
85
+ pass
86
+
87
+ class LayerNorm(nn.LayerNorm, RelProp):
88
+ pass
89
+
90
+ class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
91
+ pass
92
+
93
+
94
+ class AvgPool2d(nn.AvgPool2d, RelPropSimple):
95
+ pass
96
+
97
+
98
+ class Add(RelPropSimple):
99
+ def forward(self, inputs):
100
+ return torch.add(*inputs)
101
+
102
+ class einsum(RelPropSimple):
103
+ def __init__(self, equation):
104
+ super().__init__()
105
+ self.equation = equation
106
+ def forward(self, *operands):
107
+ return torch.einsum(self.equation, *operands)
108
+
109
+ class IndexSelect(RelProp):
110
+ def forward(self, inputs, dim, indices):
111
+ self.__setattr__('dim', dim)
112
+ self.__setattr__('indices', indices)
113
+
114
+ return torch.index_select(inputs, dim, indices)
115
+
116
+ def relprop(self, R, alpha):
117
+ Z = self.forward(self.X, self.dim, self.indices)
118
+ S = safe_divide(R, Z)
119
+ C = self.gradprop(Z, self.X, S)
120
+
121
+ if torch.is_tensor(self.X) == False:
122
+ outputs = []
123
+ outputs.append(self.X[0] * C[0])
124
+ outputs.append(self.X[1] * C[1])
125
+ else:
126
+ outputs = self.X * (C[0])
127
+ return outputs
128
+
129
+
130
+
131
+ class Clone(RelProp):
132
+ def forward(self, input, num):
133
+ self.__setattr__('num', num)
134
+ outputs = []
135
+ for _ in range(num):
136
+ outputs.append(input)
137
+
138
+ return outputs
139
+
140
+ def relprop(self, R, alpha):
141
+ Z = []
142
+ for _ in range(self.num):
143
+ Z.append(self.X)
144
+ S = [safe_divide(r, z) for r, z in zip(R, Z)]
145
+ C = self.gradprop(Z, self.X, S)[0]
146
+
147
+ R = self.X * C
148
+
149
+ return R
150
+
151
+ class Cat(RelProp):
152
+ def forward(self, inputs, dim):
153
+ self.__setattr__('dim', dim)
154
+ return torch.cat(inputs, dim)
155
+
156
+ def relprop(self, R, alpha):
157
+ Z = self.forward(self.X, self.dim)
158
+ S = safe_divide(R, Z)
159
+ C = self.gradprop(Z, self.X, S)
160
+
161
+ outputs = []
162
+ for x, c in zip(self.X, C):
163
+ outputs.append(x * c)
164
+
165
+ return outputs
166
+
167
+
168
+ class Sequential(nn.Sequential):
169
+ def relprop(self, R, alpha):
170
+ for m in reversed(self._modules.values()):
171
+ R = m.relprop(R, alpha)
172
+ return R
173
+
174
+
175
+ class BatchNorm2d(nn.BatchNorm2d, RelProp):
176
+ def relprop(self, R, alpha):
177
+ X = self.X
178
+ beta = 1 - alpha
179
+ weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
180
+ (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5))
181
+ Z = X * weight + 1e-9
182
+ S = R / Z
183
+ Ca = S * weight
184
+ R = self.X * (Ca)
185
+ return R
186
+
187
+
188
+ class Linear(nn.Linear, RelProp):
189
+ def relprop(self, R, alpha):
190
+ beta = alpha - 1
191
+ pw = torch.clamp(self.weight, min=0)
192
+ nw = torch.clamp(self.weight, max=0)
193
+ px = torch.clamp(self.X, min=0)
194
+ nx = torch.clamp(self.X, max=0)
195
+
196
+ def f(w1, w2, x1, x2):
197
+ Z1 = F.linear(x1, w1)
198
+ Z2 = F.linear(x2, w2)
199
+ S1 = safe_divide(R, Z1)
200
+ S2 = safe_divide(R, Z2)
201
+ C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0]
202
+ C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0]
203
+
204
+ return C1 + C2
205
+
206
+ activator_relevances = f(pw, nw, px, nx)
207
+ inhibitor_relevances = f(nw, pw, px, nx)
208
+
209
+ R = alpha * activator_relevances - beta * inhibitor_relevances
210
+
211
+ return R
212
+
213
+
214
+ class Conv2d(nn.Conv2d, RelProp):
215
+ def gradprop2(self, DY, weight):
216
+ Z = self.forward(self.X)
217
+
218
+ output_padding = self.X.size()[2] - (
219
+ (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0])
220
+
221
+ return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding)
222
+
223
+ def relprop(self, R, alpha):
224
+ if self.X.shape[1] == 3:
225
+ pw = torch.clamp(self.weight, min=0)
226
+ nw = torch.clamp(self.weight, max=0)
227
+ X = self.X
228
+ L = self.X * 0 + \
229
+ torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
230
+ keepdim=True)[0]
231
+ H = self.X * 0 + \
232
+ torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
233
+ keepdim=True)[0]
234
+ Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \
235
+ torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \
236
+ torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9
237
+
238
+ S = R / Za
239
+ C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
240
+ R = C
241
+ else:
242
+ beta = alpha - 1
243
+ pw = torch.clamp(self.weight, min=0)
244
+ nw = torch.clamp(self.weight, max=0)
245
+ px = torch.clamp(self.X, min=0)
246
+ nx = torch.clamp(self.X, max=0)
247
+
248
+ def f(w1, w2, x1, x2):
249
+ Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding)
250
+ Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding)
251
+ S1 = safe_divide(R, Z1)
252
+ S2 = safe_divide(R, Z2)
253
+ C1 = x1 * self.gradprop(Z1, x1, S1)[0]
254
+ C2 = x2 * self.gradprop(Z2, x2, S2)[0]
255
+ return C1 + C2
256
+
257
+ activator_relevances = f(pw, nw, px, nx)
258
+ inhibitor_relevances = f(nw, pw, px, nx)
259
+
260
+ R = alpha * activator_relevances - beta * inhibitor_relevances
261
+ return R
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/modules/layers_ours.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ __all__ = ['forward_hook', 'Clone', 'Add', 'Cat', 'ReLU', 'GELU', 'Dropout', 'BatchNorm2d', 'Linear', 'MaxPool2d',
6
+ 'AdaptiveAvgPool2d', 'AvgPool2d', 'Conv2d', 'Sequential', 'safe_divide', 'einsum', 'Softmax', 'IndexSelect',
7
+ 'LayerNorm', 'AddEye']
8
+
9
+
10
+ def safe_divide(a, b):
11
+ den = b.clamp(min=1e-9) + b.clamp(max=1e-9)
12
+ den = den + den.eq(0).type(den.type()) * 1e-9
13
+ return a / den * b.ne(0).type(b.type())
14
+
15
+
16
+ def forward_hook(self, input, output):
17
+ if type(input[0]) in (list, tuple):
18
+ self.X = []
19
+ for i in input[0]:
20
+ x = i.detach()
21
+ x.requires_grad = True
22
+ self.X.append(x)
23
+ else:
24
+ self.X = input[0].detach()
25
+ self.X.requires_grad = True
26
+
27
+ self.Y = output
28
+
29
+
30
+ def backward_hook(self, grad_input, grad_output):
31
+ self.grad_input = grad_input
32
+ self.grad_output = grad_output
33
+
34
+
35
+ class RelProp(nn.Module):
36
+ def __init__(self):
37
+ super(RelProp, self).__init__()
38
+ # if not self.training:
39
+ self.register_forward_hook(forward_hook)
40
+
41
+ def gradprop(self, Z, X, S):
42
+ C = torch.autograd.grad(Z, X, S, retain_graph=True)
43
+ return C
44
+
45
+ def relprop(self, R, alpha):
46
+ return R
47
+
48
+ class RelPropSimple(RelProp):
49
+ def relprop(self, R, alpha):
50
+ Z = self.forward(self.X)
51
+ S = safe_divide(R, Z)
52
+ C = self.gradprop(Z, self.X, S)
53
+
54
+ if torch.is_tensor(self.X) == False:
55
+ outputs = []
56
+ outputs.append(self.X[0] * C[0])
57
+ outputs.append(self.X[1] * C[1])
58
+ else:
59
+ outputs = self.X * (C[0])
60
+ return outputs
61
+
62
+ class AddEye(RelPropSimple):
63
+ # input of shape B, C, seq_len, seq_len
64
+ def forward(self, input):
65
+ return input + torch.eye(input.shape[2]).expand_as(input).to(input.device)
66
+
67
+ class ReLU(nn.ReLU, RelProp):
68
+ pass
69
+
70
+ class GELU(nn.GELU, RelProp):
71
+ pass
72
+
73
+ class Softmax(nn.Softmax, RelProp):
74
+ pass
75
+
76
+ class LayerNorm(nn.LayerNorm, RelProp):
77
+ pass
78
+
79
+ class Dropout(nn.Dropout, RelProp):
80
+ pass
81
+
82
+
83
+ class MaxPool2d(nn.MaxPool2d, RelPropSimple):
84
+ pass
85
+
86
+ class LayerNorm(nn.LayerNorm, RelProp):
87
+ pass
88
+
89
+ class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, RelPropSimple):
90
+ pass
91
+
92
+
93
+ class AvgPool2d(nn.AvgPool2d, RelPropSimple):
94
+ pass
95
+
96
+
97
+ class Add(RelPropSimple):
98
+ def forward(self, inputs):
99
+ return torch.add(*inputs)
100
+
101
+ def relprop(self, R, alpha):
102
+ Z = self.forward(self.X)
103
+ S = safe_divide(R, Z)
104
+ C = self.gradprop(Z, self.X, S)
105
+
106
+ a = self.X[0] * C[0]
107
+ b = self.X[1] * C[1]
108
+
109
+ a_sum = a.sum()
110
+ b_sum = b.sum()
111
+
112
+ a_fact = safe_divide(a_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
113
+ b_fact = safe_divide(b_sum.abs(), a_sum.abs() + b_sum.abs()) * R.sum()
114
+
115
+ a = a * safe_divide(a_fact, a.sum())
116
+ b = b * safe_divide(b_fact, b.sum())
117
+
118
+ outputs = [a, b]
119
+
120
+ return outputs
121
+
122
+ class einsum(RelPropSimple):
123
+ def __init__(self, equation):
124
+ super().__init__()
125
+ self.equation = equation
126
+ def forward(self, *operands):
127
+ return torch.einsum(self.equation, *operands)
128
+
129
+ class IndexSelect(RelProp):
130
+ def forward(self, inputs, dim, indices):
131
+ self.__setattr__('dim', dim)
132
+ self.__setattr__('indices', indices)
133
+
134
+ return torch.index_select(inputs, dim, indices)
135
+
136
+ def relprop(self, R, alpha):
137
+ Z = self.forward(self.X, self.dim, self.indices)
138
+ S = safe_divide(R, Z)
139
+ C = self.gradprop(Z, self.X, S)
140
+
141
+ if torch.is_tensor(self.X) == False:
142
+ outputs = []
143
+ outputs.append(self.X[0] * C[0])
144
+ outputs.append(self.X[1] * C[1])
145
+ else:
146
+ outputs = self.X * (C[0])
147
+ return outputs
148
+
149
+
150
+
151
+ class Clone(RelProp):
152
+ def forward(self, input, num):
153
+ self.__setattr__('num', num)
154
+ outputs = []
155
+ for _ in range(num):
156
+ outputs.append(input)
157
+
158
+ return outputs
159
+
160
+ def relprop(self, R, alpha):
161
+ Z = []
162
+ for _ in range(self.num):
163
+ Z.append(self.X)
164
+ S = [safe_divide(r, z) for r, z in zip(R, Z)]
165
+ C = self.gradprop(Z, self.X, S)[0]
166
+
167
+ R = self.X * C
168
+
169
+ return R
170
+
171
+ class Cat(RelProp):
172
+ def forward(self, inputs, dim):
173
+ self.__setattr__('dim', dim)
174
+ return torch.cat(inputs, dim)
175
+
176
+ def relprop(self, R, alpha):
177
+ Z = self.forward(self.X, self.dim)
178
+ S = safe_divide(R, Z)
179
+ C = self.gradprop(Z, self.X, S)
180
+
181
+ outputs = []
182
+ for x, c in zip(self.X, C):
183
+ outputs.append(x * c)
184
+
185
+ return outputs
186
+
187
+
188
+ class Sequential(nn.Sequential):
189
+ def relprop(self, R, alpha):
190
+ for m in reversed(self._modules.values()):
191
+ R = m.relprop(R, alpha)
192
+ return R
193
+
194
+ class BatchNorm2d(nn.BatchNorm2d, RelProp):
195
+ def relprop(self, R, alpha):
196
+ X = self.X
197
+ beta = 1 - alpha
198
+ weight = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) / (
199
+ (self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3).pow(2) + self.eps).pow(0.5))
200
+ Z = X * weight + 1e-9
201
+ S = R / Z
202
+ Ca = S * weight
203
+ R = self.X * (Ca)
204
+ return R
205
+
206
+
207
+ class Linear(nn.Linear, RelProp):
208
+ def relprop(self, R, alpha):
209
+ beta = alpha - 1
210
+ pw = torch.clamp(self.weight, min=0)
211
+ nw = torch.clamp(self.weight, max=0)
212
+ px = torch.clamp(self.X, min=0)
213
+ nx = torch.clamp(self.X, max=0)
214
+
215
+ def f(w1, w2, x1, x2):
216
+ Z1 = F.linear(x1, w1)
217
+ Z2 = F.linear(x2, w2)
218
+ S1 = safe_divide(R, Z1 + Z2)
219
+ S2 = safe_divide(R, Z1 + Z2)
220
+ C1 = x1 * torch.autograd.grad(Z1, x1, S1)[0]
221
+ C2 = x2 * torch.autograd.grad(Z2, x2, S2)[0]
222
+
223
+ return C1 + C2
224
+
225
+ activator_relevances = f(pw, nw, px, nx)
226
+ inhibitor_relevances = f(nw, pw, px, nx)
227
+
228
+ R = alpha * activator_relevances - beta * inhibitor_relevances
229
+
230
+ return R
231
+
232
+
233
+ class Conv2d(nn.Conv2d, RelProp):
234
+ def gradprop2(self, DY, weight):
235
+ Z = self.forward(self.X)
236
+
237
+ output_padding = self.X.size()[2] - (
238
+ (Z.size()[2] - 1) * self.stride[0] - 2 * self.padding[0] + self.kernel_size[0])
239
+
240
+ return F.conv_transpose2d(DY, weight, stride=self.stride, padding=self.padding, output_padding=output_padding)
241
+
242
+ def relprop(self, R, alpha):
243
+ if self.X.shape[1] == 3:
244
+ pw = torch.clamp(self.weight, min=0)
245
+ nw = torch.clamp(self.weight, max=0)
246
+ X = self.X
247
+ L = self.X * 0 + \
248
+ torch.min(torch.min(torch.min(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
249
+ keepdim=True)[0]
250
+ H = self.X * 0 + \
251
+ torch.max(torch.max(torch.max(self.X, dim=1, keepdim=True)[0], dim=2, keepdim=True)[0], dim=3,
252
+ keepdim=True)[0]
253
+ Za = torch.conv2d(X, self.weight, bias=None, stride=self.stride, padding=self.padding) - \
254
+ torch.conv2d(L, pw, bias=None, stride=self.stride, padding=self.padding) - \
255
+ torch.conv2d(H, nw, bias=None, stride=self.stride, padding=self.padding) + 1e-9
256
+
257
+ S = R / Za
258
+ C = X * self.gradprop2(S, self.weight) - L * self.gradprop2(S, pw) - H * self.gradprop2(S, nw)
259
+ R = C
260
+ else:
261
+ beta = alpha - 1
262
+ pw = torch.clamp(self.weight, min=0)
263
+ nw = torch.clamp(self.weight, max=0)
264
+ px = torch.clamp(self.X, min=0)
265
+ nx = torch.clamp(self.X, max=0)
266
+
267
+ def f(w1, w2, x1, x2):
268
+ Z1 = F.conv2d(x1, w1, bias=None, stride=self.stride, padding=self.padding)
269
+ Z2 = F.conv2d(x2, w2, bias=None, stride=self.stride, padding=self.padding)
270
+ S1 = safe_divide(R, Z1)
271
+ S2 = safe_divide(R, Z2)
272
+ C1 = x1 * self.gradprop(Z1, x1, S1)[0]
273
+ C2 = x2 * self.gradprop(Z2, x2, S2)[0]
274
+ return C1 + C2
275
+
276
+ activator_relevances = f(pw, nw, px, nx)
277
+ inhibitor_relevances = f(nw, pw, px, nx)
278
+
279
+ R = alpha * activator_relevances - beta * inhibitor_relevances
280
+ return R
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/pertubation_eval_from_hdf5.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+ import argparse
6
+
7
+ # Import saliency methods and models
8
+ from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.ViT_explanation_generator import Baselines
9
+ from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.ViT_new import vit_base_patch16_224
10
+ # from models.vgg import vgg19
11
+ import glob
12
+
13
+ from dataset.expl_hdf5 import ImagenetResults
14
+
15
+
16
+ def normalize(tensor,
17
+ mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
18
+ dtype = tensor.dtype
19
+ mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
20
+ std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
21
+ tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
22
+ return tensor
23
+
24
+
25
+ def eval(args):
26
+ num_samples = 0
27
+ num_correct_model = np.zeros((len(imagenet_ds,)))
28
+ dissimilarity_model = np.zeros((len(imagenet_ds,)))
29
+ model_index = 0
30
+
31
+ if args.scale == 'per':
32
+ base_size = 224 * 224
33
+ perturbation_steps = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
34
+ elif args.scale == '100':
35
+ base_size = 100
36
+ perturbation_steps = [5, 10, 15, 20, 25, 30, 35, 40, 45]
37
+ else:
38
+ raise Exception('scale not valid')
39
+
40
+ num_correct_pertub = np.zeros((9, len(imagenet_ds)))
41
+ dissimilarity_pertub = np.zeros((9, len(imagenet_ds)))
42
+ logit_diff_pertub = np.zeros((9, len(imagenet_ds)))
43
+ prob_diff_pertub = np.zeros((9, len(imagenet_ds)))
44
+ perturb_index = 0
45
+
46
+ for batch_idx, (data, vis, target) in enumerate(tqdm(sample_loader)):
47
+ # Update the number of samples
48
+ num_samples += len(data)
49
+
50
+ data = data.to(device)
51
+ vis = vis.to(device)
52
+ target = target.to(device)
53
+ norm_data = normalize(data.clone())
54
+
55
+ # Compute model accuracy
56
+ pred = model(norm_data)
57
+ pred_probabilities = torch.softmax(pred, dim=1)
58
+ pred_org_logit = pred.data.max(1, keepdim=True)[0].squeeze(1)
59
+ pred_org_prob = pred_probabilities.data.max(1, keepdim=True)[0].squeeze(1)
60
+ pred_class = pred.data.max(1, keepdim=True)[1].squeeze(1)
61
+ tgt_pred = (target == pred_class).type(target.type()).data.cpu().numpy()
62
+ num_correct_model[model_index:model_index+len(tgt_pred)] = tgt_pred
63
+
64
+ probs = torch.softmax(pred, dim=1)
65
+ target_probs = torch.gather(probs, 1, target[:, None])[:, 0]
66
+ second_probs = probs.data.topk(2, dim=1)[0][:, 1]
67
+ temp = torch.log(target_probs / second_probs).data.cpu().numpy()
68
+ dissimilarity_model[model_index:model_index+len(temp)] = temp
69
+
70
+ if args.wrong:
71
+ wid = np.argwhere(tgt_pred == 0).flatten()
72
+ if len(wid) == 0:
73
+ continue
74
+ wid = torch.from_numpy(wid).to(vis.device)
75
+ vis = vis.index_select(0, wid)
76
+ data = data.index_select(0, wid)
77
+ target = target.index_select(0, wid)
78
+
79
+ # Save original shape
80
+ org_shape = data.shape
81
+
82
+ if args.neg:
83
+ vis = -vis
84
+
85
+ vis = vis.reshape(org_shape[0], -1)
86
+
87
+ for i in range(len(perturbation_steps)):
88
+ _data = data.clone()
89
+
90
+ _, idx = torch.topk(vis, int(base_size * perturbation_steps[i]), dim=-1)
91
+ idx = idx.unsqueeze(1).repeat(1, org_shape[1], 1)
92
+ _data = _data.reshape(org_shape[0], org_shape[1], -1)
93
+ _data = _data.scatter_(-1, idx, 0)
94
+ _data = _data.reshape(*org_shape)
95
+
96
+ _norm_data = normalize(_data)
97
+
98
+ out = model(_norm_data)
99
+
100
+ pred_probabilities = torch.softmax(out, dim=1)
101
+ pred_prob = pred_probabilities.data.max(1, keepdim=True)[0].squeeze(1)
102
+ diff = (pred_prob - pred_org_prob).data.cpu().numpy()
103
+ prob_diff_pertub[i, perturb_index:perturb_index+len(diff)] = diff
104
+
105
+ pred_logit = out.data.max(1, keepdim=True)[0].squeeze(1)
106
+ diff = (pred_logit - pred_org_logit).data.cpu().numpy()
107
+ logit_diff_pertub[i, perturb_index:perturb_index+len(diff)] = diff
108
+
109
+ target_class = out.data.max(1, keepdim=True)[1].squeeze(1)
110
+ temp = (target == target_class).type(target.type()).data.cpu().numpy()
111
+ num_correct_pertub[i, perturb_index:perturb_index+len(temp)] = temp
112
+
113
+ probs_pertub = torch.softmax(out, dim=1)
114
+ target_probs = torch.gather(probs_pertub, 1, target[:, None])[:, 0]
115
+ second_probs = probs_pertub.data.topk(2, dim=1)[0][:, 1]
116
+ temp = torch.log(target_probs / second_probs).data.cpu().numpy()
117
+ dissimilarity_pertub[i, perturb_index:perturb_index+len(temp)] = temp
118
+
119
+ model_index += len(target)
120
+ perturb_index += len(target)
121
+
122
+ np.save(os.path.join(args.experiment_dir, 'model_hits.npy'), num_correct_model)
123
+ np.save(os.path.join(args.experiment_dir, 'model_dissimilarities.npy'), dissimilarity_model)
124
+ np.save(os.path.join(args.experiment_dir, 'perturbations_hits.npy'), num_correct_pertub[:, :perturb_index])
125
+ np.save(os.path.join(args.experiment_dir, 'perturbations_dissimilarities.npy'), dissimilarity_pertub[:, :perturb_index])
126
+ np.save(os.path.join(args.experiment_dir, 'perturbations_logit_diff.npy'), logit_diff_pertub[:, :perturb_index])
127
+ np.save(os.path.join(args.experiment_dir, 'perturbations_prob_diff.npy'), prob_diff_pertub[:, :perturb_index])
128
+
129
+ print(np.mean(num_correct_model), np.std(num_correct_model))
130
+ print(np.mean(dissimilarity_model), np.std(dissimilarity_model))
131
+ print(perturbation_steps)
132
+ print(np.mean(num_correct_pertub, axis=1), np.std(num_correct_pertub, axis=1))
133
+ print(np.mean(dissimilarity_pertub, axis=1), np.std(dissimilarity_pertub, axis=1))
134
+
135
+
136
+ if __name__ == "__main__":
137
+ parser = argparse.ArgumentParser(description='Train a segmentation')
138
+ parser.add_argument('--batch-size', type=int,
139
+ default=16,
140
+ help='')
141
+ parser.add_argument('--neg', type=bool,
142
+ default=True,
143
+ help='')
144
+ parser.add_argument('--value', action='store_true',
145
+ default=False,
146
+ help='')
147
+ parser.add_argument('--scale', type=str,
148
+ default='per',
149
+ choices=['per', '100'],
150
+ help='')
151
+ parser.add_argument('--method', type=str,
152
+ default='grad_rollout',
153
+ choices=['rollout', 'lrp', 'transformer_attribution', 'full_lrp', 'v_gradcam', 'lrp_last_layer',
154
+ 'lrp_second_layer', 'gradcam',
155
+ 'attn_last_layer', 'attn_gradcam', 'input_grads'],
156
+ help='')
157
+ parser.add_argument('--vis-class', type=str,
158
+ default='top',
159
+ choices=['top', 'target', 'index'],
160
+ help='')
161
+ parser.add_argument('--wrong', action='store_true',
162
+ default=False,
163
+ help='')
164
+ parser.add_argument('--class-id', type=int,
165
+ default=0,
166
+ help='')
167
+ parser.add_argument('--is-ablation', type=bool,
168
+ default=False,
169
+ help='')
170
+ args = parser.parse_args()
171
+
172
+ torch.multiprocessing.set_start_method('spawn')
173
+
174
+ # PATH variables
175
+ PATH = os.path.dirname(os.path.abspath(__file__)) + '/'
176
+ dataset = PATH + 'dataset/'
177
+ os.makedirs(os.path.join(PATH, 'experiments'), exist_ok=True)
178
+ os.makedirs(os.path.join(PATH, 'experiments/perturbations'), exist_ok=True)
179
+
180
+ exp_name = args.method
181
+ exp_name += '_neg' if args.neg else '_pos'
182
+ print(exp_name)
183
+
184
+ if args.vis_class == 'index':
185
+ args.runs_dir = os.path.join(PATH, 'experiments/perturbations/{}/{}_{}'.format(exp_name,
186
+ args.vis_class,
187
+ args.class_id))
188
+ else:
189
+ ablation_fold = 'ablation' if args.is_ablation else 'not_ablation'
190
+ args.runs_dir = os.path.join(PATH, 'experiments/perturbations/{}/{}/{}'.format(exp_name,
191
+ args.vis_class, ablation_fold))
192
+ # args.runs_dir = os.path.join(PATH, 'experiments/perturbations/{}/{}'.format(exp_name,
193
+ # args.vis_class))
194
+
195
+ if args.wrong:
196
+ args.runs_dir += '_wrong'
197
+
198
+ experiments = sorted(glob.glob(os.path.join(args.runs_dir, 'experiment_*')))
199
+ experiment_id = int(experiments[-1].split('_')[-1]) + 1 if experiments else 0
200
+ args.experiment_dir = os.path.join(args.runs_dir, 'experiment_{}'.format(str(experiment_id)))
201
+ os.makedirs(args.experiment_dir, exist_ok=True)
202
+
203
+ cuda = torch.cuda.is_available()
204
+ device = torch.device("cuda" if cuda else "cpu")
205
+
206
+ if args.vis_class == 'index':
207
+ vis_method_dir = os.path.join(PATH,'visualizations/{}/{}_{}'.format(args.method,
208
+ args.vis_class,
209
+ args.class_id))
210
+ else:
211
+ ablation_fold = 'ablation' if args.is_ablation else 'not_ablation'
212
+ vis_method_dir = os.path.join(PATH,'visualizations/{}/{}/{}'.format(args.method,
213
+ args.vis_class, ablation_fold))
214
+ # vis_method_dir = os.path.join(PATH, 'visualizations/{}/{}'.format(args.method,
215
+ # args.vis_class))
216
+
217
+ # imagenet_ds = ImagenetResults('visualizations/{}'.format(args.method))
218
+ imagenet_ds = ImagenetResults(vis_method_dir)
219
+
220
+ # Model
221
+ model = vit_base_patch16_224(pretrained=True).cuda()
222
+ model.eval()
223
+
224
+ save_path = PATH + 'results/'
225
+
226
+ sample_loader = torch.utils.data.DataLoader(
227
+ imagenet_ds,
228
+ batch_size=args.batch_size,
229
+ num_workers=2,
230
+ shuffle=False)
231
+
232
+ eval(args)
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/utils/__init__.py ADDED
File without changes
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (221 Bytes). View file
 
concept_attention/binary_segmentation_baselines/chefer_vit_explainability/utils/__pycache__/confusionmatrix.cpython-310.pyc ADDED
Binary file (3.55 kB). View file