JingzeShi commited on
Commit
ac7a5a6
verified
1 Parent(s): ebae748

Upload DogeForCausalLM

Browse files
Files changed (3) hide show
  1. config.json +1 -1
  2. model.safetensors +1 -1
  3. modeling_doge.py +3 -2
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "./data/Doge-20M-MoE-Instruct-SFT/checkpoint-7258",
3
  "architectures": [
4
  "DogeForCausalLM"
5
  ],
 
1
  {
2
+ "_name_or_path": "./data/Doge-20M-MoE-Instruct-SFT",
3
  "architectures": [
4
  "DogeForCausalLM"
5
  ],
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1bc235190b0ddce6cb8b0e6c1e86db0a6c7d5c4e4b23656024e6f6cfdb52221d
3
  size 69786512
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:388975546a77884213c70f44574a087115a55456b28c53251d72954ff0214245
3
  size 69786512
modeling_doge.py CHANGED
@@ -502,12 +502,13 @@ class DogeCDMoE(DogeMLP):
502
  routing_weights = self.router_gate(hidden_states).view(2, bsz * seq_len, -1)
503
 
504
  # get experts with the highest routing weights
505
- (scores_x, scores_y), (indices_x, indices_y) = [w.topk(self.num_keys, dim=-1) for w in routing_weights]
506
  all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
507
  all_indices = indices_x.unsqueeze(-1) * self.num_keys + indices_y.unsqueeze(-2)
508
  all_scores = all_scores.view(*all_scores.shape[:-2], -1)
509
  all_indices = all_indices.view(*all_indices.shape[:-2], -1)
510
- scores, indices = all_scores.topk(self.top_k, dim=-1)
 
511
  down_embed = self.down_embed(indices)
512
  up_embed = self.up_embed(indices)
513
 
 
502
  routing_weights = self.router_gate(hidden_states).view(2, bsz * seq_len, -1)
503
 
504
  # get experts with the highest routing weights
505
+ (scores_x, scores_y), (indices_x, indices_y) = routing_weights.topk(self.num_keys, dim=-1)
506
  all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
507
  all_indices = indices_x.unsqueeze(-1) * self.num_keys + indices_y.unsqueeze(-2)
508
  all_scores = all_scores.view(*all_scores.shape[:-2], -1)
509
  all_indices = all_indices.view(*all_indices.shape[:-2], -1)
510
+ scores, position_indices = all_scores.topk(self.top_k, dim=-1)
511
+ indices = all_indices.gather(-1, position_indices)
512
  down_embed = self.down_embed(indices)
513
  up_embed = self.up_embed(indices)
514