Update modeling_motif.py
Browse files- modeling_motif.py +5 -17
modeling_motif.py
CHANGED
@@ -284,14 +284,12 @@ class MotifMLP(nn.Module):
|
|
284 |
super().__init__()
|
285 |
self.hidden_size = config.hidden_size
|
286 |
self.intermediate_size = config.intermediate_size
|
287 |
-
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=
|
288 |
-
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=
|
289 |
-
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=
|
290 |
self.act_fn = ACT2FN[config.hidden_act]
|
291 |
|
292 |
def forward(self, hidden_state):
|
293 |
-
hidden_state = hidden_state
|
294 |
-
#hidden_state = self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))*
|
295 |
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
296 |
|
297 |
|
@@ -394,7 +392,7 @@ class MotifAttention(nn.Module):
|
|
394 |
output_attentions: bool = False,
|
395 |
use_cache: bool = False,
|
396 |
cache_position: Optional[torch.LongTensor] = None,
|
397 |
-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
398 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
399 |
bsz, q_len, _ = hidden_states.size()
|
400 |
|
@@ -493,8 +491,6 @@ class MotifFlashAttention2(MotifAttention):
|
|
493 |
def __init__(self, *args, **kwargs):
|
494 |
super().__init__(*args, **kwargs)
|
495 |
|
496 |
-
|
497 |
-
|
498 |
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
499 |
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
500 |
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
@@ -516,7 +512,6 @@ class MotifFlashAttention2(MotifAttention):
|
|
516 |
"""Flash Attention 2 implements"""
|
517 |
|
518 |
scale_factor = 1.0 / math.sqrt(self.head_dim)
|
519 |
-
# Copied from _flash_attention_forward
|
520 |
if not self._flash_attn_uses_top_left_mask:
|
521 |
causal = self.is_causal
|
522 |
else:
|
@@ -881,7 +876,6 @@ class MotifPreTrainedModel(PreTrainedModel):
|
|
881 |
if isinstance(module, nn.Linear):
|
882 |
module.weight.data.normal_(mean=0.0, std=module_std)
|
883 |
module.weight.data = torch.where(abs(module.weight.data) > module_std*3, 0, module.weight.data)
|
884 |
-
#torch.nn.init.trunc_normal_(module.weight.data, mean=0.0, std=module_std, a=-3*module_std, b=3*module_std)
|
885 |
if module.bias is not None:
|
886 |
module.bias.data.zero_()
|
887 |
|
@@ -1001,12 +995,8 @@ class MotifModel(MotifPreTrainedModel):
|
|
1001 |
self.vocab_size = config.vocab_size
|
1002 |
|
1003 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
1004 |
-
# NOTE: For multi-token models, the last decoder layers (one for each token index)
|
1005 |
-
# are implemented as a part of `MotifModelForCausalLM` to enable a custom forward-backward procedure.
|
1006 |
-
|
1007 |
num_hidden_layers = config.num_hidden_layers
|
1008 |
self.layers = nn.ModuleList([MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)])
|
1009 |
-
self._attn_implementation = config._attn_implementation
|
1010 |
self.norm = MotifRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1011 |
self.hidden_size = config.hidden_size
|
1012 |
self.num_heads = config.num_attention_heads
|
@@ -1079,7 +1069,6 @@ class MotifModel(MotifPreTrainedModel):
|
|
1079 |
cache_position = torch.arange(past_seen_tokens,
|
1080 |
past_seen_tokens + inputs_embeds.shape[1],
|
1081 |
device=inputs_embeds.device)
|
1082 |
-
#position_ids = None
|
1083 |
if position_ids is None:
|
1084 |
position_ids = cache_position.unsqueeze(0)
|
1085 |
|
@@ -1132,7 +1121,6 @@ class MotifModel(MotifPreTrainedModel):
|
|
1132 |
if output_attentions:
|
1133 |
all_self_attns += (layer_outputs[1], )
|
1134 |
|
1135 |
-
# <|_2_|>
|
1136 |
hidden_states = self.norm(hidden_states)
|
1137 |
|
1138 |
# add hidden states from the last decoder layer
|
@@ -1192,6 +1180,7 @@ class MotifModel(MotifPreTrainedModel):
|
|
1192 |
dtype, device = input_tensor.dtype, input_tensor.device
|
1193 |
min_dtype = torch.finfo(dtype).min
|
1194 |
sequence_length = input_tensor.shape[1]
|
|
|
1195 |
# SlidingWindowCache or StaticCache
|
1196 |
if using_sliding_window_cache or using_static_cache:
|
1197 |
target_length = past_key_values.get_max_cache_shape()
|
@@ -1407,7 +1396,6 @@ class MotifForCausalLM(MotifPreTrainedModel, GenerationMixin):
|
|
1407 |
loss_fct = CrossEntropyLoss()
|
1408 |
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
1409 |
shift_labels = shift_labels.view(-1)
|
1410 |
-
# Enable model parallelism
|
1411 |
shift_labels = shift_labels.to(shift_logits.device)
|
1412 |
loss = loss_fct(shift_logits, shift_labels)
|
1413 |
|
|
|
284 |
super().__init__()
|
285 |
self.hidden_size = config.hidden_size
|
286 |
self.intermediate_size = config.intermediate_size
|
287 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias)
|
288 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias)
|
289 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
|
290 |
self.act_fn = ACT2FN[config.hidden_act]
|
291 |
|
292 |
def forward(self, hidden_state):
|
|
|
|
|
293 |
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
294 |
|
295 |
|
|
|
392 |
output_attentions: bool = False,
|
393 |
use_cache: bool = False,
|
394 |
cache_position: Optional[torch.LongTensor] = None,
|
395 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
396 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
397 |
bsz, q_len, _ = hidden_states.size()
|
398 |
|
|
|
491 |
def __init__(self, *args, **kwargs):
|
492 |
super().__init__(*args, **kwargs)
|
493 |
|
|
|
|
|
494 |
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
495 |
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
496 |
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
|
|
512 |
"""Flash Attention 2 implements"""
|
513 |
|
514 |
scale_factor = 1.0 / math.sqrt(self.head_dim)
|
|
|
515 |
if not self._flash_attn_uses_top_left_mask:
|
516 |
causal = self.is_causal
|
517 |
else:
|
|
|
876 |
if isinstance(module, nn.Linear):
|
877 |
module.weight.data.normal_(mean=0.0, std=module_std)
|
878 |
module.weight.data = torch.where(abs(module.weight.data) > module_std*3, 0, module.weight.data)
|
|
|
879 |
if module.bias is not None:
|
880 |
module.bias.data.zero_()
|
881 |
|
|
|
995 |
self.vocab_size = config.vocab_size
|
996 |
|
997 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
|
|
|
|
|
|
998 |
num_hidden_layers = config.num_hidden_layers
|
999 |
self.layers = nn.ModuleList([MotifDecoderLayer(config = config, layer_idx=layer_idx) for layer_idx in range(num_hidden_layers)])
|
|
|
1000 |
self.norm = MotifRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1001 |
self.hidden_size = config.hidden_size
|
1002 |
self.num_heads = config.num_attention_heads
|
|
|
1069 |
cache_position = torch.arange(past_seen_tokens,
|
1070 |
past_seen_tokens + inputs_embeds.shape[1],
|
1071 |
device=inputs_embeds.device)
|
|
|
1072 |
if position_ids is None:
|
1073 |
position_ids = cache_position.unsqueeze(0)
|
1074 |
|
|
|
1121 |
if output_attentions:
|
1122 |
all_self_attns += (layer_outputs[1], )
|
1123 |
|
|
|
1124 |
hidden_states = self.norm(hidden_states)
|
1125 |
|
1126 |
# add hidden states from the last decoder layer
|
|
|
1180 |
dtype, device = input_tensor.dtype, input_tensor.device
|
1181 |
min_dtype = torch.finfo(dtype).min
|
1182 |
sequence_length = input_tensor.shape[1]
|
1183 |
+
|
1184 |
# SlidingWindowCache or StaticCache
|
1185 |
if using_sliding_window_cache or using_static_cache:
|
1186 |
target_length = past_key_values.get_max_cache_shape()
|
|
|
1396 |
loss_fct = CrossEntropyLoss()
|
1397 |
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
1398 |
shift_labels = shift_labels.view(-1)
|
|
|
1399 |
shift_labels = shift_labels.to(shift_logits.device)
|
1400 |
loss = loss_fct(shift_logits, shift_labels)
|
1401 |
|