File size: 1,875 Bytes
a085136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
---
license: apache-2.0
---

To use Sum-of-Parts(SOP), you would need to install exlib. Currently SOP is only available on the dev branch https://github.com/BrachioLab/exlib/tree/dev

To use SOP trained for `google/vit-base-patch16-224`, follow the following code.

### Load the model
```
import torch
import os
from transformers import AutoImageProcessor, AutoModelForImageClassification

import sys
from exlib.modules.sop import WrappedModel, SOPConfig, SOPImageCls, get_chained_attr


# init backbone model
backbone_model = AutoModelForImageClassification.from_pretrained('google/vit-base-patch16-224')
processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224')

# get needed wrapped models
original_model = WrappedModel(backbone_model, output_type='logits')
wrapped_backbone_model = WrappedModel(backbone_model, output_type='tuple')
projection_layer = WrappedModel(wrapped_backbone_model, output_type='hidden_states')

# load trained sop model
model = SOPImageCls.from_pretrained('BrachioLab/sop-vit-base-patch16-224', 
                                    blackbox_model=wrapped_backbone_model, 
                                    projection_layer=projection_layer)
model.eval();
```

### Open an image
```
from PIL import Image

# Open an example image
# image_path = '../../examples/ILSVRC2012_val_00000873.JPEG'
image_path = '../../examples/ILSVRC2012_val_00000247.JPEG'
image = Image.open(image_path)
image.show()
image_rgb = image.convert("RGB")
inputs = torch.tensor(processor(image_rgb)['pixel_values'])
inputs.shape # (1, 3, 224, 224)
```

### Get the output from SOP
```
# Get the outputs from the model
outputs = model(inputs, return_tuple=True)
```

### Show the groups
```
from exlib.modules.sop import show_masks_weights

show_masks_weights(inputs, outputs, i=0) # This allows you to see the group masks with group attribution scores.
```