[rl] Trainer refactor#2985
Conversation
7b46b2e to
fbf1662
Compare
| compile: CompileConfig = field(default_factory=CompileConfig) | ||
| debug: DebugConfig = field(default_factory=DebugConfig) | ||
| loss: Configurable.Config = field(default_factory=Configurable.Config) | ||
| ac_config: ActivationCheckpointConfig = field( |
There was a problem hiding this comment.
These additions are needed when properly doing parallelisms
felipemello1
left a comment
There was a problem hiding this comment.
First pass, left some comments and nits.
regarding testing, you said that the different setting "passes". Whats the criteria? should we be expecting the same final value? same speed?
| num_steps=10, | ||
| trainer=PolicyTrainer.Config( | ||
| model_spec=model_spec, | ||
| transfer_dtype="bfloat16", |
There was a problem hiding this comment.
nit: should this be hardcoded?
There was a problem hiding this comment.
I'd rather this be explicit, but maybe @daniellepintz has thoughts
There was a problem hiding this comment.
would this ever be different from generator.model_dtype? If not, we shouldn't give user this option only to enable the program to fail.
Another option is to allow different transfer dtype and generator dtype, so that the synced weight is converted one more time at receiver side.
| training=TrainingConfig(), | ||
| parallelism=ParallelismConfig( | ||
| tensor_parallel_degree=2, | ||
| disable_loss_parallel=True, |
There was a problem hiding this comment.
nit: should this be hardcoded?
There was a problem hiding this comment.
When we were using the TP-only parallelize plan, it had hardcoded the output to return plain tensors. For now, we can just keep it the same by disabling loss_parallel in the standard parallelize plan.
There was a problem hiding this comment.
Becase this PR switches trainer's parallel plan to be core torchtitan's qwen3 parallelism plan , yes we would need to have this change to make sure trainer's parallelism is applied the same before and after.
However, I'm a little bit doubt that loss parallel won't work for trainer. Would need to dig more into the how loss parallel works in RL settings
FYI after #2149 this PR landed, we can switch both trainer and generator's parallel plan to core torchtitan's qwen3 plan
Updated the description to be more clear |
5c2b5c4 to
832919f
Compare
686e9a6 to
3e9ee64
Compare
| model_converters: ModelConvertersContainer.Config = field( | ||
| default_factory=ModelConvertersContainer.Config | ||
| ) | ||
| transfer_dtype: str = "" |
There was a problem hiding this comment.
shall the type be https://github.com/pytorch/torchtitan/blob/main/torchtitan/config/configs.py#L50
we can use None for "not requested"
There was a problem hiding this comment.
This should be combined with #2942 in a combined WeightSyncConfig, depending on which lands first. cc @daniellepintz
|
|
||
| # Conditionally build frozen reference model for KL penalty | ||
| # TODO: @joecummings remove ref entirely, this is hacky and we don't need it | ||
| if getattr(config.loss, "kl_coef", 0) > 0: |
There was a problem hiding this comment.
IIRC @wwwjn had some concerns here, e.g. if we remove ref model entirely, then is it still GRPO?
I'm OK with removing it entirely for now to focus on the core part.
| # Post optimizer step model converters hook. | ||
| # e.g. calculate float8 dynamic amax/scale for FSDP2 | ||
| self.optimizers.register_step_post_hook( | ||
| lambda *args, **kwargs: model_converters.post_optimizer_hook( |
There was a problem hiding this comment.
@acisseJZhong in your refactor, I don't think we need this hook anymore -- let's discuss this later.
|
|
||
| Returns: | ||
| dict: Training metrics (loss, policy version, etc.). | ||
| dict with loss metrics, advantage stats, and logprob verification. |
There was a problem hiding this comment.
Do you think we should make this into a dataclass as well? Please leave a TODO if so.
There was a problem hiding this comment.
It depends on the direction we go with Observability. We could record all these metrics here, directly with record_metrics (or something similar). If we end up forgoing that and returning metrics here (not my preferred option), then we might consider a dataclass. Punting on this to @felipemello1's work
| training=TrainingConfig(), | ||
| parallelism=ParallelismConfig( | ||
| tensor_parallel_degree=2, | ||
| disable_loss_parallel=True, |
| # Generator uses RL-specific parallelize (TP-only, no FSDP, vLLM-compatible) | ||
| from torchtitan.experiments.rl.models.parallelize import parallelize_qwen3 | ||
|
|
||
| gen_model_spec = replace(config.model_spec, parallelize_fn=parallelize_qwen3) |
There was a problem hiding this comment.
wait, do we still need this because full DTensor change has not landed? @wwwjn can we just land it, despite PP work or not? We've waited for too long.
There was a problem hiding this comment.
The PR works and under review, should be land soon
There was a problem hiding this comment.
@joecummings After this change, you are using core torchtitan qwen3's parallel plan for trainer, only patching the parallelize_qwen3 for generator, which should work. I take my words back.
| ) | ||
|
|
||
| # 5. Optimizer step | ||
| optim_metrics = self._get_rank_0_value(self.trainer.optim_step.call().get()) |
There was a problem hiding this comment.
curious what's the example benefits of separating optim and fwd/bwd for trainer?
There was a problem hiding this comment.
There's a few:
- controller control over gradient accumulation over multiple rollout batches
- controller owned LR scheduler (future direction where we can pass in optimizer params per step - helpful for interesting things in RL)
| num_steps=10, | ||
| trainer=PolicyTrainer.Config( | ||
| model_spec=model_spec, | ||
| transfer_dtype="bfloat16", |
There was a problem hiding this comment.
would this ever be different from generator.model_dtype? If not, we shouldn't give user this option only to enable the program to fail.
Another option is to allow different transfer dtype and generator dtype, so that the synced weight is converted one more time at receiver side.
3e9ee64 to
b2fac1a
Compare
| model_converters: ModelConvertersContainer.Config = field( | ||
| default_factory=ModelConvertersContainer.Config | ||
| ) | ||
| transfer_dtype: str | None = None |
There was a problem hiding this comment.
still not sure about if we should have this! https://github.com/pytorch/torchtitan/pull/2985/changes#r3097562813
at least we should assert on the dtype being the same as generator.model_dtype
There was a problem hiding this comment.
Okay no response yet from the TS team but my understanding is that we may eventually want to be able to do quantized weight transfer. However, this seems like exposing a parameter way before it's needed so let's hardcode to the generator dtype for now (similar to what we were doing originally, but now explicitly referring to it as generator dtype)
b2fac1a to
644b8e4
Compare
| model_spec=config.model_spec, | ||
| hf_assets_path=config.hf_assets_path, | ||
| transfer_dtype=config.generator.model_dtype, | ||
| generator_dtype=config.generator.model_dtype, |
…parallelize Split monolithic step() endpoint into forward_backward() and optim_step(), giving the controller explicit control over the training loop. Consolidate model_spec and transfer_dtype into PolicyTrainer.Config, add ac_config and model_converters fields to support the canonical Qwen3 parallelize function (with FSDP, AC, float8 support). Key changes: - forward_backward(train_data) returns loss metrics, optim_step() returns grad_norm + policy_version - Add save_checkpoint(path) endpoint using dcp.save() - Use get_model_state_dict() in push_model_state_dict for FSDP compatibility - Remove reference model and KL divergence from PolicyTrainer and GRPOLoss - Remove parallelize monkey-patch; trainer uses canonical parallelize, generator gets RL-specific parallelize via controller - Set disable_loss_parallel=True in RL configs since the RL loss path uses gather on full logits, not distributed cross-entropy - Remove dead code: compute_policy_gradient_loss, compute_token_log_probs - Register model_converters.post_optimizer_hook for float8 support
644b8e4 to
c764de1
Compare
| # TODO: Switch to canonical Qwen3 parallel plan | ||
| from torchtitan.experiments.rl.models.parallelize import parallelize_qwen3 | ||
|
|
||
| config.model_spec.parallelize_fn = parallelize_qwen3 |
There was a problem hiding this comment.
Have you verified switching to Qwen3 core parallel plan works for trainer and generator? Especially generator has odd number of tokens? My impression is that it won't work on generator side with odd number of tokens when TP is applied. Did CI catch this?
Note the major difference between parallelize_qwen3 and core torchtitan's qwen3 parallel plan is whether DTensor is applied within TP region, not only the loss_parallel
| # Generator uses RL-specific parallelize (TP-only, no FSDP, vLLM-compatible) | ||
| from torchtitan.experiments.rl.models.parallelize import parallelize_qwen3 | ||
|
|
||
| gen_model_spec = replace(config.model_spec, parallelize_fn=parallelize_qwen3) |
There was a problem hiding this comment.
The PR works and under review, should be land soon
| # Generator uses RL-specific parallelize (TP-only, no FSDP, vLLM-compatible) | ||
| from torchtitan.experiments.rl.models.parallelize import parallelize_qwen3 | ||
|
|
||
| gen_model_spec = replace(config.model_spec, parallelize_fn=parallelize_qwen3) |
There was a problem hiding this comment.
@joecummings After this change, you are using core torchtitan qwen3's parallel plan for trainer, only patching the parallelize_qwen3 for generator, which should work. I take my words back.
| training=TrainingConfig(), | ||
| parallelism=ParallelismConfig( | ||
| tensor_parallel_degree=2, | ||
| disable_loss_parallel=True, |
There was a problem hiding this comment.
Becase this PR switches trainer's parallel plan to be core torchtitan's qwen3 parallelism plan , yes we would need to have this change to make sure trainer's parallelism is applied the same before and after.
However, I'm a little bit doubt that loss parallel won't work for trainer. Would need to dig more into the how loss parallel works in RL settings
FYI after #2149 this PR landed, we can switch both trainer and generator's parallel plan to core torchtitan's qwen3 plan
) The trainer refactor (#2985) changed push_model_state_dict to call get_model_state_dict() instead of `model.state_dict()`. While the end result of these calls are the same (returns sharded state_dict with DTensors), calling `get_model_state_dict` adds some overhead. Average push time over 10 steps for Qwen3 0.6B: with `get_model_state_dict()`: 0.0170s with `model.state_dict()`: 0.008s Revert to previous way of calling self.model.state_dict() directly.
Context
Main goal: Refactor and clean up the Trainer interface to be flexible and similar to familiar Tinker's API
Other things done:
init_loggingto trainer so you can actually see the logs from PolicyTrainer - before was no outputTesting