eunhwanpark-motiftech commited on
Commit
8bdf2ec
·
verified ·
1 Parent(s): c8cabde

Update modeling_motif.py

Browse files
Files changed (1) hide show
  1. modeling_motif.py +1 -27
modeling_motif.py CHANGED
@@ -34,8 +34,7 @@ from transformers.activations import ClassInstantier
34
  class PolyNorm(torch.nn.Module):
35
  """
36
  A trainable activation function introduced in https://arxiv.org/html/2411.03884v1.
37
- The code is copied from https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md,
38
- with the change `* torch.rsqrt` => `/ torch.sqrt` for potential MAF incompatibility.
39
  """
40
 
41
  def __init__(self, eps=1e-6):
@@ -117,7 +116,6 @@ class MotifRotaryEmbeddingWithCache(nn.Module):
117
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
118
  self.register_buffer("inv_freq", inv_freq, persistent=False)
119
 
120
- # Build here to make `torch.jit.trace` work.
121
  self._set_cos_sin_cache(seq_len=max_position_embeddings,
122
  device=self.inv_freq.device,
123
  dtype=torch.get_default_dtype())
@@ -173,7 +171,6 @@ class MotifRotaryEmbedding(nn.Module):
173
  self.max_seq_len_cached = max_position_embeddings
174
  self.original_max_seq_len = max_position_embeddings
175
  else:
176
- # BC: "rope_type" was originally "type"
177
  if config.rope_scaling is not None:
178
  self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
179
  else:
@@ -364,18 +361,15 @@ class MotifAttention(nn.Module):
364
  self.num_key_value_heads //= 2
365
  self.n_rep = self.num_heads // self.num_key_value_heads
366
 
367
- # re-init projections
368
  self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
369
  self.k_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
370
  self.v_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
371
  self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
372
 
373
- # init lambdas
374
  for name in ["lambda_q1", "lambda_k1", "lambda_q2", "lambda_k2"]:
375
  setattr(self, name, nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32)))
376
  getattr(self, name).data.normal_(mean=0.0, std=0.1)
377
 
378
- # Uses same norm as motif norm, without elementwise_affine option
379
  self.subln = MotifRMSNorm(2 * self.head_dim, eps=1e-5)
380
  self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * (layer_idx - 1))
381
 
@@ -400,8 +394,6 @@ class MotifAttention(nn.Module):
400
  key_states = self.k_proj(hidden_states)
401
  value_states = self.v_proj(hidden_states)
402
 
403
- ## bsz, seq, n_heads, head_dim
404
-
405
  query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
406
  key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2)
407
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, 2 * self.head_dim).transpose(1, 2)
@@ -428,11 +420,9 @@ class MotifAttention(nn.Module):
428
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
429
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
430
 
431
- # repeat k/v heads if n_kv_heads < n_heads
432
  key_states = repeat_kv(key_states, self.num_key_value_groups)
433
  value_states = repeat_kv(value_states, self.num_key_value_groups)
434
 
435
- ## bsz, #haead, q_len, head_dim -> bsz, head, q_len, q_len
436
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
437
 
438
  kv_seq_len = key_states.shape[-2]
@@ -442,24 +432,19 @@ class MotifAttention(nn.Module):
442
  torch.full((q_len, kv_seq_len), float("-inf"), dtype=attn_weights.dtype, device=attn_weights.device),
443
  1 + offset)
444
 
445
- ###add attn
446
  attn_weights = attn_weights + attention_mask
447
 
448
- # upcast attention to fp32
449
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
450
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
451
 
452
- # differential transformer lambdas
453
  lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(attn_weights)
454
  lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(attn_weights)
455
  lambda_full = lambda_1 - lambda_2 + self.lambda_init
456
  attn_weights = attn_weights.view(bsz, self.num_heads, 2, q_len, -1)
457
  attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1]
458
 
459
- ##shape : bsz, #heads, seq, head_dim
460
  attn_output = torch.matmul(attn_weights, value_states)
461
 
462
-
463
  attn_output = self.subln(attn_output)
464
  attn_output = attn_output * (1 - self.lambda_init)
