Update modeling_motif.py
Browse files- modeling_motif.py +1 -27
modeling_motif.py
CHANGED
@@ -34,8 +34,7 @@ from transformers.activations import ClassInstantier
|
|
34 |
class PolyNorm(torch.nn.Module):
|
35 |
"""
|
36 |
A trainable activation function introduced in https://arxiv.org/html/2411.03884v1.
|
37 |
-
The code is copied from https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
|
38 |
-
with the change `* torch.rsqrt` => `/ torch.sqrt` for potential MAF incompatibility.
|
39 |
"""
|
40 |
|
41 |
def __init__(self, eps=1e-6):
|
@@ -117,7 +116,6 @@ class MotifRotaryEmbeddingWithCache(nn.Module):
|
|
117 |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
118 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
119 |
|
120 |
-
# Build here to make `torch.jit.trace` work.
|
121 |
self._set_cos_sin_cache(seq_len=max_position_embeddings,
|
122 |
device=self.inv_freq.device,
|
123 |
dtype=torch.get_default_dtype())
|
@@ -173,7 +171,6 @@ class MotifRotaryEmbedding(nn.Module):
|
|
173 |
self.max_seq_len_cached = max_position_embeddings
|
174 |
self.original_max_seq_len = max_position_embeddings
|
175 |
else:
|
176 |
-
# BC: "rope_type" was originally "type"
|
177 |
if config.rope_scaling is not None:
|
178 |
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
179 |
else:
|
@@ -364,18 +361,15 @@ class MotifAttention(nn.Module):
|
|
364 |
self.num_key_value_heads //= 2
|
365 |
self.n_rep = self.num_heads // self.num_key_value_heads
|
366 |
|
367 |
-
# re-init projections
|
368 |
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
369 |
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
|
370 |
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
|
371 |
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
372 |
|
373 |
-
# init lambdas
|
374 |
for name in ["lambda_q1", "lambda_k1", "lambda_q2", "lambda_k2"]:
|
375 |
setattr(self, name, nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32)))
|
376 |
getattr(self, name).data.normal_(mean=0.0, std=0.1)
|
377 |
|
378 |
-
# Uses same norm as motif norm, without elementwise_affine option
|
379 |
self.subln = MotifRMSNorm(2 * self.head_dim, eps=1e-5)
|
380 |
self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * (layer_idx - 1))
|
381 |
|
@@ -400,8 +394,6 @@ class MotifAttention(nn.Module):
|
|
400 |
key_states = self.k_proj(hidden_states)
|
401 |
value_states = self.v_proj(hidden_states)
|
402 |
|
403 |
-
## bsz, seq, n_heads, head_dim
|
404 |
-
|
405 |
query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
|
406 |
key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
407 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, 2 * self.head_dim).transpose(1, 2)
|
@@ -428,11 +420,9 @@ class MotifAttention(nn.Module):
|
|
428 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
429 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
430 |
|
431 |
-
# repeat k/v heads if n_kv_heads < n_heads
|
432 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
433 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
434 |
|
435 |
-
## bsz, #haead, q_len, head_dim -> bsz, head, q_len, q_len
|
436 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
437 |
|
438 |
kv_seq_len = key_states.shape[-2]
|
@@ -442,24 +432,19 @@ class MotifAttention(nn.Module):
|
|
442 |
torch.full((q_len, kv_seq_len), float("-inf"), dtype=attn_weights.dtype, device=attn_weights.device),
|
443 |
1 + offset)
|
444 |
|
445 |
-
###add attn
|
446 |
attn_weights = attn_weights + attention_mask
|
447 |
|
448 |
-
# upcast attention to fp32
|
449 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
450 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
451 |
|
452 |
-
# differential transformer lambdas
|
453 |
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(attn_weights)
|
454 |
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(attn_weights)
|
455 |
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
456 |
attn_weights = attn_weights.view(bsz, self.num_heads, 2, q_len, -1)
|
457 |
attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1]
|
458 |
|
459 |
-
##shape : bsz, #heads, seq, head_dim
|
460 |
attn_output = torch.matmul(attn_weights, value_states)
|
461 |
|
462 |
-
|
463 |
attn_output = self.subln(attn_output)
|
464 |
attn_output = attn_output * (1 - self.lambda_init)
|
465 |
|
@@ -487,10 +472,8 @@ class MotifFlashAttention2(MotifAttention):
|
|
487 |
config.max_window_layers layers.
|
488 |
"""
|
489 |
|
490 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
|
491 |
def __init__(self, *args, **kwargs):
|
492 |
super().__init__(*args, **kwargs)
|
493 |
-
|
494 |
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
495 |
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
496 |
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
@@ -572,7 +555,6 @@ class MotifFlashAttention2(MotifAttention):
|
|
572 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
573 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
574 |
|
575 |
-
# repeat k/v heads if n_kv_heads < n_heads
|
576 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
577 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
578 |
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
@@ -665,7 +647,6 @@ class MotifSdpaAttention(MotifAttention):
|
|
665 |
SDPA API.
|
666 |
"""
|
667 |
|
668 |
-
# Adapted from MotifAttention.forward
|
669 |
def forward(
|
670 |
self,
|
671 |
hidden_states: torch.Tensor,
|
@@ -678,7 +659,6 @@ class MotifSdpaAttention(MotifAttention):
|
|
678 |
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
679 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
680 |
if output_attentions:
|
681 |
-
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
682 |
logger.warning_once(
|
683 |
"MotifModel is using MotifSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
684 |
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
@@ -882,7 +862,6 @@ class MotifPreTrainedModel(PreTrainedModel):
|
|
882 |
elif isinstance(module, nn.Embedding):
|
883 |
module.weight.data.normal_(mean=0.0, std=module_std)
|
884 |
module.weight.data = torch.where(abs(module.weight.data) > module_std*3, 0, module.weight.data)
|
885 |
-
#torch.nn.init.trunc_normal_(module.weight.data, mean=0.0, std=module_std, a=-3*module_std, b=3*module_std)
|
886 |
if module.padding_idx is not None:
|
887 |
module.weight.data[module.padding_idx].zero_()
|
888 |
|
@@ -1048,7 +1027,6 @@ class MotifModel(MotifPreTrainedModel):
|
|
1048 |
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
1049 |
use_cache = False
|
1050 |
|
1051 |
-
# kept for BC (non `Cache` `past_key_values` inputs)
|
1052 |
return_legacy_cache = False
|
1053 |
if use_cache and not isinstance(past_key_values, Cache):
|
1054 |
return_legacy_cache = True
|
@@ -1077,10 +1055,8 @@ class MotifModel(MotifPreTrainedModel):
|
|
1077 |
|
1078 |
hidden_states = inputs_embeds
|
1079 |
bsz, q_len, _ = hidden_states.size()
|
1080 |
-
# create position embeddings to be shared across the decoder layers
|
1081 |
position_embeddings = self.rotary_emb(hidden_states, seq_len=q_len)
|
1082 |
|
1083 |
-
# decoder layers
|
1084 |
all_hidden_states = () if output_hidden_states else None
|
1085 |
all_self_attns = () if output_attentions else None
|
1086 |
next_decoder_cache = None
|
@@ -1123,7 +1099,6 @@ class MotifModel(MotifPreTrainedModel):
|
|
1123 |
|
1124 |
hidden_states = self.norm(hidden_states)
|
1125 |
|
1126 |
-
# add hidden states from the last decoder layer
|
1127 |
if output_hidden_states:
|
1128 |
all_hidden_states += (hidden_states, )
|
1129 |
|
@@ -1289,7 +1264,6 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
|
1289 |
self.post_init()
|
1290 |
|
1291 |
if getattr(config, "tie_word_embeddings", True):
|
1292 |
-
logger.info('tie embeddings')
|
1293 |
self.tie_weights()
|
1294 |
|
1295 |
def get_input_embeddings(self):
|
|
|
34 |
class PolyNorm(torch.nn.Module):
|
35 |
"""
|
36 |
A trainable activation function introduced in https://arxiv.org/html/2411.03884v1.
|
37 |
+
The code is copied from https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
|
|
|
38 |
"""
|
39 |
|
40 |
def __init__(self, eps=1e-6):
|
|
|
116 |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
117 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
118 |
|
|
|
119 |
self._set_cos_sin_cache(seq_len=max_position_embeddings,
|
120 |
device=self.inv_freq.device,
|
121 |
dtype=torch.get_default_dtype())
|
|
|
171 |
self.max_seq_len_cached = max_position_embeddings
|
172 |
self.original_max_seq_len = max_position_embeddings
|
173 |
else:
|
|
|
174 |
if config.rope_scaling is not None:
|
175 |
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
176 |
else:
|
|
|
361 |
self.num_key_value_heads //= 2
|
362 |
self.n_rep = self.num_heads // self.num_key_value_heads
|
363 |
|
|
|
364 |
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
365 |
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
|
366 |
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size // self.n_rep, bias=False)
|
367 |
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
368 |
|
|
|
369 |
for name in ["lambda_q1", "lambda_k1", "lambda_q2", "lambda_k2"]:
|
370 |
setattr(self, name, nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32)))
|
371 |
getattr(self, name).data.normal_(mean=0.0, std=0.1)
|
372 |
|
|
|
373 |
self.subln = MotifRMSNorm(2 * self.head_dim, eps=1e-5)
|
374 |
self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * (layer_idx - 1))
|
375 |
|
|
|
394 |
key_states = self.k_proj(hidden_states)
|
395 |
value_states = self.v_proj(hidden_states)
|
396 |
|
|
|
|
|
397 |
query_states = query_states.view(bsz, q_len, 2 * self.num_heads, self.head_dim).transpose(1, 2)
|
398 |
key_states = key_states.view(bsz, q_len, 2 * self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
399 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, 2 * self.head_dim).transpose(1, 2)
|
|
|
420 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
421 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
422 |
|
|
|
423 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
424 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
425 |
|
|
|
426 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
427 |
|
428 |
kv_seq_len = key_states.shape[-2]
|
|
|
432 |
torch.full((q_len, kv_seq_len), float("-inf"), dtype=attn_weights.dtype, device=attn_weights.device),
|
433 |
1 + offset)
|
434 |
|
|
|
435 |
attn_weights = attn_weights + attention_mask
|
436 |
|
|
|
437 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
438 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
439 |
|
|
|
440 |
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(attn_weights)
|
441 |
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(attn_weights)
|
442 |
lambda_full = lambda_1 - lambda_2 + self.lambda_init
|
443 |
attn_weights = attn_weights.view(bsz, self.num_heads, 2, q_len, -1)
|
444 |
attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1]
|
445 |
|
|
|
446 |
attn_output = torch.matmul(attn_weights, value_states)
|
447 |
|
|
|
448 |
attn_output = self.subln(attn_output)
|
449 |
attn_output = attn_output * (1 - self.lambda_init)
|
450 |
|
|
|
472 |
config.max_window_layers layers.
|
473 |
"""
|
474 |
|
|
|
475 |
def __init__(self, *args, **kwargs):
|
476 |
super().__init__(*args, **kwargs)
|
|
|
477 |
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
478 |
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
479 |
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
|
|
555 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
|
556 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
557 |
|
|
|
558 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
559 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
560 |
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
|
|
647 |
SDPA API.
|
648 |
"""
|
649 |
|
|
|
650 |
def forward(
|
651 |
self,
|
652 |
hidden_states: torch.Tensor,
|
|
|
659 |
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
660 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
661 |
if output_attentions:
|
|
|
662 |
logger.warning_once(
|
663 |
"MotifModel is using MotifSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
664 |
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
|
|
862 |
elif isinstance(module, nn.Embedding):
|
863 |
module.weight.data.normal_(mean=0.0, std=module_std)
|
864 |
module.weight.data = torch.where(abs(module.weight.data) > module_std*3, 0, module.weight.data)
|
|
|
865 |
if module.padding_idx is not None:
|
866 |
module.weight.data[module.padding_idx].zero_()
|
867 |
|
|
|
1027 |
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
1028 |
use_cache = False
|
1029 |
|
|
|
1030 |
return_legacy_cache = False
|
1031 |
if use_cache and not isinstance(past_key_values, Cache):
|
1032 |
return_legacy_cache = True
|
|
|
1055 |
|
1056 |
hidden_states = inputs_embeds
|
1057 |
bsz, q_len, _ = hidden_states.size()
|
|
|
1058 |
position_embeddings = self.rotary_emb(hidden_states, seq_len=q_len)
|
1059 |
|
|
|
1060 |
all_hidden_states = () if output_hidden_states else None
|
1061 |
all_self_attns = () if output_attentions else None
|
1062 |
next_decoder_cache = None
|
|
|
1099 |
|
1100 |
hidden_states = self.norm(hidden_states)
|
1101 |
|
|
|
1102 |
if output_hidden_states:
|
1103 |
all_hidden_states += (hidden_states, )
|
1104 |
|
|
|
1264 |
self.post_init()
|
1265 |
|
1266 |
if getattr(config, "tie_word_embeddings", True):
|
|
|
1267 |
self.tie_weights()
|
1268 |
|
1269 |
def get_input_embeddings(self):
|