Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/async_grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ Because generation and training run concurrently, the training samples may have

The number of concurrent requests sent to the vLLM server is controlled by `max_inflight_tasks`. By default it is set automatically to `max_staleness × per_device_train_batch_size × gradient_accumulation_steps × num_processes` — the maximum number of samples the trainer can consume before they become stale. Generating more than this is wasteful since the excess samples will be discarded.

For generation, [`AsyncGRPOConfig`] supports the same primary sampling controls as [`GRPOConfig`]: `temperature`,
`top_p`, `top_k`, `min_p`, `repetition_penalty`, and `generation_kwargs`. As with [`GRPOTrainer`], keys provided in
`generation_kwargs` take precedence over the named sampling arguments.

## Quick start

```python
Expand Down
106 changes: 106 additions & 0 deletions tests/experimental/test_async_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import itertools
import queue

Expand All @@ -21,6 +22,8 @@
from transformers import AutoTokenizer

from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer
from trl.experimental.async_grpo import async_grpo_trainer as async_grpo_trainer_module
from trl.experimental.async_grpo import async_rollout_worker as async_rollout_worker_module
from trl.experimental.async_grpo.async_rollout_worker import RolloutSample

from ..testing_utils import TrlTestCase
Expand Down Expand Up @@ -91,7 +94,110 @@ def send_weights(self, iterator):
pass


class _CapturingRolloutWorker:
last_init_kwargs = None

def __init__(self, **kwargs):
type(self).last_init_kwargs = kwargs
self.rollout_buffer = queue.Queue()

def start(self):
pass

def update_model_version(self, version):
pass

def stop(self):
pass

def pause(self):
pass

def resume(self):
pass

def send_weights(self, iterator):
pass


class TestAsyncGRPOTrainer(TrlTestCase):
def test_init_passes_sampling_config_to_rollout_worker(self, monkeypatch):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train")
tokenizer = AutoTokenizer.from_pretrained(model_id)
training_args = AsyncGRPOConfig(
output_dir=self.tmp_dir,
num_generations=3,
max_completion_length=8,
temperature=0.7,
top_p=0.9,
top_k=10,
min_p=0.01,
repetition_penalty=1.1,
generation_kwargs={"top_k": 50, "seed": 7},
report_to="none",
)

monkeypatch.setattr(async_grpo_trainer_module, "AsyncRolloutWorker", _CapturingRolloutWorker)

AsyncGRPOTrainer(
model=model_id,
reward_funcs=dummy_reward_func,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
)

assert _CapturingRolloutWorker.last_init_kwargs["max_tokens"] == 8
assert _CapturingRolloutWorker.last_init_kwargs["temperature"] == 0.7
assert _CapturingRolloutWorker.last_init_kwargs["top_p"] == 0.9
assert _CapturingRolloutWorker.last_init_kwargs["top_k"] == 10
assert _CapturingRolloutWorker.last_init_kwargs["min_p"] == 0.01
assert _CapturingRolloutWorker.last_init_kwargs["repetition_penalty"] == 1.1
assert _CapturingRolloutWorker.last_init_kwargs["generation_kwargs"] == {"top_k": 50, "seed": 7}

def test_rollout_worker_generation_kwargs_override_named_sampling_params(self):
worker = async_rollout_worker_module.AsyncRolloutWorker.__new__(async_rollout_worker_module.AsyncRolloutWorker)
worker.model_name = "Qwen/Qwen3-4B"
worker.max_tokens = 32
worker.temperature = 0.9
worker.top_p = 0.95
worker.top_k = 10
worker.min_p = None
worker.repetition_penalty = 1.1
worker.generation_kwargs = {"top_k": 50, "temperature": 0.2, "seed": 123}
worker.request_timeout = 17
captured = {}

async def fake_post(path, payload, timeout):
captured["path"] = path
captured["payload"] = payload
captured["timeout"] = timeout
return {"choices": [{"token_ids": [42], "logprobs": {"token_logprobs": [-0.5]}}]}

worker._post = fake_post

completion_ids, completion_logprobs = asyncio.run(worker._generate_one_turn([1, 2, 3]))

assert completion_ids == [42]
assert completion_logprobs == [-0.5]
assert captured["path"] == "/v1/completions"
assert captured["timeout"] == 17
assert captured["payload"] == {
"model": "Qwen/Qwen3-4B",
"prompt": [1, 2, 3],
"max_tokens": 32,
"temperature": 0.2,
"top_p": 0.95,
"top_k": 50,
"min_p": 0.0,
"repetition_penalty": 1.1,
"n": 1,
"return_token_ids": True,
"logprobs": 0,
"seed": 123,
}

def test_init_minimal(self):
# Test that AsyncGRPOTrainer can be instantiated with only model, reward_model and train_dataset
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
Expand Down
53 changes: 53 additions & 0 deletions trl/experimental/async_grpo/async_grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from dataclasses import dataclass, field
from typing import Any

from trl.trainer.base_config import _BaseConfig

Expand All @@ -36,8 +37,24 @@ class AsyncGRPOConfig(_BaseConfig):
Maximum number of tokens to generate per completion.
temperature (`float`, *optional*, defaults to `1.0`):
Temperature for sampling. The higher the temperature, the more random the completions.
top_p (`float`, *optional*, defaults to `1.0`):
Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to
`1.0` to consider all tokens.
top_k (`int`, *optional*, defaults to `0`):
Number of highest probability vocabulary tokens to keep for top-k-filtering. If `0`, top-k-filtering is
disabled and all tokens are considered.
min_p (`float`, *optional*):
Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range.
generation_kwargs (`dict[str, Any]`, *optional*):
Additional keyword arguments to pass to the vLLM generation request. If it contains keys that conflict
with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them.
chat_template_kwargs (`dict[str, Any]`, *optional*):
Additional keyword arguments to pass to the `apply_chat_template` function when generating completions.
repetition_penalty (`float`, *optional*, defaults to `1.0`):
Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.
Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat
tokens.
max_tool_calling_iterations (`int`, *optional*):
Maximum number of tool-calling turns when training an agent. If `None`, there is no limit and generation
stops when the model generates a response turn with no tool calls or when the total response length reaches
Expand Down Expand Up @@ -115,13 +132,49 @@ class AsyncGRPOConfig(_BaseConfig):
default=1.0,
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
)
top_p: float = field(
default=1.0,
metadata={
"help": "Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. "
"Set to 1.0 to consider all tokens."
},
)
top_k: int = field(
default=0,
metadata={
"help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `0`, "
"top-k-filtering is disabled and all tokens are considered."
},
)
min_p: float | None = field(
default=None,
metadata={
"help": "Minimum token probability, which will be scaled by the probability of the most likely token. It "
"must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range."
},
)
generation_kwargs: dict[str, Any] | None = field(
default=None,
metadata={
"help": "Additional keyword arguments to pass to the vLLM generation request. If it contains keys that "
"conflict with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them."
},
)
chat_template_kwargs: dict | None = field(
default=None,
metadata={
"help": "Additional keyword arguments to pass to the `apply_chat_template` function when generating "
"completions."
},
)
repetition_penalty: float = field(
default=1.0,
metadata={
"help": "Float that penalizes new tokens based on whether they appear in the prompt and the generated "
"text so far. Values > 1.0 encourage the model to use new tokens, while values < 1.0 encourage the model "
"to repeat tokens."
},
)
max_tool_calling_iterations: int | None = field(
default=None,
metadata={
Expand Down
10 changes: 10 additions & 0 deletions trl/experimental/async_grpo/async_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ def __init__(
self.epsilon_low = self.args.epsilon
self.epsilon_high = self.args.epsilon_high
self.temperature = self.args.temperature
self.top_p = self.args.top_p
self.top_k = self.args.top_k
self.min_p = self.args.min_p
self.repetition_penalty = self.args.repetition_penalty
self.generation_kwargs = self.args.generation_kwargs or {}

# Model
model_name = model
Expand Down Expand Up @@ -370,6 +375,11 @@ def __init__(
vllm_server_url=self.args.vllm_server_base_url,
max_tokens=self.args.max_completion_length,
temperature=self.args.temperature,
top_p=self.args.top_p,
top_k=self.args.top_k,
min_p=self.args.min_p,
repetition_penalty=self.args.repetition_penalty,
generation_kwargs=self.args.generation_kwargs,
request_timeout=self.args.request_timeout,
server_timeout=self.args.vllm_server_timeout,
chat_template_kwargs=self.args.chat_template_kwargs,
Expand Down
15 changes: 15 additions & 0 deletions trl/experimental/async_grpo/async_rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def __init__(
vllm_server_url: str = "http://localhost:8000",
max_tokens: int = 32,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 0,
min_p: float | None = None,
repetition_penalty: float = 1.0,
generation_kwargs: dict[str, Any] | None = None,
request_timeout: int = 120,
server_timeout: float = 240.0,
chat_template_kwargs: dict[str, Any] | None = None,
Expand Down Expand Up @@ -155,6 +160,11 @@ def __init__(
self.model_update_group = None
self.max_tokens = max_tokens
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.min_p = min_p
self.repetition_penalty = repetition_penalty
self.generation_kwargs = generation_kwargs or {}
self.request_timeout = request_timeout
self.server_timeout = server_timeout
self.chat_template_kwargs = chat_template_kwargs or {}
Expand Down Expand Up @@ -611,10 +621,15 @@ async def _generate_one_turn(self, prompt_ids: list[int]) -> tuple[list[int], li
"prompt": prompt_ids,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"min_p": 0.0 if self.min_p is None else self.min_p,
"repetition_penalty": self.repetition_penalty,
"n": 1,
"return_token_ids": True,
"logprobs": 0,
}
payload.update(self.generation_kwargs)
while True:
try:
output = await self._post("/v1/completions", payload, self.request_timeout)
Expand Down