Upload DogeForCausalLM
Browse files- config.json +1 -1
- model.safetensors +1 -1
- modeling_doge.py +3 -2
config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "./data/Doge-20M-MoE-Instruct-SFT
|
3 |
"architectures": [
|
4 |
"DogeForCausalLM"
|
5 |
],
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "./data/Doge-20M-MoE-Instruct-SFT",
|
3 |
"architectures": [
|
4 |
"DogeForCausalLM"
|
5 |
],
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 69786512
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:388975546a77884213c70f44574a087115a55456b28c53251d72954ff0214245
|
3 |
size 69786512
|
modeling_doge.py
CHANGED
@@ -502,12 +502,13 @@ class DogeCDMoE(DogeMLP):
|
|
502 |
routing_weights = self.router_gate(hidden_states).view(2, bsz * seq_len, -1)
|
503 |
|
504 |
# get experts with the highest routing weights
|
505 |
-
(scores_x, scores_y), (indices_x, indices_y) =
|
506 |
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
|
507 |
all_indices = indices_x.unsqueeze(-1) * self.num_keys + indices_y.unsqueeze(-2)
|
508 |
all_scores = all_scores.view(*all_scores.shape[:-2], -1)
|
509 |
all_indices = all_indices.view(*all_indices.shape[:-2], -1)
|
510 |
-
scores,
|
|
|
511 |
down_embed = self.down_embed(indices)
|
512 |
up_embed = self.up_embed(indices)
|
513 |
|
|
|
502 |
routing_weights = self.router_gate(hidden_states).view(2, bsz * seq_len, -1)
|
503 |
|
504 |
# get experts with the highest routing weights
|
505 |
+
(scores_x, scores_y), (indices_x, indices_y) = routing_weights.topk(self.num_keys, dim=-1)
|
506 |
all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
|
507 |
all_indices = indices_x.unsqueeze(-1) * self.num_keys + indices_y.unsqueeze(-2)
|
508 |
all_scores = all_scores.view(*all_scores.shape[:-2], -1)
|
509 |
all_indices = all_indices.view(*all_indices.shape[:-2], -1)
|
510 |
+
scores, position_indices = all_scores.topk(self.top_k, dim=-1)
|
511 |
+
indices = all_indices.gather(-1, position_indices)
|
512 |
down_embed = self.down_embed(indices)
|
513 |
up_embed = self.up_embed(indices)
|
514 |
|