465
 
@@ -487,10 +472,8 @@ class MotifFlashAttention2(MotifAttention):
487
  config.max_window_layers layers.
488
  """
489
 
490
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
491
  def __init__(self, *args, **kwargs):
492
  super().__init__(*args, **kwargs)
493
-
494
  # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
495
  # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
496
  # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
@@ -572,7 +555,6 @@ class MotifFlashAttention2(MotifAttention):
572
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
573
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
574
 
575
- # repeat k/v heads if n_kv_heads < n_heads
576
  key_states = repeat_kv(key_states, self.num_key_value_groups)
577
  value_states = repeat_kv(value_states, self.num_key_value_groups)
578
  dropout_rate = 0.0 if not self.training else self.attention_dropout
@@ -665,7 +647,6 @@ class MotifSdpaAttention(MotifAttention):
665
  SDPA API.
666
  """
667
 
668
- # Adapted from MotifAttention.forward
669
  def forward(
670
  self,
671
  hidden_states: torch.Tensor,
@@ -678,7 +659,6 @@ class MotifSdpaAttention(MotifAttention):
678
  position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
679
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
680
  if output_attentions:
681
- # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
682
  logger.warning_once(
683
  "MotifModel is using MotifSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
684
  'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
@@ -882,7 +862,6 @@ class MotifPreTrainedModel(PreTrainedModel):
882
  elif isinstance(module, nn.Embedding):
883
  module.weight.data.normal_(mean=0.0, std=module_std)
884
  module.weight.data = torch.where(abs(module.weight.data) > module_std*3, 0, module.weight.data)
885
- #torch.nn.init.trunc_normal_(module.weight.data, mean=0.0, std=module_std, a=-3*module_std, b=3*module_std)
886
  if module.padding_idx is not None:
887
  module.weight.data[module.padding_idx].zero_()
888
 
@@ -1048,7 +1027,6 @@ class MotifModel(MotifPreTrainedModel):
1048
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
1049
  use_cache = False
1050
 
1051
- # kept for BC (non `Cache` `past_key_values` inputs)
1052
  return_legacy_cache = False
1053
  if use_cache and not isinstance(past_key_values, Cache):
1054
  return_legacy_cache = True
@@ -1077,10 +1055,8 @@ class MotifModel(MotifPreTrainedModel):
1077
 
1078
  hidden_states = inputs_embeds
1079
  bsz, q_len, _ = hidden_states.size()
1080
- # create position embeddings to be shared across the decoder layers
1081
  position_embeddings = self.rotary_emb(hidden_states, seq_len=q_len)
1082
 
1083
- # decoder layers
1084
  all_hidden_states = () if output_hidden_states else None
1085
  all_self_attns = () if output_attentions else None
1086
  next_decoder_cache = None
@@ -1123,7 +1099,6 @@ class MotifModel(MotifPreTrainedModel):
1123
 
1124
  hidden_states = self.norm(hidden_states)
1125
 
1126
- # add hidden states from the last decoder layer
1127
  if output_hidden_states:
1128
  all_hidden_states += (hidden_states, )
1129
 
@@ -1289,7 +1264,6 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
1289
  self.post_init()
1290
 
1291
  if getattr(config, "tie_word_embeddings", True):
1292
- logger.info('tie embeddings')
1293
  self.tie_weights()
1294
 
1295
  def get_input_embeddings(self):
 
34
  class PolyNorm(torch.nn.Module):
35
  """
36
  A trainable activation function introduced in https://arxiv.org/html/2411.03884v1.
37
+ The code is copied from https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
 
38
  """
39
 
40
  def __init__(self, eps=1e-6):
 
116
  inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
117
  self.register_buffer("inv_freq", inv_freq, persistent=False)
118
 
 
119
  self._set_cos_sin_cache(seq_len=max_position_embeddings,
120
  device=self.inv_freq.device,
121
  dtype=torch.get_default_dtype())
 
171
  self.max_seq_len_cached = max_position_embeddings
172
  self.original_max_seq_len = max_position_embeddings
173
  else:
 
174
  if config.rope_scaling is not None:
175
  self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
176
  else:
 
361
  self.num_key_value_heads //= 2
362
  self.n_rep = self.num_heads // self.num_key_value_heads
363
 
 
364
  self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
365
  self.k_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
366
  self.v_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
367
  self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
368
 
 
369
  for name in ["lambda_q1", "lambda_k1", "lambda_q2", "lambda_k2"]:
370
  setattr(self, name, nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32)))
