eunhwanpark-motiftech commited on
Commit
c4b7f5e
·
verified ·
1 Parent(s): cce6844

Update modeling_motif.py

Browse files
Files changed (1) hide show
  1. modeling_motif.py +12 -48
modeling_motif.py CHANGED
@@ -328,23 +328,10 @@ class MotifMLP(nn.Module):
328
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
329
  self.act_fn = ACT2FN[config.hidden_act]
330
 
331
- if config.wesar_weights:
332
- self.gate_up_proj_alpha = nn.Parameter(torch.tensor(1) *config.gate_up_proj_alpha)
333
- self.down_proj_alpha = nn.Parameter(torch.tensor(1) * config.down_proj_alpha)
334
- else:
335
- self.gate_up_proj_alpha=1
336
- self.down_proj_alpha=1
337
- if config.muP:
338
- self.down_proj.__do_scale_tager__ = True
339
- self.gate_proj.__do_scale_tager_mu_dim_model__ = True
340
- self.up_proj.__do_scale_tager_mu_dim_model__ = True
341
- self.down_proj.__do_scale_tager_mu_ffn__ = True
342
-
343
-
344
  def forward(self, hidden_state):
345
- hidden_state = hidden_state*self.gate_up_proj_alpha
346
  #hidden_state = self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))*
347
- return self.down_proj_alpha*self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
348
 
349
 
350
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@@ -470,13 +457,6 @@ class MotifAttention(nn.Module):
470
  max_position_embeddings=self.max_position_embeddings,
471
  base=self.rope_theta)
472
 
473
- for param in ["q_proj_alpha", "k_proj_alpha", "v_proj_alpha", "o_proj_alpha"]:
474
- setattr(
475
- self, param,
476
- nn.Parameter(torch.tensor(getattr(config, param, 1.0), dtype=torch.float))
477
- if config.wesar_weights else 1.0)
478
-
479
-
480
  def forward(
481
  self,
482
  hidden_states: torch.Tensor,
@@ -490,9 +470,9 @@ class MotifAttention(nn.Module):
490
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
491
  bsz, q_len, _ = hidden_states.size()
492
 
493
- query_states = self.q_proj(hidden_states) * self.q_proj_alpha
494
- key_states = self.k_proj(hidden_states) * self.k_proj_alpha
495
- value_states = self.v_proj(hidden_states) * self.v_proj_alpha
496
 
497
  ## bsz, seq, n_heads, head_dim
498
 
@@ -685,9 +665,9 @@ class MotifFlashAttention2(MotifAttention):
685
  ):
686
  bsz, q_len, _ = hidden_states.size()
687
 
688
- query_states = self.q_proj(hidden_states) * self.q_proj_alpha
689
- key_states = self.k_proj(hidden_states) * self.k_proj_alpha
690
- value_states = self.v_proj(hidden_states) * self.v_proj_alpha
691
 
692
  query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
693
  key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -798,7 +778,7 @@ class MotifFlashAttention2(MotifAttention):
798
  f" {attn_output.size()}")
799
 
800
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
801
- attn_output = self.o_proj(attn_output) * self.o_proj_alpha
802
 
803
  return attn_output, None, past_key_value
804
 
@@ -919,15 +899,6 @@ class MotifDecoderLayer(nn.Module):
919
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
920
  self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
921
 
922
- if config.wesar_weights and config.use_norm_alpha:
923
- self.input_layernorm_alpha = nn.Parameter(torch.tensor(1).float())
924
- else:
925
- self.input_layernorm_alpha = 1
926
-
927
- if config.wesar_weights and config.use_norm_alpha :
928
- self.post_attention_layernorm_alpha = nn.Parameter(torch.tensor(1).float())
929
- else:
930
- self.post_attention_layernorm_alpha = 1
931
 
932
  def forward(
933
  self,
@@ -965,7 +936,7 @@ class MotifDecoderLayer(nn.Module):
965
 
966
  residual = hidden_states
967
 
968
- hidden_states = self.input_layernorm(hidden_states) * self.input_layernorm_alpha
969
 
970
  # Self Attention
971
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
@@ -982,7 +953,7 @@ class MotifDecoderLayer(nn.Module):
982
 
983
  # Fully Connected
984
  residual = hidden_states
985
- hidden_states = self.post_attention_layernorm(hidden_states) * self.post_attention_layernorm_alpha
986
  hidden_states = self.mlp(hidden_states)
987
  hidden_states = residual + hidden_states
988
 
@@ -1199,14 +1170,7 @@ class MotifModel(MotifPreTrainedModel):
1199
  self.post_init()
1200
 
1201
  self.scale_emb = 1
1202
-
1203
- # Reparameterization <|_1_|>
1204
- if config.wesar_weights :
1205
- logger.info(f'config.wesar_weights {config.wesar_weights}')
1206
- self.norm_alpha = nn.Parameter(torch.tensor(1).float())
1207
- self.scale_emb = 10
1208
- else:
1209
- self.norm_alpha = 1
1210
 
1211
  def get_input_embeddings(self):
1212
  return self.embed_tokens
 
328
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
329
  self.act_fn = ACT2FN[config.hidden_act]
330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  def forward(self, hidden_state):
332
+ hidden_state = hidden_state
333
  #hidden_state = self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))*
334
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
335
 
336
 
337
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 
457
  max_position_embeddings=self.max_position_embeddings,
458
  base=self.rope_theta)
459
 
 
 
 
 
 
 
 
460
  def forward(
461
  self,
462
  hidden_states: torch.Tensor,
 
470
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
471
  bsz, q_len, _ = hidden_states.size()
472
 
473
+ query_states = self.q_proj(hidden_states)
474
+ key_states = self.k_proj(hidden_states)
475
+ value_states = self.v_proj(hidden_states)
476
 
477
  ## bsz, seq, n_heads, head_dim
478
 
 
665
  ):
666
  bsz, q_len, _ = hidden_states.size()
667
 
668
+ query_states = self.q_proj(hidden_states)
669
+ key_states = self.k_proj(hidden_states)
670
+ value_states = self.v_proj(hidden_states)
671
 
672
  query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
673
  key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
778
  f" {attn_output.size()}")
779
 
780
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
781
+ attn_output = self.o_proj(attn_output)
782
 
783
  return attn_output, None, past_key_value
784
 
 
899
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
900
  self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
901
 
 
 
 
 
 
 
 
 
 
902
 
903
  def forward(
904
  self,
 
936
 
937
  residual = hidden_states
938
 
939
+ hidden_states = self.input_layernorm(hidden_states)
940
 
941
  # Self Attention
942
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
 
953
 
954
  # Fully Connected
955
  residual = hidden_states
956
+ hidden_states = self.post_attention_layernorm(hidden_states)
957
  hidden_states = self.mlp(hidden_states)
958
  hidden_states = residual + hidden_states
959
 
 
1170
  self.post_init()
1171
 
1172
  self.scale_emb = 1
1173
+ self.norm_alpha = 1
 
 
 
 
 
 
 
1174
 
1175
  def get_input_embeddings(self):
1176
  return self.embed_tokens