Update modeling_motif.py
Browse files- modeling_motif.py +0 -27
modeling_motif.py
CHANGED
@@ -571,33 +571,6 @@ class MotifFlashAttention2(MotifAttention):
|
|
571 |
|
572 |
bsz = query_states.shape[0]
|
573 |
|
574 |
-
if batch_num:
|
575 |
-
query_states = query_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
|
576 |
-
key_states = key_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
|
577 |
-
value_states = value_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
|
578 |
-
|
579 |
-
attn_out = moreh_ops.flash_attention_varlen_dp(query_states,
|
580 |
-
key_states,
|
581 |
-
value_states,
|
582 |
-
attention_mask,
|
583 |
-
attention_mask,
|
584 |
-
max_seqlen_q=q_len,
|
585 |
-
max_seqlen_kv=q_len,
|
586 |
-
dropout_p=dropout_rate,
|
587 |
-
softmax_scale=scale_factor,
|
588 |
-
is_causal=causal,
|
589 |
-
batch_num=batch_num)
|
590 |
-
attn_out = attn_out.reshape(bsz, q_len, self.num_heads, -1)
|
591 |
-
else:
|
592 |
-
return MorehFlashAttention(query_states,
|
593 |
-
key_states,
|
594 |
-
value_states,
|
595 |
-
padding_mask=attention_mask,
|
596 |
-
dropout_p=dropout_rate,
|
597 |
-
softmax_scale=scale_factor,
|
598 |
-
causal=causal)
|
599 |
-
return attn_out
|
600 |
-
else:
|
601 |
return _flash_attention_forward(query_states,
|
602 |
key_states,
|
603 |
value_states,
|
|
|
571 |
|
572 |
bsz = query_states.shape[0]
|
573 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
574 |
return _flash_attention_forward(query_states,
|
575 |
key_states,
|
576 |
value_states,
|