Upload modeling_bailing_moe_v2.py with huggingface_hub
Browse files- modeling_bailing_moe_v2.py +13 -11
modeling_bailing_moe_v2.py
CHANGED
@@ -25,9 +25,7 @@ from typing import List, Optional, Tuple, Union
|
|
25 |
|
26 |
import torch
|
27 |
import torch.nn.functional as F
|
28 |
-
import torch.utils.checkpoint
|
29 |
from torch import nn
|
30 |
-
from torch.nn import CrossEntropyLoss
|
31 |
|
32 |
from transformers.activations import ACT2FN
|
33 |
from transformers.cache_utils import Cache, DynamicCache
|
@@ -1157,11 +1155,11 @@ class BailingMoeV2Model(BailingMoeV2PreTrainedModel):
|
|
1157 |
super().__init__(config)
|
1158 |
self.padding_idx = config.pad_token_id
|
1159 |
self.vocab_size = config.vocab_size
|
1160 |
-
self.
|
1161 |
|
1162 |
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
1163 |
self.layers = []
|
1164 |
-
for layer_idx in range(config.num_hidden_layers + config.
|
1165 |
layer_cls = BailingMoeV2DecoderLayer if layer_idx < config.num_hidden_layers else BailingMoeV2MTPLayer
|
1166 |
self.layers.append(layer_cls(config, layer_idx))
|
1167 |
|
@@ -1267,8 +1265,8 @@ class BailingMoeV2Model(BailingMoeV2PreTrainedModel):
|
|
1267 |
all_self_attns = () if output_attentions else None
|
1268 |
all_router_logits = () if output_router_logits else None
|
1269 |
next_decoder_cache = None
|
1270 |
-
layers = self.layers[: -self.
|
1271 |
-
mtp_layers = self.layers[-self.
|
1272 |
|
1273 |
for decoder_layer in layers:
|
1274 |
if output_hidden_states:
|
@@ -1391,7 +1389,7 @@ class BailingMoeV2ForCausalLM(BailingMoeV2PreTrainedModel, GenerationMixin):
|
|
1391 |
self.model = BailingMoeV2Model(config)
|
1392 |
self.vocab_size = config.vocab_size
|
1393 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1394 |
-
self.
|
1395 |
self.mtp_loss_scaling_factor = config.mtp_loss_scaling_factor
|
1396 |
|
1397 |
# Initialize weights and apply final processing
|
@@ -1491,18 +1489,21 @@ class BailingMoeV2ForCausalLM(BailingMoeV2PreTrainedModel, GenerationMixin):
|
|
1491 |
loss = self.loss_function(logits, labels, self.config.vocab_size, **kwargs)
|
1492 |
|
1493 |
all_mtp_logits = None
|
1494 |
-
if self.
|
1495 |
mtp_hidden_states = outputs.mtp_hidden_states
|
1496 |
-
shift_labels_mtp =
|
1497 |
-
for i in range(self.
|
1498 |
mtp_hidden_states = mtp_hidden_states[i]
|
1499 |
mtp_logits = self.lm_head(mtp_hidden_states).float()
|
1500 |
if all_mtp_logits is None:
|
1501 |
all_mtp_logits = []
|
1502 |
all_mtp_logits.append(mtp_logits)
|
1503 |
if labels is not None:
|
|
|
|
|
1504 |
shift_labels_mtp, _ = roll_tensor(shift_labels_mtp, shifts=-1, dims=-1, fill_value=-100)
|
1505 |
-
|
|
|
1506 |
if loss is not None:
|
1507 |
loss += self.mtp_loss_scaling_factor * mtp_loss
|
1508 |
else:
|
@@ -1529,3 +1530,4 @@ class BailingMoeV2ForCausalLM(BailingMoeV2PreTrainedModel, GenerationMixin):
|
|
1529 |
attentions=outputs.attentions,
|
1530 |
router_logits=outputs.router_logits,
|
1531 |
)
|
|
|
|
25 |
|
26 |
import torch
|
27 |
import torch.nn.functional as F
|
|
|
28 |
from torch import nn
|
|
|
29 |
|
30 |
from transformers.activations import ACT2FN
|
31 |
from transformers.cache_utils import Cache, DynamicCache
|
|
|
1155 |
super().__init__(config)
|
1156 |
self.padding_idx = config.pad_token_id
|
1157 |
self.vocab_size = config.vocab_size
|
1158 |
+
self.num_nextn_predict_layers = config.num_nextn_predict_layers
|
1159 |
|
1160 |
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
1161 |
self.layers = []
|
1162 |
+
for layer_idx in range(config.num_hidden_layers + config.num_nextn_predict_layers):
|
1163 |
layer_cls = BailingMoeV2DecoderLayer if layer_idx < config.num_hidden_layers else BailingMoeV2MTPLayer
|
1164 |
self.layers.append(layer_cls(config, layer_idx))
|
1165 |
|
|
|
1265 |
all_self_attns = () if output_attentions else None
|
1266 |
all_router_logits = () if output_router_logits else None
|
1267 |
next_decoder_cache = None
|
1268 |
+
layers = self.layers[: -self.num_nextn_predict_layers] if self.num_nextn_predict_layers > 0 else self.layers
|
1269 |
+
mtp_layers = self.layers[-self.num_nextn_predict_layers :] if self.num_nextn_predict_layers > 0 else None
|
1270 |
|
1271 |
for decoder_layer in layers:
|
1272 |
if output_hidden_states:
|
|
|
1389 |
self.model = BailingMoeV2Model(config)
|
1390 |
self.vocab_size = config.vocab_size
|
1391 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
1392 |
+
self.num_nextn_predict_layers = config.num_nextn_predict_layers
|
1393 |
self.mtp_loss_scaling_factor = config.mtp_loss_scaling_factor
|
1394 |
|
1395 |
# Initialize weights and apply final processing
|
|
|
1489 |
loss = self.loss_function(logits, labels, self.config.vocab_size, **kwargs)
|
1490 |
|
1491 |
all_mtp_logits = None
|
1492 |
+
if self.num_nextn_predict_layers > 0:
|
1493 |
mtp_hidden_states = outputs.mtp_hidden_states
|
1494 |
+
shift_labels_mtp = None
|
1495 |
+
for i in range(self.num_nextn_predict_layers):
|
1496 |
mtp_hidden_states = mtp_hidden_states[i]
|
1497 |
mtp_logits = self.lm_head(mtp_hidden_states).float()
|
1498 |
if all_mtp_logits is None:
|
1499 |
all_mtp_logits = []
|
1500 |
all_mtp_logits.append(mtp_logits)
|
1501 |
if labels is not None:
|
1502 |
+
if shift_labels_mtp is None:
|
1503 |
+
shift_labels_mtp = labels.clone()
|
1504 |
shift_labels_mtp, _ = roll_tensor(shift_labels_mtp, shifts=-1, dims=-1, fill_value=-100)
|
1505 |
+
mtp_logits_ = mtp_logits.view(-1, self.config.vocab_size)
|
1506 |
+
mtp_loss = self.loss_function(mtp_logits_, shift_labels_mtp.to(mtp_logits_.device).view(-1), self.config.vocab_size, **kwargs)
|
1507 |
if loss is not None:
|
1508 |
loss += self.mtp_loss_scaling_factor * mtp_loss
|
1509 |
else:
|
|
|
1530 |
attentions=outputs.attentions,
|
1531 |
router_logits=outputs.router_logits,
|
1532 |
)
|
1533 |
+
|