eunhwanpark-motiftech commited on
Commit
f76fc65
·
verified ·
1 Parent(s): 498f8bc

Update modeling_motif.py

Browse files
Files changed (1) hide show
  1. modeling_motif.py +3 -12
modeling_motif.py CHANGED
@@ -558,7 +558,7 @@ class MotifAttention(nn.Module):
558
  attn_output = attn_output.transpose(1, 2).contiguous()
559
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
560
 
561
- attn_output = self.o_proj(attn_output) * self.o_proj_alpha
562
 
563
  if not output_attentions:
564
  attn_weights = None
@@ -1285,7 +1285,7 @@ class MotifModel(MotifPreTrainedModel):
1285
  all_self_attns += (layer_outputs[1], )
1286
 
1287
  # <|_2_|>
1288
- hidden_states = self.norm(hidden_states)* self.norm_alpha
1289
 
1290
  # add hidden states from the last decoder layer
1291
  if output_hidden_states:
@@ -1461,15 +1461,6 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
1461
  # Initialize weights and apply final processing
1462
  self.post_init()
1463
 
1464
- # <|_3_|>
1465
- if config.muP:
1466
- self.lm_head.__do_scale_tager_mu_dim_base_model__=True
1467
-
1468
- # <|_4_|>
1469
- self.lm_head_alpha = 1
1470
- if config.wesar_weights:
1471
- self.lm_head_alpha = nn.Parameter(torch.tensor(1).float())
1472
-
1473
  if getattr(config, "tie_word_embeddings", True):
1474
  logger.info('tie embeddings')
1475
  self.tie_weights()
@@ -1676,7 +1667,7 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
1676
  num_logits_to_keep=num_logits_to_keep)
1677
 
1678
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1679
- hidden_states = hidden_states * self.lm_head_alpha
1680
  logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
1681
  logits = logits.float()
1682
 
 
558
  attn_output = attn_output.transpose(1, 2).contiguous()
559
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
560
 
561
+ attn_output = self.o_proj(attn_output)
562
 
563
  if not output_attentions:
564
  attn_weights = None
 
1285
  all_self_attns += (layer_outputs[1], )
1286
 
1287
  # <|_2_|>
1288
+ hidden_states = self.norm(hidden_states)
1289
 
1290
  # add hidden states from the last decoder layer
1291
  if output_hidden_states:
 
1461
  # Initialize weights and apply final processing
1462
  self.post_init()
1463
 
 
 
 
 
 
 
 
 
 
1464
  if getattr(config, "tie_word_embeddings", True):
1465
  logger.info('tie embeddings')
1466
  self.tie_weights()
 
1667
  num_logits_to_keep=num_logits_to_keep)
1668
 
1669
  # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1670
+ hidden_states = hidden_states
1671
  logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
1672
  logits = logits.float()
1673