Skip to content

[WIP]Enable DP-to-EP for MoE inference#3171

Closed
wwwjn wants to merge 9 commits intogh/wwwjn/16/basefrom
gh/wwwjn/16/head
Closed

[WIP]Enable DP-to-EP for MoE inference#3171
wwwjn wants to merge 9 commits intogh/wwwjn/16/basefrom
gh/wwwjn/16/head

Conversation

@wwwjn
Copy link
Copy Markdown
Contributor

@wwwjn wwwjn commented Apr 30, 2026

Stack from ghstack (oldest at bottom):

Map vLLM's data_parallel_size to dp_shard in TorchTitan's mesh math,
enabling EP to span both DP and TP ranks (ep = dp * tp). For inference,
the skip_dp path returns the model directly — no FSDP wrapping needed
since there's no backward pass.

Changes:

  • parallel_dims: dp_replicate mesh always exists
  • vllm_wrapper: ep_size = dp_size * tp_size, dp_shard = dp_size
  • vllm_wrapper: weight loading uses [Replicate()] * mesh.ndim
  • qwen3/parallelize: skip_dp returns model (no fully_shard)
  • llama4/parallelize: clarify shard_placement_fn comments
  • actors/utils: full_tensor() instead of to_local() for TP-sharded logits
  • actors/trainer: strict=False (expert_bias not in HF), ep_enabled clip
  • grpo: tyro.conf.Suppress on model_spec to bypass CLI parsing of Placement
  • experiments/rl: TORCHTITAN_SKIP_INITIAL_HF_LOAD env var
  • config_registry: env/validation_env in MoE configs (fix tyro CLI)
  • qwen3/init: debugmodel_moe vocab_size 2048 → 151936 (match tokenizer)
  • scripts/rl/create_debug_moe_ckpt.py: helper to generate debug checkpoint

Verified end-to-end RL with rl_grpo_qwen3_moe_debug_ep on debug model.

Map vLLM's data_parallel_size to dp_shard in TorchTitan's mesh math,
enabling EP to span both DP and TP ranks (ep = dp * tp). For inference,
the skip_dp path returns the model directly — no FSDP wrapping needed
since there's no backward pass.

Changes:
- parallel_dims: dp_replicate mesh always exists
- vllm_wrapper: ep_size = dp_size * tp_size, dp_shard = dp_size
- vllm_wrapper: weight loading uses [Replicate()] * mesh.ndim
- qwen3/parallelize: skip_dp returns model (no fully_shard)
- llama4/parallelize: clarify shard_placement_fn comments
- actors/utils: full_tensor() instead of to_local() for TP-sharded logits
- actors/trainer: strict=False (expert_bias not in HF), ep_enabled clip
- grpo: tyro.conf.Suppress on model_spec to bypass CLI parsing of Placement
- experiments/rl: TORCHTITAN_SKIP_INITIAL_HF_LOAD env var
- config_registry: env/validation_env in MoE configs (fix tyro CLI)
- qwen3/__init__: debugmodel_moe vocab_size 2048 → 151936 (match tokenizer)
- scripts/rl/create_debug_moe_ckpt.py: helper to generate debug checkpoint

Verified end-to-end RL with rl_grpo_qwen3_moe_debug_ep on debug model.

[ghstack-poisoned]
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 30, 2026
@wwwjn wwwjn changed the title Enable DP-to-EP for MoE inference [WIP]Enable DP-to-EP for MoE inference Apr 30, 2026
Map vLLM's data_parallel_size to dp_shard in TorchTitan's mesh math,
enabling EP to span both DP and TP ranks (ep = dp * tp). For inference,
the skip_dp path returns the model directly — no FSDP wrapping needed
since there's no backward pass.

