eunhwanpark-motiftech commited on
Commit
27913c0
·
verified ·
1 Parent(s): eb7a67b

Update modeling_motif.py

Browse files

remove is_moreh_attention, batch_num

Files changed (1) hide show
  1. modeling_motif.py +8 -45
modeling_motif.py CHANGED
@@ -396,11 +396,7 @@ class MotifAttention(nn.Module):
396
  self.rope_theta = config.rope_theta
397
  self.is_causal = True
398
  self.attention_dropout = config.attention_dropout
399
- try:
400
- self.batch_num = config.batch_num
401
- logger.info(f'self.batcn_num : {self.batch_num}')
402
- except:
403
- self.batch_num = None
404
  if (self.head_dim * self.num_heads) != self.hidden_size:
405
  raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
406
  f" and `num_heads`: {self.num_heads}).")
@@ -556,7 +552,7 @@ class MotifFlashAttention2(MotifAttention):
556
  return tensor.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
557
 
558
  def _compute_attention(self, query_states, key_states, value_states, attention_mask, q_len, position_ids,
559
- dropout_rate, sliding_window, is_moreh_attention, batch_num):
560
  """Flash Attention 2 implements"""
561
 
562
  scale_factor = 1.0 / math.sqrt(self.head_dim)
@@ -566,37 +562,7 @@ class MotifFlashAttention2(MotifAttention):
566
  else:
567
  causal = self.is_causal and q_len != 1
568
 
569
- if is_moreh_attention:
570
- bsz = query_states.shape[0]
571
-
572
- if batch_num:
573
- query_states = query_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
574
- key_states = key_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
575
- value_states = value_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
576
-
577
- attn_out = moreh_ops.flash_attention_varlen_dp(query_states,
578
- key_states,
579
- value_states,
580
- attention_mask,
581
- attention_mask,
582
- max_seqlen_q=q_len,
583
- max_seqlen_kv=q_len,
584
- dropout_p=dropout_rate,
585
- softmax_scale=scale_factor,
586
- is_causal=causal,
587
- batch_num=batch_num)
588
- attn_out = attn_out.reshape(bsz, q_len, self.num_heads, -1)
589
- else:
590
- return MorehFlashAttention(query_states,
591
- key_states,
592
- value_states,
593
- padding_mask=attention_mask,
594
- dropout_p=dropout_rate,
595
- softmax_scale=scale_factor,
596
- causal=causal)
597
- return attn_out
598
- else:
599
- attn_out = _flash_attention_forward(query_states.bfloat16(),
600
  key_states.bfloat16(),
601
  value_states.bfloat16(),
602
  attention_mask,
@@ -607,8 +573,7 @@ class MotifFlashAttention2(MotifAttention):
607
  is_causal=True,
608
  softmax_scale=scale_factor,
609
  use_top_left_mask=self._flash_attn_uses_top_left_mask)
610
- #logger.info(attn_out)
611
- return attn_out.float()
612
 
613
  def forward(
614
  self,
@@ -709,12 +674,10 @@ class MotifFlashAttention2(MotifAttention):
709
  k1, k2 = k1.contiguous(), k2.contiguous()
710
  v1, v2 = v1.contiguous(), v2.contiguous()
711
 
712
- is_moreh_attention = MorehFlashAttention is not None
713
-
714
- attn11, attn12 = self._compute_attention(q1, k1, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num), \
715
- self._compute_attention(q1, k1, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num)
716
- attn21, attn22 = self._compute_attention(q2, k2, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num), \
717
- self._compute_attention(q2, k2, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num)
718
 
719
  attn1, attn2 = torch.cat([attn11, attn12], dim=-1), torch.cat([attn21, attn22], dim=-1)
720
 
 
396
  self.rope_theta = config.rope_theta
397
  self.is_causal = True
398
  self.attention_dropout = config.attention_dropout
399
+
 
 
 
 
400
  if (self.head_dim * self.num_heads) != self.hidden_size:
401
  raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
402
  f" and `num_heads`: {self.num_heads}).")
 
552
  return tensor.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
553
 
554
  def _compute_attention(self, query_states, key_states, value_states, attention_mask, q_len, position_ids,
555
+ dropout_rate, sliding_window):
556
  """Flash Attention 2 implements"""
557
 
558
  scale_factor = 1.0 / math.sqrt(self.head_dim)
 
562
  else:
563
  causal = self.is_causal and q_len != 1
564
 
565
+ attn_out = _flash_attention_forward(query_states.bfloat16(),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
  key_states.bfloat16(),
567
  value_states.bfloat16(),
568
  attention_mask,
 
573
  is_causal=True,
574
  softmax_scale=scale_factor,
575
  use_top_left_mask=self._flash_attn_uses_top_left_mask)
576
+ return attn_out.float()
 
577
 
578
  def forward(
579
  self,
 
674
  k1, k2 = k1.contiguous(), k2.contiguous()
675
  v1, v2 = v1.contiguous(), v2.contiguous()
676
 
677
+ attn11, attn12 = self._compute_attention(q1, k1, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window), \
678
+ self._compute_attention(q1, k1, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window)
679
+ attn21, attn22 = self._compute_attention(q2, k2, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window), \
680
+ self._compute_attention(q2, k2, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window)
 
 
681
 
682
  attn1, attn2 = torch.cat([attn11, attn12], dim=-1), torch.cat([attn21, attn22], dim=-1)
683