diff --git a/docs/source/async_grpo_trainer.md b/docs/source/async_grpo_trainer.md index e1e49c11933..2a6aaf5aae7 100644 --- a/docs/source/async_grpo_trainer.md +++ b/docs/source/async_grpo_trainer.md @@ -31,6 +31,10 @@ Because generation and training run concurrently, the training samples may have The number of concurrent requests sent to the vLLM server is controlled by `max_inflight_tasks`. By default it is set automatically to `max_staleness × per_device_train_batch_size × gradient_accumulation_steps × num_processes` — the maximum number of samples the trainer can consume before they become stale. Generating more than this is wasteful since the excess samples will be discarded. +For generation, [`AsyncGRPOConfig`] supports the same primary sampling controls as [`GRPOConfig`]: `temperature`, +`top_p`, `top_k`, `min_p`, `repetition_penalty`, and `generation_kwargs`. As with [`GRPOTrainer`], keys provided in +`generation_kwargs` take precedence over the named sampling arguments. + ## Quick start ```python diff --git a/tests/experimental/test_async_grpo_trainer.py b/tests/experimental/test_async_grpo_trainer.py index cc5413a6213..128d5a665b6 100644 --- a/tests/experimental/test_async_grpo_trainer.py +++ b/tests/experimental/test_async_grpo_trainer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import itertools import queue @@ -21,6 +22,8 @@ from transformers import AutoTokenizer from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer +from trl.experimental.async_grpo import async_grpo_trainer as async_grpo_trainer_module +from trl.experimental.async_grpo import async_rollout_worker as async_rollout_worker_module from trl.experimental.async_grpo.async_rollout_worker import RolloutSample from ..testing_utils import TrlTestCase @@ -91,7 +94,110 @@ def send_weights(self, iterator): pass +class _CapturingRolloutWorker: + last_init_kwargs = None + + def __init__(self, **kwargs): + type(self).last_init_kwargs = kwargs + self.rollout_buffer = queue.Queue() + + def start(self): + pass + + def update_model_version(self, version): + pass + + def stop(self): + pass + + def pause(self): + pass + + def resume(self): + pass + + def send_weights(self, iterator): + pass + + class TestAsyncGRPOTrainer(TrlTestCase): + def test_init_passes_sampling_config_to_rollout_worker(self, monkeypatch): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train") + tokenizer = AutoTokenizer.from_pretrained(model_id) + training_args = AsyncGRPOConfig( + output_dir=self.tmp_dir, + num_generations=3, + max_completion_length=8, + temperature=0.7, + top_p=0.9, + top_k=10, + min_p=0.01, + repetition_penalty=1.1, + generation_kwargs={"top_k": 50, "seed": 7}, + report_to="none", + ) + + monkeypatch.setattr(async_grpo_trainer_module, "AsyncRolloutWorker", _CapturingRolloutWorker) + + AsyncGRPOTrainer( + model=model_id, + reward_funcs=dummy_reward_func, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + ) + + assert _CapturingRolloutWorker.last_init_kwargs["max_tokens"] == 8 + assert _CapturingRolloutWorker.last_init_kwargs["temperature"] == 0.7 + assert _CapturingRolloutWorker.last_init_kwargs["top_p"] == 0.9 + assert _CapturingRolloutWorker.last_init_kwargs["top_k"] == 10 + assert _CapturingRolloutWorker.last_init_kwargs["min_p"] == 0.01 + assert _CapturingRolloutWorker.last_init_kwargs["repetition_penalty"] == 1.1 + assert _CapturingRolloutWorker.last_init_kwargs["generation_kwargs"] == {"top_k": 50, "seed": 7} + + def test_rollout_worker_generation_kwargs_override_named_sampling_params(self): + worker = async_rollout_worker_module.AsyncRolloutWorker.__new__(async_rollout_worker_module.AsyncRolloutWorker) + worker.model_name = "Qwen/Qwen3-4B" + worker.max_tokens = 32 + worker.temperature = 0.9 + worker.top_p = 0.95 + worker.top_k = 10 + worker.min_p = None + worker.repetition_penalty = 1.1 + worker.generation_kwargs = {"top_k": 50, "temperature": 0.2, "seed": 123} + worker.request_timeout = 17 + captured = {} + + async def fake_post(path, payload, timeout): + captured["path"] = path + captured["payload"] = payload + captured["timeout"] = timeout + return {"choices": [{"token_ids": [42], "logprobs": {"token_logprobs": [-0.5]}}]} + + worker._post = fake_post + + completion_ids, completion_logprobs = asyncio.run(worker._generate_one_turn([1, 2, 3])) + + assert completion_ids == [42] + assert completion_logprobs == [-0.5] + assert captured["path"] == "/v1/completions" + assert captured["timeout"] == 17 + assert captured["payload"] == { + "model": "Qwen/Qwen3-4B", + "prompt": [1, 2, 3], + "max_tokens": 32, + "temperature": 0.2, + "top_p": 0.95, + "top_k": 50, + "min_p": 0.0, + "repetition_penalty": 1.1, + "n": 1, + "return_token_ids": True, + "logprobs": 0, + "seed": 123, + } + def test_init_minimal(self): # Test that AsyncGRPOTrainer can be instantiated with only model, reward_model and train_dataset model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" diff --git a/trl/experimental/async_grpo/async_grpo_config.py b/trl/experimental/async_grpo/async_grpo_config.py index 2afd760e7fc..9b8d67f622c 100644 --- a/trl/experimental/async_grpo/async_grpo_config.py +++ b/trl/experimental/async_grpo/async_grpo_config.py @@ -13,6 +13,7 @@ # limitations under the License. from dataclasses import dataclass, field +from typing import Any from trl.trainer.base_config import _BaseConfig @@ -36,8 +37,24 @@ class AsyncGRPOConfig(_BaseConfig): Maximum number of tokens to generate per completion. temperature (`float`, *optional*, defaults to `1.0`): Temperature for sampling. The higher the temperature, the more random the completions. + top_p (`float`, *optional*, defaults to `1.0`): + Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to + `1.0` to consider all tokens. + top_k (`int`, *optional*, defaults to `0`): + Number of highest probability vocabulary tokens to keep for top-k-filtering. If `0`, top-k-filtering is + disabled and all tokens are considered. + min_p (`float`, *optional*): + Minimum token probability, which will be scaled by the probability of the most likely token. It must be a + value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range. + generation_kwargs (`dict[str, Any]`, *optional*): + Additional keyword arguments to pass to the vLLM generation request. If it contains keys that conflict + with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them. chat_template_kwargs (`dict[str, Any]`, *optional*): Additional keyword arguments to pass to the `apply_chat_template` function when generating completions. + repetition_penalty (`float`, *optional*, defaults to `1.0`): + Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far. + Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat + tokens. max_tool_calling_iterations (`int`, *optional*): Maximum number of tool-calling turns when training an agent. If `None`, there is no limit and generation stops when the model generates a response turn with no tool calls or when the total response length reaches @@ -115,6 +132,34 @@ class AsyncGRPOConfig(_BaseConfig): default=1.0, metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."}, ) + top_p: float = field( + default=1.0, + metadata={ + "help": "Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. " + "Set to 1.0 to consider all tokens." + }, + ) + top_k: int = field( + default=0, + metadata={ + "help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `0`, " + "top-k-filtering is disabled and all tokens are considered." + }, + ) + min_p: float | None = field( + default=None, + metadata={ + "help": "Minimum token probability, which will be scaled by the probability of the most likely token. It " + "must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range." + }, + ) + generation_kwargs: dict[str, Any] | None = field( + default=None, + metadata={ + "help": "Additional keyword arguments to pass to the vLLM generation request. If it contains keys that " + "conflict with the other generation parameters (like `min_p`, `top_p`, etc.), they will override them." + }, + ) chat_template_kwargs: dict | None = field( default=None, metadata={ @@ -122,6 +167,14 @@ class AsyncGRPOConfig(_BaseConfig): "completions." }, ) + repetition_penalty: float = field( + default=1.0, + metadata={ + "help": "Float that penalizes new tokens based on whether they appear in the prompt and the generated " + "text so far. Values > 1.0 encourage the model to use new tokens, while values < 1.0 encourage the model " + "to repeat tokens." + }, + ) max_tool_calling_iterations: int | None = field( default=None, metadata={ diff --git a/trl/experimental/async_grpo/async_grpo_trainer.py b/trl/experimental/async_grpo/async_grpo_trainer.py index aca72c73596..f8f765f2b25 100644 --- a/trl/experimental/async_grpo/async_grpo_trainer.py +++ b/trl/experimental/async_grpo/async_grpo_trainer.py @@ -286,6 +286,11 @@ def __init__( self.epsilon_low = self.args.epsilon self.epsilon_high = self.args.epsilon_high self.temperature = self.args.temperature + self.top_p = self.args.top_p + self.top_k = self.args.top_k + self.min_p = self.args.min_p + self.repetition_penalty = self.args.repetition_penalty + self.generation_kwargs = self.args.generation_kwargs or {} # Model model_name = model @@ -370,6 +375,11 @@ def __init__( vllm_server_url=self.args.vllm_server_base_url, max_tokens=self.args.max_completion_length, temperature=self.args.temperature, + top_p=self.args.top_p, + top_k=self.args.top_k, + min_p=self.args.min_p, + repetition_penalty=self.args.repetition_penalty, + generation_kwargs=self.args.generation_kwargs, request_timeout=self.args.request_timeout, server_timeout=self.args.vllm_server_timeout, chat_template_kwargs=self.args.chat_template_kwargs, diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 3d7350d5a71..c68b606f3dd 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -93,6 +93,11 @@ def __init__( vllm_server_url: str = "http://localhost:8000", max_tokens: int = 32, temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = 0, + min_p: float | None = None, + repetition_penalty: float = 1.0, + generation_kwargs: dict[str, Any] | None = None, request_timeout: int = 120, server_timeout: float = 240.0, chat_template_kwargs: dict[str, Any] | None = None, @@ -155,6 +160,11 @@ def __init__( self.model_update_group = None self.max_tokens = max_tokens self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.min_p = min_p + self.repetition_penalty = repetition_penalty + self.generation_kwargs = generation_kwargs or {} self.request_timeout = request_timeout self.server_timeout = server_timeout self.chat_template_kwargs = chat_template_kwargs or {} @@ -611,10 +621,15 @@ async def _generate_one_turn(self, prompt_ids: list[int]) -> tuple[list[int], li "prompt": prompt_ids, "max_tokens": self.max_tokens, "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "min_p": 0.0 if self.min_p is None else self.min_p, + "repetition_penalty": self.repetition_penalty, "n": 1, "return_token_ids": True, "logprobs": 0, } + payload.update(self.generation_kwargs) while True: try: output = await self._post("/v1/completions", payload, self.request_timeout)