leejunhyeok commited on
Commit
20b97f1
·
verified ·
1 Parent(s): 8855d03

Update modeling_motif.py

Browse files
Files changed (1) hide show
  1. modeling_motif.py +13 -10
modeling_motif.py CHANGED
@@ -545,16 +545,19 @@ class MotifFlashAttention2(MotifAttention):
545
 
546
  bsz = query_states.shape[0]
547
 
548
- return _flash_attention_forward(query_states.bfloat16(),
549
- key_states.bfloat16(),
550
- value_states.bfloat16(),
551
- attention_mask,
552
- q_len,
553
- position_ids=position_ids,
554
- dropout=dropout_rate,
555
- sliding_window=sliding_window,
556
- is_causal=self.is_causal,
557
- use_top_left_mask=self._flash_attn_uses_top_left_mask)
 
 
 
558
 
559
  def forward(
560
  self,
 
545
 
546
  bsz = query_states.shape[0]
547
 
548
+ return map(
549
+ lambda x: x.float32(),
550
+ _flash_attention_forward(query_states.bfloat16(),
551
+ key_states.bfloat16(),
552
+ value_states.bfloat16(),
553
+ attention_mask,
554
+ q_len,
555
+ position_ids=position_ids,
556
+ dropout=dropout_rate,
557
+ sliding_window=sliding_window,
558
+ is_causal=self.is_causal,
559
+ use_top_left_mask=self._flash_attn_uses_top_left_mask)
560
+ )
561
 
562
  def forward(
563
  self,