leejunhyeok commited on
Commit
03a80eb
·
verified ·
1 Parent(s): 7d405b9

Update modeling_motif.py

Browse files
Files changed (1) hide show
  1. modeling_motif.py +5 -17
modeling_motif.py CHANGED
@@ -284,14 +284,12 @@ class MotifMLP(nn.Module):
284
  super().__init__()
285
  self.hidden_size = config.hidden_size
286
  self.intermediate_size = config.intermediate_size
287
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
288
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
289
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
290
  self.act_fn = ACT2FN[config.hidden_act]
291
 
292
  def forward(self, hidden_state):
293
- hidden_state = hidden_state
294
- #hidden_state = self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))*
295
  return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
296
 
297
 
@@ -394,7 +392,7 @@ class MotifAttention(nn.Module):
394
  output_attentions: bool = False,
395
  use_cache: bool = False,
396
  cache_position: Optional[torch.LongTensor] = None,
397
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
398
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
399
  bsz, q_len, _ = hidden_states.size()
400
 
@@ -493,8 +491,6 @@ class MotifFlashAttention2(MotifAttention):
493
  def __init__(self, *args, **kwargs):
494
  super().__init__(*args, **kwargs)
495
 
496
-
497
-
498
  # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
499
  # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
500
  # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
@@ -516,7 +512,6 @@ class MotifFlashAttention2(MotifAttention):
516
  """Flash Attention 2 implements"""
517
 
518
  scale_factor = 1.0 / math.sqrt(self.head_dim)
519
- # Copied from _flash_attention_forward
520
  if not self._flash_attn_uses_top_left_mask:
521
  causal = self.is_causal
522
  else:
@@ -881,7 +876,6 @@ class MotifPreTrainedModel(PreTrainedModel):
881
  if isinstance(module, nn.Linear):
882
  module.weight.data.normal_(mean=0.0, std=module_std)
883
  module.weight.data = torch.where(abs(module.weight.data) > module_std*3, 0, module.weight.data)
884
- #torch.nn.init.trunc_normal_(module.weight.data, mean=0.0, std=module_std, a=-3*module_std, b=3*module_std)
885
  if module.bias is not None:
886
  module.bias.data.zero_()
887
 
@@ -1001,12 +995,8 @@ class MotifModel(MotifPreTrainedModel):
1001
  self.vocab_size = config.vocab_size
1002
 
1003
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1004
- # NOTE: For multi-token models, the last decoder layers (one for each token index)
1005
- # are implemented as a part of `MotifModelForCausalLM` to enable a custom forward-backward procedure.
1006
-
1007
  num_hidden_layers = config.num_hidden_layers
1008
  self.layers = nn.ModuleList([MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)])
1009
- self._attn_implementation = config._attn_implementation
1010
  self.norm = MotifRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1011
  self.hidden_size = config.hidden_size
1012
  self.num_heads = config.num_attention_heads
@@ -1079,7 +1069,6 @@ class MotifModel(MotifPreTrainedModel):
1079
  cache_position = torch.arange(past_seen_tokens,
1080
  past_seen_tokens + inputs_embeds.shape[1],
1081
  device=inputs_embeds.device)
1082
- #position_ids = None
1083
  if position_ids is None:
1084
  position_ids = cache_position.unsqueeze(0)
1085
 
@@ -1132,7 +1121,6 @@ class MotifModel(MotifPreTrainedModel):
1132
  if output_attentions:
1133
  all_self_attns += (layer_outputs[1], )
1134
 
1135
- # <|_2_|>
1136
  hidden_states = self.norm(hidden_states)
1137
 
1138
  # add hidden states from the last decoder layer
@@ -1192,6 +1180,7 @@ class MotifModel(MotifPreTrainedModel):
1192
  dtype, device = input_tensor.dtype, input_tensor.device
1193
  min_dtype = torch.finfo(dtype).min
1194
  sequence_length = input_tensor.shape[1]
 
1195
  # SlidingWindowCache or StaticCache
1196
  if using_sliding_window_cache or using_static_cache:
1197
  target_length = past_key_values.get_max_cache_shape()
@@ -1407,7 +1396,6 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
1407
  loss_fct = CrossEntropyLoss()
1408
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
1409
  shift_labels = shift_labels.view(-1)
1410
- # Enable model parallelism
1411
  shift_labels = shift_labels.to(shift_logits.device)
1412
  loss = loss_fct(shift_logits, shift_labels)
1413
 
 
284
  super().__init__()
285
  self.hidden_size = config.hidden_size
286
  self.intermediate_size = config.intermediate_size
287
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias)
288
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias)
289
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
290
  self.act_fn = ACT2FN[config.hidden_act]
291
 
292
  def forward(self, hidden_state):
 
 
293
  return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
294
 
295
 
 
392
  output_attentions: bool = False,
393
  use_cache: bool = False,
394
  cache_position: Optional[torch.LongTensor] = None,
395
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
396
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
397
  bsz, q_len, _ = hidden_states.size()
398
 
 
491
  def __init__(self, *args, **kwargs):
492
  super().__init__(*args, **kwargs)
493
 
 
 
494
  # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
495
  # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
496
  # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
 
512
  """Flash Attention 2 implements"""
513
 
514
  scale_factor = 1.0 / math.sqrt(self.head_dim)
 
515
  if not self._flash_attn_uses_top_left_mask:
516
  causal = self.is_causal
517
  else:
 
876
  if isinstance(module, nn.Linear):
877
  module.weight.data.normal_(mean=0.0, std=module_std)
878
  module.weight.data = torch.where(abs(module.weight.data) > module_std*3, 0, module.weight.data)
 
879
  if module.bias is not None:
880
  module.bias.data.zero_()
881
 
 
995
  self.vocab_size = config.vocab_size
996
 
997
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
 
 
998
  num_hidden_layers = config.num_hidden_layers
999
  self.layers = nn.ModuleList([MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)])
 
1000
  self.norm = MotifRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1001
  self.hidden_size = config.hidden_size
1002
  self.num_heads = config.num_attention_heads
 
1069
  cache_position = torch.arange(past_seen_tokens,
1070
  past_seen_tokens + inputs_embeds.shape[1],
1071
  device=inputs_embeds.device)
 
1072
  if position_ids is None:
1073
  position_ids = cache_position.unsqueeze(0)
1074
 
 
1121
  if output_attentions:
1122
  all_self_attns += (layer_outputs[1], )
1123
 
 
1124
  hidden_states = self.norm(hidden_states)
1125
 
1126
  # add hidden states from the last decoder layer
 
1180
  dtype, device = input_tensor.dtype, input_tensor.device
1181
  min_dtype = torch.finfo(dtype).min
1182
  sequence_length = input_tensor.shape[1]
1183
+
1184
  # SlidingWindowCache or StaticCache
1185
  if using_sliding_window_cache or using_static_cache:
1186
  target_length = past_key_values.get_max_cache_shape()
 
1396
  loss_fct = CrossEntropyLoss()
1397
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
1398
  shift_labels = shift_labels.view(-1)
 
1399
  shift_labels = shift_labels.to(shift_logits.device)
1400
  loss = loss_fct(shift_logits, shift_labels)
1401