Conversation
| ctx.aux_loss_weight, # pyrefly: ignore [missing-attribute] | ||
| ) | ||
| (aux_grad,) = torch.autograd.grad(aux_loss, scores_detached) | ||
| return grad_scores + aux_grad, None, None, None, None, None, None, None |
There was a problem hiding this comment.
note (@tianyu-l): prefer compute aux_loss in fwd & save
There was a problem hiding this comment.
iirc, this is something (nested autograd) @xmfan wanted to avoid for graph breaks? We can 1) write the grad function analytically, can be tedious or 2) use torch.func.grad to transform the aux loss function, which is perhaps the simplest approach, or 3) use torch.library.custom_op+ register_autograd as in deepep.py. I think all of these approaches can avoid the global config added in this pr torch._dynamo.config.trace_autograd_ops = True, which can have side effects globally. Maybe Simon can correct me if anything not making sense.
There was a problem hiding this comment.
I went with torch.func.grad to remove the flag, but it seems the loss computation is still in backwards. I don't think we save any memory by computing the loss in FWD.
There was a problem hiding this comment.
Yeah, that's by design. We do not compute loss in the forward to avoid coupling with modeling loss which breaks pp. All we need is its gradients in the backward. On the other hand, if we want to log load balancing loss to monitor the healthiness of training, we can do it under torch.no_grad()in the forward.
| with torch.enable_grad(): | ||
| scores_detached = scores.detach().requires_grad_(True) | ||
| if ( | ||
| ctx.aux_loss_type == "sequence_wise" | ||
| ): # pyrefly: ignore [missing-attribute] | ||
| aux_loss = MoE._sequence_wise_aux_loss( | ||
| scores_detached, | ||
| selected_experts_indices, | ||
| ctx.bs, # pyrefly: ignore [missing-attribute] | ||
| ctx.slen, # pyrefly: ignore [missing-attribute] | ||
| ctx.top_k, # pyrefly: ignore [missing-attribute] | ||
| ctx.aux_loss_weight, # pyrefly: ignore [missing-attribute] | ||
| ) | ||
| else: | ||
| num_tokens_per_expert = torch.histc( | ||
| selected_experts_indices.view(-1).float(), | ||
| bins=ctx.num_experts, # pyrefly: ignore [missing-attribute] | ||
| min=0, | ||
| max=ctx.num_experts, # pyrefly: ignore [missing-attribute] | ||
| ) | ||
| aux_loss = MoE._batch_wise_aux_loss( | ||
| scores_detached, | ||
| num_tokens_per_expert, | ||
| ctx.top_k, # pyrefly: ignore [missing-attribute] | ||
| ctx.aux_loss_weight, # pyrefly: ignore [missing-attribute] | ||
| ) |
There was a problem hiding this comment.
why don't we compute these in forward?
| layer_cfg.moe.router._debug_force_load_balance = ( | ||
| debug.moe_force_load_balance | ||
| ) | ||
| apply_moe_load_balance_config( |
There was a problem hiding this comment.
put this at config time, not run time
| moe_aux_loss_weight: float = 0.0 | ||
| """Weight for the MoE auxiliary load-balance loss. 0 disables it.""" | ||
|
|
||
| moe_aux_loss_type: Literal["sequence_wise", "batch_wise"] = "sequence_wise" | ||
| """Type of MoE auxiliary load-balance loss.""" | ||
|
|
||
| moe_load_balance_coeff: float | None = 1e-3 | ||
| """Coefficient for aux-loss-free bias-based MoE load balancing. | ||
| Overrides the model config default when set. None disables it.""" |
There was a problem hiding this comment.
do this in model_registry not allow to override per CLI job
|
would be good if we add a buffer of |
@rakkit Logging capability will be separately addressed by @felipemello1 in a PR similar to #2607. Will tag you for review. |
| return (f_i * p_i).sum() * aux_loss_weight | ||
|
|
||
|
|
||
| class _AuxLossBackward(torch.autograd.Function): |
There was a problem hiding this comment.
I think we can just use some regular name like LoadBalanceAuxLoss? Did not get why we need to put Backward only in the name.
| router_top_k=3, | ||
| router_score_func="softmax", | ||
| score_before_experts=False, | ||
| aux_loss=MoELoadBalanceAuxLoss.Config(weight=1e-4), |
There was a problem hiding this comment.
check technical report of dsv3 and see if it's actually used.
IIRC qwen3 used it so let's enable it for qwen3. Need to check gpt-oss as well.
There was a problem hiding this comment.
yes this is from the hyperparams section; I'll check the other 2
There was a problem hiding this comment.
gpt-oss seems to set 1e-3: https://github.com/huggingface/transformers/blob/6ffbb07f93d9e44457450d1150136309b0dc966b/src/transformers/models/gpt_oss/configuration_gpt_oss.py#L61, though I need to look closely if the loss is equivalent, it might be the batch-wise but with softmax scores: https://github.com/huggingface/transformers/blob/a8f43eca15b8d1c63deb33f6b97dfab30419e5da/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L504
| "batch_wise MoE aux loss is incompatible with pipeline " | ||
| "parallelism. Use sequence_wise instead." | ||
| ) | ||
| aux_loss.local_batch_size = config.training.local_batch_size |
There was a problem hiding this comment.
should be done at config time not here.
| aux_loss = layer_cfg.moe.aux_loss | ||
| if aux_loss.weight > 0: | ||
| if aux_loss.type == "batch_wise" and pp_enabled: | ||
| raise ValueError( |
There was a problem hiding this comment.
should be done in update_from_configs fn, not here
| config.shared_experts.build() if config.shared_experts is not None else None | ||
| ) | ||
| self.aux_loss = MoELoadBalanceAuxLoss(config.aux_loss) | ||
| self.aux_loss.top_k = config.router.top_k |
There was a problem hiding this comment.
should be set in config time
| self.shared_experts = ( | ||
| config.shared_experts.build() if config.shared_experts is not None else None | ||
| ) | ||
| self.aux_loss = MoELoadBalanceAuxLoss(config.aux_loss) |
There was a problem hiding this comment.
use build, not constructor
| class Config: | ||
| weight: float = 0.0 | ||
| """Weight for the auxiliary load-balance loss. 0 disables it.""" | ||
| type: Literal["sequence_wise", "batch_wise"] = "sequence_wise" |
There was a problem hiding this comment.
Let's have two subclasses, rather than using a str factory
| bs: int, | ||
| slen: int, | ||
| ) -> torch.Tensor: | ||
| if self.weight <= 0: |
| local_batch_size: int | None = None | ||
| """Total local batch size (before microbatching). Used to normalize aux | ||
| loss gradients across pipeline-parallel microbatches so they match the | ||
| non-PP case. Set automatically by the trainer; None means use the |
There was a problem hiding this comment.
oh, could you remind me why not using global batch size?
There was a problem hiding this comment.
IIUC this is to match the non-microbatch-pipelining variant; the loss should be averaged across all tokens this sees, but this autograd fn will accumulate gradient across microbatches, so we normalize by micro_bs / local_bs .
| - num_tokens_per_expert (torch.Tensor): | ||
| Number of tokens assigned to each expert with shape ``(num_experts,)``. | ||
| - scores (torch.Tensor): | ||
| Full router scores for all experts with shape ``(bs*slen, num_experts)``. |
There was a problem hiding this comment.
add reasons to have this field returned (for load balance loss)
| class _AuxLossBase(torch.autograd.Function): | ||
| """Injects auxiliary load-balance loss gradients at the router scores level. | ||
|
|
||
| Identity in forward (returns ``top_scores`` unchanged). In backward, |
There was a problem hiding this comment.
Identity in forward (returns
top_scoresunchanged)
Is this even a goal?
There was a problem hiding this comment.
this is so it stays on the autograd graph, will update the docstring
| Expert indices selected for each token with shape ``(bs*slen, top_k)``. | ||
| - num_tokens_per_expert (torch.Tensor): | ||
| Number of tokens assigned to each expert with shape ``(num_experts,)``. | ||
| - scores (torch.Tensor): |
There was a problem hiding this comment.
nit: can we put scores before top_scores? just put scores together.
| debug.moe_force_load_balance | ||
| ) | ||
|
|
||
| if training is not None: |
There was a problem hiding this comment.
why qwen3 has this condition check, what's the difference with other models?
| ), | ||
| top_k=top_k, | ||
| ), | ||
| aux_loss=BatchWiseAuxLoss.Config(weight=1e-3), |
There was a problem hiding this comment.
can you put some reference here where we know it uses batchwise
There was a problem hiding this comment.
I actually couldn't find arxiv references, so left hf urls - maybe that's inaccurate/weird. also don't have something for llama4 yet.
| router=router, | ||
| experts=experts, | ||
| shared_experts=shared_experts, | ||
| aux_loss=BatchWiseAuxLoss.Config(weight=1e-3), |
| non_blocking_capacity_factor=non_blocking_capacity_factor, | ||
| ), | ||
| aux_loss=BatchWiseAuxLoss.Config(weight=1e-3), | ||
| ), |
There was a problem hiding this comment.
similar, we need some references, may be helpful these two
Is our batchwise auxloss equivalent to the global batch lbl? How do we compute average loss over dp ranks? Like the equation (4) in the paper 2 above.
There was a problem hiding this comment.
this is batch-wise aux loss yes, currently we ban with PP because it's not additive across microbatches, maybe we should ban with DP too until the count accumulator is implemented.
| to ``scores``'s gradient. ``top_scores`` is a pass-through so this node | ||
| remains in the autograd graph. | ||
|
|
||
| Uses ``torch.func.grad`` instead of nested ``torch.autograd.grad`` to |
There was a problem hiding this comment.
I think avoiding this global config is part our code iterating process and no need to reveal it to the users.
| ( | ||
| scores, | ||
| selected_experts_indices, | ||
| ) = ctx.saved_tensors |
There was a problem hiding this comment.
maybe add one line comment if we want to be explicit.
| ) = ctx.saved_tensors | |
| ) = ctx.saved_tensors | |
| # torch.func.grad avoids the graph break that torch.autograd.grad causes under torch.compile |
fcb7eff to
4f80bc3
Compare
| indices_per_seq = selected_experts_indices.view(bs, -1) | ||
| offset = torch.arange(bs, device=indices_per_seq.device).unsqueeze(1) * num_experts | ||
| flat_indices = (indices_per_seq + offset).reshape(-1) | ||
| counts = torch.bincount(flat_indices.long(), minlength=bs * num_experts) |
There was a problem hiding this comment.
Can you double check whether torch.bincout leads to a graph break under torch.compile? Maybe we need .scatter_add on a sparse tensor.
There was a problem hiding this comment.
This runs under compile, I believe capture_scalar_outputs enables capturing the output shape.
Add PP-safe auxiliary load-balance loss for MoE models using a custom autograd.Function that injects gradients during backward without retain_graph. Supports sequence-wise (DeepSeek-V3 style) and batch-wise loss variants via MoELoadBalanceAuxLoss config. Applied across all MoE models: DeepSeek V3, Llama4, Qwen3, Qwen3-VL, GPT-OSS.
Adds MoE load-balancing aux loss following #2061, using a custom autograd function that no-ops in FWD, updates per-layer score gradients in BWD, for PP/compile compatibility
seq-wise aux loss (aux loss settings aren't in CLI):
MODULE=deepseek_v3 CONFIG=deepseek_v3_debugmodel NGPU=4 ./run_train.shbatch-wise aux loss:
MODULE=deepseek_v3 CONFIG=deepseek_v3_debugmodel NGPU=4 ./run_train.shseq-wise, compile:
MODULE=deepseek_v3 CONFIG=deepseek_v3_debugmodel NGPU=4 ./run_train.sh --compile.enableseq-wise, compile + PP=4 (1F1B):
LOG_RANK=3 MODULE=deepseek_v3 CONFIG=deepseek_v3_debugmodel NGPU=4 ./run_train.sh --compile.enable --parallelism.pipeline_parallel_degree=4 --parallelism.pipeline_parallel_schedule="1F1B" --parallelism.pipeline_parallel_microbatch_size=1seq-wise, compile + FSDP=2, PP=2 (Interleaved1F1B):
LOG_RANK=3 MODULE=deepseek_v3 CONFIG=deepseek_v3_debugmodel NGPU=4 ./run_train.sh --compile.enable --parallelism.pipeline_parallel_degree=2 --parallelism.data_parallel_shard_degree=2 --parallelism.pipeline_parallel_schedule="Interleaved1F1B" --parallelism.pipeline_parallel_microbatch_size=1seq-wise, FSDP=2, PP=2 (ZBV). no compile, no AC (ZBV + compile/AC fails):
MODULE=deepseek_v3 CONFIG=deepseek_v3_debugmodel NGPU=4 ./run_train.sh --parallelism.pipeline_parallel_degree=2 --parallelism.data_parallel_shard_degree=2 --parallelism.pipeline_parallel_schedule="ZBVZeroBubble" --parallelism.pipeline_parallel_microbatch_size=1 --activation_checkpoint.mode=nonebatch-wise + PP errors, as aux loss is not additive over microbatches:
MODULE=deepseek_v3 CONFIG=deepseek_v3_debugmodel NGPU=4 ./run_train.sh --parallelism.pipeline_parallel_degree=4 --parallelism.pipeline_parallel_schedule="1F1B" --parallelism.pipeline_parallel_microbatch_size=1