leejunhyeok commited on
Commit
8855d03
·
verified ·
1 Parent(s): 38eae03

Update modeling_motif.py

Browse files
Files changed (1) hide show
  1. modeling_motif.py +3 -25
modeling_motif.py CHANGED
@@ -545,9 +545,9 @@ class MotifFlashAttention2(MotifAttention):
545
 
546
  bsz = query_states.shape[0]
547
 
548
- return _flash_attention_forward(query_states,
549
- key_states,
550
- value_states,
551
  attention_mask,
552
  q_len,
553
  position_ids=position_ids,
@@ -604,28 +604,6 @@ class MotifFlashAttention2(MotifAttention):
604
  value_states = repeat_kv(value_states, self.num_key_value_groups)
605
  dropout_rate = 0.0 if not self.training else self.attention_dropout
606
 
607
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
608
- # therefore the input hidden states gets silently casted in float32. Hence, we need
609
- # cast them back in float16 just to be sure everything works as expected.
610
- input_dtype = query_states.dtype
611
- if input_dtype == torch.float32:
612
- if torch.is_autocast_enabled():
613
- target_dtype = torch.get_autocast_gpu_dtype()
614
- # Handle the case where the model is quantized
615
- elif hasattr(self.config, "_pre_quantization_dtype"):
616
- target_dtype = self.config._pre_quantization_dtype
617
- else:
618
- target_dtype = self.q_proj.weight.dtype
619
-
620
- logger.warning_once(
621
- f"The input hidden states seems to be silently casted in float32, this might be related to"
622
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
623
- f" {target_dtype}.")
624
-
625
- query_states = query_states.to(target_dtype)
626
- key_states = key_states.to(target_dtype)
627
- value_states = value_states.to(target_dtype)
628
-
629
  q_len = query_states.shape[-2]
630
  kv_seq_len = key_states.shape[-2]
631
 
 
545
 
546
  bsz = query_states.shape[0]
547
 
548
+ return _flash_attention_forward(query_states.bfloat16(),
549
+ key_states.bfloat16(),
550
+ value_states.bfloat16(),
551
  attention_mask,
552
  q_len,
553
  position_ids=position_ids,
 
604
  value_states = repeat_kv(value_states, self.num_key_value_groups)
605
  dropout_rate = 0.0 if not self.training else self.attention_dropout
606
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607
  q_len = query_states.shape[-2]
608
  kv_seq_len = key_states.shape[-2]
609