Skip to content

[rl] Trainer refactor#2985

Merged
joecummings merged 1 commit intopytorch:mainfrom
joecummings:worktree-trainer-refactor
Apr 20, 2026
Merged

[rl] Trainer refactor#2985
joecummings merged 1 commit intopytorch:mainfrom
joecummings:worktree-trainer-refactor

Conversation

@joecummings
Copy link
Copy Markdown
Member

@joecummings joecummings commented Apr 15, 2026

Context

Main goal: Refactor and clean up the Trainer interface to be flexible and similar to familiar Tinker's API

Other things done:

  • Added init_logging to trainer so you can actually see the logs from PolicyTrainer - before was no output
  • Switched to use default Qwen3 parallelize function for PolicyTrainer which means you can actually now use FSDP, TP, DDP, etc. Before you could only use TP
  • Added ability to save a checkpoint - just standard
  • Removed KL calculation - we weren't using it

Testing

  • CI
  • Ran 0.6B script - same speed, same result
  • Ran 0.6B script with FSDP - faster speed, same result
  • Ran 0.6B script and saved a DCP checkpoint - checkpoint saves :)

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 15, 2026
@joecummings joecummings force-pushed the worktree-trainer-refactor branch 2 times, most recently from 7b46b2e to fbf1662 Compare April 16, 2026 18:03
@joecummings joecummings marked this pull request as ready for review April 16, 2026 18:07
compile: CompileConfig = field(default_factory=CompileConfig)
debug: DebugConfig = field(default_factory=DebugConfig)
loss: Configurable.Config = field(default_factory=Configurable.Config)
ac_config: ActivationCheckpointConfig = field(
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These additions are needed when properly doing parallelisms

Comment thread torchtitan/experiments/rl/tests/test_bitwise_identity.py Outdated
Copy link
Copy Markdown
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First pass, left some comments and nits.

regarding testing, you said that the different setting "passes". Whats the criteria? should we be expecting the same final value? same speed?

Comment thread torchtitan/experiments/rl/actors/trainer.py Outdated
Comment thread torchtitan/experiments/rl/actors/trainer.py Outdated
Comment thread torchtitan/experiments/rl/actors/trainer.py
Comment thread torchtitan/experiments/rl/actors/trainer.py
Comment thread torchtitan/experiments/rl/actors/trainer.py Outdated
Comment thread torchtitan/experiments/rl/actors/trainer.py
Comment thread torchtitan/experiments/rl/tests/test_bitwise_identity.py Outdated
num_steps=10,
trainer=PolicyTrainer.Config(
model_spec=model_spec,
transfer_dtype="bfloat16",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should this be hardcoded?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather this be explicit, but maybe @daniellepintz has thoughts

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this ever be different from generator.model_dtype? If not, we shouldn't give user this option only to enable the program to fail.

Another option is to allow different transfer dtype and generator dtype, so that the synced weight is converted one more time at receiver side.

training=TrainingConfig(),
parallelism=ParallelismConfig(
tensor_parallel_degree=2,
disable_loss_parallel=True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should this be hardcoded?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we were using the TP-only parallelize plan, it had hardcoded the output to return plain tensors. For now, we can just keep it the same by disabling loss_parallel in the standard parallelize plan.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wwwjn pls confirm this

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Becase this PR switches trainer's parallel plan to be core torchtitan's qwen3 parallelism plan , yes we would need to have this change to make sure trainer's parallelism is applied the same before and after.

However, I'm a little bit doubt that loss parallel won't work for trainer. Would need to dig more into the how loss parallel works in RL settings

FYI after #2149 this PR landed, we can switch both trainer and generator's parallel plan to core torchtitan's qwen3 plan

Comment thread torchtitan/experiments/rl/simple_grpo_sum_digits.py
@joecummings
Copy link
Copy Markdown
Member Author

regarding testing, you said that the different setting "passes". Whats the criteria? should we be expecting the same final value? same speed?

Updated the description to be more clear

@joecummings joecummings force-pushed the worktree-trainer-refactor branch from 5c2b5c4 to 832919f Compare April 16, 2026 21:54
Comment thread torchtitan/experiments/rl/actors/trainer.py
Comment thread torchtitan/experiments/rl/actors/trainer.py
Comment thread torchtitan/experiments/rl/actors/trainer.py Outdated
@joecummings joecummings force-pushed the worktree-trainer-refactor branch 2 times, most recently from 686e9a6 to 3e9ee64 Compare April 16, 2026 23:09
model_converters: ModelConvertersContainer.Config = field(
default_factory=ModelConvertersContainer.Config
)
transfer_dtype: str = ""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall the type be https://github.com/pytorch/torchtitan/blob/main/torchtitan/config/configs.py#L50

we can use None for "not requested"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be combined with #2942 in a combined WeightSyncConfig, depending on which lands first. cc @daniellepintz

Comment thread torchtitan/experiments/rl/actors/trainer.py

# Conditionally build frozen reference model for KL penalty
# TODO: @joecummings remove ref entirely, this is hacky and we don't need it
if getattr(config.loss, "kl_coef", 0) > 0:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC @wwwjn had some concerns here, e.g. if we remove ref model entirely, then is it still GRPO?

I'm OK with removing it entirely for now to focus on the core part.

# Post optimizer step model converters hook.
# e.g. calculate float8 dynamic amax/scale for FSDP2
self.optimizers.register_step_post_hook(
lambda *args, **kwargs: model_converters.post_optimizer_hook(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@acisseJZhong in your refactor, I don't think we need this hook anymore -- let's discuss this later.


Returns:
dict: Training metrics (loss, policy version, etc.).
dict with loss metrics, advantage stats, and logprob verification.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should make this into a dataclass as well? Please leave a TODO if so.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends on the direction we go with Observability. We could record all these metrics here, directly with record_metrics (or something similar). If we end up forgoing that and returning metrics here (not my preferred option), then we might consider a dataclass. Punting on this to @felipemello1's work

Comment thread torchtitan/experiments/rl/actors/trainer.py
training=TrainingConfig(),
parallelism=ParallelismConfig(
tensor_parallel_degree=2,
disable_loss_parallel=True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wwwjn pls confirm this

# Generator uses RL-specific parallelize (TP-only, no FSDP, vLLM-compatible)
from torchtitan.experiments.rl.models.parallelize import parallelize_qwen3

gen_model_spec = replace(config.model_spec, parallelize_fn=parallelize_qwen3)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, do we still need this because full DTensor change has not landed? @wwwjn can we just land it, despite PP work or not? We've waited for too long.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR works and under review, should be land soon

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joecummings After this change, you are using core torchtitan qwen3's parallel plan for trainer, only patching the parallelize_qwen3 for generator, which should work. I take my words back.

)

# 5. Optimizer step
optim_metrics = self._get_rank_0_value(self.trainer.optim_step.call().get())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious what's the example benefits of separating optim and fwd/bwd for trainer?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a few:

  • controller control over gradient accumulation over multiple rollout batches
  • controller owned LR scheduler (future direction where we can pass in optimizer params per step - helpful for interesting things in RL)

num_steps=10,
trainer=PolicyTrainer.Config(
model_spec=model_spec,
transfer_dtype="bfloat16",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this ever be different from generator.model_dtype? If not, we shouldn't give user this option only to enable the program to fail.

Another option is to allow different transfer dtype and generator dtype, so that the synced weight is converted one more time at receiver side.

@joecummings joecummings force-pushed the worktree-trainer-refactor branch from 3e9ee64 to b2fac1a Compare April 17, 2026 14:46
Comment thread torchtitan/experiments/rl/actors/trainer.py Outdated
model_converters: ModelConvertersContainer.Config = field(
default_factory=ModelConvertersContainer.Config
)
transfer_dtype: str | None = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still not sure about if we should have this! https://github.com/pytorch/torchtitan/pull/2985/changes#r3097562813

at least we should assert on the dtype being the same as generator.model_dtype

Copy link
Copy Markdown
Member Author

@joecummings joecummings Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay no response yet from the TS team but my understanding is that we may eventually want to be able to do quantized weight transfer. However, this seems like exposing a parameter way before it's needed so let's hardcode to the generator dtype for now (similar to what we were doing originally, but now explicitly referring to it as generator dtype)

@joecummings joecummings force-pushed the worktree-trainer-refactor branch from b2fac1a to 644b8e4 Compare April 20, 2026 14:35
model_spec=config.model_spec,
hf_assets_path=config.hf_assets_path,
transfer_dtype=config.generator.model_dtype,
generator_dtype=config.generator.model_dtype,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea this looks good

…parallelize

Split monolithic step() endpoint into forward_backward() and optim_step(),
giving the controller explicit control over the training loop. Consolidate
model_spec and transfer_dtype into PolicyTrainer.Config, add ac_config and
model_converters fields to support the canonical Qwen3 parallelize function
(with FSDP, AC, float8 support).

Key changes:
- forward_backward(train_data) returns loss metrics, optim_step() returns
  grad_norm + policy_version
- Add save_checkpoint(path) endpoint using dcp.save()
- Use get_model_state_dict() in push_model_state_dict for FSDP compatibility
- Remove reference model and KL divergence from PolicyTrainer and GRPOLoss
- Remove parallelize monkey-patch; trainer uses canonical parallelize,
  generator gets RL-specific parallelize via controller
- Set disable_loss_parallel=True in RL configs since the RL loss path
  uses gather on full logits, not distributed cross-entropy
- Remove dead code: compute_policy_gradient_loss, compute_token_log_probs
- Register model_converters.post_optimizer_hook for float8 support
@joecummings joecummings force-pushed the worktree-trainer-refactor branch from 644b8e4 to c764de1 Compare April 20, 2026 20:56
@joecummings joecummings added the ciflow/h100.8 Trigger H100.8 CI label Apr 20, 2026
@joecummings joecummings merged commit 627f4a3 into pytorch:main Apr 20, 2026
9 of 13 checks passed
@joecummings joecummings deleted the worktree-trainer-refactor branch April 20, 2026 21:06
# TODO: Switch to canonical Qwen3 parallel plan
from torchtitan.experiments.rl.models.parallelize import parallelize_qwen3

config.model_spec.parallelize_fn = parallelize_qwen3
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you verified switching to Qwen3 core parallel plan works for trainer and generator? Especially generator has odd number of tokens? My impression is that it won't work on generator side with odd number of tokens when TP is applied. Did CI catch this?

Note the major difference between parallelize_qwen3 and core torchtitan's qwen3 parallel plan is whether DTensor is applied within TP region, not only the loss_parallel

# Generator uses RL-specific parallelize (TP-only, no FSDP, vLLM-compatible)
from torchtitan.experiments.rl.models.parallelize import parallelize_qwen3

gen_model_spec = replace(config.model_spec, parallelize_fn=parallelize_qwen3)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR works and under review, should be land soon

# Generator uses RL-specific parallelize (TP-only, no FSDP, vLLM-compatible)
from torchtitan.experiments.rl.models.parallelize import parallelize_qwen3

gen_model_spec = replace(config.model_spec, parallelize_fn=parallelize_qwen3)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joecummings After this change, you are using core torchtitan qwen3's parallel plan for trainer, only patching the parallelize_qwen3 for generator, which should work. I take my words back.

training=TrainingConfig(),
parallelism=ParallelismConfig(
tensor_parallel_degree=2,
disable_loss_parallel=True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Becase this PR switches trainer's parallel plan to be core torchtitan's qwen3 parallelism plan , yes we would need to have this change to make sure trainer's parallelism is applied the same before and after.

However, I'm a little bit doubt that loss parallel won't work for trainer. Would need to dig more into the how loss parallel works in RL settings

FYI after #2149 this PR landed, we can switch both trainer and generator's parallel plan to core torchtitan's qwen3 plan

daniellepintz added a commit that referenced this pull request Apr 23, 2026
)

The trainer refactor (#2985) changed push_model_state_dict to call
get_model_state_dict() instead of `model.state_dict()`. While the end
result of these calls are the same (returns sharded state_dict with
DTensors), calling `get_model_state_dict` adds some overhead.

Average push time over 10 steps for Qwen3 0.6B:
with `get_model_state_dict()`: 0.0170s
with `model.state_dict()`: 0.008s

Revert to previous way of calling self.model.state_dict() directly.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/h100.8 Trigger H100.8 CI ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants