-
Notifications
You must be signed in to change notification settings - Fork 809
[moe] load-balancing aux loss #3000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |||||||
| from torch import nn | ||||||||
| from torch.distributed.tensor import DTensor, Partial | ||||||||
|
|
||||||||
| from torchtitan.config import Configurable | ||||||||
| from torchtitan.models.common.feed_forward import FeedForward | ||||||||
| from torchtitan.models.common.linear import Linear | ||||||||
|
|
||||||||
|
|
@@ -140,6 +141,118 @@ def forward( | |||||||
| return self.token_dispatcher.combine(routed_output, metadata, x, shared_experts) | ||||||||
|
|
||||||||
|
|
||||||||
| def _sequence_wise_aux_loss( | ||||||||
| scores: torch.Tensor, | ||||||||
| selected_experts_indices: torch.Tensor, | ||||||||
| bs: int, | ||||||||
| slen: int, | ||||||||
| top_k: int, | ||||||||
| aux_loss_weight: float, | ||||||||
| ) -> torch.Tensor: | ||||||||
| """Sequence-wise auxiliary load-balance loss (DeepSeek-V3 Eqs 17-20).""" | ||||||||
| num_experts = scores.size(-1) | ||||||||
| scores_per_seq = scores.view(bs, slen, num_experts) | ||||||||
| denom = scores_per_seq.sum(dim=-1, keepdim=True) + 1e-20 | ||||||||
| probs_per_seq = scores_per_seq / denom | ||||||||
| p_i = probs_per_seq.mean(dim=1) | ||||||||
| indices_per_seq = selected_experts_indices.view(bs, -1) | ||||||||
| offset = torch.arange(bs, device=indices_per_seq.device).unsqueeze(1) * num_experts | ||||||||
| flat_indices = (indices_per_seq + offset).reshape(-1) | ||||||||
| counts = torch.bincount(flat_indices.long(), minlength=bs * num_experts) | ||||||||
| counts = counts.reshape(bs, num_experts).to(dtype=scores.dtype) | ||||||||
| f_i = counts * (num_experts / (top_k * slen)) | ||||||||
| return (f_i * p_i).sum(dim=1).mean() * aux_loss_weight | ||||||||
|
|
||||||||
|
|
||||||||
| def _batch_wise_aux_loss( | ||||||||
| scores: torch.Tensor, | ||||||||
| selected_experts_indices: torch.Tensor, | ||||||||
| bs: int, | ||||||||
| slen: int, | ||||||||
| top_k: int, | ||||||||
| aux_loss_weight: float, | ||||||||
| ) -> torch.Tensor: | ||||||||
| """Batch-wise auxiliary load-balance loss.""" | ||||||||
| num_experts = scores.size(-1) | ||||||||
| total_tokens = scores.size(0) | ||||||||
| num_tokens_per_expert = torch.histc( | ||||||||
| selected_experts_indices.view(-1).float(), | ||||||||
| bins=num_experts, | ||||||||
| min=0, | ||||||||
| max=num_experts, | ||||||||
| ) | ||||||||
| p_i = scores.mean(dim=0) | ||||||||
| f_i = num_tokens_per_expert.to(scores.dtype) * ( | ||||||||
| num_experts / (top_k * total_tokens) | ||||||||
| ) | ||||||||
| return (f_i * p_i).sum() * aux_loss_weight | ||||||||
|
|
||||||||
|
|
||||||||
| class _AuxLossBase(torch.autograd.Function): | ||||||||
| """Injects auxiliary load-balance loss gradients at the router scores level. | ||||||||
|
|
||||||||
| Identity in forward (returns ``top_scores`` unchanged). In backward, | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Is this even a goal?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is so it stays on the autograd graph, will update the docstring |
||||||||
| computes ``d(aux_loss)/d(scores)`` via ``torch.func.grad`` and adds it | ||||||||
| to ``scores``'s gradient. ``top_scores`` is a pass-through so this node | ||||||||
| remains in the autograd graph. | ||||||||
| """ | ||||||||
|
|
||||||||
| @staticmethod | ||||||||
| def forward( # pyrefly: ignore [bad-override] | ||||||||
| ctx: torch.autograd.function.FunctionCtx, | ||||||||
| top_scores: torch.Tensor, | ||||||||
| scores: torch.Tensor, | ||||||||
| selected_experts_indices: torch.Tensor, | ||||||||
| bs: int, | ||||||||
| slen: int, | ||||||||
| top_k: int, | ||||||||
| aux_loss_weight: float, | ||||||||
| ) -> torch.Tensor: | ||||||||
| ctx.save_for_backward(scores, selected_experts_indices) | ||||||||
| ctx.bs = bs # pyrefly: ignore [missing-attribute] | ||||||||
| ctx.slen = slen # pyrefly: ignore [missing-attribute] | ||||||||
| ctx.top_k = top_k # pyrefly: ignore [missing-attribute] | ||||||||
| ctx.aux_loss_weight = aux_loss_weight # pyrefly: ignore [missing-attribute] | ||||||||
| return top_scores | ||||||||
|
|
||||||||
| @staticmethod | ||||||||
| def _backward_impl(loss_fn, ctx, grad_top_scores): | ||||||||
| ( | ||||||||
| scores, | ||||||||
| selected_experts_indices, | ||||||||
| ) = ctx.saved_tensors | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe add one line comment if we want to be explicit.
Suggested change
|
||||||||
| # torch.func.grad avoids the graph break that torch.autograd.grad causes under torch.compile | ||||||||
| aux_grad = torch.func.grad(loss_fn)( | ||||||||
| scores, | ||||||||
| selected_experts_indices, | ||||||||
| ctx.bs, | ||||||||
| ctx.slen, | ||||||||
| ctx.top_k, | ||||||||
| ctx.aux_loss_weight, | ||||||||
| ) | ||||||||
| return grad_top_scores, aux_grad, None, None, None, None, None | ||||||||
|
|
||||||||
|
|
||||||||
| class _SequenceWiseAuxLoss(_AuxLossBase): | ||||||||
| @staticmethod | ||||||||
| def backward( # pyrefly: ignore [bad-override] | ||||||||
| ctx: torch.autograd.function.FunctionCtx, | ||||||||
| grad_top_scores: torch.Tensor, | ||||||||
| ) -> tuple[torch.Tensor, torch.Tensor, None, None, None, None, None]: | ||||||||
| return _AuxLossBase._backward_impl( | ||||||||
| _sequence_wise_aux_loss, ctx, grad_top_scores | ||||||||
| ) | ||||||||
|
|
||||||||
|
|
||||||||
| class _BatchWiseAuxLoss(_AuxLossBase): | ||||||||
| @staticmethod | ||||||||
| def backward( # pyrefly: ignore [bad-override] | ||||||||
| ctx: torch.autograd.function.FunctionCtx, | ||||||||
| grad_top_scores: torch.Tensor, | ||||||||
| ) -> tuple[torch.Tensor, torch.Tensor, None, None, None, None, None]: | ||||||||
| return _AuxLossBase._backward_impl(_batch_wise_aux_loss, ctx, grad_top_scores) | ||||||||
|
|
||||||||
|
|
||||||||
| class TokenChoiceTopKRouter(Module): | ||||||||
| """This class implements token-choice routing. In token-choice top-K routing, each token is | ||||||||
| routed to top K experts based on the router scores. | ||||||||
|
|
@@ -233,16 +346,21 @@ def _get_node_limited_routing_scores( | |||||||
| return scores_for_choice | ||||||||
|
|
||||||||
| def forward( | ||||||||
| self, x: torch.Tensor, expert_bias: torch.Tensor | None = None | ||||||||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||||||||
| self, | ||||||||
| x: torch.Tensor, | ||||||||
| expert_bias: torch.Tensor | None = None, | ||||||||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||||||||
| """ | ||||||||
| Args: | ||||||||
| x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. | ||||||||
| expert_bias (torch.Tensor | None, optional): Optional bias tensor for experts with shape ``(num_experts,)``. | ||||||||
| Used for load balancing. Defaults to None. | ||||||||
|
|
||||||||
| Returns: | ||||||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||||||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||||||||
| - scores (torch.Tensor): | ||||||||
| Full router scores for all experts with shape ``(bs*slen, num_experts)``. | ||||||||
| Returned for load-balancing aux loss computation. | ||||||||
| - top_scores (torch.Tensor): | ||||||||
| Routing scores for selected experts with shape ``(bs*slen, top_k)``. | ||||||||
| - selected_experts_indices (torch.Tensor): | ||||||||
|
|
@@ -297,7 +415,80 @@ def forward( | |||||||
| max=self.num_experts, | ||||||||
| ) | ||||||||
|
|
||||||||
| return top_scores, selected_experts_indices, num_tokens_per_expert | ||||||||
| return scores, top_scores, selected_experts_indices, num_tokens_per_expert | ||||||||
|
|
||||||||
|
|
||||||||
| class MoELoadBalanceAuxLoss(Configurable): | ||||||||
| """MoE auxiliary load-balance loss. | ||||||||
|
|
||||||||
| Injects aux loss gradients into router scores without changing the model | ||||||||
| output (PP-safe). Call instance with router outputs to apply. | ||||||||
|
|
||||||||
| Subclass to select loss variant (sequence-wise or batch-wise). | ||||||||
| """ | ||||||||
|
|
||||||||
| _autograd_fn: type[_AuxLossBase] | ||||||||
|
|
||||||||
| @dataclass(kw_only=True, slots=True) | ||||||||
| class Config(Configurable.Config): | ||||||||
| weight: float = 0.0 | ||||||||
| """Weight for the auxiliary load-balance loss. 0 disables it.""" | ||||||||
| top_k: int = 1 | ||||||||
| """Number of experts per token. Must match the router's top_k.""" | ||||||||
| global_batch_size: int | None = None | ||||||||
| """Global batch size to normalize aux loss across all batches/microbatches. | ||||||||
| Torchtitan generally uses sum aggregation & global_bs division. | ||||||||
| Because BatchWiseAuxLoss doesn't commute over microbatches, we ban | ||||||||
| batch-wise + PP usage in any case.""" | ||||||||
|
|
||||||||
| def __init__(self, config: Config): | ||||||||
| if config.weight < 0: | ||||||||
| raise ValueError(f"aux_loss.weight must be >= 0, got {config.weight}") | ||||||||
| self.weight = config.weight | ||||||||
| self.global_batch_size = config.global_batch_size | ||||||||
| self.top_k = config.top_k | ||||||||
|
|
||||||||
| def __call__( | ||||||||
| self, | ||||||||
| top_scores: torch.Tensor, | ||||||||
| scores: torch.Tensor, | ||||||||
| selected_experts_indices: torch.Tensor, | ||||||||
| bs: int, | ||||||||
| slen: int, | ||||||||
| ) -> torch.Tensor: | ||||||||
| if self.weight == 0: | ||||||||
| return top_scores | ||||||||
| global_bs = self.global_batch_size or bs | ||||||||
| scaled_weight = self.weight * bs / global_bs | ||||||||
| return self._autograd_fn.apply( | ||||||||
| top_scores, | ||||||||
| scores, | ||||||||
| selected_experts_indices, | ||||||||
| bs, | ||||||||
| slen, | ||||||||
| self.top_k, | ||||||||
| scaled_weight, | ||||||||
| ) | ||||||||
|
|
||||||||
|
|
||||||||
| class SequenceWiseAuxLoss(MoELoadBalanceAuxLoss): | ||||||||
| """Sequence-wise auxiliary load-balance loss (DeepSeek-V3 Eqs 17-20).""" | ||||||||
|
|
||||||||
| _autograd_fn = _SequenceWiseAuxLoss | ||||||||
|
|
||||||||
| @dataclass(kw_only=True, slots=True) | ||||||||
| class Config(MoELoadBalanceAuxLoss.Config): | ||||||||
| pass | ||||||||
|
|
||||||||
|
|
||||||||
| class BatchWiseAuxLoss(MoELoadBalanceAuxLoss): | ||||||||
| """Batch-wise auxiliary load-balance loss.""" | ||||||||
|
|
||||||||
| _autograd_fn = _BatchWiseAuxLoss | ||||||||
|
|
||||||||
| @dataclass(kw_only=True, slots=True) | ||||||||
| class Config(MoELoadBalanceAuxLoss.Config): | ||||||||
| pass | ||||||||
|
|
||||||||
|
|
||||||||
| class MoE(Module): | ||||||||
|
|
@@ -325,6 +516,7 @@ class Config(Module.Config): | |||||||
| router: TokenChoiceTopKRouter.Config | ||||||||
| load_balance_coeff: float | None = 1e-3 | ||||||||
| shared_experts: FeedForward.Config | None = None | ||||||||
| aux_loss: MoELoadBalanceAuxLoss.Config | None = None | ||||||||
|
|
||||||||
| def __init__(self, config: Config): | ||||||||
| super().__init__() | ||||||||
|
|
@@ -335,6 +527,7 @@ def __init__(self, config: Config): | |||||||
| self.shared_experts = ( | ||||||||
| config.shared_experts.build() if config.shared_experts is not None else None | ||||||||
| ) | ||||||||
| self.aux_loss = config.aux_loss.build() if config.aux_loss is not None else None | ||||||||
|
|
||||||||
| # define fields for auxiliary-loss-free load balancing (https://arxiv.org/abs/2408.15664) | ||||||||
| # NOTE: tokens_per_expert is accumulated in the model forward pass. | ||||||||
|
|
@@ -391,14 +584,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||||
| bs, slen, dim = x.shape | ||||||||
| x = x.view(-1, dim) | ||||||||
|
|
||||||||
| # scores shape (bs*slen, num_experts) | ||||||||
| # top_scores and selected_experts_indices shape (bs*slen, top_k) | ||||||||
| # num_tokens_per_expert shape (num_experts,) | ||||||||
| ( | ||||||||
| scores, | ||||||||
| top_scores, | ||||||||
| selected_experts_indices, | ||||||||
| num_tokens_per_expert, | ||||||||
| ) = self.router(x, self.expert_bias) | ||||||||
|
|
||||||||
| if self.training and self.aux_loss is not None: | ||||||||
| top_scores = self.aux_loss( | ||||||||
| top_scores, scores, selected_experts_indices, bs, slen | ||||||||
| ) | ||||||||
|
|
||||||||
| # tokens_per_expert will be used to update the expert bias for load balancing. | ||||||||
| # and also to count the expert usage | ||||||||
| # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- | ||||||||
|
|
||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you double check whether
torch.bincoutleads to a graph break under torch.compile? Maybe we need.scatter_addon a sparse tensor.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This runs under compile, I believe
capture_scalar_outputsenables capturing the output shape.