Changes:
- parallel_dims: dp_replicate mesh always exists
- vllm_wrapper: ep_size = dp_size * tp_size, dp_shard = dp_size
- vllm_wrapper: weight loading uses [Replicate()] * mesh.ndim
- qwen3/parallelize: skip_dp returns model (no fully_shard)
- llama4/parallelize: clarify shard_placement_fn comments
- actors/utils: full_tensor() instead of to_local() for TP-sharded logits
- actors/trainer: strict=False (expert_bias not in HF), ep_enabled clip
- grpo: tyro.conf.Suppress on model_spec to bypass CLI parsing of Placement
- experiments/rl: TORCHTITAN_SKIP_INITIAL_HF_LOAD env var
- config_registry: env/validation_env in MoE configs (fix tyro CLI)
- qwen3/__init__: debugmodel_moe vocab_size 2048 → 151936 (match tokenizer)
- scripts/rl/create_debug_moe_ckpt.py: helper to generate debug checkpoint

Verified end-to-end RL with rl_grpo_qwen3_moe_debug_ep on debug model.

[ghstack-poisoned]
wwwjn added 6 commits April 30, 2026 20:25
Map vLLM's data_parallel_size to dp_shard in TorchTitan's mesh math,
enabling EP to span both DP and TP ranks (ep = dp * tp). For inference,
the skip_dp path returns the model directly — no FSDP wrapping needed
since there's no backward pass.

Changes:
- parallel_dims: dp_replicate mesh always exists
- vllm_wrapper: ep_size = dp_size * tp_size, dp_shard = dp_size
- vllm_wrapper: weight loading uses [Replicate()] * mesh.ndim
- qwen3/parallelize: skip_dp returns model (no fully_shard)
- llama4/parallelize: clarify shard_placement_fn comments
- actors/utils: full_tensor() instead of to_local() for TP-sharded logits
- actors/trainer: strict=False (expert_bias not in HF), ep_enabled clip
- grpo: tyro.conf.Suppress on model_spec to bypass CLI parsing of Placement
- experiments/rl: TORCHTITAN_SKIP_INITIAL_HF_LOAD env var
- config_registry: env/validation_env in MoE configs (fix tyro CLI)
- qwen3/__init__: debugmodel_moe vocab_size 2048 → 151936 (match tokenizer)
- scripts/rl/create_debug_moe_ckpt.py: helper to generate debug checkpoint

Verified end-to-end RL with rl_grpo_qwen3_moe_debug_ep on debug model.

[ghstack-poisoned]
Map vLLM's data_parallel_size to dp_shard in TorchTitan's mesh math,
enabling EP to span both DP and TP ranks (ep = dp * tp). For inference,
the skip_dp path returns the model directly — no FSDP wrapping needed
since there's no backward pass.

Changes:
- parallel_dims: dp_replicate mesh always exists
- vllm_wrapper: ep_size = dp_size * tp_size, dp_shard = dp_size
- vllm_wrapper: weight loading uses [Replicate()] * mesh.ndim
- qwen3/parallelize: skip_dp returns model (no fully_shard)
- llama4/parallelize: clarify shard_placement_fn comments
- actors/utils: full_tensor() instead of to_local() for TP-sharded logits
- actors/trainer: strict=False (expert_bias not in HF), ep_enabled clip
- grpo: tyro.conf.Suppress on model_spec to bypass CLI parsing of Placement
- experiments/rl: TORCHTITAN_SKIP_INITIAL_HF_LOAD env var
- config_registry: env/validation_env in MoE configs (fix tyro CLI)
- qwen3/__init__: debugmodel_moe vocab_size 2048 → 151936 (match tokenizer)
- scripts/rl/create_debug_moe_ckpt.py: helper to generate debug checkpoint

Verified end-to-end RL with rl_grpo_qwen3_moe_debug_ep on debug model.

[ghstack-poisoned]
Map vLLM's data_parallel_size to dp_shard in TorchTitan's mesh math,
enabling EP to span both DP and TP ranks (ep = dp * tp). For inference,
the skip_dp path returns the model directly — no FSDP wrapping needed
since there's no backward pass.