371
  getattr(self, name).data.normal_(mean=0.0, std=0.1)
372
 
 
373
  self.subln = MotifRMSNorm(2 * self.head_dim, eps=1e-5)
374
  self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * (layer_idx - 1))
375
 
 
394
  key_states = self.k_proj(hidden_states)
395
  value_states = self.v_proj(hidden_states)
396
 
 
 
397
  query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
398
  key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2)
399
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, 2 * self.head_dim).transpose(1, 2)
 
420
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
421
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
422
 
 
423
  key_states = repeat_kv(key_states, self.num_key_value_groups)
424
  value_states = repeat_kv(value_states, self.num_key_value_groups)
425
 
 
426
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
427
 
428
  kv_seq_len = key_states.shape[-2]
 
432
  torch.full((q_len, kv_seq_len), float("-inf"), dtype=attn_weights.dtype, device=attn_weights.device),
433
  1 + offset)
434
 
 
435
  attn_weights = attn_weights + attention_mask
436
 
 
437
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
438
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
439
 
 
440
  lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(attn_weights)
441
  lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(attn_weights)
442
  lambda_full = lambda_1 - lambda_2 + self.lambda_init
443
  attn_weights = attn_weights.view(bsz, self.num_heads, 2, q_len, -1)
444
  attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1]
445
 
 
446
  attn_output = torch.matmul(attn_weights, value_states)
447
 
 
448
  attn_output = self.subln(attn_output)
449
  attn_output = attn_output * (1 - self.lambda_init)
450
 
 
472
  config.max_window_layers layers.
473
  """
474
 
 
475
  def __init__(self, *args, **kwargs):
476
  super().__init__(*args, **kwargs)
 
477
  # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
478
  # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
479
  # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
 
555
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
556
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
557
 
 
558
  key_states = repeat_kv(key_states, self.num_key_value_groups)
559
  value_states = repeat_kv(value_states, self.num_key_value_groups)
560
  dropout_rate = 0.0 if not self.training else self.attention_dropout
 
647
  SDPA API.
648
  """
649
 
 
650
  def forward(
651
  self,
652
  hidden_states: torch.Tensor,
 
659
  position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
660
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
661
  if output_attentions:
 
662
  logger.warning_once(
663
  "MotifModel is using MotifSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
664
  'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
 
862
  elif isinstance(module, nn.Embedding):
863
  module.weight.data.normal_(mean=0.0, std=module_std)
864
  module.weight.data = torch.where(abs(module.weight.data) > module_std*3, 0, module.weight.data)
 
865
  if module.padding_idx is not None:
866
  module.weight.data[module.padding_idx].zero_()
867
 
 
1027
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
1028
  use_cache = False
1029
 
 
1030
  return_legacy_cache = False
1031
  if use_cache and not isinstance(past_key_values, Cache):
1032
  return_legacy_cache = True
 
1055
 
1056
  hidden_states = inputs_embeds
1057
  bsz, q_len, _ = hidden_states.size()
 
1058
  position_embeddings = self.rotary_emb(hidden_states, seq_len=q_len)
1059
 
 
1060
  all_hidden_states = () if output_hidden_states else None
1061
  all_self_attns = () if output_attentions else None
1062
  next_decoder_cache = None
 
1099
 
1100
  hidden_states = self.norm(hidden_states)
1101
 
 
1102
  if output_hidden_states:
1103
  all_hidden_states += (hidden_states, )
1104
 
 
1264
  self.post_init()
1265
 
1266
  if getattr(config, "tie_word_embeddings", True):
 
1267
  self.tie_weights()
1268
 
1269
  def get_input_embeddings(self):