Skip to content

[MoE] Pad token count to a multiple of sp_size in AllToAllTokenDispatcher#3193

Merged
wwwjn merged 6 commits intomainfrom
gh/wwwjn/18/head
May 5, 2026
Merged

[MoE] Pad token count to a multiple of sp_size in AllToAllTokenDispatcher#3193
wwwjn merged 6 commits intomainfrom
gh/wwwjn/18/head

Conversation

@wwwjn
Copy link
Copy Markdown
Contributor

@wwwjn wwwjn commented May 1, 2026

Stack from ghstack (oldest at bottom):

Today the dispatcher's _split_along_sp() raises when num_tokens (bs * slen) is not divisible by sp_size (= TP degree). Real workloads with varlen prompts can land on non-divisible totals and crash the MoE forward.

Pad inside dispatch(): round num_tokens up to the next multiple of sp_size, padding x and top_scores with zeros and selected_experts_indices with 0 (so pad rows route deterministically to expert 0 with zero score). combine() reads metadata.original_num_tokens to size the scatter_add buffer at the padded length and slices the pad rows off before returning. When sp_size == 1 or input is already divisible, behavior is bitwise identical to today.

Pad tokens are numerically inert:

  • Zero scores -> contribution to scatter_add is exactly zero either before or after expert compute (independent of score_before_experts).
  • Pad indices fall in [original, padded), which is sliced off after scatter_add, so they never appear in the returned output.

Trainer/generator can pad by different amounts depending on their batch shapes; the unpadded portions remain bitwise identical.

  • TorchAOTokenDispatcher inherits dispatch/combine and gets the fix for free.
  • DeepEPTokenDispatcher uses a separate metadata type and is unaffected.

…cher

Today the dispatcher's _split_along_sp() raises when num_tokens (bs * slen)
is not divisible by sp_size (= TP degree). Real workloads with varlen
prompts can land on non-divisible totals and crash the MoE forward.

Pad inside dispatch(): round num_tokens up to the next multiple of
sp_size, padding x and top_scores with zeros and selected_experts_indices
with 0 (so pad rows route deterministically to expert 0 with zero
score). combine() reads metadata.original_num_tokens to size the
scatter_add buffer at the padded length and slices the pad rows off
before returning. When sp_size == 1 or input is already divisible,
behavior is bitwise identical to today.

Pad tokens are numerically inert:
- Zero scores -> contribution to scatter_add is exactly zero either
  before or after expert compute (independent of score_before_experts).
- Pad indices fall in [original, padded), which is sliced off after
  scatter_add, so they never appear in the returned output.

Trainer/generator can pad by different amounts depending on their batch
shapes; the unpadded portions remain bitwise identical.

TorchAOTokenDispatcher inherits dispatch/combine and gets the fix for
free. DeepEPTokenDispatcher uses a separate metadata type and is
unaffected.

