Skip to content

[moe] load-balancing aux loss#3000

Open
pianpwk wants to merge 2 commits intomainfrom
moe_aux_loss_v2
Open

[moe] load-balancing aux loss#3000
pianpwk wants to merge 2 commits intomainfrom
moe_aux_loss_v2

Conversation

@pianpwk
Copy link
Copy Markdown
Contributor

@pianpwk pianpwk commented Apr 16, 2026

Adds MoE load-balancing aux loss following #2061, using a custom autograd function that no-ops in FWD, updates per-layer score gradients in BWD, for PP/compile compatibility

seq-wise aux loss (aux loss settings aren't in CLI): MODULE=deepseek_v3 CONFIG=deepseek_v3_debugmodel NGPU=4 ./run_train.sh

step:  1  loss:  8.03737  grad_norm:  3.6433  memory: 22.21GiB(23.36%)  tps: 7,976  tflops: 4.33  mfu: 0.44%
step:  2  loss:  6.00046  grad_norm:  4.5564  memory: 22.26GiB(23.42%)  tps: 42,761  tflops: 23.22  mfu: 2.35%
step:  3  loss:  4.77488  grad_norm:  3.4097  memory: 22.26GiB(23.42%)  tps: 43,228  tflops: 23.47  mfu: 2.37%
step:  4  loss:  4.76890  grad_norm:  2.6735  memory: 22.26GiB(23.42%)  tps: 43,139  tflops: 23.42  mfu: 2.37%
step:  5  loss:  4.45792  grad_norm:  2.4215  memory: 22.26GiB(23.42%)  tps: 43,368  tflops: 23.55  mfu: 2.38%
step:  6  loss:  4.24030  grad_norm:  2.1336  memory: 22.26GiB(23.42%)  tps: 42,783  tflops: 23.23  mfu: 2.35%
step:  7  loss:  4.04367  grad_norm:  1.7707  memory: 22.26GiB(23.42%)  tps: 43,472  tflops: 23.60  mfu: 2.39%
step:  8  loss:  3.96764  grad_norm:  2.0263  memory: 22.26GiB(23.42%)  tps: 43,404  tflops: 23.57  mfu: 2.38%
step:  9  loss:  4.06103  grad_norm:  1.9714  memory: 22.26GiB(23.42%)  tps: 43,361  tflops: 23.54  mfu: 2.38%
step: 10  loss:  3.88263  grad_norm:  1.7082  memory: 22.26GiB(23.42%)  tps: 42,772  tflops: 23.22  mfu: 2.35%

batch-wise aux loss: MODULE=deepseek_v3 CONFIG=deepseek_v3_debugmodel NGPU=4 ./run_train.sh

step:  1  loss:  7.97961  grad_norm:  3.4346  memory: 22.21GiB(23.36%)  tps: 9,803  tflops: 5.32  mfu: 0.54%
step:  2  loss:  6.08554  grad_norm:  4.4377  memory: 22.26GiB(23.42%)  tps: 43,106  tflops: 23.40  mfu: 2.37%
step:  3  loss:  4.78368  grad_norm:  2.6392  memory: 22.26GiB(23.42%)  tps: 43,587  tflops: 23.67  mfu: 2.39%
step:  4  loss:  4.62059  grad_norm:  2.3105  memory: 22.26GiB(23.42%)  tps: 43,353  tflops: 23.54  mfu: 2.38%
step:  5  loss:  4.34816  grad_norm:  2.1437  memory: 22.26GiB(23.42%)  tps: 43,634  tflops: 23.69  mfu: 2.40%
step:  6  loss:  4.20917  grad_norm:  2.0145  memory: 22.26GiB(23.42%)  tps: 42,859  tflops: 23.27  mfu: 2.35%
step:  7  loss:  4.07222  grad_norm:  1.9364  memory: 22.26GiB(23.42%)  tps: 43,686  tflops: 23.72  mfu: 2.40%
step:  8  loss:  3.96189  grad_norm:  1.8112  memory: 22.26GiB(23.42%)  tps: 43,670  tflops: 23.71  mfu: 2.40%
step:  9  loss:  4.03467  grad_norm:  1.5550  memory: 22.26GiB(23.42%)  tps: 43,576  tflops: 23.66  mfu: 2.39%
step: 10  loss:  3.86282  grad_norm:  1.5751  memory: 22.26GiB(23.42%)  tps: 43,043  tflops: 23.37  mfu: 2.36%

seq-wise, compile: MODULE=deepseek_v3 CONFIG=deepseek_v3_debugmodel NGPU=4 ./run_train.sh --compile.enable

step:  1  loss:  8.04938  grad_norm:  3.6607  memory: 16.94GiB(17.81%)  tps: 2,219  tflops: 1.20  mfu: 0.12%
step:  2  loss:  6.02174  grad_norm:  4.5543  memory: 17.12GiB(18.01%)  tps: 69,787  tflops: 37.89  mfu: 3.83%
step:  3  loss:  4.71873  grad_norm:  3.1150  memory: 17.12GiB(18.01%)  tps: 71,874  tflops: 39.02  mfu: 3.95%
step:  4  loss:  4.66633  grad_norm:  2.6679  memory: 17.12GiB(18.01%)  tps: 71,459  tflops: 38.80  mfu: 3.92%
step:  5  loss:  4.35741  grad_norm:  2.4411  memory: 17.12GiB(18.01%)  tps: 72,135  tflops: 39.16  mfu: 3.96%
step:  6  loss:  4.13684  grad_norm:  1.9808  memory: 17.12GiB(18.01%)  tps: 69,997  tflops: 38.00  mfu: 3.84%
step:  7  loss:  4.00600  grad_norm:  1.9502  memory: 17.12GiB(18.01%)  tps: 72,220  tflops: 39.21  mfu: 3.96%
step:  8  loss:  3.91741  grad_norm:  2.0644  memory: 17.12GiB(18.01%)  tps: 71,965  tflops: 39.07  mfu: 3.95%
step:  9  loss:  3.99697  grad_norm:  1.6953  memory: 17.12GiB(18.01%)  tps: 71,561  tflops: 38.85  mfu: 3.93%
step: 10  loss:  3.81696  grad_norm:  1.5733  memory: 17.12GiB(18.01%)  tps: 70,282  tflops: 38.16  mfu: 3.86%

seq-wise, compile + PP=4 (1F1B): LOG_RANK=3 MODULE=deepseek_v3 CONFIG=deepseek_v3_debugmodel NGPU=4 ./run_train.sh --compile.enable --parallelism.pipeline_parallel_degree=4 --parallelism.pipeline_parallel_schedule="1F1B" --parallelism.pipeline_parallel_microbatch_size=1

step:  1  loss:  8.08122  grad_norm: 64178.5312  memory:  1.03GiB(1.08%)  tps: 303  tflops: 0.16  mfu: 0.02%
step:  2  loss:  5.97057  grad_norm: 76623.7969  memory:  1.03GiB(1.09%)  tps: 23,904  tflops: 12.98  mfu: 1.31%
step:  3  loss:  4.87226  grad_norm: 70649.4062  memory:  1.03GiB(1.09%)  tps: 29,812  tflops: 16.19  mfu: 1.64%
step:  4  loss:  4.71646  grad_norm: 44206.9141  memory:  1.03GiB(1.09%)  tps: 30,406  tflops: 16.51  mfu: 1.67%
step:  5  loss:  4.38501  grad_norm: 40740.8125  memory:  1.03GiB(1.09%)  tps: 29,821  tflops: 16.19  mfu: 1.64%
step:  6  loss:  4.18589  grad_norm: 36898.0781  memory:  1.03GiB(1.09%)  tps: 28,992  tflops: 15.74  mfu: 1.59%
step:  7  loss:  4.00744  grad_norm: 29890.0352  memory:  1.03GiB(1.09%)  tps: 29,994  tflops: 16.29  mfu: 1.65%
step:  8  loss:  3.95164  grad_norm: 32042.1719  memory:  1.03GiB(1.09%)  tps: 30,250  tflops: 16.42  mfu: 1.66%
step:  9  loss:  4.34464  grad_norm: 29014.9023  memory:  1.03GiB(1.09%)  tps: 30,358  tflops: 16.48  mfu: 1.67%
step: 10  loss:  3.86752  grad_norm: 25947.7754  memory:  1.03GiB(1.09%)  tps: 29,918  tflops: 16.24  mfu: 1.64%

seq-wise, compile + FSDP=2, PP=2 (Interleaved1F1B): LOG_RANK=3 MODULE=deepseek_v3 CONFIG=deepseek_v3_debugmodel NGPU=4 ./run_train.sh --compile.enable --parallelism.pipeline_parallel_degree=2 --parallelism.data_parallel_shard_degree=2 --parallelism.pipeline_parallel_schedule="Interleaved1F1B" --parallelism.pipeline_parallel_microbatch_size=1

step:  1  loss:  8.23562  grad_norm: 121093.8516  memory:  2.50GiB(2.63%)  tps: 677  tflops: 0.37  mfu: 0.04%
step:  2  loss:  6.15458  grad_norm: 150805.9688  memory:  2.77GiB(2.91%)  tps: 27,050  tflops: 14.69  mfu: 1.48%
step:  3  loss:  4.85215  grad_norm: 86998.0312  memory:  2.77GiB(2.91%)  tps: 42,324  tflops: 22.98  mfu: 2.32%
step:  4  loss:  4.64139  grad_norm: 78084.5391  memory:  2.77GiB(2.91%)  tps: 41,776  tflops: 22.68  mfu: 2.29%
step:  5  loss:  4.33879  grad_norm: 73486.1719  memory:  2.77GiB(2.91%)  tps: 42,432  tflops: 23.04  mfu: 2.33%
step:  6  loss:  4.17014  grad_norm: 65393.6758  memory:  2.77GiB(2.91%)  tps: 39,736  tflops: 21.57  mfu: 2.18%
step:  7  loss:  4.06022  grad_norm: 60990.1641  memory:  2.77GiB(2.91%)  tps: 42,637  tflops: 23.15  mfu: 2.34%
step:  8  loss:  3.97168  grad_norm: 62127.4062  memory:  2.77GiB(2.91%)  tps: 42,623  tflops: 23.14  mfu: 2.34%
step:  9  loss:  4.21506  grad_norm: 50417.2227  memory:  2.77GiB(2.91%)  tps: 42,231  tflops: 22.93  mfu: 2.32%
step: 10  loss:  3.85770  grad_norm: 51002.9922  memory:  2.77GiB(2.91%)  tps: 41,349  tflops: 22.45  mfu: 2.27%

seq-wise, FSDP=2, PP=2 (ZBV). no compile, no AC (ZBV + compile/AC fails): MODULE=deepseek_v3 CONFIG=deepseek_v3_debugmodel NGPU=4 ./run_train.sh --parallelism.pipeline_parallel_degree=2 --parallelism.data_parallel_shard_degree=2 --parallelism.pipeline_parallel_schedule="ZBVZeroBubble" --parallelism.pipeline_parallel_microbatch_size=1 --activation_checkpoint.mode=none

step:  1  loss:  8.12332  grad_norm: 108232.8359  memory:  2.73GiB(2.87%)  tps: 1,655  tflops: 0.90  mfu: 0.09%
step:  2  loss:  6.30996  grad_norm: 147877.1250  memory:  2.87GiB(3.01%)  tps: 19,140  tflops: 10.39  mfu: 1.05%
step:  3  loss:  5.03262  grad_norm: 101205.5703  memory:  2.87GiB(3.01%)  tps: 25,487  tflops: 13.84  mfu: 1.40%
step:  4  loss:  4.82529  grad_norm: 93716.3203  memory:  2.87GiB(3.01%)  tps: 25,745  tflops: 13.98  mfu: 1.41%
step:  5  loss:  4.50799  grad_norm: 90023.4297  memory:  2.87GiB(3.01%)  tps: 26,018  tflops: 14.13  mfu: 1.43%
step:  6  loss:  4.28772  grad_norm: 77841.2422  memory:  2.87GiB(3.01%)  tps: 25,657  tflops: 13.93  mfu: 1.41%
step:  7  loss:  4.15718  grad_norm: 68140.4766  memory:  2.87GiB(3.01%)  tps: 26,227  tflops: 14.24  mfu: 1.44%
step:  8  loss:  4.07489  grad_norm: 68516.4844  memory:  2.87GiB(3.01%)  tps: 25,949  tflops: 14.09  mfu: 1.42%
step:  9  loss:  4.31573  grad_norm: 59333.2773  memory:  2.87GiB(3.01%)  tps: 25,900  tflops: 14.06  mfu: 1.42%
step: 10  loss:  3.95042  grad_norm: 59229.8789  memory:  2.87GiB(3.01%)  tps: 25,688  tflops: 13.95  mfu: 1.41%

batch-wise + PP errors, as aux loss is not additive over microbatches: MODULE=deepseek_v3 CONFIG=deepseek_v3_debugmodel NGPU=4 ./run_train.sh --parallelism.pipeline_parallel_degree=4 --parallelism.pipeline_parallel_schedule="1F1B" --parallelism.pipeline_parallel_microbatch_size=1

ValueError: batch_wise MoE aux loss is incompatible with pipeline parallelism because per-microbatch token-to-expert counts do not reflect the full batch. Use sequence_wise instead.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 16, 2026
@pianpwk pianpwk changed the title pp auxloss [moe] load-balancing aux loss Apr 21, 2026
@tianyu-l tianyu-l requested a review from shuhuayu April 22, 2026 23:06
Comment thread torchtitan/models/common/moe.py Outdated
ctx.aux_loss_weight, # pyrefly: ignore [missing-attribute]
)
(aux_grad,) = torch.autograd.grad(aux_loss, scores_detached)
return grad_scores + aux_grad, None, None, None, None, None, None, None
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note (@tianyu-l): prefer compute aux_loss in fwd & save

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc, this is something (nested autograd) @xmfan wanted to avoid for graph breaks? We can 1) write the grad function analytically, can be tedious or 2) use torch.func.grad to transform the aux loss function, which is perhaps the simplest approach, or 3) use torch.library.custom_op+ register_autograd as in deepep.py. I think all of these approaches can avoid the global config added in this pr torch._dynamo.config.trace_autograd_ops = True, which can have side effects globally. Maybe Simon can correct me if anything not making sense.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with torch.func.grad to remove the flag, but it seems the loss computation is still in backwards. I don't think we save any memory by computing the loss in FWD.

Copy link
Copy Markdown
Contributor

@shuhuayu shuhuayu Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's by design. We do not compute loss in the forward to avoid coupling with modeling loss which breaks pp. All we need is its gradients in the backward. On the other hand, if we want to log load balancing loss to monitor the healthiness of training, we can do it under torch.no_grad()in the forward.

Comment thread torchtitan/models/common/moe.py Outdated
Comment on lines +181 to +206
with torch.enable_grad():
scores_detached = scores.detach().requires_grad_(True)
if (
ctx.aux_loss_type == "sequence_wise"
): # pyrefly: ignore [missing-attribute]
aux_loss = MoE._sequence_wise_aux_loss(
scores_detached,
selected_experts_indices,
ctx.bs, # pyrefly: ignore [missing-attribute]
ctx.slen, # pyrefly: ignore [missing-attribute]
ctx.top_k, # pyrefly: ignore [missing-attribute]
ctx.aux_loss_weight, # pyrefly: ignore [missing-attribute]
)
else:
num_tokens_per_expert = torch.histc(
selected_experts_indices.view(-1).float(),
bins=ctx.num_experts, # pyrefly: ignore [missing-attribute]
min=0,
max=ctx.num_experts, # pyrefly: ignore [missing-attribute]
)
aux_loss = MoE._batch_wise_aux_loss(
scores_detached,
num_tokens_per_expert,
ctx.top_k, # pyrefly: ignore [missing-attribute]
ctx.aux_loss_weight, # pyrefly: ignore [missing-attribute]
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't we compute these in forward?

Comment thread torchtitan/models/qwen3_vl/model.py Outdated
layer_cfg.moe.router._debug_force_load_balance = (
debug.moe_force_load_balance
)
apply_moe_load_balance_config(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put this at config time, not run time

Comment thread torchtitan/config/configs.py Outdated
Comment on lines +82 to +90
moe_aux_loss_weight: float = 0.0
"""Weight for the MoE auxiliary load-balance loss. 0 disables it."""

moe_aux_loss_type: Literal["sequence_wise", "batch_wise"] = "sequence_wise"
"""Type of MoE auxiliary load-balance loss."""

moe_load_balance_coeff: float | None = 1e-3
"""Coefficient for aux-loss-free bias-based MoE load balancing.
Overrides the model config default when set. None disables it."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do this in model_registry not allow to override per CLI job

Comment thread torchtitan/models/common/moe.py Outdated
@rakkit
Copy link
Copy Markdown
Contributor

rakkit commented Apr 23, 2026

would be good if we add a buffer of load_balance_loss together with tokens_per_expert so one can log it out in pre-optim step hook.

@tianyu-l
Copy link
Copy Markdown
Contributor

tianyu-l commented Apr 23, 2026

would be good if we add a buffer of load_balance_loss together with tokens_per_expert so one can log it out in pre-optim step hook.

@rakkit Logging capability will be separately addressed by @felipemello1 in a PR similar to #2607. Will tag you for review.

Comment thread torchtitan/models/common/moe.py Outdated
return (f_i * p_i).sum() * aux_loss_weight


class _AuxLossBackward(torch.autograd.Function):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just use some regular name like LoadBalanceAuxLoss? Did not get why we need to put Backward only in the name.

@pianpwk pianpwk marked this pull request as ready for review April 28, 2026 17:52
@pianpwk pianpwk requested a review from tianyu-l April 28, 2026 17:52
router_top_k=3,
router_score_func="softmax",
score_before_experts=False,
aux_loss=MoELoadBalanceAuxLoss.Config(weight=1e-4),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check technical report of dsv3 and see if it's actually used.

IIRC qwen3 used it so let's enable it for qwen3. Need to check gpt-oss as well.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this is from the hyperparams section; I'll check the other 2

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread torchtitan/trainer.py Outdated
"batch_wise MoE aux loss is incompatible with pipeline "
"parallelism. Use sequence_wise instead."
)
aux_loss.local_batch_size = config.training.local_batch_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be done at config time not here.

Comment thread torchtitan/trainer.py Outdated
aux_loss = layer_cfg.moe.aux_loss
if aux_loss.weight > 0:
if aux_loss.type == "batch_wise" and pp_enabled:
raise ValueError(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be done in update_from_configs fn, not here

Comment thread torchtitan/models/common/moe.py Outdated
config.shared_experts.build() if config.shared_experts is not None else None
)
self.aux_loss = MoELoadBalanceAuxLoss(config.aux_loss)
self.aux_loss.top_k = config.router.top_k
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be set in config time

Comment thread torchtitan/models/common/moe.py Outdated
self.shared_experts = (
config.shared_experts.build() if config.shared_experts is not None else None
)
self.aux_loss = MoELoadBalanceAuxLoss(config.aux_loss)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use build, not constructor

Comment thread torchtitan/models/common/moe.py Outdated
class Config:
weight: float = 0.0
"""Weight for the auxiliary load-balance loss. 0 disables it."""
type: Literal["sequence_wise", "batch_wise"] = "sequence_wise"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have two subclasses, rather than using a str factory

Comment thread torchtitan/models/common/moe.py Outdated
bs: int,
slen: int,
) -> torch.Tensor:
if self.weight <= 0:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why it would be < 0?

Comment thread torchtitan/models/common/moe.py Outdated
local_batch_size: int | None = None
"""Total local batch size (before microbatching). Used to normalize aux
loss gradients across pipeline-parallel microbatches so they match the
non-PP case. Set automatically by the trainer; None means use the
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, could you remind me why not using global batch size?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC this is to match the non-microbatch-pipelining variant; the loss should be averaged across all tokens this sees, but this autograd fn will accumulate gradient across microbatches, so we normalize by micro_bs / local_bs .

Comment thread torchtitan/models/common/moe.py Outdated
- num_tokens_per_expert (torch.Tensor):
Number of tokens assigned to each expert with shape ``(num_experts,)``.
- scores (torch.Tensor):
Full router scores for all experts with shape ``(bs*slen, num_experts)``.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add reasons to have this field returned (for load balance loss)

class _AuxLossBase(torch.autograd.Function):
"""Injects auxiliary load-balance loss gradients at the router scores level.

Identity in forward (returns ``top_scores`` unchanged). In backward,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Identity in forward (returns top_scores unchanged)

Is this even a goal?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is so it stays on the autograd graph, will update the docstring

@tianyu-l tianyu-l requested a review from shuhuayu April 29, 2026 19:49
Comment thread torchtitan/models/common/moe.py Outdated
Expert indices selected for each token with shape ``(bs*slen, top_k)``.
- num_tokens_per_expert (torch.Tensor):
Number of tokens assigned to each expert with shape ``(num_experts,)``.
- scores (torch.Tensor):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we put scores before top_scores? just put scores together.

Comment thread torchtitan/models/qwen3/model.py Outdated
debug.moe_force_load_balance
)

if training is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why qwen3 has this condition check, what's the difference with other models?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gonna remove this

),
top_k=top_k,
),
aux_loss=BatchWiseAuxLoss.Config(weight=1e-3),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you put some reference here where we know it uses batchwise

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually couldn't find arxiv references, so left hf urls - maybe that's inaccurate/weird. also don't have something for llama4 yet.

router=router,
experts=experts,
shared_experts=shared_experts,
aux_loss=BatchWiseAuxLoss.Config(weight=1e-3),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar

non_blocking_capacity_factor=non_blocking_capacity_factor,
),
aux_loss=BatchWiseAuxLoss.Config(weight=1e-3),
),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar, we need some references, may be helpful these two

  1. https://arxiv.org/pdf/2505.09388
  2. https://arxiv.org/pdf/2501.11873

Is our batchwise auxloss equivalent to the global batch lbl? How do we compute average loss over dp ranks? Like the equation (4) in the paper 2 above.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is batch-wise aux loss yes, currently we ban with PP because it's not additive across microbatches, maybe we should ban with DP too until the count accumulator is implemented.

Comment thread torchtitan/models/common/moe.py Outdated
to ``scores``'s gradient. ``top_scores`` is a pass-through so this node
remains in the autograd graph.

Uses ``torch.func.grad`` instead of nested ``torch.autograd.grad`` to
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think avoiding this global config is part our code iterating process and no need to reveal it to the users.

(
scores,
selected_experts_indices,
) = ctx.saved_tensors
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add one line comment if we want to be explicit.

Suggested change
) = ctx.saved_tensors
) = ctx.saved_tensors
# torch.func.grad avoids the graph break that torch.autograd.grad causes under torch.compile

@pianpwk pianpwk force-pushed the moe_aux_loss_v2 branch 2 times, most recently from fcb7eff to 4f80bc3 Compare April 30, 2026 21:36
indices_per_seq = selected_experts_indices.view(bs, -1)
offset = torch.arange(bs, device=indices_per_seq.device).unsqueeze(1) * num_experts
flat_indices = (indices_per_seq + offset).reshape(-1)
counts = torch.bincount(flat_indices.long(), minlength=bs * num_experts)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you double check whether torch.bincout leads to a graph break under torch.compile? Maybe we need .scatter_add on a sparse tensor.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This runs under compile, I believe capture_scalar_outputs enables capturing the output shape.

pianpwk added 2 commits May 6, 2026 22:48
Add PP-safe auxiliary load-balance loss for MoE models using a custom
autograd.Function that injects gradients during backward without
retain_graph. Supports sequence-wise (DeepSeek-V3 style) and batch-wise
loss variants via MoELoadBalanceAuxLoss config. Applied across all MoE
models: DeepSeek V3, Llama4, Qwen3, Qwen3-VL, GPT-OSS.
@pianpwk pianpwk force-pushed the moe_aux_loss_v2 branch from 4f80bc3 to e8fe522 Compare May 7, 2026 06:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants