Potential Issue: Load Balancing Loss May Mask Per-Layer Expert Imbalances

#13
by HuggingJerry - opened

Hi team,

A quick observation on the load_balancing_loss_func in qwen3_moe/modeling_qwen3_moe.py:

The current implementation calculates load balancing loss by first concatenating gate_logits from all layers if they are provided as a tuple:

if isinstance(gate_logits, tuple):
    concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)

All subsequent metrics (tokens_per_expert, router_prob_per_expert) are then computed on this concatenated_gate_logits.

Potential Issue:
This approach averages expert utilization across all layers. Consequently, the loss value could be small (indicating good balance) even if expert utilization is highly uneven within individual layers, as long as these imbalances "cancel out" in the global average.

For example:

  • Layer 0 might heavily use Expert A and ignore Expert B.
  • Layer 1 might heavily use Expert B and ignore Expert A.

Globally, both A and B appear utilized, leading to a low loss, but each layer has a severe imbalance. This could potentially lead to suboptimal expert specialization or training inefficiencies at the layer level.

If per-layer expert balance is a design goal, the loss calculation might need to be performed for each layer's gate_logits individually before aggregation.

The following function may help address the issue:

def load_balancing_loss_func_per_layers(
    gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
    num_experts: Optional[int] = None,
    top_k=2,
    attention_mask: Optional[torch.Tensor] = None,
    batch_size: int = None,
    seq_len: int = None
) -> Union[torch.Tensor, int]:
    if gate_logits is None or not isinstance(gate_logits, tuple):
        return 0
    num_layers = len(gate_logits)
    # 在batch*seq_len维度上拼接, 
    # 输入:两个[batch_size*seq_len, num_experts]
    # 输出:[num_layers*batch_size*seq_len, num_experts]
    if isinstance(gate_logits, tuple):
        compute_device = gate_logits[0].device
        concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)

    # softmax
    # [num_layers*batch_size*seq_len, num_experts]
    routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
    routing_weights_per_layers = routing_weights.view(num_layers,
                                                       batch_size*seq_len,
                                                       -1
                                                      )
    _, selected_experts = torch.topk(routing_weights_per_layers, top_k, dim=-1)
    #print(selected_experts, selected_experts.shape)

    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)

    if attention_mask is None:
        # Compute the percentage of tokens routed to each experts
        tokens_per_expert = torch.mean(expert_mask.float(), dim=1)
        # Compute the average probability of routing to these experts
        router_prob_per_expert = torch.mean(routing_weights_per_layers, dim=1)

    else:
        batch_size, sequence_length = attention_mask.shape
        num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)

        # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
        expert_attention_mask = (
            attention_mask[None, :, :, None, None]
            .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
            .reshape(num_hidden_layers, -1, top_k, num_experts)
            .to(compute_device)
        )

        # Compute the percentage of tokens routed to each experts
        tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=1) / torch.sum(
            expert_attention_mask, dim=1
        )

        # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
        router_per_expert_attention_mask = (
            attention_mask[None, :, :, None]
            .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
            .reshape(num_hidden_layers, -1, num_experts)
            .to(compute_device)
        )

        # Compute the average probability of routing to these experts
        router_prob_per_expert = torch.sum(routing_weights_per_layers * router_per_expert_attention_mask, dim=1) / torch.sum(
            router_per_expert_attention_mask, dim=1
        )
    # [num_layers, top_k, num_experts] * [num_layers, 1, num_experts]
    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(1))
    return overall_loss * num_experts 

Sign up or log in to comment