Zero-Shot Image Classification
Transformers
Safetensors
siglip
vision
Inference Endpoints

Output features are different compared to timm

#2
by mobicham - opened

Thanks for making the model available!
It looks like the model gives different output features compared to the timm model. The difference is actually very large even using the exact same pre-processed image. Which one is the right output?

from PIL import Image
import requests
from transformers import AutoProcessor, AutoModel
import torch, timm

model_hf  = AutoModel.from_pretrained("google/siglip-so400m-patch14-384").vision_model
processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384")

model_timm = timm.create_model("vit_so400m_patch14_siglip_384")

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

with torch.no_grad():
    image_p  = processor(images=[image], return_tensors='pt')['pixel_values']
    out_hf   = model_hf(pixel_values=image_p).pooler_output
    out_timm = model_timm(image_p)

print('out_hf', out_hf) #tensor([[ 0.0042, -0.3039, -0.2630,  ..., -0.0633,  0.1298,  0.2565]])
print('out_timm', out_timm) #tensor([[ 0.0442, -0.0952,  0.3953,  ...,  0.7811,  0.1384, -0.6316]])
print('MEA', torch.abs(out_hf - out_timm).mean()) #0.6619

Hi,

Note that timm doesn't use evaluation mode by default, do you get deterministic outputs from it?

Note that I've only compared logits against the original JAX implementation, not Timm, and that the pixel values are slightly different due to the original implementation using tf.image.resize, which does not have an equivalent Pillow/Numpy/PyTorch implementation.

I did this one by one for each model, so I'm sure the implementation is correct, but could take a look in the next few days

Thanks for your reply!
Looks like timm doesn't load the pre-trained model by default, that was the main issue. After the fix, the output are close enough, but not exactly the same.
Closing this issue.

from PIL import Image
import requests
from transformers import AutoProcessor, AutoModel
import torch, timm

model_hf  = AutoModel.from_pretrained("google/siglip-so400m-patch14-384").vision_model.eval()
processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384")

model_timm = timm.create_model("vit_so400m_patch14_siglip_384", pretrained=True).eval()

# data_config = timm.data.resolve_model_data_config(model)
# transforms = timm.data.create_transform(**data_config, is_training=False)

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

with torch.no_grad():
    image_p  = processor(images=[image], return_tensors='pt')['pixel_values']
    out_hf   = model_hf(pixel_values=image_p).pooler_output
    out_timm = model_timm(image_p)

print('out_hf', out_hf) #tensor([[ 0.0042, -0.3039, -0.2630,  ..., -0.0633,  0.1298,  0.2565]])
print('out_timm', out_timm) #tensor([[-0.0024, -0.3045, -0.2662,  ..., -0.0611,  0.1315,  0.2575]])
print('MEA', torch.abs(out_hf - out_timm).mean()) #tensor(0.0026)
mobicham changed discussion status to closed

Sign up or log in to comment