Skip to content

[rl] Generator refactor#3001

Merged
joecummings merged 7 commits intopytorch:mainfrom
joecummings:worktree-gen-refactor
Apr 22, 2026
Merged

[rl] Generator refactor#3001
joecummings merged 7 commits intopytorch:mainfrom
joecummings:worktree-gen-refactor

Conversation

@joecummings
Copy link
Copy Markdown
Member

@joecummings joecummings commented Apr 16, 2026

Summary

Narrow the RL VLLMGenerator to pure generation. Reward, advantage, and grouping
now live in the controller + Grader, not in the generator's output.

Changes

Types (types.py)

  • Split the old Episode into three: Completion (generator output — no reward,
    no advantage, no group), ScoredCompletion (Completion + scalar reward from
    Grader), and Episode (what the trainer consumes — flat scored sample with
    advantage). ;token_log_probstoken_logprobs (matches TrainBatch).

Generator (actors/generator.py)

  • generate(prompts) -> list[Completion]. No expected_answers arg, no group
    stamping. prompt_idx is pinned to int(output.request_id).
  • prompt_logprobs=1 dropped (no consumer, saves compute).
  • num_samples_per_prompt moved from VLLMGenerator.Config into
    SamplingConfig.n to match vLLM.
  • generate(... *, sampling_config=None): per-call sampling override is now
    kw-only and disambiguated from the SamplingConfig type.

Grader (actors/grader.py)

  • score(completions: list[Completion], expected_answer: str) -> list[ScoredCompletion].
    Scores one prompt's group per call so reward_fn sees its natural batched
    shape. Controller loops over prompts, issuing one RPC each.

Controller (simple_grpo_sum_digits.py)

  • Groups Completions by prompt_idx, scores each group, computes mean-baseline
    advantage per group in a single pass, builds Episodes.
  • evaluate() uses the new Completion type; completions[i * n] stride picks
    the first sample per prompt.
  • _log_samples refocused on debug output only: prints prompt_idx, optional
    reward, and truncated text. Accepts Episodes or Completions.

Task (sum_digits.py)

  • reward_function(completions: list[Completion], expected_answer) consumes the
    new Completion type directly.

Configs (config_registry.py, tests/integration_tests.py)

  • All five configs migrated from num_samples_per_prompt to
    SamplingConfig.n. Integration-test CLI flag updated.

Test Plan

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 16, 2026
@joecummings joecummings force-pushed the worktree-gen-refactor branch from cd54a8f to 95e52a6 Compare April 17, 2026 17:59
@joecummings joecummings changed the title [rl] [not ready for review] Generator Refactor [rl] [not ready for review] Generator refactor Apr 17, 2026
@joecummings joecummings force-pushed the worktree-gen-refactor branch from 95e52a6 to c4e9587 Compare April 21, 2026 19:39
…o controller

- Split Episode into Completion / ScoredCompletion / Episode:
  Completion is the generator's output (no reward, no group). The grader
  returns ScoredCompletion. The controller builds Episode (flat) with
  advantages from grouped rewards.
- generate(prompts, sampling=None) now returns list[Completion]. Removes
  expected_answers arg and group_id stamping. Drops prompt_logprobs=1
  (no production consumer; saves compute). prompt_idx is pinned to
  int(output.request_id).
- Grader takes pre-grouped completions (list[list[Completion]]) + per-
  prompt expected_answers so reward_fn can be called with its natural
  batched shape. Grader derives no grouping itself.
- Controller groups by prompt_idx before scoring, computes mean-baseline
  advantages per prompt, builds Episodes in input order.
- evaluate() now uses greedy sampling (n=1, temperature=0.0) via per-call
  SamplingConfig override. Pre/post-refactor eval accuracy numbers may
  differ; the older path picked the first stochastic sample.
- Config: num_samples_per_prompt moves from VLLMGenerator.Config to
  SamplingConfig.n (matches vLLM). All callsites migrated, including
  integration_tests CLI flag.
- token_log_probs -> token_logprobs (matches TrainBatch's existing name).
- Stale docstring / TODO sweep: VLLMGenerator class + __init__, Grader
  class, evaluate(); removed the group_id GRPO-decoupling TODO.
- Grader.score takes (completions: list[Completion], expected_answer: str)
  and scores one prompt's group per call. Controller loops over groups
  and issues one RPC per prompt. Aligns the code with the "pre-grouped
  so reward sees its batched shape" intent from 7921348.
- SumDigitsTask.reward_function signature: list[Completion] instead of
  list[str]; extracts .text internally. Matches the "completions"
  parameter name to the Completion dataclass.
- Controller folds grouping, scoring, mean-baseline advantage, and
  Episode construction into a single loop. Removes the defaultdict
  rewards_by_prompt / group_mean dict-comp bookkeeping.
Disambiguates from the `SamplingConfig` type used in the annotation and
prevents positional misuse. No callers passed it by keyword, so no
callsite updates needed.
@joecummings joecummings force-pushed the worktree-gen-refactor branch from c4e9587 to af2d06c Compare April 21, 2026 19:47

Args:
config: Generator-specific configuration.
model_spec: TorchTitan model specification.
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was undocumented before.

expected_answers: list[str],
) -> list[Episode]:
"""Generate completions and return a flat list of Episodes.
prompts: list[str],
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just keeping "config" instead of "params" to be in line with Titan

"""Generate completions and return a flat list of Episodes.
prompts: list[str],
*,
sampling_config: SamplingConfig | None = None,
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allowing this here b/c we eventually WILL want to support different sampling configs e.g. per task and also for eval versus training.

c = completions[i * n]
extracted = extract_answer(c.text)
correct += int(extracted == int(answer))
format_ok += int(bool(re.search(r"\[ANSWER\]", c.text)))
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a follow up this should be the same as the task itself. That's out of scope for this refactor.

…pisode

- Add expected_answer: str field to Episode. Generator output
  (Completion) deliberately still lacks it -- this is a post-Grader
  convenience, not a regression of the narrowing.
- _log_samples now accepts either Episodes or Completions via
  duck-typing: reads expected_answer / reward from the item when
  present, falls back to the expected_answers arg for the Completion
  path. Eval's inline log block is replaced by a single call.
- sum_digits.reward_function drops the `text = c.text` local.
_log_samples is debug output, not an eval metric. Strip the
correctness / mark / expected-answer logic out and just dump the
first text per prompt (plus reward when available). Drops the
`expected_answers` arg entirely and, with no remaining caller,
removes the Episode.expected_answer field that was added one commit
prior.

Note: 79520fb and this commit collapse to a no-op on
Episode/expected_answer; squash before PR if you care.
The 14B config was added upstream after d8dad1d's original
num_samples_per_prompt -> SamplingConfig.n migration and slipped
through the audit: it still passed num_samples_per_prompt=8 to
VLLMGenerator.Config, which no longer accepts that kwarg and would
TypeError at config construction.

Squash into d8dad1d before PR.
Line-wrap adjustments in two spots ufmt flagged. No behavior change.
Squash into the originating commits (af2d06c, 79cb7d2) before PR.
@joecummings joecummings marked this pull request as ready for review April 21, 2026 21:24
@joecummings joecummings changed the title [rl] [not ready for review] Generator refactor [rl] Generator refactor Apr 21, 2026
Copy link
Copy Markdown
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very neat. Had one comment.

Comment on lines +40 to +41
Flat shape (rather than composition) because the trainer collate
path and logging read these fields directly.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't sound fundamental?

Also, even if we keep it flat, it may be better to inherit Completion for maintainability?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, I would prefer to not inherit from Completion. I would consider adding it as a argument, e.g. Episode.completion.text, but my intuition is that leaving it flat is the best call.

People may have all sort of logic in different places, and i don't think that we should enforce the coupling, e.g. let's say that the generator outputs CompletionWithLargeMetadata that i don't want to add to my Episode.

@daniellepintz
Copy link
Copy Markdown
Contributor

in the output from your test command, why is the weight sync so slow? are you not using gpu direct? something seems off there, for .6B it should be ~.1 seconds

Copy link
Copy Markdown
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left some comments, nothing blocking

ep.reward = rewards[0].item()
rewards = self.reward_fn(completions, expected_answer)
scored = [
ScoredCompletion(completion=c, reward=r.item())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a huge fan of ScoredCompletion. I would prefer that after we get a completion, we start putting data in an Episode.

Would this dataclass be used only here? If so, then i think it might be too ephemeral.

i wouldnt block the PR on this though.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a simple solution could be to build episodes, pass them here and have the function update and return them, or just return the rewards themselves, to reduce data transfer

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack - I wanted a data structure to hold the reward per completion but with a better handling of rewards in general this is not as important.

My general sense is Grader doesn't do a ton here overall.

@@ -75,30 +77,32 @@ def get_system_prompt(self) -> str:

def reward_function(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is confusing. We have another reward function here. Whats the difference between sum_digits and simple_sum_digits? can we delete one of them?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Next refactor! Yeah I don't love this either but didn't want to change reward structure in this PR

Comment on lines +40 to +41
Flat shape (rather than composition) because the trainer collate
path and logging read these fields directly.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, I would prefer to not inherit from Completion. I would consider adding it as a argument, e.g. Episode.completion.text, but my intuition is that leaving it flat is the best call.

People may have all sort of logic in different places, and i don't think that we should enforce the coupling, e.g. let's say that the generator outputs CompletionWithLargeMetadata that i don't want to add to my Episode.

@joecummings joecummings merged commit d52d247 into pytorch:main Apr 22, 2026
11 of 14 checks passed
@joecummings joecummings deleted the worktree-gen-refactor branch April 22, 2026 15:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants