leejunhyeok commited on
Commit
a55dcfd
·
verified ·
1 Parent(s): 91c40ce

Update modeling_motif.py

Browse files
Files changed (1) hide show
  1. modeling_motif.py +2 -13
modeling_motif.py CHANGED
@@ -839,7 +839,7 @@ MOTIF_ATTENTION_CLASSES = {
839
 
840
  class MotifDecoderLayer(nn.Module):
841
 
842
- def __init__(self, config: MotifConfig, moe_layer: bool, layer_idx: int):
843
  super().__init__()
844
  self.hidden_size = config.hidden_size
845
  if config.use_moreh_attention:
@@ -853,10 +853,6 @@ class MotifDecoderLayer(nn.Module):
853
  else:
854
  self.self_attn = MOTIF_ATTENTION_CLASSES["eager"](config, layer_idx)
855
  self.mlp = MotifMLP(config)
856
- ### moe
857
- self.moe = None
858
- if moe_layer:
859
- self.moe = MotifMoE(config)
860
 
861
  RMSNorm = MorehRMSNorm if MorehRMSNorm is not None else MotifRMSNorm
862
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -927,13 +923,7 @@ class MotifDecoderLayer(nn.Module):
927
  residual = hidden_states
928
  hidden_states = self.post_attention_layernorm(hidden_states) * self.post_attention_layernorm_alpha
929
 
930
- if self.moe is not None:
931
- hidden_states, identity = self.moe(hidden_states)
932
- ## add output of shared expert and output of small moe experts.
933
- ## hidden state must be zero tensor (for first forward)
934
- hidden_states += self.mlp(identity)
935
- else:
936
- hidden_states = self.mlp(hidden_states)
937
 
938
  hidden_states = residual + hidden_states
939
 
@@ -1114,7 +1104,6 @@ class MotifModel(MotifPreTrainedModel):
1114
 
1115
  num_hidden_layers = config.num_hidden_layers if self.multi_token_heads is None else config.num_hidden_layers - 1
1116
 
1117
- logger.info(f'current_moe layer { moe_layer }')
1118
  self.layers = nn.ModuleList([
1119
  MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)
1120
  ])
 
839
 
840
  class MotifDecoderLayer(nn.Module):
841
 
842
+ def __init__(self, config: MotifConfig, layer_idx: int):
843
  super().__init__()
844
  self.hidden_size = config.hidden_size
845
  if config.use_moreh_attention:
 
853
  else:
854
  self.self_attn = MOTIF_ATTENTION_CLASSES["eager"](config, layer_idx)
855
  self.mlp = MotifMLP(config)
 
 
 
 
856
 
857
  RMSNorm = MorehRMSNorm if MorehRMSNorm is not None else MotifRMSNorm
858
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
923
  residual = hidden_states
924
  hidden_states = self.post_attention_layernorm(hidden_states) * self.post_attention_layernorm_alpha
925
 
926
+ hidden_states = self.mlp(hidden_states)
 
 
 
 
 
 
927
 
928
  hidden_states = residual + hidden_states
929
 
 
1104
 
1105
  num_hidden_layers = config.num_hidden_layers if self.multi_token_heads is None else config.num_hidden_layers - 1
1106
 
 
1107
  self.layers = nn.ModuleList([
1108
  MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)
1109
  ])