Unable to use the model as mentioned in card

#2
by Siddhant230 - opened

Hi,
I am using this for loading the model.

Load model directly

from transformers import AutoImageProcessor, ViTForSketchClassification

processor = AutoImageProcessor.from_pretrained("WinKawaks/SketchXAI-Tiny-QuickDraw345")
model = ViTForSketchClassification.from_pretrained("WinKawaks/SketchXAI-Tiny-QuickDraw345")

I got the class and opts file from your repo here : https://github.com/WinKawaks/SketchXAI/tree/main
though I am able to load the model, I am getting this error for AutoImageProcessor

OSError: WinKawaks/SketchXAI-Tiny-QuickDraw345 does not appear to have a file named preprocessor_config.json. Checkout 'https://huggingface.co/WinKawaks/SketchXAI-Tiny-QuickDraw345/tree/main' for available files.

could you please suggest me a inference method for it, where I can load the model and processor, send in an image and perform a forward pass. It would be really helpful. Thanks

@Siddhant230 Thank you for your interest in this work. For SketchXAI, we take strokes as tokens instead of patches. Therefore, we do not need an "image processor". You can follow https://github.com/WinKawaks/SketchXAI/blob/main/main.py#L216 to load my model only. And if you have to use the image processor, you can load the processor and the model like this:

processor = AutoImageProcessor.from_pretrained("WinKawaks/vit-tiny-patch16-224")
# or
# processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")

model = ViTForSketchClassification.from_pretrained("WinKawaks/SketchXAI-Tiny-QuickDraw345")

sure, thank you for the help. I am able to load the model, but as you mentioned it takes in strokes as tokens, how can I extract strokes. I only have images of sketches. Could you please help me with the pipeline (specifically how to extract strokes from a image).

I am looking for something like

1. required_model_inputs = stroke_extractor(image)
2. prediction = model.predict(required_model_inputs) # predictions are potentially class index as per config.json

@Siddhant230 Our model is based on vector sketches, not raster sketches. If your dataset is in raster format only, your can train a ResNet classifier instead of using Transformer. Converting raster sketches to vector sketches is a difficult problem. Here is a solution: https://github.com/MarkMoHR/virtual_sketching. However, its performance is not very stable. You can have a try with your own dataset.

Sign up or log in to comment