Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion torchtitan/models/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .embedding import Embedding
from .feed_forward import compute_ffn_hidden_dim, FeedForward
from .linear import Linear
from .moe import MoE
from .moe import BatchWiseAuxLoss, MoE, MoELoadBalanceAuxLoss, SequenceWiseAuxLoss
from .rmsnorm import RMSNorm
from .rope import (
apply_rotary_emb_complex,
Expand All @@ -48,7 +48,10 @@
"get_sliding_window_mask_mod",
"GQAttention",
"Linear",
"BatchWiseAuxLoss",
"MoELoadBalanceAuxLoss",
"MoE",
"SequenceWiseAuxLoss",
"QKVLinear",
"RMSNorm",
"RoPE",
Expand Down
40 changes: 39 additions & 1 deletion torchtitan/models/common/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"""

from collections.abc import Callable
from dataclasses import replace
from typing import Literal

from torchtitan.models.common.attention import (
Expand All @@ -23,7 +24,13 @@
)
from torchtitan.models.common.feed_forward import FeedForward
from torchtitan.models.common.linear import Linear
from torchtitan.models.common.moe import GroupedExperts, MoE, TokenChoiceTopKRouter
from torchtitan.models.common.moe import (
BatchWiseAuxLoss,
GroupedExperts,
MoE,
MoELoadBalanceAuxLoss,
TokenChoiceTopKRouter,
)
from torchtitan.models.common.rmsnorm import RMSNorm
from torchtitan.models.common.token_dispatcher import (
AllToAllTokenDispatcher,
Expand Down Expand Up @@ -152,14 +159,18 @@ def make_moe_config(
experts: GroupedExperts.Config,
shared_experts: FeedForward.Config | None = None,
load_balance_coeff: float | None = 1e-3,
aux_loss: MoELoadBalanceAuxLoss.Config | None = None,
) -> MoE.Config:
"""Build a fully-specified MoE.Config."""
if aux_loss:
aux_loss = replace(aux_loss, top_k=router.top_k)
return MoE.Config(
num_experts=num_experts,
load_balance_coeff=load_balance_coeff,
router=router,
experts=experts,
shared_experts=shared_experts,
aux_loss=aux_loss,
)


Expand Down Expand Up @@ -265,3 +276,30 @@ def make_experts_config(
non_blocking_capacity_factor=non_blocking_capacity_factor,
),
)


def update_moe_aux_loss_configs(
layers: list,
*,
pp_enabled: bool,
global_batch_size: int,
) -> None:
"""Update aux loss configs for all MoE layers from trainer config.

Call this from each model's ``update_from_config`` to set
``global_batch_size`` for gradient normalization across PP microbatches,
gradient accumulation steps, and DP ranks (FSDP sum). Also validates
that batch-wise aux loss is not used with PP.
"""
for layer_cfg in layers:
moe_cfg = getattr(layer_cfg, "moe", None)
if moe_cfg is None:
continue
aux_loss = moe_cfg.aux_loss
if aux_loss.weight > 0:
if isinstance(aux_loss, BatchWiseAuxLoss.Config) and pp_enabled:
raise ValueError(
"batch_wise MoE aux loss is incompatible with pipeline "
"parallelism. Use sequence_wise instead."
)
aux_loss.global_batch_size = global_batch_size
208 changes: 204 additions & 4 deletions torchtitan/models/common/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch import nn
from torch.distributed.tensor import DTensor, Partial

from torchtitan.config import Configurable
from torchtitan.models.common.feed_forward import FeedForward
from torchtitan.models.common.linear import Linear

Expand Down Expand Up @@ -140,6 +141,118 @@ def forward(
return self.token_dispatcher.combine(routed_output, metadata, x, shared_experts)


def _sequence_wise_aux_loss(
scores: torch.Tensor,
selected_experts_indices: torch.Tensor,
bs: int,
slen: int,
top_k: int,
aux_loss_weight: float,
) -> torch.Tensor:
"""Sequence-wise auxiliary load-balance loss (DeepSeek-V3 Eqs 17-20)."""
num_experts = scores.size(-1)
scores_per_seq = scores.view(bs, slen, num_experts)
denom = scores_per_seq.sum(dim=-1, keepdim=True) + 1e-20
probs_per_seq = scores_per_seq / denom
p_i = probs_per_seq.mean(dim=1)
indices_per_seq = selected_experts_indices.view(bs, -1)
offset = torch.arange(bs, device=indices_per_seq.device).unsqueeze(1) * num_experts
flat_indices = (indices_per_seq + offset).reshape(-1)
counts = torch.bincount(flat_indices.long(), minlength=bs * num_experts)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you double check whether torch.bincout leads to a graph break under torch.compile? Maybe we need .scatter_add on a sparse tensor.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This runs under compile, I believe capture_scalar_outputs enables capturing the output shape.

counts = counts.reshape(bs, num_experts).to(dtype=scores.dtype)
f_i = counts * (num_experts / (top_k * slen))
return (f_i * p_i).sum(dim=1).mean() * aux_loss_weight


def _batch_wise_aux_loss(
scores: torch.Tensor,
selected_experts_indices: torch.Tensor,
bs: int,
slen: int,
top_k: int,
aux_loss_weight: float,
) -> torch.Tensor:
"""Batch-wise auxiliary load-balance loss."""
num_experts = scores.size(-1)
total_tokens = scores.size(0)
num_tokens_per_expert = torch.histc(
selected_experts_indices.view(-1).float(),
bins=num_experts,
min=0,
max=num_experts,
)
p_i = scores.mean(dim=0)
f_i = num_tokens_per_expert.to(scores.dtype) * (
num_experts / (top_k * total_tokens)
)
return (f_i * p_i).sum() * aux_loss_weight


class _AuxLossBase(torch.autograd.Function):
"""Injects auxiliary load-balance loss gradients at the router scores level.

Identity in forward (returns ``top_scores`` unchanged). In backward,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Identity in forward (returns top_scores unchanged)

Is this even a goal?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is so it stays on the autograd graph, will update the docstring

computes ``d(aux_loss)/d(scores)`` via ``torch.func.grad`` and adds it
to ``scores``'s gradient. ``top_scores`` is a pass-through so this node
remains in the autograd graph.
"""

@staticmethod
def forward( # pyrefly: ignore [bad-override]
ctx: torch.autograd.function.FunctionCtx,
top_scores: torch.Tensor,
scores: torch.Tensor,
selected_experts_indices: torch.Tensor,
bs: int,
slen: int,
top_k: int,
aux_loss_weight: float,
) -> torch.Tensor:
ctx.save_for_backward(scores, selected_experts_indices)
ctx.bs = bs # pyrefly: ignore [missing-attribute]
ctx.slen = slen # pyrefly: ignore [missing-attribute]
ctx.top_k = top_k # pyrefly: ignore [missing-attribute]
ctx.aux_loss_weight = aux_loss_weight # pyrefly: ignore [missing-attribute]
return top_scores

@staticmethod
def _backward_impl(loss_fn, ctx, grad_top_scores):
(
scores,
selected_experts_indices,
) = ctx.saved_tensors
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add one line comment if we want to be explicit.

Suggested change
) = ctx.saved_tensors
) = ctx.saved_tensors
# torch.func.grad avoids the graph break that torch.autograd.grad causes under torch.compile

# torch.func.grad avoids the graph break that torch.autograd.grad causes under torch.compile
aux_grad = torch.func.grad(loss_fn)(
scores,
selected_experts_indices,
ctx.bs,
ctx.slen,
ctx.top_k,
ctx.aux_loss_weight,
)
return grad_top_scores, aux_grad, None, None, None, None, None


class _SequenceWiseAuxLoss(_AuxLossBase):
@staticmethod
def backward( # pyrefly: ignore [bad-override]
ctx: torch.autograd.function.FunctionCtx,
grad_top_scores: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, None, None, None, None, None]:
return _AuxLossBase._backward_impl(
_sequence_wise_aux_loss, ctx, grad_top_scores
)


class _BatchWiseAuxLoss(_AuxLossBase):
@staticmethod
def backward( # pyrefly: ignore [bad-override]
ctx: torch.autograd.function.FunctionCtx,
grad_top_scores: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, None, None, None, None, None]:
return _AuxLossBase._backward_impl(_batch_wise_aux_loss, ctx, grad_top_scores)


class TokenChoiceTopKRouter(Module):
"""This class implements token-choice routing. In token-choice top-K routing, each token is
routed to top K experts based on the router scores.
Expand Down Expand Up @@ -233,16 +346,21 @@ def _get_node_limited_routing_scores(
return scores_for_choice

def forward(
self, x: torch.Tensor, expert_bias: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
self,
x: torch.Tensor,
expert_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``.
expert_bias (torch.Tensor | None, optional): Optional bias tensor for experts with shape ``(num_experts,)``.
Used for load balancing. Defaults to None.

Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- scores (torch.Tensor):
Full router scores for all experts with shape ``(bs*slen, num_experts)``.
Returned for load-balancing aux loss computation.
- top_scores (torch.Tensor):
Routing scores for selected experts with shape ``(bs*slen, top_k)``.
- selected_experts_indices (torch.Tensor):
Expand Down Expand Up @@ -297,7 +415,80 @@ def forward(
max=self.num_experts,
)

return top_scores, selected_experts_indices, num_tokens_per_expert
return scores, top_scores, selected_experts_indices, num_tokens_per_expert


class MoELoadBalanceAuxLoss(Configurable):
"""MoE auxiliary load-balance loss.

Injects aux loss gradients into router scores without changing the model
output (PP-safe). Call instance with router outputs to apply.

Subclass to select loss variant (sequence-wise or batch-wise).
"""

_autograd_fn: type[_AuxLossBase]

@dataclass(kw_only=True, slots=True)
class Config(Configurable.Config):
weight: float = 0.0
"""Weight for the auxiliary load-balance loss. 0 disables it."""
top_k: int = 1
"""Number of experts per token. Must match the router's top_k."""
global_batch_size: int | None = None
"""Global batch size to normalize aux loss across all batches/microbatches.
Torchtitan generally uses sum aggregation & global_bs division.
Because BatchWiseAuxLoss doesn't commute over microbatches, we ban
batch-wise + PP usage in any case."""

def __init__(self, config: Config):
if config.weight < 0:
raise ValueError(f"aux_loss.weight must be >= 0, got {config.weight}")
self.weight = config.weight
self.global_batch_size = config.global_batch_size
self.top_k = config.top_k

def __call__(
self,
top_scores: torch.Tensor,
scores: torch.Tensor,
selected_experts_indices: torch.Tensor,
bs: int,
slen: int,
) -> torch.Tensor:
if self.weight == 0:
return top_scores
global_bs = self.global_batch_size or bs
scaled_weight = self.weight * bs / global_bs
return self._autograd_fn.apply(
top_scores,
scores,
selected_experts_indices,
bs,
slen,
self.top_k,
scaled_weight,
)


class SequenceWiseAuxLoss(MoELoadBalanceAuxLoss):
"""Sequence-wise auxiliary load-balance loss (DeepSeek-V3 Eqs 17-20)."""

_autograd_fn = _SequenceWiseAuxLoss

@dataclass(kw_only=True, slots=True)
class Config(MoELoadBalanceAuxLoss.Config):
pass


class BatchWiseAuxLoss(MoELoadBalanceAuxLoss):
"""Batch-wise auxiliary load-balance loss."""

_autograd_fn = _BatchWiseAuxLoss

@dataclass(kw_only=True, slots=True)
class Config(MoELoadBalanceAuxLoss.Config):
pass


class MoE(Module):
Expand Down Expand Up @@ -325,6 +516,7 @@ class Config(Module.Config):
router: TokenChoiceTopKRouter.Config
load_balance_coeff: float | None = 1e-3
shared_experts: FeedForward.Config | None = None
aux_loss: MoELoadBalanceAuxLoss.Config | None = None

def __init__(self, config: Config):
super().__init__()
Expand All @@ -335,6 +527,7 @@ def __init__(self, config: Config):
self.shared_experts = (
config.shared_experts.build() if config.shared_experts is not None else None
)
self.aux_loss = config.aux_loss.build() if config.aux_loss is not None else None

# define fields for auxiliary-loss-free load balancing (https://arxiv.org/abs/2408.15664)
# NOTE: tokens_per_expert is accumulated in the model forward pass.
Expand Down Expand Up @@ -391,14 +584,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
bs, slen, dim = x.shape
x = x.view(-1, dim)

# scores shape (bs*slen, num_experts)
# top_scores and selected_experts_indices shape (bs*slen, top_k)
# num_tokens_per_expert shape (num_experts,)
(
scores,
top_scores,
selected_experts_indices,
num_tokens_per_expert,
) = self.router(x, self.expert_bias)

if self.training and self.aux_loss is not None:
top_scores = self.aux_loss(
top_scores, scores, selected_experts_indices, bs, slen
)

# tokens_per_expert will be used to update the expert bias for load balancing.
# and also to count the expert usage
# TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert --
Expand Down
Loading
Loading