Changes:
- parallel_dims: dp_replicate mesh always exists
- vllm_wrapper: ep_size = dp_size * tp_size, dp_shard = dp_size
- vllm_wrapper: weight loading uses [Replicate()] * mesh.ndim
- qwen3/parallelize: skip_dp returns model (no fully_shard)
- llama4/parallelize: clarify shard_placement_fn comments
- actors/utils: full_tensor() instead of to_local() for TP-sharded logits
- actors/trainer: strict=False (expert_bias not in HF), ep_enabled clip
- grpo: tyro.conf.Suppress on model_spec to bypass CLI parsing of Placement
- experiments/rl: TORCHTITAN_SKIP_INITIAL_HF_LOAD env var
- config_registry: env/validation_env in MoE configs (fix tyro CLI)
- qwen3/__init__: debugmodel_moe vocab_size 2048 → 151936 (match tokenizer)
- scripts/rl/create_debug_moe_ckpt.py: helper to generate debug checkpoint

Verified end-to-end RL with rl_grpo_qwen3_moe_debug_ep on debug model.

[ghstack-poisoned]
Map vLLM's data_parallel_size to dp_shard in TorchTitan's mesh math,
enabling EP to span both DP and TP ranks (ep = dp * tp). For inference,
the skip_dp path returns the model directly — no FSDP wrapping needed
since there's no backward pass.

Changes:
- parallel_dims: dp_replicate mesh always exists
- vllm_wrapper: ep_size = dp_size * tp_size, dp_shard = dp_size
- vllm_wrapper: weight loading uses [Replicate()] * mesh.ndim
- qwen3/parallelize: skip_dp returns model (no fully_shard)
- llama4/parallelize: clarify shard_placement_fn comments
- actors/utils: full_tensor() instead of to_local() for TP-sharded logits
- actors/trainer: strict=False (expert_bias not in HF), ep_enabled clip
- grpo: tyro.conf.Suppress on model_spec to bypass CLI parsing of Placement
- experiments/rl: TORCHTITAN_SKIP_INITIAL_HF_LOAD env var
- config_registry: env/validation_env in MoE configs (fix tyro CLI)
- qwen3/__init__: debugmodel_moe vocab_size 2048 → 151936 (match tokenizer)
- scripts/rl/create_debug_moe_ckpt.py: helper to generate debug checkpoint

Verified end-to-end RL with rl_grpo_qwen3_moe_debug_ep on debug model.

[ghstack-poisoned]
Map vLLM's data_parallel_size to dp_shard in TorchTitan's mesh math,
enabling EP to span both DP and TP ranks (ep = dp * tp). For inference,
the skip_dp path returns the model directly — no FSDP wrapping needed
since there's no backward pass.

Changes:
- parallel_dims: dp_replicate mesh always exists
- vllm_wrapper: ep_size = dp_size * tp_size, dp_shard = dp_size
- vllm_wrapper: weight loading uses [Replicate()] * mesh.ndim
- qwen3/parallelize: skip_dp returns model (no fully_shard)
- llama4/parallelize: clarify shard_placement_fn comments
- actors/utils: full_tensor() instead of to_local() for TP-sharded logits
- actors/trainer: strict=False (expert_bias not in HF), ep_enabled clip
- grpo: tyro.conf.Suppress on model_spec to bypass CLI parsing of Placement
- experiments/rl: TORCHTITAN_SKIP_INITIAL_HF_LOAD env var
- config_registry: env/validation_env in MoE configs (fix tyro CLI)
- qwen3/__init__: debugmodel_moe vocab_size 2048 → 151936 (match tokenizer)
- scripts/rl/create_debug_moe_ckpt.py: helper to generate debug checkpoint

Verified end-to-end RL with rl_grpo_qwen3_moe_debug_ep on debug model.

[ghstack-poisoned]
Map vLLM's data_parallel_size to dp_shard in TorchTitan's mesh math,
enabling EP to span both DP and TP ranks (ep = dp * tp). For inference,
the skip_dp path returns the model directly — no FSDP wrapping needed
since there's no backward pass.

