Skip to content

[rl] Enable TP2EP for MoE inference in vLLM wrapper#3142

Open
wwwjn wants to merge 23 commits intogh/wwwjn/14/basefrom
gh/wwwjn/14/head
Open

[rl] Enable TP2EP for MoE inference in vLLM wrapper#3142
wwwjn wants to merge 23 commits intogh/wwwjn/14/basefrom
gh/wwwjn/14/head

Conversation

@wwwjn
Copy link
Copy Markdown
Contributor

@wwwjn wwwjn commented Apr 28, 2026

Stack from ghstack (oldest at bottom):

Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:

  • Meta-device init + to_empty() + init_states() for large MoE models
  • EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
  • Use enable_sp (not inference flag) for output layout and SP splitting
  • enable_sequence_parallel=True, disable_loss_parallel=True for inference
  • Remove stale ModelConvertersContainer references
  • Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.

Similar as #3057

Enable Expert Parallelism for MoE models in the vLLM inference path,
with TP on dense layers. Includes meta-device init, init_states() fix
for RoPE cache, EP mesh creation, enable_sp-based output layout, and
MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching
native vLLM on Qwen3-30B-A3B.

[ghstack-poisoned]
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 28, 2026
Enable Expert Parallelism for MoE models in the vLLM inference path,
with TP on dense layers. Includes meta-device init, init_states() fix
for RoPE cache, EP mesh creation, enable_sp-based output layout, and
MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching
native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
Comment thread torchtitan/experiments/rl/grpo.py Outdated
Comment thread torchtitan/experiments/rl/models/vllm_wrapper.py Outdated
@wwwjn wwwjn changed the title Enable EP+TP for MoE inference in vLLM wrapper Enable TP2EP for MoE inference in vLLM wrapper Apr 28, 2026
Comment thread torchtitan/experiments/rl/models/vllm_wrapper.py Outdated
Comment thread torchtitan/experiments/rl/models/vllm_wrapper.py Outdated
Comment thread torchtitan/models/common/token_dispatcher.py Outdated
Comment thread torchtitan/models/llama4/parallelize.py Outdated
Comment thread torchtitan/models/qwen3/state_dict_adapter.py Outdated
Comment thread torchtitan/experiments/rl/grpo.py Outdated
Comment thread torchtitan/experiments/rl/actors/generator.py
Comment thread torchtitan/experiments/rl/config_registry.py
Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
Comment thread torchtitan/experiments/rl/models/vllm_wrapper.py Outdated
dp_shard=1,
cp=parallel_config.decode_context_parallel_size,
tp=parallel_config.tensor_parallel_size,
cp=1,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why hardcoding?

Copy link
Copy Markdown
Contributor Author

@wwwjn wwwjn Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not enabled now, but previous this is just for testing purpose. And we are leak of knowledge of decode_context_parallel_size in vllm, and I think hard codinf it to be 1 is safer before we officially support CP in inference

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should assert vllm CP / PP degrees to be 1

Comment thread torchtitan/experiments/rl/models/vllm_wrapper.py
Comment thread torchtitan/experiments/rl/config_registry.py
Comment thread torchtitan/experiments/rl/grpo.py Outdated
Comment thread torchtitan/models/llama4/parallelize.py Outdated
Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
@wwwjn wwwjn changed the title Enable TP2EP for MoE inference in vLLM wrapper [rl] Enable TP2EP for MoE inference in vLLM wrapper Apr 30, 2026
Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
wwwjn added 2 commits April 30, 2026 20:25
Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
@github-actions github-actions Bot mentioned this pull request May 1, 2026
Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
Copy link
Copy Markdown
Contributor Author

@wwwjn wwwjn May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is a soft link

# Initial load model weights from HuggingFace checkpoint path.
import os as _os

if _os.environ.get("TORCHTITAN_SKIP_INITIAL_HF_LOAD") != "1":
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used a env variable to random init in vllm_wrapper because it's impossible to change the vllm_wrapper init signature, which is hard to pass as a kwarg or config

storage_reader.read_metadata().state_dict_metadata.keys()
)
missing = set(hf_state_dict.keys()) - hf_keys_in_checkpoint
unexpected_missing = {k for k in missing if not k.endswith(".expert_bias")}
Copy link
Copy Markdown
Contributor Author

@wwwjn wwwjn May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to remove expert_bias from the check after the loss based load balancing: #3000. If expert_bias is in our model, we should load from checkpoint

wwwjn added 3 commits May 4, 2026 16:14
Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
wwwjn added a commit that referenced this pull request May 5, 2026
…cher (#3193)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* #3172
* #3171
* #3142
* __->__ #3193

Today the dispatcher's _split_along_sp() raises when num_tokens (bs *
slen) is not divisible by sp_size (= TP degree). Real workloads with
varlen prompts can land on non-divisible totals and crash the MoE
forward.

Pad inside dispatch(): round num_tokens up to the next multiple of
sp_size, padding x and top_scores with zeros and
selected_experts_indices with 0 (so pad rows route deterministically to
expert 0 with zero score). combine() reads metadata.original_num_tokens
to size the scatter_add buffer at the padded length and slices the pad
rows off before returning. When sp_size == 1 or input is already
divisible, behavior is bitwise identical to today.

Pad tokens are numerically inert:
- Zero scores -> contribution to scatter_add is exactly zero either
before or after expert compute (independent of score_before_experts).
- Pad indices fall in [original, padded), which is sliced off after
scatter_add, so they never appear in the returned output.

Trainer/generator can pad by different amounts depending on their batch
shapes; the unpadded portions remain bitwise identical.

- TorchAOTokenDispatcher inherits dispatch/combine and gets the fix for
free.
- DeepEPTokenDispatcher uses a separate metadata type and is unaffected.
wwwjn added 2 commits May 4, 2026 21:23
Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
wwwjn added 2 commits May 5, 2026 20:07
Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
Comment thread tests/assets/qwen3_moe_debug/config.json Outdated
Comment thread tests/assets/qwen3_moe_debug/generation_config.json Outdated
Comment thread torchtitan/experiments/rl/config_registry.py
"""Top-level config for RL training."""

model_spec: ModelSpec | None = None
model_spec: Annotated[ModelSpec | None, tyro.conf.Suppress] = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we didn't need this before?

model=model,
model_state_dict=torchtitan_state_dict,
options=StateDictOptions(strict=True),
# strict=False: HF MoE checkpoints don't carry expert_bias buffers
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then expert_bias should be non-persistent buffer?
https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/common/moe.py#L349

I believe qwen is not using loss-free-load-balancing, so we should enable load balance loss for it.
This is blocked by @pianpwk 's #3000

import os as _os

if _os.environ.get("TORCHTITAN_SKIP_INITIAL_HF_LOAD") != "1":
self._initial_load_weights(checkpoint_path=vllm_config.model_config.model)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused here. Why do we need to initial load weights at all?
https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/rl/grpo.py#L453
this will init the weights on generator. The flow should be

  • trainer build model, load weights from HF (which we now can skip with a flag for debug purpose)
  • trainer push model state dict to TS
  • generator pull model state dict from TS and load into model

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are right. I guess it's more from a historical reason: we started from inference only, so we will need a way to load weights . Now the initial_load_weights are unconditionally called when generator are initialized. I agree we should skip either 1) initial_load_weights (but passing an extra flag to vllm_wrapper need env variable), , or 2) skip the first weight sync

I'd prefer go with 1)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's 1) -- is it generator side or trainer side?

I think we can just remove any initial load weight from generator (and only rely on the first weight sync to get proper weight)

Copy link
Copy Markdown
Contributor Author

@wwwjn wwwjn May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replied here: #3142 (comment)

I think we can just remove any initial load weight from generator (and only rely on the first weight sync to get proper weight)

In RL yes.

tp=parallel_config.tensor_parallel_size,
cp=1,
tp=tp_size,
pp=parallel_config.pipeline_parallel_size,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then you should hardcode this to 1 as well?

dp_shard=1,
cp=parallel_config.decode_context_parallel_size,
tp=parallel_config.tensor_parallel_size,
cp=1,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should assert vllm CP / PP degrees to be 1

Comment thread torchtitan/experiments/rl/models/vllm_wrapper.py Outdated
Comment on lines +317 to +321
# All-to-all dispatch tokens to EP ranks.
# Use the non-autograd version under inference (vLLM), since
# _c10d_functional_autograd ops don't dispatch correctly without
# an active autograd context. Gated by a Python bool so the choice
# is stable at trace time.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm quite uncomfortable with this. An autograd function should be able to run fwd-only under inference mode, unless there is strong reason why it can't.
cc @ezyang @fegin

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image I dig into more details with claude: Under inference mode, it (`all_to_all_single_autograd`) never got dispatched to the only autograd kernel

My questions is why we use all_to_all_single_autograd for training? Can we use all_to_all_single for training?

Copy link
Copy Markdown
Member

@xmfan xmfan May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sry why do we think all_to_all_autograd didn't dispatch properly? autograd.Functions have some early exit conditions in inference mode, but I don't think that prevents you from using it in inference

UPDATE: we need to revive pytorch/pytorch#149411

Copy link
Copy Markdown
Contributor Author

@wwwjn wwwjn May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my current PR, I will leave the current branching as is, and create a issue in torchtitan to track the progress - The ultimate goal for us is to consolidate to a single ops for pretrain and RL.

#3268

Comment thread tests/assets/qwen3_moe_debug/config.json Outdated
Comment thread tests/assets/qwen3_moe_debug/generation_config.json Outdated
Comment thread torchtitan/models/common/token_dispatcher.py
Comment thread torchtitan/experiments/rl/config_registry.py
Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
wwwjn added 4 commits May 6, 2026 13:10
Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
@pytorch-bot pytorch-bot Bot added the ciflow/rl label May 7, 2026
f"got cp={p.context_parallel_degree}"
)
if p.expert_parallel_degree > 1:
# vLLM ties EP to TP (enable_expert_parallel=True reuses the TP
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so vllm doesn't support EP>TP? by also reusing DP group if we turn on FSDP.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When EP > TP , we will borrow DP degrees, which is achieved in next PR. And it's pure DP , not FSDP

config_format=TORCHTITAN_CONFIG_FORMAT,
dtype=config.model_dtype,
tensor_parallel_size=config.parallelism.tensor_parallel_degree,
enable_expert_parallel=enable_ep,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this inline?

dp_shard=1,
cp=parallel_config.decode_context_parallel_size,
tp=parallel_config.tensor_parallel_size,
cp=1,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe leave a comment on why hardcoding cp to 1? also for fsdp, wondering if we could turn it on and later have both tp2ep and dp2ep

moe = getattr(layer, "moe", None)
if moe is None:
continue
dispatcher = getattr(moe.experts, "token_dispatcher", None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you raise ValueError when it's None?

model_spec=model_registry("30B-A3B", attn_backend="varlen"),
hf_assets_path="torchtitan/experiments/rl/example_checkpoint/Qwen3-30B-A3B",
num_steps=10,
compile=CompileConfig(enable=True, backend="aot_eager"),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this doesn't need to be False?

"""Top-level config for RL training."""

model_spec: ModelSpec | None = None
model_spec: Annotated[ModelSpec | None, tyro.conf.Suppress] = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this?

routed_input = all_to_all_single_autograd(
# All-to-all dispatch tokens to EP ranks.
# Use the non-autograd version under inference (vLLM), since
# _c10d_functional_autograd ops don't dispatch correctly without
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious why can't torch.inference_mode automatically switch between these two if this is the only difference.

Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
# Replicate (no-op) and Shard(-1) (all-gather) lm_head output placements.
if isinstance(logits, DTensor):
logits = logits.to_local()
logits = logits.full_tensor()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should work with disable_loss_parallel already?

wwwjn added 2 commits May 7, 2026 15:15
Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
Enable Expert Parallelism for MoE models in the vLLM inference path, with TP on dense layers. Includes meta-device init, init_states() fix for RoPE cache, EP mesh creation, enable_sp-based output layout, and MoE debug configs.

Key changes:
- Meta-device init + to_empty() + init_states() for large MoE models
- EP mesh: ep = tp_size (TP ranks become EP ranks for experts)
- Use enable_sp (not inference flag) for output layout and SP splitting
- enable_sequence_parallel=True, disable_loss_parallel=True for inference
- Remove stale ModelConvertersContainer references
- Add MoE debug and 30B-A3B RL configs

Verified: TP=2, EP+TP=2, EP+TP=4 all produce correct output matching native vLLM on Qwen3-30B-A3B.


Similar as #3057 

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rl ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants