Warning "It is strongly recommended to train Gemma2 models with the `eager` attention implementation instead of `sdpa`"

#7
by jesusgs01 - opened

Hello,

I am trying to finetuning the Paligemma2. I have the following code:

model_id = "google/paligemma2-3b-pt-224"

model = PaliGemmaForConditionalGeneration.from_pretrained(
        model_id, 
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        device_map="auto",
    )

processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer.padding_side = "left"
processor.tokenizer.pad_token = processor.tokenizer.eos_token

return model, processor

When I start my training I have the following warning:

It is strongly recommended to train Gemma2 models with the eager attention implementation instead of sdpa. Use eager with AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`.

I assume this might impact the model's performance and convergence. Does anyone know how to properly change this setting?

Thanks in advance!

Google org

Hi @jesusgs01 ,

Using eager attention can help avoid NaN issues during training, but it might cause the model to converge faster which sometimes leads to overfitting. On the other hand, flash_attention_2 (or sdpa) may train more slowly but tends to generalize better.

To switch between these modes, you can set the attn_implementation parameter when loading the model with PaliGemmaForConditionalGeneration.from_pretrained().

  • Use attn_implementation="eager" for eager attention
  • Use attn_implementation="sdpa" to enable SDPA (scaled dot product attention)

For more details, could you please check this reference.

Thank you.

Sign up or log in to comment