Update meshrouter.py
Browse files- meshrouter.py +26 -2
meshrouter.py
CHANGED
@@ -1,3 +1,27 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
|
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig, PreTrainedModel, AutoModelForCausalLM # Import AutoModelForCausalLM
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import math
|
6 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast # Import the necessary output class
|
7 |
|
8 |
+
# Define the Router for dynamic routing
|
9 |
+
class MeshRouter(nn.Module):
|
10 |
+
def __init__(self, config: MeshConfig):
|
11 |
+
super().__init__()
|
12 |
+
self.gate = nn.Linear(config.hidden_size, config.mesh_grid_size[0] * config.mesh_grid_size[1])
|
13 |
+
self.softmax = nn.Softmax(dim=-1)
|
14 |
+
self.routing_k = config.routing_k
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
# x shape: (batch_size, sequence_length, hidden_size)
|
18 |
+
gate_scores = self.gate(x) # shape: (batch_size, sequence_length, num_experts)
|
19 |
+
gate_weights = self.softmax(gate_scores)
|
20 |
+
|
21 |
+
# Select top-k experts
|
22 |
+
topk_weights, topk_indices = torch.topk(gate_weights, self.routing_k, dim=-1)
|
23 |
+
|
24 |
+
# Normalize top-k weights
|
25 |
+
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-6)
|
26 |
+
|
27 |
+
return topk_weights, topk_indices # shapes: (batch_size, sequence_length, k), (batch_size, sequence_length, k)
|