Skip to content

[rl] Register customized config parser to vllm + less vllm config dependency#3242

Open
wwwjn wants to merge 5 commits intogh/wwwjn/20/basefrom
gh/wwwjn/20/head
Open

[rl] Register customized config parser to vllm + less vllm config dependency#3242
wwwjn wants to merge 5 commits intogh/wwwjn/20/basefrom
gh/wwwjn/20/head

Conversation

@wwwjn
Copy link
Copy Markdown
Contributor

@wwwjn wwwjn commented May 6, 2026

Stack from ghstack (oldest at bottom):

vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. Why we need this:

  • get rid of dependency on a HF format checkpoint folder when initializing. Don't implicitly depend on config.json as config source of truth

Another changes in this PR:

  • remove the round-trip translation from torchtitan config -> vllm config -> torchtitan config. Using closure to bypass.

[ghstack-poisoned]
@wwwjn wwwjn requested review from fegin, tianyu-l and wconstab as code owners May 6, 2026 19:07
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 6, 2026
@wwwjn wwwjn changed the title config parser [rl] Register customized config parser to vllm May 6, 2026
Comment thread torchtitan/experiments/rl/actors/generator.py Outdated
# Dynamic config parser class capturing ModelSpec (and any registration-
# time custom fields) in the closure.
@register_config_parser(TORCHTITAN_CONFIG_FORMAT)
class TorchTitanConfigParserForSpec(ConfigParserBase):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nested in the registry() function because we need to access model_spec when registering config_parser via closure

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR. should be cleaned.

Comment thread torchtitan/experiments/rl/models/vllm_config_parser.py Outdated
during model construction.
"""
from torchtitan.experiments.rl.models.vllm_wrapper import TorchTitanVLLMModelWrapper
from transformers import PretrainedConfig
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We / RL side have to depend on PretrainedConfig definition from transformers as it's required to as return-ed type of ConfigParser. I think ConfigParser is a clean abstraction, but we would need to depend on transformers unfortunately

Comment thread torchtitan/experiments/rl/models/vllm_registry.py Outdated
vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. It serves as 2 purpose:
1. get rid of dependency on a HF format checkpoint folder when initializing 
2. Passing customized args to VLLMModelWrapper, eg CompileConfig, skip_init_load_weights



[ghstack-poisoned]
Comment thread torchtitan/experiments/rl/actors/generator.py Outdated
Comment thread torchtitan/experiments/rl/models/vllm_wrapper.py Outdated
Comment thread torchtitan/experiments/rl/actors/generator.py Outdated
Comment thread torchtitan/experiments/rl/models/vllm_registry.py Outdated
compile_config=compile_config,
)

# Set the class name so vLLM can identify it
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why removing these comments

wwwjn added 2 commits May 6, 2026 13:33
vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. It serves as 2 purpose:
1. get rid of dependency on a HF format checkpoint folder when initializing 
2. Passing customized args to VLLMModelWrapper, eg CompileConfig, skip_init_load_weights



[ghstack-poisoned]
vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. It serves as 2 purpose:
1. get rid of dependency on a HF format checkpoint folder when initializing 
2. Passing customized args to VLLMModelWrapper, eg CompileConfig, skip_init_load_weights



[ghstack-poisoned]
@pytorch-bot pytorch-bot Bot added the ciflow/rl label May 7, 2026
@wwwjn wwwjn changed the title [rl] Register customized config parser to vllm [rl] Register customized config parser to vllm + less vllm config dependency May 7, 2026
… config dependency"


vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. Why we need this:
- get rid of dependency on a HF format checkpoint folder when initializing. Don't implicitly depend on `config.json` as config source of truth

Another changes in this PR:
- remove the round-trip translation from torchtitan config -> vllm config -> torchtitan config. Using closure to bypass.


[ghstack-poisoned]
device_id: torch.device | None = None
if comm_config.mode == "torchcomms":
try:
import torchcomms # noqa: F401 # pyrefly: ignore [missing-import]
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be reverted, not related to this PR

from torchtitan.components.optimizer import OptimizersContainer
from torchtitan.config import CommConfig, Configurable, TORCH_DTYPE_MAP
from torchtitan.config.configs import (
from torchtitan.config import (
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is just consolidate the import path

Comment on lines +167 to +171
if p.enable_sequence_parallel:
logger.warning(
"Generator enable_sequence_parallel=True hurts inference "
"throughput; prefer SP=False."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this won't be supported by spmd_types erasure mode I think, so I don't mind we ban it
cc @pianpwk

Copy link
Copy Markdown
Contributor

@pianpwk pianpwk May 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm are we not supporting SP in spmd_types

Copy link
Copy Markdown
Contributor

@tianyu-l tianyu-l May 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we support SP in spmd_types when sequence length is evenly divisible; we don't support SP when sequence length is not evenly divisible.

  • spmd_types doesn't handle padding & unpadding like DTensor.
  • uneven usually show up only in inference
  • for inference we always use non-SP (vanilla TP)

@@ -199,14 +214,17 @@ def __init__(
engine_kwargs = dict(
model=model_path,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this path for?


assert vllm_config is not None, "vllm_config is required"

# PP and CP are not supported on this inference path
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this "raise ValueError" may better happen at grpo trainer post_init, to be consistent

here we only need assert

wrapper's parallelize step.
"""
from torchtitan.experiments.rl.models.vllm_wrapper import TorchTitanVLLMModelWrapper
from transformers import PretrainedConfig
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, can we not depend on this? I know it's already implicit via vllm dependency, but trying to see if we can avoid explicit dependency (someday we can remove)

# parser only produces HF-shaped fields; torchtitan-specific config is
# delivered through the model-class closure above.
@register_config_parser(TORCHTITAN_CONFIG_FORMAT)
class TorchTitanConfigParserForSpec(ConfigParserBase):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class TorchTitanConfigParserForSpec(ConfigParserBase):
class TorchTitanConfigParser(ConfigParserBase):


# Create dynamic model class capturing ModelSpec in the closure
# Dynamic model class capturing torchtitan config in the closure.
class TorchTitanVLLMModelFromSpec(TorchTitanVLLMModelWrapper):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe simplify, from a titan-centric view

Suggested change
class TorchTitanVLLMModelFromSpec(TorchTitanVLLMModelWrapper):
class VLLMModelFromSpec(VLLMModelWrapper):

**kwargs,
):
config_dict = model_spec_to_hf_config_dict(model_spec)
return config_dict, PretrainedConfig.from_dict(config_dict)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually very weird that the contract is both a config_dict and a cls(config_dict), sounds redundant to me

if moe is not None:
hf[
"num_experts"
] = moe.experts.num_experts # presence required: >0 toggles MoE/EP branches
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: put # presence required above this field

}

def register_model_to_vllm_model_registry(
ffn = getattr(layer0, "feed_forward", None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using layer0 is not robust? what if a transformer has 1st layer MoE, 2nd layer FFN?

# Unused: only v1/metrics/perf.py reads it (off by default). SwiGLU hidden == w1.out_features.
hf["intermediate_size"] = ffn.w1.out_features

moe = getattr(layer0, "moe", None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vice versa

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rl ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants