Update modeling_motif.py
Browse files- modeling_motif.py +12 -48
modeling_motif.py
CHANGED
@@ -328,23 +328,10 @@ class MotifMLP(nn.Module):
|
|
328 |
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
329 |
self.act_fn = ACT2FN[config.hidden_act]
|
330 |
|
331 |
-
if config.wesar_weights:
|
332 |
-
self.gate_up_proj_alpha = nn.Parameter(torch.tensor(1) *config.gate_up_proj_alpha)
|
333 |
-
self.down_proj_alpha = nn.Parameter(torch.tensor(1) * config.down_proj_alpha)
|
334 |
-
else:
|
335 |
-
self.gate_up_proj_alpha=1
|
336 |
-
self.down_proj_alpha=1
|
337 |
-
if config.muP:
|
338 |
-
self.down_proj.__do_scale_tager__ = True
|
339 |
-
self.gate_proj.__do_scale_tager_mu_dim_model__ = True
|
340 |
-
self.up_proj.__do_scale_tager_mu_dim_model__ = True
|
341 |
-
self.down_proj.__do_scale_tager_mu_ffn__ = True
|
342 |
-
|
343 |
-
|
344 |
def forward(self, hidden_state):
|
345 |
-
hidden_state = hidden_state
|
346 |
#hidden_state = self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))*
|
347 |
-
return self.
|
348 |
|
349 |
|
350 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
@@ -470,13 +457,6 @@ class MotifAttention(nn.Module):
|
|
470 |
max_position_embeddings=self.max_position_embeddings,
|
471 |
base=self.rope_theta)
|
472 |
|
473 |
-
for param in ["q_proj_alpha", "k_proj_alpha", "v_proj_alpha", "o_proj_alpha"]:
|
474 |
-
setattr(
|
475 |
-
self, param,
|
476 |
-
nn.Parameter(torch.tensor(getattr(config, param, 1.0), dtype=torch.float))
|
477 |
-
if config.wesar_weights else 1.0)
|
478 |
-
|
479 |
-
|
480 |
def forward(
|
481 |
self,
|
482 |
hidden_states: torch.Tensor,
|
@@ -490,9 +470,9 @@ class MotifAttention(nn.Module):
|
|
490 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
491 |
bsz, q_len, _ = hidden_states.size()
|
492 |
|
493 |
-
query_states = self.q_proj(hidden_states)
|
494 |
-
key_states = self.k_proj(hidden_states)
|
495 |
-
value_states = self.v_proj(hidden_states)
|
496 |
|
497 |
## bsz, seq, n_heads, head_dim
|
498 |
|
@@ -685,9 +665,9 @@ class MotifFlashAttention2(MotifAttention):
|
|
685 |
):
|
686 |
bsz, q_len, _ = hidden_states.size()
|
687 |
|
688 |
-
query_states = self.q_proj(hidden_states)
|
689 |
-
key_states = self.k_proj(hidden_states)
|
690 |
-
value_states = self.v_proj(hidden_states)
|
691 |
|
692 |
query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
|
693 |
key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
@@ -798,7 +778,7 @@ class MotifFlashAttention2(MotifAttention):
|
|
798 |
f" {attn_output.size()}")
|
799 |
|
800 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
801 |
-
attn_output = self.o_proj(attn_output)
|
802 |
|
803 |
return attn_output, None, past_key_value
|
804 |
|
@@ -919,15 +899,6 @@ class MotifDecoderLayer(nn.Module):
|
|
919 |
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
920 |
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
921 |
|
922 |
-
if config.wesar_weights and config.use_norm_alpha:
|
923 |
-
self.input_layernorm_alpha = nn.Parameter(torch.tensor(1).float())
|
924 |
-
else:
|
925 |
-
self.input_layernorm_alpha = 1
|
926 |
-
|
927 |
-
if config.wesar_weights and config.use_norm_alpha :
|
928 |
-
self.post_attention_layernorm_alpha = nn.Parameter(torch.tensor(1).float())
|
929 |
-
else:
|
930 |
-
self.post_attention_layernorm_alpha = 1
|
931 |
|
932 |
def forward(
|
933 |
self,
|
@@ -965,7 +936,7 @@ class MotifDecoderLayer(nn.Module):
|
|
965 |
|
966 |
residual = hidden_states
|
967 |
|
968 |
-
hidden_states = self.input_layernorm(hidden_states)
|
969 |
|
970 |
# Self Attention
|
971 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
@@ -982,7 +953,7 @@ class MotifDecoderLayer(nn.Module):
|
|
982 |
|
983 |
# Fully Connected
|
984 |
residual = hidden_states
|
985 |
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
986 |
hidden_states = self.mlp(hidden_states)
|
987 |
hidden_states = residual + hidden_states
|
988 |
|
@@ -1199,14 +1170,7 @@ class MotifModel(MotifPreTrainedModel):
|
|
1199 |
self.post_init()
|
1200 |
|
1201 |
self.scale_emb = 1
|
1202 |
-
|
1203 |
-
# Reparameterization <|_1_|>
|
1204 |
-
if config.wesar_weights :
|
1205 |
-
logger.info(f'config.wesar_weights {config.wesar_weights}')
|
1206 |
-
self.norm_alpha = nn.Parameter(torch.tensor(1).float())
|
1207 |
-
self.scale_emb = 10
|
1208 |
-
else:
|
1209 |
-
self.norm_alpha = 1
|
1210 |
|
1211 |
def get_input_embeddings(self):
|
1212 |
return self.embed_tokens
|
|
|
328 |
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
329 |
self.act_fn = ACT2FN[config.hidden_act]
|
330 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
def forward(self, hidden_state):
|
332 |
+
hidden_state = hidden_state
|
333 |
#hidden_state = self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))*
|
334 |
+
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
335 |
|
336 |
|
337 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
|
457 |
max_position_embeddings=self.max_position_embeddings,
|
458 |
base=self.rope_theta)
|
459 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
460 |
def forward(
|
461 |
self,
|
462 |
hidden_states: torch.Tensor,
|
|
|
470 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
471 |
bsz, q_len, _ = hidden_states.size()
|
472 |
|
473 |
+
query_states = self.q_proj(hidden_states)
|
474 |
+
key_states = self.k_proj(hidden_states)
|
475 |
+
value_states = self.v_proj(hidden_states)
|
476 |
|
477 |
## bsz, seq, n_heads, head_dim
|
478 |
|
|
|
665 |
):
|
666 |
bsz, q_len, _ = hidden_states.size()
|
667 |
|
668 |
+
query_states = self.q_proj(hidden_states)
|
669 |
+
key_states = self.k_proj(hidden_states)
|
670 |
+
value_states = self.v_proj(hidden_states)
|
671 |
|
672 |
query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
|
673 |
key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
778 |
f" {attn_output.size()}")
|
779 |
|
780 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
781 |
+
attn_output = self.o_proj(attn_output)
|
782 |
|
783 |
return attn_output, None, past_key_value
|
784 |
|
|
|
899 |
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
900 |
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
901 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
902 |
|
903 |
def forward(
|
904 |
self,
|
|
|
936 |
|
937 |
residual = hidden_states
|
938 |
|
939 |
+
hidden_states = self.input_layernorm(hidden_states)
|
940 |
|
941 |
# Self Attention
|
942 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
|
953 |
|
954 |
# Fully Connected
|
955 |
residual = hidden_states
|
956 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
957 |
hidden_states = self.mlp(hidden_states)
|
958 |
hidden_states = residual + hidden_states
|
959 |
|
|
|
1170 |
self.post_init()
|
1171 |
|
1172 |
self.scale_emb = 1
|
1173 |
+
self.norm_alpha = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1174 |
|
1175 |
def get_input_embeddings(self):
|
1176 |
return self.embed_tokens
|