[rl] Register customized config parser to vllm + less vllm config dependency#3242
[rl] Register customized config parser to vllm + less vllm config dependency#3242wwwjn wants to merge 5 commits intogh/wwwjn/20/basefrom
Conversation
[ghstack-poisoned]
| # Dynamic config parser class capturing ModelSpec (and any registration- | ||
| # time custom fields) in the closure. | ||
| @register_config_parser(TORCHTITAN_CONFIG_FORMAT) | ||
| class TorchTitanConfigParserForSpec(ConfigParserBase): |
There was a problem hiding this comment.
Nested in the registry() function because we need to access model_spec when registering config_parser via closure
There was a problem hiding this comment.
Not related to this PR. should be cleaned.
| during model construction. | ||
| """ | ||
| from torchtitan.experiments.rl.models.vllm_wrapper import TorchTitanVLLMModelWrapper | ||
| from transformers import PretrainedConfig |
There was a problem hiding this comment.
We / RL side have to depend on PretrainedConfig definition from transformers as it's required to as return-ed type of ConfigParser. I think ConfigParser is a clean abstraction, but we would need to depend on transformers unfortunately
vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. It serves as 2 purpose: 1. get rid of dependency on a HF format checkpoint folder when initializing 2. Passing customized args to VLLMModelWrapper, eg CompileConfig, skip_init_load_weights [ghstack-poisoned]
| compile_config=compile_config, | ||
| ) | ||
|
|
||
| # Set the class name so vLLM can identify it |
There was a problem hiding this comment.
why removing these comments
vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. It serves as 2 purpose: 1. get rid of dependency on a HF format checkpoint folder when initializing 2. Passing customized args to VLLMModelWrapper, eg CompileConfig, skip_init_load_weights [ghstack-poisoned]
vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. It serves as 2 purpose: 1. get rid of dependency on a HF format checkpoint folder when initializing 2. Passing customized args to VLLMModelWrapper, eg CompileConfig, skip_init_load_weights [ghstack-poisoned]
… config dependency" vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. Why we need this: - get rid of dependency on a HF format checkpoint folder when initializing. Don't implicitly depend on `config.json` as config source of truth Another changes in this PR: - remove the round-trip translation from torchtitan config -> vllm config -> torchtitan config. Using closure to bypass. [ghstack-poisoned]
| device_id: torch.device | None = None | ||
| if comm_config.mode == "torchcomms": | ||
| try: | ||
| import torchcomms # noqa: F401 # pyrefly: ignore [missing-import] |
There was a problem hiding this comment.
Will be reverted, not related to this PR
| from torchtitan.components.optimizer import OptimizersContainer | ||
| from torchtitan.config import CommConfig, Configurable, TORCH_DTYPE_MAP | ||
| from torchtitan.config.configs import ( | ||
| from torchtitan.config import ( |
There was a problem hiding this comment.
This change is just consolidate the import path
| if p.enable_sequence_parallel: | ||
| logger.warning( | ||
| "Generator enable_sequence_parallel=True hurts inference " | ||
| "throughput; prefer SP=False." | ||
| ) |
There was a problem hiding this comment.
this won't be supported by spmd_types erasure mode I think, so I don't mind we ban it
cc @pianpwk
There was a problem hiding this comment.
hmm are we not supporting SP in spmd_types
There was a problem hiding this comment.
we support SP in spmd_types when sequence length is evenly divisible; we don't support SP when sequence length is not evenly divisible.
- spmd_types doesn't handle padding & unpadding like DTensor.
- uneven usually show up only in inference
- for inference we always use non-SP (vanilla TP)
| @@ -199,14 +214,17 @@ def __init__( | |||
| engine_kwargs = dict( | |||
| model=model_path, | |||
|
|
||
| assert vllm_config is not None, "vllm_config is required" | ||
|
|
||
| # PP and CP are not supported on this inference path |
There was a problem hiding this comment.
this "raise ValueError" may better happen at grpo trainer post_init, to be consistent
here we only need assert
| wrapper's parallelize step. | ||
| """ | ||
| from torchtitan.experiments.rl.models.vllm_wrapper import TorchTitanVLLMModelWrapper | ||
| from transformers import PretrainedConfig |
There was a problem hiding this comment.
oh, can we not depend on this? I know it's already implicit via vllm dependency, but trying to see if we can avoid explicit dependency (someday we can remove)
| # parser only produces HF-shaped fields; torchtitan-specific config is | ||
| # delivered through the model-class closure above. | ||
| @register_config_parser(TORCHTITAN_CONFIG_FORMAT) | ||
| class TorchTitanConfigParserForSpec(ConfigParserBase): |
There was a problem hiding this comment.
| class TorchTitanConfigParserForSpec(ConfigParserBase): | |
| class TorchTitanConfigParser(ConfigParserBase): |
|
|
||
| # Create dynamic model class capturing ModelSpec in the closure | ||
| # Dynamic model class capturing torchtitan config in the closure. | ||
| class TorchTitanVLLMModelFromSpec(TorchTitanVLLMModelWrapper): |
There was a problem hiding this comment.
maybe simplify, from a titan-centric view
| class TorchTitanVLLMModelFromSpec(TorchTitanVLLMModelWrapper): | |
| class VLLMModelFromSpec(VLLMModelWrapper): |
| **kwargs, | ||
| ): | ||
| config_dict = model_spec_to_hf_config_dict(model_spec) | ||
| return config_dict, PretrainedConfig.from_dict(config_dict) |
There was a problem hiding this comment.
It's actually very weird that the contract is both a config_dict and a cls(config_dict), sounds redundant to me
| if moe is not None: | ||
| hf[ | ||
| "num_experts" | ||
| ] = moe.experts.num_experts # presence required: >0 toggles MoE/EP branches |
There was a problem hiding this comment.
nit: put # presence required above this field
| } | ||
|
|
||
| def register_model_to_vllm_model_registry( | ||
| ffn = getattr(layer0, "feed_forward", None) |
There was a problem hiding this comment.
using layer0 is not robust? what if a transformer has 1st layer MoE, 2nd layer FFN?
| # Unused: only v1/metrics/perf.py reads it (off by default). SwiGLU hidden == w1.out_features. | ||
| hf["intermediate_size"] = ffn.w1.out_features | ||
|
|
||
| moe = getattr(layer0, "moe", None) |
Stack from ghstack (oldest at bottom):
vllm has this customized config parser registry support so we can plug in TorchTitan's config parser. Why we need this:
config.jsonas config source of truthAnother changes in this PR: