[rl] Generator refactor#3001
Conversation
cd54a8f to
95e52a6
Compare
95e52a6 to
c4e9587
Compare
…o controller - Split Episode into Completion / ScoredCompletion / Episode: Completion is the generator's output (no reward, no group). The grader returns ScoredCompletion. The controller builds Episode (flat) with advantages from grouped rewards. - generate(prompts, sampling=None) now returns list[Completion]. Removes expected_answers arg and group_id stamping. Drops prompt_logprobs=1 (no production consumer; saves compute). prompt_idx is pinned to int(output.request_id). - Grader takes pre-grouped completions (list[list[Completion]]) + per- prompt expected_answers so reward_fn can be called with its natural batched shape. Grader derives no grouping itself. - Controller groups by prompt_idx before scoring, computes mean-baseline advantages per prompt, builds Episodes in input order. - evaluate() now uses greedy sampling (n=1, temperature=0.0) via per-call SamplingConfig override. Pre/post-refactor eval accuracy numbers may differ; the older path picked the first stochastic sample. - Config: num_samples_per_prompt moves from VLLMGenerator.Config to SamplingConfig.n (matches vLLM). All callsites migrated, including integration_tests CLI flag. - token_log_probs -> token_logprobs (matches TrainBatch's existing name). - Stale docstring / TODO sweep: VLLMGenerator class + __init__, Grader class, evaluate(); removed the group_id GRPO-decoupling TODO.
- Grader.score takes (completions: list[Completion], expected_answer: str) and scores one prompt's group per call. Controller loops over groups and issues one RPC per prompt. Aligns the code with the "pre-grouped so reward sees its batched shape" intent from 7921348. - SumDigitsTask.reward_function signature: list[Completion] instead of list[str]; extracts .text internally. Matches the "completions" parameter name to the Completion dataclass. - Controller folds grouping, scoring, mean-baseline advantage, and Episode construction into a single loop. Removes the defaultdict rewards_by_prompt / group_mean dict-comp bookkeeping.
Disambiguates from the `SamplingConfig` type used in the annotation and prevents positional misuse. No callers passed it by keyword, so no callsite updates needed.
c4e9587 to
af2d06c
Compare
|
|
||
| Args: | ||
| config: Generator-specific configuration. | ||
| model_spec: TorchTitan model specification. |
There was a problem hiding this comment.
This was undocumented before.
| expected_answers: list[str], | ||
| ) -> list[Episode]: | ||
| """Generate completions and return a flat list of Episodes. | ||
| prompts: list[str], |
There was a problem hiding this comment.
This now matches the standard vLLM generate: https://docs.vllm.ai/en/stable/getting_started/quickstart/#offline-batched-inference
There was a problem hiding this comment.
Just keeping "config" instead of "params" to be in line with Titan
| """Generate completions and return a flat list of Episodes. | ||
| prompts: list[str], | ||
| *, | ||
| sampling_config: SamplingConfig | None = None, |
There was a problem hiding this comment.
Allowing this here b/c we eventually WILL want to support different sampling configs e.g. per task and also for eval versus training.
| c = completions[i * n] | ||
| extracted = extract_answer(c.text) | ||
| correct += int(extracted == int(answer)) | ||
| format_ok += int(bool(re.search(r"\[ANSWER\]", c.text))) |
There was a problem hiding this comment.
In a follow up this should be the same as the task itself. That's out of scope for this refactor.
…pisode - Add expected_answer: str field to Episode. Generator output (Completion) deliberately still lacks it -- this is a post-Grader convenience, not a regression of the narrowing. - _log_samples now accepts either Episodes or Completions via duck-typing: reads expected_answer / reward from the item when present, falls back to the expected_answers arg for the Completion path. Eval's inline log block is replaced by a single call. - sum_digits.reward_function drops the `text = c.text` local.
_log_samples is debug output, not an eval metric. Strip the correctness / mark / expected-answer logic out and just dump the first text per prompt (plus reward when available). Drops the `expected_answers` arg entirely and, with no remaining caller, removes the Episode.expected_answer field that was added one commit prior. Note: 79520fb and this commit collapse to a no-op on Episode/expected_answer; squash before PR if you care.
The 14B config was added upstream after d8dad1d's original num_samples_per_prompt -> SamplingConfig.n migration and slipped through the audit: it still passed num_samples_per_prompt=8 to VLLMGenerator.Config, which no longer accepts that kwarg and would TypeError at config construction. Squash into d8dad1d before PR.
tianyu-l
left a comment
There was a problem hiding this comment.
Looks very neat. Had one comment.
| Flat shape (rather than composition) because the trainer collate | ||
| path and logging read these fields directly. |
There was a problem hiding this comment.
This doesn't sound fundamental?
Also, even if we keep it flat, it may be better to inherit Completion for maintainability?
There was a problem hiding this comment.
Personally, I would prefer to not inherit from Completion. I would consider adding it as a argument, e.g. Episode.completion.text, but my intuition is that leaving it flat is the best call.
People may have all sort of logic in different places, and i don't think that we should enforce the coupling, e.g. let's say that the generator outputs CompletionWithLargeMetadata that i don't want to add to my Episode.
|
in the output from your test command, why is the weight sync so slow? are you not using gpu direct? something seems off there, for .6B it should be ~.1 seconds |
felipemello1
left a comment
There was a problem hiding this comment.
left some comments, nothing blocking
| ep.reward = rewards[0].item() | ||
| rewards = self.reward_fn(completions, expected_answer) | ||
| scored = [ | ||
| ScoredCompletion(completion=c, reward=r.item()) |
There was a problem hiding this comment.
not a huge fan of ScoredCompletion. I would prefer that after we get a completion, we start putting data in an Episode.
Would this dataclass be used only here? If so, then i think it might be too ephemeral.
i wouldnt block the PR on this though.
There was a problem hiding this comment.
a simple solution could be to build episodes, pass them here and have the function update and return them, or just return the rewards themselves, to reduce data transfer
There was a problem hiding this comment.
Ack - I wanted a data structure to hold the reward per completion but with a better handling of rewards in general this is not as important.
My general sense is Grader doesn't do a ton here overall.
| @@ -75,30 +77,32 @@ def get_system_prompt(self) -> str: | |||
|
|
|||
| def reward_function( | |||
There was a problem hiding this comment.
this is confusing. We have another reward function here. Whats the difference between sum_digits and simple_sum_digits? can we delete one of them?
There was a problem hiding this comment.
Next refactor! Yeah I don't love this either but didn't want to change reward structure in this PR
| Flat shape (rather than composition) because the trainer collate | ||
| path and logging read these fields directly. |
There was a problem hiding this comment.
Personally, I would prefer to not inherit from Completion. I would consider adding it as a argument, e.g. Episode.completion.text, but my intuition is that leaving it flat is the best call.
People may have all sort of logic in different places, and i don't think that we should enforce the coupling, e.g. let's say that the generator outputs CompletionWithLargeMetadata that i don't want to add to my Episode.
Summary
Narrow the RL
VLLMGeneratorto pure generation. Reward, advantage, and groupingnow live in the controller +
Grader, not in the generator's output.Changes
Types (
types.py)Episodeinto three:Completion(generator output — no reward,no advantage, no group),
ScoredCompletion(Completion + scalar reward fromGrader), and
Episode(what the trainer consumes — flat scored sample withadvantage). ;
token_log_probs→token_logprobs(matchesTrainBatch).Generator (
actors/generator.py)generate(prompts) -> list[Completion]. Noexpected_answersarg, no groupstamping.
prompt_idxis pinned toint(output.request_id).prompt_logprobs=1dropped (no consumer, saves compute).num_samples_per_promptmoved fromVLLMGenerator.ConfigintoSamplingConfig.nto match vLLM.generate(... *, sampling_config=None): per-call sampling override is nowkw-only and disambiguated from the
SamplingConfigtype.Grader (
actors/grader.py)score(completions: list[Completion], expected_answer: str) -> list[ScoredCompletion].Scores one prompt's group per call so
reward_fnsees its natural batchedshape. Controller loops over prompts, issuing one RPC each.
Controller (
simple_grpo_sum_digits.py)Completions byprompt_idx, scores each group, computes mean-baselineadvantage per group in a single pass, builds
Episodes.evaluate()uses the newCompletiontype;completions[i * n]stride picksthe first sample per prompt.
_log_samplesrefocused on debug output only: printsprompt_idx, optionalreward, and truncated text. Accepts
Episodes orCompletions.Task (
sum_digits.py)reward_function(completions: list[Completion], expected_answer)consumes thenew
Completiontype directly.Configs (
config_registry.py,tests/integration_tests.py)num_samples_per_prompttoSamplingConfig.n. Integration-test CLI flag updated.Test Plan
rl_grpo_qwen3_0_6bend-to-end (10 steps, TP=2 trainer + TP=4generator on 6 H100s). Clean exit, eval accuracy 40% → 50%. https://www.internalfb.com/phabricator/paste/view/P2282561759