Warning "It is strongly recommended to train Gemma2 models with the `eager` attention implementation instead of `sdpa`"
Hello,
I am trying to finetuning the Paligemma2. I have the following code:
model_id = "google/paligemma2-3b-pt-224"
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
device_map="auto",
)
processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer.padding_side = "left"
processor.tokenizer.pad_token = processor.tokenizer.eos_token
return model, processor
When I start my training I have the following warning:
It is strongly recommended to train Gemma2 models with the eager
attention implementation instead of sdpa
. Use eager
with AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`.
I assume this might impact the model's performance and convergence. Does anyone know how to properly change this setting?
Thanks in advance!
Hi @jesusgs01 ,
Using eager attention can help avoid NaN issues during training, but it might cause the model to converge faster which sometimes leads to overfitting. On the other hand, flash_attention_2 (or sdpa
) may train more slowly but tends to generalize better.
To switch between these modes, you can set the attn_implementation
parameter when loading the model with PaliGemmaForConditionalGeneration.from_pretrained()
.
- Use
attn_implementation="eager"
for eager attention - Use
attn_implementation="sdpa"
to enable SDPA (scaled dot product attention)
For more details, could you please check this reference.
Thank you.