RuntimeError: number of heads in query/key/value should match
Hello! Very new to running models locally, and I followed the below steps:
Hardware: 2023 Macbook Pro w/ M3 Pro chip
- set up python env with HuggingFace cli & token
- requirements I installed in venv are: transformers, langchain, langchain-huggingface, torch, torchvision, torchaudio, pillow, timm
- main.py in venv is configured to load model using pipelines from tranformers library (just used the code from use this model> transformers)
When executing main.py, seems google/gemma-3n-E4B-it downloads on my system alright, but I have a
"RuntimeError: number of heads in query/key/value should match".
Still learning how to build with models, but from what I understand, is tis caused by an issue with the model and my setup having a mismatch. Any tips?
Greetings,
I came here to report this same problem; after some testing this is specific to the mps
device type. If I specify cpu
the same code works fine, albeit very slowly. If I use cuda
on an appropriate system, it also works. For some reason this model, potentially and timm
(which is where the exception seems to be coming from) do not work on the mps
device type.
I'm fine filing a bug where the problem is (if I can figure it out), but this is definitely an issue with running this model on Mac hardware. Not that the issue is just with this model, but it's the first place I've run into it.
-- Morgan
Greetings,
After further investigation, this is a problem in the mps
code inside PyTorch. It is not present in the preview builds:
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
so a future PyTorch release will not have this problem anymore. (Despite it saying 'cpu' above, it will install a version that includes the MPS kernel.)
Hope that helps others who run into these problems!
-- Morgan
Speed difference between mps and cpu:
(multimodal) mrs@nightfox multimodal % time python multimodal.py
Loading checkpoint shards: 100%|███████████████| 4/4 [00:00<00:00, 31.77it/s]
Device set to use mps:0
The animal on the candy is a **frog**.
python multimodal.py 3.47s user 2.90s system 84% cpu 7.540 total
...
(multimodal) mrs@nightfox multimodal % time python multimodal.py
Loading checkpoint shards: 100%|███████████████| 4/4 [00:00<00:00, 30.72it/s]
Device set to use cpu
The animal on the candy is a **frog**. You can see the distinctive shape of a frog's head and eyes on the green and teal candies.
python multimodal.py 282.30s user 12.24s system 137% cpu 3:33.83 total
Notably the animal on the candy is not a frog, it's a turtle, but it's not the accuracy that's the issue here. 🤣