Gemma-3-27B-it not loading on multi-GPU setup
BLUF: Gemma3fForCausalLM doesn't map the model across multiple GPUs
I'm trying to run Gemma-3-27B-it on a linux cluster with 4 A10 GPUs; I don't have access to a A100 cluster, and need this to work on a 4 GPU setup.
My code is a cut-and-paste from the Gemma blog, section "For LLM-only model inference, ...", except I replace the checkpoint with 27B:
import torch
from transformers import AutoTokenizer, Gemma3ForCausalLM
ckpt = "google/gemma-3-27b-it" # <-- this is "google/gemma-3-4b-it" in the blog post
model = Gemma3ForCausalLM.from_pretrained(
ckpt, torch_dtype=torch.bfloat16, device_map="auto"
)
The model gets as far as loading the first 4 checkpoint shards then throws an OOM error because it continues to try to use GPU 0.
OutOfMemoryError: CUDA out of memory. Tried to allocate 222.00 MiB. GPU 0 has a total capacity of 21.99 GiB of which 209.38 MiB is free.
...
File /local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.12/site-packages/transformers/integrations/tensor_parallel.py:312, in get_tensor_shard(param, empty_param, device_mesh, rank, dim)
309 slice_indices = [slice(None)] * param_dim
310 slice_indices[dim] = slice(start, end)
--> 312 return param[tuple(slice_indices)]
Why doesn't the model map over all 4 GPUs? Isn't that what the device_map="auto"
flag is for?
transformers version 4.52.4, torch version 2.6.0+cu124
Hi @glawyer ,
The device_map="auto" functionality heavily relies on the accelerate library. If accelerate is not installed or properly configured, device_map="auto" might default to trying to load the entire model onto the first available GPU (GPU 0), leading to an OOM error.
Kindly try this steps :
Ensure accelerate is installed and configured: If you haven't already, install accelerate:
pip install accelerate -U
Then, run the accelerate configuration wizard to set up your environment:accelerate config
Follow the prompts. For a multi-GPU setup, you'll typically select multi-GPU and ensure it detects all your devices.Utilize Quantization (Highly Recommended for A10 GPUs):
For a 27B model on A10s (24GB VRAM), even with bfloat16, memory can be tight. Quantization significantly reduces memory footprint. The bitsandbytes library provides excellent 4-bit and 8-bit quantization options.If quantization isn't desired or sufficient, try specifying max_memory with device_map="auto".
Ensure all your libraries (transformers, torch, accelerate) are up to date.
Start with step 1 and 2 (quantization) as they are the most likely to resolve your OOM issue effectively on your A10 setup. Kindly try and let us know still if your facing any issues will assist you. Thank you.