[MoE] Pad token count to a multiple of sp_size in AllToAllTokenDispatcher#3193
[MoE] Pad token count to a multiple of sp_size in AllToAllTokenDispatcher#3193
Conversation
…cher Today the dispatcher's _split_along_sp() raises when num_tokens (bs * slen) is not divisible by sp_size (= TP degree). Real workloads with varlen prompts can land on non-divisible totals and crash the MoE forward. Pad inside dispatch(): round num_tokens up to the next multiple of sp_size, padding x and top_scores with zeros and selected_experts_indices with 0 (so pad rows route deterministically to expert 0 with zero score). combine() reads metadata.original_num_tokens to size the scatter_add buffer at the padded length and slices the pad rows off before returning. When sp_size == 1 or input is already divisible, behavior is bitwise identical to today. Pad tokens are numerically inert: - Zero scores -> contribution to scatter_add is exactly zero either before or after expert compute (independent of score_before_experts). - Pad indices fall in [original, padded), which is sliced off after scatter_add, so they never appear in the returned output. Trainer/generator can pad by different amounts depending on their batch shapes; the unpadded portions remain bitwise identical. TorchAOTokenDispatcher inherits dispatch/combine and gets the fix for free. DeepEPTokenDispatcher uses a separate metadata type and is unaffected. [ghstack-poisoned]
| local_num_tokens = padded_num_tokens // self.sp_size | ||
| token_indices_experts_sorted = ( | ||
| metadata.token_indices_experts_sorted + local_num_tokens * self.sp_rank | ||
| ) |
There was a problem hiding this comment.
I understand that for computing indices you need the padding info / original_num_tokens, but for other places in combine why do you need padding?
There was a problem hiding this comment.
This is because with padded token, in out.scatter_add(dim, index, src), the range of index value will be [0, num_tokens + num_padded tokens). But the size of out sill be (num_tokens, hidden_dim). With padded token, index might point to a padded token instead of valid token.
…TokenDispatcher" Today the dispatcher's _split_along_sp() raises when num_tokens (bs * slen) is not divisible by sp_size (= TP degree). Real workloads with varlen prompts can land on non-divisible totals and crash the MoE forward. Pad inside dispatch(): round num_tokens up to the next multiple of sp_size, padding x and top_scores with zeros and selected_experts_indices with 0 (so pad rows route deterministically to expert 0 with zero score). combine() reads metadata.original_num_tokens to size the scatter_add buffer at the padded length and slices the pad rows off before returning. When sp_size == 1 or input is already divisible, behavior is bitwise identical to today. Pad tokens are numerically inert: - Zero scores -> contribution to scatter_add is exactly zero either before or after expert compute (independent of score_before_experts). - Pad indices fall in [original, padded), which is sliced off after scatter_add, so they never appear in the returned output. Trainer/generator can pad by different amounts depending on their batch shapes; the unpadded portions remain bitwise identical. - TorchAOTokenDispatcher inherits dispatch/combine and gets the fix for free. - DeepEPTokenDispatcher uses a separate metadata type and is unaffected. [ghstack-poisoned]
| if padded_num_tokens != original_num_tokens: | ||
| mask = token_indices_experts_sorted < original_num_tokens | ||
| token_indices_experts_sorted = token_indices_experts_sorted[mask] | ||
| routed_output = routed_output[mask] |
There was a problem hiding this comment.
this indexing may not be efficient? I would like us to survey other impl with uneven seqlen, e.g. vllm, megatron, etc.
There was a problem hiding this comment.
The comparison is:
-
torchtitan: pad → split-along-SP → a2a → grouped GEMM → a2a → scatter_add into a per-token buffer → drop pad
-
vLLM: sequence_parallel_chunk (pad+slice) → dispatch via a2a manager → fused MoE kernel does token weighted sum internally and routed back across EP ranks → combine via all_reduce or reduce_scatterv on the SP dim → drop pad
vllm is taking the similar padding and unpadding approach as in this PR
…TokenDispatcher" Today the dispatcher's _split_along_sp() raises when num_tokens (bs * slen) is not divisible by sp_size (= TP degree). Real workloads with varlen prompts can land on non-divisible totals and crash the MoE forward. Pad inside dispatch(): round num_tokens up to the next multiple of sp_size, padding x and top_scores with zeros and selected_experts_indices with 0 (so pad rows route deterministically to expert 0 with zero score). combine() reads metadata.original_num_tokens to size the scatter_add buffer at the padded length and slices the pad rows off before returning. When sp_size == 1 or input is already divisible, behavior is bitwise identical to today. Pad tokens are numerically inert: - Zero scores -> contribution to scatter_add is exactly zero either before or after expert compute (independent of score_before_experts). - Pad indices fall in [original, padded), which is sliced off after scatter_add, so they never appear in the returned output. Trainer/generator can pad by different amounts depending on their batch shapes; the unpadded portions remain bitwise identical. - TorchAOTokenDispatcher inherits dispatch/combine and gets the fix for free. - DeepEPTokenDispatcher uses a separate metadata type and is unaffected. [ghstack-poisoned]
shuhuayu
left a comment
There was a problem hiding this comment.
I think we can use this to fix uneven inference case for now. For tp -> ep, the current code path will have unnecessary two all2alls when ep mesh is a subset of tp mesh, but it works in general when dp + tp is used. A next refactor can be removing all gather and use shard on seq_len to finish the MoE block with two all2alls for dispatch and combine.
There was a problem hiding this comment.
This PR for sequence parallel moe sounds OK to me. Please at least add a test
To get prepared for spmd_types integration, let's add an assertion when both
- Sequence Parallel used in dense region
- seq len can't be evenly divided by TP degree
I feel this should be done in update_from_config in each model (we can do common things in Decoder and let subclasses inherit).
Can be done in a separate PR.
…TokenDispatcher" Today the dispatcher's _split_along_sp() raises when num_tokens (bs * slen) is not divisible by sp_size (= TP degree). Real workloads with varlen prompts can land on non-divisible totals and crash the MoE forward. Pad inside dispatch(): round num_tokens up to the next multiple of sp_size, padding x and top_scores with zeros and selected_experts_indices with 0 (so pad rows route deterministically to expert 0 with zero score). combine() reads metadata.original_num_tokens to size the scatter_add buffer at the padded length and slices the pad rows off before returning. When sp_size == 1 or input is already divisible, behavior is bitwise identical to today. Pad tokens are numerically inert: - Zero scores -> contribution to scatter_add is exactly zero either before or after expert compute (independent of score_before_experts). - Pad indices fall in [original, padded), which is sliced off after scatter_add, so they never appear in the returned output. Trainer/generator can pad by different amounts depending on their batch shapes; the unpadded portions remain bitwise identical. - TorchAOTokenDispatcher inherits dispatch/combine and gets the fix for free. - DeepEPTokenDispatcher uses a separate metadata type and is unaffected. [ghstack-poisoned]
…TokenDispatcher" Today the dispatcher's _split_along_sp() raises when num_tokens (bs * slen) is not divisible by sp_size (= TP degree). Real workloads with varlen prompts can land on non-divisible totals and crash the MoE forward. Pad inside dispatch(): round num_tokens up to the next multiple of sp_size, padding x and top_scores with zeros and selected_experts_indices with 0 (so pad rows route deterministically to expert 0 with zero score). combine() reads metadata.original_num_tokens to size the scatter_add buffer at the padded length and slices the pad rows off before returning. When sp_size == 1 or input is already divisible, behavior is bitwise identical to today. Pad tokens are numerically inert: - Zero scores -> contribution to scatter_add is exactly zero either before or after expert compute (independent of score_before_experts). - Pad indices fall in [original, padded), which is sliced off after scatter_add, so they never appear in the returned output. Trainer/generator can pad by different amounts depending on their batch shapes; the unpadded portions remain bitwise identical. - TorchAOTokenDispatcher inherits dispatch/combine and gets the fix for free. - DeepEPTokenDispatcher uses a separate metadata type and is unaffected. [ghstack-poisoned]
I added a unit test for dispatcher. E2E tests will be covered in next PR. Will create another PR to handle the assertion and make sure number of tokens can be divided by TP degree |
…TokenDispatcher" Today the dispatcher's _split_along_sp() raises when num_tokens (bs * slen) is not divisible by sp_size (= TP degree). Real workloads with varlen prompts can land on non-divisible totals and crash the MoE forward. Pad inside dispatch(): round num_tokens up to the next multiple of sp_size, padding x and top_scores with zeros and selected_experts_indices with 0 (so pad rows route deterministically to expert 0 with zero score). combine() reads metadata.original_num_tokens to size the scatter_add buffer at the padded length and slices the pad rows off before returning. When sp_size == 1 or input is already divisible, behavior is bitwise identical to today. Pad tokens are numerically inert: - Zero scores -> contribution to scatter_add is exactly zero either before or after expert compute (independent of score_before_experts). - Pad indices fall in [original, padded), which is sliced off after scatter_add, so they never appear in the returned output. Trainer/generator can pad by different amounts depending on their batch shapes; the unpadded portions remain bitwise identical. - TorchAOTokenDispatcher inherits dispatch/combine and gets the fix for free. - DeepEPTokenDispatcher uses a separate metadata type and is unaffected. [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Today the dispatcher's _split_along_sp() raises when num_tokens (bs * slen) is not divisible by sp_size (= TP degree). Real workloads with varlen prompts can land on non-divisible totals and crash the MoE forward.
Pad inside dispatch(): round num_tokens up to the next multiple of sp_size, padding x and top_scores with zeros and selected_experts_indices with 0 (so pad rows route deterministically to expert 0 with zero score). combine() reads metadata.original_num_tokens to size the scatter_add buffer at the padded length and slices the pad rows off before returning. When sp_size == 1 or input is already divisible, behavior is bitwise identical to today.
Pad tokens are numerically inert:
Trainer/generator can pad by different amounts depending on their batch shapes; the unpadded portions remain bitwise identical.