Changes:
- parallel_dims: dp_replicate mesh always exists
- vllm_wrapper: ep_size = dp_size * tp_size, dp_shard = dp_size
- vllm_wrapper: weight loading uses [Replicate()] * mesh.ndim
- qwen3/parallelize: skip_dp returns model (no fully_shard)
- llama4/parallelize: clarify shard_placement_fn comments
- actors/utils: full_tensor() instead of to_local() for TP-sharded logits
- actors/trainer: strict=False (expert_bias not in HF), ep_enabled clip
- grpo: tyro.conf.Suppress on model_spec to bypass CLI parsing of Placement
- experiments/rl: TORCHTITAN_SKIP_INITIAL_HF_LOAD env var
- config_registry: env/validation_env in MoE configs (fix tyro CLI)
- qwen3/__init__: debugmodel_moe vocab_size 2048 → 151936 (match tokenizer)
- scripts/rl/create_debug_moe_ckpt.py: helper to generate debug checkpoint

Verified end-to-end RL with rl_grpo_qwen3_moe_debug_ep on debug model.

[ghstack-poisoned]
wwwjn added a commit that referenced this pull request May 5, 2026
…cher (#3193)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* #3172
* #3171
* #3142
* __->__ #3193

Today the dispatcher's _split_along_sp() raises when num_tokens (bs *
slen) is not divisible by sp_size (= TP degree). Real workloads with
varlen prompts can land on non-divisible totals and crash the MoE
forward.

Pad inside dispatch(): round num_tokens up to the next multiple of
sp_size, padding x and top_scores with zeros and
selected_experts_indices with 0 (so pad rows route deterministically to
expert 0 with zero score). combine() reads metadata.original_num_tokens
to size the scatter_add buffer at the padded length and slices the pad
rows off before returning. When sp_size == 1 or input is already
divisible, behavior is bitwise identical to today.

Pad tokens are numerically inert:
- Zero scores -> contribution to scatter_add is exactly zero either
before or after expert compute (independent of score_before_experts).
- Pad indices fall in [original, padded), which is sliced off after
scatter_add, so they never appear in the returned output.

Trainer/generator can pad by different amounts depending on their batch
shapes; the unpadded portions remain bitwise identical.

- TorchAOTokenDispatcher inherits dispatch/combine and gets the fix for
free.
- DeepEPTokenDispatcher uses a separate metadata type and is unaffected.
Map vLLM's data_parallel_size to dp_shard in TorchTitan's mesh math,
enabling EP to span both DP and TP ranks (ep = dp * tp). For inference,
the skip_dp path returns the model directly — no FSDP wrapping needed
since there's no backward pass.

Changes:
- parallel_dims: dp_replicate mesh always exists
- vllm_wrapper: ep_size = dp_size * tp_size, dp_shard = dp_size
- vllm_wrapper: weight loading uses [Replicate()] * mesh.ndim
- qwen3/parallelize: skip_dp returns model (no fully_shard)
- llama4/parallelize: clarify shard_placement_fn comments
- actors/utils: full_tensor() instead of to_local() for TP-sharded logits
- actors/trainer: strict=False (expert_bias not in HF), ep_enabled clip
- grpo: tyro.conf.Suppress on model_spec to bypass CLI parsing of Placement
- experiments/rl: TORCHTITAN_SKIP_INITIAL_HF_LOAD env var
- config_registry: env/validation_env in MoE configs (fix tyro CLI)
- qwen3/__init__: debugmodel_moe vocab_size 2048 → 151936 (match tokenizer)
- scripts/rl/create_debug_moe_ckpt.py: helper to generate debug checkpoint

Verified end-to-end RL with rl_grpo_qwen3_moe_debug_ep on debug model.

[ghstack-poisoned]
@wwwjn wwwjn closed this May 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant