Aratako commited on
Commit
381b397
1 Parent(s): 6c4d051

Update modeling_tanuki.py

Browse files
Files changed (1) hide show
  1. modeling_tanuki.py +2 -0
modeling_tanuki.py CHANGED
@@ -857,6 +857,8 @@ class TanukiSparseMoeBlock(nn.Module):
857
  for expert_idx in range(self.num_experts):
858
  expert_layer = self.experts[expert_idx]
859
  idx, top_x = torch.where(expert_mask[expert_idx])
 
 
860
 
861
  # Index the correct hidden states and compute the expert hidden state for
862
  # the current expert. We need to make sure to multiply the output hidden
 
857
  for expert_idx in range(self.num_experts):
858
  expert_layer = self.experts[expert_idx]
859
  idx, top_x = torch.where(expert_mask[expert_idx])
860
+ if top_x.shape[0] == 0:
861
+ continue
862
 
863
  # Index the correct hidden states and compute the expert hidden state for
864
  # the current expert. We need to make sure to multiply the output hidden