import torch import torch.nn as nn import torch.nn.functional as F # Note: This is a simplified version of communication balance loss # For the complete implementation with proper token-device mapping # the device-limited routing implementation # and more efficient calculations, please contact the author class Expert(nn.Module): """ Position-wise Feed-Forward Networks This consists of two linear transformations with a ReLU activation in between. FFN(x) = max(0, xW1 + b1 )W2 + b2 d_model: embedding dimension (e.g., 512) d_expert: expert dimension (e.g., 256) """ def __init__(self, d_model, d_expert): super().__init__() self.d_model=d_model self.d_expert= d_expert # Linear transformation y = xW+b self.fc1 = nn.Linear(self.d_model, self.d_expert, bias = True) self.fc2 = nn.Linear(self.d_expert, self.d_model, bias = True) # for potential speed up # Pre-normalize the weights (can help with training stability) nn.init.xavier_uniform_(self.fc1.weight) nn.init.xavier_uniform_(self.fc2.weight) def forward(self, input): # check input and first FF layer dimension matching batch_size, seq_length, d_input = input.size() assert self.d_model == d_input, "d_model must be the same dimension as the input" # max(0, xW_1 + b_1)W_2 + b_2 return self.fc2(F.relu(self.fc1(input))) class MixtureOfExperts(nn.Module): """ Mixture of Expert as in DeepSeek MoE(x) = x + \sum Expert^s_i(x) + \sum gate(x;K)*Expert^r_i(x) d_model: embedding dimension (e.g., 512) d_expert: expert dimension (e.g., 216) K : top K gate N_s: number of shared experts N_r: number of routed experts alpha1: hyper-parameter; expert-level balance factor alpha2: hyper-parameter; edevice-level balance factor alpha3: hyper-parameter; communication balance factor D: number of device for distributed system M: number of device for Device-Limited Routing """ def __init__(self, d_model, d_expert, K, N_s, N_r, alpha1, alpha2, alpha3, D=4, M=3): super().__init__() assert D < N_r, "Number of partitions needs to be less than number of routed experts" assert M <= D, "Number of deviced for Device-Limited Routing needs to be less than number of total device" self.d_model=d_model self.d_expert= d_expert self.K = K self.N_s = N_s self.N_r = N_r self.alpha1 = alpha1 self.alpha2 = alpha2 self.alpha3 = alpha3 self.D = D # number of device available self.M = M # for Device-Limited Routing # initialize shared experts and routed experts self.shared_experts = nn.ModuleList([ Expert(self.d_model, self.d_expert) for _ in range(N_s) ]) self.routed_experts = nn.ModuleList([ Expert(self.d_model, self.d_expert) for _ in range(N_r) ]) # Initiate centroids: learnable parameters, one vector per routed expert self.expert_centroids = nn.Parameter( torch.randn(N_r, d_model) # [num_routed_experts, d_model] ) nn.init.xavier_uniform_(self.expert_centroids) def forward(self, input): # check input and first FF layer dimension matching batch_size, seq_length, d_input = input.size() assert self.d_model == d_input, "d_model must be the same dimension as the input" shared_output = torch.zeros_like(input) for expert in self.shared_experts: shared_output += expert(input) #[batch, seq, d_model] # Calculate similarity between input tokens and expert centroids self.similarities = torch.matmul(input, self.expert_centroids.transpose(0, 1)) #[batch, seq, N_r] assert self.similarities.size(dim=-1) == self.N_r, \ "last dimension of similarities must be the same as the number of routed expert" affinity = F.softmax(self.similarities, dim = -1) #[batch, seq, N_r] ## Apply topK to calculate the gate values, indexes = torch.topk(affinity, self.K) values = F.softmax(values, dim=-1) # Renormalize the top-K values gate = torch.zeros_like(affinity).scatter_(2, indexes, values) #[batch, seq, N_r] """for testing""" self.last_gate = gate routed_output = torch.zeros_like(input) for i in range(self.N_r): routed_output += gate[:,:,i].unsqueeze(-1) * self.routed_experts[i](input) ## Auxiliary Loss for Load Balance # Expert-Level Balance Loss. T = batch_size+seq_length f = self.N_r/(self.K*T) * torch.count_nonzero(gate,(0,1)) P = 1/T * affinity.sum((0,1)) expert_loss = self.alpha1 * torch.matmul(f,P) # Device-evel Balance Loss f1= torch.tensor([partition.to(f.dtype).mean() for partition in torch.tensor_split(f, self.D)]) P1 = torch.tensor([partition.to(P.dtype).sum() for partition in torch.tensor_split(P, self.D)]) device_loss = self.alpha2 * torch.matmul(f1,P1) # Communication Balance Loss f2 = self.D/(self.M*T)*torch.tensor( [ torch.count_nonzero(partition,(0,1)).sum() for partition in torch.tensor_split(gate, self.D, dim=-1)] ) P2 = P1 commu_loss = self.alpha3 * torch.matmul(f2,P2) return input + shared_output + routed_output, expert_loss, device_loss, commu_loss