From b5222169c6057f210f482694feb286c445eaac10 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Thu, 30 Apr 2026 14:26:11 -0700 Subject: [PATCH 1/2] MoE auxiliary load-balance loss support Add PP-safe auxiliary load-balance loss for MoE models using a custom autograd.Function that injects gradients during backward without retain_graph. Supports sequence-wise (DeepSeek-V3 style) and batch-wise loss variants via MoELoadBalanceAuxLoss config. Applied across all MoE models: DeepSeek V3, Llama4, Qwen3, Qwen3-VL, GPT-OSS. --- torchtitan/models/common/__init__.py | 5 +- torchtitan/models/common/config_utils.py | 40 ++++- torchtitan/models/common/moe.py | 208 +++++++++++++++++++++- torchtitan/models/deepseek_v3/__init__.py | 13 +- torchtitan/models/deepseek_v3/model.py | 7 + torchtitan/models/gpt_oss/__init__.py | 3 +- torchtitan/models/gpt_oss/model.py | 7 + torchtitan/models/llama4/__init__.py | 2 + torchtitan/models/llama4/model.py | 7 + torchtitan/models/qwen3/__init__.py | 2 + torchtitan/models/qwen3/model.py | 7 + torchtitan/models/qwen3_vl/__init__.py | 2 + torchtitan/models/qwen3_vl/model.py | 7 + torchtitan/trainer.py | 28 ++- 14 files changed, 313 insertions(+), 25 deletions(-) diff --git a/torchtitan/models/common/__init__.py b/torchtitan/models/common/__init__.py index 0f8387c62d..c91f83e0e6 100644 --- a/torchtitan/models/common/__init__.py +++ b/torchtitan/models/common/__init__.py @@ -24,7 +24,7 @@ from .embedding import Embedding from .feed_forward import compute_ffn_hidden_dim, FeedForward from .linear import Linear -from .moe import MoE +from .moe import BatchWiseAuxLoss, MoE, MoELoadBalanceAuxLoss, SequenceWiseAuxLoss from .rmsnorm import RMSNorm from .rope import ( apply_rotary_emb_complex, @@ -48,7 +48,10 @@ "get_sliding_window_mask_mod", "GQAttention", "Linear", + "BatchWiseAuxLoss", + "MoELoadBalanceAuxLoss", "MoE", + "SequenceWiseAuxLoss", "QKVLinear", "RMSNorm", "RoPE", diff --git a/torchtitan/models/common/config_utils.py b/torchtitan/models/common/config_utils.py index 7d93cbec4c..4c32918883 100644 --- a/torchtitan/models/common/config_utils.py +++ b/torchtitan/models/common/config_utils.py @@ -11,6 +11,7 @@ """ from collections.abc import Callable +from dataclasses import replace from typing import Literal from torchtitan.models.common.attention import ( @@ -23,7 +24,13 @@ ) from torchtitan.models.common.feed_forward import FeedForward from torchtitan.models.common.linear import Linear -from torchtitan.models.common.moe import GroupedExperts, MoE, TokenChoiceTopKRouter +from torchtitan.models.common.moe import ( + BatchWiseAuxLoss, + GroupedExperts, + MoE, + MoELoadBalanceAuxLoss, + TokenChoiceTopKRouter, +) from torchtitan.models.common.rmsnorm import RMSNorm from torchtitan.models.common.token_dispatcher import ( AllToAllTokenDispatcher, @@ -152,14 +159,18 @@ def make_moe_config( experts: GroupedExperts.Config, shared_experts: FeedForward.Config | None = None, load_balance_coeff: float | None = 1e-3, + aux_loss: MoELoadBalanceAuxLoss.Config | None = None, ) -> MoE.Config: """Build a fully-specified MoE.Config.""" + if aux_loss: + aux_loss = replace(aux_loss, top_k=router.top_k) return MoE.Config( num_experts=num_experts, load_balance_coeff=load_balance_coeff, router=router, experts=experts, shared_experts=shared_experts, + aux_loss=aux_loss, ) @@ -265,3 +276,30 @@ def make_experts_config( non_blocking_capacity_factor=non_blocking_capacity_factor, ), ) + + +def update_moe_aux_loss_configs( + layers: list, + *, + pp_enabled: bool, + global_batch_size: int, +) -> None: + """Update aux loss configs for all MoE layers from trainer config. + + Call this from each model's ``update_from_config`` to set + ``global_batch_size`` for gradient normalization across PP microbatches, + gradient accumulation steps, and DP ranks (FSDP sum). Also validates + that batch-wise aux loss is not used with PP. + """ + for layer_cfg in layers: + moe_cfg = getattr(layer_cfg, "moe", None) + if moe_cfg is None: + continue + aux_loss = moe_cfg.aux_loss + if aux_loss.weight > 0: + if isinstance(aux_loss, BatchWiseAuxLoss.Config) and pp_enabled: + raise ValueError( + "batch_wise MoE aux loss is incompatible with pipeline " + "parallelism. Use sequence_wise instead." + ) + aux_loss.global_batch_size = global_batch_size diff --git a/torchtitan/models/common/moe.py b/torchtitan/models/common/moe.py index 0e3ac5bd73..f378e318ae 100644 --- a/torchtitan/models/common/moe.py +++ b/torchtitan/models/common/moe.py @@ -12,6 +12,7 @@ from torch import nn from torch.distributed.tensor import DTensor, Partial +from torchtitan.config import Configurable from torchtitan.models.common.feed_forward import FeedForward from torchtitan.models.common.linear import Linear @@ -140,6 +141,118 @@ def forward( return self.token_dispatcher.combine(routed_output, metadata, x, shared_experts) +def _sequence_wise_aux_loss( + scores: torch.Tensor, + selected_experts_indices: torch.Tensor, + bs: int, + slen: int, + top_k: int, + aux_loss_weight: float, +) -> torch.Tensor: + """Sequence-wise auxiliary load-balance loss (DeepSeek-V3 Eqs 17-20).""" + num_experts = scores.size(-1) + scores_per_seq = scores.view(bs, slen, num_experts) + denom = scores_per_seq.sum(dim=-1, keepdim=True) + 1e-20 + probs_per_seq = scores_per_seq / denom + p_i = probs_per_seq.mean(dim=1) + indices_per_seq = selected_experts_indices.view(bs, -1) + offset = torch.arange(bs, device=indices_per_seq.device).unsqueeze(1) * num_experts + flat_indices = (indices_per_seq + offset).reshape(-1) + counts = torch.bincount(flat_indices.long(), minlength=bs * num_experts) + counts = counts.reshape(bs, num_experts).to(dtype=scores.dtype) + f_i = counts * (num_experts / (top_k * slen)) + return (f_i * p_i).sum(dim=1).mean() * aux_loss_weight + + +def _batch_wise_aux_loss( + scores: torch.Tensor, + selected_experts_indices: torch.Tensor, + bs: int, + slen: int, + top_k: int, + aux_loss_weight: float, +) -> torch.Tensor: + """Batch-wise auxiliary load-balance loss.""" + num_experts = scores.size(-1) + total_tokens = scores.size(0) + num_tokens_per_expert = torch.histc( + selected_experts_indices.view(-1).float(), + bins=num_experts, + min=0, + max=num_experts, + ) + p_i = scores.mean(dim=0) + f_i = num_tokens_per_expert.to(scores.dtype) * ( + num_experts / (top_k * total_tokens) + ) + return (f_i * p_i).sum() * aux_loss_weight + + +class _AuxLossBase(torch.autograd.Function): + """Injects auxiliary load-balance loss gradients at the router scores level. + + Identity in forward (returns ``top_scores`` unchanged). In backward, + computes ``d(aux_loss)/d(scores)`` via ``torch.func.grad`` and adds it + to ``scores``'s gradient. ``top_scores`` is a pass-through so this node + remains in the autograd graph. + """ + + @staticmethod + def forward( # pyrefly: ignore [bad-override] + ctx: torch.autograd.function.FunctionCtx, + top_scores: torch.Tensor, + scores: torch.Tensor, + selected_experts_indices: torch.Tensor, + bs: int, + slen: int, + top_k: int, + aux_loss_weight: float, + ) -> torch.Tensor: + ctx.save_for_backward(scores, selected_experts_indices) + ctx.bs = bs # pyrefly: ignore [missing-attribute] + ctx.slen = slen # pyrefly: ignore [missing-attribute] + ctx.top_k = top_k # pyrefly: ignore [missing-attribute] + ctx.aux_loss_weight = aux_loss_weight # pyrefly: ignore [missing-attribute] + return top_scores + + @staticmethod + def _backward_impl(loss_fn, ctx, grad_top_scores): + ( + scores, + selected_experts_indices, + ) = ctx.saved_tensors + # torch.func.grad avoids the graph break that torch.autograd.grad causes under torch.compile + aux_grad = torch.func.grad(loss_fn)( + scores, + selected_experts_indices, + ctx.bs, + ctx.slen, + ctx.top_k, + ctx.aux_loss_weight, + ) + return grad_top_scores, aux_grad, None, None, None, None, None + + +class _SequenceWiseAuxLoss(_AuxLossBase): + @staticmethod + def backward( # pyrefly: ignore [bad-override] + ctx: torch.autograd.function.FunctionCtx, + grad_top_scores: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, None, None, None, None, None]: + return _AuxLossBase._backward_impl( + _sequence_wise_aux_loss, ctx, grad_top_scores + ) + + +class _BatchWiseAuxLoss(_AuxLossBase): + @staticmethod + def backward( # pyrefly: ignore [bad-override] + ctx: torch.autograd.function.FunctionCtx, + grad_top_scores: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, None, None, None, None, None]: + return _AuxLossBase._backward_impl(_batch_wise_aux_loss, ctx, grad_top_scores) + + class TokenChoiceTopKRouter(Module): """This class implements token-choice routing. In token-choice top-K routing, each token is routed to top K experts based on the router scores. @@ -233,8 +346,10 @@ def _get_node_limited_routing_scores( return scores_for_choice def forward( - self, x: torch.Tensor, expert_bias: torch.Tensor | None = None - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + self, + x: torch.Tensor, + expert_bias: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. @@ -242,7 +357,10 @@ def forward( Used for load balancing. Defaults to None. Returns: - tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + - scores (torch.Tensor): + Full router scores for all experts with shape ``(bs*slen, num_experts)``. + Returned for load-balancing aux loss computation. - top_scores (torch.Tensor): Routing scores for selected experts with shape ``(bs*slen, top_k)``. - selected_experts_indices (torch.Tensor): @@ -297,7 +415,80 @@ def forward( max=self.num_experts, ) - return top_scores, selected_experts_indices, num_tokens_per_expert + return scores, top_scores, selected_experts_indices, num_tokens_per_expert + + +class MoELoadBalanceAuxLoss(Configurable): + """MoE auxiliary load-balance loss. + + Injects aux loss gradients into router scores without changing the model + output (PP-safe). Call instance with router outputs to apply. + + Subclass to select loss variant (sequence-wise or batch-wise). + """ + + _autograd_fn: type[_AuxLossBase] + + @dataclass(kw_only=True, slots=True) + class Config(Configurable.Config): + weight: float = 0.0 + """Weight for the auxiliary load-balance loss. 0 disables it.""" + top_k: int = 1 + """Number of experts per token. Must match the router's top_k.""" + global_batch_size: int | None = None + """Global batch size to normalize aux loss across all batches/microbatches. + Torchtitan generally uses sum aggregation & global_bs division. + Because BatchWiseAuxLoss doesn't commute over microbatches, we ban + batch-wise + PP usage in any case.""" + + def __init__(self, config: Config): + if config.weight < 0: + raise ValueError(f"aux_loss.weight must be >= 0, got {config.weight}") + self.weight = config.weight + self.global_batch_size = config.global_batch_size + self.top_k = config.top_k + + def __call__( + self, + top_scores: torch.Tensor, + scores: torch.Tensor, + selected_experts_indices: torch.Tensor, + bs: int, + slen: int, + ) -> torch.Tensor: + if self.weight == 0: + return top_scores + global_bs = self.global_batch_size or bs + scaled_weight = self.weight * bs / global_bs + return self._autograd_fn.apply( + top_scores, + scores, + selected_experts_indices, + bs, + slen, + self.top_k, + scaled_weight, + ) + + +class SequenceWiseAuxLoss(MoELoadBalanceAuxLoss): + """Sequence-wise auxiliary load-balance loss (DeepSeek-V3 Eqs 17-20).""" + + _autograd_fn = _SequenceWiseAuxLoss + + @dataclass(kw_only=True, slots=True) + class Config(MoELoadBalanceAuxLoss.Config): + pass + + +class BatchWiseAuxLoss(MoELoadBalanceAuxLoss): + """Batch-wise auxiliary load-balance loss.""" + + _autograd_fn = _BatchWiseAuxLoss + + @dataclass(kw_only=True, slots=True) + class Config(MoELoadBalanceAuxLoss.Config): + pass class MoE(Module): @@ -325,6 +516,7 @@ class Config(Module.Config): router: TokenChoiceTopKRouter.Config load_balance_coeff: float | None = 1e-3 shared_experts: FeedForward.Config | None = None + aux_loss: MoELoadBalanceAuxLoss.Config | None = None def __init__(self, config: Config): super().__init__() @@ -335,6 +527,7 @@ def __init__(self, config: Config): self.shared_experts = ( config.shared_experts.build() if config.shared_experts is not None else None ) + self.aux_loss = config.aux_loss.build() if config.aux_loss is not None else None # define fields for auxiliary-loss-free load balancing (https://arxiv.org/abs/2408.15664) # NOTE: tokens_per_expert is accumulated in the model forward pass. @@ -391,14 +584,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: bs, slen, dim = x.shape x = x.view(-1, dim) + # scores shape (bs*slen, num_experts) # top_scores and selected_experts_indices shape (bs*slen, top_k) # num_tokens_per_expert shape (num_experts,) ( + scores, top_scores, selected_experts_indices, num_tokens_per_expert, ) = self.router(x, self.expert_bias) + if self.training and self.aux_loss is not None: + top_scores = self.aux_loss( + top_scores, scores, selected_experts_indices, bs, slen + ) + # tokens_per_expert will be used to update the expert bias for load balancing. # and also to count the expert usage # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 0393c6aba2..2285c2749f 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -21,6 +21,7 @@ make_moe_config, make_router_config, ) +from torchtitan.models.common.moe import MoELoadBalanceAuxLoss, SequenceWiseAuxLoss from torchtitan.models.common.param_init import depth_scaled_std from torchtitan.protocols.model_spec import ModelSpec @@ -173,6 +174,7 @@ def _build_dsv3_layers( attn_backend: str, moe_comm_backend: str, non_blocking_capacity_factor: float | None, + aux_loss: MoELoadBalanceAuxLoss.Config | None = None, ) -> list[TransformerBlock.Config]: """Build the list of per-layer TransformerBlock configs. @@ -236,6 +238,7 @@ def _build_dsv3_layers( w1_param_init=_LINEAR_INIT, w2w3_param_init=_depth_init(layer_id), ), + aux_loss=aux_loss, ) layers.append( @@ -284,8 +287,9 @@ def _debugmodel( num_experts=num_experts, num_shared_experts=num_shared_experts, router_top_k=3, - router_score_func="softmax", + router_score_func="sigmoid", score_before_experts=False, + aux_loss=SequenceWiseAuxLoss.Config(weight=1e-4), attn_backend=attn_backend, moe_comm_backend=moe_comm_backend, non_blocking_capacity_factor=non_blocking_capacity_factor, @@ -349,8 +353,9 @@ def _16b( num_experts=num_experts, num_shared_experts=num_shared_experts, router_top_k=6, - router_score_func="softmax", + router_score_func="sigmoid", score_before_experts=False, + aux_loss=SequenceWiseAuxLoss.Config(weight=1e-4), attn_backend=attn_backend, moe_comm_backend=moe_comm_backend, non_blocking_capacity_factor=non_blocking_capacity_factor, @@ -415,11 +420,12 @@ def _236b( num_experts=num_experts, num_shared_experts=num_shared_experts, router_top_k=6, - router_score_func="softmax", + router_score_func="sigmoid", router_num_expert_groups=8, router_num_limited_groups=3, router_route_scale=16.0, score_before_experts=False, + aux_loss=SequenceWiseAuxLoss.Config(weight=1e-4), attn_backend=attn_backend, moe_comm_backend=moe_comm_backend, non_blocking_capacity_factor=non_blocking_capacity_factor, @@ -490,6 +496,7 @@ def _671b( router_route_scale=2.5, router_route_norm=True, score_before_experts=False, + aux_loss=SequenceWiseAuxLoss.Config(weight=1e-4), attn_backend=attn_backend, moe_comm_backend=moe_comm_backend, non_blocking_capacity_factor=non_blocking_capacity_factor, diff --git a/torchtitan/models/deepseek_v3/model.py b/torchtitan/models/deepseek_v3/model.py index 2876fde864..3bfda3c20d 100644 --- a/torchtitan/models/deepseek_v3/model.py +++ b/torchtitan/models/deepseek_v3/model.py @@ -16,6 +16,7 @@ BaseAttention, ScaledDotProductAttention, ) +from torchtitan.models.common.config_utils import update_moe_aux_loss_configs from torchtitan.models.common.decoder import Decoder, TransformerBlock from torchtitan.models.common.linear import Linear from torchtitan.models.common.rmsnorm import RMSNorm @@ -237,6 +238,12 @@ def update_from_config( "(expert_parallel_degree > 1)." ) + update_moe_aux_loss_configs( + self.layers, + pp_enabled=parallelism.pipeline_parallel_degree > 1, + global_batch_size=kwargs["global_batch_size"], + ) + if parallelism.context_parallel_degree > 1 and not isinstance( self.layers[0].attention.inner_attention, ScaledDotProductAttention.Config, diff --git a/torchtitan/models/gpt_oss/__init__.py b/torchtitan/models/gpt_oss/__init__.py index e1d6a01473..128a7acae6 100644 --- a/torchtitan/models/gpt_oss/__init__.py +++ b/torchtitan/models/gpt_oss/__init__.py @@ -14,7 +14,7 @@ from torchtitan.models.common import Embedding, Linear, RMSNorm, RoPE, TransformerBlock from torchtitan.models.common.attention import FusedQKVLinear, QKVLinear from torchtitan.models.common.config_utils import make_token_dispatcher_config -from torchtitan.models.common.moe import TokenChoiceTopKRouter +from torchtitan.models.common.moe import BatchWiseAuxLoss, TokenChoiceTopKRouter from torchtitan.models.common.param_init import depth_scaled_std from torchtitan.protocols.model_spec import ModelSpec @@ -199,6 +199,7 @@ def _build_gptoss_layers( ), top_k=top_k, ), + aux_loss=BatchWiseAuxLoss.Config(weight=1e-3), ) layer_cfg = GptOssTransformerBlock.Config( attention=attn_cfg, diff --git a/torchtitan/models/gpt_oss/model.py b/torchtitan/models/gpt_oss/model.py index f2d7b90f10..f6b12eee05 100644 --- a/torchtitan/models/gpt_oss/model.py +++ b/torchtitan/models/gpt_oss/model.py @@ -22,6 +22,7 @@ get_document_mask_mod, get_sliding_window_mask_mod, ) +from torchtitan.models.common.config_utils import update_moe_aux_loss_configs from torchtitan.models.common.decoder import Decoder, TransformerBlock from torchtitan.models.common.linear import Linear from torchtitan.models.common.rope import apply_rotary_emb_cos_sin @@ -212,6 +213,12 @@ def update_from_config( ) layer_cfg.moe.experts.use_grouped_mm = False + update_moe_aux_loss_configs( + self.layers, + pp_enabled=parallelism.pipeline_parallel_degree > 1, + global_batch_size=kwargs["global_batch_size"], + ) + tp = parallelism.tensor_parallel_degree if tp > 1: n_heads = self.layers[0].attention.n_heads diff --git a/torchtitan/models/llama4/__init__.py b/torchtitan/models/llama4/__init__.py index da47175cef..6022ddf04d 100644 --- a/torchtitan/models/llama4/__init__.py +++ b/torchtitan/models/llama4/__init__.py @@ -28,6 +28,7 @@ make_moe_config, make_router_config, ) +from torchtitan.models.common.moe import BatchWiseAuxLoss from torchtitan.models.common.param_init import depth_scaled_std from torchtitan.protocols.model_spec import ModelSpec @@ -150,6 +151,7 @@ def _build_llama4_layers( router=router, experts=experts, shared_experts=shared_experts, + aux_loss=BatchWiseAuxLoss.Config(weight=1e-3), ) ffn_cfg = None else: diff --git a/torchtitan/models/llama4/model.py b/torchtitan/models/llama4/model.py index 49ca04bc5a..164c347680 100644 --- a/torchtitan/models/llama4/model.py +++ b/torchtitan/models/llama4/model.py @@ -18,6 +18,7 @@ get_document_mask_mod, get_fixed_block_mask_mod, ) +from torchtitan.models.common.config_utils import update_moe_aux_loss_configs from torchtitan.models.common.decoder import Decoder, TransformerBlock from torchtitan.models.utils import get_moe_model_nparams_and_flops from torchtitan.tools.logging import logger @@ -162,6 +163,12 @@ def update_from_config( "(expert_parallel_degree > 1)." ) + update_moe_aux_loss_configs( + self.layers, + pp_enabled=parallelism.pipeline_parallel_degree > 1, + global_batch_size=kwargs["global_batch_size"], + ) + if parallelism.context_parallel_degree > 1: raise NotImplementedError( "Context Parallel is not supported for Llama4 " diff --git a/torchtitan/models/qwen3/__init__.py b/torchtitan/models/qwen3/__init__.py index bf00f9a050..c10ef52dfb 100644 --- a/torchtitan/models/qwen3/__init__.py +++ b/torchtitan/models/qwen3/__init__.py @@ -22,6 +22,7 @@ make_moe_config, make_router_config, ) +from torchtitan.models.common.moe import BatchWiseAuxLoss from torchtitan.models.common.param_init import depth_scaled_std, skip_param_init from torchtitan.models.common.rmsnorm import RMSNorm from torchtitan.protocols.model_spec import ModelSpec @@ -172,6 +173,7 @@ def _build_qwen3_moe_layers( comm_backend=moe_comm_backend, non_blocking_capacity_factor=non_blocking_capacity_factor, ), + aux_loss=BatchWiseAuxLoss.Config(weight=1e-3), ), ) ) diff --git a/torchtitan/models/qwen3/model.py b/torchtitan/models/qwen3/model.py index 223983c89a..c1b409c85e 100644 --- a/torchtitan/models/qwen3/model.py +++ b/torchtitan/models/qwen3/model.py @@ -17,6 +17,7 @@ GQAttention, VarlenAttention, ) +from torchtitan.models.common.config_utils import update_moe_aux_loss_configs from torchtitan.models.common.decoder import Decoder, TransformerBlock from torchtitan.models.utils import get_moe_model_nparams_and_flops from torchtitan.tools.logging import logger @@ -120,6 +121,12 @@ def update_from_config( debug.moe_force_load_balance ) + update_moe_aux_loss_configs( + self.layers, + pp_enabled=parallelism.pipeline_parallel_degree > 1, + global_batch_size=kwargs["global_batch_size"], + ) + if parallelism.context_parallel_degree > 1 and isinstance( self.layers[0].attention.inner_attention, VarlenAttention.Config ): diff --git a/torchtitan/models/qwen3_vl/__init__.py b/torchtitan/models/qwen3_vl/__init__.py index 57d83028a8..42473ab77a 100644 --- a/torchtitan/models/qwen3_vl/__init__.py +++ b/torchtitan/models/qwen3_vl/__init__.py @@ -20,6 +20,7 @@ make_moe_config, make_router_config, ) +from torchtitan.models.common.moe import BatchWiseAuxLoss from torchtitan.models.common.param_init import depth_scaled_std, skip_param_init from torchtitan.models.common.rmsnorm import RMSNorm from torchtitan.models.qwen3.model import Qwen3TransformerBlock @@ -232,6 +233,7 @@ def _build_qwen3_vl_moe_layers( comm_backend=moe_comm_backend, non_blocking_capacity_factor=non_blocking_capacity_factor, ), + aux_loss=BatchWiseAuxLoss.Config(weight=1e-3), ), ) ) diff --git a/torchtitan/models/qwen3_vl/model.py b/torchtitan/models/qwen3_vl/model.py index 7ad1dd0541..ab8f4aef2a 100644 --- a/torchtitan/models/qwen3_vl/model.py +++ b/torchtitan/models/qwen3_vl/model.py @@ -12,6 +12,7 @@ from torch import nn from torchtitan.models.common.attention import AttentionMasksType, GQAttention +from torchtitan.models.common.config_utils import update_moe_aux_loss_configs from torchtitan.models.qwen3.model import Qwen3Model from torchtitan.models.utils import get_moe_model_nparams_and_flops from torchtitan.tools.logging import logger @@ -81,6 +82,12 @@ def update_from_config( debug.moe_force_load_balance ) + update_moe_aux_loss_configs( + self.layers, + pp_enabled=parallelism.pipeline_parallel_degree > 1, + global_batch_size=kwargs["global_batch_size"], + ) + tp = parallelism.tensor_parallel_degree if tp > 1: n_heads = self.layers[0].attention.n_heads diff --git a/torchtitan/trainer.py b/torchtitan/trainer.py index 5aa6e5792c..18382b25e5 100644 --- a/torchtitan/trainer.py +++ b/torchtitan/trainer.py @@ -241,11 +241,24 @@ def __init__(self, config: Config): local_batch_size=config.training.local_batch_size, ) + global_batch_size = config.training.global_batch_size + if global_batch_size < 0: + global_batch_size = config.training.local_batch_size * batch_degree + assert global_batch_size > 0 + assert ( + global_batch_size % (config.training.local_batch_size * batch_degree) == 0 + ), ( + f"global batch size must be multiple of local batch size times " + f"data-parallel degree ({global_batch_size} " + f"% ({config.training.local_batch_size} * {batch_degree}) != 0)" + ) + # build model (using meta init) model_config = model_spec.model # set the model args from training job configs model_config.update_from_config( trainer_config=config, + global_batch_size=global_batch_size, ) self.model_config = model_config @@ -302,21 +315,6 @@ def __init__(self, config: Config): compile_config=config.compile, ) - # verify batch sizes - global_batch_size = config.training.global_batch_size - if global_batch_size < 0: - # This global batch size results in 1 gradient accumulation - # step. - global_batch_size = config.training.local_batch_size * batch_degree - assert global_batch_size > 0 - assert ( - global_batch_size % (config.training.local_batch_size * batch_degree) == 0 - ), ( - f"global batch size must be multiple of local batch size times " - f"data-parallel degree ({global_batch_size} " - f"% ({config.training.local_batch_size} * {batch_degree}) != 0)" - ) - # calculate gradient accumulation steps self.gradient_accumulation_steps = global_batch_size // ( config.training.local_batch_size * batch_degree From e8fe522b19e5898a588d9d030febb10424beb10f Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Wed, 6 May 2026 23:36:00 -0700 Subject: [PATCH 2/2] refs --- torchtitan/models/gpt_oss/__init__.py | 4 ++++ torchtitan/models/llama4/__init__.py | 3 +++ torchtitan/models/qwen3/__init__.py | 8 +++++++- torchtitan/models/qwen3_vl/__init__.py | 8 +++++++- 4 files changed, 21 insertions(+), 2 deletions(-) diff --git a/torchtitan/models/gpt_oss/__init__.py b/torchtitan/models/gpt_oss/__init__.py index 128a7acae6..96764f899c 100644 --- a/torchtitan/models/gpt_oss/__init__.py +++ b/torchtitan/models/gpt_oss/__init__.py @@ -167,6 +167,10 @@ def _build_gptoss_layers( Even-indexed layers (0, 2, 4, ...) use sliding window attention. All dimensional fields are set directly. + + Aux loss ref: + - load_balancing_loss_func (batch-wise): https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_oss/modeling_gpt_oss.py + - https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_oss/configuration_gpt_oss.py """ layers = [] for layer_id in range(n_layers): diff --git a/torchtitan/models/llama4/__init__.py b/torchtitan/models/llama4/__init__.py index 6022ddf04d..ee170f654f 100644 --- a/torchtitan/models/llama4/__init__.py +++ b/torchtitan/models/llama4/__init__.py @@ -94,6 +94,9 @@ def _build_llama4_layers( Handles iRoPE (NoPE on every N layers) and MoE interleaving. For each layer, depth-scaled inits are computed using the layer index. + + Aux loss ref: + - https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4/configuration_llama4.py """ inner_attention, mask_type = get_attention_config(attn_backend) if every_n_layers_nope <= 1: diff --git a/torchtitan/models/qwen3/__init__.py b/torchtitan/models/qwen3/__init__.py index c10ef52dfb..85c3239a17 100644 --- a/torchtitan/models/qwen3/__init__.py +++ b/torchtitan/models/qwen3/__init__.py @@ -133,7 +133,13 @@ def _build_qwen3_moe_layers( moe_comm_backend: str, non_blocking_capacity_factor: float | None = None, ) -> list[TransformerBlock.Config]: - """Build per-layer configs for MoE Qwen3 models with depth-scaled inits.""" + """ + Build per-layer configs for MoE Qwen3 models with depth-scaled inits. + + Aux loss ref: + - load_balancing_func (batch-wise): https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py + - https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_moe/configuration_qwen3_moe.py + """ inner_attention, mask_type = get_attention_config(attn_backend) layers = [] for layer_id in range(n_layers): diff --git a/torchtitan/models/qwen3_vl/__init__.py b/torchtitan/models/qwen3_vl/__init__.py index 42473ab77a..d433da92e8 100644 --- a/torchtitan/models/qwen3_vl/__init__.py +++ b/torchtitan/models/qwen3_vl/__init__.py @@ -193,7 +193,13 @@ def _build_qwen3_vl_moe_layers( moe_comm_backend: str, non_blocking_capacity_factor: float | None = None, ) -> list[TransformerBlock.Config]: - """Build per-layer configs for MoE Qwen3-VL models with depth-scaled inits.""" + """ + Build per-layer configs for MoE Qwen3-VL models with depth-scaled inits. + + Aux loss ref: + - load_balancing_func (batch-wise): https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py + - https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py + """ inner_attention, mask_type = get_attention_config(attn_backend) layers = [] for layer_id in range(n_layers):