Output features are different compared to timm
Thanks for making the model available!
It looks like the model gives different output features compared to the timm model. The difference is actually very large even using the exact same pre-processed image. Which one is the right output?
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModel
import torch, timm
model_hf = AutoModel.from_pretrained("google/siglip-so400m-patch14-384").vision_model
processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384")
model_timm = timm.create_model("vit_so400m_patch14_siglip_384")
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
with torch.no_grad():
image_p = processor(images=[image], return_tensors='pt')['pixel_values']
out_hf = model_hf(pixel_values=image_p).pooler_output
out_timm = model_timm(image_p)
print('out_hf', out_hf) #tensor([[ 0.0042, -0.3039, -0.2630, ..., -0.0633, 0.1298, 0.2565]])
print('out_timm', out_timm) #tensor([[ 0.0442, -0.0952, 0.3953, ..., 0.7811, 0.1384, -0.6316]])
print('MEA', torch.abs(out_hf - out_timm).mean()) #0.6619
Hi,
Note that timm doesn't use evaluation mode by default, do you get deterministic outputs from it?
Note that I've only compared logits against the original JAX implementation, not Timm, and that the pixel values are slightly different due to the original implementation using tf.image.resize
, which does not have an equivalent Pillow/Numpy/PyTorch implementation.
I did this one by one for each model, so I'm sure the implementation is correct, but could take a look in the next few days
Thanks for your reply!
Looks like timm doesn't load the pre-trained model by default, that was the main issue. After the fix, the output are close enough, but not exactly the same.
Closing this issue.
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModel
import torch, timm
model_hf = AutoModel.from_pretrained("google/siglip-so400m-patch14-384").vision_model.eval()
processor = AutoProcessor.from_pretrained("google/siglip-so400m-patch14-384")
model_timm = timm.create_model("vit_so400m_patch14_siglip_384", pretrained=True).eval()
# data_config = timm.data.resolve_model_data_config(model)
# transforms = timm.data.create_transform(**data_config, is_training=False)
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
with torch.no_grad():
image_p = processor(images=[image], return_tensors='pt')['pixel_values']
out_hf = model_hf(pixel_values=image_p).pooler_output
out_timm = model_timm(image_p)
print('out_hf', out_hf) #tensor([[ 0.0042, -0.3039, -0.2630, ..., -0.0633, 0.1298, 0.2565]])
print('out_timm', out_timm) #tensor([[-0.0024, -0.3045, -0.2662, ..., -0.0611, 0.1315, 0.2575]])
print('MEA', torch.abs(out_hf - out_timm).mean()) #tensor(0.0026)