|
--- |
|
license: apache-2.0 |
|
--- |
|
|
|
To use Sum-of-Parts(SOP), you would need to install exlib. Currently SOP is only available on the dev branch https://github.com/BrachioLab/exlib/tree/dev |
|
|
|
To use SOP trained for `google/vit-base-patch16-224`, follow the following code. |
|
|
|
### Load the model |
|
``` |
|
import torch |
|
import os |
|
from transformers import AutoImageProcessor, AutoModelForImageClassification |
|
|
|
import sys |
|
from exlib.modules.sop import WrappedModel, SOPConfig, SOPImageCls, get_chained_attr |
|
|
|
|
|
# init backbone model |
|
backbone_model = AutoModelForImageClassification.from_pretrained('google/vit-base-patch16-224') |
|
processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224') |
|
|
|
# get needed wrapped models |
|
original_model = WrappedModel(backbone_model, output_type='logits') |
|
wrapped_backbone_model = WrappedModel(backbone_model, output_type='tuple') |
|
projection_layer = WrappedModel(wrapped_backbone_model, output_type='hidden_states') |
|
|
|
# load trained sop model |
|
model = SOPImageCls.from_pretrained('BrachioLab/sop-vit-base-patch16-224', |
|
blackbox_model=wrapped_backbone_model, |
|
projection_layer=projection_layer) |
|
model.eval(); |
|
``` |
|
|
|
### Open an image |
|
``` |
|
from PIL import Image |
|
|
|
# Open an example image |
|
# image_path = '../../examples/ILSVRC2012_val_00000873.JPEG' |
|
image_path = '../../examples/ILSVRC2012_val_00000247.JPEG' |
|
image = Image.open(image_path) |
|
image.show() |
|
image_rgb = image.convert("RGB") |
|
inputs = torch.tensor(processor(image_rgb)['pixel_values']) |
|
inputs.shape # (1, 3, 224, 224) |
|
``` |
|
|
|
### Get the output from SOP |
|
``` |
|
# Get the outputs from the model |
|
outputs = model(inputs, return_tuple=True) |
|
``` |
|
|
|
### Show the groups |
|
``` |
|
from exlib.modules.sop import show_masks_weights |
|
|
|
show_masks_weights(inputs, outputs, i=0) # This allows you to see the group masks with group attribution scores. |
|
``` |