aquiffoo commited on
Commit
fc24706
·
verified ·
1 Parent(s): c11d27f

Update meshrouter.py

Browse files
Files changed (1) hide show
  1. meshrouter.py +26 -2
meshrouter.py CHANGED
@@ -1,3 +1,27 @@
1
- # Source code for MeshRouter from cell 39f2782d
2
- # Please replace this with the actual code from the notebook cell.
 
 
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+ from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class
7
 
8
+ # Define the Router for dynamic routing
9
+ class MeshRouter(nn.Module):
10
+ def __init__(self, config: MeshConfig):
11
+ super().__init__()
12
+ self.gate = nn.Linear(config.hidden_size, config.mesh_grid_size[0] * config.mesh_grid_size[1])
13
+ self.softmax = nn.Softmax(dim=-1)
14
+ self.routing_k = config.routing_k
15
+
16
+ def forward(self, x):
17
+ # x shape: (batch_size, sequence_length, hidden_size)
18
+ gate_scores = self.gate(x) # shape: (batch_size, sequence_length, num_experts)
19
+ gate_weights = self.softmax(gate_scores)
20
+
21
+ # Select top-k experts
22
+ topk_weights, topk_indices = torch.topk(gate_weights, self.routing_k, dim=-1)
23
+
24
+ # Normalize top-k weights
25
+ topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-6)
26
+
27
+ return topk_weights, topk_indices # shapes: (batch_size, sequence_length, k), (batch_size, sequence_length, k)