zhanghanxiao commited on
Commit
dc506a5
·
verified ·
1 Parent(s): e854153

Upload modeling_bailing_moe_v2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_bailing_moe_v2.py +13 -11
modeling_bailing_moe_v2.py CHANGED
@@ -25,9 +25,7 @@ from typing import List, Optional, Tuple, Union
25
 
26
  import torch
27
  import torch.nn.functional as F
28
- import torch.utils.checkpoint
29
  from torch import nn
30
- from torch.nn import CrossEntropyLoss
31
 
32
  from transformers.activations import ACT2FN
33
  from transformers.cache_utils import Cache, DynamicCache
@@ -1157,11 +1155,11 @@ class BailingMoeV2Model(BailingMoeV2PreTrainedModel):
1157
  super().__init__(config)
1158
  self.padding_idx = config.pad_token_id
1159
  self.vocab_size = config.vocab_size
1160
- self.num_mtp_layers = config.num_mtp_layers
1161
 
1162
  self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1163
  self.layers = []
1164
- for layer_idx in range(config.num_hidden_layers + config.num_mtp_layers):
1165
  layer_cls = BailingMoeV2DecoderLayer if layer_idx < config.num_hidden_layers else BailingMoeV2MTPLayer
1166
  self.layers.append(layer_cls(config, layer_idx))
1167
 
@@ -1267,8 +1265,8 @@ class BailingMoeV2Model(BailingMoeV2PreTrainedModel):
1267
  all_self_attns = () if output_attentions else None
1268
  all_router_logits = () if output_router_logits else None
1269
  next_decoder_cache = None
1270
- layers = self.layers[: -self.num_mtp_layers] if self.num_mtp_layers > 0 else self.layers
1271
- mtp_layers = self.layers[-self.num_mtp_layers :] if self.num_mtp_layers > 0 else None
1272
 
1273
  for decoder_layer in layers:
1274
  if output_hidden_states:
@@ -1391,7 +1389,7 @@ class BailingMoeV2ForCausalLM(BailingMoeV2PreTrainedModel, GenerationMixin):
1391
  self.model = BailingMoeV2Model(config)
1392
  self.vocab_size = config.vocab_size
1393
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1394
- self.num_mtp_layers = config.num_mtp_layers
1395
  self.mtp_loss_scaling_factor = config.mtp_loss_scaling_factor
1396
 
1397
  # Initialize weights and apply final processing
@@ -1491,18 +1489,21 @@ class BailingMoeV2ForCausalLM(BailingMoeV2PreTrainedModel, GenerationMixin):
1491
  loss = self.loss_function(logits, labels, self.config.vocab_size, **kwargs)
1492
 
1493
  all_mtp_logits = None
1494
- if self.num_mtp_layers > 0:
1495
  mtp_hidden_states = outputs.mtp_hidden_states
1496
- shift_labels_mtp = labels.clone()
1497
- for i in range(self.num_mtp_layers):
1498
  mtp_hidden_states = mtp_hidden_states[i]
1499
  mtp_logits = self.lm_head(mtp_hidden_states).float()
1500
  if all_mtp_logits is None:
1501
  all_mtp_logits = []
1502
  all_mtp_logits.append(mtp_logits)
1503
  if labels is not None:
 
 
1504
  shift_labels_mtp, _ = roll_tensor(shift_labels_mtp, shifts=-1, dims=-1, fill_value=-100)
1505
- mtp_loss = self.loss_function(mtp_logits, shift_labels_mtp, self.config.vocab_size, **kwargs)
 
1506
  if loss is not None:
1507
  loss += self.mtp_loss_scaling_factor * mtp_loss
1508
  else:
@@ -1529,3 +1530,4 @@ class BailingMoeV2ForCausalLM(BailingMoeV2PreTrainedModel, GenerationMixin):
1529
  attentions=outputs.attentions,
1530
  router_logits=outputs.router_logits,
1531
  )
 
 
25
 
26
  import torch
27
  import torch.nn.functional as F
 
28
  from torch import nn
 
29
 
30
  from transformers.activations import ACT2FN
31
  from transformers.cache_utils import Cache, DynamicCache
 
1155
  super().__init__(config)
1156
  self.padding_idx = config.pad_token_id
1157
  self.vocab_size = config.vocab_size
1158
+ self.num_nextn_predict_layers = config.num_nextn_predict_layers
1159
 
1160
  self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1161
  self.layers = []
1162
+ for layer_idx in range(config.num_hidden_layers + config.num_nextn_predict_layers):
1163
  layer_cls = BailingMoeV2DecoderLayer if layer_idx < config.num_hidden_layers else BailingMoeV2MTPLayer
1164
  self.layers.append(layer_cls(config, layer_idx))
1165
 
 
1265
  all_self_attns = () if output_attentions else None
1266
  all_router_logits = () if output_router_logits else None
1267
  next_decoder_cache = None
1268
+ layers = self.layers[: -self.num_nextn_predict_layers] if self.num_nextn_predict_layers > 0 else self.layers
1269
+ mtp_layers = self.layers[-self.num_nextn_predict_layers :] if self.num_nextn_predict_layers > 0 else None
1270
 
1271
  for decoder_layer in layers:
1272
  if output_hidden_states:
 
1389
  self.model = BailingMoeV2Model(config)
1390
  self.vocab_size = config.vocab_size
1391
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1392
+ self.num_nextn_predict_layers = config.num_nextn_predict_layers
1393
  self.mtp_loss_scaling_factor = config.mtp_loss_scaling_factor
1394
 
1395
  # Initialize weights and apply final processing
 
1489
  loss = self.loss_function(logits, labels, self.config.vocab_size, **kwargs)
1490
 
1491
  all_mtp_logits = None
1492
+ if self.num_nextn_predict_layers > 0:
1493
  mtp_hidden_states = outputs.mtp_hidden_states
1494
+ shift_labels_mtp = None
1495
+ for i in range(self.num_nextn_predict_layers):
1496
  mtp_hidden_states = mtp_hidden_states[i]
1497
  mtp_logits = self.lm_head(mtp_hidden_states).float()
1498
  if all_mtp_logits is None:
1499
  all_mtp_logits = []
1500
  all_mtp_logits.append(mtp_logits)
1501
  if labels is not None:
1502
+ if shift_labels_mtp is None:
1503
+ shift_labels_mtp = labels.clone()
1504
  shift_labels_mtp, _ = roll_tensor(shift_labels_mtp, shifts=-1, dims=-1, fill_value=-100)
1505
+ mtp_logits_ = mtp_logits.view(-1, self.config.vocab_size)
1506
+ mtp_loss = self.loss_function(mtp_logits_, shift_labels_mtp.to(mtp_logits_.device).view(-1), self.config.vocab_size, **kwargs)
1507
  if loss is not None:
1508
  loss += self.mtp_loss_scaling_factor * mtp_loss
1509
  else:
 
1530
  attentions=outputs.attentions,
1531
  router_logits=outputs.router_logits,
1532
  )
1533
+