Bug in attention mask

#58
by lucasjin - opened

The perceiver atgtention mask, flash atten and eager are different behavior.

normally, the anttionmask to Attentionin

if attention_mask != None:
latent_attention_mask = torch.ones(
(attention_mask.size(0), latents.size(1)),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)
attention_mask = (
_prepare_4d_attention_mask(
attention_mask, latents.dtype, tgt_len=self.n_latents
)
if not self._use_flash_attention_2
else attention_mask
)

it should be 0 reperesnets padding, and 1 presents normal values, this is right when using flash attention, but _prepare_4d_attention_mask actually will invert it.

make the attention mask be, 0 represnets normalvalues, and 1 represents padding value.

Then in attention forward:

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)

        attn_weights = attn_weights + attention_mask

this catually make the score hgiher for padding valeus after softmax.

Have guys tested eager behaivor>?

Can u confirm it is a bug or am wrong? thanks@

in eager mode, the attention mask values I get look like this:

Screenshot 2024-05-24 at 4.54.16 PM.png

0 for "True" (i.e. attending to that position) and -inf for "False" (i.e. not attending to that position)

Sign up or log in to comment