Update modeling_tanuki.py
Browse files- modeling_tanuki.py +2 -0
modeling_tanuki.py
CHANGED
@@ -857,6 +857,8 @@ class TanukiSparseMoeBlock(nn.Module):
|
|
857 |
for expert_idx in range(self.num_experts):
|
858 |
expert_layer = self.experts[expert_idx]
|
859 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
|
|
|
860 |
|
861 |
# Index the correct hidden states and compute the expert hidden state for
|
862 |
# the current expert. We need to make sure to multiply the output hidden
|
|
|
857 |
for expert_idx in range(self.num_experts):
|
858 |
expert_layer = self.experts[expert_idx]
|
859 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
860 |
+
if top_x.shape[0] == 0:
|
861 |
+
continue
|
862 |
|
863 |
# Index the correct hidden states and compute the expert hidden state for
|
864 |
# the current expert. We need to make sure to multiply the output hidden
|