[ghstack-poisoned]
@wwwjn wwwjn requested review from fegin, tianyu-l and wconstab as code owners May 1, 2026 01:44
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 1, 2026
@wwwjn wwwjn requested a review from acisseJZhong May 1, 2026 01:54
Comment on lines +472 to 475
local_num_tokens = padded_num_tokens // self.sp_size
token_indices_experts_sorted = (
metadata.token_indices_experts_sorted + local_num_tokens * self.sp_rank
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that for computing indices you need the padding info / original_num_tokens, but for other places in combine why do you need padding?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because with padded token, in out.scatter_add(dim, index, src), the range of index value will be [0, num_tokens + num_padded tokens). But the size of out sill be (num_tokens, hidden_dim). With padded token, index might point to a padded token instead of valid token.

…TokenDispatcher"


Today the dispatcher's _split_along_sp() raises when num_tokens (bs * slen) is not divisible by sp_size (= TP degree). Real workloads with varlen prompts can land on non-divisible totals and crash the MoE forward.

Pad inside dispatch(): round num_tokens up to the next multiple of sp_size, padding x and top_scores with zeros and selected_experts_indices with 0 (so pad rows route deterministically to expert 0 with zero score). combine() reads metadata.original_num_tokens to size the scatter_add buffer at the padded length and slices the pad rows off before returning. When sp_size == 1 or input is already divisible, behavior is bitwise identical to today.

Pad tokens are numerically inert:
- Zero scores -> contribution to scatter_add is exactly zero either before or after expert compute (independent of score_before_experts).
- Pad indices fall in [original, padded), which is sliced off after scatter_add, so they never appear in the returned output.

Trainer/generator can pad by different amounts depending on their batch shapes; the unpadded portions remain bitwise identical.

- TorchAOTokenDispatcher inherits dispatch/combine and gets the fix for free. 
- DeepEPTokenDispatcher uses a separate metadata type and is unaffected.

[ghstack-poisoned]
if padded_num_tokens != original_num_tokens:
mask = token_indices_experts_sorted < original_num_tokens
token_indices_experts_sorted = token_indices_experts_sorted[mask]
routed_output = routed_output[mask]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this indexing may not be efficient? I would like us to survey other impl with uneven seqlen, e.g. vllm, megatron, etc.

Copy link
Copy Markdown
Contributor Author

@wwwjn wwwjn May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comparison is:

vllm is taking the similar padding and unpadding approach as in this PR

…TokenDispatcher"


Today the dispatcher's _split_along_sp() raises when num_tokens (bs * slen) is not divisible by sp_size (= TP degree). Real workloads with varlen prompts can land on non-divisible totals and crash the MoE forward.

Pad inside dispatch(): round num_tokens up to the next multiple of sp_size, padding x and top_scores with zeros and selected_experts_indices with 0 (so pad rows route deterministically to expert 0 with zero score). combine() reads metadata.original_num_tokens to size the scatter_add buffer at the padded length and slices the pad rows off before returning. When sp_size == 1 or input is already divisible, behavior is bitwise identical to today.

Pad tokens are numerically inert:
- Zero scores -> contribution to scatter_add is exactly zero either before or after expert compute (independent of score_before_experts).
- Pad indices fall in [original, padded), which is sliced off after scatter_add, so they never appear in the returned output.

Trainer/generator can pad by different amounts depending on their batch shapes; the unpadded portions remain bitwise identical.

- TorchAOTokenDispatcher inherits dispatch/combine and gets the fix for free. 
- DeepEPTokenDispatcher uses a separate metadata type and is unaffected.

[ghstack-poisoned]
@wwwjn wwwjn changed the title [MoE] Pad token count to a multiple of sp_size in AllToAllTokenDispatcher [Don't review][MoE] Pad token count to a multiple of sp_size in AllToAllTokenDispatcher May 1, 2026
@wwwjn wwwjn changed the title [Don't review][MoE] Pad token count to a multiple of sp_size in AllToAllTokenDispatcher [MoE] Pad token count to a multiple of sp_size in AllToAllTokenDispatcher May 4, 2026
@wwwjn wwwjn requested a review from tianyu-l May 4, 2026 17:56
Copy link
Copy Markdown
Contributor

@shuhuayu shuhuayu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use this to fix uneven inference case for now. For tp -> ep, the current code path will have unnecessary two all2alls when ep mesh is a subset of tp mesh, but it works in general when dp + tp is used. A next refactor can be removing all gather and use shard on seq_len to finish the MoE block with two all2alls for dispatch and combine.

Copy link
Copy Markdown
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR for sequence parallel moe sounds OK to me. Please at least add a test

To get prepared for spmd_types integration, let's add an assertion when both

  • Sequence Parallel used in dense region
  • seq len can't be evenly divided by TP degree

I feel this should be done in update_from_config in each model (we can do common things in Decoder and let subclasses inherit).

Can be done in a separate PR.

wwwjn added 2 commits May 4, 2026 16:14
…TokenDispatcher"


Today the dispatcher's _split_along_sp() raises when num_tokens (bs * slen) is not divisible by sp_size (= TP degree). Real workloads with varlen prompts can land on non-divisible totals and crash the MoE forward.

Pad inside dispatch(): round num_tokens up to the next multiple of sp_size, padding x and top_scores with zeros and selected_experts_indices with 0 (so pad rows route deterministically to expert 0 with zero score). combine() reads metadata.original_num_tokens to size the scatter_add buffer at the padded length and slices the pad rows off before returning. When sp_size == 1 or input is already divisible, behavior is bitwise identical to today.

Pad tokens are numerically inert:
- Zero scores -> contribution to scatter_add is exactly zero either before or after expert compute (independent of score_before_experts).
- Pad indices fall in [original, padded), which is sliced off after scatter_add, so they never appear in the returned output.

Trainer/generator can pad by different amounts depending on their batch shapes; the unpadded portions remain bitwise identical.

- TorchAOTokenDispatcher inherits dispatch/combine and gets the fix for free. 
- DeepEPTokenDispatcher uses a separate metadata type and is unaffected.

[ghstack-poisoned]
…TokenDispatcher"


Today the dispatcher's _split_along_sp() raises when num_tokens (bs * slen) is not divisible by sp_size (= TP degree). Real workloads with varlen prompts can land on non-divisible totals and crash the MoE forward.

Pad inside dispatch(): round num_tokens up to the next multiple of sp_size, padding x and top_scores with zeros and selected_experts_indices with 0 (so pad rows route deterministically to expert 0 with zero score). combine() reads metadata.original_num_tokens to size the scatter_add buffer at the padded length and slices the pad rows off before returning. When sp_size == 1 or input is already divisible, behavior is bitwise identical to today.

Pad tokens are numerically inert:
- Zero scores -> contribution to scatter_add is exactly zero either before or after expert compute (independent of score_before_experts).
- Pad indices fall in [original, padded), which is sliced off after scatter_add, so they never appear in the returned output.

Trainer/generator can pad by different amounts depending on their batch shapes; the unpadded portions remain bitwise identical.

- TorchAOTokenDispatcher inherits dispatch/combine and gets the fix for free. 
- DeepEPTokenDispatcher uses a separate metadata type and is unaffected.

[ghstack-poisoned]
@wwwjn
Copy link
Copy Markdown
Contributor Author

wwwjn commented May 4, 2026

To get prepared for spmd_types integration, let's add an assertion when both

  • Sequence Parallel used in dense region
  • seq len can't be evenly divided by TP degree

I feel this should be done in update_from_config in each model (we can do common things in Decoder and let subclasses inherit).

Can be done in a separate PR.

I added a unit test for dispatcher. E2E tests will be covered in next PR.

Will create another PR to handle the assertion and make sure number of tokens can be divided by TP degree

…TokenDispatcher"


Today the dispatcher's _split_along_sp() raises when num_tokens (bs * slen) is not divisible by sp_size (= TP degree). Real workloads with varlen prompts can land on non-divisible totals and crash the MoE forward.

Pad inside dispatch(): round num_tokens up to the next multiple of sp_size, padding x and top_scores with zeros and selected_experts_indices with 0 (so pad rows route deterministically to expert 0 with zero score). combine() reads metadata.original_num_tokens to size the scatter_add buffer at the padded length and slices the pad rows off before returning. When sp_size == 1 or input is already divisible, behavior is bitwise identical to today.

Pad tokens are numerically inert:
- Zero scores -> contribution to scatter_add is exactly zero either before or after expert compute (independent of score_before_experts).
- Pad indices fall in [original, padded), which is sliced off after scatter_add, so they never appear in the returned output.

Trainer/generator can pad by different amounts depending on their batch shapes; the unpadded portions remain bitwise identical.

- TorchAOTokenDispatcher inherits dispatch/combine and gets the fix for free. 
- DeepEPTokenDispatcher uses a separate metadata type and is unaffected.

[ghstack-poisoned]
@wwwjn wwwjn changed the base branch from gh/wwwjn/18/base to main May 5, 2026 01:34
@wwwjn wwwjn merged commit d341407 into main May 5, 2026
9 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants