leejunhyeok commited on
Commit
097873e
·
verified ·
1 Parent(s): 8bdf2ec

Update modeling_motif.py

Browse files
Files changed (1) hide show
  1. modeling_motif.py +2 -2
modeling_motif.py CHANGED
@@ -493,7 +493,7 @@ class MotifFlashAttention2(MotifAttention):
493
  def _compute_attention(self, query_states, key_states, value_states, attention_mask, q_len, position_ids,
494
  dropout_rate, sliding_window):
495
  """Flash Attention 2 implements"""
496
-
497
  scale_factor = 1.0 / math.sqrt(self.head_dim)
498
  if not self._flash_attn_uses_top_left_mask:
499
  causal = self.is_causal
@@ -511,7 +511,7 @@ class MotifFlashAttention2(MotifAttention):
511
  is_causal=True,
512
  softmax_scale=scale_factor,
513
  use_top_left_mask=self._flash_attn_uses_top_left_mask)
514
- return attn_out.float()
515
 
516
  def forward(
517
  self,
 
493
  def _compute_attention(self, query_states, key_states, value_states, attention_mask, q_len, position_ids,
494
  dropout_rate, sliding_window):
495
  """Flash Attention 2 implements"""
496
+ _input_type = query_states.dtype
497
  scale_factor = 1.0 / math.sqrt(self.head_dim)
498
  if not self._flash_attn_uses_top_left_mask:
499
  causal = self.is_causal
 
511
  is_causal=True,
512
  softmax_scale=scale_factor,
513
  use_top_left_mask=self._flash_attn_uses_top_left_mask)
514
+ return attn_out.to(_input_type)
515
 
516
  def forward(
517
  self,