Update modeling_motif.py
Browse filesremove is_moreh_attention, batch_num
- modeling_motif.py +8 -45
modeling_motif.py
CHANGED
@@ -396,11 +396,7 @@ class MotifAttention(nn.Module):
|
|
396 |
self.rope_theta = config.rope_theta
|
397 |
self.is_causal = True
|
398 |
self.attention_dropout = config.attention_dropout
|
399 |
-
|
400 |
-
self.batch_num = config.batch_num
|
401 |
-
logger.info(f'self.batcn_num : {self.batch_num}')
|
402 |
-
except:
|
403 |
-
self.batch_num = None
|
404 |
if (self.head_dim * self.num_heads) != self.hidden_size:
|
405 |
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
406 |
f" and `num_heads`: {self.num_heads}).")
|
@@ -556,7 +552,7 @@ class MotifFlashAttention2(MotifAttention):
|
|
556 |
return tensor.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
557 |
|
558 |
def _compute_attention(self, query_states, key_states, value_states, attention_mask, q_len, position_ids,
|
559 |
-
dropout_rate, sliding_window
|
560 |
"""Flash Attention 2 implements"""
|
561 |
|
562 |
scale_factor = 1.0 / math.sqrt(self.head_dim)
|
@@ -566,37 +562,7 @@ class MotifFlashAttention2(MotifAttention):
|
|
566 |
else:
|
567 |
causal = self.is_causal and q_len != 1
|
568 |
|
569 |
-
|
570 |
-
bsz = query_states.shape[0]
|
571 |
-
|
572 |
-
if batch_num:
|
573 |
-
query_states = query_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
|
574 |
-
key_states = key_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
|
575 |
-
value_states = value_states.reshape(bsz*q_len,self.num_heads,self.head_dim)
|
576 |
-
|
577 |
-
attn_out = moreh_ops.flash_attention_varlen_dp(query_states,
|
578 |
-
key_states,
|
579 |
-
value_states,
|
580 |
-
attention_mask,
|
581 |
-
attention_mask,
|
582 |
-
max_seqlen_q=q_len,
|
583 |
-
max_seqlen_kv=q_len,
|
584 |
-
dropout_p=dropout_rate,
|
585 |
-
softmax_scale=scale_factor,
|
586 |
-
is_causal=causal,
|
587 |
-
batch_num=batch_num)
|
588 |
-
attn_out = attn_out.reshape(bsz, q_len, self.num_heads, -1)
|
589 |
-
else:
|
590 |
-
return MorehFlashAttention(query_states,
|
591 |
-
key_states,
|
592 |
-
value_states,
|
593 |
-
padding_mask=attention_mask,
|
594 |
-
dropout_p=dropout_rate,
|
595 |
-
softmax_scale=scale_factor,
|
596 |
-
causal=causal)
|
597 |
-
return attn_out
|
598 |
-
else:
|
599 |
-
attn_out = _flash_attention_forward(query_states.bfloat16(),
|
600 |
key_states.bfloat16(),
|
601 |
value_states.bfloat16(),
|
602 |
attention_mask,
|
@@ -607,8 +573,7 @@ class MotifFlashAttention2(MotifAttention):
|
|
607 |
is_causal=True,
|
608 |
softmax_scale=scale_factor,
|
609 |
use_top_left_mask=self._flash_attn_uses_top_left_mask)
|
610 |
-
|
611 |
-
return attn_out.float()
|
612 |
|
613 |
def forward(
|
614 |
self,
|
@@ -709,12 +674,10 @@ class MotifFlashAttention2(MotifAttention):
|
|
709 |
k1, k2 = k1.contiguous(), k2.contiguous()
|
710 |
v1, v2 = v1.contiguous(), v2.contiguous()
|
711 |
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
self._compute_attention(
|
716 |
-
attn21, attn22 = self._compute_attention(q2, k2, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num), \
|
717 |
-
self._compute_attention(q2, k2, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window, is_moreh_attention, self.batch_num)
|
718 |
|
719 |
attn1, attn2 = torch.cat([attn11, attn12], dim=-1), torch.cat([attn21, attn22], dim=-1)
|
720 |
|
|
|
396 |
self.rope_theta = config.rope_theta
|
397 |
self.is_causal = True
|
398 |
self.attention_dropout = config.attention_dropout
|
399 |
+
|
|
|
|
|
|
|
|
|
400 |
if (self.head_dim * self.num_heads) != self.hidden_size:
|
401 |
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
402 |
f" and `num_heads`: {self.num_heads}).")
|
|
|
552 |
return tensor.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
|
553 |
|
554 |
def _compute_attention(self, query_states, key_states, value_states, attention_mask, q_len, position_ids,
|
555 |
+
dropout_rate, sliding_window):
|
556 |
"""Flash Attention 2 implements"""
|
557 |
|
558 |
scale_factor = 1.0 / math.sqrt(self.head_dim)
|
|
|
562 |
else:
|
563 |
causal = self.is_causal and q_len != 1
|
564 |
|
565 |
+
attn_out = _flash_attention_forward(query_states.bfloat16(),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
566 |
key_states.bfloat16(),
|
567 |
value_states.bfloat16(),
|
568 |
attention_mask,
|
|
|
573 |
is_causal=True,
|
574 |
softmax_scale=scale_factor,
|
575 |
use_top_left_mask=self._flash_attn_uses_top_left_mask)
|
576 |
+
return attn_out.float()
|
|
|
577 |
|
578 |
def forward(
|
579 |
self,
|
|
|
674 |
k1, k2 = k1.contiguous(), k2.contiguous()
|
675 |
v1, v2 = v1.contiguous(), v2.contiguous()
|
676 |
|
677 |
+
attn11, attn12 = self._compute_attention(q1, k1, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window), \
|
678 |
+
self._compute_attention(q1, k1, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window)
|
679 |
+
attn21, attn22 = self._compute_attention(q2, k2, v1, attention_mask, q_len, position_ids, dropout_rate, sliding_window), \
|
680 |
+
self._compute_attention(q2, k2, v2, attention_mask, q_len, position_ids, dropout_rate, sliding_window)
|
|
|
|
|
681 |
|
682 |
attn1, attn2 = torch.cat([attn11, attn12], dim=-1), torch.cat([attn21, attn22], dim=-1)
|
683 |
|