-
Notifications
You must be signed in to change notification settings - Fork 810
[rl] Register customized config parser to vllm + less vllm config dependency #3242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/wwwjn/20/base
Are you sure you want to change the base?
Changes from all commits
45c13fb
bfe2166
9414f66
e3de100
e1d5de8
285429e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,12 +20,14 @@ | |
| ) | ||
| from torchtitan.components.lr_scheduler import LRSchedulersContainer | ||
| from torchtitan.components.optimizer import OptimizersContainer | ||
| from torchtitan.config import CommConfig, Configurable, TORCH_DTYPE_MAP | ||
| from torchtitan.config.configs import ( | ||
| from torchtitan.config import ( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is just consolidate the import path |
||
| ActivationCheckpointConfig, | ||
| CommConfig, | ||
| CompileConfig, | ||
| Configurable, | ||
| DebugConfig, | ||
| ParallelismConfig, | ||
| TORCH_DTYPE_MAP, | ||
| TrainingConfig, | ||
| ) | ||
| from torchtitan.distributed import ParallelDims, utils as dist_utils | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,9 +37,12 @@ | |
| from monarch.actor import this_host | ||
| from monarch.spmd import setup_torch_elastic_env_async | ||
|
|
||
| from torchtitan.config import Configurable, ParallelismConfig | ||
| from torchtitan.config.configs import CompileConfig | ||
| from torchtitan.config.manager import ConfigManager | ||
| from torchtitan.config import ( | ||
| CompileConfig, | ||
| ConfigManager, | ||
| Configurable, | ||
| ParallelismConfig, | ||
| ) | ||
| from torchtitan.experiments.rl.actors.generator import SamplingConfig, VLLMGenerator | ||
| from torchtitan.experiments.rl.actors.trainer import PolicyTrainer | ||
| from torchtitan.experiments.rl.types import ( | ||
|
|
@@ -234,6 +237,41 @@ def __post_init__(self): | |
| "and has not been validated for determinism." | ||
| ) | ||
|
|
||
| # VLLMGenerator only supports TP. vLLM handles its own parallelism; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry my earlier comment introduced some noise I meant to say that we should put some check in post_init, but not ParallelismConfig's post_init because it'll be shared by core trainer, policy trainer, and generator. I think the proper place to put these checks is generator config's post_init. |
||
| # we only apply TP via the core parallelize function. | ||
| if self.generator.parallelism.data_parallel_replicate_degree != 1: | ||
| raise ValueError( | ||
| f"Generator does not support data parallel replication, " | ||
| f"got dp_replicate={self.generator.parallelism.data_parallel_replicate_degree}" | ||
| ) | ||
| if self.generator.parallelism.pipeline_parallel_degree > 1: | ||
| raise ValueError( | ||
| f"Generator does not support pipeline parallelism, " | ||
| f"got pp={self.generator.parallelism.pipeline_parallel_degree}" | ||
| ) | ||
| if self.generator.parallelism.context_parallel_degree > 1: | ||
| raise ValueError( | ||
| f"Generator does not support context parallelism, " | ||
| f"got cp={self.generator.parallelism.context_parallel_degree}" | ||
| ) | ||
| if self.generator.parallelism.expert_parallel_degree > 1: | ||
| raise ValueError( | ||
| f"Generator does not support expert parallelism, " | ||
| f"got ep={self.generator.parallelism.expert_parallel_degree}" | ||
| ) | ||
| if self.generator.parallelism.enable_sequence_parallel: | ||
| raise ValueError( | ||
| "Generator does not support sequence parallelism: " | ||
| "spmd_types erasure mode requires sequence length to be " | ||
| "evenly divisible by TP, which doesn't hold for inference " | ||
| "(uneven batches). Set enable_sequence_parallel=False." | ||
| ) | ||
| if not self.generator.parallelism.disable_loss_parallel: | ||
| raise ValueError( | ||
| "Generator requires disable_loss_parallel=True, " | ||
| f"got disable_loss_parallel={self.generator.parallelism.disable_loss_parallel}" | ||
| ) | ||
|
|
||
| def __init__(self, config: Config): | ||
| self.config = config | ||
| self._proc_meshes = [] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this path for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now it serves 2 purpose:
tokenizer=tokenizer_pathto EngineArgs.2.Initial_weight_loading: Will sort out the weight loading part for both trainer and generator in next PR.
After lifting both, we can pass some fake path, say "torchtitan", to vllm