leejunhyeok commited on
Commit
2a76ec8
·
verified ·
1 Parent(s): a55dcfd

Update modeling_motif.py

Browse files
Files changed (1) hide show
  1. modeling_motif.py +0 -27
modeling_motif.py CHANGED
@@ -571,33 +571,6 @@ class MotifFlashAttention2(MotifAttention):
571
 
572
  bsz = query_states.shape[0]
573
 
574
- if batch_num:
575
- query_states = query_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
576
- key_states = key_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
577
- value_states = value_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
578
-
579
- attn_out = moreh_ops.flash_attention_varlen_dp(query_states,
580
- key_states,
581
- value_states,
582
- attention_mask,
583
- attention_mask,
584
- max_seqlen_q=q_len,
585
- max_seqlen_kv=q_len,
586
- dropout_p=dropout_rate,
587
- softmax_scale=scale_factor,
588
- is_causal=causal,
589
- batch_num=batch_num)
590
- attn_out = attn_out.reshape(bsz, q_len, self.num_heads, -1)
591
- else:
592
- return MorehFlashAttention(query_states,
593
- key_states,
594
- value_states,
595
- padding_mask=attention_mask,
596
- dropout_p=dropout_rate,
597
- softmax_scale=scale_factor,
598
- causal=causal)
599
- return attn_out
600
- else:
601
  return _flash_attention_forward(query_states,
602
  key_states,
603
  value_states,
 
571
 
572
  bsz = query_states.shape[0]
573
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
  return _flash_attention_forward(query_states,
575
  key_states,
576
  value_states,