Spaces:
Running
Potential Issue: Load Balancing Loss May Mask Per-Layer Expert Imbalances
Hi team,
A quick observation on the load_balancing_loss_func
in qwen3_moe/modeling_qwen3_moe.py:
The current implementation calculates load balancing loss by first concatenating gate_logits
from all layers if they are provided as a tuple:
if isinstance(gate_logits, tuple):
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
All subsequent metrics (tokens_per_expert
, router_prob_per_expert
) are then computed on this concatenated_gate_logits
.
Potential Issue:
This approach averages expert utilization across all layers. Consequently, the loss value could be small (indicating good balance) even if expert utilization is highly uneven within individual layers, as long as these imbalances "cancel out" in the global average.
For example:
- Layer 0 might heavily use Expert A and ignore Expert B.
- Layer 1 might heavily use Expert B and ignore Expert A.
Globally, both A and B appear utilized, leading to a low loss, but each layer has a severe imbalance. This could potentially lead to suboptimal expert specialization or training inefficiencies at the layer level.
If per-layer expert balance is a design goal, the loss calculation might need to be performed for each layer's gate_logits
individually before aggregation.
The following function may help address the issue:
def load_balancing_loss_func_per_layers(
gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
num_experts: Optional[int] = None,
top_k=2,
attention_mask: Optional[torch.Tensor] = None,
batch_size: int = None,
seq_len: int = None
) -> Union[torch.Tensor, int]:
if gate_logits is None or not isinstance(gate_logits, tuple):
return 0
num_layers = len(gate_logits)
# 在batch*seq_len维度上拼接,
# 输入:两个[batch_size*seq_len, num_experts]
# 输出:[num_layers*batch_size*seq_len, num_experts]
if isinstance(gate_logits, tuple):
compute_device = gate_logits[0].device
concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
# softmax
# [num_layers*batch_size*seq_len, num_experts]
routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
routing_weights_per_layers = routing_weights.view(num_layers,
batch_size*seq_len,
-1
)
_, selected_experts = torch.topk(routing_weights_per_layers, top_k, dim=-1)
#print(selected_experts, selected_experts.shape)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
if attention_mask is None:
# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.mean(expert_mask.float(), dim=1)
# Compute the average probability of routing to these experts
router_prob_per_expert = torch.mean(routing_weights_per_layers, dim=1)
else:
batch_size, sequence_length = attention_mask.shape
num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
expert_attention_mask = (
attention_mask[None, :, :, None, None]
.expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
.reshape(num_hidden_layers, -1, top_k, num_experts)
.to(compute_device)
)
# Compute the percentage of tokens routed to each experts
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=1) / torch.sum(
expert_attention_mask, dim=1
)
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
router_per_expert_attention_mask = (
attention_mask[None, :, :, None]
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
.reshape(num_hidden_layers, -1, num_experts)
.to(compute_device)
)
# Compute the average probability of routing to these experts
router_prob_per_expert = torch.sum(routing_weights_per_layers * router_per_expert_attention_mask, dim=1) / torch.sum(
router_per_expert_attention_mask, dim=1
)
# [num_layers, top_k, num_experts] * [num_layers, 1, num_experts]
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(1))
return overall_loss * num_experts