Fizzarolli commited on
Commit
12caeca
·
verified ·
1 Parent(s): 3ec24e2

Upload folder using huggingface_hub

Browse files
configuration_bailing_shared_moe_v2.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Bailing MoE V2 model configuration"""
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class BailingSharedMoeV2Config(PretrainedConfig):
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=157184,
11
+ hidden_size=2048,
12
+ intermediate_size=5120,
13
+ num_hidden_layers=20,
14
+ num_attention_heads=16,
15
+ num_key_value_heads=4,
16
+ hidden_act="silu",
17
+ use_qkv_bias=False, # bailing only
18
+ use_bias=False, # bailing only
19
+ rms_norm_eps=1e-06,
20
+ tie_word_embeddings=False, # PretrainedConfig key, here change default value.
21
+ embedding_dropout=0.0,
22
+ attention_dropout=0.0,
23
+ output_dropout=0.0,
24
+ initializer_range=0.02,
25
+ max_position_embeddings=32768,
26
+ rope_theta=600000.0,
27
+ use_cache=True,
28
+ max_window_layers=20,
29
+ rope_scaling=None,
30
+ pad_token_id=156892,
31
+ eos_token_id=156892,
32
+ num_experts=256,
33
+ num_shared_experts=1,
34
+ num_experts_per_tok=8,
35
+ n_group=8,
36
+ topk_group=4,
37
+ moe_intermediate_size=512,
38
+ first_k_dense_replace=1,
39
+ head_dim=128,
40
+ output_router_logits=False,
41
+ use_qk_norm=True,
42
+ num_nextn_predict_layers=0,
43
+ mtp_loss_scaling_factor=0,
44
+ moe_router_enable_expert_bias=True,
45
+ routed_scaling_factor=1.0,
46
+ **kwargs,
47
+ ):
48
+ self.num_hidden_layers = num_hidden_layers
49
+ self.vocab_size = vocab_size
50
+ self.hidden_size = hidden_size
51
+ self.intermediate_size = intermediate_size
52
+ self.num_attention_heads = num_attention_heads
53
+ self.num_key_value_heads = num_key_value_heads
54
+ self.hidden_act = hidden_act
55
+ self.use_qkv_bias = use_qkv_bias
56
+ self.use_bias = use_bias
57
+ self.rms_norm_eps = rms_norm_eps
58
+ self.embedding_dropout = embedding_dropout
59
+ self.attention_dropout = attention_dropout
60
+ self.output_dropout = output_dropout
61
+ self.num_nextn_predict_layers = num_nextn_predict_layers
62
+ self.mtp_loss_scaling_factor = mtp_loss_scaling_factor
63
+ self.initializer_range = initializer_range
64
+ self.max_position_embeddings = max_position_embeddings
65
+ self.rope_theta = rope_theta
66
+ self.use_cache = use_cache
67
+ self.max_window_layers = max_window_layers
68
+ self.head_dim = head_dim or self.hidden_size // self.num_attention_heads
69
+ self.rope_scaling = rope_scaling
70
+ self.use_qk_norm = use_qk_norm
71
+ self.moe_router_enable_expert_bias = moe_router_enable_expert_bias
72
+ self.routed_scaling_factor = routed_scaling_factor
73
+
74
+ # MoE configs
75
+ self.num_experts = num_experts
76
+ self.num_shared_experts = num_shared_experts
77
+ self.num_experts_per_tok = num_experts_per_tok
78
+ self.n_group = n_group
79
+ self.topk_group = topk_group
80
+ self.moe_intermediate_size = moe_intermediate_size
81
+ self.first_k_dense_replace = first_k_dense_replace
82
+ self.output_router_logits = output_router_logits
83
+
84
+ super().__init__(pad_token_id=pad_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs)
convert_hf_to_scm.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import re
3
+ import shutil
4
+ import sys
5
+
6
+ import accelerate
7
+ import torch
8
+ from configuration_bailing_shared_moe_v2 import BailingSharedMoeV2Config
9
+ from modeling_bailing_shared_moe_v2 import BailingSharedMoeV2ForCausalLM
10
+ from configuration_bailing_moe_v2 import BailingMoeV2Config
11
+ from safetensors import safe_open
12
+
13
+ input_model = sys.argv[1]
14
+ output_model_path = sys.argv[2]
15
+
16
+ auto_map = {
17
+ "AutoConfig": "configuration_bailing_shared_moe_v2.BailingSharedMoeV2Config",
18
+ "AutoModel": "modeling_bailing_shared_moe_v2.BailingSharedMoeV2Model",
19
+ "AutoModelForCausalLM": "modeling_bailing_shared_moe_v2.BailingSharedMoeV2ForCausalLM"
20
+ }
21
+
22
+ cfg_standard_moe = BailingMoeV2Config.from_pretrained(input_model)
23
+ cfg_shared_moe = BailingSharedMoeV2Config(
24
+ auto_map=auto_map,
25
+ model_type="bailing_shared_moe_v2",
26
+ vocab_size=cfg_standard_moe.vocab_size,
27
+ hidden_size=cfg_standard_moe.hidden_size,
28
+ intermediate_size=cfg_standard_moe.intermediate_size,
29
+ num_hidden_layers=cfg_standard_moe.num_hidden_layers,
30
+ num_attention_heads=cfg_standard_moe.num_attention_heads,
31
+ num_key_value_heads=cfg_standard_moe.num_key_value_heads,
32
+ hidden_act=cfg_standard_moe.hidden_act,
33
+ max_position_embeddings=cfg_standard_moe.max_position_embeddings,
34
+ initializer_range=cfg_standard_moe.initializer_range,
35
+ rms_norm_eps=cfg_standard_moe.rms_norm_eps,
36
+ use_cache=cfg_standard_moe.use_cache,
37
+ tie_word_embeddings=cfg_standard_moe.tie_word_embeddings,
38
+ rope_theta=cfg_standard_moe.rope_theta,
39
+ rope_scaling=cfg_standard_moe.rope_scaling,
40
+ max_window_layers=cfg_standard_moe.max_window_layers,
41
+ attention_dropout=cfg_standard_moe.attention_dropout,
42
+ moe_intermediate_size=cfg_standard_moe.moe_intermediate_size,
43
+ num_experts_per_tok=cfg_standard_moe.num_experts_per_tok,
44
+ num_experts=cfg_standard_moe.num_experts,
45
+ num_shared_experts=cfg_standard_moe.num_shared_experts,
46
+ norm_topk_prob=cfg_standard_moe.norm_topk_prob,
47
+ output_router_logits=cfg_standard_moe.output_router_logits,
48
+ shared_expert_intermediate_size=None,
49
+ head_dim=cfg_standard_moe.head_dim,
50
+ embedding_dropout=cfg_standard_moe.embedding_dropout,
51
+ eos_token_id=cfg_standard_moe.eos_token_id,
52
+ first_k_dense_replace=cfg_standard_moe.first_k_dense_replace,
53
+ output_dropout=cfg_standard_moe.output_dropout,
54
+ pad_token_id=cfg_standard_moe.pad_token_id,
55
+ torch_dtype=cfg_standard_moe.torch_dtype,
56
+ use_bias=cfg_standard_moe.use_bias,
57
+ use_qkv_bias=cfg_standard_moe.use_qkv_bias,
58
+ moe_router_enable_expert_bias=cfg_standard_moe.moe_router_enable_expert_bias,
59
+ routed_scaling_factor=cfg_standard_moe.routed_scaling_factor,
60
+ n_group=cfg_standard_moe.n_group,
61
+ topk_group=cfg_standard_moe.topk_group,
62
+ use_qk_norm=cfg_standard_moe.use_qk_norm,
63
+ moe_shared_expert_intermediate_size=cfg_standard_moe.moe_shared_expert_intermediate_size,
64
+ num_nextn_predict_layers=cfg_standard_moe.num_nextn_predict_layers,
65
+ score_function=cfg_standard_moe.score_function,
66
+ router_dtype=cfg_standard_moe.router_dtype,
67
+ use_rmsnorm=cfg_standard_moe.use_rmsnorm,
68
+ partial_rotary_factor=cfg_standard_moe.partial_rotary_factor,
69
+ )
70
+
71
+ num_experts = cfg_standard_moe.num_experts
72
+
73
+ with accelerate.init_empty_weights():
74
+ model_shared_moe = BailingSharedMoeV2ForCausalLM(cfg_shared_moe)
75
+
76
+ model_shared_moe = model_shared_moe.to(torch.bfloat16)
77
+ new_state_dict = {}
78
+ pattern = f"{input_model}/model-*-of-*.safetensors"
79
+ files = sorted(glob.glob(pattern))
80
+
81
+ if len(files) == 0:
82
+ raise FileNotFoundError
83
+ tensors = {}
84
+
85
+ for file_path in files:
86
+ print(f"processing {file_path}")
87
+ with safe_open(file_path, framework="pt", device="cpu") as f:
88
+ for key in f.keys():
89
+ tensor = f.get_tensor(key)
90
+ tensors[key] = tensor
91
+
92
+ for key in tensors:
93
+ if "experts" not in key or "shared_experts" in key:
94
+ new_state_dict[key] = tensors[key]
95
+ elif "experts.0" in key:
96
+ layer_num = int(re.search(r"\d+", key).group())
97
+ new_state_dict[
98
+ f"model.layers.{layer_num}.mlp.moe_mlp.output_experts.weight"
99
+ ] = torch.stack(
100
+ [
101
+ tensors[f"model.layers.{layer_num}.mlp.experts.{i}.down_proj.weight"]
102
+ for i in range(num_experts)
103
+ ]
104
+ )
105
+ new_state_dict[f"model.layers.{layer_num}.mlp.moe_mlp.experts.weight"] = (
106
+ torch.stack(
107
+ [
108
+ torch.cat(
109
+ [
110
+ tensors[
111
+ f"model.layers.{layer_num}.mlp.experts.{i}.up_proj.weight"
112
+ ],
113
+ tensors[
114
+ f"model.layers.{layer_num}.mlp.experts.{i}.gate_proj.weight"
115
+ ],
116
+ ],
117
+ dim=0,
118
+ )
119
+ for i in range(num_experts)
120
+ ]
121
+ )
122
+ )
123
+ model_shared_moe.load_state_dict(new_state_dict, strict=True, assign=True)
124
+ model_shared_moe.save_pretrained(output_model_path)
125
+ cfg_shared_moe.save_pretrained(output_model_path)
126
+
127
+
128
+ shutil.copy(
129
+ "modeling_bailing_shared_moe_v2.py",
130
+ output_model_path + "/" + "modeling_bailing_shared_moe_v2.py",
131
+ )
132
+ shutil.copy(
133
+ "configuration_bailing_shared_moe_v2.py",
134
+ output_model_path + "/" + "configuration_bailing_shared_moe_v2.py",
135
+ )
136
+ for i in ["special_tokens_map.json", "tokenizer_config.json", "tokenizer.json"]:
137
+ shutil.copy(input_model + "/" + i, output_model_path + "/" + i)
convert_scm_to_hf.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import re
3
+ import shutil
4
+ import sys
5
+
6
+ import accelerate
7
+ import torch
8
+ from safetensors import safe_open
9
+ from configuration_bailing_shared_moe_v2 import BailingSharedMoeV2Config
10
+ from modeling_bailing_moe_v2 import BailingMoeV2ForCausalLM
11
+ from configuration_bailing_moe_v2 import BailingMoeV2Config
12
+
13
+ input_model = sys.argv[1]
14
+ output_model_path = sys.argv[2]
15
+
16
+ auto_map = {
17
+ "AutoConfig": "configuration_bailing_moe_v2.BailingMoeV2Config",
18
+ "AutoModel": "modeling_bailing_moe_v2.BailingMoeV2Model",
19
+ "AutoModelForCausalLM": "modeling_bailing_moe_v2.BailingMoeV2ForCausalLM"
20
+ }
21
+ cfg_shared_moe = BailingSharedMoeV2Config.from_pretrained(input_model)
22
+ cfg_standard_moe = BailingMoeV2Config(
23
+ auto_map=auto_map,
24
+ vocab_size=cfg_shared_moe.vocab_size,
25
+ hidden_size=cfg_shared_moe.hidden_size,
26
+ intermediate_size=cfg_shared_moe.intermediate_size,
27
+ num_hidden_layers=cfg_shared_moe.num_hidden_layers,
28
+ num_attention_heads=cfg_shared_moe.num_attention_heads,
29
+ num_key_value_heads=cfg_shared_moe.num_key_value_heads,
30
+ hidden_act=cfg_shared_moe.hidden_act,
31
+ max_position_embeddings=cfg_shared_moe.max_position_embeddings,
32
+ initializer_range=cfg_shared_moe.initializer_range,
33
+ rms_norm_eps=cfg_shared_moe.rms_norm_eps,
34
+ use_cache=cfg_shared_moe.use_cache,
35
+ tie_word_embeddings=cfg_shared_moe.tie_word_embeddings,
36
+ rope_theta=cfg_shared_moe.rope_theta,
37
+ rope_scaling=cfg_shared_moe.rope_scaling,
38
+ max_window_layers=cfg_shared_moe.max_window_layers,
39
+ attention_dropout=cfg_shared_moe.attention_dropout,
40
+ moe_intermediate_size=cfg_shared_moe.moe_intermediate_size,
41
+ num_experts_per_tok=cfg_shared_moe.num_experts_per_tok,
42
+ num_experts=cfg_shared_moe.num_experts,
43
+ num_shared_experts=cfg_shared_moe.num_shared_experts,
44
+ norm_topk_prob=cfg_shared_moe.norm_topk_prob,
45
+ output_router_logits=cfg_shared_moe.output_router_logits,
46
+ shared_expert_intermediate_size=None,
47
+ head_dim=cfg_shared_moe.head_dim,
48
+ embedding_dropout=cfg_shared_moe.embedding_dropout,
49
+ eos_token_id=cfg_shared_moe.eos_token_id,
50
+ first_k_dense_replace=cfg_shared_moe.first_k_dense_replace,
51
+ output_dropout=cfg_shared_moe.output_dropout,
52
+ pad_token_id=cfg_shared_moe.pad_token_id,
53
+ torch_dtype=cfg_shared_moe.torch_dtype,
54
+ use_bias=cfg_shared_moe.use_bias,
55
+ use_qkv_bias=cfg_shared_moe.use_qkv_bias,
56
+ moe_router_enable_expert_bias=cfg_shared_moe.moe_router_enable_expert_bias,
57
+ routed_scaling_factor=cfg_shared_moe.routed_scaling_factor,
58
+ n_group=cfg_shared_moe.n_group,
59
+ topk_group=cfg_shared_moe.topk_group,
60
+ use_qk_norm=cfg_shared_moe.use_qk_norm,
61
+ moe_shared_expert_intermediate_size=cfg_shared_moe.moe_shared_expert_intermediate_size,
62
+ num_nextn_predict_layers=cfg_shared_moe.num_nextn_predict_layers,
63
+ score_function=cfg_shared_moe.score_function,
64
+ router_dtype=cfg_shared_moe.router_dtype,
65
+ use_rmsnorm=cfg_shared_moe.use_rmsnorm,
66
+ partial_rotary_factor=cfg_shared_moe.partial_rotary_factor
67
+ )
68
+ num_experts = cfg_standard_moe.num_experts
69
+
70
+ with accelerate.init_empty_weights():
71
+ model_standard_moe = BailingMoeV2ForCausalLM(cfg_shared_moe)
72
+
73
+ model_standard_moe = model_standard_moe.to(torch.bfloat16)
74
+ new_state_dict = {}
75
+ pattern = f"{input_model}/model-*-of-*.safetensors"
76
+ files = sorted(glob.glob(pattern))
77
+
78
+ if len(files) == 0:
79
+ raise FileNotFoundError
80
+ tensors = {}
81
+
82
+ for file_path in files:
83
+ print(f"processing {file_path}")
84
+ with safe_open(file_path, framework="pt", device="cpu") as f:
85
+ for key in f.keys():
86
+ tensor = f.get_tensor(key)
87
+ tensors[key] = tensor
88
+
89
+ for key in tensors:
90
+ if "moe_mlp" not in key:
91
+ new_state_dict[key] = tensors[key]
92
+ elif "moe_mlp.output_experts" in key:
93
+ layer_num = int(re.search(r"\d+", key).group())
94
+ for i, tensor in enumerate(torch.unbind(tensors[key])):
95
+ new_state_dict[
96
+ f"model.layers.{layer_num}.mlp.experts.{i}.down_proj.weight"
97
+ ] = tensor.contiguous()
98
+ elif "moe_mlp.experts" in key:
99
+ layer_num = int(re.search(r"\d+", key).group())
100
+ for i, tensor in enumerate(torch.unbind(tensors[key])):
101
+ (
102
+ new_state_dict[
103
+ f"model.layers.{layer_num}.mlp.experts.{i}.up_proj.weight"
104
+ ],
105
+ new_state_dict[
106
+ f"model.layers.{layer_num}.mlp.experts.{i}.gate_proj.weight"
107
+ ],
108
+ ) = torch.chunk(tensor, 2, dim=0)
109
+
110
+ model_standard_moe.load_state_dict(new_state_dict, strict=True, assign=True)
111
+ model_standard_moe.save_pretrained(output_model_path)
112
+ cfg_standard_moe.save_pretrained(output_model_path)
113
+
114
+ shutil.copy(
115
+ "modeling_bailing_moe_v2.py",
116
+ output_model_path + "/" + "modeling_bailing_moe_v2.py",
117
+ )
118
+ shutil.copy(
119
+ "configuration_bailing_moe_v2.py",
120
+ output_model_path + "/" + "configuration_bailing_moe_v2.py",
121
+ )
122
+
123
+ for i in ["special_tokens_map.json", "tokenizer_config.json", "tokenizer.json"]:
124
+ shutil.copy(input_model + "/" + i, output_model_path + "/" + i)
modeling_bailing_shared_moe_v2.py ADDED
@@ -0,0 +1,1552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Antgroup and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """PyTorch model."""
21
+
22
+ import math
23
+ import warnings
24
+ from typing import List, Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.nn.functional as F
28
+ from torch import nn
29
+
30
+ from transformers.activations import ACT2FN
31
+ from transformers.cache_utils import Cache, DynamicCache
32
+ from transformers.modeling_attn_mask_utils import (
33
+ AttentionMaskConverter,
34
+ _prepare_4d_attention_mask,
35
+ _prepare_4d_causal_attention_mask,
36
+ _prepare_4d_causal_attention_mask_for_sdpa,
37
+ )
38
+ from transformers.modeling_outputs import MoeModelOutputWithPast
39
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
40
+ from transformers.modeling_utils import PreTrainedModel
41
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
42
+ from transformers.utils import (
43
+ add_start_docstrings,
44
+ add_start_docstrings_to_model_forward,
45
+ is_flash_attn_2_available,
46
+ is_flash_attn_greater_or_equal_2_10,
47
+ logging,
48
+ replace_return_docstrings,
49
+ )
50
+ from transformers.utils.import_utils import is_torch_fx_available
51
+ from .configuration_bailing_shared_moe_v2 import BailingSharedMoeV2Config
52
+ from transformers.generation.utils import GenerationMixin
53
+ from dataclasses import dataclass
54
+ from transformers.utils import ModelOutput
55
+
56
+ #from cut_cross_entropy.transformers.utils import apply_lce
57
+ from cut_cross_entropy import linear_cross_entropy
58
+ from typing import TypedDict
59
+
60
+ class CCEPreset(TypedDict):
61
+ filter_eps: float | str | None
62
+ accum_e_fp32: bool
63
+ accum_c_fp32: bool
64
+ filter_e_grad: bool
65
+ filter_c_grad: bool
66
+
67
+ class CCEKwargs(CCEPreset):
68
+ impl: str
69
+ reduction: str
70
+
71
+ @dataclass
72
+ class PatchOptions:
73
+ impl: str
74
+ reduction: str
75
+ filter_eps: float | str | None
76
+ accum_e_fp32: bool
77
+ accum_c_fp32: bool
78
+ filter_e_grad: bool
79
+ filter_c_grad: bool
80
+ train_only: bool
81
+
82
+ def to_kwargs(self) -> CCEKwargs:
83
+ return CCEKwargs(
84
+ impl=self.impl,
85
+ reduction=self.reduction,
86
+ filter_eps=self.filter_eps,
87
+ accum_e_fp32=self.accum_e_fp32,
88
+ accum_c_fp32=self.accum_c_fp32,
89
+ filter_e_grad=self.filter_e_grad,
90
+ filter_c_grad=self.filter_c_grad,
91
+ )
92
+
93
+ _PATCH_OPTS = PatchOptions(
94
+ impl="cce",
95
+ reduction="mean",
96
+ filter_eps="auto",
97
+ accum_e_fp32=False,
98
+ accum_c_fp32=False,
99
+ filter_e_grad=True,
100
+ filter_c_grad=True,
101
+ train_only=False,
102
+ )
103
+
104
+ def apply_lce(
105
+ e: torch.Tensor,
106
+ c: torch.Tensor,
107
+ labels: torch.Tensor,
108
+ opts,
109
+ bias: torch.Tensor | None = None,
110
+ **loss_kwargs,
111
+ ) -> torch.Tensor:
112
+ num_items_in_batch = loss_kwargs.get("num_items_in_batch", None)
113
+ cce_kwargs = opts.to_kwargs()
114
+ if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean":
115
+ cce_kwargs["reduction"] = "sum"
116
+ else:
117
+ num_items_in_batch = None
118
+
119
+ loss = linear_cross_entropy(
120
+ e,
121
+ c,
122
+ labels.to(e.device),
123
+ bias=bias,
124
+ shift=True,
125
+ **cce_kwargs,
126
+ )
127
+
128
+ if num_items_in_batch is not None:
129
+ loss = loss / num_items_in_batch
130
+
131
+ return loss
132
+
133
+ import scattermoe
134
+
135
+ if is_flash_attn_2_available():
136
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
137
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
138
+
139
+
140
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
141
+ # It means that the function will not be traced through and simply appear as a node in the graph.
142
+ if is_torch_fx_available():
143
+ if not is_torch_greater_or_equal_than_1_13:
144
+ import torch.fx
145
+
146
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
147
+
148
+
149
+ logger = logging.get_logger(__name__)
150
+
151
+ _CONFIG_FOR_DOC = "BailingSharedMoeV2Config"
152
+
153
+
154
+ def roll_tensor(tensor, shifts=-1, dims=-1, fill_value=0):
155
+ """Roll the tensor input along the given dimension(s).
156
+ Inserted elements are set to be 0.0.
157
+ """
158
+ rolled_tensor = torch.roll(tensor, shifts=shifts, dims=dims)
159
+ rolled_tensor.select(dims, shifts).fill_(fill_value)
160
+ return rolled_tensor, rolled_tensor.sum()
161
+
162
+
163
+ @dataclass
164
+ class MoEV2CausalLMOutputWithPast(ModelOutput):
165
+ """
166
+ Base class for causal language model (or autoregressive) outputs as well as Mixture of Expert's router hidden
167
+ states terms, to train a MoE model.
168
+
169
+ Args:
170
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
171
+ Language modeling loss (for next-token prediction).
172
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
173
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
174
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
175
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
176
+
177
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
178
+ `past_key_values` input) to speed up sequential decoding.
179
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
180
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
181
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
182
+
183
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
184
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
185
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
186
+ sequence_length)`.
187
+
188
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
189
+ heads.
190
+ z_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
191
+ z_loss for the sparse modules.
192
+ aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
193
+ aux_loss for the sparse modules.
194
+ router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
195
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
196
+
197
+ Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse
198
+ modules.
199
+ """
200
+
201
+ loss: Optional[torch.FloatTensor] = None
202
+ logits: Optional[torch.FloatTensor] = None
203
+ past_key_values: Optional[Cache] = None
204
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
205
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
206
+ z_loss: Optional[torch.FloatTensor] = None
207
+ aux_loss: Optional[torch.FloatTensor] = None
208
+ router_logits: Optional[tuple[torch.FloatTensor]] = None
209
+ mtp_loss: Optional[torch.FloatTensor] = None
210
+ mtp_logits: Optional[tuple[torch.FloatTensor, ...]] = None
211
+
212
+
213
+ class MoeV2ModelOutputWithPast(MoeModelOutputWithPast):
214
+
215
+ def __init__(self, mtp_hidden_states=None, **kwargs):
216
+ super().__init__(**kwargs)
217
+ self.mtp_hidden_states = mtp_hidden_states
218
+
219
+
220
+ def _get_unpad_data(attention_mask):
221
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
222
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
223
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
224
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
225
+ return (
226
+ indices,
227
+ cu_seqlens,
228
+ max_seqlen_in_batch,
229
+ )
230
+
231
+
232
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
233
+ warnings.warn(
234
+ "Calling `transformers.models.BailingSharedMoeV2.modeling_BailingSharedMoeV2._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask"
235
+ )
236
+ return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
237
+
238
+
239
+ def _make_causal_mask(
240
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
241
+ ):
242
+ warnings.warn(
243
+ "Calling `transformers.models.BailingSharedMoeV2.modeling_BailingSharedMoeV2._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.BailingSharedMoeV2.modeling_BailingSharedMoeV2.AttentionMaskConverter._make_causal_mask"
244
+ )
245
+ return AttentionMaskConverter._make_causal_mask(
246
+ input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
247
+ )
248
+
249
+
250
+ class BailingSharedMoeV2RMSNorm(nn.Module):
251
+ def __init__(self, hidden_size, eps=1e-6):
252
+ """
253
+ BailingSharedMoeV2RMSNorm is equivalent to T5LayerNorm
254
+ """
255
+ super().__init__()
256
+ self.weight = nn.Parameter(torch.ones(hidden_size))
257
+ self.variance_epsilon = eps
258
+
259
+ def forward(self, hidden_states):
260
+ input_dtype = hidden_states.dtype
261
+ hidden_states = hidden_states.to(torch.float32)
262
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
263
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
264
+ return self.weight * hidden_states.to(input_dtype)
265
+
266
+
267
+ ALL_LAYERNORM_LAYERS.append(BailingSharedMoeV2RMSNorm)
268
+
269
+
270
+ class BailingSharedMoeV2RotaryEmbedding(nn.Module):
271
+ def __init__(self, config: BailingSharedMoeV2Config, device=None):
272
+ super().__init__()
273
+ # BC: "rope_type" was originally "type"
274
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
275
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
276
+ else:
277
+ self.rope_type = "default"
278
+ self.max_seq_len_cached = config.max_position_embeddings
279
+ self.original_max_seq_len = config.max_position_embeddings
280
+
281
+ self.config = config
282
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
283
+
284
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
285
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
286
+ self.original_inv_freq = self.inv_freq
287
+
288
+ @torch.no_grad()
289
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
290
+ def forward(self, x, position_ids):
291
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
292
+ position_ids_expanded = position_ids[:, None, :].float()
293
+
294
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
295
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
296
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
297
+ emb = torch.cat((freqs, freqs), dim=-1)
298
+ cos = emb.cos() * self.attention_scaling
299
+ sin = emb.sin() * self.attention_scaling
300
+
301
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
302
+
303
+
304
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
305
+ def rotate_half(x):
306
+ """Rotates half the hidden dims of the input."""
307
+ x1 = x[..., : x.shape[-1] // 2]
308
+ x2 = x[..., x.shape[-1] // 2 :]
309
+ return torch.cat((-x2, x1), dim=-1)
310
+
311
+
312
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
313
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
314
+ """Applies Rotary Position Embedding to the query and key tensors.
315
+
316
+ Args:
317
+ q (`torch.Tensor`): The query tensor.
318
+ k (`torch.Tensor`): The key tensor.
319
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
320
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
321
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
322
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
323
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
324
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
325
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
326
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
327
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
328
+ Returns:
329
+ `tuple(torch.Tensor)` comprising the query and key tensors rotated using the Rotary Position Embedding.
330
+ """
331
+ cos = cos.unsqueeze(unsqueeze_dim)
332
+ sin = sin.unsqueeze(unsqueeze_dim)
333
+
334
+ # Keep half or full tensor for later concatenation
335
+ rotary_dim = cos.shape[-1]
336
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
337
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
338
+
339
+ # Apply rotary embeddings on the first half or full tensor
340
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
341
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
342
+
343
+ # Concatenate back to full shape
344
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
345
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
346
+ return q_embed, k_embed
347
+
348
+
349
+ class BailingSharedMoeV2MLP(nn.Module):
350
+ def __init__(self, config: BailingSharedMoeV2Config, intermediate_size: int):
351
+ super().__init__()
352
+ self.config = config
353
+ self.hidden_size = config.hidden_size
354
+ self.intermediate_size = intermediate_size
355
+
356
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
357
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
358
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
359
+ self.act_fn = ACT2FN[config.hidden_act]
360
+
361
+ def forward(self, x):
362
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
363
+
364
+
365
+ class BailingSharedMoeV2Gate(nn.Module):
366
+ def __init__(self, config):
367
+ super().__init__()
368
+ self.config = config
369
+ self.top_k = config.num_experts_per_tok
370
+ self.num_experts = config.num_experts
371
+
372
+ self.n_group = config.n_group
373
+ self.topk_group = config.topk_group
374
+
375
+ # topk selection algorithm
376
+ self.gating_dim = config.hidden_size
377
+ self.weight = nn.Parameter(torch.empty((self.num_experts, self.gating_dim)))
378
+ self.routed_scaling_factor = config.routed_scaling_factor
379
+
380
+ self.register_buffer("expert_bias", torch.zeros((self.num_experts)))
381
+ self.reset_parameters()
382
+
383
+ def reset_parameters(self) -> None:
384
+ import torch.nn.init as init
385
+
386
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
387
+
388
+ def group_limited_topk(
389
+ self,
390
+ scores: torch.Tensor,
391
+ ):
392
+ num_tokens, _ = scores.size()
393
+ # Organize the experts into groups
394
+ group_scores = scores.view(num_tokens, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
395
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
396
+ group_mask = torch.zeros_like(group_scores)
397
+ group_mask.scatter_(1, group_idx, 1)
398
+
399
+ # Mask the experts based on selection groups
400
+ score_mask = (
401
+ group_mask.unsqueeze(-1)
402
+ .expand(num_tokens, self.n_group, self.num_experts // self.n_group)
403
+ .reshape(num_tokens, -1)
404
+ )
405
+
406
+ masked_scores = scores.masked_fill(~score_mask.bool(), float('-inf'))
407
+ probs, top_indices = torch.topk(masked_scores, k=self.top_k, dim=-1)
408
+
409
+ return probs, top_indices
410
+
411
+ def forward(self, hidden_states):
412
+ # compute gating score
413
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
414
+ logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
415
+
416
+ scores = torch.sigmoid(logits.float()).type_as(logits)
417
+
418
+ scores_for_routing = scores + self.expert_bias.to(scores.device)
419
+ _, topk_idx = self.group_limited_topk(scores_for_routing)
420
+
421
+ scores = torch.gather(scores, dim=1, index=topk_idx).type_as(logits)
422
+
423
+ topk_weight = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.top_k > 1 else scores
424
+ topk_weight = topk_weight * self.routed_scaling_factor
425
+
426
+ return topk_idx, topk_weight, logits
427
+
428
+
429
+ class BailingSharedMoeV2SparseMoeBlock(nn.Module):
430
+ """
431
+ A mixed expert module containing shared experts.
432
+ """
433
+
434
+ def __init__(self, config: BailingSharedMoeV2Config):
435
+ super().__init__()
436
+ self.config = config
437
+ self.num_experts_per_tok = config.num_experts_per_tok
438
+ self.moe_mlp = scattermoe.mlp.GLUMLP(
439
+ input_size=self.config.hidden_size,
440
+ hidden_size=self.config.moe_intermediate_size,
441
+ num_experts=self.config.num_experts,
442
+ top_k=self.config.num_experts_per_tok,
443
+ activation=ACT2FN[config.hidden_act],
444
+ )
445
+ self.gate = BailingSharedMoeV2Gate(config)
446
+ if config.num_shared_experts is not None:
447
+ self.shared_experts = BailingSharedMoeV2MLP(
448
+ config=config, intermediate_size=config.moe_intermediate_size * config.num_shared_experts
449
+ )
450
+
451
+
452
+ def forward(self, hidden_states):
453
+ identity = hidden_states
454
+ bsz, seq_len, h = hidden_states.shape
455
+ topk_idx, topk_weight, router_logits = self.gate(hidden_states)
456
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
457
+ flat_topk_idx = topk_idx.view(-1)
458
+ y = self.moe_mlp(hidden_states, topk_weight.to(torch.bfloat16), flat_topk_idx)
459
+ if self.config.num_shared_experts is not None:
460
+ y = y + self.shared_experts(identity)
461
+ return y, (router_logits.view(bsz, seq_len, -1), topk_idx.view(bsz, seq_len, -1))
462
+
463
+
464
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
465
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
466
+ """
467
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
468
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
469
+ """
470
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
471
+ if n_rep == 1:
472
+ return hidden_states
473
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
474
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
475
+
476
+
477
+ # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->BailingSharedMoeV2
478
+ class BailingSharedMoeV2Attention(nn.Module):
479
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
480
+
481
+ def __init__(self, config: BailingSharedMoeV2Config, layer_idx: Optional[int] = None):
482
+ super().__init__()
483
+ self.config = config
484
+ self.layer_idx = layer_idx
485
+ if layer_idx is None:
486
+ logger.warning_once(
487
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
488
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
489
+ "when creating this class."
490
+ )
491
+
492
+ self.attention_dropout = config.attention_dropout
493
+ self.hidden_size = config.hidden_size
494
+ self.num_heads = config.num_attention_heads
495
+ self.head_dim = config.head_dim or self.hidden_size // self.num_heads
496
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
497
+ self.rope_dim = int(self.head_dim * partial_rotary_factor)
498
+ self.num_key_value_heads = config.num_key_value_heads
499
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
500
+ self.max_position_embeddings = config.max_position_embeddings
501
+ self.rope_theta = config.rope_theta
502
+ self.is_causal = True
503
+
504
+ self.query_key_value = nn.Linear(
505
+ self.hidden_size,
506
+ (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
507
+ bias=config.use_qkv_bias,
508
+ )
509
+
510
+ if self.config.use_qk_norm:
511
+ self.query_layernorm = BailingSharedMoeV2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
512
+ self.key_layernorm = BailingSharedMoeV2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
513
+ self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.use_bias)
514
+
515
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
516
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
517
+
518
+ def forward(
519
+ self,
520
+ hidden_states: torch.Tensor,
521
+ attention_mask: Optional[torch.Tensor] = None,
522
+ position_ids: Optional[torch.LongTensor] = None,
523
+ past_key_value: Optional[Cache] = None,
524
+ output_attentions: bool = False,
525
+ use_cache: bool = False,
526
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
527
+ **kwargs,
528
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
529
+
530
+ bsz, q_len, _ = hidden_states.size()
531
+
532
+ qkv = self.query_key_value(hidden_states)
533
+ qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
534
+
535
+ query_states, key_states, value_states = qkv.split(
536
+ [self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2
537
+ )
538
+ query_states = query_states.transpose(1, 2)
539
+ key_states = key_states.transpose(1, 2)
540
+ value_states = value_states.transpose(1, 2)
541
+
542
+ if self.config.use_qk_norm:
543
+ query_states = self.query_layernorm(query_states)
544
+ key_states = self.key_layernorm(key_states)
545
+
546
+ cos, sin = position_embeddings
547
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
548
+
549
+ if past_key_value is not None:
550
+ if self.layer_idx is None:
551
+ raise ValueError(
552
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
553
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
554
+ "with a layer index."
555
+ )
556
+ cache_kwargs = {"sin": sin, "cos": cos}
557
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
558
+
559
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
560
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
561
+
562
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
563
+
564
+ kv_seq_len = key_states.shape[-2]
565
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
566
+ raise ValueError(
567
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
568
+ f" {attn_weights.size()}"
569
+ )
570
+
571
+ if attention_mask is not None:
572
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
573
+ raise ValueError(
574
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
575
+ )
576
+ attn_weights = attn_weights + attention_mask
577
+
578
+ # upcast attention to fp32
579
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
580
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
581
+ attn_output = torch.matmul(attn_weights, value_states)
582
+
583
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
584
+ raise ValueError(
585
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
586
+ f" {attn_output.size()}"
587
+ )
588
+
589
+ attn_output = attn_output.transpose(1, 2).contiguous()
590
+
591
+ attn_output = attn_output.reshape(bsz, q_len, -1)
592
+
593
+ attn_output = self.dense(attn_output)
594
+
595
+ if not output_attentions:
596
+ attn_weights = None
597
+
598
+ return attn_output, attn_weights, past_key_value
599
+
600
+
601
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->BailingSharedMoeV2
602
+ class BailingSharedMoeV2FlashAttention2(BailingSharedMoeV2Attention):
603
+ """
604
+ BailingSharedMoeV2 flash attention module. This module inherits from `BailingSharedMoeV2Attention` as the weights of the module stays
605
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
606
+ flash attention and deal with padding tokens in case the input contains any of them.
607
+ """
608
+
609
+ def __init__(self, *args, **kwargs):
610
+ super().__init__(*args, **kwargs)
611
+
612
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
613
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
614
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
615
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
616
+
617
+ def forward(
618
+ self,
619
+ hidden_states: torch.Tensor,
620
+ attention_mask: Optional[torch.LongTensor] = None,
621
+ position_ids: Optional[torch.LongTensor] = None,
622
+ past_key_value: Optional[Cache] = None,
623
+ output_attentions: bool = False,
624
+ use_cache: bool = False,
625
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
626
+ **kwargs,
627
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
628
+ # BailingSharedMoeV2FlashAttention2 attention does not support output_attentions
629
+ output_attentions = False
630
+
631
+ bsz, q_len, _ = hidden_states.size()
632
+
633
+ # Flash attention requires the input to have the shape
634
+ # batch_size x seq_length x head_dim x hidden_dim
635
+ # therefore we just need to keep the original shape
636
+
637
+ qkv = self.query_key_value(hidden_states)
638
+ qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
639
+
640
+ query_states, key_states, value_states = qkv.split(
641
+ [self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2
642
+ )
643
+ query_states = query_states.transpose(1, 2)
644
+ key_states = key_states.transpose(1, 2)
645
+ value_states = value_states.transpose(1, 2)
646
+
647
+ if self.config.use_qk_norm:
648
+ query_states = self.query_layernorm(query_states)
649
+ key_states = self.key_layernorm(key_states)
650
+
651
+ cos, sin = position_embeddings
652
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
653
+
654
+ if past_key_value is not None:
655
+ cache_kwargs = {"sin": sin, "cos": cos}
656
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
657
+
658
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
659
+ # to be able to avoid many of these transpose/reshape/view.
660
+ query_states = query_states.transpose(1, 2)
661
+ key_states = key_states.transpose(1, 2)
662
+ value_states = value_states.transpose(1, 2)
663
+
664
+ dropout_rate = self.attention_dropout if self.training else 0.0
665
+
666
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
667
+ # therefore the input hidden states gets silently cast in float32. Hence, we need
668
+ # cast them back in the correct dtype just to be sure everything works as expected.
669
+ # This might slow down training & inference so it is recommended to not cast the LayerNorms
670
+ # in fp32. (BailingSharedMoeV2RMSNorm handles it correctly)
671
+
672
+ input_dtype = query_states.dtype
673
+ if input_dtype == torch.float32:
674
+ # Handle the case where the model is quantized
675
+ if hasattr(self.config, "_pre_quantization_dtype"):
676
+ target_dtype = self.config._pre_quantization_dtype
677
+ elif torch.is_autocast_enabled():
678
+ target_dtype = torch.get_autocast_gpu_dtype()
679
+ else:
680
+ target_dtype = self.query_key_value.weight.dtype
681
+
682
+ logger.warning_once(
683
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
684
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
685
+ f" {target_dtype}."
686
+ )
687
+
688
+ query_states = query_states.to(target_dtype)
689
+ key_states = key_states.to(target_dtype)
690
+ value_states = value_states.to(target_dtype)
691
+
692
+ attn_output = self._flash_attention_forward(
693
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
694
+ )
695
+
696
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
697
+ attn_output = self.dense(attn_output)
698
+
699
+ if not output_attentions:
700
+ attn_weights = None
701
+
702
+ return attn_output, attn_weights, past_key_value
703
+
704
+ def _flash_attention_forward(
705
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
706
+ ):
707
+ """
708
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
709
+ first unpad the input, then computes the attention scores and pad the final attention scores.
710
+
711
+ Args:
712
+ query_states (`torch.Tensor`):
713
+ Input query states to be passed to Flash Attention API
714
+ key_states (`torch.Tensor`):
715
+ Input key states to be passed to Flash Attention API
716
+ value_states (`torch.Tensor`):
717
+ Input value states to be passed to Flash Attention API
718
+ attention_mask (`torch.Tensor`):
719
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
720
+ position of padding tokens and 1 for the position of non-padding tokens.
721
+ dropout (`int`, *optional*):
722
+ Attention dropout
723
+ softmax_scale (`float`, *optional*):
724
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
725
+ query_length (`int`):
726
+ The length of the query sequence in terms of tokens. This represents the number of tokens in the
727
+ `query_states` tensor along the sequence dimension. It is used to determine the effective sequence
728
+ length for attention computations.
729
+ """
730
+ if not self._flash_attn_uses_top_left_mask:
731
+ causal = self.is_causal
732
+ else:
733
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in BailingSharedMoeV2FlashAttention2 __init__.
734
+ causal = self.is_causal and query_length != 1
735
+
736
+ # Contains at least one padding token in the sequence
737
+ if attention_mask is not None:
738
+ batch_size = query_states.shape[0]
739
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
740
+ query_states, key_states, value_states, attention_mask, query_length
741
+ )
742
+
743
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
744
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
745
+
746
+ attn_output_unpad = flash_attn_varlen_func(
747
+ query_states,
748
+ key_states,
749
+ value_states,
750
+ cu_seqlens_q=cu_seqlens_q,
751
+ cu_seqlens_k=cu_seqlens_k,
752
+ max_seqlen_q=max_seqlen_in_batch_q,
753
+ max_seqlen_k=max_seqlen_in_batch_k,
754
+ dropout_p=dropout,
755
+ softmax_scale=softmax_scale,
756
+ causal=causal,
757
+ )
758
+
759
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
760
+ else:
761
+ attn_output = flash_attn_func(
762
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
763
+ )
764
+
765
+ return attn_output
766
+
767
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
768
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
769
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
770
+
771
+ key_layer = index_first_axis(
772
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
773
+ )
774
+ value_layer = index_first_axis(
775
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
776
+ )
777
+ if query_length == kv_seq_len:
778
+ query_layer = index_first_axis(
779
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
780
+ )
781
+ cu_seqlens_q = cu_seqlens_k
782
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
783
+ indices_q = indices_k
784
+ elif query_length == 1:
785
+ max_seqlen_in_batch_q = 1
786
+ cu_seqlens_q = torch.arange(
787
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
788
+ ) # There is a memcpy here, that is very bad.
789
+ indices_q = cu_seqlens_q[:-1]
790
+ query_layer = query_layer.squeeze(1)
791
+ else:
792
+ # The -q_len: slice assumes left padding.
793
+ attention_mask = attention_mask[:, -query_length:]
794
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
795
+
796
+ return (
797
+ query_layer,
798
+ key_layer,
799
+ value_layer,
800
+ indices_q,
801
+ (cu_seqlens_q, cu_seqlens_k),
802
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
803
+ )
804
+
805
+
806
+ # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->BailingSharedMoeV2
807
+ class BailingSharedMoeV2SdpaAttention(BailingSharedMoeV2Attention):
808
+ """
809
+ BailingSharedMoeV2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
810
+ `BailingSharedMoeV2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
811
+ SDPA API.
812
+ """
813
+
814
+ # Adapted from BailingSharedMoeV2Attention.forward
815
+ def forward(
816
+ self,
817
+ hidden_states: torch.Tensor,
818
+ attention_mask: Optional[torch.Tensor] = None,
819
+ position_ids: Optional[torch.LongTensor] = None,
820
+ past_key_value: Optional[Cache] = None,
821
+ output_attentions: bool = False,
822
+ use_cache: bool = False,
823
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
824
+ **kwargs,
825
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
826
+ if output_attentions:
827
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
828
+ logger.warning_once(
829
+ "BailingSharedMoeV2Model is using BailingSharedMoeV2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
830
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
831
+ )
832
+ return super().forward(
833
+ hidden_states=hidden_states,
834
+ attention_mask=attention_mask,
835
+ position_ids=position_ids,
836
+ past_key_value=past_key_value,
837
+ output_attentions=output_attentions,
838
+ use_cache=use_cache,
839
+ )
840
+
841
+ bsz, q_len, _ = hidden_states.size()
842
+
843
+ qkv = self.query_key_value(hidden_states)
844
+ qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
845
+
846
+ query_states, key_states, value_states = qkv.split(
847
+ [self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2
848
+ )
849
+ query_states = query_states.transpose(1, 2)
850
+ key_states = key_states.transpose(1, 2)
851
+ value_states = value_states.transpose(1, 2)
852
+
853
+ if self.config.use_qk_norm:
854
+ query_states = self.query_layernorm(query_states)
855
+ key_states = self.key_layernorm(key_states)
856
+
857
+ cos, sin = position_embeddings
858
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
859
+
860
+ if past_key_value is not None:
861
+ cache_kwargs = {"sin": sin, "cos": cos}
862
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
863
+
864
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
865
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
866
+
867
+ if attention_mask is not None:
868
+ kv_seq_len = key_states.shape[-2]
869
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
870
+ raise ValueError(
871
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
872
+ )
873
+
874
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
875
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
876
+ if query_states.device.type == "cuda" and attention_mask is not None:
877
+ query_states = query_states.contiguous()
878
+ key_states = key_states.contiguous()
879
+ value_states = value_states.contiguous()
880
+
881
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
882
+ query_states,
883
+ key_states,
884
+ value_states,
885
+ attn_mask=attention_mask,
886
+ dropout_p=self.attention_dropout if self.training else 0.0,
887
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
888
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
889
+ )
890
+
891
+ attn_output = attn_output.transpose(1, 2).contiguous()
892
+ attn_output = attn_output.reshape(bsz, q_len, -1)
893
+
894
+ attn_output = self.dense(attn_output)
895
+
896
+ return attn_output, None, past_key_value
897
+
898
+
899
+ ATTENTION_CLASSES = {
900
+ "eager": BailingSharedMoeV2Attention,
901
+ "flash_attention_2": BailingSharedMoeV2FlashAttention2,
902
+ "sdpa": BailingSharedMoeV2SdpaAttention,
903
+ }
904
+
905
+
906
+ class BailingSharedMoeV2MTPLayer(nn.Module):
907
+ def __init__(self, config: BailingSharedMoeV2Config, layer_idx: int):
908
+ super().__init__()
909
+ self.layer_idx = layer_idx
910
+ self.input_layernorm = BailingSharedMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
911
+ self.enorm = BailingSharedMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
912
+
913
+ self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
914
+ self.post_attention_layernorm = BailingSharedMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
915
+ self.attention = ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
916
+ self.mlp = BailingSharedMoeV2SparseMoeBlock(config)
917
+
918
+ self.hnorm = BailingSharedMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
919
+ self.final_layernorm = BailingSharedMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
920
+
921
+ def forward(
922
+ self,
923
+ input_embeds,
924
+ hidden_states: torch.Tensor,
925
+ attention_mask: Optional[torch.Tensor] = None,
926
+ position_ids: Optional[torch.LongTensor] = None,
927
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
928
+ output_attentions: Optional[bool] = False,
929
+ output_router_logits: Optional[bool] = False,
930
+ use_cache: Optional[bool] = False,
931
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
932
+ **kwargs,
933
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
934
+ input_embeds = self.enorm(input_embeds)
935
+ hidden_states = self.hnorm(hidden_states)
936
+ hidden_states = self.eh_proj(torch.cat([input_embeds, hidden_states], dim=-1))
937
+ residual = hidden_states
938
+
939
+ hidden_states = self.input_layernorm(hidden_states)
940
+
941
+ # Self Attention
942
+ hidden_states, self_attn_weights, present_key_value = self.attention(
943
+ hidden_states=hidden_states,
944
+ attention_mask=attention_mask,
945
+ position_ids=position_ids,
946
+ past_key_value=past_key_value,
947
+ output_attentions=output_attentions,
948
+ position_embeddings=position_embeddings,
949
+ use_cache=use_cache,
950
+ )
951
+ hidden_states = residual + hidden_states
952
+
953
+ # Fully Connected
954
+ residual = hidden_states
955
+ hidden_states = self.post_attention_layernorm(hidden_states)
956
+ hidden_states = self.mlp(hidden_states)
957
+ if isinstance(hidden_states, tuple):
958
+ hidden_states, router_logits = hidden_states
959
+ else:
960
+ router_logits = None
961
+ hidden_states = residual + hidden_states.to(residual.device)
962
+ hidden_states = self.final_layernorm(hidden_states)
963
+
964
+ outputs = (hidden_states,)
965
+
966
+ if output_attentions:
967
+ outputs += (self_attn_weights,)
968
+
969
+ if use_cache:
970
+ outputs += (present_key_value,)
971
+
972
+ if output_router_logits:
973
+ outputs += (router_logits,)
974
+
975
+ return outputs
976
+
977
+
978
+ class BailingSharedMoeV2DecoderLayer(nn.Module):
979
+ def __init__(self, config: BailingSharedMoeV2Config, layer_idx: int):
980
+ super().__init__()
981
+ self.hidden_size = config.hidden_size
982
+
983
+ self.attention = ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
984
+
985
+ self.mlp = (
986
+ BailingSharedMoeV2SparseMoeBlock(config)
987
+ if (config.num_experts is not None and layer_idx >= config.first_k_dense_replace)
988
+ else BailingSharedMoeV2MLP(config=config, intermediate_size=config.intermediate_size)
989
+ )
990
+ self.input_layernorm = BailingSharedMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
991
+ self.post_attention_layernorm = BailingSharedMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
992
+
993
+ def forward(
994
+ self,
995
+ hidden_states: torch.Tensor,
996
+ attention_mask: Optional[torch.Tensor] = None,
997
+ position_ids: Optional[torch.LongTensor] = None,
998
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
999
+ output_attentions: Optional[bool] = False,
1000
+ output_router_logits: Optional[bool] = False,
1001
+ use_cache: Optional[bool] = False,
1002
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
1003
+ **kwargs,
1004
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1005
+ """
1006
+ Args:
1007
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1008
+ attention_mask (`torch.FloatTensor`, *optional*):
1009
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
1010
+ query_sequence_length, key_sequence_length)` if default attention is used.
1011
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1012
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1013
+ config.n_positions - 1]`.
1014
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*):
1015
+ cached past key and value projection states
1016
+ output_attentions (`bool`, *optional*):
1017
+ Whether to return the attentions tensors of all attention layers. See `attentions` under
1018
+ returned tensors for more detail.
1019
+ output_router_logits (`bool`, *optional*):
1020
+ Whether or not to return the logits of all the routers. They are useful for computing the router loss,
1021
+ and should not be returned during inference.
1022
+ use_cache (`bool`, *optional*):
1023
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1024
+ (see `past_key_values`).
1025
+ """
1026
+ residual = hidden_states
1027
+
1028
+ hidden_states = self.input_layernorm(hidden_states)
1029
+
1030
+ # Self Attention
1031
+ hidden_states, self_attn_weights, present_key_value = self.attention(
1032
+ hidden_states=hidden_states,
1033
+ attention_mask=attention_mask,
1034
+ position_ids=position_ids,
1035
+ past_key_value=past_key_value,
1036
+ output_attentions=output_attentions,
1037
+ position_embeddings=position_embeddings,
1038
+ use_cache=use_cache,
1039
+ )
1040
+ hidden_states = residual + hidden_states
1041
+
1042
+ # Fully Connected
1043
+ residual = hidden_states
1044
+ hidden_states = self.post_attention_layernorm(hidden_states)
1045
+ hidden_states = self.mlp(hidden_states)
1046
+ if isinstance(hidden_states, tuple):
1047
+ hidden_states, router_logits = hidden_states
1048
+ else:
1049
+ router_logits = None
1050
+ hidden_states = residual + hidden_states.to(residual.device)
1051
+
1052
+ outputs = (hidden_states,)
1053
+
1054
+ if output_attentions:
1055
+ outputs += (self_attn_weights,)
1056
+
1057
+ if use_cache:
1058
+ outputs += (present_key_value,)
1059
+
1060
+ if output_router_logits:
1061
+ outputs += (router_logits,)
1062
+
1063
+ return outputs
1064
+
1065
+
1066
+ BAILINGMOEV2_START_DOCSTRING = r"""
1067
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1068
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1069
+ etc.)
1070
+
1071
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1072
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1073
+ and behavior.
1074
+
1075
+ Parameters:
1076
+ config ([`BailingSharedMoeV2Config`]):
1077
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1078
+ load the weights associated with the model, only the configuration. Check out the
1079
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1080
+ """
1081
+
1082
+
1083
+ @add_start_docstrings(
1084
+ "The bare BailingSharedMoeV2 Model outputting raw hidden-states without any specific head on top.",
1085
+ BAILINGMOEV2_START_DOCSTRING,
1086
+ )
1087
+ class BailingSharedMoeV2PreTrainedModel(PreTrainedModel):
1088
+ config_class = BailingSharedMoeV2Config
1089
+ base_model_prefix = "model"
1090
+ supports_gradient_checkpointing = True
1091
+ _no_split_modules = ["BailingSharedMoeV2DecoderLayer"]
1092
+ _skip_keys_device_placement = "past_key_values"
1093
+ _supports_flash_attn_2 = True
1094
+ _supports_sdpa = True
1095
+ _supports_cache_class = True
1096
+
1097
+ def _init_weights(self, module):
1098
+ std = self.config.initializer_range
1099
+ if isinstance(module, nn.Linear):
1100
+ module.weight.data.normal_(mean=0.0, std=std)
1101
+ if module.bias is not None:
1102
+ module.bias.data.zero_()
1103
+ elif isinstance(module, nn.Embedding):
1104
+ module.weight.data.normal_(mean=0.0, std=std)
1105
+ if module.padding_idx is not None:
1106
+ module.weight.data[module.padding_idx].zero_()
1107
+
1108
+
1109
+ BAILINGMOEV2_INPUTS_DOCSTRING = r"""
1110
+ Args:
1111
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1112
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1113
+ it.
1114
+
1115
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1116
+ [`PreTrainedTokenizer.__call__`] for details.
1117
+
1118
+ [What are input IDs?](../glossary#input-ids)
1119
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1120
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1121
+
1122
+ - 1 for tokens that are **not masked**,
1123
+ - 0 for tokens that are **masked**.
1124
+
1125
+ [What are attention masks?](../glossary#attention-mask)
1126
+
1127
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1128
+ [`PreTrainedTokenizer.__call__`] for details.
1129
+
1130
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1131
+ `past_key_values`).
1132
+
1133
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1134
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1135
+ information on the default strategy.
1136
+
1137
+ - 1 indicates the head is **not masked**,
1138
+ - 0 indicates the head is **masked**.
1139
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1140
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1141
+ config.n_positions - 1]`.
1142
+
1143
+ [What are position IDs?](../glossary#position-ids)
1144
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1145
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1146
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1147
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1148
+
1149
+ Two formats are allowed:
1150
+ - a [`~cache_utils.Cache`] instance;
1151
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1152
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1153
+ cache format.
1154
+
1155
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1156
+ legacy cache format will be returned.
1157
+
1158
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1159
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1160
+ of shape `(batch_size, sequence_length)`.
1161
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1162
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1163
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1164
+ model's internal embedding lookup matrix.
1165
+ use_cache (`bool`, *optional*):
1166
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1167
+ `past_key_values`).
1168
+ output_attentions (`bool`, *optional*):
1169
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1170
+ tensors for more detail.
1171
+ output_hidden_states (`bool`, *optional*):
1172
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1173
+ more detail.
1174
+ return_dict (`bool`, *optional*):
1175
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1176
+ """
1177
+
1178
+
1179
+ @add_start_docstrings(
1180
+ "The bare BailingSharedMoeV2 Model outputting raw hidden-states without any specific head on top.",
1181
+ BAILINGMOEV2_START_DOCSTRING,
1182
+ )
1183
+ class BailingSharedMoeV2Model(BailingSharedMoeV2PreTrainedModel):
1184
+ """
1185
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`BailingSharedMoeV2DecoderLayer`]
1186
+
1187
+ Args:
1188
+ config: BailingSharedMoeV2Config
1189
+ """
1190
+
1191
+ def __init__(self, config: BailingSharedMoeV2Config):
1192
+ super().__init__(config)
1193
+ self.padding_idx = config.pad_token_id
1194
+ self.vocab_size = config.vocab_size
1195
+ self.num_nextn_predict_layers = config.num_nextn_predict_layers
1196
+
1197
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1198
+ self.layers = []
1199
+ for layer_idx in range(config.num_hidden_layers):
1200
+ layer_cls = BailingSharedMoeV2DecoderLayer if layer_idx < config.num_hidden_layers - config.num_nextn_predict_layers else BailingSharedMoeV2MTPLayer
1201
+ self.layers.append(layer_cls(config, layer_idx))
1202
+
1203
+ self.layers = nn.ModuleList(self.layers)
1204
+
1205
+ self._use_sdpa = config._attn_implementation == "sdpa"
1206
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1207
+ self.norm = BailingSharedMoeV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1208
+ self.rotary_emb = BailingSharedMoeV2RotaryEmbedding(config=config)
1209
+ self.gradient_checkpointing = False
1210
+ # Initialize weights and apply final processing
1211
+ self.post_init()
1212
+
1213
+ def get_input_embeddings(self):
1214
+ return self.word_embeddings
1215
+
1216
+ def set_input_embeddings(self, value):
1217
+ self.word_embeddings = value
1218
+
1219
+ @add_start_docstrings_to_model_forward(BAILINGMOEV2_INPUTS_DOCSTRING)
1220
+ def forward(
1221
+ self,
1222
+ input_ids: torch.LongTensor = None,
1223
+ attention_mask: Optional[torch.Tensor] = None,
1224
+ position_ids: Optional[torch.LongTensor] = None,
1225
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1226
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1227
+ use_cache: Optional[bool] = None,
1228
+ output_attentions: Optional[bool] = None,
1229
+ output_hidden_states: Optional[bool] = None,
1230
+ output_router_logits: Optional[bool] = None,
1231
+ return_dict: Optional[bool] = None,
1232
+ **kwargs,
1233
+ ) -> Union[Tuple, MoeV2ModelOutputWithPast]:
1234
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1235
+ output_hidden_states = (
1236
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1237
+ )
1238
+ output_router_logits = (
1239
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1240
+ )
1241
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1242
+
1243
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1244
+
1245
+ # retrieve input_ids and inputs_embeds
1246
+ if input_ids is not None and inputs_embeds is not None:
1247
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1248
+ elif input_ids is not None:
1249
+ batch_size, seq_length = input_ids.shape[:2]
1250
+ elif inputs_embeds is not None:
1251
+ batch_size, seq_length = inputs_embeds.shape[:2]
1252
+ else:
1253
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1254
+
1255
+ if self.gradient_checkpointing and self.training:
1256
+ if use_cache:
1257
+ logger.warning_once(
1258
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers."
1259
+ )
1260
+ use_cache = False
1261
+
1262
+ if use_cache and past_key_values is None:
1263
+ past_key_values = DynamicCache()
1264
+
1265
+ if inputs_embeds is None:
1266
+ inputs_embeds = self.word_embeddings(input_ids)
1267
+
1268
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1269
+
1270
+ if position_ids is None:
1271
+ position_ids = torch.arange(
1272
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1273
+ )
1274
+ position_ids = position_ids.unsqueeze(0)
1275
+
1276
+ if self._use_flash_attention_2:
1277
+ # 2d mask is passed through the layers
1278
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1279
+ elif self._use_sdpa and not output_attentions:
1280
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1281
+ # the manual implementation that requires a 4D causal mask in all cases.
1282
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1283
+ attention_mask,
1284
+ (batch_size, seq_length),
1285
+ inputs_embeds,
1286
+ past_seen_tokens,
1287
+ )
1288
+ else:
1289
+ # 4d mask is passed through the layers
1290
+ attention_mask = _prepare_4d_causal_attention_mask(
1291
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_seen_tokens
1292
+ )
1293
+
1294
+ # embed positions
1295
+ hidden_states = inputs_embeds
1296
+
1297
+ # create position embeddings to be shared across the decoder layers
1298
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1299
+
1300
+ # decoder layers
1301
+ all_hidden_states = () if output_hidden_states else None
1302
+ all_self_attns = () if output_attentions else None
1303
+ all_router_logits = () if output_router_logits else None
1304
+ next_decoder_cache = None
1305
+ layers = self.layers[: -self.num_nextn_predict_layers] if self.num_nextn_predict_layers > 0 else self.layers
1306
+ mtp_layers = self.layers[-self.num_nextn_predict_layers :] if self.num_nextn_predict_layers > 0 else None
1307
+
1308
+ for decoder_layer in layers:
1309
+ if output_hidden_states:
1310
+ all_hidden_states += (hidden_states,)
1311
+
1312
+ if self.gradient_checkpointing and self.training:
1313
+ layer_outputs = self._gradient_checkpointing_func(
1314
+ decoder_layer.__call__,
1315
+ hidden_states,
1316
+ attention_mask,
1317
+ position_ids,
1318
+ past_key_values,
1319
+ output_attentions,
1320
+ output_router_logits,
1321
+ use_cache,
1322
+ position_embeddings,
1323
+ )
1324
+ else:
1325
+ layer_outputs = decoder_layer(
1326
+ hidden_states,
1327
+ attention_mask=attention_mask,
1328
+ position_ids=position_ids,
1329
+ past_key_value=past_key_values,
1330
+ output_attentions=output_attentions,
1331
+ output_router_logits=output_router_logits,
1332
+ use_cache=use_cache,
1333
+ position_embeddings=position_embeddings,
1334
+ )
1335
+ hidden_states = layer_outputs[0]
1336
+
1337
+ if use_cache:
1338
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1339
+
1340
+ if output_attentions:
1341
+ all_self_attns += (layer_outputs[1],)
1342
+
1343
+ if output_router_logits and layer_outputs[-1] is not None:
1344
+ all_router_logits += (layer_outputs[-1],)
1345
+
1346
+ hidden_states = self.norm(hidden_states)
1347
+ main_hidden_states = hidden_states
1348
+
1349
+ # add hidden states from the last decoder layer
1350
+ if output_hidden_states:
1351
+ all_hidden_states += (main_hidden_states,)
1352
+
1353
+ mtp_hidden_states = None
1354
+
1355
+ if mtp_layers:
1356
+ for decoder_layer in mtp_layers:
1357
+ input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1)
1358
+ inputs_embeds = self.word_embeddings(input_ids)
1359
+
1360
+ if self.gradient_checkpointing and self.training:
1361
+ layer_outputs = self._gradient_checkpointing_func(
1362
+ decoder_layer.__call__,
1363
+ inputs_embeds,
1364
+ hidden_states,
1365
+ attention_mask,
1366
+ position_ids,
1367
+ past_key_values,
1368
+ output_attentions,
1369
+ output_router_logits,
1370
+ use_cache,
1371
+ position_embeddings,
1372
+ )
1373
+ else:
1374
+ layer_outputs = decoder_layer(
1375
+ inputs_embeds,
1376
+ hidden_states,
1377
+ attention_mask=attention_mask,
1378
+ position_ids=position_ids,
1379
+ past_key_value=past_key_values,
1380
+ output_attentions=output_attentions,
1381
+ output_router_logits=output_router_logits,
1382
+ use_cache=use_cache,
1383
+ position_embeddings=position_embeddings,
1384
+ )
1385
+ if mtp_hidden_states is None:
1386
+ mtp_hidden_states = []
1387
+ hidden_states = layer_outputs[0]
1388
+ mtp_hidden_states.append(hidden_states)
1389
+
1390
+ if output_hidden_states:
1391
+ all_hidden_states += (hidden_states,)
1392
+
1393
+ if use_cache:
1394
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1395
+
1396
+ if output_attentions:
1397
+ all_self_attns += (layer_outputs[1],)
1398
+
1399
+ if output_router_logits and layer_outputs[-1] is not None:
1400
+ all_router_logits += (layer_outputs[-1],)
1401
+
1402
+ next_cache = None
1403
+ if use_cache:
1404
+ next_cache = next_decoder_cache
1405
+ if not return_dict:
1406
+ return tuple(
1407
+ v
1408
+ for v in [main_hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
1409
+ if v is not None
1410
+ )
1411
+ return MoeV2ModelOutputWithPast(
1412
+ last_hidden_state=main_hidden_states,
1413
+ past_key_values=next_cache,
1414
+ hidden_states=all_hidden_states,
1415
+ mtp_hidden_states=mtp_hidden_states,
1416
+ attentions=all_self_attns,
1417
+ router_logits=all_router_logits,
1418
+ )
1419
+
1420
+
1421
+ class BailingSharedMoeV2ForCausalLM(BailingSharedMoeV2PreTrainedModel, GenerationMixin):
1422
+ _tied_weights_keys = ["lm_head.weight"]
1423
+
1424
+ def __init__(self, config: BailingSharedMoeV2Config):
1425
+ super().__init__(config)
1426
+ self.model = BailingSharedMoeV2Model(config)
1427
+ self.vocab_size = config.vocab_size
1428
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1429
+ self.num_nextn_predict_layers = config.num_nextn_predict_layers
1430
+ self.mtp_loss_scaling_factor = config.mtp_loss_scaling_factor
1431
+
1432
+ # Initialize weights and apply final processing
1433
+ self.post_init()
1434
+
1435
+ def get_input_embeddings(self):
1436
+ return self.model.word_embeddings
1437
+
1438
+ def set_input_embeddings(self, value):
1439
+ self.model.word_embeddings = value
1440
+
1441
+ def get_output_embeddings(self):
1442
+ return self.lm_head
1443
+
1444
+ def set_output_embeddings(self, new_embeddings):
1445
+ self.lm_head = new_embeddings
1446
+
1447
+ def set_decoder(self, decoder):
1448
+ self.model = decoder
1449
+
1450
+ def get_decoder(self):
1451
+ return self.model
1452
+
1453
+ @add_start_docstrings_to_model_forward(BAILINGMOEV2_INPUTS_DOCSTRING)
1454
+ @replace_return_docstrings(output_type=MoEV2CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1455
+ def forward(
1456
+ self,
1457
+ input_ids: torch.LongTensor = None,
1458
+ attention_mask: Optional[torch.Tensor] = None,
1459
+ position_ids: Optional[torch.LongTensor] = None,
1460
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1461
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1462
+ labels: Optional[torch.LongTensor] = None,
1463
+ use_cache: Optional[bool] = None,
1464
+ output_attentions: Optional[bool] = None,
1465
+ output_hidden_states: Optional[bool] = None,
1466
+ output_router_logits: Optional[bool] = None,
1467
+ return_dict: Optional[bool] = None,
1468
+ **kwargs,
1469
+ ) -> Union[Tuple, MoEV2CausalLMOutputWithPast]:
1470
+ r"""
1471
+ Args:
1472
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1473
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1474
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1475
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1476
+
1477
+ Returns:
1478
+
1479
+ Example:
1480
+
1481
+ ```python
1482
+ >>> from transformers import AutoTokenizer
1483
+
1484
+ >>> model = BailingSharedMoeV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1485
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1486
+
1487
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1488
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1489
+
1490
+ >>> # Generate
1491
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1492
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1493
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1494
+ ```"""
1495
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1496
+ output_hidden_states = (
1497
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1498
+ )
1499
+ output_router_logits = (
1500
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
1501
+ )
1502
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1503
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1504
+ outputs = self.model(
1505
+ input_ids=input_ids,
1506
+ attention_mask=attention_mask,
1507
+ position_ids=position_ids,
1508
+ past_key_values=past_key_values,
1509
+ inputs_embeds=inputs_embeds,
1510
+ use_cache=use_cache,
1511
+ output_attentions=output_attentions,
1512
+ output_hidden_states=output_hidden_states,
1513
+ output_router_logits=output_router_logits,
1514
+ return_dict=return_dict,
1515
+ **kwargs,
1516
+ )
1517
+
1518
+ loss = None
1519
+ all_mtp_loss = []
1520
+ aux_loss = None
1521
+ hidden_states = outputs[0]
1522
+ logits = self.lm_head(hidden_states)
1523
+ logits = logits.float()
1524
+
1525
+ if labels is not None:
1526
+ #loss = self.loss_function(logits, labels, self.config.vocab_size, **kwargs)
1527
+ loss = apply_lce(
1528
+ hidden_states,
1529
+ self.lm_head.weight,
1530
+ labels,
1531
+ _PATCH_OPTS,
1532
+ **kwargs,
1533
+ )
1534
+
1535
+ if not return_dict:
1536
+ output = (logits,) + outputs[1:]
1537
+ if output_router_logits:
1538
+ output = (aux_loss,) + output
1539
+ return (loss,) + output if loss is not None else output
1540
+
1541
+ return MoEV2CausalLMOutputWithPast(
1542
+ loss=loss,
1543
+ mtp_loss=all_mtp_loss,
1544
+ aux_loss=aux_loss,
1545
+ logits=logits,
1546
+ mtp_logits=[],
1547
+ past_key_values=outputs.past_key_values,
1548
+ hidden_states=outputs.hidden_states,
1549
+ attentions=outputs.attentions,
1550
+ router_logits=outputs.router_logits,
1551
+ )
1552
+