Update modeling_motif.py
Browse files- modeling_motif.py +3 -12
modeling_motif.py
CHANGED
@@ -558,7 +558,7 @@ class MotifAttention(nn.Module):
|
|
558 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
559 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
560 |
|
561 |
-
attn_output = self.o_proj(attn_output)
|
562 |
|
563 |
if not output_attentions:
|
564 |
attn_weights = None
|
@@ -1285,7 +1285,7 @@ class MotifModel(MotifPreTrainedModel):
|
|
1285 |
all_self_attns += (layer_outputs[1], )
|
1286 |
|
1287 |
# <|_2_|>
|
1288 |
-
hidden_states = self.norm(hidden_states)
|
1289 |
|
1290 |
# add hidden states from the last decoder layer
|
1291 |
if output_hidden_states:
|
@@ -1461,15 +1461,6 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
|
1461 |
# Initialize weights and apply final processing
|
1462 |
self.post_init()
|
1463 |
|
1464 |
-
# <|_3_|>
|
1465 |
-
if config.muP:
|
1466 |
-
self.lm_head.__do_scale_tager_mu_dim_base_model__=True
|
1467 |
-
|
1468 |
-
# <|_4_|>
|
1469 |
-
self.lm_head_alpha = 1
|
1470 |
-
if config.wesar_weights:
|
1471 |
-
self.lm_head_alpha = nn.Parameter(torch.tensor(1).float())
|
1472 |
-
|
1473 |
if getattr(config, "tie_word_embeddings", True):
|
1474 |
logger.info('tie embeddings')
|
1475 |
self.tie_weights()
|
@@ -1676,7 +1667,7 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
|
1676 |
num_logits_to_keep=num_logits_to_keep)
|
1677 |
|
1678 |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
1679 |
-
hidden_states = hidden_states
|
1680 |
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
1681 |
logits = logits.float()
|
1682 |
|
|
|
558 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
559 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
560 |
|
561 |
+
attn_output = self.o_proj(attn_output)
|
562 |
|
563 |
if not output_attentions:
|
564 |
attn_weights = None
|
|
|
1285 |
all_self_attns += (layer_outputs[1], )
|
1286 |
|
1287 |
# <|_2_|>
|
1288 |
+
hidden_states = self.norm(hidden_states)
|
1289 |
|
1290 |
# add hidden states from the last decoder layer
|
1291 |
if output_hidden_states:
|
|
|
1461 |
# Initialize weights and apply final processing
|
1462 |
self.post_init()
|
1463 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1464 |
if getattr(config, "tie_word_embeddings", True):
|
1465 |
logger.info('tie embeddings')
|
1466 |
self.tie_weights()
|
|
|
1667 |
num_logits_to_keep=num_logits_to_keep)
|
1668 |
|
1669 |
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
1670 |
+
hidden_states = hidden_states
|
1671 |
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
1672 |
logits = logits.float()
|
1673 |
|