Update modeling_motif.py
Browse files- modeling_motif.py +2 -13
modeling_motif.py
CHANGED
@@ -839,7 +839,7 @@ MOTIF_ATTENTION_CLASSES = {
|
|
839 |
|
840 |
class MotifDecoderLayer(nn.Module):
|
841 |
|
842 |
-
def __init__(self, config: MotifConfig,
|
843 |
super().__init__()
|
844 |
self.hidden_size = config.hidden_size
|
845 |
if config.use_moreh_attention:
|
@@ -853,10 +853,6 @@ class MotifDecoderLayer(nn.Module):
|
|
853 |
else:
|
854 |
self.self_attn = MOTIF_ATTENTION_CLASSES["eager"](config, layer_idx)
|
855 |
self.mlp = MotifMLP(config)
|
856 |
-
### moe
|
857 |
-
self.moe = None
|
858 |
-
if moe_layer:
|
859 |
-
self.moe = MotifMoE(config)
|
860 |
|
861 |
RMSNorm = MorehRMSNorm if MorehRMSNorm is not None else MotifRMSNorm
|
862 |
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
@@ -927,13 +923,7 @@ class MotifDecoderLayer(nn.Module):
|
|
927 |
residual = hidden_states
|
928 |
hidden_states = self.post_attention_layernorm(hidden_states) * self.post_attention_layernorm_alpha
|
929 |
|
930 |
-
|
931 |
-
hidden_states, identity = self.moe(hidden_states)
|
932 |
-
## add output of shared expert and output of small moe experts.
|
933 |
-
## hidden state must be zero tensor (for first forward)
|
934 |
-
hidden_states += self.mlp(identity)
|
935 |
-
else:
|
936 |
-
hidden_states = self.mlp(hidden_states)
|
937 |
|
938 |
hidden_states = residual + hidden_states
|
939 |
|
@@ -1114,7 +1104,6 @@ class MotifModel(MotifPreTrainedModel):
|
|
1114 |
|
1115 |
num_hidden_layers = config.num_hidden_layers if self.multi_token_heads is None else config.num_hidden_layers - 1
|
1116 |
|
1117 |
-
logger.info(f'current_moe layer { moe_layer }')
|
1118 |
self.layers = nn.ModuleList([
|
1119 |
MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)
|
1120 |
])
|
|
|
839 |
|
840 |
class MotifDecoderLayer(nn.Module):
|
841 |
|
842 |
+
def __init__(self, config: MotifConfig, layer_idx: int):
|
843 |
super().__init__()
|
844 |
self.hidden_size = config.hidden_size
|
845 |
if config.use_moreh_attention:
|
|
|
853 |
else:
|
854 |
self.self_attn = MOTIF_ATTENTION_CLASSES["eager"](config, layer_idx)
|
855 |
self.mlp = MotifMLP(config)
|
|
|
|
|
|
|
|
|
856 |
|
857 |
RMSNorm = MorehRMSNorm if MorehRMSNorm is not None else MotifRMSNorm
|
858 |
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
923 |
residual = hidden_states
|
924 |
hidden_states = self.post_attention_layernorm(hidden_states) * self.post_attention_layernorm_alpha
|
925 |
|
926 |
+
hidden_states = self.mlp(hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
927 |
|
928 |
hidden_states = residual + hidden_states
|
929 |
|
|
|
1104 |
|
1105 |
num_hidden_layers = config.num_hidden_layers if self.multi_token_heads is None else config.num_hidden_layers - 1
|
1106 |
|
|
|
1107 |
self.layers = nn.ModuleList([
|
1108 |
MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)
|
1109 |
])
|